Showing preview only (1,299K chars total). Download the full file or copy to clipboard to get everything.
Repository: lzzcd001/GShell
Branch: main
Commit: c2f0ba9ea01a
Files: 124
Total size: 1.2 MB
Directory structure:
gitextract_zb4dyoqx/
├── .gitignore
├── GMeshDiffusion/
│ ├── diffusion_configs/
│ │ ├── config_lower_occgrid_normalized.py
│ │ └── config_upper_occgrid_normalized.py
│ ├── lib/
│ │ ├── dataset/
│ │ │ ├── gshell_dataset.py
│ │ │ └── gshell_dataset_aug.py
│ │ └── diffusion/
│ │ ├── evaler.py
│ │ ├── likelihood.py
│ │ ├── losses.py
│ │ ├── models/
│ │ │ ├── __init__.py
│ │ │ ├── ema.py
│ │ │ ├── functional.py
│ │ │ ├── layers.py
│ │ │ ├── normalization.py
│ │ │ ├── unet3d_occgrid.py
│ │ │ └── utils.py
│ │ ├── sampling.py
│ │ ├── sde_lib.py
│ │ ├── trainer.py
│ │ ├── trainer_ddp.py
│ │ └── utils.py
│ ├── main_diffusion.py
│ ├── main_diffusion_ddp.py
│ ├── metadata/
│ │ ├── get_splits_lower.py
│ │ ├── get_splits_upper.py
│ │ ├── save_tet_info.py
│ │ └── tet_to_cubic_grid_dataset.py
│ └── scripts/
│ ├── run_eval_lower_occgrid_normalized.sh
│ ├── run_eval_upper_occgrid_normalized.sh
│ ├── run_lower_occgrid_normalized_ddp.sh
│ └── run_upper_occgrid_normalized_ddp.sh
├── README.md
├── configs/
│ ├── deepfashion_mc.json
│ ├── deepfashion_mc_256.json
│ ├── deepfashion_mc_512.json
│ ├── deepfashion_mc_80.json
│ ├── nerf_chair.json
│ ├── polycam_mc.json
│ ├── polycam_mc_128.json
│ └── polycam_mc_16samples.json
├── data/
│ └── tets/
│ └── generate_tets.py
├── dataset/
│ ├── __init__.py
│ ├── dataset.py
│ ├── dataset_deepfashion.py
│ ├── dataset_deepfashion_testset.py
│ ├── dataset_llff.py
│ ├── dataset_mesh.py
│ ├── dataset_nerf.py
│ └── dataset_nerf_colmap.py
├── denoiser/
│ └── denoiser.py
├── eval_gmeshdiffusion_generated_samples.py
├── geometry/
│ ├── embedding.py
│ ├── flexicubes_table.py
│ ├── gshell_flexicubes.py
│ ├── gshell_flexicubes_geometry.py
│ ├── gshell_tets.py
│ ├── gshell_tets_geometry.py
│ └── mlp.py
├── render/
│ ├── light.py
│ ├── material.py
│ ├── mesh.py
│ ├── mlptexture.py
│ ├── obj.py
│ ├── optixutils/
│ │ ├── __init__.py
│ │ ├── c_src/
│ │ │ ├── accessor.h
│ │ │ ├── bsdf.h
│ │ │ ├── common.h
│ │ │ ├── denoising.cu
│ │ │ ├── denoising.h
│ │ │ ├── envsampling/
│ │ │ │ ├── kernel.cu
│ │ │ │ └── params.h
│ │ │ ├── math_utils.h
│ │ │ ├── optix_wrapper.cpp
│ │ │ ├── optix_wrapper.h
│ │ │ └── torch_bindings.cpp
│ │ ├── include/
│ │ │ ├── internal/
│ │ │ │ ├── optix_7_device_impl.h
│ │ │ │ ├── optix_7_device_impl_exception.h
│ │ │ │ └── optix_7_device_impl_transformations.h
│ │ │ ├── optix.h
│ │ │ ├── optix_7_device.h
│ │ │ ├── optix_7_host.h
│ │ │ ├── optix_7_types.h
│ │ │ ├── optix_denoiser_tiling.h
│ │ │ ├── optix_device.h
│ │ │ ├── optix_function_table.h
│ │ │ ├── optix_function_table_definition.h
│ │ │ ├── optix_host.h
│ │ │ ├── optix_stack_size.h
│ │ │ ├── optix_stubs.h
│ │ │ └── optix_types.h
│ │ ├── ops.py
│ │ └── tests/
│ │ └── filter_test.py
│ ├── regularizer.py
│ ├── render.py
│ ├── renderutils/
│ │ ├── __init__.py
│ │ ├── bsdf.py
│ │ ├── c_src/
│ │ │ ├── bsdf.cu
│ │ │ ├── bsdf.h
│ │ │ ├── common.cpp
│ │ │ ├── common.h
│ │ │ ├── cubemap.cu
│ │ │ ├── cubemap.h
│ │ │ ├── loss.cu
│ │ │ ├── loss.h
│ │ │ ├── mesh.cu
│ │ │ ├── mesh.h
│ │ │ ├── normal.cu
│ │ │ ├── normal.h
│ │ │ ├── tensor.h
│ │ │ ├── torch_bindings.cpp
│ │ │ ├── vec3f.h
│ │ │ └── vec4f.h
│ │ ├── loss.py
│ │ ├── ops.py
│ │ └── tests/
│ │ ├── test_bsdf.py
│ │ ├── test_loss.py
│ │ ├── test_mesh.py
│ │ └── test_perf.py
│ ├── texture.py
│ └── util.py
├── train_gflexicubes_deepfashion.py
├── train_gflexicubes_polycam.py
├── train_gshelltet_deepfashion.py
├── train_gshelltet_polycam.py
└── train_gshelltet_synthetic.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
# lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.DS_STORE
================================================
FILE: GMeshDiffusion/diffusion_configs/config_lower_occgrid_normalized.py
================================================
import ml_collections
import torch
import os
def get_config():
config = ml_collections.ConfigDict()
# data
data = config.data = ml_collections.ConfigDict()
data.root_dir = 'PLACEHOLDER'
# data.dataset_metapath = os.path.join(data.root_dir, 'metadata/lower_res64_train.txt')
data.num_workers = 4
data.grid_size = 128
data.tet_resolution = 64
data.num_channels = 4
data.use_occ_grid = True
data.grid_metafile = os.path.join(data.root_dir, 'metadata/lower_res64_grid_train.txt')
data.occgrid_metafile = os.path.join(data.root_dir, 'metadata/lower_res64_occgrid_train.txt')
data.occ_mask_path = os.path.join(data.root_dir, 'metadata/occ_mask_res64.pt')
data.tet_info_path = os.path.join(data.root_dir, 'metadata/tet_info.pt')
data.filter_meta_path = None
data.aug = True
# training
training = config.training = ml_collections.ConfigDict()
training.sde = 'vpsde'
training.continuous = False
training.reduce_mean = True
training.batch_size = 1 ### for DDP, global_batch_size = nproc * local_batch_size
training.num_grad_acc_steps = 4
training.n_iters = 2400001
training.snapshot_freq = 1000
training.log_freq = 50
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
training.loss_type = 'l2'
training.train_dir = "PLACEHOLDER"
training.snapshot_freq_for_preemption = 1000
training.gradscaler_growth_interval = 1000
training.use_aux_loss = False
training.compile = True # PyTorch 2.0, torch.compile
training.enable_xformers_memory_efficient_attention = True
# sampling
sampling = config.sampling = ml_collections.ConfigDict()
sampling.method = 'pc'
sampling.predictor = 'ancestral_sampling'
sampling.corrector = 'none'
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.075
# model
model = config.model = ml_collections.ConfigDict()
model.name = 'unet3d_occgrid'
model.use_occ_grid = True
model.num_res_blocks = 2
model.num_res_blocks_1st_layer = 2
model.base_channels = 128
model.ch_mult = (1, 2, 2, 4, 4, 4)
model.down_block_types = (
"ResBlock", "ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock"
)
model.up_block_types = (
"ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock", "ResBlock"
)
model.scale_by_sigma = False
model.num_scales = 1000
model.ema_rate = 0.9999
model.normalization = 'GroupNorm'
model.act_fn = 'swish'
model.attn_resolutions = (16,)
model.resamp_with_conv = True
model.dropout = 0.1
model.sigma_max = 378
model.sigma_min = 0.01
model.beta_min = 0.1
model.beta_max = 20.
model.embedding_type = 'fourier'
model.pred_type = 'noise'
model.conditional = True
model.feature_mask_path = os.path.join(data.root_dir, 'metadata/global_mask_res64.pt')
model.pixcat_mask_path = os.path.join(data.root_dir, 'metadata/cat_mask_res64.pt')
# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 1e-5
optim.optimizer = 'AdamW'
optim.lr = 1e-5
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 5000
optim.grad_clip = 1.
# eval
config.eval = eval_config = ml_collections.ConfigDict()
eval_config.batch_size = 2
eval_config.idx = 0
eval_config.bin_size = 30
eval_config.eval_dir = "PLACEHOLDER"
eval_config.ckpt_path = "PLACEHOLDER"
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
return config
================================================
FILE: GMeshDiffusion/diffusion_configs/config_upper_occgrid_normalized.py
================================================
import ml_collections
import torch
import os
def get_config():
config = ml_collections.ConfigDict()
# data
data = config.data = ml_collections.ConfigDict()
data.root_dir = 'PLACEHOLDER'
# data.dataset_metapath = os.path.join(data.root_dir, 'metadata/upper_res64_train.txt')
data.num_workers = 4
data.grid_size = 128
data.tet_resolution = 64
data.num_channels = 4
data.use_occ_grid = True
data.grid_metafile = os.path.join(data.root_dir, 'metadata/upper_res64_grid_train.txt')
data.occgrid_metafile = os.path.join(data.root_dir, 'metadata/upper_res64_occgrid_train.txt')
data.occ_mask_path = os.path.join(data.root_dir, 'metadata/occ_mask_res64.pt')
data.tet_info_path = os.path.join(data.root_dir, 'metadata/tet_info.pt')
data.filter_meta_path = None
data.aug = True
# training
training = config.training = ml_collections.ConfigDict()
training.sde = 'vpsde'
training.continuous = False
training.reduce_mean = True
training.batch_size = 1 ### for DDP, global_batch_size = nproc * local_batch_size
training.num_grad_acc_steps = 4
training.n_iters = 2400001
training.snapshot_freq = 1000
training.log_freq = 50
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
training.loss_type = 'l2'
training.train_dir = "PLACEHOLDER"
training.snapshot_freq_for_preemption = 1000
training.gradscaler_growth_interval = 1000
training.use_aux_loss = False
training.compile = True # PyTorch 2.0, torch.compile
training.enable_xformers_memory_efficient_attention = True
# sampling
sampling = config.sampling = ml_collections.ConfigDict()
sampling.method = 'pc'
sampling.predictor = 'ancestral_sampling'
sampling.corrector = 'none'
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.075
# model
model = config.model = ml_collections.ConfigDict()
model.name = 'unet3d_occgrid'
model.use_occ_grid = True
model.num_res_blocks = 2
model.num_res_blocks_1st_layer = 2
model.base_channels = 128
model.ch_mult = (1, 2, 2, 4, 4, 4)
model.down_block_types = (
"ResBlock", "ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock"
)
model.up_block_types = (
"ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock", "ResBlock"
)
model.scale_by_sigma = False
model.num_scales = 1000
model.ema_rate = 0.9999
model.normalization = 'GroupNorm'
model.act_fn = 'swish'
model.attn_resolutions = (16,)
model.resamp_with_conv = True
model.dropout = 0.1
model.sigma_max = 378
model.sigma_min = 0.01
model.beta_min = 0.1
model.beta_max = 20.
model.embedding_type = 'fourier'
model.pred_type = 'noise'
model.conditional = True
model.feature_mask_path = os.path.join(data.root_dir, 'metadata/global_mask_res64_occaug_normalized_v1.pt')
model.pixcat_mask_path = os.path.join(data.root_dir, 'metadata/cat_mask_res64_occaug_normalized_v1.pt')
# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 1e-5
optim.optimizer = 'AdamW'
optim.lr = 1e-5
optim.beta1 = 0.9
optim.eps = 1e-8
optim.warmup = 5000
optim.grad_clip = 1.
# eval
config.eval = eval_config = ml_collections.ConfigDict()
eval_config.batch_size = 2
eval_config.idx = 0
eval_config.bin_size = 30
eval_config.eval_dir = "PLACEHOLDER"
eval_config.ckpt_path = "PLACEHOLDER"
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
return config
================================================
FILE: GMeshDiffusion/lib/dataset/gshell_dataset.py
================================================
import torch
import numpy as np
from torch.utils.data import Dataset
class GShellDataset(Dataset):
def __init__(self, filepath_metafile, extension='pt'):
super().__init__()
with open(filepath_metafile, 'r') as f:
self.filepath_list = [fpath.rstrip() for fpath in f]
self.extension = extension
assert self.extension in ['pt', 'npy']
def __len__(self):
return len(self.filepath_list)
def __getitem__(self, idx):
with torch.no_grad():
if self.extension == 'pt':
datum = torch.load(self.filepath_list[idx], map_location='cpu')
else:
datum = torch.tensor(np.load(self.filepath_list[idx]))
return datum
================================================
FILE: GMeshDiffusion/lib/dataset/gshell_dataset_aug.py
================================================
import torch
from torch.utils.data import Dataset
class GShellAugDataset(Dataset):
def __init__(self, FLAGS, extension='pt'):
super().__init__()
with open(FLAGS.data.grid_metafile, 'r') as f:
self.filepath_list = [fpath.rstrip() for fpath in f]
with open(FLAGS.data.occgrid_metafile, 'r') as f:
self.occ_filepath_list = [fpath.rstrip() for fpath in f]
self.extension = extension
self.num_channels = FLAGS.data.num_channels
print('num_channels: ', self.num_channels)
assert self.extension in ['pt', 'npy']
def __len__(self):
return len(self.filepath_list)
def __getitem__(self, idx):
with torch.no_grad():
grid = torch.load(self.filepath_list[idx], map_location='cpu')
try:
occ_grid = torch.load(self.occ_filepath_list[idx], map_location='cpu')
except:
print(self.occ_filepath_list[idx])
raise
return (grid[:self.num_channels], occ_grid)
@staticmethod
def collate(data):
return {
'grid': torch.stack([x[0] for x in data]),
'occgrid': torch.stack([x[1] for x in data]),
}
================================================
FILE: GMeshDiffusion/lib/diffusion/evaler.py
================================================
import os
import sys
import numpy as np
import tqdm
import logging
from . import losses
from .models import utils as mutils
from .models.ema import ExponentialMovingAverage
from . import sde_lib
import torch
from .utils import restore_checkpoint
from . import sampling
def uncond_gen(
config
):
"""
Unconditional Generation
"""
with torch.no_grad():
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
idx = config.eval.idx
bin_size = config.eval.bin_size
print(f"idx to save: {idx * bin_size} to {idx * bin_size + bin_size - 1}")
# Create directory to eval_folder
os.makedirs(eval_dir, exist_ok=True)
scaler, inverse_scaler = lambda x: x, lambda x: x
# Initialize model
score_model = mutils.create_model(config, use_parallel=False)
optimizer = losses.get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
sampling_eps = 1e-3
sampling_shape = (config.eval.batch_size,
config.data.num_channels,
config.data.grid_size, config.data.grid_size, config.data.grid_size)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)
assert os.path.exists(ckpt_path)
print('ckpt path:', ckpt_path)
try:
state = restore_checkpoint(ckpt_path, state, device=config.device)
except:
raise
ema.copy_to(score_model.parameters())
print(f"loaded model is trained till iter {state['step'] // config.training.num_grad_acc_steps}")
for k in range(bin_size):
save_file_path = os.path.join(eval_dir, f"{idx * bin_size + k}")
print(f'check: {save_file_path}')
if os.path.exists(save_file_path + '.pt'):
# continue
pass
print(f'will save to: {save_file_path}')
samples, n = sampling_fn(score_model)
if type(samples) != tuple:
print(samples[:, 0].unique())
torch.save(samples, save_file_path + '.pt')
samples = samples.cpu().numpy()
# np.save(save_file_path, samples)
else:
print(samples[0][:, 0].unique())
torch.save(samples[0], save_file_path + '.pt')
torch.save(samples[1], save_file_path + '_occ.pt')
# samples, occ = samples[0].cpu().numpy(), samples[1].cpu().numpy()
# np.save(save_file_path + '.npy, samples)
def slerp(z1, z2, alpha):
'''
Spherical Linear Interpolation
'''
theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
return (
torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1
+ torch.sin(alpha * theta) / torch.sin(theta) * z2
)
def uncond_gen_interp(
config,
idx=0,
):
"""
Generation with interpolation between initial noises
Used for DDIM
"""
with torch.no_grad():
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
# Create directory to eval_folder
os.makedirs(eval_dir, exist_ok=True)
scaler, inverse_scaler = lambda x: x, lambda x: x
# Initialize model
score_model = mutils.create_model(config, use_parallel=False)
optimizer = losses.get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
sampling_eps = 1e-3
sampling_shape = (config.eval.batch_size,
config.data.num_channels,
config.data.grid_size, config.data.grid_size, config.data.grid_size)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)
assert os.path.exists(ckpt_path)
print('ckpt path:', ckpt_path)
try:
state = restore_checkpoint(ckpt_path, state, device=config.device)
except:
raise
ema.copy_to(score_model.parameters())
print(f"loaded model is trained till iter {state['step'] // config.training.num_grad_acc_steps}")
idx = config.eval.idx
bin_size = config.eval.bin_size
config.eval.interp_batch_size = 32
print(f"idx to save: {idx * bin_size} to {idx * bin_size + bin_size - 1}")
for k in range(bin_size):
save_file_path = os.path.join(eval_dir, f"{idx * bin_size + k}")
noise = sde.prior_sampling(
(2, config.data.num_channels, config.data.grid_size, config.data.grid_size, config.data.grid_size)
).to(config.device)
interp_sampling_shape = (config.eval.interp_batch_size,
config.data.num_channels,
config.data.grid_size, config.data.grid_size, config.data.grid_size)
x0 = torch.zeros(interp_sampling_shape, device=config.device)
x0[0] = noise[0]
x0[-1] = noise[1]
for i in range(1, config.eval.interp_batch_size - 1):
x0[i] = slerp(x0[0], x0[-1], i / float(config.eval.interp_batch_size - 1))
if config.model.use_occ_grid:
noise_occ = sde.prior_sampling(
(2, 1, config.data.grid_size * 2, config.data.grid_size * 2, config.data.grid_size * 2)
).to(config.device)
interp_sampling_shape = (config.eval.interp_batch_size,
1,
config.data.grid_size * 2, config.data.grid_size * 2, config.data.grid_size * 2)
x0_occ = torch.zeros(interp_sampling_shape, device=config.device)
x0_occ[0] = noise_occ[0]
x0_occ[-1] = noise_occ[1]
for i in range(1, config.eval.interp_batch_size - 1):
x0_occ[i] = slerp(x0_occ[0], x0_occ[-1], i / float(config.eval.interp_batch_size - 1))
else:
x0_occ = None
sample_list = []
sample_occ_list = []
for i in tqdm.trange(config.eval.interp_batch_size):
samples, n = sampling_fn(score_model, x0=x0[i:i+1], x0_occ=x0_occ[i:i+1])
if type(samples) != tuple:
# samples = samples.cpu()
sample_list.append(samples.cpu())
else:
# samples = samples.cpu()
sample_list.append(samples[0].cpu())
sample_occ_list.append(samples[1].cpu())
# np.save(save_file_path, np.concatenate(sample_list, axis=0))
torch.save(torch.cat(sample_list, dim=0), save_file_path + '.pt')
if config.model.use_occ_grid:
torch.save(torch.cat(sample_occ_list, dim=0), save_file_path + '_occ.pt')
def cond_gen(
config,
save_fname='0',
):
"""
Conditional Generation with partially completed dmtet from a 2.5D view (converted into a cubic grid)
"""
with torch.no_grad():
eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
# Create directory to eval_folder
os.makedirs(eval_dir, exist_ok=True)
scaler, inverse_scaler = lambda x: x, lambda x: x
# Initialize model
score_model = mutils.create_model(config)
optimizer = losses.get_optimizer(config, score_model.parameters())
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
resolution = config.data.image_size
grid_mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda")
grid_mask = grid_mask[:, :, :config.data.input_size, :config.data.input_size, :config.data.input_size]
sampling_eps = 1e-3
sampling_shape = (config.eval.batch_size,
config.data.num_channels,
# resolution, resolution, resolution)
config.data.input_size, config.data.input_size, config.data.input_size)
sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)
assert os.path.exists(ckpt_path)
print('ckpt path:', ckpt_path)
try:
state = restore_checkpoint(ckpt_path, state, device=config.device)
except:
raise
ema.copy_to(score_model.parameters())
print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")
save_file_path = os.path.join(eval_dir, f"{save_fname}.npy")
### Conditional but free gradients; start from small t
partial_dict = torch.load(config.eval.partial_dmtet_path)
partial_sdf = partial_dict['sdf']
partial_mask = partial_dict['vis']
### compute the mapping from tet indices to 3D cubic grid vertex indices
tet_path = config.eval.tet_path
tet = np.load(tet_path)
vertices = torch.tensor(tet['vertices'])
vertices_unique = vertices[:].unique()
dx = vertices_unique[1] - vertices_unique[0]
ind_to_coord = (torch.round(
(vertices - vertices.min()) / dx)
).long()
partial_sdf_grid = torch.zeros((1, 1, resolution, resolution, resolution))
partial_sdf_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_sdf
partial_mask_grid = torch.zeros((1, 1, resolution, resolution, resolution))
partial_mask_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_mask.float()
samples, n = sampling_fn(
score_model,
partial=partial_sdf_grid.cuda(),
partial_mask=partial_mask_grid.cuda(),
freeze_iters=config.eval.freeze_iters
)
samples = samples.cpu().numpy()
np.save(save_file_path, samples)
================================================
FILE: GMeshDiffusion/lib/diffusion/likelihood.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""
import torch
import numpy as np
from scipy import integrate
from .models import utils as mutils
def get_div_fn(fn):
"""Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator."""
def div_fn(x, t, eps):
with torch.enable_grad():
x.requires_grad_(True)
fn_eps = torch.sum(fn(x, t) * eps)
grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]
x.requires_grad_(False)
return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape))))
return div_fn
def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',
rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5):
"""Create a function to compute the unbiased log-likelihood estimate of a given data point.
Args:
sde: A `sde_lib.SDE` object that represents the forward SDE.
inverse_scaler: The inverse data normalizer.
hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator.
rtol: A `float` number. The relative tolerance level of the black-box ODE solver.
atol: A `float` number. The absolute tolerance level of the black-box ODE solver.
method: A `str`. The algorithm for the black-box ODE solver.
See documentation for `scipy.integrate.solve_ivp`.
eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.
Returns:
A function that a batch of data points and returns the log-likelihoods in bits/dim,
the latent code, and the number of function evaluations cost by computation.
"""
def drift_fn(model, x, t):
"""The drift function of the reverse-time SDE."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True)
# Probability flow ODE is a special case of Reverse SDE
rsde = sde.reverse(score_fn, probability_flow=True)
return rsde.sde(x, t)[0]
def div_fn(model, x, t, noise):
return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise)
def likelihood_fn(model, data):
"""Compute an unbiased estimate to the log-likelihood in bits/dim.
Args:
model: A score model.
data: A PyTorch tensor.
Returns:
bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim.
z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the
probability flow ODE.
nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
"""
with torch.no_grad():
shape = data.shape
if hutchinson_type == 'Gaussian':
epsilon = torch.randn_like(data)
elif hutchinson_type == 'Rademacher':
epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1.
else:
raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")
def ode_func(t, x):
sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32)
vec_t = torch.ones(sample.shape[0], device=sample.device) * t
drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t))
logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))
return np.concatenate([drift, logp_grad], axis=0)
init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)
solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)
nfe = solution.nfev
zp = solution.y[:, -1]
z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)
delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)
prior_logp = sde.prior_logp(z)
bpd = -(prior_logp + delta_logp) / np.log(2)
N = np.prod(shape[1:])
bpd = bpd / N
# A hack to convert log-likelihoods to bits/dim
offset = 7. - inverse_scaler(-1.)
bpd = bpd + offset
return bpd, z, nfe
return likelihood_fn
================================================
FILE: GMeshDiffusion/lib/diffusion/losses.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All functions related to loss computation and optimization.
"""
import torch
import torch.optim as optim
import numpy as np
from .models import utils as mutils
def get_optimizer(config, params):
"""Returns a flax optimizer object based on `config`."""
if config.optim.optimizer == 'Adam':
optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
weight_decay=config.optim.weight_decay)
elif config.optim.optimizer == 'AdamW':
optimizer = optim.AdamW(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
weight_decay=config.optim.weight_decay)
else:
raise NotImplementedError(
f'Optimizer {config.optim.optimizer} not supported yet!')
return optimizer
def optimization_manager(config):
"""Returns an optimize_fn based on `config`."""
def optimize_fn(optimizer, params, step, lr=config.optim.lr,
warmup=config.optim.warmup,
grad_clip=config.optim.grad_clip,
gradscaler=None):
"""Optimizes with warmup and gradient clipping (disabled if negative)."""
if warmup > 0:
for g in optimizer.param_groups:
g['lr'] = lr * np.minimum(step / warmup, 1.0)
if grad_clip >= 0:
gradscaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
gradscaler.step(optimizer)
gradscaler.update()
# optimizer.step()
return optimize_fn
def get_ddpm_loss_fn(vpsde, train, loss_type='l2', pred_type='noise', use_vis_mask=False, use_occ=False, use_aux=False):
"""Legacy code to reproduce previous results on DDPM. Not recommended for new work."""
if use_occ:
def loss_fn(model, batch, use_mesh_reg=False, verts_discretiezd=None, midpoints_discretiezd=None, edges=None):
batch, batch_occ = batch['grid'], batch['occgrid']
model_fn = mutils.get_model_fn(model, train=train)
labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)
sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)
sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)
with torch.no_grad():
noise = torch.randn_like(batch, device=batch.device)
noise_occ = torch.randn_like(batch_occ, device=batch.device)
perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \
sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise
perturbed_data = perturbed_data.type(batch.dtype)
perturbed_data_occ = sqrt_alphas_cumprod[labels, None, None, None, None] * batch_occ + \
sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise_occ
perturbed_data_occ = perturbed_data_occ.type(batch_occ.dtype)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
pred, pred_occ = model_fn((perturbed_data, perturbed_data_occ), labels)
pred, pred_occ = pred.float(), pred_occ.float()
alphas1 = sqrt_alphas_cumprod[labels, None, None, None, None]
alphas2 = sqrt_1m_alphas_cumprod[labels, None, None, None, None]
if pred_type == 'noise':
score = pred
score_occ = pred_occ
x0 = (perturbed_data - score * alphas2) / alphas1
x0_occ = (perturbed_data_occ - score_occ * alphas2) / alphas1
elif pred_type == 'x0':
x0 = pred
x0_occ = pred_occ
score = (perturbed_data - x0 * alphas1) / alphas2
score_occ = (perturbed_data_occ - pred_occ * alphas1) / alphas2
# noise = noise[:, :, :score.size(2), :score.size(3), :score.size(4)] ### to accommodate change of size due to arch
if loss_type == 'l2':
losses = torch.square(score - noise)
losses_occ = torch.square(score_occ - noise_occ)
assert losses_occ.size(1) == 1
elif loss_type == 'l1':
raise NotImplementedError
losses = torch.abs(score - noise)
else:
raise NotImplementedError
mask = model.module.feature_mask
occ_mask = model.module.occ_mask
assert len(mask.size()) == 5
assert mask.size(1) == losses.size(1)
assert occ_mask.size(1) == losses_occ.size(1)
assert losses.size(0) == losses_occ.size(0)
if mask is not None:
losses = losses * mask
losses_occ = losses_occ * occ_mask
occ_loss_scale = 1.0 if not use_aux else 1.0
loss = (torch.sum(losses) + torch.sum(losses_occ)) / (mask.sum() + occ_mask.sum()) / losses.size(0)
else:
raise NotImplementedError
if use_aux:
pred_vis = model.module.extract_vis_from_cubicgrid(x0, x0_occ)
with torch.no_grad():
gt_vis = model.module.extract_vis_from_cubicgrid(batch, batch_occ.view(*x0_occ.size()))
reg_loss = (
(pred_vis - gt_vis).pow(2).view(x0.size(0), -1).mean(dim=-1) * sqrt_alphas_cumprod[labels]
).mean()
else:
reg_loss = torch.zeros_like(loss)
total_loss = loss + reg_loss
return total_loss, loss, reg_loss
else:
def loss_fn(model, batch, use_mesh_reg=False, verts_discretiezd=None, midpoints_discretiezd=None, edges=None):
model_fn = mutils.get_model_fn(model, train=train)
labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)
sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)
sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)
noise = torch.randn_like(batch, device=batch.device)
perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \
sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise
perturbed_data = perturbed_data.type(batch.dtype)
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
pred = model_fn(perturbed_data, labels)
pred = pred.float()
alphas1 = sqrt_alphas_cumprod[labels, None, None, None, None]
alphas2 = sqrt_1m_alphas_cumprod[labels, None, None, None, None]
if pred_type == 'noise':
score = pred
x0 = (perturbed_data - score * alphas2) / alphas1
elif pred_type == 'x0':
x0 = pred
score = (perturbed_data - x0 * alphas1) / alphas2
if use_vis_mask:
assert x0.size(0) == 1
vis_mask = model.extract_vismask_from_cubicgrid(x0)
# noise = noise[:, :, :score.size(2), :score.size(3), :score.size(4)] ### to accommodate change of size due to arch
if loss_type == 'l2':
losses = torch.square((score - noise) * vis_mask)
elif loss_type == 'l1':
losses = torch.abs((score - noise) * vis_mask)
else:
raise NotImplementedError
else:
if loss_type == 'l2':
losses = torch.square(score - noise)
elif loss_type == 'l1':
losses = torch.abs(score - noise)
else:
raise NotImplementedError
mask = model.module.feature_mask
assert len(mask.size()) == 5
assert mask.size(1) == losses.size(1)
if mask is not None:
losses = losses * mask
loss = torch.sum(losses) / mask.sum() / losses.size(0)
else:
raise NotImplementedError
reg_loss = torch.zeros_like(loss)
total_loss = loss
return total_loss, loss, reg_loss
return loss_fn
def get_step_fn(sde, train, optimize_fn=None, loss_type='l2', pred_type='noise', use_vis_mask=False, use_occ=False, use_aux=False):
"""Create a one-step training/evaluation function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
optimize_fn: An optimization function.
reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
continuous: `True` indicates that the model is defined to take continuous time steps.
likelihood_weighting: If `True`, weight the mixture of score matching losses according to
https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.
Returns:
A one-step function for training or evaluation.
"""
loss_fn = get_ddpm_loss_fn(sde, train, loss_type=loss_type, pred_type=pred_type, use_vis_mask=use_vis_mask, use_occ=use_occ, use_aux=use_aux)
def step_fn(state, batch, clear_grad=True, update_param=True, gradscaler=None):
"""Running one step of training or evaluation.
This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
for faster execution.
Args:
state: A dictionary of training information, containing the score model, optimizer,
EMA status, and number of optimization steps.
batch: A mini-batch of training/evaluation data.
Returns:
loss: The average loss value of this state.
"""
model = state['model']
if train:
optimizer = state['optimizer']
if clear_grad:
optimizer.zero_grad()
loss_total, loss_score, loss_reg = loss_fn(model, batch)
gradscaler.scale(loss_total).backward()
if update_param:
optimize_fn(optimizer, model.parameters(), step=state['step'], gradscaler=gradscaler)
state['step'] += 1
state['ema'].update(model.parameters())
else:
with torch.no_grad():
ema = state['ema']
ema.store(model.parameters())
ema.copy_to(model.parameters())
loss_total, loss_score, loss_reg = loss_fn(model, batch)
ema.restore(model.parameters())
return {
'loss_total': loss_total,
'loss_score': loss_score,
'loss_reg': loss_reg,
}
return step_fn
================================================
FILE: GMeshDiffusion/lib/diffusion/models/__init__.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: GMeshDiffusion/lib/diffusion/models/ema.py
================================================
# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py
from __future__ import division
from __future__ import unicode_literals
import torch
# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
class ExponentialMovingAverage:
"""
Maintains (exponential) moving average of a set of parameters.
"""
def __init__(self, parameters, decay, use_num_updates=True):
"""
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the result of
`model.parameters()`.
decay: The exponential decay.
use_num_updates: Whether to use number of updates when computing
averages.
"""
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.decay = decay
self.num_updates = 0 if use_num_updates else None
self.shadow_params = [p.clone().detach()
for p in parameters if p.requires_grad]
self.collected_params = []
def update(self, parameters):
"""
Update currently maintained parameters.
Call this every time the parameters are updated, such as the result of
the `optimizer.step()` call.
Args:
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
parameters used to initialize this object.
"""
decay = self.decay
if self.num_updates is not None:
self.num_updates += 1
decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
one_minus_decay = 1.0 - decay
with torch.no_grad():
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
# print(s_param.device, s_param.device, param.device)
s_param.sub_(one_minus_decay * (s_param - param))
def copy_to(self, parameters):
"""
Copy current parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages.
"""
parameters = [p for p in parameters if p.requires_grad]
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
param.data.copy_(s_param.data)
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
def state_dict(self):
return dict(decay=self.decay, num_updates=self.num_updates,
shadow_params=self.shadow_params)
def load_state_dict(self, state_dict, device='cuda'):
self.decay = state_dict['decay']
self.num_updates = state_dict['num_updates']
self.shadow_params = state_dict['shadow_params']
for k, _ in enumerate(self.shadow_params):
self.shadow_params[k] = self.shadow_params[k].to(device)
# for k in self.shadow_params:
# print(k.device)
# raise
================================================
FILE: GMeshDiffusion/lib/diffusion/models/functional.py
================================================
#################################################################################################
# Copyright (c) 2023 Ali Hassani.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
#################################################################################################
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd
try:
from natten import _C
except ImportError:
raise ImportError(
f"Failed to import NATTEN's CPP backend. "
+ f"This could be due to an invalid/incomplete install. "
+ f"Please uninstall NATTEN (pip uninstall natten) and re-install with the"
f" correct torch build: "
+ f"shi-labs.com/natten"
)
def has_cuda():
return _C.has_cuda()
def has_half():
return _C.has_half()
def has_bfloat():
return _C.has_bfloat()
def has_gemm():
return _C.has_gemm()
def enable_tf32():
return _C.set_gemm_tf32(True)
def disable_tf32():
return _C.set_gemm_tf32(False)
def enable_tiled_na():
return _C.set_tiled_na(True)
def disable_tiled_na():
return _C.set_tiled_na(False)
def enable_gemm_na():
return _C.set_gemm_na(True)
def disable_gemm_na():
return _C.set_gemm_na(False)
class NeighborhoodAttention1DQKAutogradFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx, query, key, rpb, kernel_size, dilation):
query = query.contiguous()
key = key.contiguous()
attn = _C.na1d_qk_forward(query, key, rpb, kernel_size, dilation)
ctx.save_for_backward(query, key)
ctx.kernel_size = kernel_size
ctx.dilation = dilation
ctx.bias = rpb is not None
return attn
@staticmethod
@custom_bwd
def backward(ctx, grad_out):
outputs = _C.na1d_qk_backward(
grad_out.contiguous(),
ctx.saved_tensors[0],
ctx.saved_tensors[1],
ctx.bias,
ctx.kernel_size,
ctx.dilation,
)
d_query, d_key, d_rpb = outputs
return d_query, d_key, d_rpb, None, None
class NeighborhoodAttention1DAVAutogradFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx, attn, value, kernel_size, dilation):
attn = attn.contiguous()
value = value.contiguous()
out = _C.na1d_av_forward(attn, value, kernel_size, dilation)
ctx.save_for_backward(attn, value)
ctx.kernel_size = kernel_size
ctx.dilation = dilation
return out
@staticmethod
@custom_bwd
def backward(ctx, grad_out):
outputs = _C.na1d_av_backward(
grad_out.contiguous(),
ctx.saved_tensors[0],
ctx.saved_tensors[1],
ctx.kernel_size,
ctx.dilation,
)
d_attn, d_value = outputs
return d_attn, d_value, None, None
class NeighborhoodAttention2DQKAutogradFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx, query, key, rpb, kernel_size, dilation):
query = query.contiguous()
key = key.contiguous()
if rpb is not None:
rpb = rpb.to(key.dtype)
attn = _C.na2d_qk_forward(query, key, rpb, kernel_size, dilation)
ctx.save_for_backward(query, key)
ctx.kernel_size = kernel_size
ctx.dilation = dilation
ctx.bias = rpb is not None
return attn
@staticmethod
@custom_bwd
def backward(ctx, grad_out):
outputs = _C.na2d_qk_backward(
grad_out.contiguous(),
ctx.saved_tensors[0],
ctx.saved_tensors[1],
ctx.bias,
ctx.kernel_size,
ctx.dilation,
)
d_query, d_key, d_rpb = outputs
return d_query, d_key, d_rpb, None, None
class NeighborhoodAttention2DAVAutogradFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx, attn, value, kernel_size, dilation):
attn = attn.contiguous().to(value.dtype)
value = value.contiguous()
out = _C.na2d_av_forward(attn, value, kernel_size, dilation)
ctx.save_for_backward(attn, value)
ctx.kernel_size = kernel_size
ctx.dilation = dilation
return out
@staticmethod
@custom_bwd
def backward(ctx, grad_out):
outputs = _C.na2d_av_backward(
grad_out.contiguous(),
ctx.saved_tensors[0],
ctx.saved_tensors[1],
ctx.kernel_size,
ctx.dilation,
)
d_attn, d_value = outputs
return d_attn, d_value, None, None
class NeighborhoodAttention3DQKAutogradFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx, query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation):
query = query.contiguous()
key = key.contiguous()
attn = _C.na3d_qk_forward(
query, key, rpb, kernel_size, dilation, kernel_size_d, dilation_d
)
ctx.save_for_backward(query, key)
ctx.kernel_size_d = kernel_size_d
ctx.kernel_size = kernel_size
ctx.dilation_d = dilation_d
ctx.dilation = dilation
ctx.bias = rpb is not None
return attn
@staticmethod
@custom_bwd
def backward(ctx, grad_out):
outputs = _C.na3d_qk_backward(
grad_out.contiguous(),
ctx.saved_tensors[0],
ctx.saved_tensors[1],
ctx.bias,
ctx.kernel_size,
ctx.dilation,
ctx.kernel_size_d,
ctx.dilation_d,
)
d_query, d_key, d_rpb = outputs
return d_query, d_key, d_rpb, None, None, None, None
class NeighborhoodAttention3DAVAutogradFunction(Function):
@staticmethod
@custom_fwd
def forward(ctx, attn, value, kernel_size_d, kernel_size, dilation_d, dilation):
attn = attn.contiguous()
value = value.contiguous()
out = _C.na3d_av_forward(
attn, value, kernel_size, dilation, kernel_size_d, dilation_d
)
ctx.save_for_backward(attn, value)
ctx.kernel_size_d = kernel_size_d
ctx.kernel_size = kernel_size
ctx.dilation_d = dilation_d
ctx.dilation = dilation
return out
@staticmethod
@custom_bwd
def backward(ctx, grad_out):
outputs = _C.na3d_av_backward(
grad_out.contiguous(),
ctx.saved_tensors[0],
ctx.saved_tensors[1],
ctx.kernel_size,
ctx.dilation,
ctx.kernel_size_d,
ctx.dilation_d,
)
d_attn, d_value = outputs
return d_attn, d_value, None, None, None, None
def natten1dqkrpb(query, key, rpb, kernel_size, dilation):
return NeighborhoodAttention1DQKAutogradFunction.apply(
query, key, rpb, kernel_size, dilation
)
def natten1dqk(query, key, kernel_size, dilation):
return NeighborhoodAttention1DQKAutogradFunction.apply(
query, key, None, kernel_size, dilation
)
def natten1dav(attn, value, kernel_size, dilation):
return NeighborhoodAttention1DAVAutogradFunction.apply(
attn, value, kernel_size, dilation
)
def natten2dqkrpb(query, key, rpb, kernel_size, dilation):
return NeighborhoodAttention2DQKAutogradFunction.apply(
query, key, rpb, kernel_size, dilation
)
def natten2dqk(query, key, kernel_size, dilation):
return NeighborhoodAttention2DQKAutogradFunction.apply(
query, key, None, kernel_size, dilation
)
def natten2dav(attn, value, kernel_size, dilation):
return NeighborhoodAttention2DAVAutogradFunction.apply(
attn, value, kernel_size, dilation
)
def natten3dqkrpb(query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation):
return NeighborhoodAttention3DQKAutogradFunction.apply(
query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation
)
def natten3dqk(query, key, kernel_size_d, kernel_size, dilation_d, dilation):
return NeighborhoodAttention3DQKAutogradFunction.apply(
query, key, None, kernel_size_d, kernel_size, dilation_d, dilation
)
def natten3dav(attn, value, kernel_size_d, kernel_size, dilation_d, dilation):
return NeighborhoodAttention3DAVAutogradFunction.apply(
attn, value, kernel_size_d, kernel_size, dilation_d, dilation
)
================================================
FILE: GMeshDiffusion/lib/diffusion/models/layers.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
"""Common layers for defining score networks.
"""
import math
import string
from functools import partial
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from .normalization import ConditionalInstanceNorm3dPlus
class GroupNormFloat32(nn.GroupNorm):
def forward(self, input):
with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
return F.group_norm(
input.float(), self.num_groups, self.weight, self.bias, self.eps)
def get_act_fn(act_name):
"""Get activation functions from the config file."""
if act_name.lower() == 'elu':
return nn.ELU()
elif act_name.lower() == 'relu':
return nn.ReLU()
elif act_name.lower() == 'lrelu':
return nn.LeakyReLU(negative_slope=0.2)
elif act_name.lower() == 'swish' or act_name.lower() == 'silu':
return nn.SiLU()
else:
raise NotImplementedError('activation function does not exist!')
def variance_scaling(scale, mode, distribution,
in_axis=1, out_axis=0,
dtype=torch.float32,
device='cpu'):
"""Ported from JAX. """
def _compute_fans(shape, in_axis=1, out_axis=0):
receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
fan_in = shape[in_axis] * receptive_field_size
fan_out = shape[out_axis] * receptive_field_size
return fan_in, fan_out
def init(shape, dtype=dtype, device=device):
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
if mode == "fan_in":
denominator = fan_in
elif mode == "fan_out":
denominator = fan_out
elif mode == "fan_avg":
denominator = (fan_in + fan_out) / 2
else:
raise ValueError(
"invalid mode for variance scaling initializer: {}".format(mode))
variance = scale / denominator
if distribution == "normal":
return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
elif distribution == "uniform":
return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
else:
raise ValueError("invalid distribution for variance scaling initializer")
return init
def default_init(scale=1.):
"""The same initialization used in DDPM."""
scale = 1e-10 if scale == 0 else scale
return variance_scaling(scale, 'fan_avg', 'uniform')
class Dense(nn.Module):
"""Linear layer with `default_init`."""
def __init__(self):
super().__init__()
def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
"""1x1 convolution with DDPM initialization."""
conv = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
dilation=dilation, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def conv5x5(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):
"""3x3 convolution with DDPM initialization."""
conv = nn.Conv3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding,
dilation=dilation, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def conv3x3_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=1):
"""3x3 convolution with DDPM initialization."""
conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
dilation=dilation, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
def conv5x5_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):
"""3x3 convolution with DDPM initialization."""
conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding,
dilation=dilation, bias=bias)
conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
nn.init.zeros_(conv.bias)
return conv
###########################################################################
# Functions below are ported over from the DDPM codebase:
# https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
###########################################################################
def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
with torch.no_grad():
assert len(timesteps.shape) == 1 # and timesteps.dtype == tf.int32
half_dim = embedding_dim // 2
# magic number 10000 is from transformers
emb = math.log(max_positions) / (half_dim - 1)
# emb = math.log(2.) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
# emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
# emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
emb = timesteps[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = F.pad(emb, (0, 1), mode='constant')
assert emb.shape == (timesteps.shape[0], embedding_dim)
return emb
class AttnBlock(nn.Module):
"""Channel-wise self-attention block."""
def __init__(self, channels, num_groups=32):
super().__init__()
self.GroupNorm_0 = GroupNormFloat32(num_groups=num_groups, num_channels=channels, eps=1e-6)
self.NIN_0 = conv1x1(channels, channels)
self.NIN_1 = conv1x1(channels, channels)
self.NIN_2 = conv1x1(channels, channels)
self.NIN_3 = conv1x1(channels, channels, init_scale=0.)
def forward(self, x):
B, C, D, H, W = x.shape
h = self.GroupNorm_0(x)
q = self.NIN_0(h)
k = self.NIN_1(h)
v = self.NIN_2(h)
# q = q.view(B, C, -1).permute(0, 2, 1)
# k = k.view(B, C, -1).permute(0, 2, 1)
# v = v.view(B, C, -1).permute(0, 2, 1)
# with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
# h = F.scaled_dot_product_attention(q.float(), k.float(), v.float())
# h = h.permute(0, 2, 1).view(B, C, D, H, W)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
w = torch.einsum('bcdhw,bckij->bdhwkij', q.float(), k.float()) * (int(C) ** (-0.5))
w = torch.reshape(w, (B, D, H, W, D * H * W))
w = F.softmax(w, dim=-1)
w = torch.reshape(w, (B, D, H, W, D, H, W))
h = torch.einsum('bdhwkij,bckij->bcdhw', w, v)
h = self.NIN_3(h)
return x + h
class Upsample(nn.Module):
def __init__(self, channels, with_conv=False):
super().__init__()
if with_conv:
self.Conv_0 = conv3x3(channels, channels)
self.with_conv = with_conv
def forward(self, x, temb=None):
B, C, D, H, W = x.shape
h = F.interpolate(x.float(), (D * 2, H * 2, W * 2), mode='nearest')
if self.with_conv:
h = self.Conv_0(h)
return h
class Downsample(nn.Module):
def __init__(self, channels, with_conv=False):
super().__init__()
if with_conv:
self.Conv_0 = conv3x3(channels, channels, stride=2, padding=0)
self.with_conv = with_conv
def forward(self, x, temb=None):
B, C, D, H, W = x.shape
# Emulate 'SAME' padding
if self.with_conv:
x = F.pad(x, (0, 1, 0, 1, 0, 1))
x = self.Conv_0(x)
else:
x = F.avg_pool3d(x, kernel_size=2, stride=2, padding=0)
assert x.shape == (B, C, D // 2, H // 2, W // 2)
return x
class ResBlock(nn.Module):
"""The ResNet Blocks used in DDPM."""
def __init__(self, act_fn, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1, num_groups=32):
super().__init__()
if out_ch is None:
out_ch = in_ch
self.GroupNorm_0 = GroupNormFloat32(num_groups=num_groups, num_channels=in_ch, eps=1e-6)
self.act = act_fn
self.Conv_0 = conv3x3(in_ch, out_ch)
if temb_dim is not None:
self.Dense_0 = nn.Linear(temb_dim, out_ch)
self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
nn.init.zeros_(self.Dense_0.bias)
self.GroupNorm_1 = GroupNormFloat32(num_groups=num_groups, num_channels=out_ch, eps=1e-6)
self.Dropout_0 = nn.Dropout(dropout)
self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=0.)
if in_ch != out_ch:
if conv_shortcut:
self.Conv_2 = conv3x3(in_ch, out_ch)
else:
self.NIN_0 = conv1x1(in_ch, out_ch)
self.out_ch = out_ch
self.in_ch = in_ch
self.conv_shortcut = conv_shortcut
def forward(self, x, temb=None):
B, C, D, H, W = x.shape
assert C == self.in_ch
out_ch = self.out_ch if self.out_ch else self.in_ch
h = self.act(self.GroupNorm_0(x))
h = self.Conv_0(h)
# Add bias to each feature map conditioned on the time embedding
if temb is not None:
h += self.Dense_0(self.act(temb))[:, :, None, None, None]
h = self.act(self.GroupNorm_1(h))
h = self.Dropout_0(h)
h = self.Conv_1(h)
if C != out_ch:
if self.conv_shortcut:
x = self.Conv_2(x)
else:
x = self.NIN_0(x)
return x + h
class AttnResBlock(ResBlock):
"""The ResNet Blocks used in DDPM."""
def __init__(self, act_fn, in_ch, out_ch, temb_dim=None, conv_shortcut=False, dropout=0.1, num_groups=32):
super().__init__(act_fn, in_ch, out_ch, temb_dim, conv_shortcut, dropout, num_groups=num_groups)
self.attn_block = AttnBlock(out_ch, num_groups=num_groups)
def forward(self, x, temb=None):
h = super().forward(x, temb)
return self.attn_block(h)
================================================
FILE: GMeshDiffusion/lib/diffusion/models/normalization.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization layers."""
import torch.nn as nn
import torch
import functools
def get_normalization(config, conditional=False):
"""Obtain normalization modules from the config file."""
norm = config.model.normalization
if conditional:
if norm == 'InstanceNorm++':
return functools.partial(ConditionalInstanceNorm3dPlus, num_classes=config.model.num_classes)
else:
raise NotImplementedError(f'{norm} not implemented yet.')
else:
if norm == 'InstanceNorm':
return nn.InstanceNorm3d
elif norm == 'InstanceNorm++':
return InstanceNorm3dPlus
elif norm == 'VarianceNorm':
return VarianceNorm3d
elif norm == 'GroupNorm':
return nn.GroupNorm
else:
raise ValueError('Unknown normalization: %s' % norm)
class ConditionalBatchNorm3d(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
super().__init__()
self.num_features = num_features
self.bias = bias
self.bn = nn.BatchNorm3d(num_features, affine=False)
if self.bias:
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
else:
self.embed = nn.Embedding(num_classes, num_features)
self.embed.weight.data.uniform_()
def forward(self, x, y):
out = self.bn(x)
if self.bias:
gamma, beta = self.embed(y).chunk(2, dim=1)
out = gamma.view(-1, self.num_features, 1, 1, 1) * out + beta.view(-1, self.num_features, 1, 1, 1)
else:
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1, 1) * out
return out
class ConditionalInstanceNorm3d(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
super().__init__()
self.num_features = num_features
self.bias = bias
self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
if bias:
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
else:
self.embed = nn.Embedding(num_classes, num_features)
self.embed.weight.data.uniform_()
def forward(self, x, y):
h = self.instance_norm(x)
if self.bias:
gamma, beta = self.embed(y).chunk(2, dim=-1)
out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1)
else:
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1, 1) * h
return out
class ConditionalVarianceNorm3d(nn.Module):
def __init__(self, num_features, num_classes, bias=False):
super().__init__()
self.num_features = num_features
self.bias = bias
self.embed = nn.Embedding(num_classes, num_features)
self.embed.weight.data.normal_(1, 0.02)
def forward(self, x, y):
vars = torch.var(x, dim=(2, 3, 4), keepdim=True)
h = x / torch.sqrt(vars + 1e-5)
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1, 1) * h
return out
class VarianceNorm3d(nn.Module):
def __init__(self, num_features, bias=False):
super().__init__()
self.num_features = num_features
self.bias = bias
self.alpha = nn.Parameter(torch.zeros(num_features))
self.alpha.data.normal_(1, 0.02)
def forward(self, x):
vars = torch.var(x, dim=(2, 3, 4), keepdim=True)
h = x / torch.sqrt(vars + 1e-5)
out = self.alpha.view(-1, self.num_features, 1, 1, 1) * h
return out
class ConditionalNoneNorm3d(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
super().__init__()
self.num_features = num_features
self.bias = bias
if bias:
self.embed = nn.Embedding(num_classes, num_features * 2)
self.embed.weight.data[:, :num_features].uniform_() # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, num_features:].zero_() # Initialise bias at 0
else:
self.embed = nn.Embedding(num_classes, num_features)
self.embed.weight.data.uniform_()
def forward(self, x, y):
if self.bias:
gamma, beta = self.embed(y).chunk(2, dim=-1)
out = gamma.view(-1, self.num_features, 1, 1, 1) * x + beta.view(-1, self.num_features, 1, 1, 1)
else:
gamma = self.embed(y)
out = gamma.view(-1, self.num_features, 1, 1, 1) * x
return out
class NoneNorm3d(nn.Module):
def __init__(self, num_features, bias=True):
super().__init__()
def forward(self, x):
return x
class InstanceNorm3dPlus(nn.Module):
def __init__(self, num_features, bias=True):
super().__init__()
self.num_features = num_features
self.bias = bias
self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
self.alpha = nn.Parameter(torch.zeros(num_features))
self.gamma = nn.Parameter(torch.zeros(num_features))
self.alpha.data.normal_(1, 0.02)
self.gamma.data.normal_(1, 0.02)
if bias:
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
means = torch.mean(x, dim=(2, 3, 4))
m = torch.mean(means, dim=-1, keepdim=True)
v = torch.var(means, dim=-1, keepdim=True)
means = (means - m) / (torch.sqrt(v + 1e-5))
h = self.instance_norm(x)
if self.bias:
h = h + means[..., None, None, None] * self.alpha[..., None, None, None]
out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1, 1)
else:
h = h + means[..., None, None, None] * self.alpha[..., None, None, None]
out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h
return out
class ConditionalInstanceNorm3dPlus(nn.Module):
def __init__(self, num_features, num_classes, bias=True):
super().__init__()
self.num_features = num_features
self.bias = bias
self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
if bias:
self.embed = nn.Embedding(num_classes, num_features * 3)
self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02) # Initialise scale at N(1, 0.02)
self.embed.weight.data[:, 2 * num_features:].zero_() # Initialise bias at 0
else:
self.embed = nn.Embedding(num_classes, 2 * num_features)
self.embed.weight.data.normal_(1, 0.02)
def forward(self, x, y):
means = torch.mean(x, dim=(2, 3, 4))
m = torch.mean(means, dim=-1, keepdim=True)
v = torch.var(means, dim=-1, keepdim=True)
means = (means - m) / (torch.sqrt(v + 1e-5))
h = self.instance_norm(x)
if self.bias:
gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
h = h + means[..., None, None, None] * alpha[..., None, None, None]
out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1)
else:
gamma, alpha = self.embed(y).chunk(2, dim=-1)
h = h + means[..., None, None, None] * alpha[..., None, None, None]
out = gamma.view(-1, self.num_features, 1, 1, 1) * h
return out
================================================
FILE: GMeshDiffusion/lib/diffusion/models/unet3d_occgrid.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
"""DDPM model.
This code is the pytorch equivalent of:
https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py
"""
import torch
import torch.nn as nn
import functools
import numpy as np
from . import utils
from .layers import ResBlock, AttnResBlock, Upsample, Downsample, conv1x1, conv3x3, conv5x5, get_act_fn, default_init, get_timestep_embedding, GroupNormFloat32
import sys
def str_to_class(classname):
return getattr(sys.modules[__name__], classname)
@utils.register_model(name='unet3d_occgrid')
class UNet3D(nn.Module):
def __init__(self, config):
super().__init__()
self.act_fn = get_act_fn(config.model.act_fn)
self.nf = nf = config.model.base_channels
data_ch = config.data.num_channels
ch_mult = config.model.ch_mult
feature_mask = torch.load(config.model.feature_mask_path, map_location='cpu').view(1, data_ch, 128, 128, 128)
pixcat_mask = torch.load(config.model.pixcat_mask_path, map_location='cpu').view(1, 1, 128, 128, 128)
occ_mask_path = config.data.occ_mask_path
occ_mask = torch.load(occ_mask_path, map_location='cpu').view(1, 1, 256, 256, 256)
tet_info = torch.load(config.data.tet_info_path)
self.tet_edge_vpos = tet_info['tet_edge_vpos'].cuda()
self.tet_edge_pix_loc = tet_info['tet_edge_pix_loc'].cuda().view(-1, 2, 3)
self.tet_edge_pix_loc = self.tet_edge_pix_loc.view(-1, 2, 3)
# self.tet_center_loc = tet_info['tet_center_loc'].cuda()
self.vis_edges = tet_info['vis_edges'].cuda()
self.occ_edge_cano_order = tet_info['occ_edge_cano_order'].cuda()
self.tet_edgenode_loc = self.tet_edge_pix_loc.float().mean(dim=1).long()
self.occ_edge_loc = self.tet_edgenode_loc.view(-1, 6, 3)[:, self.vis_edges.view(-1)].view(-1, 2, 3)
self.occ_node_loc = (self.occ_edge_loc.view(-1, 12, 2, 3).float().mean(dim=-2) * 2.0).long().view(-1, 3)
print(self.tet_edgenode_loc.size(), self.vis_edges.size(), self.occ_edge_loc.size(), self.occ_node_loc.size())
self.tet_edge_pix_loc = self.tet_edge_pix_loc.view(-1, 3)
self.feature_mask = torch.nn.Parameter(feature_mask, requires_grad=False)
self.pixcat_mask = torch.nn.Parameter(pixcat_mask, requires_grad=False)
self.occ_mask = torch.nn.Parameter(occ_mask, requires_grad=False)
self.down_block_types = config.model.down_block_types
self.up_block_types = config.model.up_block_types
self.num_res_blocks = config.model.num_res_blocks
self.num_res_blocks_1st_layer = config.model.num_res_blocks_1st_layer
resamp_with_conv = config.model.resamp_with_conv
dropout = config.model.dropout
assert len(self.down_block_types) == len(self.up_block_types)
module_dict = {
module: functools.partial(str_to_class(module), act_fn=self.act_fn, temb_dim=4 * nf, dropout=dropout)
for module in ["ResBlock", "AttnResBlock"]
}
# Condition on noise levels.
noise_temb_layers = [nn.Linear(nf, nf * 4), nn.SiLU(), nn.Linear(nf * 4, nf * 4)]
noise_temb_layers[0].weight.data = default_init()(noise_temb_layers[0].weight.data.shape)
nn.init.zeros_(noise_temb_layers[0].bias)
noise_temb_layers[2].weight.data = default_init()(noise_temb_layers[2].weight.data.shape)
nn.init.zeros_(noise_temb_layers[2].bias)
self.noise_temb_layers = nn.Sequential(*noise_temb_layers)
self.occ_conv = conv3x3(1, nf, stride=2, padding=1)
self.occ_mask_conv = conv3x3(1, nf, stride=2, padding=1)
# Downsampling block
self.mask_layer = conv5x5(1, nf, stride=1, padding=2)
self.input_layer = conv5x5(data_ch, nf, stride=1, padding=2)
hs_c = [nf]
in_ch = nf
modules = []
for i_level, down_block_type in enumerate(self.down_block_types):
curr_block = module_dict[down_block_type]
# Residual blocks for this resolution
num_res_blocks = self.num_res_blocks_1st_layer if i_level == 0 else self.num_res_blocks
for i_block in range(num_res_blocks):
out_ch = nf * ch_mult[i_level]
modules.append(curr_block(in_ch=in_ch, out_ch=out_ch))
in_ch = out_ch
hs_c.append(in_ch)
if i_level != len(self.down_block_types) - 1:
modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv))
hs_c.append(in_ch)
in_ch = hs_c[-1]
modules.append(module_dict["AttnResBlock"](in_ch=in_ch, out_ch=in_ch))
modules.append(module_dict["ResBlock"](in_ch=in_ch))
# Upsampling block
for i_level, up_block_type in enumerate(self.up_block_types):
curr_block = module_dict[up_block_type]
num_res_blocks = self.num_res_blocks_1st_layer if i_level == len(self.up_block_types) - 1 else self.num_res_blocks
for i_block in range(num_res_blocks + 1):
out_ch = nf * ch_mult[len(self.up_block_types) - i_level - 1]
modules.append(curr_block(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
in_ch = out_ch
if i_level != len(self.up_block_types) - 1:
modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv))
self.all_modules = nn.ModuleList(modules)
self.output_norm_layer = nn.Sequential(
GroupNormFloat32(num_channels=in_ch, num_groups=32, eps=1e-6),
nn.SiLU(),
)
self.output_layer = conv5x5(in_ch, data_ch, init_scale=0., stride=1, padding=2)
self.occ_output_layer = nn.ConvTranspose3d(in_ch, 1, 4, stride=2, padding=1)
def sequentially_call_module(self, idx, x, temb=None):
return idx + 1, self.all_modules[idx](x, temb)
def forward(self, x, labels):
modules = self.all_modules
with torch.no_grad():
x, occ_grid = x[0], x[1]
if True or self.centered:
# Input is in [-1, 1]
x = x
else:
# Input is in [0, 1]
x = 2 * x - 1.
# Mask out unused values
x = x * self.feature_mask
occ_grid = occ_grid * self.occ_mask
with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
# timestep/scale embedding
timesteps = labels
temb = get_timestep_embedding(timesteps.float(), self.nf)
temb = self.noise_temb_layers(temb)
# Downsampling block
hs = [self.input_layer(x) + self.mask_layer(self.pixcat_mask) + self.occ_conv(occ_grid) + self.occ_mask_conv(self.occ_mask)]
m_idx = 0
for i_level in range(len(self.down_block_types)):
num_res_blocks = self.num_res_blocks_1st_layer if i_level == 0 else self.num_res_blocks
for i_block in range(num_res_blocks):
m_idx, h = self.sequentially_call_module(m_idx, hs[-1], temb)
hs.append(h)
if i_level != len(self.down_block_types) - 1:
m_idx, h = self.sequentially_call_module(m_idx, hs[-1])
hs.append(h)
h = hs[-1]
m_idx, h = self.sequentially_call_module(m_idx, h, temb)
m_idx, h = self.sequentially_call_module(m_idx, h, temb)
# Upsampling block
for i_level in range(len(self.up_block_types)):
num_res_blocks = self.num_res_blocks_1st_layer if i_level == len(self.up_block_types) - 1 else self.num_res_blocks
for i_block in range(num_res_blocks + 1):
hspop = hs.pop()
h = torch.cat([h, hspop], dim=1)
m_idx, h = self.sequentially_call_module(m_idx, h, temb)
if i_level != len(self.up_block_types) - 1:
m_idx, h = self.sequentially_call_module(m_idx, h, temb)
assert not hs
h = self.output_norm_layer(h)
grid = self.output_layer(h)
grid_occ = self.occ_output_layer(h)
# Mask out unused values
grid = grid * self.feature_mask
grid_occ = grid_occ * self.occ_mask
return grid, grid_occ
================================================
FILE: GMeshDiffusion/lib/diffusion/models/utils.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All functions and modules related to model definition.
"""
import torch
from .. import sde_lib
import numpy as np
_MODELS = {}
def register_model(cls=None, *, name=None):
"""A decorator for registering model classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _MODELS:
raise ValueError(f'Already registered model with name: {local_name}')
_MODELS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_model(name):
return _MODELS[name]
def get_sigmas(config):
"""Get sigmas --- the set of noise levels for SMLD from config files.
Args:
config: A ConfigDict object parsed from the config file
Returns:
sigmas: a jax numpy arrary of noise levels
"""
sigmas = np.exp(
np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))
return sigmas
def get_ddpm_params(config):
"""Get betas and alphas --- parameters used in the original DDPM paper."""
num_diffusion_timesteps = 1000
# parameters need to be adapted if number of time steps differs from 1000
beta_start = config.model.beta_min / config.model.num_scales
beta_end = config.model.beta_max / config.model.num_scales
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)
return {
'betas': betas,
'alphas': alphas,
'alphas_cumprod': alphas_cumprod,
'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
'beta_min': beta_start * (num_diffusion_timesteps - 1),
'beta_max': beta_end * (num_diffusion_timesteps - 1),
'num_diffusion_timesteps': num_diffusion_timesteps
}
def create_model(config, use_parallel=True, ddp=False, rank=None):
"""Create the score model."""
model_name = config.model.name
score_model = get_model(model_name)(config)
if use_parallel:
if ddp:
score_model = score_model.to(rank)
score_model = torch.nn.parallel.DistributedDataParallel(
score_model,
find_unused_parameters=False,
# find_unused_parameters=True,
gradient_as_bucket_view=True,
# static_graph=True,
device_ids=[rank])
# score_model = torch.compile(score_model)
# score_model = torch.compile(score_model, fullgraph=True)
else:
score_model = torch.nn.DataParallel(score_model).to(config.device)
else:
score_model = score_model.to(config.device)
return score_model
def get_model_fn(model, train=False):
"""Create a function to give the output of the score-based model.
Args:
model: The score model.
train: `True` for training and `False` for evaluation.
Returns:
A model function.
"""
def model_fn(x, labels):
"""Compute the output of the score-based model.
Args:
x: A mini-batch of input data.
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
Returns:
A tuple of (model output, new mutable states)
"""
if not train:
model.eval()
return model(x, labels)
else:
model.train()
return model(x, labels)
return model_fn
def get_reg_fn(model, train=False):
"""Create a function to give the output of the score-based model.
Args:
model: The score model.
train: `True` for training and `False` for evaluation.
Returns:
A model function.
"""
def model_fn(x):
"""Compute the output of the score-based model.
Args:
x: A mini-batch of input data.
labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
for different models.
Returns:
A tuple of (model output, new mutable states)
"""
if not train:
model.eval()
try:
return model.get_reg(x)
except:
return torch.zeros_like(x, device=x.device)
else:
model.train()
try:
return model.get_reg(x)
except:
return torch.zeros_like(x, device=x.device)
return model_fn
def get_score_fn(sde, model, train=False, continuous=False, std_scale=True, pred_type='noise'):
"""Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
model: A score model.
train: `True` for training and `False` for evaluation.
continuous: If `True`, the score-based model is expected to directly take continuous time steps.
std_scale: whether to scale the score function by the inverse of std. Used for DDIM sampling
Returns:
A score function.
"""
model_fn = get_model_fn(model, train=train)
reg_fn = get_reg_fn(model, train=train)
assert not continuous
if isinstance(sde, sde_lib.VPSDE):
if not std_scale:
def score_fn(x, t):
labels = t * (sde.N - 1)
pred = model_fn(x, labels)
if pred_type == 'x0':
labels = labels.long()
alphas1 = sde.sqrt_alphas_cumprod[labels, None, None, None, None].cuda()
alphas2 = sde.sqrt_1m_alphas_cumprod[labels, None, None, None, None].cuda()
score = (x - pred * alphas1) / alphas2
elif pred_type == 'noise':
score = pred
return score
else:
def score_fn(x, t):
# For VP-trained models, t=0 corresponds to the lowest noise level
labels = t * (sde.N - 1)
pred = model_fn(x, labels)
if pred_type == 'x0':
labels = labels.long()
alphas1 = sde.sqrt_alphas_cumprod[labels, None, None, None, None].cuda()
alphas2 = sde.sqrt_1m_alphas_cumprod[labels, None, None, None, None].cuda()
score = (x - pred * alphas1) / alphas2
elif pred_type == 'noise':
score = pred
std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]
score = -score / std[:, None, None, None, None]
return score
else:
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
return score_fn
def to_flattened_numpy(x):
"""Flatten a torch tensor `x` and convert it to numpy."""
return x.detach().cpu().numpy().reshape((-1,))
def from_flattened_numpy(x, shape):
"""Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
return torch.from_numpy(x.reshape(shape))
================================================
FILE: GMeshDiffusion/lib/diffusion/sampling.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""
import functools
import torch
import numpy as np
import abc
from .models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn
from scipy import integrate
from . import sde_lib
from .models import utils as mutils
import logging
import tqdm
_CORRECTORS = {}
_PREDICTORS = {}
def register_predictor(cls=None, *, name=None):
"""A decorator for registering predictor classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _PREDICTORS:
raise ValueError(f'Already registered model with name: {local_name}')
_PREDICTORS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def register_corrector(cls=None, *, name=None):
"""A decorator for registering corrector classes."""
def _register(cls):
if name is None:
local_name = cls.__name__
else:
local_name = name
if local_name in _CORRECTORS:
raise ValueError(f'Already registered model with name: {local_name}')
_CORRECTORS[local_name] = cls
return cls
if cls is None:
return _register
else:
return _register(cls)
def get_predictor(name):
return _PREDICTORS[name]
def get_corrector(name):
return _CORRECTORS[name]
def get_sampling_fn(config, sde, shape, inverse_scaler, eps, grid_mask=None, return_traj=False, pred_type='noise'):
"""Create a sampling function.
Args:
config: A `ml_collections.ConfigDict` object that contains all configuration information.
sde: A `sde_lib.SDE` object that represents the forward SDE.
shape: A sequence of integers representing the expected shape of a single sample.
inverse_scaler: The inverse data normalizer function.
eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.
Returns:
A function that takes random states and a replicated training state and outputs samples with the
trailing dimensions matching `shape`.
"""
sampler_name = config.sampling.method
# Probability flow ODE sampling with black-box ODE solvers
# Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.
if sampler_name.lower() == 'pc':
predictor = get_predictor(config.sampling.predictor.lower())
corrector = get_corrector(config.sampling.corrector.lower())
sampling_fn = get_pc_sampler(sde=sde,
shape=shape,
predictor=predictor,
corrector=corrector,
inverse_scaler=inverse_scaler,
snr=config.sampling.snr,
n_steps=config.sampling.n_steps_each,
probability_flow=config.sampling.probability_flow,
continuous=config.training.continuous,
denoise=config.sampling.noise_removal,
eps=eps,
device=config.device,
grid_mask=grid_mask,
return_traj=return_traj,
pred_type=pred_type,
use_occ=config.model.use_occ_grid)
elif sampler_name.lower() == 'ddim':
predictor = get_predictor('ddim')
sampling_fn = get_ddim_sampler(sde=sde,
shape=shape,
predictor=predictor,
inverse_scaler=inverse_scaler,
n_steps=config.sampling.n_steps_each,
denoise=config.sampling.noise_removal,
eps=eps,
device=config.device,
grid_mask=grid_mask,
pred_type=pred_type,
use_occ=config.model.use_occ_grid)
else:
raise ValueError(f"Sampler name {sampler_name} unknown.")
return sampling_fn
class Predictor(abc.ABC):
"""The abstract class for a predictor algorithm."""
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__()
self.sde = sde
# Compute the reverse SDE/ODE
self.rsde = sde.reverse(score_fn, probability_flow)
self.score_fn = score_fn
@abc.abstractmethod
def update_fn(self, x, t):
"""One update of the predictor.
Args:
x: A PyTorch tensor representing the current state
t: A Pytorch tensor representing the current time step.
Returns:
x: A PyTorch tensor of the next state.
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
"""
pass
class Corrector(abc.ABC):
"""The abstract class for a corrector algorithm."""
def __init__(self, sde, score_fn, snr, n_steps):
super().__init__()
self.sde = sde
self.score_fn = score_fn
self.snr = snr
self.n_steps = n_steps
@abc.abstractmethod
def update_fn(self, x, t):
"""One update of the corrector.
Args:
x: A PyTorch tensor representing the current state
t: A PyTorch tensor representing the current time step.
Returns:
x: A PyTorch tensor of the next state.
x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
"""
pass
@register_predictor(name='euler_maruyama')
class EulerMaruyamaPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t):
dt = -1. / self.rsde.N
z = torch.randn_like(x)
drift, diffusion = self.rsde.sde(x, t)
x_mean = x + drift * dt
x = x_mean + diffusion[:, None, None, None, None] * np.sqrt(-dt) * z
return x, x_mean
@register_predictor(name='reverse_diffusion')
class ReverseDiffusionPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t):
f, G = self.rsde.discretize(x, t)
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None, None] * z
return x, x_mean
@register_predictor(name='ancestral_sampling')
class AncestralSamplingPredictor(Predictor):
"""The ancestral sampling predictor. Currently only supports VE/VP SDEs."""
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
if not isinstance(sde, sde_lib.VPSDE):
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
assert not probability_flow, "Probability flow not supported by ancestral sampling"
def vpsde_update_fn(self, x, t):
sde = self.sde
timestep = (t * (sde.N - 1) / sde.T).long()
beta = sde.discrete_betas.to(t.device)[timestep]
score = self.score_fn(x, t)
x_mean = (x + beta[:, None, None, None, None] * score) / torch.sqrt(1. - beta)[:, None, None, None, None]
noise = torch.randn_like(x)
x = x_mean + torch.sqrt(beta)[:, None, None, None, None] * noise
return x, x_mean
def update_fn(self, x, t):
if isinstance(self.sde, sde_lib.VPSDE):
return self.vpsde_update_fn(x, t)
else:
raise NotImplementedError
@register_predictor(name='none')
class NonePredictor(Predictor):
"""An empty predictor that does nothing."""
def __init__(self, sde, score_fn, probability_flow=False):
pass
def update_fn(self, x, t):
return x, x
@register_predictor(name='ddim')
class DDIMPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t, tprev=None):
x, x0_pred = self.rsde.discretize_ddim(x, t, tprev=tprev)
return x, x0_pred
@register_corrector(name='langevin')
class LangevinCorrector(Corrector):
def __init__(self, sde, score_fn, snr, n_steps):
super().__init__(sde, score_fn, snr, n_steps)
if not isinstance(sde, sde_lib.VPSDE):
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
def update_fn(self, x, t):
sde = self.sde
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
if isinstance(sde, sde_lib.VPSDE):
timestep = (t * (sde.N - 1) / sde.T).long()
alpha = sde.alphas.to(t.device)[timestep]
else:
alpha = torch.ones_like(t)
for i in range(n_steps):
grad = score_fn(x, t)
noise = torch.randn_like(x)
grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
x_mean = x + step_size[:, None, None, None, None] * grad
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None, None] * noise
return x, x_mean
@register_corrector(name='ald')
class AnnealedLangevinDynamics(Corrector):
"""The original annealed Langevin dynamics predictor in NCSN/NCSNv2.
We include this corrector only for completeness. It was not directly used in our paper.
"""
def __init__(self, sde, score_fn, snr, n_steps):
super().__init__(sde, score_fn, snr, n_steps)
if not isinstance(sde, sde_lib.VPSDE):
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
def update_fn(self, x, t):
sde = self.sde
score_fn = self.score_fn
n_steps = self.n_steps
target_snr = self.snr
if isinstance(sde, sde_lib.VPSDE):
timestep = (t * (sde.N - 1) / sde.T).long()
alpha = sde.alphas.to(t.device)[timestep]
else:
alpha = torch.ones_like(t)
std = self.sde.marginal_prob(x, t)[1]
for i in range(n_steps):
grad = score_fn(x, t)
noise = torch.randn_like(x)
step_size = (target_snr * std) ** 2 * 2 * alpha
x_mean = x + step_size[:, None, None, None, None] * grad
x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None, None]
return x, x_mean
@register_corrector(name='none')
class NoneCorrector(Corrector):
"""An empty corrector that does nothing."""
def __init__(self, sde, score_fn, snr, n_steps):
pass
def update_fn(self, x, t):
return x, x
def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous, pred_type='noise'):
"""A wrapper that configures and returns the update function of predictors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous, pred_type=pred_type)
if predictor is None:
# Corrector-only sampler
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
else:
predictor_obj = predictor(sde, score_fn, probability_flow)
return predictor_obj.update_fn(x, t)
def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps, pred_type='noise'):
"""A wrapper tha configures and returns the update function of correctors."""
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous, pred_type=pred_type)
if corrector is None:
# Predictor-only sampler
corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
else:
corrector_obj = corrector(sde, score_fn, snr, n_steps)
return corrector_obj.update_fn(x, t)
def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr,
n_steps=1, probability_flow=False, continuous=False,
denoise=True, eps=1e-3, device='cuda',
grid_mask=None, return_traj=False, pred_type='noise', use_occ=False):
"""Create a Predictor-Corrector (PC) sampler.
Args:
sde: An `sde_lib.SDE` object representing the forward SDE.
shape: A sequence of integers. The expected shape of a single sample.
predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
inverse_scaler: The inverse data normalizer.
snr: A `float` number. The signal-to-noise ratio for configuring correctors.
n_steps: An integer. The number of corrector steps per predictor update.
probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
continuous: `True` indicates that the score model was continuously trained.
denoise: If `True`, add one-step denoising to the final samples.
eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
device: PyTorch device.
Returns:
A sampling function that returns samples and the number of function evaluations during sampling.
"""
# Create predictor & corrector update functions
predictor_update_fn = functools.partial(shared_predictor_update_fn,
sde=sde,
predictor=predictor,
probability_flow=probability_flow,
continuous=continuous,
pred_type=pred_type)
corrector_update_fn = functools.partial(shared_corrector_update_fn,
sde=sde,
corrector=corrector,
continuous=continuous,
snr=snr,
n_steps=n_steps,
pred_type=pred_type)
def pc_sampler(model,
partial=None, partial_grid_mask=None, partial_channel=0,
freeze_iters=None):
""" The PC sampler funciton.
Args:
model: A score model.
Returns:
Samples, number of function evaluations.
"""
with torch.no_grad():
if freeze_iters is None:
freeze_iters = sde.N + 10 # just some randomly large number greater than sde.N
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
mask = model.feature_mask
if pred_type == 'noise':
def compute_xzero(sde, model, x, t, grid_mask_input):
timestep_int = (t * (sde.N - 1) / sde.T).long()
alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda()
alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda()
alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda()
alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda()
score_pred = model(x, t * torch.ones(shape[0], device=x.device))
x0_pred_scaled = (x - alphas2 * score_pred)
x0_pred = x0_pred_scaled / alphas1
x0_pred = x0_pred.clamp(-1, 1)
return x0_pred * grid_mask_input
elif pred_type == 'x0':
def compute_xzero(sde, model, x, t, grid_mask_input):
timestep_int = (t * (sde.N - 1) / sde.T).long()
alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda()
alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda()
alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda()
alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda()
x0_pred = model(x, t * torch.ones(shape[0], device=x.device))
return x0_pred * grid_mask_input
# Initial sample
x = sde.prior_sampling(shape).to(device)
assert len(x.size()) == 5
traj_buffer = []
if partial is not None:
assert len(partial.size()) == 5
t = timesteps[0]
vec_t = torch.ones(shape[0], device=t.device) * t
x[:, partial_channel] = partial[:, partial_channel] * grid_mask[:, partial_channel]
partial_mean, partial_std = sde.marginal_prob(x, vec_t)
sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device)
x[:, partial_channel] = (
x[:, partial_channel] * (1 - partial_mask[:, partial_channel])
+ sampled_update[:, partial_channel] * partial_mask[:, partial_channel]
) * grid_mask[:, partial_channel]
if partial is not None:
x_mean = x
for i in tqdm.trange(sde.N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = corrector_update_fn(x, vec_t, model=model)
x, x_mean = x * grid_mask, x_mean * grid_mask
x, x_mean = predictor_update_fn(x, vec_t, model=model)
x, x_mean = x * grid_mask, x_mean * grid_mask
if i != sde.N - 1 and i < freeze_iters:
x[:, partial_channel] = (x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel]
x_mean[:, partial_channel] = (x_mean[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel]
### add noise to the condition x0_star
partial_mean, partial_std = sde.marginal_prob(x, timesteps[i] * torch.ones(shape[0], device=t.device))
sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device)
x[:, partial_channel] = (
x[:, partial_channel] * (1 - partial_mask[:, partial_channel])
+ sampled_update * partial_mask[:, partial_channel]
) * grid_mask[:, partial_channel]
x_mean[:, partial_channel] = x[:, partial_channel]
else:
for i in tqdm.trange(sde.N - 1):
t = timesteps[i]
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False):
vec_t = torch.ones(shape[0], device=t.device) * t
x, x_mean = corrector_update_fn(x, vec_t, model=model)
x, x_mean = x * mask, x_mean * mask
x, x_mean = predictor_update_fn(x, vec_t, model=model)
x, x_mean = x * mask, x_mean * mask
print(x.min(), x.max())
if return_traj and i >= 700 and i % 10 == 0:
traj_buffer.append(compute_xzero(sde, model, x, t, grid_mask))
if return_traj:
return traj_buffer, sde.N * (n_steps + 1)
return inverse_scaler(x_mean * mask if denoise else x * mask), sde.N * (n_steps + 1)
return pc_sampler
def ddim_predictor_update_fn(x, t, tprev, sde, model, predictor, probability_flow, continuous, pred_type='noise'):
"""A wrapper that configures and returns the update function of predictors."""
assert not continuous
score_fn = mutils.get_score_fn(sde, model, train=False, continuous=False, std_scale=False, pred_type=pred_type)
if predictor is None:
# Corrector-only sampler
predictor_obj = NonePredictor(sde, score_fn, probability_flow)
else:
predictor_obj = predictor(sde, score_fn, probability_flow)
return predictor_obj.update_fn(x, t, tprev)
def get_ddim_sampler(sde, shape, predictor, inverse_scaler, n_steps=1,
denoise=False, eps=1e-3, device='cuda', grid_mask=None, pred_type='noise', use_occ=False):
"""Probability flow ODE sampler with the black-box ODE solver.
Args:
sde: An `sde_lib.SDE` object that represents the forward SDE.
shape: A sequence of integers. The expected shape of a single sample.
inverse_scaler: The inverse data normalizer.
denoise: If `True`, add one-step denoising to final samples.
eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
device: PyTorch device.
Returns:
A sampling function that returns samples and the number of function evaluations during sampling.
"""
predictor_update_fn = functools.partial(ddim_predictor_update_fn,
sde=sde,
predictor=predictor,
probability_flow=False,
continuous=False,
pred_type=pred_type)
def ddim_sampler(model, schedule='quad', num_steps=100, x0=None, x0_occ=None,
partial=None, partial_grid_mask=None, partial_channel=0):
""" The PC sampler funciton.
Args:
model: A score model.
Returns:
Samples, number of function evaluations.
"""
with torch.no_grad():
print(device)
if x0 is not None:
x = x0.to(device)
else:
# Initial sample
x = sde.prior_sampling(shape).to(device)
mask = model.feature_mask
if use_occ:
occ_mask = model.occ_mask.float()
if x0_occ is not None:
x_occ = x0_occ.to(device)
else:
# Initial sample
x_occ = sde.prior_sampling((x.size(0), 1, x.size(2)*2, x.size(3)*2, x.size(4)*2)).to(device)
x = (x.float() * mask, x_occ.float() * occ_mask)
if partial is not None:
x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask
timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
if schedule == 'uniform':
skip = sde.N // num_steps
seq = range(0, sde.N, skip)
elif schedule == 'quad':
seq = (
np.linspace(
0, np.sqrt(sde.N * 0.8), 100
)
** 2
)
seq = [int(s) for s in list(seq)]
timesteps = torch.tensor(seq, device=device) / sde.N
for i in tqdm.tqdm(list(reversed(range(1, len(timesteps)))), leave=False):
t = timesteps[i]
tprev = timesteps[i - 1]
vec_t = torch.ones(1, device=t.device) * t
vec_tprev = torch.ones(1, device=t.device) * tprev
with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
x, x0_pred = predictor_update_fn(x, vec_t, model=model, tprev=vec_tprev)
if use_occ:
x = (x[0] * mask, x[1] * occ_mask)
x0_pred = (x0_pred[0] * mask, x0_pred[1] * occ_mask)
# print(x[0].min(), x[0].max())
else:
x, x0_pred = x * mask, x0_pred * mask
# print(x.min(), x.max())
if partial is not None:
x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask
x0_pred[:, partial_channel] = x0_pred[:, partial_channel] * (1 - partial_mask) + partial * partial_mask
if use_occ:
encode = False
return (
inverse_scaler(x0_pred[0] * mask if (denoise and not encode) else x[0] * mask),
inverse_scaler(x0_pred[1] * occ_mask if (denoise and not encode) else x[1] * occ_mask)
), sde.N * (n_steps + 1)
else:
encode = False
return inverse_scaler(x0_pred * mask if (denoise and not encode) else x * mask), sde.N * (n_steps + 1)
return ddim_sampler
================================================
FILE: GMeshDiffusion/lib/diffusion/sde_lib.py
================================================
"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs."""
import abc
import torch
import numpy as np
import torch.nn.functional as F
import time
class SDE(abc.ABC):
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
def __init__(self, N):
"""Construct an SDE.
Args:
N: number of discretization time steps.
"""
super().__init__()
self.N = N
@property
@abc.abstractmethod
def T(self):
"""End time of the SDE."""
pass
@abc.abstractmethod
def sde(self, x, t):
pass
@abc.abstractmethod
def marginal_prob(self, x, t):
"""Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""
pass
@abc.abstractmethod
def prior_sampling(self, shape):
"""Generate one sample from the prior distribution, $p_T(x)$."""
pass
@abc.abstractmethod
def prior_logp(self, z):
"""Compute log-density of the prior distribution.
Useful for computing the log-likelihood via probability flow ODE.
Args:
z: latent code
Returns:
log probability density
"""
pass
def discretize(self, x, t):
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
Args:
x: a torch tensor
t: a torch float representing the time step (from 0 to `self.T`)
Returns:
f, G
"""
dt = 1 / self.N
drift, diffusion = self.sde(x, t)
f = drift * dt
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
return f, G
def reverse(self, score_fn, probability_flow=False):
"""Create the reverse-time SDE/ODE.
Args:
score_fn: A time-dependent score-based model that takes x and t and returns the score.
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
"""
N = self.N
T = self.T
sde_fn = self.sde
discretize_fn = self.discretize
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_1m_alphas_cumprod = self.sqrt_1m_alphas_cumprod
# Build the class for reverse-time SDE.
class RSDE(self.__class__):
def __init__(self):
self.N = N
self.probability_flow = probability_flow
@property
def T(self):
return T
def sde(self, x, t):
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
drift, diffusion = sde_fn(x, t)
score = score_fn(x, t)
drift = drift - diffusion[:, None, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
# Set the diffusion function to zero for ODEs.
diffusion = 0. if self.probability_flow else diffusion
return drift, diffusion
def discretize(self, x, t):
"""Create discretized iteration rules for the reverse diffusion sampler."""
f, G = discretize_fn(x, t)
rev_f = f - G[:, None, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
def discretize_ddim(self, x, t, tprev=None, encode=False):
"""DDPM discretization."""
timestep = (t * (N - 1) / T).long()
timestep_prev = (tprev * (N - 1) / T).long()
if type(x) == torch.Tensor:
score = score_fn(x.float(), t.float())
# alphas1prev_div_alphas1 = torch.exp(log_diff)
alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double()
alphas2prev_div_alphas2 = alphas2_prev.double() / alphas2.double()
x0_pred_scaled = (x.double() - alphas2.double() * score.double())
use_clip = False
if use_clip:
# raise NotImplementedError
x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
score_scaled_t = x - x0_pred_scaled
x0_pred = x0_pred_scaled / alphas1
x_new = (
alphas1prev_div_alphas1.double() * x +
(-alphas1prev_div_alphas1 + alphas2prev_div_alphas2.double()) * score_scaled_t.double()
)
return x_new, x0_pred
else:
score, score_occ = score_fn(x, t.float())
x, x_occ = x
alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
alphas1prev_div_alphas1 = alphas1_prev / alphas1
alphas2prev_div_alphas2 = alphas2_prev / alphas2
x0_pred_scaled = (x - alphas2 * score)
x0_occ_pred_scaled = (x_occ - alphas2 * score_occ)
use_clip = False
if use_clip:
x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
x0_occ_pred_scaled = x0_occ_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
score_scaled_t = x - x0_pred_scaled
x0_pred = x0_pred_scaled / alphas1
score_occ_scaled_t = x_occ - x0_occ_pred_scaled
x0_occ_pred = x0_occ_pred_scaled / alphas1
x_new = (
alphas1prev_div_alphas1 * x +
(-alphas1prev_div_alphas1 + alphas2prev_div_alphas2) * score_scaled_t
)
x_occ_new = (
alphas1prev_div_alphas1 * x_occ +
(-alphas1prev_div_alphas1 + alphas2prev_div_alphas2) * score_occ_scaled_t
)
return (x_new, x_occ_new), (x0_pred, x0_occ_pred)
def discretize_conditional_ddpm(self, x, t, tprev=None, condition_func=None, condition=False):
"""DDPM discretization."""
timestep = (t * (N - 1) / T).long()
timestep_prev = (tprev * (N - 1) / T).long()
score = score_fn(x.float(), t.float())
# alphas1prev_div_alphas1 = torch.exp(log_diff)
alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double()
x0_pred_scaled = (x.double() - alphas2.double() * score.double())
x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
x0_pred = x0_pred_scaled / alphas1
if condition is None:
condition_update = 0
else:
if (t - 0.99).mean() < 1e-3:
x = x0_pred
condition_update = condition_func(x.float(), condition)
x_new = (
x - alphas1prev_div_alphas1.double() * condition_update
)
return x_new, x0_pred
return RSDE()
class VPSDE(SDE):
def __init__(self, beta_min=0.1, beta_max=20, N=1000):
"""Construct a Variance Preserving SDE.
Args:
beta_min: value of beta(0)
beta_max: value of beta(1)
N: number of discretization steps
"""
super().__init__(N)
self.beta_0 = beta_min
self.beta_1 = beta_max
self.N = N
self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N).cuda()
self.alphas = 1. - self.discrete_betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.alphas_cumprod_ext = torch.cat([torch.tensor([1.0 - 1e-4]).cuda(), torch.cumprod(self.alphas, dim=0)], dim=0)
self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
self.alphas_cumprod = self.alphas_cumprod
self.alphas_cumprod_ext = self.alphas_cumprod_ext
@property
def T(self):
return 1
def sde(self, x, t):
beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
drift = -0.5 * beta_t[:, None, None, None, None] * x
diffusion = torch.sqrt(beta_t)
return drift, diffusion
def marginal_prob(self, x, t):
log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
mean = torch.exp(log_mean_coeff[:, None, None, None, None]) * x
std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
return mean, std
def prior_sampling(self, shape):
return torch.randn(*shape)
def prior_logp(self, z):
shape = z.shape
N = np.prod(shape[1:])
logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3, 4)) / 2.
return logps
def discretize(self, x, t):
"""DDPM discretization."""
timestep = (t * (self.N - 1) / self.T).long()
beta = self.discrete_betas.to(x.device)[timestep]
alpha = self.alphas.to(x.device)[timestep]
sqrt_beta = torch.sqrt(beta)
f = torch.sqrt(alpha)[:, None, None, None, None] * x - x
G = sqrt_beta
return f, G
================================================
FILE: GMeshDiffusion/lib/diffusion/trainer.py
================================================
import os
import sys
import numpy as np
import logging
# Keep the import below for registering all model definitions
from .models import unet3d, unet3d_occgrid, unet3d_tet_aware, unet3d_occgrid_v2, unet3d_meshdiffusion
from . import losses
from .models import utils as mutils
from .models.ema import ExponentialMovingAverage
from . import sde_lib
import torch
from torch.utils import tensorboard
from .utils import save_checkpoint, restore_checkpoint
from ..dataset.gshell_dataset import GShellDataset
from ..dataset.gshell_dataset_aug import GShellAugDataset
def train(config):
"""Runs the training pipeline.
Args:
config: Configuration to use.
workdir: Working directory for checkpoints and TF summaries. If this
contains checkpoint training will be resumed from the latest checkpoint.
"""
workdir = config.training.train_dir
# Create directories for experimental logs
logging.info("working dir: {:s}".format(workdir))
tb_dir = os.path.join(workdir, "tensorboard")
writer = tensorboard.SummaryWriter(tb_dir)
# Initialize model.
score_model = mutils.create_model(config)
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
optimizer = losses.get_optimizer(config, score_model.parameters())
gradscaler = torch.cuda.amp.GradScaler(enabled=True)
state = dict(optimizer=optimizer, model=score_model, ema=ema, gradscaler=gradscaler, step=0)
# Create checkpoints directory
checkpoint_dir = os.path.join(workdir, "checkpoints")
# Intermediate checkpoints to resume training after pre-emption in cloud environments
checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)
# Resume training when intermediate checkpoints are detected
state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
initial_step = int(state['step'])
print(f"work dir: {workdir}")
try:
use_occ_grid = config.data.use_occ_grid
except:
use_occ_grid = False
if use_occ_grid:
train_dataset = GShellAugDataset(config)
else:
train_dataset = GShellDataset(config.data.dataset_metapath)
try:
collate_fn = train_dataset.collate
except:
collate_fn = None
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.training.batch_size,
shuffle=True,
num_workers=config.data.num_workers,
collate_fn=collate_fn,
pin_memory=True
)
data_iter = iter(train_loader)
print("data loader set")
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(config)
try:
use_vis_mask = config.model.use_vis_mask
except:
use_vis_mask = False
print('use_vis_mask', use_vis_mask)
train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,
loss_type=config.training.loss_type,
pred_type=config.model.pred_type,
use_vis_mask=use_vis_mask,
use_occ=use_occ_grid,
use_aux=config.training.use_aux_loss)
num_train_steps = config.training.n_iters
# In case there are multiple hosts (e.g., TPU pods), only log to host 0
logging.info("Starting training loop at step %d." % (initial_step // config.training.num_grad_acc_steps,))
iter_size = config.training.num_grad_acc_steps
for step in range(initial_step // iter_size, num_train_steps + 1):
tmp_loss_dict = {
'loss_total': 0.0,
'loss_score': 0.0,
'loss_reg': 0.0,
}
for step_inner in range(iter_size):
try:
# batch, batch_mask = next(data_iter)
batch = next(data_iter)
except StopIteration:
# StopIteration is thrown if dataset ends
# reinitialize data loader
data_iter = iter(train_loader)
batch = next(data_iter)
if type(batch) == dict:
for k in batch:
batch[k] = batch[k].to('cuda', non_blocking=False)
else:
batch = batch.to('cuda', non_blocking=False)
# Execute one training step
clear_grad_flag = (step_inner == 0)
update_param_flag = (step_inner == iter_size - 1)
loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler)
for key in loss_dict:
tmp_loss_dict[key] += loss_dict[key].item() / iter_size
# print(torch.cuda.memory_summary())
if step % config.training.log_freq == 0:
# logging.info("step: %d, training_loss: %.5e" % (step, tmp_loss))
logging.info(
"step: %d, loss_total: %.5e, loss_score: %.5e, loss_reg: %.5e"
% (step, tmp_loss_dict['loss_total'], tmp_loss_dict['loss_score'], tmp_loss_dict['loss_reg'])
)
sys.stdout.flush()
writer.add_scalar("loss_total", tmp_loss_dict['loss_total'], step)
writer.add_scalar("loss_score", tmp_loss_dict['loss_score'], step)
writer.add_scalar("loss_reg", tmp_loss_dict['loss_reg'], step)
# Save a temporary checkpoint to resume training after pre-emption periodically
if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
logging.info(f"save meta at iter {step}")
save_checkpoint(checkpoint_meta_dir, state)
# Save a checkpoint periodically and generate samples if needed
if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
logging.info(f"save model: {step}-th")
save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state)
================================================
FILE: GMeshDiffusion/lib/diffusion/trainer_ddp.py
================================================
import os
import sys
import numpy as np
import logging
# Keep the import below for registering all model definitions
from .models import unet3d, unet3d_occgrid, unet3d_tet_aware, unet3d_occgrid_v2, unet3d_meshdiffusion
from . import losses
from .models import utils as mutils
from .models.ema import ExponentialMovingAverage
from . import sde_lib
import torch
from torch.utils import tensorboard
from .utils import save_checkpoint, restore_checkpoint
from ..dataset.gshell_dataset import GShellDataset
from ..dataset.gshell_dataset_aug import GShellAugDataset
from .lion.lion import Lion
import torch.distributed as dist
def train(config):
"""Runs the training pipeline.
Args:
config: Configuration to use.
workdir: Working directory for checkpoints and TF summaries. If this
contains checkpoint training will be resumed from the latest checkpoint.
"""
dist.init_process_group("nccl")
rank = dist.get_rank()
torch.cuda.set_device(rank)
device = torch.device("cuda", rank)
print(f"Start running basic DDP example on rank {rank}.")
# create model and move it to GPU with id rank
world_size = torch.cuda.device_count()
device_id = rank % torch.cuda.device_count()
workdir = config.training.train_dir
# Create directories for experimental logs
logging.info("working dir: {:s}".format(workdir))
tb_dir = os.path.join(workdir, "tensorboard")
writer = tensorboard.SummaryWriter(tb_dir)
# Initialize model.
score_model = mutils.create_model(config, ddp=True, rank=rank)
ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
optimizer = losses.get_optimizer(config, score_model.parameters())
gradscaler = torch.cuda.amp.GradScaler(growth_interval=config.training.gradscaler_growth_interval)
state = dict(optimizer=optimizer, model=score_model, ema=ema, gradscaler=gradscaler, step=0)
# Create checkpoints directory
checkpoint_dir = os.path.join(workdir, "checkpoints")
# Intermediate checkpoints to resume training after pre-emption in cloud environments
checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)
# Resume training when intermediate checkpoints are detected
state = restore_checkpoint(checkpoint_meta_dir, state, config.device, rank=rank)
initial_step = int(state['step'])
print(f"work dir: {workdir}")
try:
use_occ_grid = config.data.use_occ_grid
except:
use_occ_grid = False
if use_occ_grid:
train_dataset = GShellAugDataset(config)
else:
train_dataset = GShellDataset(config.data.dataset_metapath)
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=world_size,
rank=rank
)
try:
collate_fn = train_dataset.collate
except:
collate_fn = None
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.training.batch_size,
num_workers=config.data.num_workers,
# pin_memory=True,
sampler=train_sampler,
collate_fn=collate_fn
)
data_iter = iter(train_loader)
print("data loader set")
# Setup SDEs
sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
# Build one-step training and evaluation functions
optimize_fn = losses.optimization_manager(config)
try:
use_vis_mask = config.model.use_vis_mask
except:
use_vis_mask = False
print('use_vis_mask', use_vis_mask)
train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,
loss_type=config.training.loss_type,
pred_type=config.model.pred_type,
use_vis_mask=use_vis_mask,
use_occ=use_occ_grid,
use_aux=config.training.use_aux_loss)
num_train_steps = config.training.n_iters
# In case there are multiple hosts (e.g., TPU pods), only log to host 0
logging.info("Starting training loop at step %d." % (initial_step // config.training.num_grad_acc_steps,))
iter_size = config.training.num_grad_acc_steps
epoch = 0
train_sampler.set_epoch(epoch)
for step in range(initial_step // iter_size, num_train_steps + 1):
tmp_loss_dict = {
'loss_total': 0.0,
'loss_score': 0.0,
'loss_reg': 0.0,
}
for step_inner in range(iter_size):
try:
# batch, batch_mask = next(data_iter)
batch = next(data_iter)
except StopIteration:
# StopIteration is thrown if dataset ends
# reinitialize data loader
epoch += 1
train_sampler.set_epoch(epoch)
data_iter = iter(train_loader)
batch = next(data_iter)
if type(batch) == dict:
for k in batch:
batch[k] = batch[k].to(rank, non_blocking=False)
else:
batch = batch.to(rank, non_blocking=False)
# Execute one training step
clear_grad_flag = (step_inner == 0)
update_param_flag = (step_inner == iter_size - 1)
if not update_param_flag:
with score_model.no_sync():
loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler)
else:
loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler)
for key in loss_dict:
tmp_loss_dict[key] += loss_dict[key].item() / iter_size
# print(torch.cuda.memory_summary())
if step % config.training.log_freq == 0:
loss = tmp_loss_dict['loss_total']
loss = torch.tensor(loss / world_size).to(rank)
# logging.info("step: %d, training_loss: %.5e" % (step, tmp_loss))
dist.reduce(loss, dst=0, op=dist.ReduceOp.SUM)
if rank == 0:
loss = loss.item()
logging.info("step: %d, loss: %.5e, scale: %.5e" % (step, loss, gradscaler.get_scale()))
sys.stdout.flush()
writer.add_scalar("loss", loss, step)
if rank == 0:
# Save a temporary checkpoint to resume training after pre-emption periodically
if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
logging.info(f"save meta at iter {step}")
save_checkpoint(checkpoint_meta_dir, state)
# Save a checkpoint periodically and generate samples if needed
if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
logging.info(f"save model: {step}-th")
save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state)
dist.destroy_process_group()
================================================
FILE: GMeshDiffusion/lib/diffusion/utils.py
================================================
import torch
import os
import logging
def restore_checkpoint(ckpt_dir, state, device, strict=False, rank=None):
if not os.path.exists(ckpt_dir):
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
logging.warning(f"No checkpoint found at {ckpt_dir}. "
f"Returned the same state as input")
if strict:
raise
return state
else:
if rank is not None:
device = f"cuda:{rank}"
# loaded_state = torch.load(ckpt_dir, map_location=device)
loaded_state = torch.load(ckpt_dir, map_location='cpu')
state['optimizer'].load_state_dict(loaded_state['optimizer'])
try:
state['model'].load_state_dict(loaded_state['model'], strict=False)
except:
consume_prefix_in_state_dict_if_present(loaded_state['model'])
state['model'].load_state_dict(loaded_state['model'], strict=False)
state['ema'].load_state_dict(loaded_state['ema'], device=device)
state['step'] = loaded_state['step']
state['model'].to(device)
try:
state['gradscaler'].load_state_dict(loaded_state['gradscaler'])
# state['gradscaler'].to(device)
except:
# raise
pass
torch.cuda.empty_cache()
return state
def save_checkpoint(ckpt_dir, state):
saved_state = {
'optimizer': state['optimizer'].state_dict(),
'model': state['model'].state_dict(),
'ema': state['ema'].state_dict(),
'step': state['step'],
'gradscaler': state['gradscaler'].state_dict()
}
torch.save(saved_state, ckpt_dir)
================================================
FILE: GMeshDiffusion/main_diffusion.py
================================================
"""Training and evaluation"""
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import lib.diffusion.trainer as trainer
import lib.diffusion.evaler as evaler
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "diffusion configs", lock_config=False)
flags.DEFINE_enum("mode", None, ["train", "uncond_gen", "cond_gen", "uncond_gen_interp"], "Running mode")
flags.mark_flags_as_required(["config", "mode"])
def main(argv):
if FLAGS.mode == 'train':
trainer.train(FLAGS.config)
elif FLAGS.mode == 'uncond_gen':
evaler.uncond_gen(FLAGS.config)
elif FLAGS.mode == 'uncond_gen_interp':
evaler.uncond_gen_interp(FLAGS.config)
elif FLAGS.mode == 'cond_gen':
evaler.cond_gen(FLAGS.config)
if __name__ == "__main__":
app.run(main)
================================================
FILE: GMeshDiffusion/main_diffusion_ddp.py
================================================
"""Training and evaluation"""
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import lib.diffusion.trainer_ddp as trainer
import lib.diffusion.evaler as evaler
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
"config", None, "diffusion configs", lock_config=False)
flags.DEFINE_enum("mode", None, ["train", "uncond_gen", "cond_gen", "uncond_gen_interp"], "Running mode")
flags.mark_flags_as_required(["config", "mode"])
def main(argv):
print(FLAGS.config)
if FLAGS.mode == 'train':
trainer.train(FLAGS.config)
if __name__ == "__main__":
app.run(main)
================================================
FILE: GMeshDiffusion/metadata/get_splits_lower.py
================================================
import os
import random
random.seed(42)
split_ratio = 0.9
data_root = 'PLACEHOLDER'
grid_root = os.path.join(data_root, 'grid')
occgrid_root = os.path.join(data_root, 'grid_aug')
data_path_list = sorted([os.path.join(data_root, fpath) for fpath in os.listdir(data_root)])
random.shuffle(data_path_list)
n_train = int(len(data_path_list) * split_ratio)
train_list = data_path_list[:n_train]
test_list = data_path_list[n_train:]
with open('lower_res64_grid_train.txt', 'w') as f:
f.write('\n'.join(train_list))
with open('lower_res64_grid_test.txt', 'w') as f:
f.write('\n'.join(test_list))
occgrid_train_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in train_list]
occgrid_test_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in test_list]
with open('lower_res64_occgrid_train.txt', 'w') as f:
f.write('\n'.join(occgrid_train_list))
with open('lower_res64_occgrid_test.txt', 'w') as f:
f.write('\n'.join(occgrid_test_list))
================================================
FILE: GMeshDiffusion/metadata/get_splits_upper.py
================================================
import os
import random
random.seed(42)
split_ratio = 0.9
data_root = 'PLACEHOLDER'
grid_root = os.path.join(data_root, 'grid')
occgrid_root = os.path.join(data_root, 'grid_aug')
data_path_list = sorted([os.path.join(data_root, fpath) for fpath in os.listdir(data_root)])
random.shuffle(data_path_list)
n_train = int(len(data_path_list) * split_ratio)
train_list = data_path_list[:n_train]
test_list = data_path_list[n_train:]
with open('upper_res64_grid_train.txt', 'w') as f:
f.write('\n'.join(train_list))
with open('upper_res64_grid_test.txt', 'w') as f:
f.write('\n'.join(test_list))
occgrid_train_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in train_list]
occgrid_test_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in test_list]
with open('upper_res64_occgrid_train.txt', 'w') as f:
f.write('\n'.join(occgrid_train_list))
with open('upper_res64_occgrid_test.txt', 'w') as f:
f.write('\n'.join(occgrid_test_list))
================================================
FILE: GMeshDiffusion/metadata/save_tet_info.py
================================================
'''
Storing tet-grid related meta-info into a single file
'''
import numpy as np
import torch
import os
import tqdm
import argparse
from itertools import combinations
def tet_to_grids(vertices, values_list, grid_size):
grid = torch.zeros(12, grid_size, grid_size, grid_size, device=vertices.device)
with torch.no_grad():
for k, values in enumerate(values_list):
if k == 0:
grid[k, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.squeeze()
else:
grid[1:4, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.transpose(0, 1)
return grid
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='nvdiffrec')
parser.add_argument('-res', '--resolution', type=int)
parser.add_argument('-r', '--root', type=str)
parser.add_argument('-s', '--source', type=str)
parser.add_argument('-t', '--target', type=str)
FLAGS = parser.parse_args()
tet_path = f'./tets/{FLAGS.resolution}_tets_cropped_reordered.npz'
tet = np.load(tet_path)
vertices = torch.tensor(tet['vertices']).cuda()
indices = torch.tensor(tet['indices']).long().cuda()
edges = torch.tensor(tet['edges']).long().cuda()
tet_edges = torch.tensor(tet['tet_edges']).long().view(-1, 2).cuda()
vertices_unique = vertices[:].unique()
dx = vertices_unique[1] - vertices_unique[0]
dx = dx / 2.0 ### denser grid
vertices_discretized = (
((vertices - vertices.min()) / dx)
).long()
midpoints = (vertices_discretized[edges[:, 0]] + vertices_discretized[edges[:, 1]]) / 2.0
midpoints_dicretized = midpoints.long()
tet_verts = vertices_discretized[indices.view(-1)].view(-1, 4, 3)
tet_center = tet_verts.float().mean(dim=1)
tet_center_discretized = tet_center.long()
edge_ind_list = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
msdf_tetedges = []
msdf_from_tetverts = []
for i in range(5):
for j in range(i+1, 6):
if (edge_ind_list[i][0] == edge_ind_list[j][0]
or edge_ind_list[i][0] == edge_ind_list[j][1]
or edge_ind_list[i][1] == edge_ind_list[j][0]
or edge_ind_list[i][1] == edge_ind_list[j][1]
):
msdf_tetedges.append(i)
msdf_tetedges.append(j)
msdf_from_tetverts.extend([edge_ind_list[i][0], edge_ind_list[i][1], edge_ind_list[j][0], edge_ind_list[j][1]])
msdf_tetedges = torch.tensor(msdf_tetedges)
msdf_from_tetverts = torch.tensor(msdf_from_tetverts)
print(msdf_tetedges)
print(msdf_tetedges.size())
tet_edges = tet_edges.view(-1, 2)
msdf_tetedges = msdf_tetedges.view(-1)
tet_edgenodes_pos = (vertices_discretized[tet_edges[:, 0]] + vertices_discretized[tet_edges[:, 1]]) / 2.0
tet_edgenodes_pos = tet_edgenodes_pos.view(-1, 6, 2)
occ_edge_pos = tet_edgenodes_pos[:, msdf_tetedges].view(-1, 12, 2, 3)
edge_twopoint_order = torch.sign(occ_edge_pos[:, :, 0, :] - occ_edge_pos[:, :, 1, :])
edge_twopoint_order_binary_code = (edge_twopoint_order * torch.tensor([16, 4, 1], device=edge_twopoint_order.device).view(1, 1, -1)).sum(dim=-1)
edge_twopoint_order_binary_code = torch.stack([edge_twopoint_order_binary_code, -edge_twopoint_order_binary_code], dim=-1)
_, edge_twopoint_order = edge_twopoint_order_binary_code.sort(dim=-1)
occ_edge_cano_order = torch.arange(2).view(1, 1, 2).expand(occ_edge_pos.size(0), 12, 2).cuda()
occ_edge_cano_order = torch.gather(
input=occ_edge_cano_order,
dim=-1,
index=edge_twopoint_order
)
tet_edges = tet_edges.view(-1)
torch.save({
'tet_v_pos': vertices,
'tet_edge_vpos': vertices[tet_edges].view(-1, 2, 3),
'tet_edge_pix_loc': vertices_discretized[tet_edges].view(-1, 2, 3),
'tet_center_loc': tet_center_discretized,
'msdf_edges': msdf_tetedges.view(12, 2),
'occ_edge_cano_order': occ_edge_cano_order
}, 'tet_info.pt')
================================================
FILE: GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py
================================================
import numpy as np
import torch
import os
import tqdm
import argparse
def tet_to_grids(vertices, values_list, grid_size):
grid = torch.zeros(4, grid_size, grid_size, grid_size, device=vertices.device)
with torch.no_grad():
for k, values in enumerate(values_list):
if k == 0:
grid[k, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.squeeze()
else:
grid[1:4, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.transpose(0, 1)
return grid
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='nvdiffrec')
parser.add_argument('-res', '--resolution', type=int)
parser.add_argument('-ss', '--split-size', type=int, default=int(1e8))
parser.add_argument('-ind', '--index', type=int)
parser.add_argument('-r', '--root', type=str)
parser.add_argument('-s', '--source', type=str)
parser.add_argument('-t', '--target', type=str)
FLAGS = parser.parse_args()
tet_path = f'./tets/{FLAGS.resolution}_tets_cropped_reordered.npz'
tet = np.load(tet_path)
vertices = torch.tensor(tet['vertices']).cuda()
indices = torch.tensor(tet['indices']).long().cuda()
edges = torch.tensor(tet['edges']).long().cuda()
tet_edges = torch.tensor(tet['tet_edges']).long().view(-1, 2).cuda()
vertices_unique = vertices[:].unique()
dx = vertices_unique[1] - vertices_unique[0]
dx = dx / 2.0 ### denser grid
vertices_discretized = (
((vertices - vertices.min()) / dx)
).long()
print(vertices_discretized.size())
midpoints = (vertices_discretized[edges[:, 0]] + vertices_discretized[edges[:, 1]]) / 2.0
midpoints_dicretized = midpoints.long()
tet_verts = vertices_discretized[indices.view(-1)].view(-1, 4, 3)
tet_center = tet_verts.float().mean(dim=1)
tet_center_discretized = tet_center.long()
global_mask = torch.zeros(4, FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda()
cat_mask = torch.zeros(FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda()
global_mask[:4, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] += 1.0
cat_mask[vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = 1
global_mask[0, midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] += 1.0
cat_mask[midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] = -1
torch.save(global_mask, f'global_mask_res{FLAGS.resolution}.pt')
torch.save(cat_mask, f'cat_mask_res{FLAGS.resolution}.pt')
save_folder = FLAGS.root
grid_folder_base = os.path.join(save_folder, FLAGS.target)
os.makedirs(grid_folder_base, exist_ok=True)
print(grid_folder_base)
edge_ind_list = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
msdf_tetedges = []
msdf_from_tetverts = []
for i in range(5):
for j in range(i+1, 6):
if (edge_ind_list[i][0] == edge_ind_list[j][0]
or edge_ind_list[i][0] == edge_ind_list[j][1]
or edge_ind_list[i][1] == edge_ind_list[j][0]
or edge_ind_list[i][1] == edge_ind_list[j][1]
):
msdf_tetedges.append(i)
msdf_tetedges.append(j)
msdf_from_tetverts.extend([edge_ind_list[i][0], edge_ind_list[i][1], edge_ind_list[j][0], edge_ind_list[j][1]])
msdf_tetedges = torch.tensor(msdf_tetedges)
msdf_from_tetverts = torch.tensor(msdf_from_tetverts)
print(msdf_tetedges)
print(msdf_tetedges.size())
occgrid_mask_already_saved = False
tets_folder = os.path.join(save_folder, FLAGS.source)
with torch.no_grad():
for k in tqdm.trange(FLAGS.split_size):
global_index = k + FLAGS.index * FLAGS.split_size
tet_path = os.path.join(tets_folder, 'dmt_dict_{:05d}.pt'.format(global_index))
if os.path.exists(os.path.join(grid_folder_base, 'grid_{:05d}.pt'.format(global_index))):
# continue
pass
try:
if os.path.exists(tet_path):
tet = torch.load(tet_path, map_location="cuda")
sdf = tet['sdf'].view(-1, 1)
msdf = tet['msdf'].view(-1, 1)
deform = tet['deform']
### resetting sdfs and offsets of all non-mesh-generating tet vertices
tet_edges = tet_edges.view(-1, 2)
tet_edge_mask = ((torch.sign(sdf[tet_edges[:, 0]]) - torch.sign(sdf[tet_edges[:, 1]])) != 0).bool().squeeze(-1).view(-1, 6)
tet_sdf_coeff = (
torch.abs(sdf[tet_edges[:, 0]])
/ (torch.abs(sdf[tet_edges[:, 0]] - sdf[tet_edges[:, 1]]) + 1e-10)
).squeeze(-1)
tet_sdf_coeff = tet_sdf_coeff.view(-1, 1)
midpoint_msdf_tet = msdf[tet_edges[:, 0]] * (1 - tet_sdf_coeff) + msdf[tet_edges[:, 1]] * tet_sdf_coeff
midpoint_msdf_tet = midpoint_msdf_tet.view(-1, 6)
tet_mask = ((midpoint_msdf_tet > 0) & tet_edge_mask).sum(dim=-1).bool()
vert_mask = torch.zeros_like(sdf.squeeze())
vert_mask[indices[tet_mask].view(-1)] = 1.0
vert_mask = ~vert_mask.bool()
msdf[vert_mask] = -1.0
deform[vert_mask] = 0.0
tet_nonallnegmsdf = (torch.sign(msdf[indices.view(-1)].view(-1, 4)).sum(dim=-1) != -4)
vert_mask_nonallnegmsdf = torch.zeros_like(sdf.squeeze())
vert_mask_nonallnegmsdf[indices[tet_nonallnegmsdf].view(-1)] = 1.0
vert_mask_nonallnegmsdf = ~vert_mask_nonallnegmsdf.bool()
sdf[vert_mask_nonallnegmsdf] = 1.0
mask = (
(torch.sign(sdf[edges[:, 0]]) - torch.sign(sdf[edges[:, 1]]) != 0).bool()
)
nan_mask = (
((torch.sign(sdf[edges[:, 0]]) + torch.sign(sdf[edges[:, 1]])) == 2)
| ((torch.sign(sdf[edges[:, 0]]) + torch.sign(sdf[edges[:, 1]])) == -2)
).bool().squeeze(-1)
original_sdf_coeff = torch.abs(sdf[edges[:, 0]]) / (torch.abs(sdf[edges[:, 0]] - sdf[edges[:, 1]]) + 1e-10)
original_sdf_coeff[nan_mask] = torch.nan
normalized_sdf_coeff = ((original_sdf_coeff - 0.5) * 2.0)
normalized_sdf_coeff = torch.nan_to_num(normalized_sdf_coeff)
assert torch.all(normalized_sdf_coeff.abs() <= 1.0)
sdf_sign = torch.sign(sdf)
sdf_sign[sdf_sign == 0] = 1
midpoint_msdf = msdf[edges[:, 0]] * (1 - original_sdf_coeff.view(-1, 1)) + msdf[edges[:, 1]] * original_sdf_coeff.view(-1, 1)
midpoint_msdf_sign = torch.sign(midpoint_msdf)
midpoint_msdf_sign[midpoint_msdf_sign == 0] = -1
midpoint_msdf_sign = midpoint_msdf_sign * mask - (1.0 - mask.float())
############################ Occ Grid ############################
tet_edges = tet_edges.view(-1, 2)
tet_edge_mask = ((torch.sign(sdf[tet_edges[:, 0]]) - torch.sign(sdf[tet_edges[:, 1]])) != 0).bool().squeeze(-1).view(-1, 6)
tet_sdf_coeff = (
torch.abs(sdf[tet_edges[:, 0]])
/ (torch.abs(sdf[tet_edges[:, 0]] - sdf[tet_edges[:, 1]]) + 1e-10)
).squeeze(-1)
tet_sdf_coeff = tet_sdf_coeff * tet_edge_mask.view(-1)
tet_sdf_coeff = tet_sdf_coeff.view(-1, 1)
nan_mask = (
((torch.sign(sdf[tet_edges[:, 0]]) + torch.sign(sdf[tet_edges[:, 1]])) == 2)
| ((torch.sign(sdf[tet_edges[:, 0]]) + torch.sign(sdf[tet_edges[:, 1]])) == -2)
).bool().squeeze(-1)
tet_sdf_coeff[nan_mask] = torch.nan
midpoint_msdf_tet = msdf[tet_edges[:, 0]] * (1 - tet_sdf_coeff) + msdf[tet_edges[:, 1]] * tet_sdf_coeff
midpoint_msdf_tet = midpoint_msdf_tet.view(-1, 6)
inscribed_edge_twopoint_msdf = midpoint_msdf_tet[:, msdf_tetedges.view(-1)].view(-1, 12, 2)
assert ((
(tet_edges.view(-1, 6, 2)[:, msdf_tetedges.view(-1), :].view(-1, 24, 2).sum(dim=-1)) - indices[:, msdf_from_tetverts].view(-1, 24, 2).sum(dim=-1)
).sum().item() == 0)
assert msdf_tetedges.view(-1).size(0) == 24
inscribed_tet_fourpoint_pos = vertices_discretized[indices[:, msdf_from_tetverts].view(-1)].view(-1, 12, 4, 3).to(torch.float64)
inscribed_edge_twopoint_pos = inscribed_tet_fourpoint_pos.view(-1, 12, 2, 2, 3).mean(dim=-2)
occgrid_loc = inscribed_edge_twopoint_pos.mean(dim=-2)
occgrid_loc = (occgrid_loc * 2).to(torch.int64).view(-1, 3)
edge_twopoint_order = torch.sign(inscribed_edge_twopoint_pos[:, :, 0, :] - inscribed_edge_twopoint_pos[:, :, 1, :])
edge_twopoint_order_binary_code = (edge_twopoint_order * torch.tensor([16, 4, 1], device=edge_twopoint_order.device).view(1, 1, -1)).sum(dim=-1)
edge_twopoint_order_binary_code = torch.stack([edge_twopoint_order_binary_code, -edge_twopoint_order_binary_code], dim=-1)
_, edge_twopoint_order = edge_twopoint_order_binary_code.sort(dim=-1)
inscribed_edge_twopoint_msdf = torch.gather(
input=inscribed_edge_twopoint_msdf,
dim=-1,
index=edge_twopoint_order
)
mask_msdf = (
((inscribed_edge_twopoint_msdf[:, :, 0] > 0) & (inscribed_edge_twopoint_msdf[:, :, 1] <= 0)) |
((inscribed_edge_twopoint_msdf[:, :, 0] <= 0) & (inscribed_edge_twopoint_msdf[:, :, 1] > 0))
)
msdf_coeff_12 = (
torch.abs(inscribed_edge_twopoint_msdf[:, :, 0])
/ (
torch.abs(inscribed_edge_twopoint_msdf[:, :, 0] - inscribed_edge_twopoint_msdf[:, :, 1])
+ 1e-10
)
)
msdf_coeff_12 = (msdf_coeff_12 - 0.5) * 2.0 * mask_msdf
msdf_coeff_12 = torch.nan_to_num(msdf_coeff_12)
occ_grid = torch.zeros(256, 256, 256, dtype=torch.float, device=msdf_coeff_12.device)
occ_grid[occgrid_loc[:, 0], occgrid_loc[:, 1], occgrid_loc[:, 2]] = msdf_coeff_12.view(-1).to(torch.float)
if not occgrid_mask_already_saved:
occ_grid_mask = torch.zeros(256, 256, 256, dtype=torch.float, device=msdf_coeff_12.device)
occ_grid_mask[occgrid_loc[:, 0], occgrid_loc[:, 1], occgrid_loc[:, 2]] = 1
torch.save(occ_grid_mask, f'occ_mask_res{FLAGS.resolution}.pt')
occgrid_mask_already_saved = True
# #################
torch.cuda.empty_cache()
grid = torch.zeros(4, FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda()
grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = sdf_sign.squeeze()
grid[1:4, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = deform.transpose(0, 1)
grid[0, midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] = midpoint_msdf_sign.squeeze()
assert grid.abs().max() <= 1
save_path = os.path.join(grid_folder_base, 'grid_{:05d}.pt'.format(global_index))
torch.save(grid, save_path)
save_path = os.path.join(grid_folder_base, 'occgrid_{:05d}.pt'.format(global_index))
torch.save(occ_grid, save_path)
except:
raise
================================================
FILE: GMeshDiffusion/scripts/run_eval_lower_occgrid_normalized.sh
================================================
python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_lower_occgrid_normalized.py \
--config.eval.eval_dir=$EVAL_DIR \
--config.data.root_dir=$REPO_ROOT_DIR \
--config.sampling.method=ddim \
--config.eval.ckpt_path=$CKPT_PATH \
--config.eval.bin_size=30 \
--config.eval.idx $1
================================================
FILE: GMeshDiffusion/scripts/run_eval_upper_occgrid_normalized.sh
================================================
python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_upper_occgrid_normalized.py \
--config.eval.eval_dir=$EVAL_DIR \
--config.data.root_dir=$REPO_ROOT_DIR \
--config.sampling.method=ddim \
--config.eval.ckpt_path=$CKPT_PATH \
--config.eval.bin_size=10 \
--config.eval.idx $1
================================================
FILE: GMeshDiffusion/scripts/run_lower_occgrid_normalized_ddp.sh
================================================
torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_lower_occgrid_normalized.py \
--config.training.train_dir=$SAVE_DIR --config.data.root_dir=$REPO_ROOT_DIR
================================================
FILE: GMeshDiffusion/scripts/run_upper_occgrid_normalized_ddp.sh
================================================
torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_upper_occgrid_normalized.py \
--config.training.train_dir=$SAVE_DIR --config.data.root_dir=$REPO_ROOT_DIR
================================================
FILE: README.md
================================================
<div align="center">
<img src="assets/gshell_logo.png" width="900"/>
</div>
# Ghost on the Shell: An Expressive Representation of General 3D Shapes
<div align="center">
<img src="assets/teaser.png" width="900"/>
</div>
## Introduction
This is the official implementation of our paper (ICLR 2024 oral) "Ghost on the Shell: An Expressive Representation of General 3D Shapes" (G-Shell).
G-Shell is a generic and differentiable representation for both watertight and non-watertight meshes. It enables 1) efficient and robust rasterization-based multiview reconstruction and 2) template-free generation of non-watertight meshes.
Please refer to [our project page](https://gshell3d.github.io) and [our paper](https://gshell3d.github.io/static/paper/gshell.pdf) for more details.
## Getting Started
### Requirements
- Python >= 3.8
- CUDA 11.8
- PyTorch == 1.13.1
(Conda installation recommended)
#### Reconstruction
Run the following
```
pip install ninja imageio PyOpenGL glfw xatlas gdown
pip install git+https://github.com/NVlabs/nvdiffrast/
pip install --global-option="--no-networks" git+https://github.com/NVlabs/tiny-cuda-nn#subdirectory=bindings/torch
```
Follow the instructions [here](https://github.com/NVIDIAGameWorks/kaolin/) to install kaolin.
Download the tet-grid files ([res128](https://drive.google.com/file/d/1u5FzpuY_BOAg8-g9lRwvah7mbCBOfNVg/view?usp=sharing), [res256](https://drive.google.com/file/d/1JnFoPEGcTLFJ7OHSWrI72h1H9_yOxUP6/view?usp=sharing)) & [res64 for G-MeshDiffusion](https://drive.google.com/file/d/1YQuU4D-0q8kwrzEfla3hGzBg4erBhand/view?usp=drive_link) to `data/tets` folder under the root directory. Alternatively, you may follow https://github.com/crawforddoran/quartet and `data/tets/generate_tets.py` to create the tet-grid files.
#### Generation
Install the following
- Pytorch3D
- ml_collections
## To-dos
- [x] Code for reconstruction
- [ ] DeepFashion3D multiview image dataset for metallic surfaces
- [x] Code for generative models
- [ ] Code for DeepFashion3D dataset preparation
- [ ] Evaluation code for generative models
## Reconstruction
### Datasets
#### DeepFashion3D mesh dataset
We provide ground-truth images (rendered under realistic environment light with Blender) for 9 instances in [DeepFashion3D-v2 dataset](https://github.com/GAP-LAB-CUHK-SZ/deepFashion3D). The download links for the raw meshes can be found in their repo.
non-metallic material: [training data](https://drive.google.com/file/d/1LwBqLYzamFLyBIiNpD6kEkvySrq2nruG/view?usp=sharing), [test data](https://drive.google.com/file/d/1-47dH_yJrUzKVdKbJenslpdyHwDI6QVo/view?usp=sharing)
#### NeRF synthetic dataset
Download the [NeRF synthetic dataset archive](https://drive.google.com/uc?export=download&id=18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG) and unzip it into the `data/` folder.
#### Hat dataset
Download link: https://drive.google.com/file/d/18UmT1NM5wJQ-ZM-rtUXJHXkDc-ba-xVk/view?usp=sharing
RGB images, segmentation masks and the corresponding camera poses are included. Alternatively, you may choose to 1) generate the camera poses with COLMAP and 2) create binary segmentation masks by yourself.
### Training
#### DeepFashion3D-v2 instances
The mesh instances' IDs are [30, 92, 117, 133, 164, 320, 448, 522, 591]. To reconstruct the `$INDEX`-th mesh (0-8) in the list using tet-based G-Shell, run
```
python train_gshelltet_deepfashion.py --config config/deepfashion_mc_256.json --index $INDEX --trainset_path $TRAINSET_PATH --testset_path $TESTSET_PATH --o $OUTPUT_PATH
```
For FlexiCubes + G-Shell, run
```
python train_gflexicubes_deepfashion.py --config config/deepfashion_mc_80.json --index $INDEX --trainset_path $TRAINSET_PATH --testset_path $TESTSET_PATH --o $OUTPUT_PATH
```
#### Synthetic data
```
python train_gshelltet_synthetic.py --config config/nerf_chair.json --o $OUTPUT_PATH
```
#### Hat data
```
python train_gshelltet_polycam.py --config config/polycam_mc_128.json --trainset_path $TRAINSET_PATH --o $OUTPUT_PATH
```
```
python train_gshelltet_polycam.py --config config/polycam_mc_128.json --trainset_path $TRAINSET_PATH --o $OUTPUT_PATH
```
#### On config files
You may consider modify the following, depending on your demand:
- `gshell_grid`: the G-Shell grid size. For tet-based G-Shell, please make sure the corresponding tet-grid file exists under `data/tets` (e.g., `256_tets.npz`). Otherwise, follow https://github.com/crawforddoran/quartet and `data/tets/generate_tets.py` to generate the desired tet-grid file.
- `n_samples`: the number of MC samples for light rays per rasterized pixel. The higher the better (at a cost of memory and speed).
- `batch_size`: how many views sampled in each iteration.
- `iteration`: total number of iterations.
- `kd_min`, `kd_max`, etc: the min/max of the corresponding PBR material parameter.
## Generation
### Preparation
Download info files for the underlying tet grids and binary masks that indicating which locations store useful values in the cubic grids from [tet_info.pt](https://drive.google.com/file/d/19Dw_hOpcVHazpm2_1qA7T7j5xABOUWxv/view?usp=drive_link), [global_mask_res64.pt](https://drive.google.com/file/d/1mlSnu23_u08HH5aO3x5z1V9GzguFzoiT/view?usp=drive_link), [cat_mask_res64.pt](https://drive.google.com/file/d/11Bm4CQX-y1X7R47AfQQz20s7oP6AbbNK/view?usp=drive_link) and [occ_mask_res64.pt](https://drive.google.com/file/d/1qEqqLfZe633GdVkj5MGOCON_kf0l4e4G/view?usp=drive_link). Put these files under `GMeshDiffusion/metadata/`.
#### For inference
Download the pretrained model for upper-body garments lower-body garments [here](https://huggingface.co/lzzcd001/GMeshDiffusion-Models).
#### For training
1) Download the processed Cloth3D garment dataset (for upper-body & lower-body garments) from [link](https://huggingface.co/datasets/lzzcd001/Cloth3D-GShell-Dataset). Alternatively, you may create a grid dataset for your own objects by a) normalize your datapoints by re-center and re-scaling meshes, b) fitting G-Shell representations and c) turn these representations into cubid grids by running `GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py`.
2) Run `GMeshDiffusion/metadata/get_splits_lower.py` and/or `GMeshDiffusion/metadata/get_splits_upper.py` to generate lists of training and test datapoints.
### Inference
1. Modify the batch size in config files in `GMeshDiffusion/diffusion_config/` and enter the desired directories and values (for model checkpoints, where to store generated samples, etc.) in `GMeshDiffusion/scripts`.
2. Run the eval scripts in `GMeshDiffusion/scripts`.
3. Run `eval_gmeshdiffusion_generated_samples.py` to extract triangular meshes.
### Training
1. Enter the desired directories (for model checkpoints and where to store generated samples) in `GMeshDiffusion/scripts`.
2. Modify the config files if necessary.
3. Run the training scripts in `GMeshDiffusion/scripts`.
## Citation
If you find our work useful to your research, please consider citing:
```
@article{liu2024gshell,
title={Ghost on the Shell: An Expressive Representation of General 3D Shapes},
author={Liu, Zhen and Feng, Yao and Xiu, Yuliang and Liu, Weiyang and Paull, Liam and Black, Michael J. and Sch{\"o}lkopf, Bernhard},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024},
}
```
## Acknowledgement
We sincerely thank the authors of [Nvdiffrecmc](https://github.com/NVlabs/nvdiffrecmc), [FlexiCubes](https://github.com/nv-tlabs/FlexiCubes) and https://github.com/yang-song/score_sde_pytorch for sharing their codes. Our repo is adapted from [MeshDiffusion](https://github.com/lzzcd001/MeshDiffusion/).
================================================
FILE: configs/deepfashion_mc.json
================================================
{
"ref_mesh": "data/spot/spot.obj",
"random_textures": true,
"iter": 5000,
"save_interval": 100,
"texture_res": [ 1024, 1024 ],
"train_res": [1024, 1024],
"batch": 2,
"learning_rate": [0.03, 0.005],
"ks_min" : [0, 0.001, 0.0],
"ks_max" : [0, 1.0, 1.0],
"envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
"lock_pos" : false,
"display": [{"latlong" : true}],
"background" : "white",
"denoiser": "bilateral",
"n_samples" : 24,
"env_scale" : 2.0,
"gshell_grid" : 128,
"validate" : true,
"laplace_scale" : 6000,
"boxscale": [1, 1, 1],
"aabb": [-1, -1, -1, 1, 1, 1]
}
================================================
FILE: configs/deepfashion_mc_256.json
================================================
{
"ref_mesh": "data/spot/spot.obj",
"random_textures": true,
"iter": 5000,
"save_interval": 100,
"texture_res": [ 1024, 1024 ],
"train_res": [1024, 1024],
"batch": 2,
"learning_rate": [0.03, 0.005],
"ks_min" : [0, 0.001, 0.0],
"ks_max" : [0, 1.0, 1.0],
"envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
"lock_pos" : false,
"display": [{"latlong" : true}],
"background" : "white",
"denoiser": "bilateral",
"n_samples" : 24,
"env_scale" : 2.0,
"gshell_grid" : 256,
"validate" : true,
"laplace_scale" : 6000,
"boxscale": [1, 1, 1],
"aabb": [-1, -1, -1, 1, 1, 1]
}
================================================
FILE: configs/deepfashion_mc_512.json
================================================
{
"ref_mesh": "data/spot/spot.obj",
"random_textures": true,
"iter": 5000,
"save_interval": 100,
"texture_res": [ 1024, 1024 ],
"train_res": [1024, 1024],
"batch": 2,
"learning_rate": [0.03, 0.005],
"ks_min" : [0, 0.001, 0.0],
"ks_max" : [0, 1.0, 1.0],
"envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
"validate" : false,
"lock_pos" : false,
"display": [{"latlong" : true}],
"background" : "white",
"denoiser": "bilateral",
"n_samples" : 12,
"env_scale" : 2.0,
"gshell_grid" : 512,
"validate" : true,
"laplace_scale" : 6000,
"boxscale": [1, 1, 1],
"aabb": [-1, -1, -1, 1, 1, 1]
}
================================================
FILE: configs/deepfashion_mc_80.json
================================================
{
"ref_mesh": "data/spot/spot.obj",
"random_textures": true,
"iter": 5000,
"save_interval": 100,
"texture_res": [ 1024, 1024 ],
"train_res": [1024, 1024],
"batch": 2,
"learning_rate": [0.03, 0.005],
"ks_min" : [0, 0.001, 0.0],
"ks_max" : [0, 1.0, 1.0],
"envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
"lock_pos" : false,
"display": [{"latlong" : true}],
"background" : "white",
"denoiser": "bilateral",
"n_samples" : 24,
"env_scale" : 2.0,
"gshell_grid" : 80,
"validate" : true,
"laplace_scale" : 6000,
"boxscale": [1, 1, 1],
"aabb": [-1, -1, -1, 1, 1, 1]
}
================================================
FILE: configs/nerf_chair.json
================================================
{
"ref_mesh": "data/nerf_synthetic/chair",
"random_textures": true,
"iter": 5000,
"save_interval": 100,
"texture_res": [ 1024, 1024 ],
"train_res": [800, 800],
"batch": 2,
"learning_rate": [0.03, 0.005],
"gshell_grid" : 128,
"mesh_scale" : 2.1,
"validate" : true,
"n_samples" : 8,
"denoiser" : "bilateral",
"display": [{"latlong" : true}, {"bsdf" : "kd"}, {"bsdf" : "ks"}, {"bsdf" : "normal"}],
"background" : "white",
"boxscale": [1, 1, 1],
"aabb": [-1, -1, -1, 1, 1, 1]
}
================================================
FILE: configs/polycam_mc.json
================================================
{
"ref_mesh": "data/spot/spot.obj",
"random_textures": true,
"iter": 5000,
"save_interval": 100,
"texture_res": [ 1024, 1024 ],
"train_res": [768, 1024],
"batch": 2,
"learning_rate": [0.03, 0.005],
"ks_min" : [0, 0.001, 0.0],
"ks_max" : [0, 1.0, 1.0],
"envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
"lock_pos" : false,
"display": [{"latlong" : true}],
"background" : "white",
"denoiser": "bilateral",
"n_samples" : 8,
"env_scale" : 2.0,
"gshell_grid" : 256,
"validate" : true,
"laplace_scale" : 6000,
"boxscale": [1, 1, 1],
"aabb": [-1, -1, -1, 1, 1, 1]
}
================================================
FILE: configs/polycam_mc_128.json
================================================
{
"ref_mesh": "data/spot/spot.obj",
"random_textures": true,
"iter": 5000,
"save_interval": 100,
"texture_res": [ 1024, 1024 ],
"train_res": [768, 1024],
"batch": 2,
"learning_rate": [0.03, 0.005],
"ks_min" : [0, 0.001, 0.0],
"ks_max" : [0, 1.0, 1.0],
"envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
"lock_pos" : false,
"display": [{"latlong" : true}],
"background" : "white",
"denoiser": "bilateral",
"n_samples" : 8,
"env_scale" : 2.0,
"gshell_grid" : 128,
"validate" : true,
"laplace_scale" : 6000,
"boxscale": [1, 1, 1],
"aabb": [-1, -1, -1, 1, 1, 1]
}
================================================
FILE: configs/polycam_mc_16samples.json
================================================
{
"ref_mesh": "data/spot/spot.obj",
"random_textures": true,
"iter": 5000,
"save_interval": 100,
"texture_res": [ 1024, 1024 ],
"train_res": [768, 1024],
"batch": 2,
"learning_rate": [0.03, 0.005],
"ks_min" : [0, 0.001, 0.0],
"ks_max" : [0, 1.0, 1.0],
"envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
"lock_pos" : false,
"display": [{"latlong" : true}],
"background" : "white",
"denoiser": "bilateral",
"n_samples" : 16,
"env_scale" : 2.0,
"gshell_grid" : 256,
"validate" : true,
"laplace_scale" : 6000,
"boxscale": [1, 1, 1],
"aabb": [-1, -1, -1, 1, 1, 1]
}
================================================
FILE: data/tets/generate_tets.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import numpy as np
'''
This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet,
to generate a tet grid
1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`
2) Run the function below to generate a file `cube_32_tet.tet`
'''
def generate_tetrahedron_grid_file(res=32, root='..'):
frac = 1.0 / res
command = 'cd %s/quartet; ' % (root) + \
'./quartet meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res)
os.system(command)
'''
This code segment shows how to convert from a quartet .tet file to compressed npz file
'''
def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile = '32_tets'):
file1 = open(quartetfile, 'r')
header = file1.readline()
numvertices = int(header.split(" ")[1])
numtets = int(header.split(" ")[2])
print(numvertices, numtets)
# load vertices
vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices)
print(vertices.shape)
# load indices
indices = np.loadtxt(quartetfile, dtype=int, skiprows=1+numvertices, max_rows=numtets)
print(indices.shape)
np.savez_compressed(npzfile, vertices=vertices, indices=indices)
================================================
FILE: dataset/__init__.py
================================================
# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from .dataset import Dataset
from .dataset_mesh import DatasetMesh
from .dataset_nerf import DatasetNERF
from .dataset_llff import DatasetLLFF
================================================
FILE: dataset/dataset.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
class Dataset(torch.utils.data.Dataset):
"""Basic dataset interface"""
def __init__(self):
super().__init__()
def __len__(self):
raise NotImplementedError
def __getitem__(self):
raise NotImplementedError
def collate(self, batch):
iter_res, iter_spp = batch[0]['resolution'], batch[0]['spp']
return {
'mv' : torch.cat(list([item['mv'] for item in batch]), dim=0),
'mvp' : torch.cat(list([item['mvp'] for item in batch]), dim=0),
'campos' : torch.cat(list([item['campos'] for item in batch]), dim=0),
'resolution' : iter_res,
'spp' : iter_spp,
'img' : torch.cat(list([item['img'] for item in batch]), dim=0) if 'img' in batch[0] else None,
'img_second' : torch.cat(list([item['img_second'] for item in batch]), dim=0) if 'img_second' in batch[0] else None,
'invdepth' : torch.cat(list([item['invdepth'] for item in batch]), dim=0)if 'invdepth' in batch[0] else None,
'invdepth_second' : torch.cat(list([item['invdepth_second'] for item in batch]), dim=0) if 'invdepth_second' in batch[0] else None,
'envlight_transform': torch.cat(list([item['envlight_transform'] for item in batch]), dim=0) if 'envlight_transform' in batch and batch[0]['envlight_transform'] is not None else None,
}
================================================
FILE: dataset/dataset_deepfashion.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import glob
import json
import torch
import numpy as np
from render import util
from .dataset import Dataset
import cv2 as cv
# This function is borrowed from IDR: https://github.com/lioryariv/idr
def load_K_Rt_from_P(filename, P=None):
if P is None:
lines = open(filename).read().splitlines()
if len(lines) == 4:
lines = lines[1:]
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
P = np.asarray(lines).astype(np.float32).squeeze()
out = cv.decomposeProjectionMatrix(P)
K = out[0]
R = out[1]
t = out[2]
K = K / K[2, 2]
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose()
pose[:3, 3] = (t[:3] / t[3])[:, 0]
return intrinsics, pose
def _load_img(path):
img = util.load_image_raw(path)
if img.dtype != np.float32: # LDR image
img = torch.tensor(img / 255, dtype=torch.float32)
img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])
else:
img = torch.tensor(img, dtype=torch.float32)
return img
class DatasetDeepFashion(Dataset):
def __init__(self, base_dir, FLAGS, examples=None):
self.FLAGS = FLAGS
self.examples = examples
self.base_dir = base_dir
# Load config / transforms
self.n_images = 72 ### hardcoded
self.fovy = np.deg2rad(60)
self.proj_mtx = util.perspective(
self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]
)
camera_dict = np.load(os.path.join(self.base_dir, 'cameras_sphere.npz'))
# world_mat is a projection matrix from world to image
self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
self.scale_mats_np = []
# scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.
self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
self.intrinsics_all = []
self.pose_all = []
for scale_mat, world_mat in zip(self.scale_mats_np, self.world_mats_np):
P = world_mat @ scale_mat
P = P[:3, :4]
intrinsics, pose = load_K_Rt_from_P(None, P)
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
self.pose_all.append(torch.from_numpy(pose).float())
# Determine resolution & aspect ratio
self.resolution = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(0))).shape[0:2]
self.aspect = self.resolution[1] / self.resolution[0]
if self.FLAGS.local_rank == 0:
print("DatasetNERF: %d images with shape [%d, %d]" % (self.n_images, self.resolution[0], self.resolution[1]))
def _parse_frame(self, idx):
# Load image data and modelview matrix
img = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(idx)))
img[:,:,:3] = img[:,:,:3] * img[:,:,3:]
img[:,:,3] = torch.sign(img[:,:,3])
assert img.size(-1) == 4
flip_mat = torch.tensor([
[ 1, 0, 0, 0],
[ 0, -1, 0, 0],
[ 0, 0, -1, 0],
[ 0, 0, 0, 1]
], dtype=torch.float)
mv = flip_mat @ torch.linalg.inv(self.pose_all[idx])
campos = torch.linalg.inv(mv)[:3, 3]
mvp = self.proj_mtx @ mv
return img[None, ...].cuda(), mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda() # Add batch dimension
def __len__(self):
return self.n_images if self.examples is None else self.examples
def __getitem__(self, itr):
iter_res = self.FLAGS.train_res
img = []
img, mv, mvp, campos = self._parse_frame(itr % self.n_images)
return {
'mv' : mv,
'mvp' : mvp,
'campos' : campos,
'resolution' : iter_res,
'spp' : self.FLAGS.spp,
'img' : img
}
================================================
FILE: dataset/dataset_deepfashion_testset.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import glob
import json
import torch
import numpy as np
from render import util
from .dataset import Dataset
import cv2 as cv
# This function is borrowed from IDR: https://github.com/lioryariv/idr
def load_K_Rt_from_P(filename, P=None):
if P is None:
lines = open(filename).read().splitlines()
if len(lines) == 4:
lines = lines[1:]
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
P = np.asarray(lines).astype(np.float32).squeeze()
out = cv.decomposeProjectionMatrix(P)
K = out[0]
R = out[1]
t = out[2]
K = K / K[2, 2]
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose()
pose[:3, 3] = (t[:3] / t[3])[:, 0]
return intrinsics, pose
def _load_img(path):
img = util.load_image_raw(path)
if img.dtype != np.float32: # LDR image
img = torch.tensor(img / 255, dtype=torch.float32)
img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])
else:
img = torch.tensor(img, dtype=torch.float32)
return img
def _load_mask(path):
img = util.load_image_raw(path)
if img.dtype != np.float32: # LDR image
img = torch.tensor(img / 255, dtype=torch.float32)
else:
img = torch.tensor(img, dtype=torch.float32)
return img
class DatasetDeepFashionTestset(Dataset):
def __init__(self, base_dir, FLAGS, examples=None):
self.FLAGS = FLAGS
self.examples = examples
self.base_dir = base_dir
# Load config / transforms
self.n_images = 200 ### hardcoded
proj_mtx_all = np.load(os.path.join(self.base_dir, 'proj_mtx_all.npy'))
self.intrinsics_all = []
self.pose_all = []
self.fovy = np.deg2rad(60)
self.proj_mtx = util.perspective(
self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]
)
for i in range(proj_mtx_all.shape[0]):
P = proj_mtx_all[i]
P = P[:3, :4]
intrinsics, pose = load_K_Rt_from_P(None, P)
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
self.pose_all.append(torch.from_numpy(pose).float())
# Determine resolution & aspect ratio
self.resolution = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(0))).shape[0:2]
self.aspect = self.resolution[1] / self.resolution[0]
if self.FLAGS.local_rank == 0:
print("DatasetNERF: %d images with shape [%d, %d]" % (self.n_images, self.resolution[0], self.resolution[1]))
def _parse_frame(self, idx):
# Load image data and modelview matrix
img = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(idx)))
assert img.size(-1) == 4
flip_mat = torch.tensor([
[ 1, 0, 0, 0],
[ 0, -1, 0, 0],
[ 0, 0, -1, 0],
[ 0, 0, 0, 1]
], dtype=torch.float)
mv = flip_mat @ torch.linalg.inv(self.pose_all[idx])
campos = torch.linalg.inv(mv)[:3, 3]
mvp = self.proj_mtx @ mv
return img[None, ...].cuda(), mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda() # Add batch dimension
def __len__(self):
return self.n_images if self.examples is None else self.examples
def __getitem__(self, itr):
iter_res = self.FLAGS.train_res
img = []
img, mv, mvp, campos = self._parse_frame(itr % self.n_images)
return {
'mv' : mv,
'mvp' : mvp,
'campos' : campos,
'resolution' : iter_res,
'spp' : self.FLAGS.spp,
'img' : img
}
================================================
FILE: dataset/dataset_llff.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import os
import glob
import torch
import numpy as np
from render import util
from .dataset import Dataset
def _load_mask(fn):
img = torch.tensor(util.load_image(fn), dtype=torch.float32)
if len(img.shape) == 2:
img = img[..., None].repeat(1, 1, 3)
return img
def _load_img(fn):
img = util.load_image_raw(fn)
if img.dtype != np.float32: # LDR image
img = torch.tensor(img / 255, dtype=torch.float32)
img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])
else:
img = torch.tensor(img, dtype=torch.float32)
return img
###############################################################################
# LLFF datasets (real world camera lightfields)
###############################################################################
class DatasetLLFF(Dataset):
def __init__(self, base_dir, FLAGS, examples=None):
self.FLAGS = FLAGS
self.base_dir = base_dir
self.examples = examples
# Enumerate all image files and get resolution
all_img = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "images", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]
self.resolution = _load_img(all_img[0]).shape[0:2]
# Load camera poses
poses_bounds = np.load(os.path.join(self.base_dir, 'poses_bounds.npy'))
poses = poses_bounds[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) # Taken from nerf, swizzles from LLFF to expected coordinate system
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
lcol = np.array([0,0,0,1], dtype=np.float32)[None, None, :].repeat(poses.shape[0], 0)
self.imvs = torch.tensor(np.concatenate((poses[:, :, 0:4], lcol), axis=1), dtype=torch.float32)
self.aspect = self.resolution[1] / self.resolution[0] # width / height
self.fovy = util.focal_length_to_fovy(poses[:, 2, 4], poses[:, 0, 4])
# Recenter scene so lookat position is origin
center = util.lines_focal(self.imvs[..., :3, 3], -self.imvs[..., :3, 2])
self.imvs[..., :3, 3] = self.imvs[..., :3, 3] - center[None, ...]
if self.FLAGS.local_rank == 0:
print("DatasetLLFF: %d images with shape [%d, %d]" % (len(all_img), self.resolution[0], self.resolution[1]))
print("DatasetLLFF: auto-centering at %s" % (center.cpu().numpy()))
# Pre-load from disc to avoid slow png parsing
if self.FLAGS.pre_load:
self.preloaded_data = []
for i in range(self.imvs.shape[0]):
self.preloaded_data += [self._parse_frame(i)]
def _parse_frame(self, idx):
all_img = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "images", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]
all_mask = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "masks", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]
assert len(all_img) == self.imvs.shape[0] and len(all_mask) == self.imvs.shape[0]
# Load image+mask data
img = _load_img(all_img[idx])
mask = _load_mask(all_mask[idx])
img = torch.cat((img, mask[..., 0:1]), dim=-1)
# Setup transforms
proj = util.perspective(self.fovy[idx, ...], self.aspect, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
mv = torch.linalg.inv(self.imvs[idx, ...])
campos = torch.linalg.inv(mv)[:3, 3]
mvp = proj @ mv
return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...] # Add batch dimension
def __len__(self):
return self.imvs.shape[0] if self.examples is None else self.examples
def __getitem__(self, itr):
if self.FLAGS.pre_load:
img, mv, mvp, campos = self.preloaded_data[itr % self.imvs.shape[0]]
else:
img, mv, mvp, campos = self._parse_frame(itr % self.imvs.shape[0])
return {
'mv' : mv,
'mvp' : mvp,
'campos' : campos,
'resolution' : self.resolution,
'spp' : self.FLAGS.spp,
'img' : img
}
================================================
FILE: dataset/dataset_mesh.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import numpy as np
import torch
from render import util
from render import mesh
from render import render
from render import light
from .dataset import Dataset
###############################################################################
# Reference dataset using mesh & rendering
###############################################################################
class DatasetMesh(Dataset):
def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False, mesh_center=None):
# Init
self.glctx = glctx
self.cam_radius = cam_radius
self.FLAGS = FLAGS
self.validate = validate
self.fovy = np.deg2rad(45)
self.aspect = FLAGS.train_res[1] / FLAGS.train_res[0]
self.random_lgt = FLAGS.random_lgt
self.camera_lgt = False
self.mesh_center = mesh_center
if self.FLAGS.local_rank == 0:
print("DatasetMesh: ref mesh has %d triangles and %d vertices" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0]))
# Sanity test training texture resolution
ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes())
if 'normal' in ref_mesh.material:
ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes())
if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]:
print("---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1]))
# Load environment map texture
self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
self.ref_mesh = mesh.compute_tangents(ref_mesh)
def _rotate_scene(self, itr):
proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
# Smooth rotation for display.
ang = (itr / 50) * np.pi * 2
mv = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
mvp = proj_mtx @ mv
campos = torch.linalg.inv(mv)[:3, 3]
return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp
def _random_scene(self):
# ==============================================================================================
# Setup projection matrix
# ==============================================================================================
iter_res = self.FLAGS.train_res
proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
# ==============================================================================================
# Random camera & light position
# ==============================================================================================
# Random rotation/translation matrix for optimization.
if self.mesh_center is not None:
mv = (
util.translate(-self.mesh_center[0], -self.mesh_center
gitextract_zb4dyoqx/ ├── .gitignore ├── GMeshDiffusion/ │ ├── diffusion_configs/ │ │ ├── config_lower_occgrid_normalized.py │ │ └── config_upper_occgrid_normalized.py │ ├── lib/ │ │ ├── dataset/ │ │ │ ├── gshell_dataset.py │ │ │ └── gshell_dataset_aug.py │ │ └── diffusion/ │ │ ├── evaler.py │ │ ├── likelihood.py │ │ ├── losses.py │ │ ├── models/ │ │ │ ├── __init__.py │ │ │ ├── ema.py │ │ │ ├── functional.py │ │ │ ├── layers.py │ │ │ ├── normalization.py │ │ │ ├── unet3d_occgrid.py │ │ │ └── utils.py │ │ ├── sampling.py │ │ ├── sde_lib.py │ │ ├── trainer.py │ │ ├── trainer_ddp.py │ │ └── utils.py │ ├── main_diffusion.py │ ├── main_diffusion_ddp.py │ ├── metadata/ │ │ ├── get_splits_lower.py │ │ ├── get_splits_upper.py │ │ ├── save_tet_info.py │ │ └── tet_to_cubic_grid_dataset.py │ └── scripts/ │ ├── run_eval_lower_occgrid_normalized.sh │ ├── run_eval_upper_occgrid_normalized.sh │ ├── run_lower_occgrid_normalized_ddp.sh │ └── run_upper_occgrid_normalized_ddp.sh ├── README.md ├── configs/ │ ├── deepfashion_mc.json │ ├── deepfashion_mc_256.json │ ├── deepfashion_mc_512.json │ ├── deepfashion_mc_80.json │ ├── nerf_chair.json │ ├── polycam_mc.json │ ├── polycam_mc_128.json │ └── polycam_mc_16samples.json ├── data/ │ └── tets/ │ └── generate_tets.py ├── dataset/ │ ├── __init__.py │ ├── dataset.py │ ├── dataset_deepfashion.py │ ├── dataset_deepfashion_testset.py │ ├── dataset_llff.py │ ├── dataset_mesh.py │ ├── dataset_nerf.py │ └── dataset_nerf_colmap.py ├── denoiser/ │ └── denoiser.py ├── eval_gmeshdiffusion_generated_samples.py ├── geometry/ │ ├── embedding.py │ ├── flexicubes_table.py │ ├── gshell_flexicubes.py │ ├── gshell_flexicubes_geometry.py │ ├── gshell_tets.py │ ├── gshell_tets_geometry.py │ └── mlp.py ├── render/ │ ├── light.py │ ├── material.py │ ├── mesh.py │ ├── mlptexture.py │ ├── obj.py │ ├── optixutils/ │ │ ├── __init__.py │ │ ├── c_src/ │ │ │ ├── accessor.h │ │ │ ├── bsdf.h │ │ │ ├── common.h │ │ │ ├── denoising.cu │ │ │ ├── denoising.h │ │ │ ├── envsampling/ │ │ │ │ ├── kernel.cu │ │ │ │ └── params.h │ │ │ ├── math_utils.h │ │ │ ├── optix_wrapper.cpp │ │ │ ├── optix_wrapper.h │ │ │ └── torch_bindings.cpp │ │ ├── include/ │ │ │ ├── internal/ │ │ │ │ ├── optix_7_device_impl.h │ │ │ │ ├── optix_7_device_impl_exception.h │ │ │ │ └── optix_7_device_impl_transformations.h │ │ │ ├── optix.h │ │ │ ├── optix_7_device.h │ │ │ ├── optix_7_host.h │ │ │ ├── optix_7_types.h │ │ │ ├── optix_denoiser_tiling.h │ │ │ ├── optix_device.h │ │ │ ├── optix_function_table.h │ │ │ ├── optix_function_table_definition.h │ │ │ ├── optix_host.h │ │ │ ├── optix_stack_size.h │ │ │ ├── optix_stubs.h │ │ │ └── optix_types.h │ │ ├── ops.py │ │ └── tests/ │ │ └── filter_test.py │ ├── regularizer.py │ ├── render.py │ ├── renderutils/ │ │ ├── __init__.py │ │ ├── bsdf.py │ │ ├── c_src/ │ │ │ ├── bsdf.cu │ │ │ ├── bsdf.h │ │ │ ├── common.cpp │ │ │ ├── common.h │ │ │ ├── cubemap.cu │ │ │ ├── cubemap.h │ │ │ ├── loss.cu │ │ │ ├── loss.h │ │ │ ├── mesh.cu │ │ │ ├── mesh.h │ │ │ ├── normal.cu │ │ │ ├── normal.h │ │ │ ├── tensor.h │ │ │ ├── torch_bindings.cpp │ │ │ ├── vec3f.h │ │ │ └── vec4f.h │ │ ├── loss.py │ │ ├── ops.py │ │ └── tests/ │ │ ├── test_bsdf.py │ │ ├── test_loss.py │ │ ├── test_mesh.py │ │ └── test_perf.py │ ├── texture.py │ └── util.py ├── train_gflexicubes_deepfashion.py ├── train_gflexicubes_polycam.py ├── train_gshelltet_deepfashion.py ├── train_gshelltet_polycam.py └── train_gshelltet_synthetic.py
SYMBOL INDEX (975 symbols across 88 files)
FILE: GMeshDiffusion/diffusion_configs/config_lower_occgrid_normalized.py
function get_config (line 6) | def get_config():
FILE: GMeshDiffusion/diffusion_configs/config_upper_occgrid_normalized.py
function get_config (line 6) | def get_config():
FILE: GMeshDiffusion/lib/dataset/gshell_dataset.py
class GShellDataset (line 5) | class GShellDataset(Dataset):
method __init__ (line 6) | def __init__(self, filepath_metafile, extension='pt'):
method __len__ (line 14) | def __len__(self):
method __getitem__ (line 17) | def __getitem__(self, idx):
FILE: GMeshDiffusion/lib/dataset/gshell_dataset_aug.py
class GShellAugDataset (line 4) | class GShellAugDataset(Dataset):
method __init__ (line 5) | def __init__(self, FLAGS, extension='pt'):
method __len__ (line 17) | def __len__(self):
method __getitem__ (line 20) | def __getitem__(self, idx):
method collate (line 31) | def collate(data):
FILE: GMeshDiffusion/lib/diffusion/evaler.py
function uncond_gen (line 15) | def uncond_gen(
function slerp (line 78) | def slerp(z1, z2, alpha):
function uncond_gen_interp (line 88) | def uncond_gen_interp(
function cond_gen (line 183) | def cond_gen(
FILE: GMeshDiffusion/lib/diffusion/likelihood.py
function get_div_fn (line 26) | def get_div_fn(fn):
function get_likelihood_fn (line 40) | def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',
FILE: GMeshDiffusion/lib/diffusion/losses.py
function get_optimizer (line 25) | def get_optimizer(config, params):
function optimization_manager (line 40) | def optimization_manager(config):
function get_ddpm_loss_fn (line 60) | def get_ddpm_loss_fn(vpsde, train, loss_type='l2', pred_type='noise', us...
function get_step_fn (line 194) | def get_step_fn(sde, train, optimize_fn=None, loss_type='l2', pred_type=...
FILE: GMeshDiffusion/lib/diffusion/models/ema.py
class ExponentialMovingAverage (line 10) | class ExponentialMovingAverage:
method __init__ (line 15) | def __init__(self, parameters, decay, use_num_updates=True):
method update (line 32) | def update(self, parameters):
method copy_to (line 54) | def copy_to(self, parameters):
method store (line 67) | def store(self, parameters):
method restore (line 77) | def restore(self, parameters):
method state_dict (line 92) | def state_dict(self):
method load_state_dict (line 96) | def load_state_dict(self, state_dict, device='cuda'):
FILE: GMeshDiffusion/lib/diffusion/models/functional.py
function has_cuda (line 39) | def has_cuda():
function has_half (line 43) | def has_half():
function has_bfloat (line 47) | def has_bfloat():
function has_gemm (line 51) | def has_gemm():
function enable_tf32 (line 55) | def enable_tf32():
function disable_tf32 (line 59) | def disable_tf32():
function enable_tiled_na (line 63) | def enable_tiled_na():
function disable_tiled_na (line 67) | def disable_tiled_na():
function enable_gemm_na (line 71) | def enable_gemm_na():
function disable_gemm_na (line 75) | def disable_gemm_na():
class NeighborhoodAttention1DQKAutogradFunction (line 79) | class NeighborhoodAttention1DQKAutogradFunction(Function):
method forward (line 82) | def forward(ctx, query, key, rpb, kernel_size, dilation):
method backward (line 94) | def backward(ctx, grad_out):
class NeighborhoodAttention1DAVAutogradFunction (line 107) | class NeighborhoodAttention1DAVAutogradFunction(Function):
method forward (line 110) | def forward(ctx, attn, value, kernel_size, dilation):
method backward (line 121) | def backward(ctx, grad_out):
class NeighborhoodAttention2DQKAutogradFunction (line 133) | class NeighborhoodAttention2DQKAutogradFunction(Function):
method forward (line 136) | def forward(ctx, query, key, rpb, kernel_size, dilation):
method backward (line 150) | def backward(ctx, grad_out):
class NeighborhoodAttention2DAVAutogradFunction (line 163) | class NeighborhoodAttention2DAVAutogradFunction(Function):
method forward (line 166) | def forward(ctx, attn, value, kernel_size, dilation):
method backward (line 177) | def backward(ctx, grad_out):
class NeighborhoodAttention3DQKAutogradFunction (line 189) | class NeighborhoodAttention3DQKAutogradFunction(Function):
method forward (line 192) | def forward(ctx, query, key, rpb, kernel_size_d, kernel_size, dilation...
method backward (line 208) | def backward(ctx, grad_out):
class NeighborhoodAttention3DAVAutogradFunction (line 223) | class NeighborhoodAttention3DAVAutogradFunction(Function):
method forward (line 226) | def forward(ctx, attn, value, kernel_size_d, kernel_size, dilation_d, ...
method backward (line 241) | def backward(ctx, grad_out):
function natten1dqkrpb (line 255) | def natten1dqkrpb(query, key, rpb, kernel_size, dilation):
function natten1dqk (line 261) | def natten1dqk(query, key, kernel_size, dilation):
function natten1dav (line 267) | def natten1dav(attn, value, kernel_size, dilation):
function natten2dqkrpb (line 273) | def natten2dqkrpb(query, key, rpb, kernel_size, dilation):
function natten2dqk (line 279) | def natten2dqk(query, key, kernel_size, dilation):
function natten2dav (line 285) | def natten2dav(attn, value, kernel_size, dilation):
function natten3dqkrpb (line 291) | def natten3dqkrpb(query, key, rpb, kernel_size_d, kernel_size, dilation_...
function natten3dqk (line 297) | def natten3dqk(query, key, kernel_size_d, kernel_size, dilation_d, dilat...
function natten3dav (line 303) | def natten3dav(attn, value, kernel_size_d, kernel_size, dilation_d, dila...
FILE: GMeshDiffusion/lib/diffusion/models/layers.py
class GroupNormFloat32 (line 28) | class GroupNormFloat32(nn.GroupNorm):
method forward (line 29) | def forward(self, input):
function get_act_fn (line 35) | def get_act_fn(act_name):
function variance_scaling (line 49) | def variance_scaling(scale, mode, distribution,
function default_init (line 83) | def default_init(scale=1.):
class Dense (line 89) | class Dense(nn.Module):
method __init__ (line 91) | def __init__(self):
function conv1x1 (line 95) | def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., p...
function conv3x3 (line 102) | def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init...
function conv5x5 (line 110) | def conv5x5(in_planes, out_planes, stride=2, bias=True, dilation=1, init...
function conv3x3_transposed (line 119) | def conv3x3_transposed(in_planes, out_planes, stride=2, bias=True, dilat...
function conv5x5_transposed (line 127) | def conv5x5_transposed(in_planes, out_planes, stride=2, bias=True, dilat...
function get_timestep_embedding (line 141) | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
class AttnBlock (line 158) | class AttnBlock(nn.Module):
method __init__ (line 160) | def __init__(self, channels, num_groups=32):
method forward (line 168) | def forward(self, x):
class Upsample (line 192) | class Upsample(nn.Module):
method __init__ (line 193) | def __init__(self, channels, with_conv=False):
method forward (line 199) | def forward(self, x, temb=None):
class Downsample (line 207) | class Downsample(nn.Module):
method __init__ (line 208) | def __init__(self, channels, with_conv=False):
method forward (line 214) | def forward(self, x, temb=None):
class ResBlock (line 227) | class ResBlock(nn.Module):
method __init__ (line 229) | def __init__(self, act_fn, in_ch, out_ch=None, temb_dim=None, conv_sho...
method forward (line 253) | def forward(self, x, temb=None):
class AttnResBlock (line 272) | class AttnResBlock(ResBlock):
method __init__ (line 274) | def __init__(self, act_fn, in_ch, out_ch, temb_dim=None, conv_shortcut...
method forward (line 278) | def forward(self, x, temb=None):
FILE: GMeshDiffusion/lib/diffusion/models/normalization.py
function get_normalization (line 22) | def get_normalization(config, conditional=False):
class ConditionalBatchNorm3d (line 43) | class ConditionalBatchNorm3d(nn.Module):
method __init__ (line 44) | def __init__(self, num_features, num_classes, bias=True):
method forward (line 57) | def forward(self, x, y):
class ConditionalInstanceNorm3d (line 68) | class ConditionalInstanceNorm3d(nn.Module):
method __init__ (line 69) | def __init__(self, num_features, num_classes, bias=True):
method forward (line 82) | def forward(self, x, y):
class ConditionalVarianceNorm3d (line 93) | class ConditionalVarianceNorm3d(nn.Module):
method __init__ (line 94) | def __init__(self, num_features, num_classes, bias=False):
method forward (line 101) | def forward(self, x, y):
class VarianceNorm3d (line 110) | class VarianceNorm3d(nn.Module):
method __init__ (line 111) | def __init__(self, num_features, bias=False):
method forward (line 118) | def forward(self, x):
class ConditionalNoneNorm3d (line 126) | class ConditionalNoneNorm3d(nn.Module):
method __init__ (line 127) | def __init__(self, num_features, num_classes, bias=True):
method forward (line 139) | def forward(self, x, y):
class NoneNorm3d (line 149) | class NoneNorm3d(nn.Module):
method __init__ (line 150) | def __init__(self, num_features, bias=True):
method forward (line 153) | def forward(self, x):
class InstanceNorm3dPlus (line 157) | class InstanceNorm3dPlus(nn.Module):
method __init__ (line 158) | def __init__(self, num_features, bias=True):
method forward (line 170) | def forward(self, x):
class ConditionalInstanceNorm3dPlus (line 186) | class ConditionalInstanceNorm3dPlus(nn.Module):
method __init__ (line 187) | def __init__(self, num_features, num_classes, bias=True):
method forward (line 200) | def forward(self, x, y):
FILE: GMeshDiffusion/lib/diffusion/models/unet3d_occgrid.py
function str_to_class (line 33) | def str_to_class(classname):
class UNet3D (line 38) | class UNet3D(nn.Module):
method __init__ (line 39) | def __init__(self, config):
method sequentially_call_module (line 142) | def sequentially_call_module(self, idx, x, temb=None):
method forward (line 145) | def forward(self, x, labels):
FILE: GMeshDiffusion/lib/diffusion/models/utils.py
function register_model (line 27) | def register_model(cls=None, *, name=None):
function get_model (line 46) | def get_model(name):
function get_sigmas (line 50) | def get_sigmas(config):
function get_ddpm_params (line 63) | def get_ddpm_params(config):
function create_model (line 88) | def create_model(config, use_parallel=True, ddp=False, rank=None):
function get_model_fn (line 111) | def get_model_fn(model, train=False):
function get_reg_fn (line 142) | def get_reg_fn(model, train=False):
function get_score_fn (line 179) | def get_score_fn(sde, model, train=False, continuous=False, std_scale=Tr...
function to_flattened_numpy (line 236) | def to_flattened_numpy(x):
function from_flattened_numpy (line 241) | def from_flattened_numpy(x, shape):
FILE: GMeshDiffusion/lib/diffusion/sampling.py
function register_predictor (line 37) | def register_predictor(cls=None, *, name=None):
function register_corrector (line 56) | def register_corrector(cls=None, *, name=None):
function get_predictor (line 75) | def get_predictor(name):
function get_corrector (line 79) | def get_corrector(name):
function get_sampling_fn (line 83) | def get_sampling_fn(config, sde, shape, inverse_scaler, eps, grid_mask=N...
class Predictor (line 139) | class Predictor(abc.ABC):
method __init__ (line 142) | def __init__(self, sde, score_fn, probability_flow=False):
method update_fn (line 150) | def update_fn(self, x, t):
class Corrector (line 164) | class Corrector(abc.ABC):
method __init__ (line 167) | def __init__(self, sde, score_fn, snr, n_steps):
method update_fn (line 175) | def update_fn(self, x, t):
class EulerMaruyamaPredictor (line 190) | class EulerMaruyamaPredictor(Predictor):
method __init__ (line 191) | def __init__(self, sde, score_fn, probability_flow=False):
method update_fn (line 194) | def update_fn(self, x, t):
class ReverseDiffusionPredictor (line 204) | class ReverseDiffusionPredictor(Predictor):
method __init__ (line 205) | def __init__(self, sde, score_fn, probability_flow=False):
method update_fn (line 208) | def update_fn(self, x, t):
class AncestralSamplingPredictor (line 217) | class AncestralSamplingPredictor(Predictor):
method __init__ (line 220) | def __init__(self, sde, score_fn, probability_flow=False):
method vpsde_update_fn (line 226) | def vpsde_update_fn(self, x, t):
method update_fn (line 236) | def update_fn(self, x, t):
class NonePredictor (line 244) | class NonePredictor(Predictor):
method __init__ (line 247) | def __init__(self, sde, score_fn, probability_flow=False):
method update_fn (line 250) | def update_fn(self, x, t):
class DDIMPredictor (line 254) | class DDIMPredictor(Predictor):
method __init__ (line 255) | def __init__(self, sde, score_fn, probability_flow=False):
method update_fn (line 259) | def update_fn(self, x, t, tprev=None):
class LangevinCorrector (line 264) | class LangevinCorrector(Corrector):
method __init__ (line 265) | def __init__(self, sde, score_fn, snr, n_steps):
method update_fn (line 270) | def update_fn(self, x, t):
class AnnealedLangevinDynamics (line 294) | class AnnealedLangevinDynamics(Corrector):
method __init__ (line 300) | def __init__(self, sde, score_fn, snr, n_steps):
method update_fn (line 305) | def update_fn(self, x, t):
class NoneCorrector (line 329) | class NoneCorrector(Corrector):
method __init__ (line 332) | def __init__(self, sde, score_fn, snr, n_steps):
method update_fn (line 335) | def update_fn(self, x, t):
function shared_predictor_update_fn (line 339) | def shared_predictor_update_fn(x, t, sde, model, predictor, probability_...
function shared_corrector_update_fn (line 350) | def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, ...
function get_pc_sampler (line 361) | def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr,
function ddim_predictor_update_fn (line 508) | def ddim_predictor_update_fn(x, t, tprev, sde, model, predictor, probabi...
function get_ddim_sampler (line 519) | def get_ddim_sampler(sde, shape, predictor, inverse_scaler, n_steps=1,
FILE: GMeshDiffusion/lib/diffusion/sde_lib.py
class SDE (line 9) | class SDE(abc.ABC):
method __init__ (line 12) | def __init__(self, N):
method T (line 23) | def T(self):
method sde (line 28) | def sde(self, x, t):
method marginal_prob (line 32) | def marginal_prob(self, x, t):
method prior_sampling (line 37) | def prior_sampling(self, shape):
method prior_logp (line 42) | def prior_logp(self, z):
method discretize (line 54) | def discretize(self, x, t):
method reverse (line 73) | def reverse(self, score_fn, probability_flow=False):
class VPSDE (line 209) | class VPSDE(SDE):
method __init__ (line 210) | def __init__(self, beta_min=0.1, beta_max=20, N=1000):
method T (line 234) | def T(self):
method sde (line 237) | def sde(self, x, t):
method marginal_prob (line 243) | def marginal_prob(self, x, t):
method prior_sampling (line 249) | def prior_sampling(self, shape):
method prior_logp (line 252) | def prior_logp(self, z):
method discretize (line 258) | def discretize(self, x, t):
FILE: GMeshDiffusion/lib/diffusion/trainer.py
function train (line 20) | def train(config):
FILE: GMeshDiffusion/lib/diffusion/trainer_ddp.py
function train (line 22) | def train(config):
FILE: GMeshDiffusion/lib/diffusion/utils.py
function restore_checkpoint (line 6) | def restore_checkpoint(ckpt_dir, state, device, strict=False, rank=None):
function save_checkpoint (line 38) | def save_checkpoint(ckpt_dir, state):
FILE: GMeshDiffusion/main_diffusion.py
function main (line 19) | def main(argv):
FILE: GMeshDiffusion/main_diffusion_ddp.py
function main (line 20) | def main(argv):
FILE: GMeshDiffusion/metadata/save_tet_info.py
function tet_to_grids (line 14) | def tet_to_grids(vertices, values_list, grid_size):
FILE: GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py
function tet_to_grids (line 7) | def tet_to_grids(vertices, values_list, grid_size):
FILE: data/tets/generate_tets.py
function generate_tetrahedron_grid_file (line 21) | def generate_tetrahedron_grid_file(res=32, root='..'):
function convert_from_quartet_to_npz (line 31) | def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile...
FILE: dataset/dataset.py
class Dataset (line 12) | class Dataset(torch.utils.data.Dataset):
method __init__ (line 14) | def __init__(self):
method __len__ (line 17) | def __len__(self):
method __getitem__ (line 20) | def __getitem__(self):
method collate (line 23) | def collate(self, batch):
FILE: dataset/dataset_deepfashion.py
function load_K_Rt_from_P (line 24) | def load_K_Rt_from_P(filename, P=None):
function _load_img (line 48) | def _load_img(path):
class DatasetDeepFashion (line 59) | class DatasetDeepFashion(Dataset):
method __init__ (line 60) | def __init__(self, base_dir, FLAGS, examples=None):
method _parse_frame (line 101) | def _parse_frame(self, idx):
method __len__ (line 121) | def __len__(self):
method __getitem__ (line 124) | def __getitem__(self, itr):
FILE: dataset/dataset_deepfashion_testset.py
function load_K_Rt_from_P (line 24) | def load_K_Rt_from_P(filename, P=None):
function _load_img (line 48) | def _load_img(path):
function _load_mask (line 58) | def _load_mask(path):
class DatasetDeepFashionTestset (line 67) | class DatasetDeepFashionTestset(Dataset):
method __init__ (line 68) | def __init__(self, base_dir, FLAGS, examples=None):
method _parse_frame (line 101) | def _parse_frame(self, idx):
method __len__ (line 119) | def __len__(self):
method __getitem__ (line 122) | def __getitem__(self, itr):
FILE: dataset/dataset_llff.py
function _load_mask (line 20) | def _load_mask(fn):
function _load_img (line 26) | def _load_img(fn):
class DatasetLLFF (line 39) | class DatasetLLFF(Dataset):
method __init__ (line 40) | def __init__(self, base_dir, FLAGS, examples=None):
method _parse_frame (line 75) | def _parse_frame(self, idx):
method __len__ (line 93) | def __len__(self):
method __getitem__ (line 96) | def __getitem__(self, itr):
FILE: dataset/dataset_mesh.py
class DatasetMesh (line 24) | class DatasetMesh(Dataset):
method __init__ (line 26) | def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False,...
method _rotate_scene (line 54) | def _rotate_scene(self, itr):
method _random_scene (line 65) | def _random_scene(self):
method __len__ (line 89) | def __len__(self):
method __getitem__ (line 92) | def __getitem__(self, itr):
FILE: dataset/dataset_nerf.py
function _load_img (line 25) | def _load_img(path):
class DatasetNERF (line 36) | class DatasetNERF(Dataset):
method __init__ (line 37) | def __init__(self, cfg_path, FLAGS, examples=None):
method _parse_frame (line 59) | def _parse_frame(self, cfg, idx):
method __len__ (line 73) | def __len__(self):
method __getitem__ (line 76) | def __getitem__(self, itr):
FILE: dataset/dataset_nerf_colmap.py
function _load_img (line 25) | def _load_img(path):
class DatasetNERF (line 34) | class DatasetNERF(Dataset):
method __init__ (line 35) | def __init__(self, cfg_path, FLAGS, examples=None):
method _parse_frame (line 57) | def _parse_frame(self, cfg, idx):
method __len__ (line 73) | def __len__(self):
method __getitem__ (line 76) | def __getitem__(self, itr):
FILE: denoiser/denoiser.py
class BilateralDenoiser (line 21) | class BilateralDenoiser(torch.nn.Module):
method __init__ (line 22) | def __init__(self, influence=1.0):
method set_influence (line 26) | def set_influence(self, factor):
method forward (line 31) | def forward(self, input):
FILE: geometry/embedding.py
class Embedding (line 4) | class Embedding(nn.Module):
method __init__ (line 5) | def __init__(self, in_channels, N_freqs, logscale=True):
method forward (line 21) | def forward(self, x):
FILE: geometry/gshell_flexicubes.py
class GShellFlexiCubes (line 16) | class GShellFlexiCubes:
method __init__ (line 67) | def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
method construct_voxel_grid (line 103) | def construct_voxel_grid(self, res):
method __call__ (line 136) | def __call__(self, x_nx3, s_n, nu_n, cube_fx8, res, beta_fx12=None, al...
method _compute_reg_loss (line 232) | def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
method _normalize_weights (line 242) | def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
method _get_case_id (line 266) | def _get_case_id(self, occ_fx8, surf_cubes, res):
method _identify_surf_edges (line 309) | def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
method _identify_surf_cubes (line 334) | def _identify_surf_cubes(self, s_n, cube_fx8):
method _linear_interp (line 345) | def _linear_interp(self, edges_weight, edges_x):
method _linear_interp_nonan (line 357) | def _linear_interp_nonan(self, edges_weight, edges_x):
method _solve_vd_QEF (line 373) | def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
method _compute_vd (line 387) | def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, nu_n, ca...
method _triangulate (line 487) | def _triangulate(self, s_n, surf_edges, vd, nu_d, nu_d_stopvgd, vd_gam...
method _triangulate_msdf (line 554) | def _triangulate_msdf(self, vertices, faces, nu_n, nu_n_stopvgd):
method _tetrahedralize (line 593) | def _tetrahedralize(
FILE: geometry/gshell_flexicubes_geometry.py
function compute_sdf_reg_loss (line 33) | def compute_sdf_reg_loss(sdf, all_edges):
class GShellFlexiCubesGeometry (line 45) | class GShellFlexiCubesGeometry(torch.nn.Module):
method __init__ (line 46) | def __init__(self, grid_res, scale, FLAGS):
method generate_edges (line 111) | def generate_edges(self):
method getAABB (line 120) | def getAABB(self):
method clamp_deform (line 124) | def clamp_deform(self):
method map_uv2 (line 130) | def map_uv2(self, faces):
method map_uv (line 136) | def map_uv(self, face_gidx, max_idx):
method getMesh (line 166) | def getMesh(self, material, _training=False):
method render (line 210) | def render(self, glctx, target, lgt, opt_material, bsdf=None, denoiser...
method tick (line 237) | def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, d...
FILE: geometry/gshell_tets.py
function auto_normals (line 9) | def auto_normals(v_pos, t_pos_idx):
function compute_tangents (line 40) | def compute_tangents(v_pos, v_tex, v_nrm, t_pos_idx, t_tex_idx, t_nrm_idx):
class GShell_Tets (line 80) | class GShell_Tets:
method __init__ (line 81) | def __init__(self):
method sort_edges (line 200) | def sort_edges(self, edges_ex2):
method map_uv (line 210) | def map_uv(self, face_gidx, max_idx):
method __call__ (line 245) | def __call__(self, pos_nx3, sdf_n, msdf_n, tet_fx4, output_watertight_...
method marching_from_auggrid (line 447) | def marching_from_auggrid(self, pos_nx3, sdf_n, tet_fx4,
FILE: geometry/gshell_tets_geometry.py
function compute_sdf_reg_loss (line 33) | def compute_sdf_reg_loss(sdf, all_edges):
class GShellTetsGeometry (line 45) | class GShellTetsGeometry(torch.nn.Module):
method __init__ (line 46) | def __init__(self, grid_res, scale, FLAGS, offset=None, tet_init_file=...
method generate_edges (line 149) | def generate_edges(self):
method getAABB (line 158) | def getAABB(self):
method clamp_deform (line 162) | def clamp_deform(self):
method getMesh_from_augmented_grid_withocc (line 167) | def getMesh_from_augmented_grid_withocc(self, material, sdf_sign, sdf_...
method getMesh (line 191) | def getMesh(self, material):
method render (line 230) | def render(self, glctx, target, lgt, opt_material, bsdf=None, denoiser...
method tick (line 257) | def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, d...
FILE: geometry/mlp.py
class MLP (line 7) | class MLP(nn.Module):
method __init__ (line 8) | def __init__(self, n_freq=6, d_hidden=128, d_out=1, n_hidden=3, skip_i...
method forward (line 32) | def forward(self, x):
FILE: render/light.py
class EnvironmentLight (line 21) | class EnvironmentLight:
method __init__ (line 27) | def __init__(self, base):
method xfm (line 34) | def xfm(self, mtx):
method parameters (line 37) | def parameters(self):
method clone (line 40) | def clone(self):
method clamp_ (line 43) | def clamp_(self, min=None, max=None):
method update_pdf (line 46) | def update_pdf(self):
method generate_image (line 62) | def generate_image(self, res):
function _load_env_hdr (line 71) | def _load_env_hdr(fn, scale=1.0, res=None, trainable=False):
function load_env (line 86) | def load_env(fn, scale=1.0, res=None, trainable=False):
function save_env_map (line 93) | def save_env_map(fn, light):
function create_trainable_env_rnd (line 102) | def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25):
FILE: render/material.py
function load_mtl (line 21) | def load_mtl(fn, clear_ks=True):
function save_mtl (line 72) | def save_mtl(fn, material):
function create_trainable (line 99) | def create_trainable(material):
function get_parameters (line 106) | def get_parameters(material):
function _upscale_replicate (line 117) | def _upscale_replicate(x, full_res):
function merge_materials (line 122) | def merge_materials(materials, texcoords, tfaces, mfaces):
FILE: render/mesh.py
class Mesh (line 20) | class Mesh:
method __init__ (line 21) | def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=N...
method copy_none (line 35) | def copy_none(self, other):
method clone (line 55) | def clone(self):
function load_mesh (line 79) | def load_mesh(filename, mtl_override=None, mtl_default=None, mtl_type_ov...
function aabb (line 88) | def aabb(mesh):
function aabb_clean (line 94) | def aabb_clean(mesh):
function compute_edges (line 101) | def compute_edges(attr_idx, return_inverse=False):
function compute_edge_to_face_mapping (line 123) | def compute_edge_to_face_mapping(attr_idx, return_inverse=False):
function unit_size (line 158) | def unit_size(mesh):
function center_by_reference (line 170) | def center_by_reference(base_mesh, ref_aabb, scale):
function center_by_reference_noscale (line 177) | def center_by_reference_noscale(base_mesh, ref_aabb, scale=None):
function center_with_global_aabb (line 183) | def center_with_global_aabb(base_mesh, ref_aabb, scale):
function center_with_global_aabb_perdim (line 191) | def center_with_global_aabb_perdim(base_mesh, ref_aabb, scale):
function scale_with_global_aabb (line 198) | def scale_with_global_aabb(base_mesh, ref_aabb, scale):
function scale_with_global_aabb_perdim (line 204) | def scale_with_global_aabb_perdim(base_mesh, ref_aabb, scale):
function auto_normals (line 212) | def auto_normals(imesh):
function compute_tangents (line 243) | def compute_tangents(imesh, v_tng=None):
FILE: render/mlptexture.py
class _MLP (line 18) | class _MLP(torch.nn.Module):
method __init__ (line 19) | def __init__(self, cfg, loss_scale=1.0):
method forward (line 33) | def forward(self, x):
method _init_weights (line 37) | def _init_weights(m):
class MLPTexture3D (line 47) | class MLPTexture3D(torch.nn.Module):
method __init__ (line 48) | def __init__(self, AABB, channels = 3, internal_dims = 32, hidden = 2,...
method sample (line 87) | def sample(self, texc):
method clamp_ (line 101) | def clamp_(self):
method cleanup (line 104) | def cleanup(self):
FILE: render/obj.py
function _find_mat (line 21) | def _find_mat(materials, name):
function load_obj (line 31) | def load_obj(filename, clear_ks=True, mtl_override=None, mtl_default=Non...
function write_obj (line 143) | def write_obj(folder, mesh, save_material=True):
FILE: render/optixutils/c_src/accessor.h
type T (line 122) | typedef T* PtrType;
type T (line 128) | typedef T* __restrict__ PtrType;
function C10_HOST_DEVICE (line 150) | C10_HOST_DEVICE index_t size(index_t i) const {
function C10_HOST_DEVICE (line 153) | C10_HOST_DEVICE PtrType data() {
function PtrTraits (line 184) | PtrTraits, index_t> operator[](index_t i) const {
function C10_HOST_DEVICE (line 199) | C10_HOST_DEVICE T & operator[](index_t i) {
function C10_HOST_DEVICE (line 203) | C10_HOST_DEVICE const T & operator[](index_t i) const {
function C10_HOST (line 222) | C10_HOST GenericPackedTensorAccessorBase() {}
function C10_HOST (line 224) | C10_HOST GenericPackedTensorAccessorBase(
function C10_HOST_DEVICE (line 246) | C10_HOST_DEVICE index_t stride(index_t i) const {
function C10_HOST_DEVICE (line 249) | C10_HOST_DEVICE index_t size(index_t i) const {
function C10_HOST_DEVICE (line 252) | C10_HOST_DEVICE PtrType data() {
function C10_HOST (line 270) | C10_HOST GenericPackedTensorAccessor() : GenericPackedTensorAccessorBase...
function C10_DEVICE (line 323) | C10_DEVICE T & operator[](index_t i) {
function C10_DEVICE (line 326) | C10_DEVICE const T& operator[](index_t i) const {
FILE: render/optixutils/c_src/bsdf.h
function __device__ (line 21) | __device__ inline float fwdLambert(const float3 nrm, const float3 wi)
function __device__ (line 26) | __device__ inline void bwdLambert(const float3 nrm, const float3 wi, flo...
function __device__ (line 35) | __device__ inline float fwdFresnelSchlick(const float f0, const float f9...
function __device__ (line 42) | __device__ inline void bwdFresnelSchlick(const float f0, const float f90...
function __device__ (line 54) | __device__ inline float3 fwdFresnelSchlick(const float3 f0, const float3...
function __device__ (line 61) | __device__ inline void bwdFresnelSchlick(const float3 f0, const float3 f...
function __device__ (line 76) | __device__ inline float fwdNdfGGX(const float alphaSqr, const float cosT...
function __device__ (line 83) | __device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTh...
function __device__ (line 98) | __device__ inline float fwdLambdaGGX(const float alphaSqr, const float c...
function __device__ (line 107) | __device__ inline void bwdLambdaGGX(const float alphaSqr, const float co...
function __device__ (line 122) | __device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSq...
function __device__ (line 129) | __device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr...
function __device__ (line 144) | __device__ float3 fwdPbrSpecular(const float3 col, const float3 nrm, con...
function __device__ (line 164) | __device__ void bwdPbrSpecular(
function __device__ (line 222) | __device__ void fwdPbrBSDF(const float3 kd, const float3 arm, const floa...
function __device__ (line 238) | __device__ void bwdPbrBSDF(
FILE: render/optixutils/c_src/common.h
function __device__ (line 14) | __device__ inline float3 fetch3(const T &tensor, U idx, Args... args) {
function float3 (line 17) | inline float3 fetch3(const T &tensor) {
function __device__ (line 22) | __device__ inline float2 fetch2(const T &tensor, U idx, Args... args) {
function float2 (line 25) | inline float2 fetch2(const T &tensor) {
FILE: render/optixutils/c_src/denoising.h
type BilateralDenoiserParams (line 12) | struct BilateralDenoiserParams
FILE: render/optixutils/c_src/envsampling/params.h
type EnvSamplingParams (line 11) | struct EnvSamplingParams
FILE: render/optixutils/c_src/math_utils.h
function T (line 13) | __inline__ T clamp(T x, T _min, T _max) { return min(_max, max(_min, x)); }
function __device__ (line 14) | static __device__ inline float3 make_float3(float a) { return make_float...
function __device__ (line 80) | static __device__ inline float sum(float3 a)
function __device__ (line 85) | static __device__ inline float dot(float3 a, float3 b) { return a.x * b....
function __device__ (line 87) | static __device__ inline void bwd_dot(float3 a, float3 b, float3& d_a, f...
function __device__ (line 93) | static __device__ inline float luminance(const float3 rgb)
function __device__ (line 98) | static __device__ inline float3 cross(float3 a, float3 b)
function __device__ (line 107) | static __device__ inline void bwd_cross(float3 a, float3 b, float3 &d_a,...
function __device__ (line 118) | static __device__ inline float3 reflect(float3 x, float3 n)
function __device__ (line 123) | static __device__ inline void bwd_reflect(float3 x, float3 n, float3& d_...
function __device__ (line 134) | static __device__ inline float3 safe_normalize(float3 v)
function __device__ (line 140) | static __device__ inline void bwd_safe_normalize(const float3 v, float3&...
function __device__ (line 155) | static __device__ inline void branchlessONB(const float3 &n, float3 &b1,...
FILE: render/optixutils/c_src/optix_wrapper.cpp
function context_log_cb (line 43) | static void context_log_cb( unsigned int level, const char* tag, const c...
function readSourceFile (line 49) | static bool readSourceFile( std::string& str, const std::string& filename )
function getCuStringFromFile (line 63) | static void getCuStringFromFile( std::string& cu, const char* filename )
function getPtxFromCuString (line 73) | static void getPtxFromCuString( std::string& ptx, const char* include_di...
type SbtRecord (line 144) | struct SbtRecord
function createPipeline (line 149) | void createPipeline(const OptixDeviceContext context, const std::string&...
FILE: render/optixutils/c_src/optix_wrapper.h
type OptiXState (line 17) | struct OptiXState
function class (line 30) | class OptiXStateWrapper
FILE: render/optixutils/c_src/torch_bindings.cpp
function optix_build_bvh (line 37) | void optix_build_bvh(OptiXStateWrapper& stateWrapper,torch::Tensor grid_...
function packed_accessor32 (line 118) | PackedTensorAccessor32<T, N> packed_accessor32(torch::Tensor tensor)
function env_shade_fwd (line 123) | std::tuple<torch::Tensor, torch::Tensor> env_shade_fwd(
function env_shade_bwd (line 190) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, t...
function bilateral_denoiser_fwd (line 274) | torch::Tensor bilateral_denoiser_fwd(torch::Tensor col, torch::Tensor nr...
function bilateral_denoiser_bwd (line 297) | torch::Tensor bilateral_denoiser_bwd(torch::Tensor col, torch::Tensor nr...
function PYBIND11_MODULE (line 321) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: render/optixutils/include/internal/optix_7_device_impl.h
function optixTrace (line 39) | void optixTrace( OptixTraversableHandle handle,
function optixTrace (line 78) | void optixTrace( OptixTraversableHandle handle,
function optixTrace (line 118) | void optixTrace( OptixTraversableHandle handle,
function optixTrace (line 159) | void optixTrace( OptixTraversableHandle handle,
function optixTrace (line 201) | void optixTrace( OptixTraversableHandle handle,
function optixTrace (line 244) | void optixTrace( OptixTraversableHandle handle,
function optixTrace (line 288) | void optixTrace( OptixTraversableHandle handle,
function optixTrace (line 333) | void optixTrace( OptixTraversableHandle handle,
function optixTrace (line 379) | void optixTrace( OptixTraversableHandle handle,
function optixSetPayload_0 (line 427) | void optixSetPayload_0( unsigned int p )
function optixSetPayload_1 (line 432) | void optixSetPayload_1( unsigned int p )
function optixSetPayload_2 (line 437) | void optixSetPayload_2( unsigned int p )
function optixSetPayload_3 (line 442) | void optixSetPayload_3( unsigned int p )
function optixSetPayload_4 (line 447) | void optixSetPayload_4( unsigned int p )
function optixSetPayload_5 (line 452) | void optixSetPayload_5( unsigned int p )
function optixSetPayload_6 (line 457) | void optixSetPayload_6( unsigned int p )
function optixSetPayload_7 (line 462) | void optixSetPayload_7( unsigned int p )
function optixGetPayload_0 (line 468) | unsigned int optixGetPayload_0()
function optixGetPayload_1 (line 475) | unsigned int optixGetPayload_1()
function optixGetPayload_2 (line 482) | unsigned int optixGetPayload_2()
function optixGetPayload_3 (line 489) | unsigned int optixGetPayload_3()
function optixGetPayload_4 (line 496) | unsigned int optixGetPayload_4()
function optixGetPayload_5 (line 503) | unsigned int optixGetPayload_5()
function optixGetPayload_6 (line 510) | unsigned int optixGetPayload_6()
function optixGetPayload_7 (line 517) | unsigned int optixGetPayload_7()
function optixUndefinedValue (line 525) | unsigned int optixUndefinedValue()
function float3 (line 532) | float3 optixGetWorldRayOrigin()
function float3 (line 541) | float3 optixGetWorldRayDirection()
function float3 (line 550) | float3 optixGetObjectRayOrigin()
function float3 (line 559) | float3 optixGetObjectRayDirection()
function optixGetRayTmin (line 568) | float optixGetRayTmin()
function optixGetRayTmax (line 575) | float optixGetRayTmax()
function optixGetRayTime (line 582) | float optixGetRayTime()
function optixGetRayFlags (line 589) | unsigned int optixGetRayFlags()
function optixGetRayVisibilityMask (line 596) | unsigned int optixGetRayVisibilityMask()
function OptixTraversableHandle (line 603) | OptixTraversableHandle optixGetInstanceTraversableFromIAS( OptixTraversa...
function optixGetTriangleVertexData (line 613) | void optixGetTriangleVertexData( OptixTraversableHandle gas,
function optixGetLinearCurveVertexData (line 627) | void optixGetLinearCurveVertexData( OptixTraversableHandle gas,
function optixGetQuadraticBSplineVertexData (line 641) | void optixGetQuadraticBSplineVertexData( OptixTraversableHandle gas,
function optixGetCubicBSplineVertexData (line 656) | void optixGetCubicBSplineVertexData( OptixTraversableHandle gas,
function OptixTraversableHandle (line 673) | OptixTraversableHandle optixGetGASTraversableHandle()
function optixGetGASMotionTimeBegin (line 680) | float optixGetGASMotionTimeBegin( OptixTraversableHandle handle )
function optixGetGASMotionTimeEnd (line 687) | float optixGetGASMotionTimeEnd( OptixTraversableHandle handle )
function optixGetGASMotionStepCount (line 694) | unsigned int optixGetGASMotionStepCount( OptixTraversableHandle handle )
function optixGetWorldToObjectTransformMatrix (line 701) | void optixGetWorldToObjectTransformMatrix( float m[12] )
function optixGetObjectToWorldTransformMatrix (line 736) | void optixGetObjectToWorldTransformMatrix( float m[12] )
function float3 (line 771) | float3 optixTransformPointFromWorldToObjectSpace( float3 point )
function float3 (line 781) | float3 optixTransformVectorFromWorldToObjectSpace( float3 vec )
function float3 (line 791) | float3 optixTransformNormalFromWorldToObjectSpace( float3 normal )
function float3 (line 801) | float3 optixTransformPointFromObjectToWorldSpace( float3 point )
function float3 (line 811) | float3 optixTransformVectorFromObjectToWorldSpace( float3 vec )
function float3 (line 821) | float3 optixTransformNormalFromObjectToWorldSpace( float3 normal )
function optixGetTransformListSize (line 831) | unsigned int optixGetTransformListSize()
function OptixTraversableHandle (line 838) | OptixTraversableHandle optixGetTransformListHandle( unsigned int index )
function OptixTransformType (line 845) | OptixTransformType optixGetTransformTypeFromHandle( OptixTraversableHand...
function OptixStaticTransform (line 852) | const OptixStaticTransform* optixGetStaticTransformFromHandle( OptixTrav...
function OptixSRTMotionTransform (line 859) | const OptixSRTMotionTransform* optixGetSRTMotionTransformFromHandle( Opt...
function OptixMatrixMotionTransform (line 866) | const OptixMatrixMotionTransform* optixGetMatrixMotionTransformFromHandl...
function optixGetInstanceIdFromHandle (line 873) | unsigned int optixGetInstanceIdFromHandle( OptixTraversableHandle handle )
function OptixTraversableHandle (line 880) | OptixTraversableHandle optixGetInstanceChildFromHandle( OptixTraversable...
function float4 (line 887) | const float4* optixGetInstanceTransformFromHandle( OptixTraversableHandl...
function float4 (line 894) | const float4* optixGetInstanceInverseTransformFromHandle( OptixTraversab...
function optixReportIntersection (line 901) | bool optixReportIntersection( float hitT, unsigned int hitKind )
function optixReportIntersection (line 913) | bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned...
function optixReportIntersection (line 925) | bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned...
function optixReportIntersection (line 937) | bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned...
function optixReportIntersection (line 949) | bool optixReportIntersection( float hitT,
function optixReportIntersection (line 966) | bool optixReportIntersection( float hitT,
function optixReportIntersection (line 984) | bool optixReportIntersection( float hitT,
function optixReportIntersection (line 1003) | bool optixReportIntersection( float hitT,
function optixReportIntersection (line 1023) | bool optixReportIntersection( float hitT,
function optixGetAttribute_0 (line 1049) | unsigned int optixGetAttribute_0()
function optixGetAttribute_1 (line 1054) | unsigned int optixGetAttribute_1()
function optixGetAttribute_2 (line 1059) | unsigned int optixGetAttribute_2()
function optixGetAttribute_3 (line 1064) | unsigned int optixGetAttribute_3()
function optixGetAttribute_4 (line 1069) | unsigned int optixGetAttribute_4()
function optixGetAttribute_5 (line 1074) | unsigned int optixGetAttribute_5()
function optixGetAttribute_6 (line 1079) | unsigned int optixGetAttribute_6()
function optixGetAttribute_7 (line 1084) | unsigned int optixGetAttribute_7()
function optixTerminateRay (line 1091) | void optixTerminateRay()
function optixIgnoreIntersection (line 1096) | void optixIgnoreIntersection()
function optixGetPrimitiveIndex (line 1101) | unsigned int optixGetPrimitiveIndex()
function optixGetSbtGASIndex (line 1108) | unsigned int optixGetSbtGASIndex()
function optixGetInstanceId (line 1115) | unsigned int optixGetInstanceId()
function optixGetInstanceIndex (line 1122) | unsigned int optixGetInstanceIndex()
function optixGetHitKind (line 1129) | unsigned int optixGetHitKind()
function OptixPrimitiveType (line 1136) | OptixPrimitiveType optixGetPrimitiveType(unsigned int hitKind)
function optixIsBackFaceHit (line 1143) | bool optixIsBackFaceHit( unsigned int hitKind )
function optixIsFrontFaceHit (line 1150) | bool optixIsFrontFaceHit( unsigned int hitKind )
function OptixPrimitiveType (line 1156) | OptixPrimitiveType optixGetPrimitiveType()
function optixIsBackFaceHit (line 1161) | bool optixIsBackFaceHit()
function optixIsFrontFaceHit (line 1166) | bool optixIsFrontFaceHit()
function optixIsTriangleHit (line 1171) | bool optixIsTriangleHit()
function optixIsTriangleFrontFaceHit (line 1176) | bool optixIsTriangleFrontFaceHit()
function optixIsTriangleBackFaceHit (line 1181) | bool optixIsTriangleBackFaceHit()
function optixGetCurveParameter (line 1186) | float optixGetCurveParameter()
function float2 (line 1191) | float2 optixGetTriangleBarycentrics()
function uint3 (line 1198) | uint3 optixGetLaunchIndex()
function uint3 (line 1207) | uint3 optixGetLaunchDimensions()
function CUdeviceptr (line 1216) | CUdeviceptr optixGetSbtDataPointer()
function optixThrowException (line 1223) | void optixThrowException( int exceptionCode )
function optixThrowException (line 1232) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
function optixThrowException (line 1241) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
function optixThrowException (line 1250) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
function optixThrowException (line 1259) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
function optixThrowException (line 1268) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
function optixThrowException (line 1277) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
function optixThrowException (line 1286) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
function optixThrowException (line 1295) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
function optixGetExceptionCode (line 1304) | int optixGetExceptionCode()
function optixGetExceptionDetail_0 (line 1316) | unsigned int optixGetExceptionDetail_0()
function optixGetExceptionDetail_1 (line 1321) | unsigned int optixGetExceptionDetail_1()
function optixGetExceptionDetail_2 (line 1326) | unsigned int optixGetExceptionDetail_2()
function optixGetExceptionDetail_3 (line 1331) | unsigned int optixGetExceptionDetail_3()
function optixGetExceptionDetail_4 (line 1336) | unsigned int optixGetExceptionDetail_4()
function optixGetExceptionDetail_5 (line 1341) | unsigned int optixGetExceptionDetail_5()
function optixGetExceptionDetail_6 (line 1346) | unsigned int optixGetExceptionDetail_6()
function optixGetExceptionDetail_7 (line 1351) | unsigned int optixGetExceptionDetail_7()
function OptixTraversableHandle (line 1358) | OptixTraversableHandle optixGetExceptionInvalidTraversable()
function optixGetExceptionInvalidSbtOffset (line 1365) | int optixGetExceptionInvalidSbtOffset()
function OptixInvalidRayExceptionDetails (line 1372) | OptixInvalidRayExceptionDetails optixGetExceptionInvalidRay()
function OptixParameterMismatchExceptionDetails (line 1388) | OptixParameterMismatchExceptionDetails optixGetExceptionParameterMismatch()
function ReturnT (line 1411) | ReturnT optixDirectCall( unsigned int sbtIndex, ArgTypes... args )
function ReturnT (line 1421) | ReturnT optixContinuationCall( unsigned int sbtIndex, ArgTypes... args )
function uint4 (line 1431) | uint4 optixTexFootprint2D( unsigned long long tex, unsigned int texInfo,...
function uint4 (line 1447) | uint4 optixTexFootprint2DGrad( unsigned long long tex,
function uint4 (line 1474) | uint4
FILE: render/optixutils/include/internal/optix_7_device_impl_exception.h
function namespace (line 40) | namespace optix_impl {
FILE: render/optixutils/include/internal/optix_7_device_impl_transformations.h
function namespace (line 36) | namespace optix_impl {
FILE: render/optixutils/include/optix_7_types.h
type CUdeviceptr (line 50) | typedef unsigned long long CUdeviceptr;
type CUdeviceptr (line 53) | typedef unsigned int CUdeviceptr;
type OptixDeviceContext_t (line 57) | struct OptixDeviceContext_t
type OptixModule_t (line 60) | struct OptixModule_t
type OptixProgramGroup_t (line 63) | struct OptixProgramGroup_t
type OptixPipeline_t (line 66) | struct OptixPipeline_t
type OptixDenoiser_t (line 69) | struct OptixDenoiser_t
type OptixTraversableHandle (line 72) | typedef unsigned long long OptixTraversableHandle;
type OptixVisibilityMask (line 75) | typedef unsigned int OptixVisibilityMask;
type OptixResult (line 112) | typedef enum OptixResult
type OptixDeviceProperty (line 156) | typedef enum OptixDeviceProperty
type OptixDeviceContextValidationMode (line 225) | typedef enum OptixDeviceContextValidationMode
type OptixDeviceContextOptions (line 234) | typedef struct OptixDeviceContextOptions
type OptixGeometryFlags (line 249) | typedef enum OptixGeometryFlags
type OptixHitKind (line 269) | typedef enum OptixHitKind
type OptixIndicesFormat (line 278) | typedef enum OptixIndicesFormat
type OptixVertexFormat (line 289) | typedef enum OptixVertexFormat
type OptixTransformFormat (line 301) | typedef enum OptixTransformFormat
type OptixBuildInputTriangleArray (line 310) | typedef struct OptixBuildInputTriangleArray
type OptixPrimitiveType (line 381) | typedef enum OptixPrimitiveType
type OptixPrimitiveTypeFlags (line 398) | typedef enum OptixPrimitiveTypeFlags
type OptixBuildInputCurveArray (line 429) | typedef struct OptixBuildInputCurveArray
type OptixAabb (line 480) | typedef struct OptixAabb
type OptixBuildInputCustomPrimitiveArray (line 493) | typedef struct OptixBuildInputCustomPrimitiveArray
type OptixBuildInputInstanceArray (line 538) | typedef struct OptixBuildInputInstanceArray
type OptixBuildInputType (line 556) | typedef enum OptixBuildInputType
type OptixBuildInput (line 575) | typedef struct OptixBuildInput
type OptixInstanceFlags (line 606) | typedef enum OptixInstanceFlags
type OptixInstance (line 638) | typedef struct OptixInstance
type OptixBuildFlags (line 668) | typedef enum OptixBuildFlags
type OptixBuildOperation (line 706) | typedef enum OptixBuildOperation
type OptixMotionFlags (line 717) | typedef enum OptixMotionFlags
type OptixMotionOptions (line 728) | typedef struct OptixMotionOptions
type OptixAccelBuildOptions (line 747) | typedef struct OptixAccelBuildOptions
type OptixAccelBufferSizes (line 767) | typedef struct OptixAccelBufferSizes
type OptixAccelPropertyType (line 787) | typedef enum OptixAccelPropertyType
type OptixAccelEmitDesc (line 799) | typedef struct OptixAccelEmitDesc
type OptixAccelRelocationInfo (line 811) | typedef struct OptixAccelRelocationInfo
type OptixStaticTransform (line 822) | typedef struct OptixStaticTransform
type OptixMatrixMotionTransform (line 862) | typedef struct OptixMatrixMotionTransform
type OptixSRTData (line 907) | typedef struct OptixSRTData
type OptixSRTMotionTransform (line 944) | typedef struct OptixSRTMotionTransform
type OptixTraversableType (line 967) | typedef enum OptixTraversableType
type OptixPixelFormat (line 980) | typedef enum OptixPixelFormat
type OptixImage2D (line 995) | typedef struct OptixImage2D
type OptixDenoiserModelKind (line 1015) | typedef enum OptixDenoiserModelKind
type OptixDenoiserOptions (line 1034) | typedef struct OptixDenoiserOptions
type OptixDenoiserGuideLayer (line 1046) | typedef struct OptixDenoiserGuideLayer
type OptixDenoiserLayer (line 1061) | typedef struct OptixDenoiserLayer
type OptixDenoiserParams (line 1078) | typedef struct OptixDenoiserParams
type OptixDenoiserSizes (line 1104) | typedef struct OptixDenoiserSizes
type OptixRayFlags (line 1116) | typedef enum OptixRayFlags
type OptixTransformType (line 1172) | typedef enum OptixTransformType
type OptixTraversableGraphFlags (line 1183) | typedef enum OptixTraversableGraphFlags
type OptixCompileOptimizationLevel (line 1204) | typedef enum OptixCompileOptimizationLevel
type OptixCompileDebugLevel (line 1221) | typedef enum OptixCompileDebugLevel
type OptixModuleCompileBoundValueEntry (line 1268) | typedef struct OptixModuleCompileBoundValueEntry {
type OptixModuleCompileOptions (line 1280) | typedef struct OptixModuleCompileOptions
type OptixProgramGroupKind (line 1302) | typedef enum OptixProgramGroupKind
type OptixProgramGroupFlags (line 1326) | typedef enum OptixProgramGroupFlags
type OptixProgramGroupSingleModule (line 1338) | typedef struct OptixProgramGroupSingleModule
type OptixProgramGroupHitgroup (line 1351) | typedef struct OptixProgramGroupHitgroup
type OptixProgramGroupCallables (line 1372) | typedef struct OptixProgramGroupCallables
type OptixProgramGroupDesc (line 1385) | typedef struct OptixProgramGroupDesc
type OptixProgramGroupOptions (line 1411) | typedef struct OptixProgramGroupOptions
type OptixExceptionCodes (line 1418) | typedef enum OptixExceptionCodes
type OptixExceptionFlags (line 1521) | typedef enum OptixExceptionFlags
type OptixPipelineCompileOptions (line 1545) | typedef struct OptixPipelineCompileOptions
type OptixPipelineLinkOptions (line 1581) | typedef struct OptixPipelineLinkOptions
type OptixShaderBindingTable (line 1594) | typedef struct OptixShaderBindingTable
type OptixStackSizes (line 1634) | typedef struct OptixStackSizes
type OptixQueryFunctionTableOptions (line 1654) | typedef enum OptixQueryFunctionTableOptions
type OptixResult (line 1662) | typedef OptixResult( OptixQueryFunctionTable_t )( int abiId,
type OptixBuiltinISOptions (line 1673) | typedef struct OptixBuiltinISOptions
type OptixInvalidRayExceptionDetails (line 1685) | typedef struct OptixInvalidRayExceptionDetails
type OptixParameterMismatchExceptionDetails (line 1700) | typedef struct OptixParameterMismatchExceptionDetails
FILE: render/optixutils/include/optix_denoiser_tiling.h
type OptixUtilDenoiserImageTile (line 54) | struct OptixUtilDenoiserImageTile
function optixUtilGetPixelStride (line 73) | inline unsigned int optixUtilGetPixelStride( const OptixImage2D& image )
function OptixResult (line 118) | inline OptixResult optixUtilDenoiserSplitImage(
function OptixResult (line 206) | inline OptixResult optixUtilDenoiserInvokeTiled(
FILE: render/optixutils/include/optix_function_table.h
type OptixFunctionTable (line 55) | typedef struct OptixFunctionTable
FILE: render/optixutils/include/optix_stack_size.h
function OptixResult (line 52) | inline OptixResult optixUtilAccumulateStackSizes( OptixProgramGroup prog...
function OptixResult (line 86) | inline OptixResult optixUtilComputeStackSizes( const OptixStackSizes* st...
function OptixResult (line 151) | inline OptixResult optixUtilComputeStackSizesDCSplit( const OptixStackSi...
function OptixResult (line 212) | inline OptixResult optixUtilComputeStackSizesCssCCTree( const OptixStack...
function OptixResult (line 263) | inline OptixResult optixUtilComputeStackSizesSimplePathTracer( OptixProg...
FILE: render/optixutils/include/optix_stubs.h
function OptixResult (line 188) | inline OptixResult optixInitWithHandle( void** handlePtr )
function OptixResult (line 224) | inline OptixResult optixInit( void )
function OptixResult (line 235) | inline OptixResult optixUninitWithHandle( void* handle )
function OptixResult (line 318) | inline OptixResult optixDeviceContextCreate( CUcontext fromContext, cons...
function OptixResult (line 323) | inline OptixResult optixDeviceContextDestroy( OptixDeviceContext context )
function OptixResult (line 328) | inline OptixResult optixDeviceContextGetProperty( OptixDeviceContext con...
function OptixResult (line 333) | inline OptixResult optixDeviceContextSetLogCallback( OptixDeviceContext ...
function OptixResult (line 341) | inline OptixResult optixDeviceContextSetCacheEnabled( OptixDeviceContext...
function OptixResult (line 346) | inline OptixResult optixDeviceContextSetCacheLocation( OptixDeviceContex...
function OptixResult (line 351) | inline OptixResult optixDeviceContextSetCacheDatabaseSizes( OptixDeviceC...
function OptixResult (line 356) | inline OptixResult optixDeviceContextGetCacheEnabled( OptixDeviceContext...
function OptixResult (line 361) | inline OptixResult optixDeviceContextGetCacheLocation( OptixDeviceContex...
function OptixResult (line 366) | inline OptixResult optixDeviceContextGetCacheDatabaseSizes( OptixDeviceC...
function OptixResult (line 371) | inline OptixResult optixModuleCreateFromPTX( OptixDeviceContext ...
function OptixResult (line 384) | inline OptixResult optixModuleDestroy( OptixModule module )
function OptixResult (line 389) | inline OptixResult optixBuiltinISModuleGet( OptixDeviceContext ...
function OptixResult (line 399) | inline OptixResult optixProgramGroupCreate( OptixDeviceContext ...
function OptixResult (line 411) | inline OptixResult optixProgramGroupDestroy( OptixProgramGroup programGr...
function OptixResult (line 416) | inline OptixResult optixProgramGroupGetStackSize( OptixProgramGroup prog...
function OptixResult (line 421) | inline OptixResult optixPipelineCreate( OptixDeviceContext ...
function OptixResult (line 434) | inline OptixResult optixPipelineDestroy( OptixPipeline pipeline )
function OptixResult (line 439) | inline OptixResult optixPipelineSetStackSize( OptixPipeline pipeline,
function OptixResult (line 449) | inline OptixResult optixAccelComputeMemoryUsage( OptixDeviceContext ...
function OptixResult (line 458) | inline OptixResult optixAccelBuild( OptixDeviceContext context,
function OptixResult (line 477) | inline OptixResult optixAccelGetRelocationInfo( OptixDeviceContext conte...
function OptixResult (line 483) | inline OptixResult optixAccelCheckRelocationCompatibility( OptixDeviceCo...
function OptixResult (line 488) | inline OptixResult optixAccelRelocate( OptixDeviceContext c...
function OptixResult (line 501) | inline OptixResult optixAccelCompact( OptixDeviceContext context,
function OptixResult (line 511) | inline OptixResult optixConvertPointerToTraversableHandle( OptixDeviceCo...
function OptixResult (line 519) | inline OptixResult optixSbtRecordPackHeader( OptixProgramGroup programGr...
function OptixResult (line 524) | inline OptixResult optixLaunch( OptixPipeline pipeline,
function OptixResult (line 536) | inline OptixResult optixDenoiserCreate( OptixDeviceContext context, Opti...
function OptixResult (line 541) | inline OptixResult optixDenoiserCreateWithUserModel( OptixDeviceContext ...
function OptixResult (line 546) | inline OptixResult optixDenoiserDestroy( OptixDenoiser handle )
function OptixResult (line 551) | inline OptixResult optixDenoiserComputeMemoryResources( const OptixDenoi...
function OptixResult (line 559) | inline OptixResult optixDenoiserSetup( OptixDenoiser denoiser,
function OptixResult (line 572) | inline OptixResult optixDenoiserInvoke( OptixDenoiser ...
function OptixResult (line 590) | inline OptixResult optixDenoiserComputeIntensity( OptixDenoiser ha...
function OptixResult (line 600) | inline OptixResult optixDenoiserComputeAverageColor( OptixDenoiser ...
FILE: render/optixutils/ops.py
function find_cl_path (line 25) | def find_cl_path():
class _optix_env_shade_func (line 81) | class _optix_env_shade_func(torch.autograd.Function):
method forward (line 85) | def forward(ctx, optix_ctx, mask, ro, gb_pos, gb_normal, gb_view_pos, ...
method backward (line 101) | def backward(ctx, diff_grad, spec_grad):
class _bilateral_denoiser_func (line 110) | class _bilateral_denoiser_func(torch.autograd.Function):
method forward (line 112) | def forward(ctx, col, nrm, zdz, sigma):
method backward (line 119) | def backward(ctx, out_grad):
class OptiXContext (line 128) | class OptiXContext:
method __init__ (line 129) | def __init__(self):
function optix_build_bvh (line 133) | def optix_build_bvh(optix_ctx, verts, tris, rebuild):
function optix_env_shade (line 141) | def optix_env_shade(optix_ctx, mask, ro, gb_pos, gb_normal, gb_view_pos,...
function bilateral_denoiser (line 145) | def bilateral_denoiser(col, nrm, zdz, sigma):
FILE: render/optixutils/tests/filter_test.py
function length (line 22) | def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
function safe_normalize (line 25) | def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
function dot (line 28) | def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
class BilateralDenoiser (line 31) | class BilateralDenoiser(torch.nn.Module):
method __init__ (line 32) | def __init__(self, sigma=1.0):
method set_sigma (line 36) | def set_sigma(self, sigma):
method forward (line 41) | def forward(self, input):
function relative_loss (line 76) | def relative_loss(name, ref, cuda):
function test_filter (line 84) | def test_filter():
FILE: render/regularizer.py
function luma (line 16) | def luma(x):
function value (line 18) | def value(x):
function chroma_loss (line 21) | def chroma_loss(kd, color_ref, lambda_chroma):
function shading_loss (line 28) | def shading_loss(diffuse_light, specular_light, color_ref, lambda_diffus...
function material_smoothness_grad (line 46) | def material_smoothness_grad(kd_grad, ks_grad, nrm_grad, lambda_kd=0.25,...
function image_grad (line 56) | def image_grad(buf, std=0.01):
function avg_edge_length (line 68) | def avg_edge_length(v_pos, t_pos_idx):
function laplace_regularizer_const (line 77) | def laplace_regularizer_const(v_pos, t_pos_idx):
function normal_consistency (line 101) | def normal_consistency(v_pos, t_pos_idx):
FILE: render/render.py
function interpolate (line 25) | def interpolate(attr, rast, attr_idx, rast_db=None):
function shade (line 31) | def shade(
function render_layer (line 199) | def render_layer(
function render_mesh (line 325) | def render_mesh(
function render_uv (line 449) | def render_uv(ctx, mesh, resolution, mlp_texture):
FILE: render/renderutils/bsdf.py
function _dot (line 19) | def _dot(x, y):
function _reflect (line 22) | def _reflect(x, n):
function _safe_normalize (line 25) | def _safe_normalize(x):
function _bend_normal (line 28) | def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
function _perturb_normal (line 38) | def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
function bsdf_prepare_shading_normal (line 46) | def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm...
function bsdf_lambert (line 57) | def bsdf_lambert(nrm, wi):
function bsdf_frostbite (line 64) | def bsdf_frostbite(nrm, wi, wo, linearRoughness):
function bsdf_phong (line 85) | def bsdf_phong(nrm, wo, wi, N):
function bsdf_fresnel_shlick (line 96) | def bsdf_fresnel_shlick(f0, f90, cosTheta):
function bsdf_ndf_ggx (line 100) | def bsdf_ndf_ggx(alphaSqr, cosTheta):
function bsdf_lambda_ggx (line 105) | def bsdf_lambda_ggx(alphaSqr, cosTheta):
function bsdf_masking_smith_ggx_correlated (line 112) | def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
function bsdf_pbr_specular (line 117) | def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
function bsdf_pbr (line 136) | def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):
FILE: render/renderutils/c_src/bsdf.h
type LambertKernelParams (line 16) | struct LambertKernelParams
type FrostbiteDiffuseKernelParams (line 24) | struct FrostbiteDiffuseKernelParams
type FresnelShlickKernelParams (line 34) | struct FresnelShlickKernelParams
type NdfGGXParams (line 43) | struct NdfGGXParams
type MaskingSmithParams (line 51) | struct MaskingSmithParams
type PbrSpecular (line 60) | struct PbrSpecular
type PbrBSDF (line 72) | struct PbrBSDF
FILE: render/renderutils/c_src/common.cpp
function dim3 (line 18) | dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)
function dim3 (line 56) | dim3 getWarpSize(dim3 blockSize)
function dim3 (line 65) | dim3 getLaunchGridSize(dim3 blockSize, dim3 dims)
FILE: render/renderutils/c_src/common.h
function dim3 (line 29) | static inline dim3 getWarpSize(dim3 blockSize)
function __device__ (line 38) | __device__ static inline float clamp(float val, float mn, float mx) { re...
FILE: render/renderutils/c_src/cubemap.h
type DiffuseCubemapKernelParams (line 16) | struct DiffuseCubemapKernelParams
type SpecularCubemapKernelParams (line 23) | struct SpecularCubemapKernelParams
type SpecularBoundsKernelParams (line 33) | struct SpecularBoundsKernelParams
FILE: render/renderutils/c_src/loss.h
type TonemapperType (line 16) | enum TonemapperType
type LossType (line 22) | enum LossType
type LossKernelParams (line 30) | struct LossKernelParams
FILE: render/renderutils/c_src/mesh.h
type XfmKernelParams (line 16) | struct XfmKernelParams
FILE: render/renderutils/c_src/normal.h
type PrepareShadingNormalKernelParams (line 16) | struct PrepareShadingNormalKernelParams
FILE: render/renderutils/c_src/tensor.h
type Tensor (line 20) | struct Tensor
function store (line 65) | inline void store(unsigned int x, unsigned int y, unsigned int z, float ...
function __device__ (line 70) | __device__ inline void store(unsigned int x, unsigned int y, unsigned in...
function __device__ (line 79) | __device__ inline void store_grad(unsigned int x, unsigned int y, unsign...
function __device__ (line 84) | __device__ inline void store_grad(unsigned int x, unsigned int y, unsign...
FILE: render/renderutils/c_src/torch_bindings.cpp
function update_grid (line 100) | void update_grid(dim3 &gridSize, torch::Tensor x)
function update_grid (line 108) | void update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs)
function Tensor (line 116) | Tensor make_cuda_tensor(torch::Tensor val)
function Tensor (line 130) | Tensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* ...
function prepare_shading_normal_fwd (line 161) | torch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tenso...
function prepare_shading_normal_bwd (line 203) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, t...
function lambert_fwd (line 237) | torch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16)
function lambert_bwd (line 268) | std::tuple<torch::Tensor, torch::Tensor> lambert_bwd(torch::Tensor nrm, ...
function frostbite_fwd (line 295) | torch::Tensor frostbite_fwd(torch::Tensor nrm, torch::Tensor wi, torch::...
function frostbite_bwd (line 330) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> f...
function fresnel_shlick_fwd (line 359) | torch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, to...
function fresnel_shlick_bwd (line 392) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> fresnel_shlick_b...
function ndf_ggx_fwd (line 420) | torch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta...
function ndf_ggx_bwd (line 451) | std::tuple<torch::Tensor, torch::Tensor> ndf_ggx_bwd(torch::Tensor alpha...
function lambda_ggx_fwd (line 478) | torch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTh...
function lambda_ggx_bwd (line 509) | std::tuple<torch::Tensor, torch::Tensor> lambda_ggx_bwd(torch::Tensor al...
function masking_smith_fwd (line 536) | torch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor co...
function masking_smith_bwd (line 569) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> masking_smith_bw...
function pbr_specular_fwd (line 597) | torch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, tor...
function pbr_specular_bwd (line 635) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, t...
function pbr_bsdf_fwd (line 666) | torch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::T...
function pbr_bsdf_bwd (line 707) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, t...
function diffuse_cubemap_fwd (line 740) | torch::Tensor diffuse_cubemap_fwd(torch::Tensor cubemap)
function diffuse_cubemap_bwd (line 769) | torch::Tensor diffuse_cubemap_bwd(torch::Tensor cubemap, torch::Tensor g...
function specular_bounds (line 799) | torch::Tensor specular_bounds(int resolution, float costheta_cutoff)
function specular_cubemap_fwd (line 826) | torch::Tensor specular_cubemap_fwd(torch::Tensor cubemap, torch::Tensor ...
function specular_cubemap_bwd (line 859) | torch::Tensor specular_cubemap_bwd(torch::Tensor cubemap, torch::Tensor ...
function LossType (line 895) | LossType strToLoss(std::string str)
function image_loss_fwd (line 907) | torch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, st...
function image_loss_bwd (line 941) | std::tuple<torch::Tensor, torch::Tensor> image_loss_bwd(torch::Tensor im...
function xfm_fwd (line 971) | torch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool i...
function xfm_bwd (line 1006) | torch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch:...
function PYBIND11_MODULE (line 1034) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: render/renderutils/c_src/vec3f.h
type vec3f (line 14) | struct vec3f
function __device__ (line 38) | __device__ static inline float sum(vec3f a)
function __device__ (line 43) | __device__ static inline vec3f cross(vec3f a, vec3f b)
function __device__ (line 52) | __device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec...
function __device__ (line 63) | __device__ static inline float dot(vec3f a, vec3f b)
function __device__ (line 68) | __device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f...
function __device__ (line 74) | __device__ static inline vec3f reflect(vec3f x, vec3f n)
function __device__ (line 79) | __device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, v...
function __device__ (line 90) | __device__ static inline vec3f safeNormalize(vec3f v)
function __device__ (line 96) | __device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v...
FILE: render/renderutils/c_src/vec4f.h
type vec4f (line 14) | struct vec4f
FILE: render/renderutils/loss.py
function _tonemap_srgb (line 16) | def _tonemap_srgb(f, exposure=5):
function _SMAPE (line 20) | def _SMAPE(img, target, eps=0.01):
function _RELMSE (line 25) | def _RELMSE(img, target, eps=0.1):
function image_loss_fn (line 30) | def image_loss_fn(img, target, loss, tonemapper):
FILE: render/renderutils/ops.py
function _get_plugin (line 23) | def _get_plugin():
class _fresnel_shlick_func (line 92) | class _fresnel_shlick_func(torch.autograd.Function):
method forward (line 94) | def forward(ctx, f0, f90, cosTheta):
method backward (line 100) | def backward(ctx, dout):
function _fresnel_shlick (line 104) | def _fresnel_shlick(f0, f90, cosTheta, use_python=False):
class _ndf_ggx_func (line 115) | class _ndf_ggx_func(torch.autograd.Function):
method forward (line 117) | def forward(ctx, alphaSqr, cosTheta):
method backward (line 123) | def backward(ctx, dout):
function _ndf_ggx (line 127) | def _ndf_ggx(alphaSqr, cosTheta, use_python=False):
class _lambda_ggx_func (line 137) | class _lambda_ggx_func(torch.autograd.Function):
method forward (line 139) | def forward(ctx, alphaSqr, cosTheta):
method backward (line 145) | def backward(ctx, dout):
function _lambda_ggx (line 149) | def _lambda_ggx(alphaSqr, cosTheta, use_python=False):
class _masking_smith_func (line 159) | class _masking_smith_func(torch.autograd.Function):
method forward (line 161) | def forward(ctx, alphaSqr, cosThetaI, cosThetaO):
method backward (line 167) | def backward(ctx, dout):
function _masking_smith (line 171) | def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False):
class _prepare_shading_normal_func (line 184) | class _prepare_shading_normal_func(torch.autograd.Function):
method forward (line 186) | def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng,...
method backward (line 193) | def backward(ctx, dout):
function prepare_shading_normal (line 197) | def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smo...
class _lambert_func (line 235) | class _lambert_func(torch.autograd.Function):
method forward (line 237) | def forward(ctx, nrm, wi):
method backward (line 243) | def backward(ctx, dout):
function lambert (line 247) | def lambert(nrm, wi, use_python=False):
class _frostbite_diffuse_func (line 269) | class _frostbite_diffuse_func(torch.autograd.Function):
method forward (line 271) | def forward(ctx, nrm, wi, wo, linearRoughness):
method backward (line 277) | def backward(ctx, dout):
function frostbite_diffuse (line 281) | def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False):
class _pbr_specular_func (line 305) | class _pbr_specular_func(torch.autograd.Function):
method forward (line 307) | def forward(ctx, col, nrm, wo, wi, alpha, min_roughness):
method backward (line 314) | def backward(ctx, dout):
function pbr_specular (line 318) | def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python...
class _pbr_bsdf_func (line 344) | class _pbr_bsdf_func(torch.autograd.Function):
method forward (line 346) | def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness...
method backward (line 354) | def backward(ctx, dout):
function pbr_bsdf (line 358) | def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08,...
class _diffuse_cubemap_func (line 394) | class _diffuse_cubemap_func(torch.autograd.Function):
method forward (line 396) | def forward(ctx, cubemap):
method backward (line 402) | def backward(ctx, dout):
function diffuse_cubemap (line 407) | def diffuse_cubemap(cubemap, use_python=False):
class _specular_cubemap (line 416) | class _specular_cubemap(torch.autograd.Function):
method forward (line 418) | def forward(ctx, cubemap, roughness, costheta_cutoff, bounds):
method backward (line 425) | def backward(ctx, dout):
function __ndfBounds (line 431) | def __ndfBounds(res, roughness, cutoff):
function specular_cubemap (line 449) | def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False):
class _image_loss_func (line 466) | class _image_loss_func(torch.autograd.Function):
method forward (line 468) | def forward(ctx, img, target, loss, tonemapper):
method backward (line 475) | def backward(ctx, dout):
function image_loss (line 479) | def image_loss(img, target, loss='l1', tonemapper='none', use_python=Fal...
class _xfm_func (line 506) | class _xfm_func(torch.autograd.Function):
method forward (line 508) | def forward(ctx, points, matrix, isPoints):
method backward (line 514) | def backward(ctx, dout):
function xfm_points (line 518) | def xfm_points(points, matrix, use_python=False):
function xfm_vectors (line 536) | def xfm_vectors(vectors, matrix, use_python=False):
FILE: render/renderutils/tests/test_bsdf.py
function relative_loss (line 20) | def relative_loss(name, ref, cuda):
function test_normal (line 25) | def test_normal():
function test_schlick (line 59) | def test_schlick():
function test_ndf_ggx (line 85) | def test_ndf_ggx():
function test_lambda_ggx (line 109) | def test_lambda_ggx():
function test_masking_smith (line 132) | def test_masking_smith():
function test_lambert (line 157) | def test_lambert():
function test_frostbite (line 179) | def test_frostbite():
function test_pbr_specular (line 207) | def test_pbr_specular():
function test_pbr_bsdf (line 244) | def test_pbr_bsdf(bsdf):
FILE: render/renderutils/tests/test_loss.py
function tonemap_srgb (line 20) | def tonemap_srgb(f):
function l1 (line 23) | def l1(output, target):
function relative_loss (line 30) | def relative_loss(name, ref, cuda):
function test_loss (line 35) | def test_loss(loss, tonemapper):
FILE: render/renderutils/tests/test_mesh.py
function tonemap_srgb (line 23) | def tonemap_srgb(f):
function l1 (line 26) | def l1(output, target):
function relative_loss (line 33) | def relative_loss(name, ref, cuda):
function test_xfm_points (line 38) | def test_xfm_points():
function test_xfm_vectors (line 58) | def test_xfm_vectors():
FILE: render/renderutils/tests/test_perf.py
function test_bsdf (line 19) | def test_bsdf(BATCH, RES, ITR):
FILE: render/texture.py
class texture2d_mip (line 20) | class texture2d_mip(torch.autograd.Function):
method forward (line 22) | def forward(ctx, texture):
method backward (line 26) | def backward(ctx, dout):
class Texture2D (line 38) | class Texture2D:
method __init__ (line 41) | def __init__(self, init, min_max=None):
method sample (line 57) | def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'):
method getRes (line 70) | def getRes(self):
method getChannels (line 73) | def getChannels(self):
method getMips (line 76) | def getMips(self):
method parameters (line 82) | def parameters(self):
method clamp_ (line 86) | def clamp_(self):
method normalize_ (line 93) | def normalize_(self):
function create_trainable (line 103) | def create_trainable(init, res=None, auto_mipmaps=True, min_max=None):
function srgb_to_rgb (line 137) | def srgb_to_rgb(texture):
function rgb_to_srgb (line 140) | def rgb_to_srgb(texture):
function _load_mip2D (line 147) | def _load_mip2D(fn, lambda_fn=None, channels=None):
function load_texture2D (line 155) | def load_texture2D(fn, lambda_fn=None, channels=None):
function _save_mip2D (line 165) | def _save_mip2D(fn, mip, mipidx, lambda_fn):
function save_texture2D (line 177) | def save_texture2D(fn, tex, lambda_fn=None):
FILE: render/util.py
function dot (line 19) | def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
function reflect (line 22) | def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
function length (line 25) | def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
function safe_normalize (line 28) | def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
function to_hvec (line 31) | def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:
function ycocg2rgb (line 34) | def ycocg2rgb(ycocg):
function hsv2rgb (line 41) | def hsv2rgb(image): # Based on https://kornia.readthedocs.io/en/latest/_...
function pixel_grid (line 61) | def pixel_grid(width, height, center_x = 0.5, center_y = 0.5):
function dilate (line 70) | def dilate(x, x_avg, mask, N):
function _rgb_to_srgb (line 94) | def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
function rgb_to_srgb (line 97) | def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
function _srgb_to_rgb (line 103) | def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
function srgb_to_rgb (line 106) | def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
function reinhard (line 112) | def reinhard(f: torch.Tensor) -> torch.Tensor:
function mse_to_psnr (line 122) | def mse_to_psnr(mse):
function psnr_to_mse (line 126) | def psnr_to_mse(psnr):
function get_miplevels (line 134) | def get_miplevels(texture: np.ndarray) -> float:
function tex_2d (line 138) | def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='neares...
function cube_to_dir (line 149) | def cube_to_dir(s, x, y):
function latlong_to_cubemap (line 158) | def latlong_to_cubemap(latlong_map, res):
function cubemap_to_latlong (line 173) | def cubemap_to_latlong(cubemap, res):
function scale_img_hwc (line 192) | def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') ->...
function scale_img_nhwc (line 195) | def scale_img_nhwc(x : torch.Tensor, size, mag='bilinear', min='area') ...
function avg_pool_nhwc (line 207) | def avg_pool_nhwc(x : torch.Tensor, size) -> torch.Tensor:
function segment_sum (line 216) | def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch....
function fovx_to_fovy (line 235) | def fovx_to_fovy(fovx, aspect):
function focal_length_to_fovy (line 238) | def focal_length_to_fovy(focal_length, sensor_height):
function perspective (line 242) | def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
function perspective_offcenter (line 250) | def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1...
function translate (line 274) | def translate(x, y, z, device=None):
function rotate_x (line 280) | def rotate_x(a, device=None):
function rotate_y (line 287) | def rotate_y(a, device=None):
function rotate_z (line 294) | def rotate_z(a, device=None):
function scale (line 301) | def scale(s, device=None):
function lookAt (line 307) | def lookAt(eye, at, up):
function random_rotation_translation (line 324) | def random_rotation_translation(t, device=None):
function random_rotation (line 335) | def random_rotation(device=None):
function lines_focal (line 350) | def lines_focal(o, d):
function cosine_sample (line 361) | def cosine_sample(N, size=None):
function bilinear_downsample (line 396) | def bilinear_downsample(x : torch.tensor) -> torch.Tensor:
function bilinear_downsample (line 406) | def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor:
function init_glfw (line 422) | def init_glfw():
function display_image (line 440) | def display_image(image, title=None):
function save_image (line 483) | def save_image(fn, x : np.ndarray) -> np.ndarray:
function save_image_raw (line 492) | def save_image_raw(fn, x : np.ndarray):
function load_image_raw (line 499) | def load_image_raw(fn) -> np.ndarray:
function load_image (line 502) | def load_image(fn) -> np.ndarray:
function time_to_text (line 511) | def time_to_text(x):
function checkerboard (line 521) | def checkerboard(res, checker_size) -> np.ndarray:
FILE: train_gflexicubes_deepfashion.py
function createLoss (line 52) | def createLoss(FLAGS):
function prepare_batch (line 71) | def prepare_batch(target, bg_type='black'):
function xatlas_uvmap (line 101) | def xatlas_uvmap(glctx, geometry, mat, FLAGS):
function initial_guess_material (line 155) | def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
function initial_guess_material_knownkskd (line 172) | def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
function validate_itr (line 190) | def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, deno...
function validate (line 237) | def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_d...
function optimize_mesh (line 288) | def optimize_mesh(
FILE: train_gflexicubes_polycam.py
function createLoss (line 52) | def createLoss(FLAGS):
function prepare_batch (line 71) | def prepare_batch(target, bg_type='black'):
function xatlas_uvmap (line 101) | def xatlas_uvmap(glctx, geometry, mat, FLAGS):
function initial_guess_material (line 155) | def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
function initial_guess_material_knownkskd (line 172) | def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
function validate_itr (line 190) | def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, deno...
function validate (line 209) | def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_d...
function optimize_mesh (line 260) | def optimize_mesh(
FILE: train_gshelltet_deepfashion.py
function createLoss (line 52) | def createLoss(FLAGS):
function prepare_batch (line 71) | def prepare_batch(target, bg_type='black'):
function xatlas_uvmap (line 101) | def xatlas_uvmap(glctx, geometry, mat, FLAGS):
function initial_guess_material (line 155) | def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
function initial_guess_material_knownkskd (line 172) | def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
function validate_itr (line 190) | def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, deno...
function validate (line 227) | def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_d...
function optimize_mesh (line 278) | def optimize_mesh(
FILE: train_gshelltet_polycam.py
function createLoss (line 52) | def createLoss(FLAGS):
function prepare_batch (line 71) | def prepare_batch(target, bg_type='black'):
function xatlas_uvmap (line 101) | def xatlas_uvmap(glctx, geometry, mat, FLAGS):
function initial_guess_material (line 155) | def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
function initial_guess_material_knownkskd (line 172) | def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
function validate_itr (line 190) | def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, deno...
function validate (line 209) | def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_d...
function optimize_mesh (line 260) | def optimize_mesh(
FILE: train_gshelltet_synthetic.py
function createLoss (line 51) | def createLoss(FLAGS):
function prepare_batch (line 70) | def prepare_batch(target, bg_type='black'):
function xatlas_uvmap (line 100) | def xatlas_uvmap(glctx, geometry, mat, FLAGS):
function initial_guess_material (line 154) | def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
function initial_guess_material_knownkskd (line 171) | def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
function validate_itr (line 189) | def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, deno...
function validate (line 212) | def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_d...
function optimize_mesh (line 263) | def optimize_mesh(
Condensed preview — 124 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,318K chars).
[
{
"path": ".gitignore",
"chars": 3091,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "GMeshDiffusion/diffusion_configs/config_lower_occgrid_normalized.py",
"chars": 3733,
"preview": "import ml_collections\nimport torch\nimport os\n\n\ndef get_config():\n config = ml_collections.ConfigDict()\n\n # data\n "
},
{
"path": "GMeshDiffusion/diffusion_configs/config_upper_occgrid_normalized.py",
"chars": 3775,
"preview": "import ml_collections\nimport torch\nimport os\n\n\ndef get_config():\n config = ml_collections.ConfigDict()\n\n # data\n "
},
{
"path": "GMeshDiffusion/lib/dataset/gshell_dataset.py",
"chars": 741,
"preview": "import torch\nimport numpy as np\nfrom torch.utils.data import Dataset\n\nclass GShellDataset(Dataset):\n def __init__(sel"
},
{
"path": "GMeshDiffusion/lib/dataset/gshell_dataset_aug.py",
"chars": 1231,
"preview": "import torch\nfrom torch.utils.data import Dataset\n\nclass GShellAugDataset(Dataset):\n def __init__(self, FLAGS, extens"
},
{
"path": "GMeshDiffusion/lib/diffusion/evaler.py",
"chars": 10808,
"preview": "import os\nimport sys\nimport numpy as np\nimport tqdm\n\nimport logging\nfrom . import losses\nfrom .models import utils as mu"
},
{
"path": "GMeshDiffusion/lib/diffusion/likelihood.py",
"chars": 4714,
"preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "GMeshDiffusion/lib/diffusion/losses.py",
"chars": 10283,
"preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "GMeshDiffusion/lib/diffusion/models/__init__.py",
"chars": 608,
"preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "GMeshDiffusion/lib/diffusion/models/ema.py",
"chars": 3672,
"preview": "# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py\n\nfrom __future__ import divi"
},
{
"path": "GMeshDiffusion/lib/diffusion/models/functional.py",
"chars": 9332,
"preview": "#################################################################################################\n# Copyright (c) 2023 A"
},
{
"path": "GMeshDiffusion/lib/diffusion/models/layers.py",
"chars": 11260,
"preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "GMeshDiffusion/lib/diffusion/models/normalization.py",
"chars": 7768,
"preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "GMeshDiffusion/lib/diffusion/models/unet3d_occgrid.py",
"chars": 8884,
"preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "GMeshDiffusion/lib/diffusion/models/utils.py",
"chars": 7257,
"preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "GMeshDiffusion/lib/diffusion/sampling.py",
"chars": 26124,
"preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
},
{
"path": "GMeshDiffusion/lib/diffusion/sde_lib.py",
"chars": 9305,
"preview": "\"\"\"Abstract SDE classes, Reverse SDE, and VE/VP SDEs.\"\"\"\nimport abc\nimport torch\nimport numpy as np\nimport torch.nn.func"
},
{
"path": "GMeshDiffusion/lib/diffusion/trainer.py",
"chars": 6286,
"preview": "import os\nimport sys\nimport numpy as np\n\nimport logging\n# Keep the import below for registering all model definitions\nfr"
},
{
"path": "GMeshDiffusion/lib/diffusion/trainer_ddp.py",
"chars": 7278,
"preview": "import os\nimport sys\nimport numpy as np\n\nimport logging\n# Keep the import below for registering all model definitions\nfr"
},
{
"path": "GMeshDiffusion/lib/diffusion/utils.py",
"chars": 1503,
"preview": "import torch\nimport os\nimport logging\n\n\ndef restore_checkpoint(ckpt_dir, state, device, strict=False, rank=None):\n if n"
},
{
"path": "GMeshDiffusion/main_diffusion.py",
"chars": 847,
"preview": "\"\"\"Training and evaluation\"\"\"\n\nfrom absl import app\nfrom absl import flags\nfrom ml_collections.config_flags import confi"
},
{
"path": "GMeshDiffusion/main_diffusion_ddp.py",
"chars": 635,
"preview": "\"\"\"Training and evaluation\"\"\"\n\nfrom absl import app\nfrom absl import flags\nfrom ml_collections.config_flags import confi"
},
{
"path": "GMeshDiffusion/metadata/get_splits_lower.py",
"chars": 974,
"preview": "import os\nimport random\n\nrandom.seed(42)\n\nsplit_ratio = 0.9\ndata_root = 'PLACEHOLDER'\ngrid_root = os.path.join(data_root"
},
{
"path": "GMeshDiffusion/metadata/get_splits_upper.py",
"chars": 974,
"preview": "import os\nimport random\n\nrandom.seed(42)\n\nsplit_ratio = 0.9\ndata_root = 'PLACEHOLDER'\ngrid_root = os.path.join(data_root"
},
{
"path": "GMeshDiffusion/metadata/save_tet_info.py",
"chars": 4023,
"preview": "'''\n Storing tet-grid related meta-info into a single file\n'''\n\nimport numpy as np\nimport torch\nimport os\nimport tqdm"
},
{
"path": "GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py",
"chars": 12533,
"preview": "import numpy as np\nimport torch\nimport os\nimport tqdm\nimport argparse\n\ndef tet_to_grids(vertices, values_list, grid_size"
},
{
"path": "GMeshDiffusion/scripts/run_eval_lower_occgrid_normalized.sh",
"chars": 299,
"preview": "python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_lower_occgrid_normalized.py \\\n--config.eval"
},
{
"path": "GMeshDiffusion/scripts/run_eval_upper_occgrid_normalized.sh",
"chars": 299,
"preview": "python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_upper_occgrid_normalized.py \\\n--config.eval"
},
{
"path": "GMeshDiffusion/scripts/run_lower_occgrid_normalized_ddp.sh",
"chars": 213,
"preview": "torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_lower_occgri"
},
{
"path": "GMeshDiffusion/scripts/run_upper_occgrid_normalized_ddp.sh",
"chars": 214,
"preview": "torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_upper_occgri"
},
{
"path": "README.md",
"chars": 7662,
"preview": "<div align=\"center\">\n <img src=\"assets/gshell_logo.png\" width=\"900\"/>\n</div>\n\n# Ghost on the Shell: An Expressive Repre"
},
{
"path": "configs/deepfashion_mc.json",
"chars": 655,
"preview": "{\n \"ref_mesh\": \"data/spot/spot.obj\",\n \"random_textures\": true,\n \"iter\": 5000,\n \"save_interval\": 100,\n \"te"
},
{
"path": "configs/deepfashion_mc_256.json",
"chars": 655,
"preview": "{\n \"ref_mesh\": \"data/spot/spot.obj\",\n \"random_textures\": true,\n \"iter\": 5000,\n \"save_interval\": 100,\n \"te"
},
{
"path": "configs/deepfashion_mc_512.json",
"chars": 679,
"preview": "{\n \"ref_mesh\": \"data/spot/spot.obj\",\n \"random_textures\": true,\n \"iter\": 5000,\n \"save_interval\": 100,\n \"te"
},
{
"path": "configs/deepfashion_mc_80.json",
"chars": 654,
"preview": "{\n \"ref_mesh\": \"data/spot/spot.obj\",\n \"random_textures\": true,\n \"iter\": 5000,\n \"save_interval\": 100,\n \"te"
},
{
"path": "configs/nerf_chair.json",
"chars": 541,
"preview": "{\n \"ref_mesh\": \"data/nerf_synthetic/chair\",\n \"random_textures\": true,\n \"iter\": 5000,\n \"save_interval\": 100,\n"
},
{
"path": "configs/polycam_mc.json",
"chars": 653,
"preview": "{\n \"ref_mesh\": \"data/spot/spot.obj\",\n \"random_textures\": true,\n \"iter\": 5000,\n \"save_interval\": 100,\n \"te"
},
{
"path": "configs/polycam_mc_128.json",
"chars": 653,
"preview": "{\n \"ref_mesh\": \"data/spot/spot.obj\",\n \"random_textures\": true,\n \"iter\": 5000,\n \"save_interval\": 100,\n \"te"
},
{
"path": "configs/polycam_mc_16samples.json",
"chars": 654,
"preview": "{\n \"ref_mesh\": \"data/spot/spot.obj\",\n \"random_textures\": true,\n \"iter\": 5000,\n \"save_interval\": 100,\n \"te"
},
{
"path": "data/tets/generate_tets.py",
"chars": 1775,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "dataset/__init__.py",
"chars": 572,
"preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
},
{
"path": "dataset/dataset.py",
"chars": 1866,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "dataset/dataset_deepfashion.py",
"chars": 4627,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "dataset/dataset_deepfashion_testset.py",
"chars": 4300,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "dataset/dataset_llff.py",
"chars": 4839,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "dataset/dataset_mesh.py",
"chars": 5842,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "dataset/dataset_nerf.py",
"chars": 3623,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "dataset/dataset_nerf_colmap.py",
"chars": 3738,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "denoiser/denoiser.py",
"chars": 1210,
"preview": "import os\n\nimport torch\nimport numpy as np\nimport math\n\nfrom render import util\nif \"TWOSIDED_TEXTURE\" not in os.environ "
},
{
"path": "eval_gmeshdiffusion_generated_samples.py",
"chars": 9619,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "geometry/embedding.py",
"chars": 1200,
"preview": "import torch\nfrom torch import nn\n\nclass Embedding(nn.Module):\n def __init__(self, in_channels, N_freqs, logscale=Tru"
},
{
"path": "geometry/flexicubes_table.py",
"chars": 41249,
"preview": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# NVIDIA CORPORATION & AFFILIATES and its"
},
{
"path": "geometry/gshell_flexicubes.py",
"chars": 40870,
"preview": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# NVIDIA CORPORATION & AFFILIATES and its"
},
{
"path": "geometry/gshell_flexicubes_geometry.py",
"chars": 16375,
"preview": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# NVIDIA CORPORATION & AFFILIATES and its"
},
{
"path": "geometry/gshell_tets.py",
"chars": 31153,
"preview": "import numpy as np\nimport torch\n\nfrom render import util\n\n##############################################################"
},
{
"path": "geometry/gshell_tets_geometry.py",
"chars": 17765,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "geometry/mlp.py",
"chars": 1372,
"preview": "import torch\nimport torch.nn as nn\nimport numpy as np\n\nfrom .embedding import Embedding\n\nclass MLP(nn.Module):\n def _"
},
{
"path": "render/light.py",
"chars": 4278,
"preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
},
{
"path": "render/material.py",
"chars": 7200,
"preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
},
{
"path": "render/mesh.py",
"chars": 12068,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/mlptexture.py",
"chars": 4650,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/obj.py",
"chars": 7794,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/optixutils/__init__.py",
"chars": 601,
"preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
},
{
"path": "render/optixutils/c_src/accessor.h",
"chars": 13669,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/c_src/bsdf.h",
"chars": 11222,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/c_src/common.h",
"chars": 5311,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/c_src/denoising.cu",
"chars": 5103,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/c_src/denoising.h",
"chars": 790,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/c_src/envsampling/kernel.cu",
"chars": 20280,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/c_src/envsampling/params.h",
"chars": 2054,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/c_src/math_utils.h",
"chars": 12168,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/c_src/optix_wrapper.cpp",
"chars": 13359,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/c_src/optix_wrapper.h",
"chars": 1088,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/c_src/torch_bindings.cpp",
"chars": 14563,
"preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
},
{
"path": "render/optixutils/include/internal/optix_7_device_impl.h",
"chars": 73134,
"preview": "/*\n* Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all in"
},
{
"path": "render/optixutils/include/internal/optix_7_device_impl_exception.h",
"chars": 15383,
"preview": "/*\n* Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all in"
},
{
"path": "render/optixutils/include/internal/optix_7_device_impl_transformations.h",
"chars": 17987,
"preview": "/*\n* Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all in"
},
{
"path": "render/optixutils/include/optix.h",
"chars": 1716,
"preview": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain al"
},
{
"path": "render/optixutils/include/optix_7_device.h",
"chars": 60451,
"preview": "/*\n* Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all in"
},
{
"path": "render/optixutils/include/optix_7_host.h",
"chars": 44487,
"preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all"
},
{
"path": "render/optixutils/include/optix_7_types.h",
"chars": 70493,
"preview": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain al"
},
{
"path": "render/optixutils/include/optix_denoiser_tiling.h",
"chars": 13614,
"preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * Redistribution and use in source and binary for"
},
{
"path": "render/optixutils/include/optix_device.h",
"chars": 2129,
"preview": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain al"
},
{
"path": "render/optixutils/include/optix_function_table.h",
"chars": 16836,
"preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all"
},
{
"path": "render/optixutils/include/optix_function_table_definition.h",
"chars": 1827,
"preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all"
},
{
"path": "render/optixutils/include/optix_host.h",
"chars": 1661,
"preview": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain al"
},
{
"path": "render/optixutils/include/optix_stack_size.h",
"chars": 17447,
"preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * Redistribution and use in source and binary for"
},
{
"path": "render/optixutils/include/optix_stubs.h",
"chars": 28668,
"preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * Redistribution and use in source and binary for"
},
{
"path": "render/optixutils/include/optix_types.h",
"chars": 1777,
"preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation. All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all"
},
{
"path": "render/optixutils/ops.py",
"chars": 7155,
"preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
},
{
"path": "render/optixutils/tests/filter_test.py",
"chars": 4260,
"preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
},
{
"path": "render/regularizer.py",
"chars": 5832,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/render.py",
"chars": 21371,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# NVIDIA CORPORATION, its affiliates a"
},
{
"path": "render/renderutils/__init__.py",
"chars": 942,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/renderutils/bsdf.py",
"chars": 6266,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/renderutils/c_src/bsdf.cu",
"chars": 26451,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/bsdf.h",
"chars": 1564,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/common.cpp",
"chars": 2285,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/common.h",
"chars": 1218,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/cubemap.cu",
"chars": 13015,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/cubemap.h",
"chars": 907,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/loss.cu",
"chars": 7355,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/loss.h",
"chars": 896,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/mesh.cu",
"chars": 3973,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/mesh.h",
"chars": 696,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/normal.cu",
"chars": 7463,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/normal.h",
"chars": 786,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/tensor.h",
"chars": 4408,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/torch_bindings.cpp",
"chars": 41998,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/vec3f.h",
"chars": 4430,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/c_src/vec4f.h",
"chars": 831,
"preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
},
{
"path": "render/renderutils/loss.py",
"chars": 1675,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/renderutils/ops.py",
"chars": 22087,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/renderutils/tests/test_bsdf.py",
"chars": 13752,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/renderutils/tests/test_loss.py",
"chars": 2225,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/renderutils/tests/test_mesh.py",
"chars": 3433,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/renderutils/tests/test_perf.py",
"chars": 2342,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "render/texture.py",
"chars": 7853,
"preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
},
{
"path": "render/util.py",
"chars": 22370,
"preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
},
{
"path": "train_gflexicubes_deepfashion.py",
"chars": 35437,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "train_gflexicubes_polycam.py",
"chars": 32260,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "train_gshelltet_deepfashion.py",
"chars": 34461,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "train_gshelltet_polycam.py",
"chars": 32617,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
},
{
"path": "train_gshelltet_synthetic.py",
"chars": 33225,
"preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
}
]
About this extraction
This page contains the full source code of the lzzcd001/GShell GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 124 files (1.2 MB), approximately 348.0k tokens, and a symbol index with 975 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.