Full Code of lzzcd001/GShell for AI

main c2f0ba9ea01a cached
124 files
1.2 MB
348.0k tokens
975 symbols
1 requests
Download .txt
Showing preview only (1,299K chars total). Download the full file or copy to clipboard to get everything.
Repository: lzzcd001/GShell
Branch: main
Commit: c2f0ba9ea01a
Files: 124
Total size: 1.2 MB

Directory structure:
gitextract_zb4dyoqx/

├── .gitignore
├── GMeshDiffusion/
│   ├── diffusion_configs/
│   │   ├── config_lower_occgrid_normalized.py
│   │   └── config_upper_occgrid_normalized.py
│   ├── lib/
│   │   ├── dataset/
│   │   │   ├── gshell_dataset.py
│   │   │   └── gshell_dataset_aug.py
│   │   └── diffusion/
│   │       ├── evaler.py
│   │       ├── likelihood.py
│   │       ├── losses.py
│   │       ├── models/
│   │       │   ├── __init__.py
│   │       │   ├── ema.py
│   │       │   ├── functional.py
│   │       │   ├── layers.py
│   │       │   ├── normalization.py
│   │       │   ├── unet3d_occgrid.py
│   │       │   └── utils.py
│   │       ├── sampling.py
│   │       ├── sde_lib.py
│   │       ├── trainer.py
│   │       ├── trainer_ddp.py
│   │       └── utils.py
│   ├── main_diffusion.py
│   ├── main_diffusion_ddp.py
│   ├── metadata/
│   │   ├── get_splits_lower.py
│   │   ├── get_splits_upper.py
│   │   ├── save_tet_info.py
│   │   └── tet_to_cubic_grid_dataset.py
│   └── scripts/
│       ├── run_eval_lower_occgrid_normalized.sh
│       ├── run_eval_upper_occgrid_normalized.sh
│       ├── run_lower_occgrid_normalized_ddp.sh
│       └── run_upper_occgrid_normalized_ddp.sh
├── README.md
├── configs/
│   ├── deepfashion_mc.json
│   ├── deepfashion_mc_256.json
│   ├── deepfashion_mc_512.json
│   ├── deepfashion_mc_80.json
│   ├── nerf_chair.json
│   ├── polycam_mc.json
│   ├── polycam_mc_128.json
│   └── polycam_mc_16samples.json
├── data/
│   └── tets/
│       └── generate_tets.py
├── dataset/
│   ├── __init__.py
│   ├── dataset.py
│   ├── dataset_deepfashion.py
│   ├── dataset_deepfashion_testset.py
│   ├── dataset_llff.py
│   ├── dataset_mesh.py
│   ├── dataset_nerf.py
│   └── dataset_nerf_colmap.py
├── denoiser/
│   └── denoiser.py
├── eval_gmeshdiffusion_generated_samples.py
├── geometry/
│   ├── embedding.py
│   ├── flexicubes_table.py
│   ├── gshell_flexicubes.py
│   ├── gshell_flexicubes_geometry.py
│   ├── gshell_tets.py
│   ├── gshell_tets_geometry.py
│   └── mlp.py
├── render/
│   ├── light.py
│   ├── material.py
│   ├── mesh.py
│   ├── mlptexture.py
│   ├── obj.py
│   ├── optixutils/
│   │   ├── __init__.py
│   │   ├── c_src/
│   │   │   ├── accessor.h
│   │   │   ├── bsdf.h
│   │   │   ├── common.h
│   │   │   ├── denoising.cu
│   │   │   ├── denoising.h
│   │   │   ├── envsampling/
│   │   │   │   ├── kernel.cu
│   │   │   │   └── params.h
│   │   │   ├── math_utils.h
│   │   │   ├── optix_wrapper.cpp
│   │   │   ├── optix_wrapper.h
│   │   │   └── torch_bindings.cpp
│   │   ├── include/
│   │   │   ├── internal/
│   │   │   │   ├── optix_7_device_impl.h
│   │   │   │   ├── optix_7_device_impl_exception.h
│   │   │   │   └── optix_7_device_impl_transformations.h
│   │   │   ├── optix.h
│   │   │   ├── optix_7_device.h
│   │   │   ├── optix_7_host.h
│   │   │   ├── optix_7_types.h
│   │   │   ├── optix_denoiser_tiling.h
│   │   │   ├── optix_device.h
│   │   │   ├── optix_function_table.h
│   │   │   ├── optix_function_table_definition.h
│   │   │   ├── optix_host.h
│   │   │   ├── optix_stack_size.h
│   │   │   ├── optix_stubs.h
│   │   │   └── optix_types.h
│   │   ├── ops.py
│   │   └── tests/
│   │       └── filter_test.py
│   ├── regularizer.py
│   ├── render.py
│   ├── renderutils/
│   │   ├── __init__.py
│   │   ├── bsdf.py
│   │   ├── c_src/
│   │   │   ├── bsdf.cu
│   │   │   ├── bsdf.h
│   │   │   ├── common.cpp
│   │   │   ├── common.h
│   │   │   ├── cubemap.cu
│   │   │   ├── cubemap.h
│   │   │   ├── loss.cu
│   │   │   ├── loss.h
│   │   │   ├── mesh.cu
│   │   │   ├── mesh.h
│   │   │   ├── normal.cu
│   │   │   ├── normal.h
│   │   │   ├── tensor.h
│   │   │   ├── torch_bindings.cpp
│   │   │   ├── vec3f.h
│   │   │   └── vec4f.h
│   │   ├── loss.py
│   │   ├── ops.py
│   │   └── tests/
│   │       ├── test_bsdf.py
│   │       ├── test_loss.py
│   │       ├── test_mesh.py
│   │       └── test_perf.py
│   ├── texture.py
│   └── util.py
├── train_gflexicubes_deepfashion.py
├── train_gflexicubes_polycam.py
├── train_gshelltet_deepfashion.py
├── train_gshelltet_polycam.py
└── train_gshelltet_synthetic.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
# lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

.DS_STORE


================================================
FILE: GMeshDiffusion/diffusion_configs/config_lower_occgrid_normalized.py
================================================
import ml_collections
import torch
import os


def get_config():
    config = ml_collections.ConfigDict()

    # data
    data = config.data = ml_collections.ConfigDict()
    data.root_dir = 'PLACEHOLDER'
    # data.dataset_metapath = os.path.join(data.root_dir, 'metadata/lower_res64_train.txt')
    data.num_workers = 4
    data.grid_size = 128
    data.tet_resolution = 64
    data.num_channels = 4
    data.use_occ_grid = True
    data.grid_metafile = os.path.join(data.root_dir, 'metadata/lower_res64_grid_train.txt')
    data.occgrid_metafile = os.path.join(data.root_dir, 'metadata/lower_res64_occgrid_train.txt')

    data.occ_mask_path = os.path.join(data.root_dir, 'metadata/occ_mask_res64.pt')
    data.tet_info_path = os.path.join(data.root_dir, 'metadata/tet_info.pt')

    data.filter_meta_path = None
    data.aug = True

    # training
    training = config.training = ml_collections.ConfigDict()
    training.sde = 'vpsde'
    training.continuous = False
    training.reduce_mean = True
    training.batch_size = 1 ### for DDP, global_batch_size = nproc * local_batch_size
    training.num_grad_acc_steps = 4 
    training.n_iters = 2400001
    training.snapshot_freq = 1000
    training.log_freq = 50
    ## produce samples at each snapshot.
    training.snapshot_sampling = True
    training.likelihood_weighting = False
    training.loss_type = 'l2'
    training.train_dir = "PLACEHOLDER"
    training.snapshot_freq_for_preemption = 1000
    training.gradscaler_growth_interval = 1000
    training.use_aux_loss = False


    training.compile = True # PyTorch 2.0, torch.compile
    training.enable_xformers_memory_efficient_attention = True

    # sampling
    sampling = config.sampling = ml_collections.ConfigDict()
    sampling.method = 'pc'
    sampling.predictor = 'ancestral_sampling'
    sampling.corrector = 'none'
    sampling.n_steps_each = 1
    sampling.noise_removal = True
    sampling.probability_flow = False
    sampling.snr = 0.075


    # model
    model = config.model = ml_collections.ConfigDict()
    model.name = 'unet3d_occgrid'
    model.use_occ_grid = True
    model.num_res_blocks = 2
    model.num_res_blocks_1st_layer = 2
    model.base_channels = 128
    model.ch_mult = (1, 2, 2, 4, 4, 4)
    model.down_block_types = (
        "ResBlock", "ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock"
    )
    model.up_block_types = (
       "ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock", "ResBlock"
    )
    model.scale_by_sigma = False
    model.num_scales = 1000
    model.ema_rate = 0.9999
    model.normalization = 'GroupNorm'
    model.act_fn = 'swish'
    model.attn_resolutions = (16,)
    model.resamp_with_conv = True
    model.dropout = 0.1
    model.sigma_max = 378
    model.sigma_min = 0.01
    model.beta_min = 0.1
    model.beta_max = 20.
    model.embedding_type = 'fourier'
    model.pred_type = 'noise'
    model.conditional = True

    model.feature_mask_path = os.path.join(data.root_dir, 'metadata/global_mask_res64.pt')
    model.pixcat_mask_path = os.path.join(data.root_dir, 'metadata/cat_mask_res64.pt')

    # optimization
    config.optim = optim = ml_collections.ConfigDict()
    optim.weight_decay = 1e-5
    optim.optimizer = 'AdamW'
    optim.lr = 1e-5
    optim.beta1 = 0.9
    optim.eps = 1e-8
    optim.warmup = 5000
    optim.grad_clip = 1.

    # eval
    config.eval = eval_config = ml_collections.ConfigDict()
    eval_config.batch_size = 2
    eval_config.idx = 0
    eval_config.bin_size = 30
    eval_config.eval_dir = "PLACEHOLDER"
    eval_config.ckpt_path = "PLACEHOLDER"
    

    config.seed = 42
    config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')


    return config


================================================
FILE: GMeshDiffusion/diffusion_configs/config_upper_occgrid_normalized.py
================================================
import ml_collections
import torch
import os


def get_config():
    config = ml_collections.ConfigDict()

    # data
    data = config.data = ml_collections.ConfigDict()
    data.root_dir = 'PLACEHOLDER'
    # data.dataset_metapath = os.path.join(data.root_dir, 'metadata/upper_res64_train.txt')
    data.num_workers = 4
    data.grid_size = 128
    data.tet_resolution = 64
    data.num_channels = 4
    data.use_occ_grid = True
    data.grid_metafile = os.path.join(data.root_dir, 'metadata/upper_res64_grid_train.txt')
    data.occgrid_metafile = os.path.join(data.root_dir, 'metadata/upper_res64_occgrid_train.txt')

    data.occ_mask_path = os.path.join(data.root_dir, 'metadata/occ_mask_res64.pt')
    data.tet_info_path = os.path.join(data.root_dir, 'metadata/tet_info.pt')

    data.filter_meta_path = None
    data.aug = True

    # training
    training = config.training = ml_collections.ConfigDict()
    training.sde = 'vpsde'
    training.continuous = False
    training.reduce_mean = True
    training.batch_size = 1 ### for DDP, global_batch_size = nproc * local_batch_size
    training.num_grad_acc_steps = 4 
    training.n_iters = 2400001
    training.snapshot_freq = 1000
    training.log_freq = 50
    ## produce samples at each snapshot.
    training.snapshot_sampling = True
    training.likelihood_weighting = False
    training.loss_type = 'l2'
    training.train_dir = "PLACEHOLDER"
    training.snapshot_freq_for_preemption = 1000
    training.gradscaler_growth_interval = 1000
    training.use_aux_loss = False


    training.compile = True # PyTorch 2.0, torch.compile
    training.enable_xformers_memory_efficient_attention = True

    # sampling
    sampling = config.sampling = ml_collections.ConfigDict()
    sampling.method = 'pc'
    sampling.predictor = 'ancestral_sampling'
    sampling.corrector = 'none'
    sampling.n_steps_each = 1
    sampling.noise_removal = True
    sampling.probability_flow = False
    sampling.snr = 0.075


    # model
    model = config.model = ml_collections.ConfigDict()
    model.name = 'unet3d_occgrid'
    model.use_occ_grid = True
    model.num_res_blocks = 2
    model.num_res_blocks_1st_layer = 2
    model.base_channels = 128
    model.ch_mult = (1, 2, 2, 4, 4, 4)
    model.down_block_types = (
        "ResBlock", "ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock"
    )
    model.up_block_types = (
       "ResBlock", "ResBlock", "AttnResBlock", "ResBlock", "ResBlock", "ResBlock"
    )
    model.scale_by_sigma = False
    model.num_scales = 1000
    model.ema_rate = 0.9999
    model.normalization = 'GroupNorm'
    model.act_fn = 'swish'
    model.attn_resolutions = (16,)
    model.resamp_with_conv = True
    model.dropout = 0.1
    model.sigma_max = 378
    model.sigma_min = 0.01
    model.beta_min = 0.1
    model.beta_max = 20.
    model.embedding_type = 'fourier'
    model.pred_type = 'noise'
    model.conditional = True

    model.feature_mask_path = os.path.join(data.root_dir, 'metadata/global_mask_res64_occaug_normalized_v1.pt')
    model.pixcat_mask_path = os.path.join(data.root_dir, 'metadata/cat_mask_res64_occaug_normalized_v1.pt')

    # optimization
    config.optim = optim = ml_collections.ConfigDict()
    optim.weight_decay = 1e-5
    optim.optimizer = 'AdamW'
    optim.lr = 1e-5
    optim.beta1 = 0.9
    optim.eps = 1e-8
    optim.warmup = 5000
    optim.grad_clip = 1.

    # eval
    config.eval = eval_config = ml_collections.ConfigDict()
    eval_config.batch_size = 2
    eval_config.idx = 0
    eval_config.bin_size = 30
    eval_config.eval_dir = "PLACEHOLDER"
    eval_config.ckpt_path = "PLACEHOLDER"
    

    config.seed = 42
    config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')


    return config


================================================
FILE: GMeshDiffusion/lib/dataset/gshell_dataset.py
================================================
import torch
import numpy as np
from torch.utils.data import Dataset

class GShellDataset(Dataset):
    def __init__(self, filepath_metafile, extension='pt'):
        super().__init__()
        with open(filepath_metafile, 'r') as f:
            self.filepath_list = [fpath.rstrip() for fpath in f]

        self.extension = extension
        assert self.extension in ['pt', 'npy']
    
    def __len__(self):
        return len(self.filepath_list)

    def __getitem__(self, idx):
        with torch.no_grad():
            if self.extension == 'pt':
                datum = torch.load(self.filepath_list[idx], map_location='cpu')
            else:
                datum = torch.tensor(np.load(self.filepath_list[idx]))
        return datum


================================================
FILE: GMeshDiffusion/lib/dataset/gshell_dataset_aug.py
================================================
import torch
from torch.utils.data import Dataset

class GShellAugDataset(Dataset):
    def __init__(self, FLAGS, extension='pt'):
        super().__init__()
        with open(FLAGS.data.grid_metafile, 'r') as f:
            self.filepath_list = [fpath.rstrip() for fpath in f]
        with open(FLAGS.data.occgrid_metafile, 'r') as f:
            self.occ_filepath_list = [fpath.rstrip() for fpath in f]

        self.extension = extension
        self.num_channels = FLAGS.data.num_channels
        print('num_channels: ', self.num_channels)
        assert self.extension in ['pt', 'npy']
    
    def __len__(self):
        return len(self.filepath_list)

    def __getitem__(self, idx):
        with torch.no_grad():
            grid = torch.load(self.filepath_list[idx], map_location='cpu')
            try:
                occ_grid = torch.load(self.occ_filepath_list[idx], map_location='cpu')
            except:
                print(self.occ_filepath_list[idx])
                raise
        return (grid[:self.num_channels], occ_grid)
    
    @staticmethod
    def collate(data):
        return {
            'grid': torch.stack([x[0] for x in data]),
            'occgrid': torch.stack([x[1] for x in data]),
        }


================================================
FILE: GMeshDiffusion/lib/diffusion/evaler.py
================================================
import os
import sys
import numpy as np
import tqdm

import logging
from . import losses
from .models import utils as mutils
from .models.ema import ExponentialMovingAverage
from . import sde_lib
import torch
from .utils import restore_checkpoint
from . import sampling

def uncond_gen(
        config
    ):
    """
        Unconditional Generation
    """
    with torch.no_grad():
        eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
        idx = config.eval.idx
        bin_size = config.eval.bin_size
        print(f"idx to save: {idx * bin_size} to {idx * bin_size + bin_size - 1}")
        # Create directory to eval_folder
        os.makedirs(eval_dir, exist_ok=True)

        scaler, inverse_scaler = lambda x: x, lambda x: x

        # Initialize model
        score_model = mutils.create_model(config, use_parallel=False)
        optimizer = losses.get_optimizer(config, score_model.parameters())
        ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
        state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)

        # Setup SDEs
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)

        sampling_eps = 1e-3
        sampling_shape = (config.eval.batch_size,
                        config.data.num_channels,
                        config.data.grid_size, config.data.grid_size, config.data.grid_size)
        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)

        assert os.path.exists(ckpt_path)
        print('ckpt path:', ckpt_path)
        try:
            state = restore_checkpoint(ckpt_path, state, device=config.device)
        except:
            raise
        ema.copy_to(score_model.parameters())

        print(f"loaded model is trained till iter {state['step'] // config.training.num_grad_acc_steps}")


        for k in range(bin_size):
            save_file_path = os.path.join(eval_dir, f"{idx * bin_size + k}")
            print(f'check: {save_file_path}')
            if os.path.exists(save_file_path + '.pt'):
                # continue
                pass
            print(f'will save to: {save_file_path}')
            samples, n = sampling_fn(score_model)
            if type(samples) != tuple:
                print(samples[:, 0].unique())
                torch.save(samples, save_file_path + '.pt')
                samples = samples.cpu().numpy()
                # np.save(save_file_path, samples)
            else:
                print(samples[0][:, 0].unique())
                torch.save(samples[0], save_file_path + '.pt')
                torch.save(samples[1], save_file_path + '_occ.pt')
                # samples, occ = samples[0].cpu().numpy(), samples[1].cpu().numpy()
            # np.save(save_file_path + '.npy, samples)


def slerp(z1, z2, alpha):
    '''
        Spherical Linear Interpolation
    '''
    theta = torch.acos(torch.sum(z1 * z2) / (torch.norm(z1) * torch.norm(z2)))
    return (
            torch.sin((1 - alpha) * theta) / torch.sin(theta) * z1
            + torch.sin(alpha * theta) / torch.sin(theta) * z2
    )

def uncond_gen_interp(
        config,
        idx=0,
    ):
    """
        Generation with interpolation between initial noises
        Used for DDIM
    """
    with torch.no_grad():
        eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
        # Create directory to eval_folder
        os.makedirs(eval_dir, exist_ok=True)

        scaler, inverse_scaler = lambda x: x, lambda x: x

        # Initialize model
        score_model = mutils.create_model(config, use_parallel=False)
        optimizer = losses.get_optimizer(config, score_model.parameters())
        ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
        state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)

        # Setup SDEs
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)

        sampling_eps = 1e-3
        sampling_shape = (config.eval.batch_size,
                        config.data.num_channels,
                        config.data.grid_size, config.data.grid_size, config.data.grid_size)
        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)

        assert os.path.exists(ckpt_path)
        print('ckpt path:', ckpt_path)
        try:
            state = restore_checkpoint(ckpt_path, state, device=config.device)
        except:
            raise
        ema.copy_to(score_model.parameters())

        print(f"loaded model is trained till iter {state['step'] // config.training.num_grad_acc_steps}")


        idx = config.eval.idx
        bin_size = config.eval.bin_size
        config.eval.interp_batch_size = 32
        print(f"idx to save: {idx * bin_size} to {idx * bin_size + bin_size - 1}")

        for k in range(bin_size):
            save_file_path = os.path.join(eval_dir, f"{idx * bin_size + k}")

            noise = sde.prior_sampling(
                (2, config.data.num_channels, config.data.grid_size, config.data.grid_size, config.data.grid_size)
            ).to(config.device)
        
            interp_sampling_shape = (config.eval.interp_batch_size,
                            config.data.num_channels,
                            config.data.grid_size, config.data.grid_size, config.data.grid_size)
            x0 = torch.zeros(interp_sampling_shape, device=config.device)
            x0[0] = noise[0]
            x0[-1] = noise[1]
            for i in range(1, config.eval.interp_batch_size - 1):
                x0[i] = slerp(x0[0], x0[-1], i / float(config.eval.interp_batch_size - 1))

            if config.model.use_occ_grid:
                noise_occ = sde.prior_sampling(
                    (2, 1, config.data.grid_size * 2, config.data.grid_size * 2, config.data.grid_size * 2)
                ).to(config.device)
                interp_sampling_shape = (config.eval.interp_batch_size,
                                1,
                                config.data.grid_size * 2, config.data.grid_size * 2, config.data.grid_size * 2)
                x0_occ = torch.zeros(interp_sampling_shape, device=config.device)
                x0_occ[0] = noise_occ[0]
                x0_occ[-1] = noise_occ[1]
                for i in range(1, config.eval.interp_batch_size - 1):
                    x0_occ[i] = slerp(x0_occ[0], x0_occ[-1], i / float(config.eval.interp_batch_size - 1))
            else:
                x0_occ = None

            sample_list = []
            sample_occ_list = []
            for i in tqdm.trange(config.eval.interp_batch_size):
                samples, n = sampling_fn(score_model, x0=x0[i:i+1], x0_occ=x0_occ[i:i+1])
                if type(samples) != tuple:
                    # samples = samples.cpu()
                    sample_list.append(samples.cpu())
                else:
                    # samples = samples.cpu()
                    sample_list.append(samples[0].cpu())
                    sample_occ_list.append(samples[1].cpu())

            # np.save(save_file_path, np.concatenate(sample_list, axis=0))
            torch.save(torch.cat(sample_list, dim=0), save_file_path + '.pt')
            if config.model.use_occ_grid:
                torch.save(torch.cat(sample_occ_list, dim=0), save_file_path + '_occ.pt')


def cond_gen(
        config,
        save_fname='0',
    ):
    """
        Conditional Generation with partially completed dmtet from a 2.5D view (converted into a cubic grid)
    """
    with torch.no_grad():
        eval_dir, ckpt_path = config.eval.eval_dir, config.eval.ckpt_path
        # Create directory to eval_folder
        os.makedirs(eval_dir, exist_ok=True)

        scaler, inverse_scaler = lambda x: x, lambda x: x

        # Initialize model
        score_model = mutils.create_model(config)
        optimizer = losses.get_optimizer(config, score_model.parameters())
        ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
        state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)

        # Setup SDEs
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)

        resolution = config.data.image_size
        grid_mask = torch.load(f'./data/grid_mask_{resolution}.pt').view(1, 1, resolution, resolution, resolution).to("cuda")
        grid_mask = grid_mask[:, :, :config.data.input_size, :config.data.input_size, :config.data.input_size]

        sampling_eps = 1e-3
        sampling_shape = (config.eval.batch_size,
                        config.data.num_channels,
                        # resolution, resolution, resolution)
                        config.data.input_size, config.data.input_size, config.data.input_size)
        sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, grid_mask=grid_mask)

        assert os.path.exists(ckpt_path)
        print('ckpt path:', ckpt_path)
        try:
            state = restore_checkpoint(ckpt_path, state, device=config.device)
        except:
            raise
        ema.copy_to(score_model.parameters())

        print(f"loaded model is trained till iter {state['step'] // config.training.iter_size}")

        
        save_file_path = os.path.join(eval_dir, f"{save_fname}.npy")

        ### Conditional but free gradients; start from small t

        partial_dict = torch.load(config.eval.partial_dmtet_path)
        partial_sdf = partial_dict['sdf']
        partial_mask = partial_dict['vis']


        ### compute the mapping from tet indices to 3D cubic grid vertex indices
        tet_path = config.eval.tet_path
        tet = np.load(tet_path)
        vertices = torch.tensor(tet['vertices'])
        vertices_unique = vertices[:].unique()
        dx = vertices_unique[1] - vertices_unique[0]

        ind_to_coord = (torch.round(
            (vertices - vertices.min()) / dx)
        ).long()

        
        partial_sdf_grid = torch.zeros((1, 1, resolution, resolution, resolution))
        partial_sdf_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_sdf
        partial_mask_grid = torch.zeros((1, 1, resolution, resolution, resolution))
        partial_mask_grid[0, 0, ind_to_coord[:, 0], ind_to_coord[:, 1], ind_to_coord[:, 2]] = partial_mask.float()

        samples, n = sampling_fn(
            score_model, 
            partial=partial_sdf_grid.cuda(), 
            partial_mask=partial_mask_grid.cuda(), 
            freeze_iters=config.eval.freeze_iters
        )

        samples = samples.cpu().numpy()
        np.save(save_file_path, samples)



================================================
FILE: GMeshDiffusion/lib/diffusion/likelihood.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""

import torch
import numpy as np
from scipy import integrate
from .models import utils as mutils


def get_div_fn(fn):
  """Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator."""

  def div_fn(x, t, eps):
    with torch.enable_grad():
      x.requires_grad_(True)
      fn_eps = torch.sum(fn(x, t) * eps)
      grad_fn_eps = torch.autograd.grad(fn_eps, x)[0]
    x.requires_grad_(False)
    return torch.sum(grad_fn_eps * eps, dim=tuple(range(1, len(x.shape))))

  return div_fn


def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',
                      rtol=1e-5, atol=1e-5, method='RK45', eps=1e-5):
  """Create a function to compute the unbiased log-likelihood estimate of a given data point.

  Args:
    sde: A `sde_lib.SDE` object that represents the forward SDE.
    inverse_scaler: The inverse data normalizer.
    hutchinson_type: "Rademacher" or "Gaussian". The type of noise for Hutchinson-Skilling trace estimator.
    rtol: A `float` number. The relative tolerance level of the black-box ODE solver.
    atol: A `float` number. The absolute tolerance level of the black-box ODE solver.
    method: A `str`. The algorithm for the black-box ODE solver.
      See documentation for `scipy.integrate.solve_ivp`.
    eps: A `float` number. The probability flow ODE is integrated to `eps` for numerical stability.

  Returns:
    A function that a batch of data points and returns the log-likelihoods in bits/dim,
      the latent code, and the number of function evaluations cost by computation.
  """

  def drift_fn(model, x, t):
    """The drift function of the reverse-time SDE."""
    score_fn = mutils.get_score_fn(sde, model, train=False, continuous=True)
    # Probability flow ODE is a special case of Reverse SDE
    rsde = sde.reverse(score_fn, probability_flow=True)
    return rsde.sde(x, t)[0]

  def div_fn(model, x, t, noise):
    return get_div_fn(lambda xx, tt: drift_fn(model, xx, tt))(x, t, noise)

  def likelihood_fn(model, data):
    """Compute an unbiased estimate to the log-likelihood in bits/dim.

    Args:
      model: A score model.
      data: A PyTorch tensor.

    Returns:
      bpd: A PyTorch tensor of shape [batch size]. The log-likelihoods on `data` in bits/dim.
      z: A PyTorch tensor of the same shape as `data`. The latent representation of `data` under the
        probability flow ODE.
      nfe: An integer. The number of function evaluations used for running the black-box ODE solver.
    """
    with torch.no_grad():
      shape = data.shape
      if hutchinson_type == 'Gaussian':
        epsilon = torch.randn_like(data)
      elif hutchinson_type == 'Rademacher':
        epsilon = torch.randint_like(data, low=0, high=2).float() * 2 - 1.
      else:
        raise NotImplementedError(f"Hutchinson type {hutchinson_type} unknown.")

      def ode_func(t, x):
        sample = mutils.from_flattened_numpy(x[:-shape[0]], shape).to(data.device).type(torch.float32)
        vec_t = torch.ones(sample.shape[0], device=sample.device) * t
        drift = mutils.to_flattened_numpy(drift_fn(model, sample, vec_t))
        logp_grad = mutils.to_flattened_numpy(div_fn(model, sample, vec_t, epsilon))
        return np.concatenate([drift, logp_grad], axis=0)

      init = np.concatenate([mutils.to_flattened_numpy(data), np.zeros((shape[0],))], axis=0)
      solution = integrate.solve_ivp(ode_func, (eps, sde.T), init, rtol=rtol, atol=atol, method=method)
      nfe = solution.nfev
      zp = solution.y[:, -1]
      z = mutils.from_flattened_numpy(zp[:-shape[0]], shape).to(data.device).type(torch.float32)
      delta_logp = mutils.from_flattened_numpy(zp[-shape[0]:], (shape[0],)).to(data.device).type(torch.float32)
      prior_logp = sde.prior_logp(z)
      bpd = -(prior_logp + delta_logp) / np.log(2)
      N = np.prod(shape[1:])
      bpd = bpd / N
      # A hack to convert log-likelihoods to bits/dim
      offset = 7. - inverse_scaler(-1.)
      bpd = bpd + offset
      return bpd, z, nfe

  return likelihood_fn


================================================
FILE: GMeshDiffusion/lib/diffusion/losses.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""All functions related to loss computation and optimization.
"""

import torch
import torch.optim as optim
import numpy as np
from .models import utils as mutils


def get_optimizer(config, params):
  """Returns a flax optimizer object based on `config`."""
  if config.optim.optimizer == 'Adam':
    optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
                           weight_decay=config.optim.weight_decay)
  elif config.optim.optimizer == 'AdamW':
    optimizer = optim.AdamW(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
                           weight_decay=config.optim.weight_decay)
  else:
    raise NotImplementedError(
      f'Optimizer {config.optim.optimizer} not supported yet!')

  return optimizer


def optimization_manager(config):
  """Returns an optimize_fn based on `config`."""

  def optimize_fn(optimizer, params, step, lr=config.optim.lr,
                  warmup=config.optim.warmup,
                  grad_clip=config.optim.grad_clip,
                  gradscaler=None):
    """Optimizes with warmup and gradient clipping (disabled if negative)."""
    if warmup > 0:
      for g in optimizer.param_groups:
        g['lr'] = lr * np.minimum(step / warmup, 1.0)
    if grad_clip >= 0:
      gradscaler.unscale_(optimizer)
      torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
    gradscaler.step(optimizer)
    gradscaler.update()
    # optimizer.step()

  return optimize_fn

def get_ddpm_loss_fn(vpsde, train, loss_type='l2', pred_type='noise', use_vis_mask=False, use_occ=False, use_aux=False):
  """Legacy code to reproduce previous results on DDPM. Not recommended for new work."""


  if use_occ:
    def loss_fn(model, batch, use_mesh_reg=False, verts_discretiezd=None, midpoints_discretiezd=None, edges=None):
      batch, batch_occ = batch['grid'], batch['occgrid']
      model_fn = mutils.get_model_fn(model, train=train)
      labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)
      sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)
      sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)
      with torch.no_grad():
        noise = torch.randn_like(batch, device=batch.device)
        noise_occ = torch.randn_like(batch_occ, device=batch.device)
        perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \
                        sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise
        perturbed_data = perturbed_data.type(batch.dtype)
        perturbed_data_occ = sqrt_alphas_cumprod[labels, None, None, None, None] * batch_occ + \
                        sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise_occ
        perturbed_data_occ = perturbed_data_occ.type(batch_occ.dtype)


      with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        pred, pred_occ = model_fn((perturbed_data, perturbed_data_occ), labels)
    
      pred, pred_occ = pred.float(), pred_occ.float()
      alphas1 = sqrt_alphas_cumprod[labels, None, None, None, None]
      alphas2 = sqrt_1m_alphas_cumprod[labels, None, None, None, None]
      if pred_type == 'noise':
        score = pred
        score_occ = pred_occ
        x0 = (perturbed_data - score * alphas2) / alphas1
        x0_occ = (perturbed_data_occ - score_occ * alphas2) / alphas1
      elif pred_type == 'x0':
        x0 = pred
        x0_occ = pred_occ
        score = (perturbed_data - x0 * alphas1) / alphas2
        score_occ = (perturbed_data_occ - pred_occ * alphas1) / alphas2
      
      # noise = noise[:, :, :score.size(2), :score.size(3), :score.size(4)] ### to accommodate change of size due to arch
      if loss_type == 'l2':
        losses = torch.square(score - noise)
        losses_occ = torch.square(score_occ - noise_occ)
        assert losses_occ.size(1) == 1
      elif loss_type == 'l1':
        raise NotImplementedError
        losses = torch.abs(score - noise)
      else:
        raise NotImplementedError

      mask = model.module.feature_mask
      occ_mask = model.module.occ_mask
      assert len(mask.size()) == 5
      assert mask.size(1) == losses.size(1)
      assert occ_mask.size(1) == losses_occ.size(1)
      assert losses.size(0) == losses_occ.size(0)
      if mask is not None:
        losses = losses * mask
        losses_occ = losses_occ * occ_mask
        occ_loss_scale = 1.0 if not use_aux else 1.0
        loss = (torch.sum(losses) + torch.sum(losses_occ)) / (mask.sum() + occ_mask.sum()) / losses.size(0)
      else:
        raise NotImplementedError
        
      if use_aux:
        pred_vis = model.module.extract_vis_from_cubicgrid(x0, x0_occ)
        with torch.no_grad():
          gt_vis = model.module.extract_vis_from_cubicgrid(batch, batch_occ.view(*x0_occ.size()))
        reg_loss = (
          (pred_vis - gt_vis).pow(2).view(x0.size(0), -1).mean(dim=-1) * sqrt_alphas_cumprod[labels]
        ).mean()
      else:
        reg_loss = torch.zeros_like(loss)
      total_loss = loss + reg_loss

      return total_loss, loss, reg_loss
  else:
    def loss_fn(model, batch, use_mesh_reg=False, verts_discretiezd=None, midpoints_discretiezd=None, edges=None):
      model_fn = mutils.get_model_fn(model, train=train)
      labels = torch.randint(0, vpsde.N, (batch.shape[0],), device=batch.device)
      sqrt_alphas_cumprod = vpsde.sqrt_alphas_cumprod.to(batch.device)
      sqrt_1m_alphas_cumprod = vpsde.sqrt_1m_alphas_cumprod.to(batch.device)
      noise = torch.randn_like(batch, device=batch.device)
      perturbed_data = sqrt_alphas_cumprod[labels, None, None, None, None] * batch + \
                      sqrt_1m_alphas_cumprod[labels, None, None, None, None] * noise
      perturbed_data = perturbed_data.type(batch.dtype)

      with torch.cuda.amp.autocast(dtype=torch.bfloat16):
        pred = model_fn(perturbed_data, labels)
      pred = pred.float()
      alphas1 = sqrt_alphas_cumprod[labels, None, None, None, None]
      alphas2 = sqrt_1m_alphas_cumprod[labels, None, None, None, None]
      if pred_type == 'noise':
        score = pred
        x0 = (perturbed_data - score * alphas2) / alphas1
      elif pred_type == 'x0':
        x0 = pred
        score = (perturbed_data - x0 * alphas1) / alphas2
      
      if use_vis_mask:
        assert x0.size(0) == 1
        vis_mask = model.extract_vismask_from_cubicgrid(x0)
        # noise = noise[:, :, :score.size(2), :score.size(3), :score.size(4)] ### to accommodate change of size due to arch
        if loss_type == 'l2':
          losses = torch.square((score - noise) * vis_mask)
        elif loss_type == 'l1':
          losses = torch.abs((score - noise) * vis_mask)
        else:
          raise NotImplementedError
      else:
        if loss_type == 'l2':
          losses = torch.square(score - noise)
        elif loss_type == 'l1':
          losses = torch.abs(score - noise)
        else:
          raise NotImplementedError

      mask = model.module.feature_mask
      assert len(mask.size()) == 5
      assert mask.size(1) == losses.size(1)
      if mask is not None:
        losses = losses * mask
        loss = torch.sum(losses) / mask.sum() / losses.size(0)
      else:
        raise NotImplementedError


      reg_loss = torch.zeros_like(loss)
      total_loss = loss

      return total_loss, loss, reg_loss

  return loss_fn

def get_step_fn(sde, train, optimize_fn=None, loss_type='l2', pred_type='noise', use_vis_mask=False, use_occ=False, use_aux=False):
  """Create a one-step training/evaluation function.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    optimize_fn: An optimization function.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    continuous: `True` indicates that the model is defined to take continuous time steps.
    likelihood_weighting: If `True`, weight the mixture of score matching losses according to
      https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.

  Returns:
    A one-step function for training or evaluation.
  """
  
  loss_fn = get_ddpm_loss_fn(sde, train, loss_type=loss_type, pred_type=pred_type, use_vis_mask=use_vis_mask, use_occ=use_occ, use_aux=use_aux)

  def step_fn(state, batch, clear_grad=True, update_param=True, gradscaler=None):
    """Running one step of training or evaluation.

    This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
    for faster execution.

    Args:
      state: A dictionary of training information, containing the score model, optimizer,
       EMA status, and number of optimization steps.
      batch: A mini-batch of training/evaluation data.

    Returns:
      loss: The average loss value of this state.
    """
    model = state['model']
    if train:
      optimizer = state['optimizer']
      if clear_grad:
        optimizer.zero_grad()
      loss_total, loss_score, loss_reg = loss_fn(model, batch)
      gradscaler.scale(loss_total).backward()
      if update_param:
        optimize_fn(optimizer, model.parameters(), step=state['step'], gradscaler=gradscaler)
      state['step'] += 1
      state['ema'].update(model.parameters())
    else:
      with torch.no_grad():
        ema = state['ema']
        ema.store(model.parameters())
        ema.copy_to(model.parameters())
        loss_total, loss_score, loss_reg = loss_fn(model, batch)
        ema.restore(model.parameters())

    return {
      'loss_total': loss_total,
      'loss_score': loss_score,
      'loss_reg': loss_reg,
    }

  return step_fn

================================================
FILE: GMeshDiffusion/lib/diffusion/models/__init__.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.



================================================
FILE: GMeshDiffusion/lib/diffusion/models/ema.py
================================================
# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py

from __future__ import division
from __future__ import unicode_literals

import torch


# Partially based on: https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/moving_averages.py
class ExponentialMovingAverage:
  """
  Maintains (exponential) moving average of a set of parameters.
  """

  def __init__(self, parameters, decay, use_num_updates=True):
    """
    Args:
      parameters: Iterable of `torch.nn.Parameter`; usually the result of
        `model.parameters()`.
      decay: The exponential decay.
      use_num_updates: Whether to use number of updates when computing
        averages.
    """
    if decay < 0.0 or decay > 1.0:
      raise ValueError('Decay must be between 0 and 1')
    self.decay = decay
    self.num_updates = 0 if use_num_updates else None
    self.shadow_params = [p.clone().detach()
                          for p in parameters if p.requires_grad]
    self.collected_params = []

  def update(self, parameters):
    """
    Update currently maintained parameters.

    Call this every time the parameters are updated, such as the result of
    the `optimizer.step()` call.

    Args:
      parameters: Iterable of `torch.nn.Parameter`; usually the same set of
        parameters used to initialize this object.
    """
    decay = self.decay
    if self.num_updates is not None:
      self.num_updates += 1
      decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
    one_minus_decay = 1.0 - decay
    with torch.no_grad():
      parameters = [p for p in parameters if p.requires_grad]
      for s_param, param in zip(self.shadow_params, parameters):
        # print(s_param.device, s_param.device, param.device)
        s_param.sub_(one_minus_decay * (s_param - param))

  def copy_to(self, parameters):
    """
    Copy current parameters into given collection of parameters.

    Args:
      parameters: Iterable of `torch.nn.Parameter`; the parameters to be
        updated with the stored moving averages.
    """
    parameters = [p for p in parameters if p.requires_grad]
    for s_param, param in zip(self.shadow_params, parameters):
      if param.requires_grad:
        param.data.copy_(s_param.data)

  def store(self, parameters):
    """
    Save the current parameters for restoring later.

    Args:
      parameters: Iterable of `torch.nn.Parameter`; the parameters to be
        temporarily stored.
    """
    self.collected_params = [param.clone() for param in parameters]

  def restore(self, parameters):
    """
    Restore the parameters stored with the `store` method.
    Useful to validate the model with EMA parameters without affecting the
    original optimization process. Store the parameters before the
    `copy_to` method. After validation (or model saving), use this to
    restore the former parameters.

    Args:
      parameters: Iterable of `torch.nn.Parameter`; the parameters to be
        updated with the stored parameters.
    """
    for c_param, param in zip(self.collected_params, parameters):
      param.data.copy_(c_param.data)

  def state_dict(self):
    return dict(decay=self.decay, num_updates=self.num_updates,
                shadow_params=self.shadow_params)

  def load_state_dict(self, state_dict, device='cuda'):
    self.decay = state_dict['decay']
    self.num_updates = state_dict['num_updates']
    self.shadow_params = state_dict['shadow_params']
    for k, _ in enumerate(self.shadow_params):
      self.shadow_params[k] = self.shadow_params[k].to(device)
    # for k in self.shadow_params:
    #   print(k.device)
    # raise

================================================
FILE: GMeshDiffusion/lib/diffusion/models/functional.py
================================================
#################################################################################################
# Copyright (c) 2023 Ali Hassani.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
#################################################################################################
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd

try:
    from natten import _C
except ImportError:
    raise ImportError(
        f"Failed to import NATTEN's CPP backend. "
        + f"This could be due to an invalid/incomplete install. "
        + f"Please uninstall NATTEN (pip uninstall natten) and re-install with the"
        f" correct torch build: "
        + f"shi-labs.com/natten"
    )


def has_cuda():
    return _C.has_cuda()


def has_half():
    return _C.has_half()


def has_bfloat():
    return _C.has_bfloat()


def has_gemm():
    return _C.has_gemm()


def enable_tf32():
    return _C.set_gemm_tf32(True)


def disable_tf32():
    return _C.set_gemm_tf32(False)


def enable_tiled_na():
    return _C.set_tiled_na(True)


def disable_tiled_na():
    return _C.set_tiled_na(False)


def enable_gemm_na():
    return _C.set_gemm_na(True)


def disable_gemm_na():
    return _C.set_gemm_na(False)


class NeighborhoodAttention1DQKAutogradFunction(Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, query, key, rpb, kernel_size, dilation):
        query = query.contiguous()
        key = key.contiguous()
        attn = _C.na1d_qk_forward(query, key, rpb, kernel_size, dilation)
        ctx.save_for_backward(query, key)
        ctx.kernel_size = kernel_size
        ctx.dilation = dilation
        ctx.bias = rpb is not None
        return attn

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_out):
        outputs = _C.na1d_qk_backward(
            grad_out.contiguous(),
            ctx.saved_tensors[0],
            ctx.saved_tensors[1],
            ctx.bias,
            ctx.kernel_size,
            ctx.dilation,
        )
        d_query, d_key, d_rpb = outputs
        return d_query, d_key, d_rpb, None, None


class NeighborhoodAttention1DAVAutogradFunction(Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, attn, value, kernel_size, dilation):
        attn = attn.contiguous()
        value = value.contiguous()
        out = _C.na1d_av_forward(attn, value, kernel_size, dilation)
        ctx.save_for_backward(attn, value)
        ctx.kernel_size = kernel_size
        ctx.dilation = dilation
        return out

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_out):
        outputs = _C.na1d_av_backward(
            grad_out.contiguous(),
            ctx.saved_tensors[0],
            ctx.saved_tensors[1],
            ctx.kernel_size,
            ctx.dilation,
        )
        d_attn, d_value = outputs
        return d_attn, d_value, None, None


class NeighborhoodAttention2DQKAutogradFunction(Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, query, key, rpb, kernel_size, dilation):
        query = query.contiguous()
        key = key.contiguous()
        if rpb is not None:
            rpb = rpb.to(key.dtype)
        attn = _C.na2d_qk_forward(query, key, rpb, kernel_size, dilation)
        ctx.save_for_backward(query, key)
        ctx.kernel_size = kernel_size
        ctx.dilation = dilation
        ctx.bias = rpb is not None
        return attn

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_out):
        outputs = _C.na2d_qk_backward(
            grad_out.contiguous(),
            ctx.saved_tensors[0],
            ctx.saved_tensors[1],
            ctx.bias,
            ctx.kernel_size,
            ctx.dilation,
        )
        d_query, d_key, d_rpb = outputs
        return d_query, d_key, d_rpb, None, None


class NeighborhoodAttention2DAVAutogradFunction(Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, attn, value, kernel_size, dilation):
        attn = attn.contiguous().to(value.dtype)
        value = value.contiguous()
        out = _C.na2d_av_forward(attn, value, kernel_size, dilation)
        ctx.save_for_backward(attn, value)
        ctx.kernel_size = kernel_size
        ctx.dilation = dilation
        return out

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_out):
        outputs = _C.na2d_av_backward(
            grad_out.contiguous(),
            ctx.saved_tensors[0],
            ctx.saved_tensors[1],
            ctx.kernel_size,
            ctx.dilation,
        )
        d_attn, d_value = outputs
        return d_attn, d_value, None, None


class NeighborhoodAttention3DQKAutogradFunction(Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation):
        query = query.contiguous()
        key = key.contiguous()
        attn = _C.na3d_qk_forward(
            query, key, rpb, kernel_size, dilation, kernel_size_d, dilation_d
        )
        ctx.save_for_backward(query, key)
        ctx.kernel_size_d = kernel_size_d
        ctx.kernel_size = kernel_size
        ctx.dilation_d = dilation_d
        ctx.dilation = dilation
        ctx.bias = rpb is not None
        return attn

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_out):
        outputs = _C.na3d_qk_backward(
            grad_out.contiguous(),
            ctx.saved_tensors[0],
            ctx.saved_tensors[1],
            ctx.bias,
            ctx.kernel_size,
            ctx.dilation,
            ctx.kernel_size_d,
            ctx.dilation_d,
        )
        d_query, d_key, d_rpb = outputs
        return d_query, d_key, d_rpb, None, None, None, None


class NeighborhoodAttention3DAVAutogradFunction(Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, attn, value, kernel_size_d, kernel_size, dilation_d, dilation):
        attn = attn.contiguous()
        value = value.contiguous()
        out = _C.na3d_av_forward(
            attn, value, kernel_size, dilation, kernel_size_d, dilation_d
        )
        ctx.save_for_backward(attn, value)
        ctx.kernel_size_d = kernel_size_d
        ctx.kernel_size = kernel_size
        ctx.dilation_d = dilation_d
        ctx.dilation = dilation
        return out

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_out):
        outputs = _C.na3d_av_backward(
            grad_out.contiguous(),
            ctx.saved_tensors[0],
            ctx.saved_tensors[1],
            ctx.kernel_size,
            ctx.dilation,
            ctx.kernel_size_d,
            ctx.dilation_d,
        )
        d_attn, d_value = outputs
        return d_attn, d_value, None, None, None, None


def natten1dqkrpb(query, key, rpb, kernel_size, dilation):
    return NeighborhoodAttention1DQKAutogradFunction.apply(
        query, key, rpb, kernel_size, dilation
    )


def natten1dqk(query, key, kernel_size, dilation):
    return NeighborhoodAttention1DQKAutogradFunction.apply(
        query, key, None, kernel_size, dilation
    )


def natten1dav(attn, value, kernel_size, dilation):
    return NeighborhoodAttention1DAVAutogradFunction.apply(
        attn, value, kernel_size, dilation
    )


def natten2dqkrpb(query, key, rpb, kernel_size, dilation):
    return NeighborhoodAttention2DQKAutogradFunction.apply(
        query, key, rpb, kernel_size, dilation
    )


def natten2dqk(query, key, kernel_size, dilation):
    return NeighborhoodAttention2DQKAutogradFunction.apply(
        query, key, None, kernel_size, dilation
    )


def natten2dav(attn, value, kernel_size, dilation):
    return NeighborhoodAttention2DAVAutogradFunction.apply(
        attn, value, kernel_size, dilation
    )


def natten3dqkrpb(query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation):
    return NeighborhoodAttention3DQKAutogradFunction.apply(
        query, key, rpb, kernel_size_d, kernel_size, dilation_d, dilation
    )


def natten3dqk(query, key, kernel_size_d, kernel_size, dilation_d, dilation):
    return NeighborhoodAttention3DQKAutogradFunction.apply(
        query, key, None, kernel_size_d, kernel_size, dilation_d, dilation
    )


def natten3dav(attn, value, kernel_size_d, kernel_size, dilation_d, dilation):
    return NeighborhoodAttention3DAVAutogradFunction.apply(
        attn, value, kernel_size_d, kernel_size, dilation_d, dilation
    )

================================================
FILE: GMeshDiffusion/lib/diffusion/models/layers.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
"""Common layers for defining score networks.
"""
import math
import string
from functools import partial
import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from .normalization import ConditionalInstanceNorm3dPlus

class GroupNormFloat32(nn.GroupNorm):
    def forward(self, input):
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
            return F.group_norm(
                input.float(), self.num_groups, self.weight, self.bias, self.eps)


def get_act_fn(act_name):
    """Get activation functions from the config file."""

    if act_name.lower() == 'elu':
        return nn.ELU()
    elif act_name.lower() == 'relu':
        return nn.ReLU()
    elif act_name.lower() == 'lrelu':
        return nn.LeakyReLU(negative_slope=0.2)
    elif act_name.lower() == 'swish' or act_name.lower() == 'silu':
        return nn.SiLU()
    else:
        raise NotImplementedError('activation function does not exist!')

def variance_scaling(scale, mode, distribution,
                     in_axis=1, out_axis=0,
                     dtype=torch.float32,
                     device='cpu'):
    """Ported from JAX. """

    def _compute_fans(shape, in_axis=1, out_axis=0):
        receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
        fan_in = shape[in_axis] * receptive_field_size
        fan_out = shape[out_axis] * receptive_field_size
        return fan_in, fan_out

    def init(shape, dtype=dtype, device=device):
        fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
        if mode == "fan_in":
            denominator = fan_in
        elif mode == "fan_out":
            denominator = fan_out
        elif mode == "fan_avg":
            denominator = (fan_in + fan_out) / 2
        else:
            raise ValueError(
                "invalid mode for variance scaling initializer: {}".format(mode))
        variance = scale / denominator
        if distribution == "normal":
            return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
        elif distribution == "uniform":
            return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
        else:
            raise ValueError("invalid distribution for variance scaling initializer")

    return init


def default_init(scale=1.):
    """The same initialization used in DDPM."""
    scale = 1e-10 if scale == 0 else scale
    return variance_scaling(scale, 'fan_avg', 'uniform')


class Dense(nn.Module):
    """Linear layer with `default_init`."""
    def __init__(self):
        super().__init__()


def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
    """1x1 convolution with DDPM initialization."""
    conv = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
    nn.init.zeros_(conv.bias)
    return conv

def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
    """3x3 convolution with DDPM initialization."""
    conv = nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
                    dilation=dilation, bias=bias)
    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
    nn.init.zeros_(conv.bias)
    return conv

def conv5x5(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):
    """3x3 convolution with DDPM initialization."""
    conv = nn.Conv3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding,
                    dilation=dilation, bias=bias)
    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
    nn.init.zeros_(conv.bias)
    return conv


def conv3x3_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=1):
    """3x3 convolution with DDPM initialization."""
    conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
                    dilation=dilation, bias=bias)
    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
    nn.init.zeros_(conv.bias)
    return conv

def conv5x5_transposed(in_planes, out_planes, stride=2, bias=True, dilation=1, init_scale=1., padding=2):
    """3x3 convolution with DDPM initialization."""
    conv = nn.ConvTranspose3d(in_planes, out_planes, kernel_size=5, stride=stride, padding=padding,
                    dilation=dilation, bias=bias)
    conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
    nn.init.zeros_(conv.bias)
    return conv


###########################################################################
# Functions below are ported over from the DDPM codebase:
#  https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py
###########################################################################

def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
    with torch.no_grad():
        assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32
        half_dim = embedding_dim // 2
        # magic number 10000 is from transformers
        emb = math.log(max_positions) / (half_dim - 1)
        # emb = math.log(2.) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
        # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
        # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
        emb = timesteps[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if embedding_dim % 2 == 1:  # zero pad
            emb = F.pad(emb, (0, 1), mode='constant')
        assert emb.shape == (timesteps.shape[0], embedding_dim)
        return emb

class AttnBlock(nn.Module):
    """Channel-wise self-attention block."""
    def __init__(self, channels, num_groups=32):
        super().__init__()
        self.GroupNorm_0 = GroupNormFloat32(num_groups=num_groups, num_channels=channels, eps=1e-6)
        self.NIN_0 = conv1x1(channels, channels)
        self.NIN_1 = conv1x1(channels, channels)
        self.NIN_2 = conv1x1(channels, channels)
        self.NIN_3 = conv1x1(channels, channels, init_scale=0.)

    def forward(self, x):
        B, C, D, H, W = x.shape
        h = self.GroupNorm_0(x)
        q = self.NIN_0(h)
        k = self.NIN_1(h)
        v = self.NIN_2(h)

        # q = q.view(B, C, -1).permute(0, 2, 1)
        # k = k.view(B, C, -1).permute(0, 2, 1)
        # v = v.view(B, C, -1).permute(0, 2, 1)
        # with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
        #     h = F.scaled_dot_product_attention(q.float(), k.float(), v.float())
        # h = h.permute(0, 2, 1).view(B, C, D, H, W)

        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
            w = torch.einsum('bcdhw,bckij->bdhwkij', q.float(), k.float()) * (int(C) ** (-0.5))
            w = torch.reshape(w, (B, D, H, W, D * H * W))
            w = F.softmax(w, dim=-1)
            w = torch.reshape(w, (B, D, H, W, D, H, W))
        h = torch.einsum('bdhwkij,bckij->bcdhw', w, v)

        h = self.NIN_3(h)
        return x + h

class Upsample(nn.Module):
    def __init__(self, channels, with_conv=False):
        super().__init__()
        if with_conv:
            self.Conv_0 = conv3x3(channels, channels)
        self.with_conv = with_conv

    def forward(self, x, temb=None):
        B, C, D, H, W = x.shape
        h = F.interpolate(x.float(), (D * 2, H * 2, W * 2), mode='nearest')
        if self.with_conv:
            h = self.Conv_0(h)
        return h


class Downsample(nn.Module):
    def __init__(self, channels, with_conv=False):
        super().__init__()
        if with_conv:
            self.Conv_0 = conv3x3(channels, channels, stride=2, padding=0)
            self.with_conv = with_conv

    def forward(self, x, temb=None):
        B, C, D, H, W = x.shape
        # Emulate 'SAME' padding
        if self.with_conv:
            x = F.pad(x, (0, 1, 0, 1, 0, 1))
            x = self.Conv_0(x)
        else:
            x = F.avg_pool3d(x, kernel_size=2, stride=2, padding=0)

        assert x.shape == (B, C, D // 2, H // 2, W // 2)
        return x


class ResBlock(nn.Module):
    """The ResNet Blocks used in DDPM."""
    def __init__(self, act_fn, in_ch, out_ch=None, temb_dim=None, conv_shortcut=False, dropout=0.1, num_groups=32):
        super().__init__()
        if out_ch is None:
            out_ch = in_ch
        self.GroupNorm_0 = GroupNormFloat32(num_groups=num_groups, num_channels=in_ch, eps=1e-6)
        self.act = act_fn
        self.Conv_0 = conv3x3(in_ch, out_ch)
        if temb_dim is not None:
            self.Dense_0 = nn.Linear(temb_dim, out_ch)
            self.Dense_0.weight.data = default_init()(self.Dense_0.weight.data.shape)
            nn.init.zeros_(self.Dense_0.bias)

        self.GroupNorm_1 = GroupNormFloat32(num_groups=num_groups, num_channels=out_ch, eps=1e-6)
        self.Dropout_0 = nn.Dropout(dropout)
        self.Conv_1 = conv3x3(out_ch, out_ch, init_scale=0.)
        if in_ch != out_ch:
            if conv_shortcut:
                self.Conv_2 = conv3x3(in_ch, out_ch)
            else:
                self.NIN_0 = conv1x1(in_ch, out_ch)
        self.out_ch = out_ch
        self.in_ch = in_ch
        self.conv_shortcut = conv_shortcut

    def forward(self, x, temb=None):
        B, C, D, H, W = x.shape
        assert C == self.in_ch
        out_ch = self.out_ch if self.out_ch else self.in_ch
        h = self.act(self.GroupNorm_0(x))
        h = self.Conv_0(h)
        # Add bias to each feature map conditioned on the time embedding
        if temb is not None:
            h += self.Dense_0(self.act(temb))[:, :, None, None, None]
        h = self.act(self.GroupNorm_1(h))
        h = self.Dropout_0(h)
        h = self.Conv_1(h)
        if C != out_ch:
            if self.conv_shortcut:
                x = self.Conv_2(x)
            else:
                x = self.NIN_0(x)
        return x + h

class AttnResBlock(ResBlock):
    """The ResNet Blocks used in DDPM."""
    def __init__(self, act_fn, in_ch, out_ch, temb_dim=None, conv_shortcut=False, dropout=0.1, num_groups=32):
        super().__init__(act_fn, in_ch, out_ch, temb_dim, conv_shortcut, dropout, num_groups=num_groups)
        self.attn_block = AttnBlock(out_ch, num_groups=num_groups)

    def forward(self, x, temb=None):
        h = super().forward(x, temb)
        return self.attn_block(h)


================================================
FILE: GMeshDiffusion/lib/diffusion/models/normalization.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Normalization layers."""
import torch.nn as nn
import torch
import functools


def get_normalization(config, conditional=False):
  """Obtain normalization modules from the config file."""
  norm = config.model.normalization
  if conditional:
    if norm == 'InstanceNorm++':
      return functools.partial(ConditionalInstanceNorm3dPlus, num_classes=config.model.num_classes)
    else:
      raise NotImplementedError(f'{norm} not implemented yet.')
  else:
    if norm == 'InstanceNorm':
      return nn.InstanceNorm3d
    elif norm == 'InstanceNorm++':
      return InstanceNorm3dPlus
    elif norm == 'VarianceNorm':
      return VarianceNorm3d
    elif norm == 'GroupNorm':
      return nn.GroupNorm
    else:
      raise ValueError('Unknown normalization: %s' % norm)


class ConditionalBatchNorm3d(nn.Module):
  def __init__(self, num_features, num_classes, bias=True):
    super().__init__()
    self.num_features = num_features
    self.bias = bias
    self.bn = nn.BatchNorm3d(num_features, affine=False)
    if self.bias:
      self.embed = nn.Embedding(num_classes, num_features * 2)
      self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)
      self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0
    else:
      self.embed = nn.Embedding(num_classes, num_features)
      self.embed.weight.data.uniform_()

  def forward(self, x, y):
    out = self.bn(x)
    if self.bias:
      gamma, beta = self.embed(y).chunk(2, dim=1)
      out = gamma.view(-1, self.num_features, 1, 1, 1) * out + beta.view(-1, self.num_features, 1, 1, 1)
    else:
      gamma = self.embed(y)
      out = gamma.view(-1, self.num_features, 1, 1, 1) * out
    return out


class ConditionalInstanceNorm3d(nn.Module):
  def __init__(self, num_features, num_classes, bias=True):
    super().__init__()
    self.num_features = num_features
    self.bias = bias
    self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
    if bias:
      self.embed = nn.Embedding(num_classes, num_features * 2)
      self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)
      self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0
    else:
      self.embed = nn.Embedding(num_classes, num_features)
      self.embed.weight.data.uniform_()

  def forward(self, x, y):
    h = self.instance_norm(x)
    if self.bias:
      gamma, beta = self.embed(y).chunk(2, dim=-1)
      out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1)
    else:
      gamma = self.embed(y)
      out = gamma.view(-1, self.num_features, 1, 1, 1) * h
    return out


class ConditionalVarianceNorm3d(nn.Module):
  def __init__(self, num_features, num_classes, bias=False):
    super().__init__()
    self.num_features = num_features
    self.bias = bias
    self.embed = nn.Embedding(num_classes, num_features)
    self.embed.weight.data.normal_(1, 0.02)

  def forward(self, x, y):
    vars = torch.var(x, dim=(2, 3, 4), keepdim=True)
    h = x / torch.sqrt(vars + 1e-5)

    gamma = self.embed(y)
    out = gamma.view(-1, self.num_features, 1, 1, 1) * h
    return out


class VarianceNorm3d(nn.Module):
  def __init__(self, num_features, bias=False):
    super().__init__()
    self.num_features = num_features
    self.bias = bias
    self.alpha = nn.Parameter(torch.zeros(num_features))
    self.alpha.data.normal_(1, 0.02)

  def forward(self, x):
    vars = torch.var(x, dim=(2, 3, 4), keepdim=True)
    h = x / torch.sqrt(vars + 1e-5)

    out = self.alpha.view(-1, self.num_features, 1, 1, 1) * h
    return out


class ConditionalNoneNorm3d(nn.Module):
  def __init__(self, num_features, num_classes, bias=True):
    super().__init__()
    self.num_features = num_features
    self.bias = bias
    if bias:
      self.embed = nn.Embedding(num_classes, num_features * 2)
      self.embed.weight.data[:, :num_features].uniform_()  # Initialise scale at N(1, 0.02)
      self.embed.weight.data[:, num_features:].zero_()  # Initialise bias at 0
    else:
      self.embed = nn.Embedding(num_classes, num_features)
      self.embed.weight.data.uniform_()

  def forward(self, x, y):
    if self.bias:
      gamma, beta = self.embed(y).chunk(2, dim=-1)
      out = gamma.view(-1, self.num_features, 1, 1, 1) * x + beta.view(-1, self.num_features, 1, 1, 1)
    else:
      gamma = self.embed(y)
      out = gamma.view(-1, self.num_features, 1, 1, 1) * x
    return out


class NoneNorm3d(nn.Module):
  def __init__(self, num_features, bias=True):
    super().__init__()

  def forward(self, x):
    return x


class InstanceNorm3dPlus(nn.Module):
  def __init__(self, num_features, bias=True):
    super().__init__()
    self.num_features = num_features
    self.bias = bias
    self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
    self.alpha = nn.Parameter(torch.zeros(num_features))
    self.gamma = nn.Parameter(torch.zeros(num_features))
    self.alpha.data.normal_(1, 0.02)
    self.gamma.data.normal_(1, 0.02)
    if bias:
      self.beta = nn.Parameter(torch.zeros(num_features))

  def forward(self, x):
    means = torch.mean(x, dim=(2, 3, 4))
    m = torch.mean(means, dim=-1, keepdim=True)
    v = torch.var(means, dim=-1, keepdim=True)
    means = (means - m) / (torch.sqrt(v + 1e-5))
    h = self.instance_norm(x)

    if self.bias:
      h = h + means[..., None, None, None] * self.alpha[..., None, None, None]
      out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h + self.beta.view(-1, self.num_features, 1, 1, 1)
    else:
      h = h + means[..., None, None, None] * self.alpha[..., None, None, None]
      out = self.gamma.view(-1, self.num_features, 1, 1, 1) * h
    return out


class ConditionalInstanceNorm3dPlus(nn.Module):
  def __init__(self, num_features, num_classes, bias=True):
    super().__init__()
    self.num_features = num_features
    self.bias = bias
    self.instance_norm = nn.InstanceNorm3d(num_features, affine=False, track_running_stats=False)
    if bias:
      self.embed = nn.Embedding(num_classes, num_features * 3)
      self.embed.weight.data[:, :2 * num_features].normal_(1, 0.02)  # Initialise scale at N(1, 0.02)
      self.embed.weight.data[:, 2 * num_features:].zero_()  # Initialise bias at 0
    else:
      self.embed = nn.Embedding(num_classes, 2 * num_features)
      self.embed.weight.data.normal_(1, 0.02)

  def forward(self, x, y):
    means = torch.mean(x, dim=(2, 3, 4))
    m = torch.mean(means, dim=-1, keepdim=True)
    v = torch.var(means, dim=-1, keepdim=True)
    means = (means - m) / (torch.sqrt(v + 1e-5))
    h = self.instance_norm(x)

    if self.bias:
      gamma, alpha, beta = self.embed(y).chunk(3, dim=-1)
      h = h + means[..., None, None, None] * alpha[..., None, None, None]
      out = gamma.view(-1, self.num_features, 1, 1, 1) * h + beta.view(-1, self.num_features, 1, 1, 1)
    else:
      gamma, alpha = self.embed(y).chunk(2, dim=-1)
      h = h + means[..., None, None, None] * alpha[..., None, None, None]
      out = gamma.view(-1, self.num_features, 1, 1, 1) * h
    return out


================================================
FILE: GMeshDiffusion/lib/diffusion/models/unet3d_occgrid.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
"""DDPM model.

This code is the pytorch equivalent of:
https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py
"""
import torch
import torch.nn as nn
import functools
import numpy as np

from . import utils
from .layers import ResBlock, AttnResBlock, Upsample, Downsample, conv1x1, conv3x3, conv5x5, get_act_fn, default_init, get_timestep_embedding, GroupNormFloat32

import sys


def str_to_class(classname):
    return getattr(sys.modules[__name__], classname)


@utils.register_model(name='unet3d_occgrid')
class UNet3D(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.act_fn = get_act_fn(config.model.act_fn)
        self.nf = nf = config.model.base_channels
        data_ch = config.data.num_channels
        ch_mult = config.model.ch_mult
        feature_mask = torch.load(config.model.feature_mask_path, map_location='cpu').view(1, data_ch, 128, 128, 128)
        pixcat_mask = torch.load(config.model.pixcat_mask_path, map_location='cpu').view(1, 1, 128, 128, 128)

        occ_mask_path = config.data.occ_mask_path
        occ_mask = torch.load(occ_mask_path, map_location='cpu').view(1, 1, 256, 256, 256)


        tet_info = torch.load(config.data.tet_info_path)
        self.tet_edge_vpos = tet_info['tet_edge_vpos'].cuda()
        self.tet_edge_pix_loc = tet_info['tet_edge_pix_loc'].cuda().view(-1, 2, 3)
        self.tet_edge_pix_loc = self.tet_edge_pix_loc.view(-1, 2, 3)
        # self.tet_center_loc = tet_info['tet_center_loc'].cuda()
        self.vis_edges = tet_info['vis_edges'].cuda()
        self.occ_edge_cano_order = tet_info['occ_edge_cano_order'].cuda()
        self.tet_edgenode_loc = self.tet_edge_pix_loc.float().mean(dim=1).long()
        self.occ_edge_loc = self.tet_edgenode_loc.view(-1, 6, 3)[:, self.vis_edges.view(-1)].view(-1, 2, 3)
        self.occ_node_loc = (self.occ_edge_loc.view(-1, 12, 2, 3).float().mean(dim=-2) * 2.0).long().view(-1, 3)
        print(self.tet_edgenode_loc.size(), self.vis_edges.size(), self.occ_edge_loc.size(), self.occ_node_loc.size())
        self.tet_edge_pix_loc = self.tet_edge_pix_loc.view(-1, 3)
        
        
        self.feature_mask = torch.nn.Parameter(feature_mask, requires_grad=False)
        self.pixcat_mask = torch.nn.Parameter(pixcat_mask, requires_grad=False)
        self.occ_mask = torch.nn.Parameter(occ_mask, requires_grad=False)
        self.down_block_types = config.model.down_block_types
        self.up_block_types = config.model.up_block_types
        self.num_res_blocks = config.model.num_res_blocks
        self.num_res_blocks_1st_layer = config.model.num_res_blocks_1st_layer
        resamp_with_conv = config.model.resamp_with_conv
        dropout = config.model.dropout
        assert len(self.down_block_types) == len(self.up_block_types)


        module_dict = {
            module: functools.partial(str_to_class(module), act_fn=self.act_fn, temb_dim=4 * nf, dropout=dropout)
            for module in ["ResBlock", "AttnResBlock"]
        }


        # Condition on noise levels.
        noise_temb_layers = [nn.Linear(nf, nf * 4), nn.SiLU(), nn.Linear(nf * 4, nf * 4)]
        noise_temb_layers[0].weight.data = default_init()(noise_temb_layers[0].weight.data.shape)
        nn.init.zeros_(noise_temb_layers[0].bias)
        noise_temb_layers[2].weight.data = default_init()(noise_temb_layers[2].weight.data.shape)
        nn.init.zeros_(noise_temb_layers[2].bias)
        self.noise_temb_layers = nn.Sequential(*noise_temb_layers)

        self.occ_conv = conv3x3(1, nf, stride=2, padding=1)
        self.occ_mask_conv = conv3x3(1, nf, stride=2, padding=1)

        # Downsampling block
        self.mask_layer = conv5x5(1, nf, stride=1, padding=2)
        self.input_layer = conv5x5(data_ch, nf, stride=1, padding=2)
        hs_c = [nf]
        in_ch = nf
        
        modules = []
        for i_level, down_block_type in enumerate(self.down_block_types):
            curr_block = module_dict[down_block_type]
            # Residual blocks for this resolution
            num_res_blocks = self.num_res_blocks_1st_layer if i_level == 0 else self.num_res_blocks
            for i_block in range(num_res_blocks):
                out_ch = nf * ch_mult[i_level]
                modules.append(curr_block(in_ch=in_ch, out_ch=out_ch))
                in_ch = out_ch
                hs_c.append(in_ch)
        
            if i_level != len(self.down_block_types) - 1:
                modules.append(Downsample(channels=in_ch, with_conv=resamp_with_conv))
                hs_c.append(in_ch)

        in_ch = hs_c[-1]
        modules.append(module_dict["AttnResBlock"](in_ch=in_ch, out_ch=in_ch))
        modules.append(module_dict["ResBlock"](in_ch=in_ch))

        # Upsampling block
        for i_level, up_block_type in enumerate(self.up_block_types):
            curr_block = module_dict[up_block_type]
            num_res_blocks = self.num_res_blocks_1st_layer if i_level == len(self.up_block_types) - 1 else self.num_res_blocks
            for i_block in range(num_res_blocks + 1):
                out_ch = nf * ch_mult[len(self.up_block_types) - i_level - 1]
                modules.append(curr_block(in_ch=in_ch + hs_c.pop(), out_ch=out_ch))
                in_ch = out_ch
            if i_level != len(self.up_block_types) - 1:
                modules.append(Upsample(channels=in_ch, with_conv=resamp_with_conv))

        self.all_modules = nn.ModuleList(modules)

        self.output_norm_layer = nn.Sequential(
            GroupNormFloat32(num_channels=in_ch, num_groups=32, eps=1e-6),
            nn.SiLU(),
        )
        self.output_layer = conv5x5(in_ch, data_ch, init_scale=0., stride=1, padding=2)


        self.occ_output_layer = nn.ConvTranspose3d(in_ch, 1, 4, stride=2, padding=1)

    def sequentially_call_module(self, idx, x, temb=None):
        return idx + 1, self.all_modules[idx](x, temb)

    def forward(self, x, labels):
        modules = self.all_modules

        with torch.no_grad():
            x, occ_grid = x[0], x[1]
            if True or self.centered:
                # Input is in [-1, 1]
                x = x
            else:
                # Input is in [0, 1]
                x = 2 * x - 1.

            # Mask out unused values
            x = x * self.feature_mask

            occ_grid = occ_grid * self.occ_mask
        with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
            # timestep/scale embedding
            timesteps = labels
            temb = get_timestep_embedding(timesteps.float(), self.nf)
            temb = self.noise_temb_layers(temb)

        # Downsampling block
        hs = [self.input_layer(x) + self.mask_layer(self.pixcat_mask) + self.occ_conv(occ_grid) + self.occ_mask_conv(self.occ_mask)]

        m_idx = 0
        for i_level in range(len(self.down_block_types)):
            num_res_blocks = self.num_res_blocks_1st_layer if i_level == 0 else self.num_res_blocks
            for i_block in range(num_res_blocks):
                m_idx, h = self.sequentially_call_module(m_idx, hs[-1], temb)
                hs.append(h)
            if i_level != len(self.down_block_types) - 1:
                m_idx, h = self.sequentially_call_module(m_idx, hs[-1])
                hs.append(h)

        h = hs[-1]
        m_idx, h = self.sequentially_call_module(m_idx, h, temb)
        m_idx, h = self.sequentially_call_module(m_idx, h, temb)

        # Upsampling block
        for i_level in range(len(self.up_block_types)):
            num_res_blocks = self.num_res_blocks_1st_layer if i_level == len(self.up_block_types) - 1 else self.num_res_blocks
            for i_block in range(num_res_blocks + 1):
                hspop = hs.pop()
                h = torch.cat([h, hspop], dim=1)
                m_idx, h = self.sequentially_call_module(m_idx, h, temb)
            if i_level != len(self.up_block_types) - 1:
                m_idx, h = self.sequentially_call_module(m_idx, h, temb)

        assert not hs
        h = self.output_norm_layer(h)
        grid = self.output_layer(h)
        grid_occ = self.occ_output_layer(h)

        # Mask out unused values
        grid = grid * self.feature_mask
        grid_occ = grid_occ * self.occ_mask

        return grid, grid_occ


================================================
FILE: GMeshDiffusion/lib/diffusion/models/utils.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""All functions and modules related to model definition.
"""

import torch
from .. import sde_lib
import numpy as np


_MODELS = {}


def register_model(cls=None, *, name=None):
  """A decorator for registering model classes."""

  def _register(cls):
    if name is None:
      local_name = cls.__name__
    else:
      local_name = name
    if local_name in _MODELS:
      raise ValueError(f'Already registered model with name: {local_name}')
    _MODELS[local_name] = cls
    return cls

  if cls is None:
    return _register
  else:
    return _register(cls)


def get_model(name):
  return _MODELS[name]


def get_sigmas(config):
  """Get sigmas --- the set of noise levels for SMLD from config files.
  Args:
    config: A ConfigDict object parsed from the config file
  Returns:
    sigmas: a jax numpy arrary of noise levels
  """
  sigmas = np.exp(
    np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))

  return sigmas


def get_ddpm_params(config):
  """Get betas and alphas --- parameters used in the original DDPM paper."""
  num_diffusion_timesteps = 1000
  # parameters need to be adapted if number of time steps differs from 1000
  beta_start = config.model.beta_min / config.model.num_scales
  beta_end = config.model.beta_max / config.model.num_scales
  betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)

  alphas = 1. - betas
  alphas_cumprod = np.cumprod(alphas, axis=0)
  sqrt_alphas_cumprod = np.sqrt(alphas_cumprod)
  sqrt_1m_alphas_cumprod = np.sqrt(1. - alphas_cumprod)

  return {
    'betas': betas,
    'alphas': alphas,
    'alphas_cumprod': alphas_cumprod,
    'sqrt_alphas_cumprod': sqrt_alphas_cumprod,
    'sqrt_1m_alphas_cumprod': sqrt_1m_alphas_cumprod,
    'beta_min': beta_start * (num_diffusion_timesteps - 1),
    'beta_max': beta_end * (num_diffusion_timesteps - 1),
    'num_diffusion_timesteps': num_diffusion_timesteps
  }


def create_model(config, use_parallel=True, ddp=False, rank=None):
  """Create the score model."""
  model_name = config.model.name
  score_model = get_model(model_name)(config)
  if use_parallel:
    if ddp:
      score_model = score_model.to(rank)
      score_model = torch.nn.parallel.DistributedDataParallel(
        score_model, 
        find_unused_parameters=False,
        # find_unused_parameters=True,
        gradient_as_bucket_view=True,
        # static_graph=True,
        device_ids=[rank])
      # score_model = torch.compile(score_model)
      # score_model = torch.compile(score_model, fullgraph=True)
    else:
      score_model = torch.nn.DataParallel(score_model).to(config.device)
  else:
    score_model = score_model.to(config.device)
  return score_model


def get_model_fn(model, train=False):
  """Create a function to give the output of the score-based model.

  Args:
    model: The score model.
    train: `True` for training and `False` for evaluation.

  Returns:
    A model function.
  """

  def model_fn(x, labels):
    """Compute the output of the score-based model.

    Args:
      x: A mini-batch of input data.
      labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
        for different models.

    Returns:
      A tuple of (model output, new mutable states)
    """
    if not train:
      model.eval()
      return model(x, labels)
    else:
      model.train()
      return model(x, labels)

  return model_fn

def get_reg_fn(model, train=False):
  """Create a function to give the output of the score-based model.

  Args:
    model: The score model.
    train: `True` for training and `False` for evaluation.

  Returns:
    A model function.
  """

  def model_fn(x):
    """Compute the output of the score-based model.

    Args:
      x: A mini-batch of input data.
      labels: A mini-batch of conditioning variables for time steps. Should be interpreted differently
        for different models.

    Returns:
      A tuple of (model output, new mutable states)
    """
    if not train:
      model.eval()
      try:
        return model.get_reg(x)
      except:
        return torch.zeros_like(x, device=x.device)
    else:
      model.train()
      try:
        return model.get_reg(x)
      except:
        return torch.zeros_like(x, device=x.device)

  return model_fn

def get_score_fn(sde, model, train=False, continuous=False, std_scale=True, pred_type='noise'):
  """Wraps `score_fn` so that the model output corresponds to a real time-dependent score function.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    model: A score model.
    train: `True` for training and `False` for evaluation.
    continuous: If `True`, the score-based model is expected to directly take continuous time steps.
    std_scale: whether to scale the score function by the inverse of std. Used for DDIM sampling

  Returns:
    A score function.
  """
  model_fn = get_model_fn(model, train=train)
  reg_fn = get_reg_fn(model, train=train)

  assert not continuous
  if isinstance(sde, sde_lib.VPSDE):
    if not std_scale:
      def score_fn(x, t):
        labels = t * (sde.N - 1)
        pred = model_fn(x, labels)

        if pred_type == 'x0':
          labels = labels.long()
          alphas1 = sde.sqrt_alphas_cumprod[labels, None, None, None, None].cuda()
          alphas2 = sde.sqrt_1m_alphas_cumprod[labels, None, None, None, None].cuda()
          score = (x - pred * alphas1) / alphas2
        elif pred_type == 'noise':
          score = pred
        return score
    else:
      def score_fn(x, t):
        # For VP-trained models, t=0 corresponds to the lowest noise level
        labels = t * (sde.N - 1)
        pred = model_fn(x, labels)


        if pred_type == 'x0':
          labels = labels.long()
          alphas1 = sde.sqrt_alphas_cumprod[labels, None, None, None, None].cuda()
          alphas2 = sde.sqrt_1m_alphas_cumprod[labels, None, None, None, None].cuda()
          score = (x - pred * alphas1) / alphas2
        elif pred_type == 'noise':
          score = pred

        std = sde.sqrt_1m_alphas_cumprod.to(labels.device)[labels.long()]

        score = -score / std[:, None, None, None, None]
        return score

  else:
    raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")

  return score_fn


def to_flattened_numpy(x):
  """Flatten a torch tensor `x` and convert it to numpy."""
  return x.detach().cpu().numpy().reshape((-1,))


def from_flattened_numpy(x, shape):
  """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
  return torch.from_numpy(x.reshape(shape))

================================================
FILE: GMeshDiffusion/lib/diffusion/sampling.py
================================================
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""
import functools

import torch
import numpy as np
import abc

from .models.utils import from_flattened_numpy, to_flattened_numpy, get_score_fn
from scipy import integrate
from . import sde_lib
from .models import utils as mutils

import logging
import tqdm

_CORRECTORS = {}
_PREDICTORS = {}


def register_predictor(cls=None, *, name=None):
    """A decorator for registering predictor classes."""

    def _register(cls):
        if name is None:
            local_name = cls.__name__
        else:
            local_name = name
        if local_name in _PREDICTORS:
            raise ValueError(f'Already registered model with name: {local_name}')
        _PREDICTORS[local_name] = cls
        return cls

    if cls is None:
        return _register
    else:
        return _register(cls)


def register_corrector(cls=None, *, name=None):
    """A decorator for registering corrector classes."""

    def _register(cls):
        if name is None:
            local_name = cls.__name__
        else:
            local_name = name
        if local_name in _CORRECTORS:
            raise ValueError(f'Already registered model with name: {local_name}')
        _CORRECTORS[local_name] = cls
        return cls

    if cls is None:
        return _register
    else:
        return _register(cls)


def get_predictor(name):
    return _PREDICTORS[name]


def get_corrector(name):
    return _CORRECTORS[name]


def get_sampling_fn(config, sde, shape, inverse_scaler, eps, grid_mask=None, return_traj=False, pred_type='noise'):
    """Create a sampling function.

    Args:
        config: A `ml_collections.ConfigDict` object that contains all configuration information.
        sde: A `sde_lib.SDE` object that represents the forward SDE.
        shape: A sequence of integers representing the expected shape of a single sample.
        inverse_scaler: The inverse data normalizer function.
        eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.

    Returns:
        A function that takes random states and a replicated training state and outputs samples with the
            trailing dimensions matching `shape`.
    """

    sampler_name = config.sampling.method
    # Probability flow ODE sampling with black-box ODE solvers
    # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.
    if sampler_name.lower() == 'pc':
        predictor = get_predictor(config.sampling.predictor.lower())
        corrector = get_corrector(config.sampling.corrector.lower())
        sampling_fn = get_pc_sampler(sde=sde,
                                    shape=shape,
                                    predictor=predictor,
                                    corrector=corrector,
                                    inverse_scaler=inverse_scaler,
                                    snr=config.sampling.snr,
                                    n_steps=config.sampling.n_steps_each,
                                    probability_flow=config.sampling.probability_flow,
                                    continuous=config.training.continuous,
                                    denoise=config.sampling.noise_removal,
                                    eps=eps,
                                    device=config.device,
                                    grid_mask=grid_mask,
                                    return_traj=return_traj,
                                    pred_type=pred_type,
                                    use_occ=config.model.use_occ_grid)
    elif sampler_name.lower() == 'ddim':
        predictor = get_predictor('ddim')
        sampling_fn = get_ddim_sampler(sde=sde,
                                    shape=shape,
                                    predictor=predictor,
                                    inverse_scaler=inverse_scaler,
                                    n_steps=config.sampling.n_steps_each,
                                    denoise=config.sampling.noise_removal,
                                    eps=eps,
                                    device=config.device,
                                    grid_mask=grid_mask,
                                    pred_type=pred_type,
                                    use_occ=config.model.use_occ_grid)
    else:
        raise ValueError(f"Sampler name {sampler_name} unknown.")

    return sampling_fn


class Predictor(abc.ABC):
    """The abstract class for a predictor algorithm."""

    def __init__(self, sde, score_fn, probability_flow=False):
        super().__init__()
        self.sde = sde
        # Compute the reverse SDE/ODE
        self.rsde = sde.reverse(score_fn, probability_flow)
        self.score_fn = score_fn

    @abc.abstractmethod
    def update_fn(self, x, t):
        """One update of the predictor.

        Args:
            x: A PyTorch tensor representing the current state
            t: A Pytorch tensor representing the current time step.

        Returns:
            x: A PyTorch tensor of the next state.
            x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
        """
        pass


class Corrector(abc.ABC):
    """The abstract class for a corrector algorithm."""

    def __init__(self, sde, score_fn, snr, n_steps):
        super().__init__()
        self.sde = sde
        self.score_fn = score_fn
        self.snr = snr
        self.n_steps = n_steps

    @abc.abstractmethod
    def update_fn(self, x, t):
        """One update of the corrector.

        Args:
            x: A PyTorch tensor representing the current state
            t: A PyTorch tensor representing the current time step.

        Returns:
            x: A PyTorch tensor of the next state.
            x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
        """
        pass


@register_predictor(name='euler_maruyama')
class EulerMaruyamaPredictor(Predictor):
    def __init__(self, sde, score_fn, probability_flow=False):
        super().__init__(sde, score_fn, probability_flow)

    def update_fn(self, x, t):
        dt = -1. / self.rsde.N
        z = torch.randn_like(x)
        drift, diffusion = self.rsde.sde(x, t)
        x_mean = x + drift * dt
        x = x_mean + diffusion[:, None, None, None, None] * np.sqrt(-dt) * z
        return x, x_mean


@register_predictor(name='reverse_diffusion')
class ReverseDiffusionPredictor(Predictor):
    def __init__(self, sde, score_fn, probability_flow=False):
        super().__init__(sde, score_fn, probability_flow)

    def update_fn(self, x, t):
        f, G = self.rsde.discretize(x, t)
        z = torch.randn_like(x)
        x_mean = x - f
        x = x_mean + G[:, None, None, None, None] * z
        return x, x_mean


@register_predictor(name='ancestral_sampling')
class AncestralSamplingPredictor(Predictor):
    """The ancestral sampling predictor. Currently only supports VE/VP SDEs."""

    def __init__(self, sde, score_fn, probability_flow=False):
        super().__init__(sde, score_fn, probability_flow)
        if not isinstance(sde, sde_lib.VPSDE):
            raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
        assert not probability_flow, "Probability flow not supported by ancestral sampling"

    def vpsde_update_fn(self, x, t):
        sde = self.sde
        timestep = (t * (sde.N - 1) / sde.T).long()
        beta = sde.discrete_betas.to(t.device)[timestep]
        score = self.score_fn(x, t)
        x_mean = (x + beta[:, None, None, None, None] * score) / torch.sqrt(1. - beta)[:, None, None, None, None]
        noise = torch.randn_like(x)
        x = x_mean + torch.sqrt(beta)[:, None, None, None, None] * noise
        return x, x_mean

    def update_fn(self, x, t):
        if isinstance(self.sde, sde_lib.VPSDE):
            return self.vpsde_update_fn(x, t)
        else:
            raise NotImplementedError


@register_predictor(name='none')
class NonePredictor(Predictor):
    """An empty predictor that does nothing."""

    def __init__(self, sde, score_fn, probability_flow=False):
        pass

    def update_fn(self, x, t):
        return x, x

@register_predictor(name='ddim')
class DDIMPredictor(Predictor):
    def __init__(self, sde, score_fn, probability_flow=False):
        super().__init__(sde, score_fn, probability_flow)


    def update_fn(self, x, t, tprev=None):
        x, x0_pred = self.rsde.discretize_ddim(x, t, tprev=tprev)
        return x, x0_pred

@register_corrector(name='langevin')
class LangevinCorrector(Corrector):
    def __init__(self, sde, score_fn, snr, n_steps):
        super().__init__(sde, score_fn, snr, n_steps)
        if not isinstance(sde, sde_lib.VPSDE):
            raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")

    def update_fn(self, x, t):
        sde = self.sde
        score_fn = self.score_fn
        n_steps = self.n_steps
        target_snr = self.snr
        if isinstance(sde, sde_lib.VPSDE):
            timestep = (t * (sde.N - 1) / sde.T).long()
            alpha = sde.alphas.to(t.device)[timestep]
        else:
            alpha = torch.ones_like(t)

        for i in range(n_steps):
            grad = score_fn(x, t)
            noise = torch.randn_like(x)
            grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
            noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
            step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
            x_mean = x + step_size[:, None, None, None, None] * grad
            x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None, None] * noise

        return x, x_mean


@register_corrector(name='ald')
class AnnealedLangevinDynamics(Corrector):
    """The original annealed Langevin dynamics predictor in NCSN/NCSNv2.

    We include this corrector only for completeness. It was not directly used in our paper.
    """

    def __init__(self, sde, score_fn, snr, n_steps):
        super().__init__(sde, score_fn, snr, n_steps)
        if not isinstance(sde, sde_lib.VPSDE):
            raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")

    def update_fn(self, x, t):
        sde = self.sde
        score_fn = self.score_fn
        n_steps = self.n_steps
        target_snr = self.snr
        if isinstance(sde, sde_lib.VPSDE):
            timestep = (t * (sde.N - 1) / sde.T).long()
            alpha = sde.alphas.to(t.device)[timestep]
        else:
            alpha = torch.ones_like(t)

        std = self.sde.marginal_prob(x, t)[1]

        for i in range(n_steps):
            grad = score_fn(x, t)
            noise = torch.randn_like(x)
            step_size = (target_snr * std) ** 2 * 2 * alpha
            x_mean = x + step_size[:, None, None, None, None] * grad
            x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None, None, None]

        return x, x_mean


@register_corrector(name='none')
class NoneCorrector(Corrector):
    """An empty corrector that does nothing."""

    def __init__(self, sde, score_fn, snr, n_steps):
        pass

    def update_fn(self, x, t):
        return x, x


def shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous, pred_type='noise'):
    """A wrapper that configures and returns the update function of predictors."""
    score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous, pred_type=pred_type)
    if predictor is None:
        # Corrector-only sampler
        predictor_obj = NonePredictor(sde, score_fn, probability_flow)
    else:
        predictor_obj = predictor(sde, score_fn, probability_flow)
    return predictor_obj.update_fn(x, t)


def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, snr, n_steps, pred_type='noise'):
    """A wrapper tha configures and returns the update function of correctors."""
    score_fn = mutils.get_score_fn(sde, model, train=False, continuous=continuous, pred_type=pred_type)
    if corrector is None:
        # Predictor-only sampler
        corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
    else:
        corrector_obj = corrector(sde, score_fn, snr, n_steps)
    return corrector_obj.update_fn(x, t)


def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr,
                                     n_steps=1, probability_flow=False, continuous=False,
                                     denoise=True, eps=1e-3, device='cuda', 
                                     grid_mask=None, return_traj=False, pred_type='noise', use_occ=False):
    """Create a Predictor-Corrector (PC) sampler.

    Args:
        sde: An `sde_lib.SDE` object representing the forward SDE.
        shape: A sequence of integers. The expected shape of a single sample.
        predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
        corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
        inverse_scaler: The inverse data normalizer.
        snr: A `float` number. The signal-to-noise ratio for configuring correctors.
        n_steps: An integer. The number of corrector steps per predictor update.
        probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
        continuous: `True` indicates that the score model was continuously trained.
        denoise: If `True`, add one-step denoising to the final samples.
        eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.
        device: PyTorch device.

    Returns:
        A sampling function that returns samples and the number of function evaluations during sampling.
    """
    # Create predictor & corrector update functions
    predictor_update_fn = functools.partial(shared_predictor_update_fn,
                                            sde=sde,
                                            predictor=predictor,
                                            probability_flow=probability_flow,
                                            continuous=continuous,
                                            pred_type=pred_type)
    corrector_update_fn = functools.partial(shared_corrector_update_fn,
                                            sde=sde,
                                            corrector=corrector,
                                            continuous=continuous,
                                            snr=snr,
                                            n_steps=n_steps,
                                            pred_type=pred_type)

    def pc_sampler(model, 
            partial=None, partial_grid_mask=None, partial_channel=0, 
            freeze_iters=None):
        """ The PC sampler funciton.

        Args:
            model: A score model.
        Returns:
            Samples, number of function evaluations.
        """
        with torch.no_grad():

            if freeze_iters is None:
                freeze_iters = sde.N + 10 # just some randomly large number greater than sde.N
            timesteps = torch.linspace(sde.T, eps, sde.N, device=device)

            mask = model.feature_mask

            if pred_type == 'noise':
                def compute_xzero(sde, model, x, t, grid_mask_input):
                    timestep_int = (t * (sde.N - 1) / sde.T).long()
                    alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda()
                    alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda()
                    alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda()
                    alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda()
                    score_pred = model(x, t * torch.ones(shape[0], device=x.device))
                    x0_pred_scaled = (x - alphas2 * score_pred)
                    x0_pred = x0_pred_scaled / alphas1
                    x0_pred = x0_pred.clamp(-1, 1)
                    return x0_pred * grid_mask_input
            elif pred_type == 'x0':
                def compute_xzero(sde, model, x, t, grid_mask_input):
                    timestep_int = (t * (sde.N - 1) / sde.T).long()
                    alphas1 = sde.sqrt_alphas_cumprod[timestep_int].cuda()
                    alphas2 = sde.sqrt_1m_alphas_cumprod[timestep_int].cuda()
                    alphas1_prev = sde.sqrt_alphas_cumprod[timestep_int - 1].cuda()
                    alphas2_prev = sde.sqrt_1m_alphas_cumprod[timestep_int - 1].cuda()
                    x0_pred = model(x, t * torch.ones(shape[0], device=x.device))
                    return x0_pred * grid_mask_input

        
            # Initial sample
            x = sde.prior_sampling(shape).to(device)
            assert len(x.size()) == 5

            traj_buffer = []
        
            if partial is not None:
                assert len(partial.size()) == 5
                t = timesteps[0]
                vec_t = torch.ones(shape[0], device=t.device) * t
                x[:, partial_channel] = partial[:, partial_channel] * grid_mask[:, partial_channel]

                partial_mean, partial_std = sde.marginal_prob(x, vec_t)
                sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device)
                x[:, partial_channel] = (
                    x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) 
                    + sampled_update[:, partial_channel] * partial_mask[:, partial_channel]
                ) * grid_mask[:, partial_channel]


            if partial is not None:
                x_mean = x
                for i in tqdm.trange(sde.N):
                    t = timesteps[i]
                    vec_t = torch.ones(shape[0], device=t.device) * t

                    x, x_mean = corrector_update_fn(x, vec_t, model=model)
                    x, x_mean = x * grid_mask, x_mean * grid_mask
                    x, x_mean = predictor_update_fn(x, vec_t, model=model)
                    x, x_mean = x * grid_mask, x_mean * grid_mask


                    if i != sde.N - 1 and i < freeze_iters:

                        x[:, partial_channel] = (x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel]
                        x_mean[:, partial_channel] = (x_mean[:, partial_channel] * (1 - partial_mask[:, partial_channel]) + partial[:, partial_channel] * partial_mask[:, partial_channel]) * grid_mask[:, partial_channel]

                        ### add noise to the condition x0_star
                        partial_mean, partial_std = sde.marginal_prob(x, timesteps[i] * torch.ones(shape[0], device=t.device))
                        sampled_update = partial_mean[:, partial_channel] + partial_std[:, None, None, None] * torch.randn_like(partial_mean[:, partial_channel], device=partial_std.device)
                        x[:, partial_channel] = (
                            x[:, partial_channel] * (1 - partial_mask[:, partial_channel]) 
                            + sampled_update * partial_mask[:, partial_channel]
                        ) * grid_mask[:, partial_channel]
                        x_mean[:, partial_channel] = x[:, partial_channel]

            else:

                for i in tqdm.trange(sde.N - 1):
                    t = timesteps[i]

                    with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False):
                        vec_t = torch.ones(shape[0], device=t.device) * t
                        x, x_mean = corrector_update_fn(x, vec_t, model=model)
                        x, x_mean = x * mask, x_mean * mask
                        x, x_mean = predictor_update_fn(x, vec_t, model=model)
                        x, x_mean = x * mask, x_mean * mask
                        print(x.min(), x.max())

                    if return_traj and i >= 700 and i % 10 == 0:
                        traj_buffer.append(compute_xzero(sde, model, x, t, grid_mask))

            if return_traj:
                return traj_buffer, sde.N * (n_steps + 1)
            return inverse_scaler(x_mean * mask if denoise else x * mask), sde.N * (n_steps + 1)

    return pc_sampler

def ddim_predictor_update_fn(x, t, tprev, sde, model, predictor, probability_flow, continuous, pred_type='noise'):
    """A wrapper that configures and returns the update function of predictors."""
    assert not continuous
    score_fn = mutils.get_score_fn(sde, model, train=False, continuous=False, std_scale=False, pred_type=pred_type)
    if predictor is None:
        # Corrector-only sampler
        predictor_obj = NonePredictor(sde, score_fn, probability_flow)
    else:
        predictor_obj = predictor(sde, score_fn, probability_flow)
    return predictor_obj.update_fn(x, t, tprev)

def get_ddim_sampler(sde, shape, predictor, inverse_scaler, n_steps=1,
                    denoise=False, eps=1e-3, device='cuda', grid_mask=None, pred_type='noise', use_occ=False):
    """Probability flow ODE sampler with the black-box ODE solver.

    Args:
        sde: An `sde_lib.SDE` object that represents the forward SDE.
        shape: A sequence of integers. The expected shape of a single sample.
        inverse_scaler: The inverse data normalizer.
        denoise: If `True`, add one-step denoising to final samples.
        eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.
        device: PyTorch device.

    Returns:
        A sampling function that returns samples and the number of function evaluations during sampling.
    """

    predictor_update_fn = functools.partial(ddim_predictor_update_fn,
                                            sde=sde,
                                            predictor=predictor,
                                            probability_flow=False,
                                            continuous=False,
                                            pred_type=pred_type)

    def ddim_sampler(model, schedule='quad', num_steps=100, x0=None, x0_occ=None,
            partial=None, partial_grid_mask=None, partial_channel=0):
        """ The PC sampler funciton.

        Args:
            model: A score model.
        Returns:
            Samples, number of function evaluations.
        """
        with torch.no_grad():
            print(device)
            if x0 is not None:
                x = x0.to(device)
            else:
                # Initial sample
                x = sde.prior_sampling(shape).to(device)
            
            mask = model.feature_mask
            if use_occ:
                occ_mask = model.occ_mask.float()
                if x0_occ is not None:
                    x_occ = x0_occ.to(device)
                else:
                    # Initial sample
                    x_occ = sde.prior_sampling((x.size(0), 1, x.size(2)*2, x.size(3)*2, x.size(4)*2)).to(device)
        
                x = (x.float() * mask, x_occ.float() * occ_mask)

            if partial is not None:
                x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask

            timesteps = torch.linspace(sde.T, eps, sde.N, device=device)

            if schedule == 'uniform':
                skip = sde.N // num_steps
                seq = range(0, sde.N, skip)
            elif schedule == 'quad':
                seq = (
                    np.linspace(
                        0, np.sqrt(sde.N * 0.8), 100
                    )
                    ** 2
                )
                seq = [int(s) for s in list(seq)]

            timesteps = torch.tensor(seq, device=device) / sde.N

            for i in tqdm.tqdm(list(reversed(range(1, len(timesteps)))), leave=False):
                t = timesteps[i]
                tprev = timesteps[i - 1]
                vec_t = torch.ones(1, device=t.device) * t
                vec_tprev = torch.ones(1, device=t.device) * tprev
                with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
                    x, x0_pred = predictor_update_fn(x, vec_t, model=model, tprev=vec_tprev)
                    if use_occ:
                        x = (x[0] * mask, x[1] * occ_mask)
                        x0_pred = (x0_pred[0] * mask, x0_pred[1] * occ_mask)
                        # print(x[0].min(), x[0].max())
                    else:
                        x, x0_pred = x * mask, x0_pred * mask
                        # print(x.min(), x.max())
                if partial is not None:
                    x[:, partial_channel] = x[:, partial_channel] * (1 - partial_mask) + partial * partial_mask
                    x0_pred[:, partial_channel] = x0_pred[:, partial_channel] * (1 - partial_mask) + partial * partial_mask

            if use_occ:
                encode = False
                return (
                    inverse_scaler(x0_pred[0] * mask if (denoise and not encode) else x[0] * mask),
                    inverse_scaler(x0_pred[1] * occ_mask if (denoise and not encode) else x[1] * occ_mask)
                ), sde.N * (n_steps + 1)
            else:
                encode = False
                return inverse_scaler(x0_pred * mask if (denoise and not encode) else x * mask), sde.N * (n_steps + 1)
    return ddim_sampler


================================================
FILE: GMeshDiffusion/lib/diffusion/sde_lib.py
================================================
"""Abstract SDE classes, Reverse SDE, and VE/VP SDEs."""
import abc
import torch
import numpy as np
import torch.nn.functional as F
import time


class SDE(abc.ABC):
  """SDE abstract class. Functions are designed for a mini-batch of inputs."""

  def __init__(self, N):
    """Construct an SDE.

    Args:
      N: number of discretization time steps.
    """
    super().__init__()
    self.N = N

  @property
  @abc.abstractmethod
  def T(self):
    """End time of the SDE."""
    pass

  @abc.abstractmethod
  def sde(self, x, t):
    pass

  @abc.abstractmethod
  def marginal_prob(self, x, t):
    """Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""
    pass

  @abc.abstractmethod
  def prior_sampling(self, shape):
    """Generate one sample from the prior distribution, $p_T(x)$."""
    pass

  @abc.abstractmethod
  def prior_logp(self, z):
    """Compute log-density of the prior distribution.

    Useful for computing the log-likelihood via probability flow ODE.

    Args:
      z: latent code
    Returns:
      log probability density
    """
    pass

  def discretize(self, x, t):
    """Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.

    Useful for reverse diffusion sampling and probabiliy flow sampling.
    Defaults to Euler-Maruyama discretization.

    Args:
      x: a torch tensor
      t: a torch float representing the time step (from 0 to `self.T`)

    Returns:
      f, G
    """
    dt = 1 / self.N
    drift, diffusion = self.sde(x, t)
    f = drift * dt
    G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
    return f, G

  def reverse(self, score_fn, probability_flow=False):
    """Create the reverse-time SDE/ODE.

    Args:
      score_fn: A time-dependent score-based model that takes x and t and returns the score.
      probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
    """
    N = self.N
    T = self.T
    sde_fn = self.sde
    discretize_fn = self.discretize
    sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
    sqrt_1m_alphas_cumprod = self.sqrt_1m_alphas_cumprod

    # Build the class for reverse-time SDE.
    class RSDE(self.__class__):
      def __init__(self):
        self.N = N
        self.probability_flow = probability_flow

      @property
      def T(self):
        return T

      def sde(self, x, t):
        """Create the drift and diffusion functions for the reverse SDE/ODE."""
        drift, diffusion = sde_fn(x, t)
        score = score_fn(x, t)
        drift = drift - diffusion[:, None, None, None, None] ** 2 * score * (0.5 if self.probability_flow else 1.)
        # Set the diffusion function to zero for ODEs.
        diffusion = 0. if self.probability_flow else diffusion
        return drift, diffusion

      def discretize(self, x, t):
        """Create discretized iteration rules for the reverse diffusion sampler."""
        f, G = discretize_fn(x, t)
        rev_f = f - G[:, None, None, None, None] ** 2 * score_fn(x, t) * (0.5 if self.probability_flow else 1.)
        rev_G = torch.zeros_like(G) if self.probability_flow else G
        return rev_f, rev_G

      def discretize_ddim(self, x, t, tprev=None, encode=False):
        """DDPM discretization."""
        timestep = (t * (N - 1) / T).long()
        timestep_prev = (tprev * (N - 1) / T).long()

        if type(x) == torch.Tensor:
          score = score_fn(x.float(), t.float())

          # alphas1prev_div_alphas1 = torch.exp(log_diff)
          alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
          alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
          alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
          alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
          alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double()
          alphas2prev_div_alphas2 = alphas2_prev.double() / alphas2.double()


          x0_pred_scaled = (x.double() - alphas2.double() * score.double())
          use_clip = False
          if use_clip:
            # raise NotImplementedError
            x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
          score_scaled_t = x - x0_pred_scaled
          x0_pred = x0_pred_scaled / alphas1

          x_new = (
            alphas1prev_div_alphas1.double() * x + 
            (-alphas1prev_div_alphas1 + alphas2prev_div_alphas2.double()) * score_scaled_t.double()
          )
          return x_new, x0_pred
        else:
          score, score_occ = score_fn(x, t.float())
          x, x_occ = x

          alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
          alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
          alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
          alphas2_prev = sqrt_1m_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
          alphas1prev_div_alphas1 = alphas1_prev / alphas1
          alphas2prev_div_alphas2 = alphas2_prev / alphas2

          x0_pred_scaled = (x - alphas2 * score)
          x0_occ_pred_scaled = (x_occ - alphas2 * score_occ)
          use_clip = False
          if use_clip:
            x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
            x0_occ_pred_scaled = x0_occ_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
          score_scaled_t = x - x0_pred_scaled
          x0_pred = x0_pred_scaled / alphas1
          score_occ_scaled_t = x_occ - x0_occ_pred_scaled
          x0_occ_pred = x0_occ_pred_scaled / alphas1

          x_new = (
            alphas1prev_div_alphas1 * x + 
            (-alphas1prev_div_alphas1 + alphas2prev_div_alphas2) * score_scaled_t
          )
          x_occ_new = (
            alphas1prev_div_alphas1 * x_occ + 
            (-alphas1prev_div_alphas1 + alphas2prev_div_alphas2) * score_occ_scaled_t
          )
          return (x_new, x_occ_new), (x0_pred, x0_occ_pred)


      def discretize_conditional_ddpm(self, x, t, tprev=None, condition_func=None, condition=False):
        """DDPM discretization."""
        timestep = (t * (N - 1) / T).long()
        timestep_prev = (tprev * (N - 1) / T).long()

        score = score_fn(x.float(), t.float())

        # alphas1prev_div_alphas1 = torch.exp(log_diff)
        alphas1 = sqrt_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
        alphas2 = sqrt_1m_alphas_cumprod[timestep].cuda()[:, None, None, None, None]
        alphas1_prev = sqrt_alphas_cumprod[timestep_prev].cuda()[:, None, None, None, None]
        alphas1prev_div_alphas1 = alphas1_prev.double() / alphas1.double()

        x0_pred_scaled = (x.double() - alphas2.double() * score.double())
        x0_pred_scaled = x0_pred_scaled.clamp(-alphas1[0].squeeze(), alphas1[0].squeeze())
        x0_pred = x0_pred_scaled / alphas1

        if condition is None:
          condition_update = 0
        else:
          if (t - 0.99).mean() < 1e-3:
            x = x0_pred
          condition_update = condition_func(x.float(), condition)

        x_new = (
          x - alphas1prev_div_alphas1.double() * condition_update
        )
        return x_new, x0_pred


    return RSDE()


class VPSDE(SDE):
  def __init__(self, beta_min=0.1, beta_max=20, N=1000):
    """Construct a Variance Preserving SDE.

    Args:
      beta_min: value of beta(0)
      beta_max: value of beta(1)
      N: number of discretization steps
    """
    super().__init__(N)
    self.beta_0 = beta_min
    self.beta_1 = beta_max
    self.N = N
    self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N).cuda()
    self.alphas = 1. - self.discrete_betas
    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
    self.alphas_cumprod_ext = torch.cat([torch.tensor([1.0 - 1e-4]).cuda(), torch.cumprod(self.alphas, dim=0)], dim=0)

    self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
    self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

    self.alphas_cumprod = self.alphas_cumprod
    self.alphas_cumprod_ext = self.alphas_cumprod_ext

  @property
  def T(self):
    return 1

  def sde(self, x, t):
    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
    drift = -0.5 * beta_t[:, None, None, None, None] * x
    diffusion = torch.sqrt(beta_t)
    return drift, diffusion

  def marginal_prob(self, x, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = torch.exp(log_mean_coeff[:, None, None, None, None]) * x
    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
    return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape)

  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2, 3, 4)) / 2.
    return logps

  def discretize(self, x, t):
    """DDPM discretization."""
    timestep = (t * (self.N - 1) / self.T).long()
    beta = self.discrete_betas.to(x.device)[timestep]
    alpha = self.alphas.to(x.device)[timestep]
    sqrt_beta = torch.sqrt(beta)
    f = torch.sqrt(alpha)[:, None, None, None, None] * x - x
    G = sqrt_beta
    return f, G

================================================
FILE: GMeshDiffusion/lib/diffusion/trainer.py
================================================
import os
import sys
import numpy as np

import logging
# Keep the import below for registering all model definitions
from .models import unet3d, unet3d_occgrid, unet3d_tet_aware, unet3d_occgrid_v2, unet3d_meshdiffusion

from . import losses
from .models import utils as mutils
from .models.ema import ExponentialMovingAverage
from . import sde_lib
import torch
from torch.utils import tensorboard
from .utils import save_checkpoint, restore_checkpoint
from ..dataset.gshell_dataset import GShellDataset
from ..dataset.gshell_dataset_aug import GShellAugDataset


def train(config):
    """Runs the training pipeline.

    Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
        contains checkpoint training will be resumed from the latest checkpoint.
    """

    workdir = config.training.train_dir
    # Create directories for experimental logs
    logging.info("working dir: {:s}".format(workdir))


    tb_dir = os.path.join(workdir, "tensorboard")
    writer = tensorboard.SummaryWriter(tb_dir)

    # Initialize model.
    score_model = mutils.create_model(config)
    ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
    optimizer = losses.get_optimizer(config, score_model.parameters())
    gradscaler = torch.cuda.amp.GradScaler(enabled=True)

    state = dict(optimizer=optimizer, model=score_model, ema=ema, gradscaler=gradscaler, step=0)


    # Create checkpoints directory
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    # Intermediate checkpoints to resume training after pre-emption in cloud environments
    checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)

    # Resume training when intermediate checkpoints are detected
    state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
    initial_step = int(state['step'])

    print(f"work dir: {workdir}")

    
    try:
        use_occ_grid = config.data.use_occ_grid
    except:
        use_occ_grid = False
    if use_occ_grid:
        train_dataset = GShellAugDataset(config)
    else:
        train_dataset = GShellDataset(config.data.dataset_metapath)


    try:
        collate_fn = train_dataset.collate
    except:
        collate_fn = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=config.training.batch_size, 
        shuffle=True,
        num_workers=config.data.num_workers,
        collate_fn=collate_fn,
        pin_memory=True
    )

    data_iter = iter(train_loader)

    print("data loader set")

    # Setup SDEs
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(config)
    try:
        use_vis_mask = config.model.use_vis_mask
    except:
        use_vis_mask = False
    print('use_vis_mask', use_vis_mask)
    train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,
                                        loss_type=config.training.loss_type,
                                        pred_type=config.model.pred_type,
                                        use_vis_mask=use_vis_mask,
                                        use_occ=use_occ_grid,
                                        use_aux=config.training.use_aux_loss)

    num_train_steps = config.training.n_iters

    # In case there are multiple hosts (e.g., TPU pods), only log to host 0
    logging.info("Starting training loop at step %d." % (initial_step // config.training.num_grad_acc_steps,))


    iter_size = config.training.num_grad_acc_steps
    for step in range(initial_step // iter_size, num_train_steps + 1):
        tmp_loss_dict = {
            'loss_total': 0.0,
            'loss_score': 0.0,
            'loss_reg': 0.0,
        }
        for step_inner in range(iter_size):
            try:
                # batch, batch_mask = next(data_iter)
                batch = next(data_iter)
            except StopIteration:
                # StopIteration is thrown if dataset ends
                # reinitialize data loader 
                data_iter = iter(train_loader)
                batch = next(data_iter)

            
            if type(batch) == dict:
                for k in batch:
                    batch[k] = batch[k].to('cuda', non_blocking=False)
            else:
                batch = batch.to('cuda', non_blocking=False)

            # Execute one training step
            clear_grad_flag = (step_inner == 0)
            update_param_flag = (step_inner == iter_size - 1)
            loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler)
            for key in loss_dict:
                tmp_loss_dict[key] += loss_dict[key].item() / iter_size

            # print(torch.cuda.memory_summary())

        if step % config.training.log_freq == 0:
            # logging.info("step: %d, training_loss: %.5e" % (step, tmp_loss))
            logging.info(
                "step: %d, loss_total: %.5e, loss_score: %.5e, loss_reg: %.5e" 
                % (step, tmp_loss_dict['loss_total'], tmp_loss_dict['loss_score'], tmp_loss_dict['loss_reg'])
            )
            sys.stdout.flush()
            writer.add_scalar("loss_total", tmp_loss_dict['loss_total'], step)
            writer.add_scalar("loss_score", tmp_loss_dict['loss_score'], step)
            writer.add_scalar("loss_reg", tmp_loss_dict['loss_reg'], step)

        # Save a temporary checkpoint to resume training after pre-emption periodically
        if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
            logging.info(f"save meta at iter {step}")
            save_checkpoint(checkpoint_meta_dir, state)

        # Save a checkpoint periodically and generate samples if needed
        if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
            logging.info(f"save model: {step}-th")
            save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state)


================================================
FILE: GMeshDiffusion/lib/diffusion/trainer_ddp.py
================================================
import os
import sys
import numpy as np

import logging
# Keep the import below for registering all model definitions
from .models import unet3d, unet3d_occgrid, unet3d_tet_aware, unet3d_occgrid_v2, unet3d_meshdiffusion

from . import losses
from .models import utils as mutils
from .models.ema import ExponentialMovingAverage
from . import sde_lib
import torch
from torch.utils import tensorboard
from .utils import save_checkpoint, restore_checkpoint
from ..dataset.gshell_dataset import GShellDataset
from ..dataset.gshell_dataset_aug import GShellAugDataset

from .lion.lion import Lion
import torch.distributed as dist

def train(config):
    """Runs the training pipeline.

    Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
        contains checkpoint training will be resumed from the latest checkpoint.
    """
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    torch.cuda.set_device(rank)
    device = torch.device("cuda", rank)
    print(f"Start running basic DDP example on rank {rank}.")

    # create model and move it to GPU with id rank
    world_size = torch.cuda.device_count()
    device_id = rank % torch.cuda.device_count()

    workdir = config.training.train_dir
    # Create directories for experimental logs
    logging.info("working dir: {:s}".format(workdir))


    tb_dir = os.path.join(workdir, "tensorboard")
    writer = tensorboard.SummaryWriter(tb_dir)

    # Initialize model.
    score_model = mutils.create_model(config, ddp=True, rank=rank)
    ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
    optimizer = losses.get_optimizer(config, score_model.parameters())
    gradscaler = torch.cuda.amp.GradScaler(growth_interval=config.training.gradscaler_growth_interval)

    state = dict(optimizer=optimizer, model=score_model, ema=ema, gradscaler=gradscaler, step=0)


    # Create checkpoints directory
    checkpoint_dir = os.path.join(workdir, "checkpoints")
    # Intermediate checkpoints to resume training after pre-emption in cloud environments
    checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(os.path.dirname(checkpoint_meta_dir), exist_ok=True)

    # Resume training when intermediate checkpoints are detected
    state = restore_checkpoint(checkpoint_meta_dir, state, config.device, rank=rank)
    initial_step = int(state['step'])

    print(f"work dir: {workdir}")

    try:
        use_occ_grid = config.data.use_occ_grid
    except:
        use_occ_grid = False
    if use_occ_grid:
        train_dataset = GShellAugDataset(config)
    else:
        train_dataset = GShellDataset(config.data.dataset_metapath)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
    	train_dataset,
    	num_replicas=world_size,
    	rank=rank
    )

    try:
        collate_fn = train_dataset.collate
    except:
        collate_fn = None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=config.training.batch_size, 
        num_workers=config.data.num_workers,
        # pin_memory=True,
        sampler=train_sampler,
        collate_fn=collate_fn
    )

    data_iter = iter(train_loader)

    print("data loader set")

    # Setup SDEs
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(config)
    try:
        use_vis_mask = config.model.use_vis_mask
    except:
        use_vis_mask = False
    print('use_vis_mask', use_vis_mask)
    train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,
                                        loss_type=config.training.loss_type,
                                        pred_type=config.model.pred_type,
                                        use_vis_mask=use_vis_mask,
                                        use_occ=use_occ_grid,
                                        use_aux=config.training.use_aux_loss)

    num_train_steps = config.training.n_iters

    # In case there are multiple hosts (e.g., TPU pods), only log to host 0
    logging.info("Starting training loop at step %d." % (initial_step // config.training.num_grad_acc_steps,))

    iter_size = config.training.num_grad_acc_steps
    epoch = 0
    train_sampler.set_epoch(epoch)
    for step in range(initial_step // iter_size, num_train_steps + 1):
        tmp_loss_dict = {
            'loss_total': 0.0,
            'loss_score': 0.0,
            'loss_reg': 0.0,
        }
        for step_inner in range(iter_size):
            try:
                # batch, batch_mask = next(data_iter)
                batch = next(data_iter)
            except StopIteration:
                # StopIteration is thrown if dataset ends
                # reinitialize data loader 
                epoch += 1
                train_sampler.set_epoch(epoch)
                data_iter = iter(train_loader)
                batch = next(data_iter)

            if type(batch) == dict:
                for k in batch:
                    batch[k] = batch[k].to(rank, non_blocking=False)
            else:
                batch = batch.to(rank, non_blocking=False)

            # Execute one training step
            clear_grad_flag = (step_inner == 0)
            update_param_flag = (step_inner == iter_size - 1)
            if not update_param_flag:
                with score_model.no_sync():
                    loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler)
            else:
                loss_dict = train_step_fn(state, batch, clear_grad=clear_grad_flag, update_param=update_param_flag, gradscaler=gradscaler)
            for key in loss_dict:
                tmp_loss_dict[key] += loss_dict[key].item() / iter_size

            # print(torch.cuda.memory_summary())

        if step % config.training.log_freq == 0:
            loss = tmp_loss_dict['loss_total']
            loss = torch.tensor(loss / world_size).to(rank)

            # logging.info("step: %d, training_loss: %.5e" % (step, tmp_loss))
            dist.reduce(loss, dst=0, op=dist.ReduceOp.SUM)
            if rank == 0:
                loss = loss.item()
                logging.info("step: %d, loss: %.5e, scale: %.5e" % (step, loss, gradscaler.get_scale()))
                sys.stdout.flush()
                writer.add_scalar("loss", loss, step)

        if rank == 0:
            # Save a temporary checkpoint to resume training after pre-emption periodically
            if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
                logging.info(f"save meta at iter {step}")
                save_checkpoint(checkpoint_meta_dir, state)

            # Save a checkpoint periodically and generate samples if needed
            if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
                logging.info(f"save model: {step}-th")
                save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{step}.pth'), state)

    dist.destroy_process_group()

================================================
FILE: GMeshDiffusion/lib/diffusion/utils.py
================================================
import torch
import os
import logging


def restore_checkpoint(ckpt_dir, state, device, strict=False, rank=None):
  if not os.path.exists(ckpt_dir):
    os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
    logging.warning(f"No checkpoint found at {ckpt_dir}. "
                    f"Returned the same state as input")
    if strict:
      raise
    return state
  else:
    if rank is not None:
      device = f"cuda:{rank}"
    # loaded_state = torch.load(ckpt_dir, map_location=device)
    loaded_state = torch.load(ckpt_dir, map_location='cpu')
    state['optimizer'].load_state_dict(loaded_state['optimizer'])
    try:
      state['model'].load_state_dict(loaded_state['model'], strict=False)
    except:
      consume_prefix_in_state_dict_if_present(loaded_state['model'])
      state['model'].load_state_dict(loaded_state['model'], strict=False)
    state['ema'].load_state_dict(loaded_state['ema'], device=device)
    state['step'] = loaded_state['step']
    state['model'].to(device)
    try:
      state['gradscaler'].load_state_dict(loaded_state['gradscaler'])
      # state['gradscaler'].to(device)
    except:
      # raise
      pass
    torch.cuda.empty_cache()
    return state


def save_checkpoint(ckpt_dir, state):
  saved_state = {
    'optimizer': state['optimizer'].state_dict(),
    'model': state['model'].state_dict(),
    'ema': state['ema'].state_dict(),
    'step': state['step'],
    'gradscaler': state['gradscaler'].state_dict()
  }
  torch.save(saved_state, ckpt_dir)

================================================
FILE: GMeshDiffusion/main_diffusion.py
================================================
"""Training and evaluation"""

from absl import app
from absl import flags
from ml_collections.config_flags import config_flags

import lib.diffusion.trainer as trainer
import lib.diffusion.evaler as evaler


FLAGS = flags.FLAGS

config_flags.DEFINE_config_file(
    "config", None, "diffusion configs", lock_config=False)
flags.DEFINE_enum("mode", None, ["train", "uncond_gen", "cond_gen", "uncond_gen_interp"], "Running mode")
flags.mark_flags_as_required(["config", "mode"])


def main(argv):
    if FLAGS.mode == 'train':
        trainer.train(FLAGS.config)
    elif FLAGS.mode == 'uncond_gen':
        evaler.uncond_gen(FLAGS.config)
    elif FLAGS.mode == 'uncond_gen_interp':
        evaler.uncond_gen_interp(FLAGS.config)
    elif FLAGS.mode == 'cond_gen':
        evaler.cond_gen(FLAGS.config)

if __name__ == "__main__":
  app.run(main)


================================================
FILE: GMeshDiffusion/main_diffusion_ddp.py
================================================
"""Training and evaluation"""

from absl import app
from absl import flags
from ml_collections.config_flags import config_flags

import lib.diffusion.trainer_ddp as trainer
import lib.diffusion.evaler as evaler




FLAGS = flags.FLAGS

config_flags.DEFINE_config_file(
    "config", None, "diffusion configs", lock_config=False)
flags.DEFINE_enum("mode", None, ["train", "uncond_gen", "cond_gen", "uncond_gen_interp"], "Running mode")
flags.mark_flags_as_required(["config", "mode"])

def main(argv):
    print(FLAGS.config)
    if FLAGS.mode == 'train':
        trainer.train(FLAGS.config)

if __name__ == "__main__":
  app.run(main)


================================================
FILE: GMeshDiffusion/metadata/get_splits_lower.py
================================================
import os
import random

random.seed(42)

split_ratio = 0.9
data_root = 'PLACEHOLDER'
grid_root = os.path.join(data_root, 'grid')
occgrid_root = os.path.join(data_root, 'grid_aug')
data_path_list = sorted([os.path.join(data_root, fpath) for fpath in os.listdir(data_root)])

random.shuffle(data_path_list)

n_train = int(len(data_path_list) * split_ratio)
train_list = data_path_list[:n_train]
test_list = data_path_list[n_train:]

with open('lower_res64_grid_train.txt', 'w') as f:
    f.write('\n'.join(train_list))

with open('lower_res64_grid_test.txt', 'w') as f:
    f.write('\n'.join(test_list))


occgrid_train_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in train_list]
occgrid_test_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in test_list]

with open('lower_res64_occgrid_train.txt', 'w') as f:
    f.write('\n'.join(occgrid_train_list))

with open('lower_res64_occgrid_test.txt', 'w') as f:
    f.write('\n'.join(occgrid_test_list))



================================================
FILE: GMeshDiffusion/metadata/get_splits_upper.py
================================================
import os
import random

random.seed(42)

split_ratio = 0.9
data_root = 'PLACEHOLDER'
grid_root = os.path.join(data_root, 'grid')
occgrid_root = os.path.join(data_root, 'grid_aug')
data_path_list = sorted([os.path.join(data_root, fpath) for fpath in os.listdir(data_root)])

random.shuffle(data_path_list)

n_train = int(len(data_path_list) * split_ratio)
train_list = data_path_list[:n_train]
test_list = data_path_list[n_train:]

with open('upper_res64_grid_train.txt', 'w') as f:
    f.write('\n'.join(train_list))

with open('upper_res64_grid_test.txt', 'w') as f:
    f.write('\n'.join(test_list))


occgrid_train_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in train_list]
occgrid_test_list = [os.path.join(occgrid_root, x.split('/')[-1]) for x in test_list]

with open('upper_res64_occgrid_train.txt', 'w') as f:
    f.write('\n'.join(occgrid_train_list))

with open('upper_res64_occgrid_test.txt', 'w') as f:
    f.write('\n'.join(occgrid_test_list))



================================================
FILE: GMeshDiffusion/metadata/save_tet_info.py
================================================
'''
    Storing tet-grid related meta-info into a single file
'''

import numpy as np
import torch
import os
import tqdm
import argparse

from itertools import combinations


def tet_to_grids(vertices, values_list, grid_size):
    grid = torch.zeros(12, grid_size, grid_size, grid_size, device=vertices.device)
    with torch.no_grad():
        for k, values in enumerate(values_list):
            if k == 0:
                grid[k, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.squeeze()
            else:
                grid[1:4, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.transpose(0, 1)
    return grid

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='nvdiffrec')
    parser.add_argument('-res', '--resolution', type=int)
    parser.add_argument('-r', '--root', type=str)
    parser.add_argument('-s', '--source', type=str)
    parser.add_argument('-t', '--target', type=str)
    FLAGS = parser.parse_args()

    tet_path = f'./tets/{FLAGS.resolution}_tets_cropped_reordered.npz'
    tet = np.load(tet_path)
    vertices = torch.tensor(tet['vertices']).cuda()
    indices = torch.tensor(tet['indices']).long().cuda()

    edges = torch.tensor(tet['edges']).long().cuda()
    tet_edges = torch.tensor(tet['tet_edges']).long().view(-1, 2).cuda()

    vertices_unique = vertices[:].unique()
    dx = vertices_unique[1] - vertices_unique[0]
    dx = dx / 2.0 ### denser grid
    vertices_discretized = (
        ((vertices - vertices.min()) / dx)
    ).long()

    midpoints = (vertices_discretized[edges[:, 0]] + vertices_discretized[edges[:, 1]]) / 2.0
    midpoints_dicretized = midpoints.long()

    tet_verts = vertices_discretized[indices.view(-1)].view(-1, 4, 3)
    tet_center = tet_verts.float().mean(dim=1)
    tet_center_discretized = tet_center.long()


    edge_ind_list = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
    msdf_tetedges = []
    msdf_from_tetverts = []
    for i in range(5):
        for j in range(i+1, 6):
            if (edge_ind_list[i][0] == edge_ind_list[j][0]
                or edge_ind_list[i][0] == edge_ind_list[j][1]
                or edge_ind_list[i][1] == edge_ind_list[j][0]
                or edge_ind_list[i][1] == edge_ind_list[j][1]
            ):
                msdf_tetedges.append(i)
                msdf_tetedges.append(j)
                msdf_from_tetverts.extend([edge_ind_list[i][0], edge_ind_list[i][1], edge_ind_list[j][0], edge_ind_list[j][1]])
    msdf_tetedges = torch.tensor(msdf_tetedges)
    msdf_from_tetverts = torch.tensor(msdf_from_tetverts)
    print(msdf_tetedges)
    print(msdf_tetedges.size())



    tet_edges = tet_edges.view(-1, 2)
    msdf_tetedges = msdf_tetedges.view(-1)
    tet_edgenodes_pos = (vertices_discretized[tet_edges[:, 0]] + vertices_discretized[tet_edges[:, 1]]) / 2.0
    tet_edgenodes_pos = tet_edgenodes_pos.view(-1, 6, 2)
    occ_edge_pos = tet_edgenodes_pos[:, msdf_tetedges].view(-1, 12, 2, 3)
    

    edge_twopoint_order = torch.sign(occ_edge_pos[:, :, 0, :] - occ_edge_pos[:, :, 1, :])
    edge_twopoint_order_binary_code = (edge_twopoint_order * torch.tensor([16, 4, 1], device=edge_twopoint_order.device).view(1, 1, -1)).sum(dim=-1)
    edge_twopoint_order_binary_code = torch.stack([edge_twopoint_order_binary_code, -edge_twopoint_order_binary_code], dim=-1)
    _, edge_twopoint_order = edge_twopoint_order_binary_code.sort(dim=-1)

    occ_edge_cano_order = torch.arange(2).view(1, 1, 2).expand(occ_edge_pos.size(0), 12, 2).cuda()
    occ_edge_cano_order = torch.gather(
        input=occ_edge_cano_order,
        dim=-1,
        index=edge_twopoint_order
    )

    tet_edges = tet_edges.view(-1)

    torch.save({
        'tet_v_pos': vertices,
        'tet_edge_vpos': vertices[tet_edges].view(-1, 2, 3),
        'tet_edge_pix_loc': vertices_discretized[tet_edges].view(-1, 2, 3),
        'tet_center_loc': tet_center_discretized,
        'msdf_edges': msdf_tetedges.view(12, 2),
        'occ_edge_cano_order': occ_edge_cano_order
    }, 'tet_info.pt')


================================================
FILE: GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py
================================================
import numpy as np
import torch
import os
import tqdm
import argparse

def tet_to_grids(vertices, values_list, grid_size):
    grid = torch.zeros(4, grid_size, grid_size, grid_size, device=vertices.device)
    with torch.no_grad():
        for k, values in enumerate(values_list):
            if k == 0:
                grid[k, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.squeeze()
            else:
                grid[1:4, vertices[:, 0], vertices[:, 1], vertices[:, 2]] = values.transpose(0, 1)
    return grid

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='nvdiffrec')
    parser.add_argument('-res', '--resolution', type=int)
    parser.add_argument('-ss', '--split-size', type=int, default=int(1e8))
    parser.add_argument('-ind', '--index', type=int)
    parser.add_argument('-r', '--root', type=str)
    parser.add_argument('-s', '--source', type=str)
    parser.add_argument('-t', '--target', type=str)
    FLAGS = parser.parse_args()

    tet_path = f'./tets/{FLAGS.resolution}_tets_cropped_reordered.npz'
    tet = np.load(tet_path)
    vertices = torch.tensor(tet['vertices']).cuda()
    indices = torch.tensor(tet['indices']).long().cuda()

    edges = torch.tensor(tet['edges']).long().cuda()
    tet_edges = torch.tensor(tet['tet_edges']).long().view(-1, 2).cuda()
    
    vertices_unique = vertices[:].unique()
    dx = vertices_unique[1] - vertices_unique[0]
    dx = dx / 2.0 ### denser grid
    vertices_discretized = (
        ((vertices - vertices.min()) / dx)
    ).long()

    print(vertices_discretized.size())
    midpoints = (vertices_discretized[edges[:, 0]] + vertices_discretized[edges[:, 1]]) / 2.0
    midpoints_dicretized = midpoints.long()

    tet_verts = vertices_discretized[indices.view(-1)].view(-1, 4, 3)
    tet_center = tet_verts.float().mean(dim=1)
    tet_center_discretized = tet_center.long()


    global_mask = torch.zeros(4, FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda()
    cat_mask = torch.zeros(FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda()
    global_mask[:4, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] += 1.0
    cat_mask[vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = 1
    global_mask[0, midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] += 1.0
    cat_mask[midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] = -1


    torch.save(global_mask, f'global_mask_res{FLAGS.resolution}.pt')
    torch.save(cat_mask, f'cat_mask_res{FLAGS.resolution}.pt')

    save_folder = FLAGS.root

    grid_folder_base = os.path.join(save_folder, FLAGS.target)
    os.makedirs(grid_folder_base, exist_ok=True)

    print(grid_folder_base)

    edge_ind_list = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
    msdf_tetedges = []
    msdf_from_tetverts = []
    for i in range(5):
        for j in range(i+1, 6):
            if (edge_ind_list[i][0] == edge_ind_list[j][0]
                or edge_ind_list[i][0] == edge_ind_list[j][1]
                or edge_ind_list[i][1] == edge_ind_list[j][0]
                or edge_ind_list[i][1] == edge_ind_list[j][1]
            ):
                msdf_tetedges.append(i)
                msdf_tetedges.append(j)
                msdf_from_tetverts.extend([edge_ind_list[i][0], edge_ind_list[i][1], edge_ind_list[j][0], edge_ind_list[j][1]])
    msdf_tetedges = torch.tensor(msdf_tetedges)
    msdf_from_tetverts = torch.tensor(msdf_from_tetverts)
    print(msdf_tetedges)
    print(msdf_tetedges.size())



    occgrid_mask_already_saved = False
    tets_folder = os.path.join(save_folder, FLAGS.source)

    with torch.no_grad():
        for k in tqdm.trange(FLAGS.split_size):
            global_index = k + FLAGS.index * FLAGS.split_size
            tet_path = os.path.join(tets_folder, 'dmt_dict_{:05d}.pt'.format(global_index))
            if os.path.exists(os.path.join(grid_folder_base, 'grid_{:05d}.pt'.format(global_index))):
                # continue
                pass
            try:
                if os.path.exists(tet_path):
                    tet = torch.load(tet_path, map_location="cuda")

                    sdf = tet['sdf'].view(-1, 1)
                    msdf = tet['msdf'].view(-1, 1)
                    deform = tet['deform']


                    ### resetting sdfs and offsets of all non-mesh-generating tet vertices
                    tet_edges = tet_edges.view(-1, 2)
                    tet_edge_mask = ((torch.sign(sdf[tet_edges[:, 0]]) - torch.sign(sdf[tet_edges[:, 1]])) != 0).bool().squeeze(-1).view(-1, 6)
                    tet_sdf_coeff = (
                        torch.abs(sdf[tet_edges[:, 0]]) 
                        / (torch.abs(sdf[tet_edges[:, 0]] - sdf[tet_edges[:, 1]]) + 1e-10)
                    ).squeeze(-1)
                    tet_sdf_coeff = tet_sdf_coeff.view(-1, 1)
                    midpoint_msdf_tet = msdf[tet_edges[:, 0]] * (1 - tet_sdf_coeff) + msdf[tet_edges[:, 1]] * tet_sdf_coeff
                    midpoint_msdf_tet = midpoint_msdf_tet.view(-1, 6)
                    tet_mask = ((midpoint_msdf_tet > 0) & tet_edge_mask).sum(dim=-1).bool()
                    vert_mask = torch.zeros_like(sdf.squeeze())
                    vert_mask[indices[tet_mask].view(-1)] = 1.0
                    vert_mask = ~vert_mask.bool()
                    msdf[vert_mask] = -1.0
                    deform[vert_mask] = 0.0

                    tet_nonallnegmsdf = (torch.sign(msdf[indices.view(-1)].view(-1, 4)).sum(dim=-1) != -4)
                    vert_mask_nonallnegmsdf = torch.zeros_like(sdf.squeeze())
                    vert_mask_nonallnegmsdf[indices[tet_nonallnegmsdf].view(-1)] = 1.0
                    vert_mask_nonallnegmsdf = ~vert_mask_nonallnegmsdf.bool()
                    sdf[vert_mask_nonallnegmsdf] = 1.0
                    

                    

                    mask = (
                        (torch.sign(sdf[edges[:, 0]]) - torch.sign(sdf[edges[:, 1]]) != 0).bool()
                    )

                    nan_mask = (
                            ((torch.sign(sdf[edges[:, 0]]) + torch.sign(sdf[edges[:, 1]])) == 2)
                            | ((torch.sign(sdf[edges[:, 0]]) + torch.sign(sdf[edges[:, 1]])) == -2) 
                        ).bool().squeeze(-1)

                    original_sdf_coeff = torch.abs(sdf[edges[:, 0]]) / (torch.abs(sdf[edges[:, 0]] - sdf[edges[:, 1]]) + 1e-10)


                    original_sdf_coeff[nan_mask] = torch.nan

                    normalized_sdf_coeff = ((original_sdf_coeff - 0.5) * 2.0)
                    normalized_sdf_coeff = torch.nan_to_num(normalized_sdf_coeff)
                    assert torch.all(normalized_sdf_coeff.abs() <= 1.0)


                    sdf_sign = torch.sign(sdf)
                    sdf_sign[sdf_sign == 0] = 1

                    midpoint_msdf = msdf[edges[:, 0]] * (1 - original_sdf_coeff.view(-1, 1)) + msdf[edges[:, 1]] * original_sdf_coeff.view(-1, 1)
                    midpoint_msdf_sign = torch.sign(midpoint_msdf)
                    midpoint_msdf_sign[midpoint_msdf_sign == 0] = -1
                    midpoint_msdf_sign = midpoint_msdf_sign * mask - (1.0 - mask.float())

                    ############################ Occ Grid ############################


                    tet_edges = tet_edges.view(-1, 2)
                    tet_edge_mask = ((torch.sign(sdf[tet_edges[:, 0]]) - torch.sign(sdf[tet_edges[:, 1]])) != 0).bool().squeeze(-1).view(-1, 6)
                    tet_sdf_coeff = (
                        torch.abs(sdf[tet_edges[:, 0]]) 
                        / (torch.abs(sdf[tet_edges[:, 0]] - sdf[tet_edges[:, 1]]) + 1e-10)
                    ).squeeze(-1)
                    tet_sdf_coeff = tet_sdf_coeff * tet_edge_mask.view(-1)
                    tet_sdf_coeff = tet_sdf_coeff.view(-1, 1)
                    nan_mask = (
                            ((torch.sign(sdf[tet_edges[:, 0]]) + torch.sign(sdf[tet_edges[:, 1]])) == 2)
                            | ((torch.sign(sdf[tet_edges[:, 0]]) + torch.sign(sdf[tet_edges[:, 1]])) == -2) 
                        ).bool().squeeze(-1)
                    tet_sdf_coeff[nan_mask] = torch.nan
                    midpoint_msdf_tet = msdf[tet_edges[:, 0]] * (1 - tet_sdf_coeff) + msdf[tet_edges[:, 1]] * tet_sdf_coeff
                    midpoint_msdf_tet = midpoint_msdf_tet.view(-1, 6)
                    inscribed_edge_twopoint_msdf = midpoint_msdf_tet[:, msdf_tetedges.view(-1)].view(-1, 12, 2)

                    assert ((
                        (tet_edges.view(-1, 6, 2)[:, msdf_tetedges.view(-1), :].view(-1, 24, 2).sum(dim=-1)) - indices[:, msdf_from_tetverts].view(-1, 24, 2).sum(dim=-1)
                    ).sum().item() == 0)

                    assert msdf_tetedges.view(-1).size(0) == 24
                    inscribed_tet_fourpoint_pos = vertices_discretized[indices[:, msdf_from_tetverts].view(-1)].view(-1, 12, 4, 3).to(torch.float64)
                    inscribed_edge_twopoint_pos = inscribed_tet_fourpoint_pos.view(-1, 12, 2, 2, 3).mean(dim=-2)
                    occgrid_loc = inscribed_edge_twopoint_pos.mean(dim=-2)
                    occgrid_loc = (occgrid_loc * 2).to(torch.int64).view(-1, 3)


                    edge_twopoint_order = torch.sign(inscribed_edge_twopoint_pos[:, :, 0, :] - inscribed_edge_twopoint_pos[:, :, 1, :])
                    edge_twopoint_order_binary_code = (edge_twopoint_order * torch.tensor([16, 4, 1], device=edge_twopoint_order.device).view(1, 1, -1)).sum(dim=-1)
                    edge_twopoint_order_binary_code = torch.stack([edge_twopoint_order_binary_code, -edge_twopoint_order_binary_code], dim=-1)
                    _, edge_twopoint_order = edge_twopoint_order_binary_code.sort(dim=-1)

                    inscribed_edge_twopoint_msdf = torch.gather(
                        input=inscribed_edge_twopoint_msdf,
                        dim=-1,
                        index=edge_twopoint_order
                    )

                    mask_msdf = (
                        ((inscribed_edge_twopoint_msdf[:, :, 0] > 0) & (inscribed_edge_twopoint_msdf[:, :, 1] <= 0)) |
                        ((inscribed_edge_twopoint_msdf[:, :, 0] <= 0) & (inscribed_edge_twopoint_msdf[:, :, 1] > 0)) 
                    )
                    msdf_coeff_12 = (
                        torch.abs(inscribed_edge_twopoint_msdf[:, :, 0]) 
                        / (
                            torch.abs(inscribed_edge_twopoint_msdf[:, :, 0] - inscribed_edge_twopoint_msdf[:, :, 1])
                            + 1e-10
                        )
                    )

                    msdf_coeff_12 = (msdf_coeff_12 - 0.5) * 2.0 * mask_msdf
                    msdf_coeff_12 = torch.nan_to_num(msdf_coeff_12)

                    occ_grid = torch.zeros(256, 256, 256, dtype=torch.float, device=msdf_coeff_12.device)
                    occ_grid[occgrid_loc[:, 0], occgrid_loc[:, 1], occgrid_loc[:, 2]] = msdf_coeff_12.view(-1).to(torch.float)

                    if not occgrid_mask_already_saved:
                        occ_grid_mask = torch.zeros(256, 256, 256, dtype=torch.float, device=msdf_coeff_12.device)
                        occ_grid_mask[occgrid_loc[:, 0], occgrid_loc[:, 1], occgrid_loc[:, 2]] = 1
                        torch.save(occ_grid_mask, f'occ_mask_res{FLAGS.resolution}.pt')
                        occgrid_mask_already_saved = True



                    # #################

                    torch.cuda.empty_cache()
                    grid = torch.zeros(4, FLAGS.resolution * 2, FLAGS.resolution * 2, FLAGS.resolution * 2).cuda()
                    grid[0, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = sdf_sign.squeeze()
                    grid[1:4, vertices_discretized[:, 0], vertices_discretized[:, 1], vertices_discretized[:, 2]] = deform.transpose(0, 1)
                    grid[0, midpoints_dicretized[:, 0], midpoints_dicretized[:, 1], midpoints_dicretized[:, 2]] = midpoint_msdf_sign.squeeze()

                    assert grid.abs().max() <= 1

                    save_path = os.path.join(grid_folder_base, 'grid_{:05d}.pt'.format(global_index))
                    torch.save(grid, save_path)

                    save_path = os.path.join(grid_folder_base, 'occgrid_{:05d}.pt'.format(global_index))
                    torch.save(occ_grid, save_path)
                
            except:
                raise

================================================
FILE: GMeshDiffusion/scripts/run_eval_lower_occgrid_normalized.sh
================================================
python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_lower_occgrid_normalized.py \
--config.eval.eval_dir=$EVAL_DIR \
--config.data.root_dir=$REPO_ROOT_DIR \
--config.sampling.method=ddim \
--config.eval.ckpt_path=$CKPT_PATH \
--config.eval.bin_size=30 \
--config.eval.idx $1

================================================
FILE: GMeshDiffusion/scripts/run_eval_upper_occgrid_normalized.sh
================================================
python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_upper_occgrid_normalized.py \
--config.eval.eval_dir=$EVAL_DIR \
--config.data.root_dir=$REPO_ROOT_DIR \
--config.sampling.method=ddim \
--config.eval.ckpt_path=$CKPT_PATH \
--config.eval.bin_size=10 \
--config.eval.idx $1

================================================
FILE: GMeshDiffusion/scripts/run_lower_occgrid_normalized_ddp.sh
================================================
torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_lower_occgrid_normalized.py \
--config.training.train_dir=$SAVE_DIR --config.data.root_dir=$REPO_ROOT_DIR

================================================
FILE: GMeshDiffusion/scripts/run_upper_occgrid_normalized_ddp.sh
================================================
torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_upper_occgrid_normalized.py \
--config.training.train_dir=$SAVE_DIR --config.data.root_dir=$REPO_ROOT_DIR


================================================
FILE: README.md
================================================
<div align="center">
  <img src="assets/gshell_logo.png" width="900"/>
</div>

# Ghost on the Shell: An Expressive Representation of General 3D Shapes


<div align="center">
  <img src="assets/teaser.png" width="900"/>
</div>

## Introduction

This is the official implementation of our paper (ICLR 2024 oral) "Ghost on the Shell: An Expressive Representation of General 3D Shapes" (G-Shell).

G-Shell is a generic and differentiable representation for both watertight and non-watertight meshes. It enables 1) efficient and robust rasterization-based multiview reconstruction and 2) template-free generation of non-watertight meshes.

Please refer to [our project page](https://gshell3d.github.io) and [our paper](https://gshell3d.github.io/static/paper/gshell.pdf) for more details.


## Getting Started

### Requirements


- Python >= 3.8
- CUDA 11.8
- PyTorch == 1.13.1

(Conda installation recommended)

#### Reconstruction

Run the following

```
pip install ninja imageio PyOpenGL glfw xatlas gdown
pip install git+https://github.com/NVlabs/nvdiffrast/
pip install --global-option="--no-networks" git+https://github.com/NVlabs/tiny-cuda-nn#subdirectory=bindings/torch
```

Follow the instructions [here](https://github.com/NVIDIAGameWorks/kaolin/) to install kaolin.

Download the tet-grid files ([res128](https://drive.google.com/file/d/1u5FzpuY_BOAg8-g9lRwvah7mbCBOfNVg/view?usp=sharing), [res256](https://drive.google.com/file/d/1JnFoPEGcTLFJ7OHSWrI72h1H9_yOxUP6/view?usp=sharing)) & [res64 for G-MeshDiffusion](https://drive.google.com/file/d/1YQuU4D-0q8kwrzEfla3hGzBg4erBhand/view?usp=drive_link) to `data/tets` folder under the root directory. Alternatively, you may follow https://github.com/crawforddoran/quartet and `data/tets/generate_tets.py` to create the tet-grid files.

#### Generation

Install the following

- Pytorch3D
- ml_collections

## To-dos

- [x] Code for reconstruction
- [ ] DeepFashion3D multiview image dataset for metallic surfaces
- [x] Code for generative models
- [ ] Code for DeepFashion3D dataset preparation
- [ ] Evaluation code for generative models

## Reconstruction

### Datasets

#### DeepFashion3D mesh dataset

We provide ground-truth images (rendered under realistic environment light with Blender) for 9 instances in [DeepFashion3D-v2 dataset](https://github.com/GAP-LAB-CUHK-SZ/deepFashion3D). The download links for the raw meshes can be found in their repo.

non-metallic material: [training data](https://drive.google.com/file/d/1LwBqLYzamFLyBIiNpD6kEkvySrq2nruG/view?usp=sharing), [test data](https://drive.google.com/file/d/1-47dH_yJrUzKVdKbJenslpdyHwDI6QVo/view?usp=sharing)


#### NeRF synthetic dataset

Download the [NeRF synthetic dataset archive](https://drive.google.com/uc?export=download&id=18JxhpWD-4ZmuFKLzKlAw-w5PpzZxXOcG) and unzip it into the `data/` folder.

#### Hat dataset

Download link: https://drive.google.com/file/d/18UmT1NM5wJQ-ZM-rtUXJHXkDc-ba-xVk/view?usp=sharing

RGB images, segmentation masks and the corresponding camera poses are included. Alternatively, you may choose to 1) generate the camera poses with COLMAP and 2) create binary segmentation masks by yourself.

### Training

#### DeepFashion3D-v2 instances

The mesh instances' IDs are [30, 92, 117, 133, 164, 320, 448, 522, 591]. To reconstruct the `$INDEX`-th mesh (0-8) in the list using tet-based G-Shell, run

```
  python train_gshelltet_deepfashion.py --config config/deepfashion_mc_256.json --index $INDEX --trainset_path $TRAINSET_PATH --testset_path $TESTSET_PATH --o $OUTPUT_PATH
```

For FlexiCubes + G-Shell, run

```
  python train_gflexicubes_deepfashion.py --config config/deepfashion_mc_80.json --index $INDEX --trainset_path $TRAINSET_PATH --testset_path $TESTSET_PATH --o $OUTPUT_PATH
```

#### Synthetic data

```
  python train_gshelltet_synthetic.py --config config/nerf_chair.json --o $OUTPUT_PATH
```

#### Hat data

```
  python train_gshelltet_polycam.py --config config/polycam_mc_128.json --trainset_path $TRAINSET_PATH --o $OUTPUT_PATH
```

```
  python train_gshelltet_polycam.py --config config/polycam_mc_128.json --trainset_path $TRAINSET_PATH --o $OUTPUT_PATH
```

#### On config files

You may consider modify the following, depending on your demand:

- `gshell_grid`: the G-Shell grid size. For tet-based G-Shell, please make sure the corresponding tet-grid file exists under `data/tets` (e.g., `256_tets.npz`). Otherwise, follow https://github.com/crawforddoran/quartet and `data/tets/generate_tets.py` to generate the desired tet-grid file.
- `n_samples`: the number of MC samples for light rays per rasterized pixel. The higher the better (at a cost of memory and speed).
- `batch_size`: how many views sampled in each iteration.
- `iteration`: total number of iterations.
- `kd_min`, `kd_max`, etc: the min/max of the corresponding PBR material parameter.






## Generation

### Preparation

Download info files for the underlying tet grids and binary masks that indicating which locations store useful values in the cubic grids from [tet_info.pt](https://drive.google.com/file/d/19Dw_hOpcVHazpm2_1qA7T7j5xABOUWxv/view?usp=drive_link), [global_mask_res64.pt](https://drive.google.com/file/d/1mlSnu23_u08HH5aO3x5z1V9GzguFzoiT/view?usp=drive_link), [cat_mask_res64.pt](https://drive.google.com/file/d/11Bm4CQX-y1X7R47AfQQz20s7oP6AbbNK/view?usp=drive_link) and [occ_mask_res64.pt](https://drive.google.com/file/d/1qEqqLfZe633GdVkj5MGOCON_kf0l4e4G/view?usp=drive_link). Put these files under `GMeshDiffusion/metadata/`.

#### For inference

Download the pretrained model for upper-body garments lower-body garments [here](https://huggingface.co/lzzcd001/GMeshDiffusion-Models).

#### For training

1) Download the processed Cloth3D garment dataset (for upper-body & lower-body garments) from [link](https://huggingface.co/datasets/lzzcd001/Cloth3D-GShell-Dataset). Alternatively, you may create a grid dataset for your own objects by a) normalize your datapoints by re-center and re-scaling meshes, b) fitting G-Shell representations and c) turn these representations into cubid grids by running `GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py`.

2) Run `GMeshDiffusion/metadata/get_splits_lower.py` and/or `GMeshDiffusion/metadata/get_splits_upper.py` to generate lists of training and test datapoints.


### Inference

1. Modify the batch size in config files in `GMeshDiffusion/diffusion_config/` and enter the desired directories and values (for model checkpoints, where to store generated samples, etc.) in `GMeshDiffusion/scripts`.
2. Run the eval scripts in `GMeshDiffusion/scripts`.
3. Run `eval_gmeshdiffusion_generated_samples.py` to extract triangular meshes.

### Training

1. Enter the desired directories (for model checkpoints and where to store generated samples) in `GMeshDiffusion/scripts`.
2. Modify the config files if necessary.
3. Run the training scripts in `GMeshDiffusion/scripts`.


## Citation

If you find our work useful to your research, please consider citing:

```
@article{liu2024gshell,
    title={Ghost on the Shell: An Expressive Representation of General 3D Shapes},
    author={Liu, Zhen and Feng, Yao and Xiu, Yuliang and Liu, Weiyang and Paull, Liam and Black, Michael J. and Sch{\"o}lkopf, Bernhard},
    booktitle={The Twelfth International Conference on Learning Representations},
    year={2024},
}
```


## Acknowledgement

We sincerely thank the authors of [Nvdiffrecmc](https://github.com/NVlabs/nvdiffrecmc), [FlexiCubes](https://github.com/nv-tlabs/FlexiCubes) and https://github.com/yang-song/score_sde_pytorch for sharing their codes. Our repo is adapted from [MeshDiffusion](https://github.com/lzzcd001/MeshDiffusion/).


================================================
FILE: configs/deepfashion_mc.json
================================================
{
    "ref_mesh": "data/spot/spot.obj",
    "random_textures": true,
    "iter": 5000,
    "save_interval": 100,
    "texture_res": [ 1024, 1024 ],
    "train_res": [1024, 1024],
    "batch": 2,
    "learning_rate": [0.03, 0.005],
    "ks_min" : [0, 0.001, 0.0],
    "ks_max" : [0, 1.0, 1.0],
    "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
    "lock_pos" : false,
    "display": [{"latlong" : true}],
    "background" : "white",
    "denoiser": "bilateral",
    "n_samples" : 24,
    "env_scale" : 2.0,
    "gshell_grid" : 128,
    "validate" : true,
    "laplace_scale" : 6000,
    "boxscale": [1, 1, 1],
    "aabb": [-1, -1, -1, 1, 1, 1]
}

================================================
FILE: configs/deepfashion_mc_256.json
================================================
{
    "ref_mesh": "data/spot/spot.obj",
    "random_textures": true,
    "iter": 5000,
    "save_interval": 100,
    "texture_res": [ 1024, 1024 ],
    "train_res": [1024, 1024],
    "batch": 2,
    "learning_rate": [0.03, 0.005],
    "ks_min" : [0, 0.001, 0.0],
    "ks_max" : [0, 1.0, 1.0],
    "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
    "lock_pos" : false,
    "display": [{"latlong" : true}],
    "background" : "white",
    "denoiser": "bilateral",
    "n_samples" : 24,
    "env_scale" : 2.0,
    "gshell_grid" : 256,
    "validate" : true,
    "laplace_scale" : 6000,
    "boxscale": [1, 1, 1],
    "aabb": [-1, -1, -1, 1, 1, 1]
}

================================================
FILE: configs/deepfashion_mc_512.json
================================================
{
    "ref_mesh": "data/spot/spot.obj",
    "random_textures": true,
    "iter": 5000,
    "save_interval": 100,
    "texture_res": [ 1024, 1024 ],
    "train_res": [1024, 1024],
    "batch": 2,
    "learning_rate": [0.03, 0.005],
    "ks_min" : [0, 0.001, 0.0],
    "ks_max" : [0, 1.0, 1.0],
    "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
    "validate" : false,
    "lock_pos" : false,
    "display": [{"latlong" : true}],
    "background" : "white",
    "denoiser": "bilateral",
    "n_samples" : 12,
    "env_scale" : 2.0,
    "gshell_grid" : 512,
    "validate" : true,
    "laplace_scale" : 6000,
    "boxscale": [1, 1, 1],
    "aabb": [-1, -1, -1, 1, 1, 1]
}

================================================
FILE: configs/deepfashion_mc_80.json
================================================
{
    "ref_mesh": "data/spot/spot.obj",
    "random_textures": true,
    "iter": 5000,
    "save_interval": 100,
    "texture_res": [ 1024, 1024 ],
    "train_res": [1024, 1024],
    "batch": 2,
    "learning_rate": [0.03, 0.005],
    "ks_min" : [0, 0.001, 0.0],
    "ks_max" : [0, 1.0, 1.0],
    "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
    "lock_pos" : false,
    "display": [{"latlong" : true}],
    "background" : "white",
    "denoiser": "bilateral",
    "n_samples" : 24,
    "env_scale" : 2.0,
    "gshell_grid" : 80,
    "validate" : true,
    "laplace_scale" : 6000,
    "boxscale": [1, 1, 1],
    "aabb": [-1, -1, -1, 1, 1, 1]
}

================================================
FILE: configs/nerf_chair.json
================================================
{
    "ref_mesh": "data/nerf_synthetic/chair",
    "random_textures": true,
    "iter": 5000,
    "save_interval": 100,
    "texture_res": [ 1024, 1024 ],
    "train_res": [800, 800],
    "batch": 2,
    "learning_rate": [0.03, 0.005],
    "gshell_grid" : 128,
    "mesh_scale" : 2.1,
    "validate" : true,
    "n_samples" : 8,
    "denoiser" : "bilateral",
    "display": [{"latlong" : true}, {"bsdf" : "kd"}, {"bsdf" : "ks"}, {"bsdf" : "normal"}],
    "background" : "white",
    "boxscale": [1, 1, 1],
    "aabb": [-1, -1, -1, 1, 1, 1]
}

================================================
FILE: configs/polycam_mc.json
================================================
{
    "ref_mesh": "data/spot/spot.obj",
    "random_textures": true,
    "iter": 5000,
    "save_interval": 100,
    "texture_res": [ 1024, 1024 ],
    "train_res": [768, 1024],
    "batch": 2,
    "learning_rate": [0.03, 0.005],
    "ks_min" : [0, 0.001, 0.0],
    "ks_max" : [0, 1.0, 1.0],
    "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
    "lock_pos" : false,
    "display": [{"latlong" : true}],
    "background" : "white",
    "denoiser": "bilateral",
    "n_samples" : 8,
    "env_scale" : 2.0,
    "gshell_grid" : 256,
    "validate" : true,
    "laplace_scale" : 6000,
    "boxscale": [1, 1, 1],
    "aabb": [-1, -1, -1, 1, 1, 1]
}

================================================
FILE: configs/polycam_mc_128.json
================================================
{
    "ref_mesh": "data/spot/spot.obj",
    "random_textures": true,
    "iter": 5000,
    "save_interval": 100,
    "texture_res": [ 1024, 1024 ],
    "train_res": [768, 1024],
    "batch": 2,
    "learning_rate": [0.03, 0.005],
    "ks_min" : [0, 0.001, 0.0],
    "ks_max" : [0, 1.0, 1.0],
    "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
    "lock_pos" : false,
    "display": [{"latlong" : true}],
    "background" : "white",
    "denoiser": "bilateral",
    "n_samples" : 8,
    "env_scale" : 2.0,
    "gshell_grid" : 128,
    "validate" : true,
    "laplace_scale" : 6000,
    "boxscale": [1, 1, 1],
    "aabb": [-1, -1, -1, 1, 1, 1]
}

================================================
FILE: configs/polycam_mc_16samples.json
================================================
{
    "ref_mesh": "data/spot/spot.obj",
    "random_textures": true,
    "iter": 5000,
    "save_interval": 100,
    "texture_res": [ 1024, 1024 ],
    "train_res": [768, 1024],
    "batch": 2,
    "learning_rate": [0.03, 0.005],
    "ks_min" : [0, 0.001, 0.0],
    "ks_max" : [0, 1.0, 1.0],
    "envlight": "data/irrmaps/aerodynamics_workshop_2k.hdr",
    "lock_pos" : false,
    "display": [{"latlong" : true}],
    "background" : "white",
    "denoiser": "bilateral",
    "n_samples" : 16,
    "env_scale" : 2.0,
    "gshell_grid" : 256,
    "validate" : true,
    "laplace_scale" : 6000,
    "boxscale": [1, 1, 1],
    "aabb": [-1, -1, -1, 1, 1, 1]
}

================================================
FILE: data/tets/generate_tets.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction, 
# disclosure or distribution of this material and related documentation 
# without an express license agreement from NVIDIA CORPORATION or 
# its affiliates is strictly prohibited.

import os
import numpy as np


'''
This code segment shows how to use Quartet: https://github.com/crawforddoran/quartet, 
to generate a tet grid 
1) Download, compile and run Quartet as described in the link above. Example usage `quartet meshes/cube.obj 0.5 cube_5.tet`
2) Run the function below to generate a file `cube_32_tet.tet`
'''

def generate_tetrahedron_grid_file(res=32, root='..'):
    frac = 1.0 / res
    command = 'cd %s/quartet; ' % (root) + \
                './quartet meshes/cube.obj %f meshes/cube_%f_tet.tet -s meshes/cube_boundary_%f.obj' % (frac, res, res)
    os.system(command)


'''
This code segment shows how to convert from a quartet .tet file to compressed npz file
'''
def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile = '32_tets'):

    file1 = open(quartetfile, 'r')
    header = file1.readline()
    numvertices = int(header.split(" ")[1])
    numtets     = int(header.split(" ")[2])
    print(numvertices, numtets)

    # load vertices
    vertices = np.loadtxt(quartetfile, skiprows=1, max_rows=numvertices)
    print(vertices.shape)

    # load indices
    indices = np.loadtxt(quartetfile, dtype=int, skiprows=1+numvertices, max_rows=numtets)
    print(indices.shape)

    np.savez_compressed(npzfile, vertices=vertices, indices=indices)

================================================
FILE: dataset/__init__.py
================================================
# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

from .dataset import Dataset
from .dataset_mesh import DatasetMesh
from .dataset_nerf import DatasetNERF
from .dataset_llff import DatasetLLFF

================================================
FILE: dataset/dataset.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction, 
# disclosure or distribution of this material and related documentation 
# without an express license agreement from NVIDIA CORPORATION or 
# its affiliates is strictly prohibited.

import torch

class Dataset(torch.utils.data.Dataset):
    """Basic dataset interface"""
    def __init__(self): 
        super().__init__()

    def __len__(self):
        raise NotImplementedError

    def __getitem__(self):
        raise NotImplementedError

    def collate(self, batch):
        iter_res, iter_spp = batch[0]['resolution'], batch[0]['spp']
        return {
            'mv' : torch.cat(list([item['mv'] for item in batch]), dim=0),
            'mvp' : torch.cat(list([item['mvp'] for item in batch]), dim=0),
            'campos' : torch.cat(list([item['campos'] for item in batch]), dim=0),
            'resolution' : iter_res,
            'spp' : iter_spp,
            'img' : torch.cat(list([item['img'] for item in batch]), dim=0) if 'img' in batch[0] else None,
            'img_second' : torch.cat(list([item['img_second'] for item in batch]), dim=0) if 'img_second' in batch[0] else None,
            'invdepth' : torch.cat(list([item['invdepth'] for item in batch]), dim=0)if 'invdepth' in batch[0] else None,
            'invdepth_second' : torch.cat(list([item['invdepth_second'] for item in batch]), dim=0) if 'invdepth_second' in batch[0] else None,
            'envlight_transform': torch.cat(list([item['envlight_transform'] for item in batch]), dim=0) if 'envlight_transform' in batch and batch[0]['envlight_transform'] is not None else None,
        }

================================================
FILE: dataset/dataset_deepfashion.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction, 
# disclosure or distribution of this material and related documentation 
# without an express license agreement from NVIDIA CORPORATION or 
# its affiliates is strictly prohibited.

import os
import glob
import json

import torch
import numpy as np

from render import util

from .dataset import Dataset

import cv2 as cv

# This function is borrowed from IDR: https://github.com/lioryariv/idr
def load_K_Rt_from_P(filename, P=None):
    if P is None:
        lines = open(filename).read().splitlines()
        if len(lines) == 4:
            lines = lines[1:]
        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
        P = np.asarray(lines).astype(np.float32).squeeze()

    out = cv.decomposeProjectionMatrix(P)
    K = out[0]
    R = out[1]
    t = out[2]

    K = K / K[2, 2]
    intrinsics = np.eye(4)
    intrinsics[:3, :3] = K


    pose = np.eye(4, dtype=np.float32)
    pose[:3, :3] = R.transpose()
    pose[:3, 3] = (t[:3] / t[3])[:, 0]

    return intrinsics, pose

def _load_img(path):
    img = util.load_image_raw(path)
    if img.dtype != np.float32: # LDR image
        img = torch.tensor(img / 255, dtype=torch.float32)
        img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])
    else:
        img = torch.tensor(img, dtype=torch.float32)
    return img



class DatasetDeepFashion(Dataset):
    def __init__(self, base_dir, FLAGS, examples=None):
        self.FLAGS = FLAGS
        self.examples = examples
        self.base_dir = base_dir

        # Load config / transforms
        self.n_images = 72 ### hardcoded

        self.fovy               = np.deg2rad(60)
        self.proj_mtx = util.perspective(
            self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]
        )



        camera_dict = np.load(os.path.join(self.base_dir, 'cameras_sphere.npz'))

        # world_mat is a projection matrix from world to image
        self.world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
        self.scale_mats_np = []


        # scale_mat: used for coordinate normalization, we assume the scene to render is inside a unit sphere at origin.
        self.scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
        self.intrinsics_all = []
        self.pose_all = []

        for scale_mat, world_mat in zip(self.scale_mats_np, self.world_mats_np):
            P = world_mat @ scale_mat
            P = P[:3, :4]
            intrinsics, pose = load_K_Rt_from_P(None, P)
            self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
            self.pose_all.append(torch.from_numpy(pose).float())

        # Determine resolution & aspect ratio
        self.resolution = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(0))).shape[0:2]
        self.aspect = self.resolution[1] / self.resolution[0]

        if self.FLAGS.local_rank == 0:
            print("DatasetNERF: %d images with shape [%d, %d]" % (self.n_images, self.resolution[0], self.resolution[1]))

    def _parse_frame(self, idx):
        # Load image data and modelview matrix
        img    = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(idx)))
        img[:,:,:3] = img[:,:,:3] * img[:,:,3:]
        img[:,:,3] = torch.sign(img[:,:,3])
        assert img.size(-1) == 4

        flip_mat = torch.tensor([
            [ 1,  0,  0,  0],
            [ 0, -1,  0,  0],
            [ 0,  0, -1,  0],
            [ 0,  0,  0,  1]
        ], dtype=torch.float)

        mv = flip_mat @ torch.linalg.inv(self.pose_all[idx])
        campos = torch.linalg.inv(mv)[:3, 3]
        mvp = self.proj_mtx @ mv

        return img[None, ...].cuda(), mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda() # Add batch dimension

    def __len__(self):
        return self.n_images if self.examples is None else self.examples

    def __getitem__(self, itr):
        iter_res = self.FLAGS.train_res
        
        img      = []

        img, mv, mvp, campos = self._parse_frame(itr % self.n_images)

        return {
            'mv' : mv,
            'mvp' : mvp,
            'campos' : campos,
            'resolution' : iter_res,
            'spp' : self.FLAGS.spp,
            'img' : img
        }


================================================
FILE: dataset/dataset_deepfashion_testset.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction, 
# disclosure or distribution of this material and related documentation 
# without an express license agreement from NVIDIA CORPORATION or 
# its affiliates is strictly prohibited.

import os
import glob
import json

import torch
import numpy as np

from render import util

from .dataset import Dataset

import cv2 as cv

# This function is borrowed from IDR: https://github.com/lioryariv/idr
def load_K_Rt_from_P(filename, P=None):
    if P is None:
        lines = open(filename).read().splitlines()
        if len(lines) == 4:
            lines = lines[1:]
        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
        P = np.asarray(lines).astype(np.float32).squeeze()

    out = cv.decomposeProjectionMatrix(P)
    K = out[0]
    R = out[1]
    t = out[2]

    K = K / K[2, 2]
    intrinsics = np.eye(4)
    intrinsics[:3, :3] = K


    pose = np.eye(4, dtype=np.float32)
    pose[:3, :3] = R.transpose()
    pose[:3, 3] = (t[:3] / t[3])[:, 0]

    return intrinsics, pose

def _load_img(path):
    img = util.load_image_raw(path)
    if img.dtype != np.float32: # LDR image
        img = torch.tensor(img / 255, dtype=torch.float32)
        img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])
    else:
        img = torch.tensor(img, dtype=torch.float32)
    return img


def _load_mask(path):
    img = util.load_image_raw(path)
    if img.dtype != np.float32: # LDR image
        img = torch.tensor(img / 255, dtype=torch.float32)
    else:
        img = torch.tensor(img, dtype=torch.float32)
    return img


class DatasetDeepFashionTestset(Dataset):
    def __init__(self, base_dir, FLAGS, examples=None):
        self.FLAGS = FLAGS
        self.examples = examples
        self.base_dir = base_dir

        # Load config / transforms
        self.n_images = 200 ### hardcoded


        proj_mtx_all = np.load(os.path.join(self.base_dir, 'proj_mtx_all.npy'))
        self.intrinsics_all = []
        self.pose_all = []


        self.fovy               = np.deg2rad(60)
        self.proj_mtx = util.perspective(
            self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1]
        )

        for i in range(proj_mtx_all.shape[0]):
            P = proj_mtx_all[i]
            P = P[:3, :4]
            intrinsics, pose = load_K_Rt_from_P(None, P)
            self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
            self.pose_all.append(torch.from_numpy(pose).float())

        # Determine resolution & aspect ratio
        self.resolution = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(0))).shape[0:2]
        self.aspect = self.resolution[1] / self.resolution[0]

        if self.FLAGS.local_rank == 0:
            print("DatasetNERF: %d images with shape [%d, %d]" % (self.n_images, self.resolution[0], self.resolution[1]))

    def _parse_frame(self, idx):
        # Load image data and modelview matrix
        img    = _load_img(os.path.join(self.base_dir, '{:03d}.png'.format(idx)))
        assert img.size(-1) == 4

        flip_mat = torch.tensor([
            [ 1,  0,  0,  0],
            [ 0, -1,  0,  0],
            [ 0,  0, -1,  0],
            [ 0,  0,  0,  1]
        ], dtype=torch.float)

        mv = flip_mat @ torch.linalg.inv(self.pose_all[idx])
        campos = torch.linalg.inv(mv)[:3, 3]
        mvp = self.proj_mtx @ mv

        return img[None, ...].cuda(), mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda() # Add batch dimension

    def __len__(self):
        return self.n_images if self.examples is None else self.examples

    def __getitem__(self, itr):
        iter_res = self.FLAGS.train_res
        
        img      = []

        img, mv, mvp, campos = self._parse_frame(itr % self.n_images)
        

        return {
            'mv' : mv,
            'mvp' : mvp,
            'campos' : campos,
            'resolution' : iter_res,
            'spp' : self.FLAGS.spp,
            'img' : img
        }


================================================
FILE: dataset/dataset_llff.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction, 
# disclosure or distribution of this material and related documentation 
# without an express license agreement from NVIDIA CORPORATION or 
# its affiliates is strictly prohibited.

import os
import glob

import torch
import numpy as np

from render import util

from .dataset import Dataset

def _load_mask(fn):
    img = torch.tensor(util.load_image(fn), dtype=torch.float32)
    if len(img.shape) == 2:
        img = img[..., None].repeat(1, 1, 3)
    return img

def _load_img(fn):
    img = util.load_image_raw(fn)
    if img.dtype != np.float32: # LDR image
        img = torch.tensor(img / 255, dtype=torch.float32)
        img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])
    else:
        img = torch.tensor(img, dtype=torch.float32)
    return img

###############################################################################
# LLFF datasets (real world camera lightfields)
###############################################################################

class DatasetLLFF(Dataset):
    def __init__(self, base_dir, FLAGS, examples=None):
        self.FLAGS = FLAGS
        self.base_dir = base_dir
        self.examples = examples

        # Enumerate all image files and get resolution
        all_img = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "images", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]
        self.resolution = _load_img(all_img[0]).shape[0:2]

        # Load camera poses
        poses_bounds = np.load(os.path.join(self.base_dir, 'poses_bounds.npy'))
        
        poses        = poses_bounds[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
        poses        = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) # Taken from nerf, swizzles from LLFF to expected coordinate system
        poses        = np.moveaxis(poses, -1, 0).astype(np.float32)
        
        lcol         = np.array([0,0,0,1], dtype=np.float32)[None, None, :].repeat(poses.shape[0], 0)
        self.imvs    = torch.tensor(np.concatenate((poses[:, :, 0:4], lcol), axis=1), dtype=torch.float32)
        self.aspect  = self.resolution[1] / self.resolution[0] # width / height
        self.fovy    = util.focal_length_to_fovy(poses[:, 2, 4], poses[:, 0, 4])

        # Recenter scene so lookat position is origin
        center                = util.lines_focal(self.imvs[..., :3, 3], -self.imvs[..., :3, 2])
        self.imvs[..., :3, 3] = self.imvs[..., :3, 3] - center[None, ...]

        if self.FLAGS.local_rank == 0:
            print("DatasetLLFF: %d images with shape [%d, %d]" % (len(all_img), self.resolution[0], self.resolution[1]))
            print("DatasetLLFF: auto-centering at %s" % (center.cpu().numpy()))

        # Pre-load from disc to avoid slow png parsing
        if self.FLAGS.pre_load:
            self.preloaded_data = []
            for i in range(self.imvs.shape[0]):
                self.preloaded_data += [self._parse_frame(i)]

    def _parse_frame(self, idx):
        all_img  = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "images", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]
        all_mask = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "masks", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]
        assert len(all_img) == self.imvs.shape[0] and len(all_mask) == self.imvs.shape[0]

        # Load image+mask data
        img  = _load_img(all_img[idx])
        mask = _load_mask(all_mask[idx])
        img  = torch.cat((img, mask[..., 0:1]), dim=-1)

        # Setup transforms
        proj   = util.perspective(self.fovy[idx, ...], self.aspect, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
        mv     = torch.linalg.inv(self.imvs[idx, ...])
        campos = torch.linalg.inv(mv)[:3, 3]
        mvp    = proj @ mv

        return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...] # Add batch dimension

    def __len__(self):
        return self.imvs.shape[0] if self.examples is None else self.examples

    def __getitem__(self, itr):
        if self.FLAGS.pre_load:
            img, mv, mvp, campos = self.preloaded_data[itr % self.imvs.shape[0]]
        else:
            img, mv, mvp, campos = self._parse_frame(itr % self.imvs.shape[0])

        return {
            'mv' : mv,
            'mvp' : mvp,
            'campos' : campos,
            'resolution' : self.resolution,
            'spp' : self.FLAGS.spp,
            'img' : img
        }


================================================
FILE: dataset/dataset_mesh.py
================================================
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction, 
# disclosure or distribution of this material and related documentation 
# without an express license agreement from NVIDIA CORPORATION or 
# its affiliates is strictly prohibited.

import numpy as np
import torch

from render import util
from render import mesh
from render import render
from render import light

from .dataset import Dataset

###############################################################################
# Reference dataset using mesh & rendering
###############################################################################

class DatasetMesh(Dataset):

    def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False, mesh_center=None):
        # Init 
        self.glctx              = glctx
        self.cam_radius         = cam_radius
        self.FLAGS              = FLAGS
        self.validate           = validate
        self.fovy               = np.deg2rad(45)
        self.aspect             = FLAGS.train_res[1] / FLAGS.train_res[0]
        self.random_lgt         = FLAGS.random_lgt
        self.camera_lgt         = False

        self.mesh_center = mesh_center

        if self.FLAGS.local_rank == 0:
            print("DatasetMesh: ref mesh has %d triangles and %d vertices" % (ref_mesh.t_pos_idx.shape[0], ref_mesh.v_pos.shape[0]))

        # Sanity test training texture resolution
        ref_texture_res = np.maximum(ref_mesh.material['kd'].getRes(), ref_mesh.material['ks'].getRes())
        if 'normal' in ref_mesh.material:
            ref_texture_res = np.maximum(ref_texture_res, ref_mesh.material['normal'].getRes())
        if self.FLAGS.local_rank == 0 and FLAGS.texture_res[0] < ref_texture_res[0] or FLAGS.texture_res[1] < ref_texture_res[1]:
            print("---> WARNING: Picked a texture resolution lower than the reference mesh [%d, %d] < [%d, %d]" % (FLAGS.texture_res[0], FLAGS.texture_res[1], ref_texture_res[0], ref_texture_res[1]))

        # Load environment map texture
        self.envlight = light.load_env(FLAGS.envmap, scale=FLAGS.env_scale)
        
        self.ref_mesh = mesh.compute_tangents(ref_mesh)

    def _rotate_scene(self, itr):
        proj_mtx = util.perspective(self.fovy, self.FLAGS.display_res[1] / self.FLAGS.display_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])

        # Smooth rotation for display.
        ang    = (itr / 50) * np.pi * 2
        mv     = util.translate(0, 0, -self.cam_radius) @ (util.rotate_x(-0.4) @ util.rotate_y(ang))
        mvp    = proj_mtx @ mv
        campos = torch.linalg.inv(mv)[:3, 3]

        return mv[None, ...].cuda(), mvp[None, ...].cuda(), campos[None, ...].cuda(), self.FLAGS.display_res, self.FLAGS.spp

    def _random_scene(self):
        # ==============================================================================================
        #  Setup projection matrix
        # ==============================================================================================
        iter_res = self.FLAGS.train_res
        proj_mtx = util.perspective(self.fovy, iter_res[1] / iter_res[0], self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])

        # ==============================================================================================
        #  Random camera & light position
        # ==============================================================================================

        # Random rotation/translation matrix for optimization.
        if self.mesh_center is not None:
            mv     = (
                util.translate(-self.mesh_center[0], -self.mesh_center
Download .txt
gitextract_zb4dyoqx/

├── .gitignore
├── GMeshDiffusion/
│   ├── diffusion_configs/
│   │   ├── config_lower_occgrid_normalized.py
│   │   └── config_upper_occgrid_normalized.py
│   ├── lib/
│   │   ├── dataset/
│   │   │   ├── gshell_dataset.py
│   │   │   └── gshell_dataset_aug.py
│   │   └── diffusion/
│   │       ├── evaler.py
│   │       ├── likelihood.py
│   │       ├── losses.py
│   │       ├── models/
│   │       │   ├── __init__.py
│   │       │   ├── ema.py
│   │       │   ├── functional.py
│   │       │   ├── layers.py
│   │       │   ├── normalization.py
│   │       │   ├── unet3d_occgrid.py
│   │       │   └── utils.py
│   │       ├── sampling.py
│   │       ├── sde_lib.py
│   │       ├── trainer.py
│   │       ├── trainer_ddp.py
│   │       └── utils.py
│   ├── main_diffusion.py
│   ├── main_diffusion_ddp.py
│   ├── metadata/
│   │   ├── get_splits_lower.py
│   │   ├── get_splits_upper.py
│   │   ├── save_tet_info.py
│   │   └── tet_to_cubic_grid_dataset.py
│   └── scripts/
│       ├── run_eval_lower_occgrid_normalized.sh
│       ├── run_eval_upper_occgrid_normalized.sh
│       ├── run_lower_occgrid_normalized_ddp.sh
│       └── run_upper_occgrid_normalized_ddp.sh
├── README.md
├── configs/
│   ├── deepfashion_mc.json
│   ├── deepfashion_mc_256.json
│   ├── deepfashion_mc_512.json
│   ├── deepfashion_mc_80.json
│   ├── nerf_chair.json
│   ├── polycam_mc.json
│   ├── polycam_mc_128.json
│   └── polycam_mc_16samples.json
├── data/
│   └── tets/
│       └── generate_tets.py
├── dataset/
│   ├── __init__.py
│   ├── dataset.py
│   ├── dataset_deepfashion.py
│   ├── dataset_deepfashion_testset.py
│   ├── dataset_llff.py
│   ├── dataset_mesh.py
│   ├── dataset_nerf.py
│   └── dataset_nerf_colmap.py
├── denoiser/
│   └── denoiser.py
├── eval_gmeshdiffusion_generated_samples.py
├── geometry/
│   ├── embedding.py
│   ├── flexicubes_table.py
│   ├── gshell_flexicubes.py
│   ├── gshell_flexicubes_geometry.py
│   ├── gshell_tets.py
│   ├── gshell_tets_geometry.py
│   └── mlp.py
├── render/
│   ├── light.py
│   ├── material.py
│   ├── mesh.py
│   ├── mlptexture.py
│   ├── obj.py
│   ├── optixutils/
│   │   ├── __init__.py
│   │   ├── c_src/
│   │   │   ├── accessor.h
│   │   │   ├── bsdf.h
│   │   │   ├── common.h
│   │   │   ├── denoising.cu
│   │   │   ├── denoising.h
│   │   │   ├── envsampling/
│   │   │   │   ├── kernel.cu
│   │   │   │   └── params.h
│   │   │   ├── math_utils.h
│   │   │   ├── optix_wrapper.cpp
│   │   │   ├── optix_wrapper.h
│   │   │   └── torch_bindings.cpp
│   │   ├── include/
│   │   │   ├── internal/
│   │   │   │   ├── optix_7_device_impl.h
│   │   │   │   ├── optix_7_device_impl_exception.h
│   │   │   │   └── optix_7_device_impl_transformations.h
│   │   │   ├── optix.h
│   │   │   ├── optix_7_device.h
│   │   │   ├── optix_7_host.h
│   │   │   ├── optix_7_types.h
│   │   │   ├── optix_denoiser_tiling.h
│   │   │   ├── optix_device.h
│   │   │   ├── optix_function_table.h
│   │   │   ├── optix_function_table_definition.h
│   │   │   ├── optix_host.h
│   │   │   ├── optix_stack_size.h
│   │   │   ├── optix_stubs.h
│   │   │   └── optix_types.h
│   │   ├── ops.py
│   │   └── tests/
│   │       └── filter_test.py
│   ├── regularizer.py
│   ├── render.py
│   ├── renderutils/
│   │   ├── __init__.py
│   │   ├── bsdf.py
│   │   ├── c_src/
│   │   │   ├── bsdf.cu
│   │   │   ├── bsdf.h
│   │   │   ├── common.cpp
│   │   │   ├── common.h
│   │   │   ├── cubemap.cu
│   │   │   ├── cubemap.h
│   │   │   ├── loss.cu
│   │   │   ├── loss.h
│   │   │   ├── mesh.cu
│   │   │   ├── mesh.h
│   │   │   ├── normal.cu
│   │   │   ├── normal.h
│   │   │   ├── tensor.h
│   │   │   ├── torch_bindings.cpp
│   │   │   ├── vec3f.h
│   │   │   └── vec4f.h
│   │   ├── loss.py
│   │   ├── ops.py
│   │   └── tests/
│   │       ├── test_bsdf.py
│   │       ├── test_loss.py
│   │       ├── test_mesh.py
│   │       └── test_perf.py
│   ├── texture.py
│   └── util.py
├── train_gflexicubes_deepfashion.py
├── train_gflexicubes_polycam.py
├── train_gshelltet_deepfashion.py
├── train_gshelltet_polycam.py
└── train_gshelltet_synthetic.py
Download .txt
SYMBOL INDEX (975 symbols across 88 files)

FILE: GMeshDiffusion/diffusion_configs/config_lower_occgrid_normalized.py
  function get_config (line 6) | def get_config():

FILE: GMeshDiffusion/diffusion_configs/config_upper_occgrid_normalized.py
  function get_config (line 6) | def get_config():

FILE: GMeshDiffusion/lib/dataset/gshell_dataset.py
  class GShellDataset (line 5) | class GShellDataset(Dataset):
    method __init__ (line 6) | def __init__(self, filepath_metafile, extension='pt'):
    method __len__ (line 14) | def __len__(self):
    method __getitem__ (line 17) | def __getitem__(self, idx):

FILE: GMeshDiffusion/lib/dataset/gshell_dataset_aug.py
  class GShellAugDataset (line 4) | class GShellAugDataset(Dataset):
    method __init__ (line 5) | def __init__(self, FLAGS, extension='pt'):
    method __len__ (line 17) | def __len__(self):
    method __getitem__ (line 20) | def __getitem__(self, idx):
    method collate (line 31) | def collate(data):

FILE: GMeshDiffusion/lib/diffusion/evaler.py
  function uncond_gen (line 15) | def uncond_gen(
  function slerp (line 78) | def slerp(z1, z2, alpha):
  function uncond_gen_interp (line 88) | def uncond_gen_interp(
  function cond_gen (line 183) | def cond_gen(

FILE: GMeshDiffusion/lib/diffusion/likelihood.py
  function get_div_fn (line 26) | def get_div_fn(fn):
  function get_likelihood_fn (line 40) | def get_likelihood_fn(sde, inverse_scaler, hutchinson_type='Rademacher',

FILE: GMeshDiffusion/lib/diffusion/losses.py
  function get_optimizer (line 25) | def get_optimizer(config, params):
  function optimization_manager (line 40) | def optimization_manager(config):
  function get_ddpm_loss_fn (line 60) | def get_ddpm_loss_fn(vpsde, train, loss_type='l2', pred_type='noise', us...
  function get_step_fn (line 194) | def get_step_fn(sde, train, optimize_fn=None, loss_type='l2', pred_type=...

FILE: GMeshDiffusion/lib/diffusion/models/ema.py
  class ExponentialMovingAverage (line 10) | class ExponentialMovingAverage:
    method __init__ (line 15) | def __init__(self, parameters, decay, use_num_updates=True):
    method update (line 32) | def update(self, parameters):
    method copy_to (line 54) | def copy_to(self, parameters):
    method store (line 67) | def store(self, parameters):
    method restore (line 77) | def restore(self, parameters):
    method state_dict (line 92) | def state_dict(self):
    method load_state_dict (line 96) | def load_state_dict(self, state_dict, device='cuda'):

FILE: GMeshDiffusion/lib/diffusion/models/functional.py
  function has_cuda (line 39) | def has_cuda():
  function has_half (line 43) | def has_half():
  function has_bfloat (line 47) | def has_bfloat():
  function has_gemm (line 51) | def has_gemm():
  function enable_tf32 (line 55) | def enable_tf32():
  function disable_tf32 (line 59) | def disable_tf32():
  function enable_tiled_na (line 63) | def enable_tiled_na():
  function disable_tiled_na (line 67) | def disable_tiled_na():
  function enable_gemm_na (line 71) | def enable_gemm_na():
  function disable_gemm_na (line 75) | def disable_gemm_na():
  class NeighborhoodAttention1DQKAutogradFunction (line 79) | class NeighborhoodAttention1DQKAutogradFunction(Function):
    method forward (line 82) | def forward(ctx, query, key, rpb, kernel_size, dilation):
    method backward (line 94) | def backward(ctx, grad_out):
  class NeighborhoodAttention1DAVAutogradFunction (line 107) | class NeighborhoodAttention1DAVAutogradFunction(Function):
    method forward (line 110) | def forward(ctx, attn, value, kernel_size, dilation):
    method backward (line 121) | def backward(ctx, grad_out):
  class NeighborhoodAttention2DQKAutogradFunction (line 133) | class NeighborhoodAttention2DQKAutogradFunction(Function):
    method forward (line 136) | def forward(ctx, query, key, rpb, kernel_size, dilation):
    method backward (line 150) | def backward(ctx, grad_out):
  class NeighborhoodAttention2DAVAutogradFunction (line 163) | class NeighborhoodAttention2DAVAutogradFunction(Function):
    method forward (line 166) | def forward(ctx, attn, value, kernel_size, dilation):
    method backward (line 177) | def backward(ctx, grad_out):
  class NeighborhoodAttention3DQKAutogradFunction (line 189) | class NeighborhoodAttention3DQKAutogradFunction(Function):
    method forward (line 192) | def forward(ctx, query, key, rpb, kernel_size_d, kernel_size, dilation...
    method backward (line 208) | def backward(ctx, grad_out):
  class NeighborhoodAttention3DAVAutogradFunction (line 223) | class NeighborhoodAttention3DAVAutogradFunction(Function):
    method forward (line 226) | def forward(ctx, attn, value, kernel_size_d, kernel_size, dilation_d, ...
    method backward (line 241) | def backward(ctx, grad_out):
  function natten1dqkrpb (line 255) | def natten1dqkrpb(query, key, rpb, kernel_size, dilation):
  function natten1dqk (line 261) | def natten1dqk(query, key, kernel_size, dilation):
  function natten1dav (line 267) | def natten1dav(attn, value, kernel_size, dilation):
  function natten2dqkrpb (line 273) | def natten2dqkrpb(query, key, rpb, kernel_size, dilation):
  function natten2dqk (line 279) | def natten2dqk(query, key, kernel_size, dilation):
  function natten2dav (line 285) | def natten2dav(attn, value, kernel_size, dilation):
  function natten3dqkrpb (line 291) | def natten3dqkrpb(query, key, rpb, kernel_size_d, kernel_size, dilation_...
  function natten3dqk (line 297) | def natten3dqk(query, key, kernel_size_d, kernel_size, dilation_d, dilat...
  function natten3dav (line 303) | def natten3dav(attn, value, kernel_size_d, kernel_size, dilation_d, dila...

FILE: GMeshDiffusion/lib/diffusion/models/layers.py
  class GroupNormFloat32 (line 28) | class GroupNormFloat32(nn.GroupNorm):
    method forward (line 29) | def forward(self, input):
  function get_act_fn (line 35) | def get_act_fn(act_name):
  function variance_scaling (line 49) | def variance_scaling(scale, mode, distribution,
  function default_init (line 83) | def default_init(scale=1.):
  class Dense (line 89) | class Dense(nn.Module):
    method __init__ (line 91) | def __init__(self):
  function conv1x1 (line 95) | def conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., p...
  function conv3x3 (line 102) | def conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init...
  function conv5x5 (line 110) | def conv5x5(in_planes, out_planes, stride=2, bias=True, dilation=1, init...
  function conv3x3_transposed (line 119) | def conv3x3_transposed(in_planes, out_planes, stride=2, bias=True, dilat...
  function conv5x5_transposed (line 127) | def conv5x5_transposed(in_planes, out_planes, stride=2, bias=True, dilat...
  function get_timestep_embedding (line 141) | def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
  class AttnBlock (line 158) | class AttnBlock(nn.Module):
    method __init__ (line 160) | def __init__(self, channels, num_groups=32):
    method forward (line 168) | def forward(self, x):
  class Upsample (line 192) | class Upsample(nn.Module):
    method __init__ (line 193) | def __init__(self, channels, with_conv=False):
    method forward (line 199) | def forward(self, x, temb=None):
  class Downsample (line 207) | class Downsample(nn.Module):
    method __init__ (line 208) | def __init__(self, channels, with_conv=False):
    method forward (line 214) | def forward(self, x, temb=None):
  class ResBlock (line 227) | class ResBlock(nn.Module):
    method __init__ (line 229) | def __init__(self, act_fn, in_ch, out_ch=None, temb_dim=None, conv_sho...
    method forward (line 253) | def forward(self, x, temb=None):
  class AttnResBlock (line 272) | class AttnResBlock(ResBlock):
    method __init__ (line 274) | def __init__(self, act_fn, in_ch, out_ch, temb_dim=None, conv_shortcut...
    method forward (line 278) | def forward(self, x, temb=None):

FILE: GMeshDiffusion/lib/diffusion/models/normalization.py
  function get_normalization (line 22) | def get_normalization(config, conditional=False):
  class ConditionalBatchNorm3d (line 43) | class ConditionalBatchNorm3d(nn.Module):
    method __init__ (line 44) | def __init__(self, num_features, num_classes, bias=True):
    method forward (line 57) | def forward(self, x, y):
  class ConditionalInstanceNorm3d (line 68) | class ConditionalInstanceNorm3d(nn.Module):
    method __init__ (line 69) | def __init__(self, num_features, num_classes, bias=True):
    method forward (line 82) | def forward(self, x, y):
  class ConditionalVarianceNorm3d (line 93) | class ConditionalVarianceNorm3d(nn.Module):
    method __init__ (line 94) | def __init__(self, num_features, num_classes, bias=False):
    method forward (line 101) | def forward(self, x, y):
  class VarianceNorm3d (line 110) | class VarianceNorm3d(nn.Module):
    method __init__ (line 111) | def __init__(self, num_features, bias=False):
    method forward (line 118) | def forward(self, x):
  class ConditionalNoneNorm3d (line 126) | class ConditionalNoneNorm3d(nn.Module):
    method __init__ (line 127) | def __init__(self, num_features, num_classes, bias=True):
    method forward (line 139) | def forward(self, x, y):
  class NoneNorm3d (line 149) | class NoneNorm3d(nn.Module):
    method __init__ (line 150) | def __init__(self, num_features, bias=True):
    method forward (line 153) | def forward(self, x):
  class InstanceNorm3dPlus (line 157) | class InstanceNorm3dPlus(nn.Module):
    method __init__ (line 158) | def __init__(self, num_features, bias=True):
    method forward (line 170) | def forward(self, x):
  class ConditionalInstanceNorm3dPlus (line 186) | class ConditionalInstanceNorm3dPlus(nn.Module):
    method __init__ (line 187) | def __init__(self, num_features, num_classes, bias=True):
    method forward (line 200) | def forward(self, x, y):

FILE: GMeshDiffusion/lib/diffusion/models/unet3d_occgrid.py
  function str_to_class (line 33) | def str_to_class(classname):
  class UNet3D (line 38) | class UNet3D(nn.Module):
    method __init__ (line 39) | def __init__(self, config):
    method sequentially_call_module (line 142) | def sequentially_call_module(self, idx, x, temb=None):
    method forward (line 145) | def forward(self, x, labels):

FILE: GMeshDiffusion/lib/diffusion/models/utils.py
  function register_model (line 27) | def register_model(cls=None, *, name=None):
  function get_model (line 46) | def get_model(name):
  function get_sigmas (line 50) | def get_sigmas(config):
  function get_ddpm_params (line 63) | def get_ddpm_params(config):
  function create_model (line 88) | def create_model(config, use_parallel=True, ddp=False, rank=None):
  function get_model_fn (line 111) | def get_model_fn(model, train=False):
  function get_reg_fn (line 142) | def get_reg_fn(model, train=False):
  function get_score_fn (line 179) | def get_score_fn(sde, model, train=False, continuous=False, std_scale=Tr...
  function to_flattened_numpy (line 236) | def to_flattened_numpy(x):
  function from_flattened_numpy (line 241) | def from_flattened_numpy(x, shape):

FILE: GMeshDiffusion/lib/diffusion/sampling.py
  function register_predictor (line 37) | def register_predictor(cls=None, *, name=None):
  function register_corrector (line 56) | def register_corrector(cls=None, *, name=None):
  function get_predictor (line 75) | def get_predictor(name):
  function get_corrector (line 79) | def get_corrector(name):
  function get_sampling_fn (line 83) | def get_sampling_fn(config, sde, shape, inverse_scaler, eps, grid_mask=N...
  class Predictor (line 139) | class Predictor(abc.ABC):
    method __init__ (line 142) | def __init__(self, sde, score_fn, probability_flow=False):
    method update_fn (line 150) | def update_fn(self, x, t):
  class Corrector (line 164) | class Corrector(abc.ABC):
    method __init__ (line 167) | def __init__(self, sde, score_fn, snr, n_steps):
    method update_fn (line 175) | def update_fn(self, x, t):
  class EulerMaruyamaPredictor (line 190) | class EulerMaruyamaPredictor(Predictor):
    method __init__ (line 191) | def __init__(self, sde, score_fn, probability_flow=False):
    method update_fn (line 194) | def update_fn(self, x, t):
  class ReverseDiffusionPredictor (line 204) | class ReverseDiffusionPredictor(Predictor):
    method __init__ (line 205) | def __init__(self, sde, score_fn, probability_flow=False):
    method update_fn (line 208) | def update_fn(self, x, t):
  class AncestralSamplingPredictor (line 217) | class AncestralSamplingPredictor(Predictor):
    method __init__ (line 220) | def __init__(self, sde, score_fn, probability_flow=False):
    method vpsde_update_fn (line 226) | def vpsde_update_fn(self, x, t):
    method update_fn (line 236) | def update_fn(self, x, t):
  class NonePredictor (line 244) | class NonePredictor(Predictor):
    method __init__ (line 247) | def __init__(self, sde, score_fn, probability_flow=False):
    method update_fn (line 250) | def update_fn(self, x, t):
  class DDIMPredictor (line 254) | class DDIMPredictor(Predictor):
    method __init__ (line 255) | def __init__(self, sde, score_fn, probability_flow=False):
    method update_fn (line 259) | def update_fn(self, x, t, tprev=None):
  class LangevinCorrector (line 264) | class LangevinCorrector(Corrector):
    method __init__ (line 265) | def __init__(self, sde, score_fn, snr, n_steps):
    method update_fn (line 270) | def update_fn(self, x, t):
  class AnnealedLangevinDynamics (line 294) | class AnnealedLangevinDynamics(Corrector):
    method __init__ (line 300) | def __init__(self, sde, score_fn, snr, n_steps):
    method update_fn (line 305) | def update_fn(self, x, t):
  class NoneCorrector (line 329) | class NoneCorrector(Corrector):
    method __init__ (line 332) | def __init__(self, sde, score_fn, snr, n_steps):
    method update_fn (line 335) | def update_fn(self, x, t):
  function shared_predictor_update_fn (line 339) | def shared_predictor_update_fn(x, t, sde, model, predictor, probability_...
  function shared_corrector_update_fn (line 350) | def shared_corrector_update_fn(x, t, sde, model, corrector, continuous, ...
  function get_pc_sampler (line 361) | def get_pc_sampler(sde, shape, predictor, corrector, inverse_scaler, snr,
  function ddim_predictor_update_fn (line 508) | def ddim_predictor_update_fn(x, t, tprev, sde, model, predictor, probabi...
  function get_ddim_sampler (line 519) | def get_ddim_sampler(sde, shape, predictor, inverse_scaler, n_steps=1,

FILE: GMeshDiffusion/lib/diffusion/sde_lib.py
  class SDE (line 9) | class SDE(abc.ABC):
    method __init__ (line 12) | def __init__(self, N):
    method T (line 23) | def T(self):
    method sde (line 28) | def sde(self, x, t):
    method marginal_prob (line 32) | def marginal_prob(self, x, t):
    method prior_sampling (line 37) | def prior_sampling(self, shape):
    method prior_logp (line 42) | def prior_logp(self, z):
    method discretize (line 54) | def discretize(self, x, t):
    method reverse (line 73) | def reverse(self, score_fn, probability_flow=False):
  class VPSDE (line 209) | class VPSDE(SDE):
    method __init__ (line 210) | def __init__(self, beta_min=0.1, beta_max=20, N=1000):
    method T (line 234) | def T(self):
    method sde (line 237) | def sde(self, x, t):
    method marginal_prob (line 243) | def marginal_prob(self, x, t):
    method prior_sampling (line 249) | def prior_sampling(self, shape):
    method prior_logp (line 252) | def prior_logp(self, z):
    method discretize (line 258) | def discretize(self, x, t):

FILE: GMeshDiffusion/lib/diffusion/trainer.py
  function train (line 20) | def train(config):

FILE: GMeshDiffusion/lib/diffusion/trainer_ddp.py
  function train (line 22) | def train(config):

FILE: GMeshDiffusion/lib/diffusion/utils.py
  function restore_checkpoint (line 6) | def restore_checkpoint(ckpt_dir, state, device, strict=False, rank=None):
  function save_checkpoint (line 38) | def save_checkpoint(ckpt_dir, state):

FILE: GMeshDiffusion/main_diffusion.py
  function main (line 19) | def main(argv):

FILE: GMeshDiffusion/main_diffusion_ddp.py
  function main (line 20) | def main(argv):

FILE: GMeshDiffusion/metadata/save_tet_info.py
  function tet_to_grids (line 14) | def tet_to_grids(vertices, values_list, grid_size):

FILE: GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py
  function tet_to_grids (line 7) | def tet_to_grids(vertices, values_list, grid_size):

FILE: data/tets/generate_tets.py
  function generate_tetrahedron_grid_file (line 21) | def generate_tetrahedron_grid_file(res=32, root='..'):
  function convert_from_quartet_to_npz (line 31) | def convert_from_quartet_to_npz(quartetfile = 'cube_32_tet.tet', npzfile...

FILE: dataset/dataset.py
  class Dataset (line 12) | class Dataset(torch.utils.data.Dataset):
    method __init__ (line 14) | def __init__(self):
    method __len__ (line 17) | def __len__(self):
    method __getitem__ (line 20) | def __getitem__(self):
    method collate (line 23) | def collate(self, batch):

FILE: dataset/dataset_deepfashion.py
  function load_K_Rt_from_P (line 24) | def load_K_Rt_from_P(filename, P=None):
  function _load_img (line 48) | def _load_img(path):
  class DatasetDeepFashion (line 59) | class DatasetDeepFashion(Dataset):
    method __init__ (line 60) | def __init__(self, base_dir, FLAGS, examples=None):
    method _parse_frame (line 101) | def _parse_frame(self, idx):
    method __len__ (line 121) | def __len__(self):
    method __getitem__ (line 124) | def __getitem__(self, itr):

FILE: dataset/dataset_deepfashion_testset.py
  function load_K_Rt_from_P (line 24) | def load_K_Rt_from_P(filename, P=None):
  function _load_img (line 48) | def _load_img(path):
  function _load_mask (line 58) | def _load_mask(path):
  class DatasetDeepFashionTestset (line 67) | class DatasetDeepFashionTestset(Dataset):
    method __init__ (line 68) | def __init__(self, base_dir, FLAGS, examples=None):
    method _parse_frame (line 101) | def _parse_frame(self, idx):
    method __len__ (line 119) | def __len__(self):
    method __getitem__ (line 122) | def __getitem__(self, itr):

FILE: dataset/dataset_llff.py
  function _load_mask (line 20) | def _load_mask(fn):
  function _load_img (line 26) | def _load_img(fn):
  class DatasetLLFF (line 39) | class DatasetLLFF(Dataset):
    method __init__ (line 40) | def __init__(self, base_dir, FLAGS, examples=None):
    method _parse_frame (line 75) | def _parse_frame(self, idx):
    method __len__ (line 93) | def __len__(self):
    method __getitem__ (line 96) | def __getitem__(self, itr):

FILE: dataset/dataset_mesh.py
  class DatasetMesh (line 24) | class DatasetMesh(Dataset):
    method __init__ (line 26) | def __init__(self, ref_mesh, glctx, cam_radius, FLAGS, validate=False,...
    method _rotate_scene (line 54) | def _rotate_scene(self, itr):
    method _random_scene (line 65) | def _random_scene(self):
    method __len__ (line 89) | def __len__(self):
    method __getitem__ (line 92) | def __getitem__(self, itr):

FILE: dataset/dataset_nerf.py
  function _load_img (line 25) | def _load_img(path):
  class DatasetNERF (line 36) | class DatasetNERF(Dataset):
    method __init__ (line 37) | def __init__(self, cfg_path, FLAGS, examples=None):
    method _parse_frame (line 59) | def _parse_frame(self, cfg, idx):
    method __len__ (line 73) | def __len__(self):
    method __getitem__ (line 76) | def __getitem__(self, itr):

FILE: dataset/dataset_nerf_colmap.py
  function _load_img (line 25) | def _load_img(path):
  class DatasetNERF (line 34) | class DatasetNERF(Dataset):
    method __init__ (line 35) | def __init__(self, cfg_path, FLAGS, examples=None):
    method _parse_frame (line 57) | def _parse_frame(self, cfg, idx):
    method __len__ (line 73) | def __len__(self):
    method __getitem__ (line 76) | def __getitem__(self, itr):

FILE: denoiser/denoiser.py
  class BilateralDenoiser (line 21) | class BilateralDenoiser(torch.nn.Module):
    method __init__ (line 22) | def __init__(self, influence=1.0):
    method set_influence (line 26) | def set_influence(self, factor):
    method forward (line 31) | def forward(self, input):

FILE: geometry/embedding.py
  class Embedding (line 4) | class Embedding(nn.Module):
    method __init__ (line 5) | def __init__(self, in_channels, N_freqs, logscale=True):
    method forward (line 21) | def forward(self, x):

FILE: geometry/gshell_flexicubes.py
  class GShellFlexiCubes (line 16) | class GShellFlexiCubes:
    method __init__ (line 67) | def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99):
    method construct_voxel_grid (line 103) | def construct_voxel_grid(self, res):
    method __call__ (line 136) | def __call__(self, x_nx3, s_n, nu_n, cube_fx8, res, beta_fx12=None, al...
    method _compute_reg_loss (line 232) | def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
    method _normalize_weights (line 242) | def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes):
    method _get_case_id (line 266) | def _get_case_id(self, occ_fx8, surf_cubes, res):
    method _identify_surf_edges (line 309) | def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes):
    method _identify_surf_cubes (line 334) | def _identify_surf_cubes(self, s_n, cube_fx8):
    method _linear_interp (line 345) | def _linear_interp(self, edges_weight, edges_x):
    method _linear_interp_nonan (line 357) | def _linear_interp_nonan(self, edges_weight, edges_x):
    method _solve_vd_QEF (line 373) | def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None):
    method _compute_vd (line 387) | def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, nu_n, ca...
    method _triangulate (line 487) | def _triangulate(self, s_n, surf_edges, vd, nu_d, nu_d_stopvgd, vd_gam...
    method _triangulate_msdf (line 554) | def _triangulate_msdf(self, vertices, faces, nu_n, nu_n_stopvgd):
    method _tetrahedralize (line 593) | def _tetrahedralize(

FILE: geometry/gshell_flexicubes_geometry.py
  function compute_sdf_reg_loss (line 33) | def compute_sdf_reg_loss(sdf, all_edges):
  class GShellFlexiCubesGeometry (line 45) | class GShellFlexiCubesGeometry(torch.nn.Module):
    method __init__ (line 46) | def __init__(self, grid_res, scale, FLAGS):
    method generate_edges (line 111) | def generate_edges(self):
    method getAABB (line 120) | def getAABB(self):
    method clamp_deform (line 124) | def clamp_deform(self):
    method map_uv2 (line 130) | def map_uv2(self, faces):
    method map_uv (line 136) | def map_uv(self, face_gidx, max_idx):
    method getMesh (line 166) | def getMesh(self, material, _training=False):
    method render (line 210) | def render(self, glctx, target, lgt, opt_material, bsdf=None, denoiser...
    method tick (line 237) | def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, d...

FILE: geometry/gshell_tets.py
  function auto_normals (line 9) | def auto_normals(v_pos, t_pos_idx):
  function compute_tangents (line 40) | def compute_tangents(v_pos, v_tex, v_nrm, t_pos_idx, t_tex_idx, t_nrm_idx):
  class GShell_Tets (line 80) | class GShell_Tets:
    method __init__ (line 81) | def __init__(self):
    method sort_edges (line 200) | def sort_edges(self, edges_ex2):
    method map_uv (line 210) | def map_uv(self, face_gidx, max_idx):
    method __call__ (line 245) | def __call__(self, pos_nx3, sdf_n, msdf_n, tet_fx4, output_watertight_...
    method marching_from_auggrid (line 447) | def marching_from_auggrid(self, pos_nx3, sdf_n, tet_fx4,

FILE: geometry/gshell_tets_geometry.py
  function compute_sdf_reg_loss (line 33) | def compute_sdf_reg_loss(sdf, all_edges):
  class GShellTetsGeometry (line 45) | class GShellTetsGeometry(torch.nn.Module):
    method __init__ (line 46) | def __init__(self, grid_res, scale, FLAGS, offset=None, tet_init_file=...
    method generate_edges (line 149) | def generate_edges(self):
    method getAABB (line 158) | def getAABB(self):
    method clamp_deform (line 162) | def clamp_deform(self):
    method getMesh_from_augmented_grid_withocc (line 167) | def getMesh_from_augmented_grid_withocc(self, material, sdf_sign, sdf_...
    method getMesh (line 191) | def getMesh(self, material):
    method render (line 230) | def render(self, glctx, target, lgt, opt_material, bsdf=None, denoiser...
    method tick (line 257) | def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration, d...

FILE: geometry/mlp.py
  class MLP (line 7) | class MLP(nn.Module):
    method __init__ (line 8) | def __init__(self, n_freq=6, d_hidden=128, d_out=1, n_hidden=3, skip_i...
    method forward (line 32) | def forward(self, x):

FILE: render/light.py
  class EnvironmentLight (line 21) | class EnvironmentLight:
    method __init__ (line 27) | def __init__(self, base):
    method xfm (line 34) | def xfm(self, mtx):
    method parameters (line 37) | def parameters(self):
    method clone (line 40) | def clone(self):
    method clamp_ (line 43) | def clamp_(self, min=None, max=None):
    method update_pdf (line 46) | def update_pdf(self):
    method generate_image (line 62) | def generate_image(self, res):
  function _load_env_hdr (line 71) | def _load_env_hdr(fn, scale=1.0, res=None, trainable=False):
  function load_env (line 86) | def load_env(fn, scale=1.0, res=None, trainable=False):
  function save_env_map (line 93) | def save_env_map(fn, light):
  function create_trainable_env_rnd (line 102) | def create_trainable_env_rnd(base_res, scale=0.5, bias=0.25):

FILE: render/material.py
  function load_mtl (line 21) | def load_mtl(fn, clear_ks=True):
  function save_mtl (line 72) | def save_mtl(fn, material):
  function create_trainable (line 99) | def create_trainable(material):
  function get_parameters (line 106) | def get_parameters(material):
  function _upscale_replicate (line 117) | def _upscale_replicate(x, full_res):
  function merge_materials (line 122) | def merge_materials(materials, texcoords, tfaces, mfaces):

FILE: render/mesh.py
  class Mesh (line 20) | class Mesh:
    method __init__ (line 21) | def __init__(self, v_pos=None, t_pos_idx=None, v_nrm=None, t_nrm_idx=N...
    method copy_none (line 35) | def copy_none(self, other):
    method clone (line 55) | def clone(self):
  function load_mesh (line 79) | def load_mesh(filename, mtl_override=None, mtl_default=None, mtl_type_ov...
  function aabb (line 88) | def aabb(mesh):
  function aabb_clean (line 94) | def aabb_clean(mesh):
  function compute_edges (line 101) | def compute_edges(attr_idx, return_inverse=False):
  function compute_edge_to_face_mapping (line 123) | def compute_edge_to_face_mapping(attr_idx, return_inverse=False):
  function unit_size (line 158) | def unit_size(mesh):
  function center_by_reference (line 170) | def center_by_reference(base_mesh, ref_aabb, scale):
  function center_by_reference_noscale (line 177) | def center_by_reference_noscale(base_mesh, ref_aabb, scale=None):
  function center_with_global_aabb (line 183) | def center_with_global_aabb(base_mesh, ref_aabb, scale):
  function center_with_global_aabb_perdim (line 191) | def center_with_global_aabb_perdim(base_mesh, ref_aabb, scale):
  function scale_with_global_aabb (line 198) | def scale_with_global_aabb(base_mesh, ref_aabb, scale):
  function scale_with_global_aabb_perdim (line 204) | def scale_with_global_aabb_perdim(base_mesh, ref_aabb, scale):
  function auto_normals (line 212) | def auto_normals(imesh):
  function compute_tangents (line 243) | def compute_tangents(imesh, v_tng=None):

FILE: render/mlptexture.py
  class _MLP (line 18) | class _MLP(torch.nn.Module):
    method __init__ (line 19) | def __init__(self, cfg, loss_scale=1.0):
    method forward (line 33) | def forward(self, x):
    method _init_weights (line 37) | def _init_weights(m):
  class MLPTexture3D (line 47) | class MLPTexture3D(torch.nn.Module):
    method __init__ (line 48) | def __init__(self, AABB, channels = 3, internal_dims = 32, hidden = 2,...
    method sample (line 87) | def sample(self, texc):
    method clamp_ (line 101) | def clamp_(self):
    method cleanup (line 104) | def cleanup(self):

FILE: render/obj.py
  function _find_mat (line 21) | def _find_mat(materials, name):
  function load_obj (line 31) | def load_obj(filename, clear_ks=True, mtl_override=None, mtl_default=Non...
  function write_obj (line 143) | def write_obj(folder, mesh, save_material=True):

FILE: render/optixutils/c_src/accessor.h
  type T (line 122) | typedef T* PtrType;
  type T (line 128) | typedef T* __restrict__ PtrType;
  function C10_HOST_DEVICE (line 150) | C10_HOST_DEVICE index_t size(index_t i) const {
  function C10_HOST_DEVICE (line 153) | C10_HOST_DEVICE PtrType data() {
  function PtrTraits (line 184) | PtrTraits, index_t> operator[](index_t i) const {
  function C10_HOST_DEVICE (line 199) | C10_HOST_DEVICE T & operator[](index_t i) {
  function C10_HOST_DEVICE (line 203) | C10_HOST_DEVICE const T & operator[](index_t i) const {
  function C10_HOST (line 222) | C10_HOST GenericPackedTensorAccessorBase() {}
  function C10_HOST (line 224) | C10_HOST GenericPackedTensorAccessorBase(
  function C10_HOST_DEVICE (line 246) | C10_HOST_DEVICE index_t stride(index_t i) const {
  function C10_HOST_DEVICE (line 249) | C10_HOST_DEVICE index_t size(index_t i) const {
  function C10_HOST_DEVICE (line 252) | C10_HOST_DEVICE PtrType data() {
  function C10_HOST (line 270) | C10_HOST GenericPackedTensorAccessor() : GenericPackedTensorAccessorBase...
  function C10_DEVICE (line 323) | C10_DEVICE T & operator[](index_t i) {
  function C10_DEVICE (line 326) | C10_DEVICE const T& operator[](index_t i) const {

FILE: render/optixutils/c_src/bsdf.h
  function __device__ (line 21) | __device__ inline float fwdLambert(const float3 nrm, const float3 wi)
  function __device__ (line 26) | __device__ inline void bwdLambert(const float3 nrm, const float3 wi, flo...
  function __device__ (line 35) | __device__ inline float fwdFresnelSchlick(const float f0, const float f9...
  function __device__ (line 42) | __device__ inline void bwdFresnelSchlick(const float f0, const float f90...
  function __device__ (line 54) | __device__ inline float3 fwdFresnelSchlick(const float3 f0, const float3...
  function __device__ (line 61) | __device__ inline void bwdFresnelSchlick(const float3 f0, const float3 f...
  function __device__ (line 76) | __device__ inline float fwdNdfGGX(const float alphaSqr, const float cosT...
  function __device__ (line 83) | __device__ inline void bwdNdfGGX(const float alphaSqr, const float cosTh...
  function __device__ (line 98) | __device__ inline float fwdLambdaGGX(const float alphaSqr, const float c...
  function __device__ (line 107) | __device__ inline void bwdLambdaGGX(const float alphaSqr, const float co...
  function __device__ (line 122) | __device__ inline float fwdMaskingSmithGGXCorrelated(const float alphaSq...
  function __device__ (line 129) | __device__ inline void bwdMaskingSmithGGXCorrelated(const float alphaSqr...
  function __device__ (line 144) | __device__ float3 fwdPbrSpecular(const float3 col, const float3 nrm, con...
  function __device__ (line 164) | __device__ void bwdPbrSpecular(
  function __device__ (line 222) | __device__ void fwdPbrBSDF(const float3 kd, const float3 arm, const floa...
  function __device__ (line 238) | __device__ void bwdPbrBSDF(

FILE: render/optixutils/c_src/common.h
  function __device__ (line 14) | __device__ inline float3 fetch3(const T &tensor, U idx, Args... args) {
  function float3 (line 17) | inline float3 fetch3(const T &tensor) {
  function __device__ (line 22) | __device__ inline float2 fetch2(const T &tensor, U idx, Args... args) {
  function float2 (line 25) | inline float2 fetch2(const T &tensor) {

FILE: render/optixutils/c_src/denoising.h
  type BilateralDenoiserParams (line 12) | struct BilateralDenoiserParams

FILE: render/optixutils/c_src/envsampling/params.h
  type EnvSamplingParams (line 11) | struct EnvSamplingParams

FILE: render/optixutils/c_src/math_utils.h
  function T (line 13) | __inline__ T clamp(T x, T _min, T _max) { return min(_max, max(_min, x)); }
  function __device__ (line 14) | static __device__ inline float3 make_float3(float a) { return make_float...
  function __device__ (line 80) | static __device__ inline float sum(float3 a)
  function __device__ (line 85) | static __device__ inline float dot(float3 a, float3 b) { return a.x * b....
  function __device__ (line 87) | static __device__ inline void bwd_dot(float3 a, float3 b, float3& d_a, f...
  function __device__ (line 93) | static __device__ inline float luminance(const float3 rgb)
  function __device__ (line 98) | static __device__ inline float3 cross(float3 a, float3 b)
  function __device__ (line 107) | static __device__ inline void bwd_cross(float3 a, float3 b, float3 &d_a,...
  function __device__ (line 118) | static __device__ inline float3 reflect(float3 x, float3 n)
  function __device__ (line 123) | static __device__ inline void bwd_reflect(float3 x, float3 n, float3& d_...
  function __device__ (line 134) | static __device__ inline float3 safe_normalize(float3 v)
  function __device__ (line 140) | static __device__ inline void bwd_safe_normalize(const float3 v, float3&...
  function __device__ (line 155) | static __device__ inline void branchlessONB(const float3 &n, float3 &b1,...

FILE: render/optixutils/c_src/optix_wrapper.cpp
  function context_log_cb (line 43) | static void context_log_cb( unsigned int level, const char* tag, const c...
  function readSourceFile (line 49) | static bool readSourceFile( std::string& str, const std::string& filename )
  function getCuStringFromFile (line 63) | static void getCuStringFromFile( std::string& cu, const char* filename )
  function getPtxFromCuString (line 73) | static void getPtxFromCuString( std::string& ptx, const char* include_di...
  type SbtRecord (line 144) | struct SbtRecord
  function createPipeline (line 149) | void createPipeline(const OptixDeviceContext context, const std::string&...

FILE: render/optixutils/c_src/optix_wrapper.h
  type OptiXState (line 17) | struct OptiXState
  function class (line 30) | class OptiXStateWrapper

FILE: render/optixutils/c_src/torch_bindings.cpp
  function optix_build_bvh (line 37) | void optix_build_bvh(OptiXStateWrapper& stateWrapper,torch::Tensor grid_...
  function packed_accessor32 (line 118) | PackedTensorAccessor32<T, N> packed_accessor32(torch::Tensor tensor)
  function env_shade_fwd (line 123) | std::tuple<torch::Tensor, torch::Tensor> env_shade_fwd(
  function env_shade_bwd (line 190) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, t...
  function bilateral_denoiser_fwd (line 274) | torch::Tensor bilateral_denoiser_fwd(torch::Tensor col, torch::Tensor nr...
  function bilateral_denoiser_bwd (line 297) | torch::Tensor bilateral_denoiser_bwd(torch::Tensor col, torch::Tensor nr...
  function PYBIND11_MODULE (line 321) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: render/optixutils/include/internal/optix_7_device_impl.h
  function optixTrace (line 39) | void optixTrace( OptixTraversableHandle handle,
  function optixTrace (line 78) | void optixTrace( OptixTraversableHandle handle,
  function optixTrace (line 118) | void optixTrace( OptixTraversableHandle handle,
  function optixTrace (line 159) | void optixTrace( OptixTraversableHandle handle,
  function optixTrace (line 201) | void optixTrace( OptixTraversableHandle handle,
  function optixTrace (line 244) | void optixTrace( OptixTraversableHandle handle,
  function optixTrace (line 288) | void optixTrace( OptixTraversableHandle handle,
  function optixTrace (line 333) | void optixTrace( OptixTraversableHandle handle,
  function optixTrace (line 379) | void optixTrace( OptixTraversableHandle handle,
  function optixSetPayload_0 (line 427) | void optixSetPayload_0( unsigned int p )
  function optixSetPayload_1 (line 432) | void optixSetPayload_1( unsigned int p )
  function optixSetPayload_2 (line 437) | void optixSetPayload_2( unsigned int p )
  function optixSetPayload_3 (line 442) | void optixSetPayload_3( unsigned int p )
  function optixSetPayload_4 (line 447) | void optixSetPayload_4( unsigned int p )
  function optixSetPayload_5 (line 452) | void optixSetPayload_5( unsigned int p )
  function optixSetPayload_6 (line 457) | void optixSetPayload_6( unsigned int p )
  function optixSetPayload_7 (line 462) | void optixSetPayload_7( unsigned int p )
  function optixGetPayload_0 (line 468) | unsigned int optixGetPayload_0()
  function optixGetPayload_1 (line 475) | unsigned int optixGetPayload_1()
  function optixGetPayload_2 (line 482) | unsigned int optixGetPayload_2()
  function optixGetPayload_3 (line 489) | unsigned int optixGetPayload_3()
  function optixGetPayload_4 (line 496) | unsigned int optixGetPayload_4()
  function optixGetPayload_5 (line 503) | unsigned int optixGetPayload_5()
  function optixGetPayload_6 (line 510) | unsigned int optixGetPayload_6()
  function optixGetPayload_7 (line 517) | unsigned int optixGetPayload_7()
  function optixUndefinedValue (line 525) | unsigned int optixUndefinedValue()
  function float3 (line 532) | float3 optixGetWorldRayOrigin()
  function float3 (line 541) | float3 optixGetWorldRayDirection()
  function float3 (line 550) | float3 optixGetObjectRayOrigin()
  function float3 (line 559) | float3 optixGetObjectRayDirection()
  function optixGetRayTmin (line 568) | float optixGetRayTmin()
  function optixGetRayTmax (line 575) | float optixGetRayTmax()
  function optixGetRayTime (line 582) | float optixGetRayTime()
  function optixGetRayFlags (line 589) | unsigned int optixGetRayFlags()
  function optixGetRayVisibilityMask (line 596) | unsigned int optixGetRayVisibilityMask()
  function OptixTraversableHandle (line 603) | OptixTraversableHandle optixGetInstanceTraversableFromIAS( OptixTraversa...
  function optixGetTriangleVertexData (line 613) | void optixGetTriangleVertexData( OptixTraversableHandle gas,
  function optixGetLinearCurveVertexData (line 627) | void optixGetLinearCurveVertexData( OptixTraversableHandle gas,
  function optixGetQuadraticBSplineVertexData (line 641) | void optixGetQuadraticBSplineVertexData( OptixTraversableHandle gas,
  function optixGetCubicBSplineVertexData (line 656) | void optixGetCubicBSplineVertexData( OptixTraversableHandle gas,
  function OptixTraversableHandle (line 673) | OptixTraversableHandle optixGetGASTraversableHandle()
  function optixGetGASMotionTimeBegin (line 680) | float optixGetGASMotionTimeBegin( OptixTraversableHandle handle )
  function optixGetGASMotionTimeEnd (line 687) | float optixGetGASMotionTimeEnd( OptixTraversableHandle handle )
  function optixGetGASMotionStepCount (line 694) | unsigned int optixGetGASMotionStepCount( OptixTraversableHandle handle )
  function optixGetWorldToObjectTransformMatrix (line 701) | void optixGetWorldToObjectTransformMatrix( float m[12] )
  function optixGetObjectToWorldTransformMatrix (line 736) | void optixGetObjectToWorldTransformMatrix( float m[12] )
  function float3 (line 771) | float3 optixTransformPointFromWorldToObjectSpace( float3 point )
  function float3 (line 781) | float3 optixTransformVectorFromWorldToObjectSpace( float3 vec )
  function float3 (line 791) | float3 optixTransformNormalFromWorldToObjectSpace( float3 normal )
  function float3 (line 801) | float3 optixTransformPointFromObjectToWorldSpace( float3 point )
  function float3 (line 811) | float3 optixTransformVectorFromObjectToWorldSpace( float3 vec )
  function float3 (line 821) | float3 optixTransformNormalFromObjectToWorldSpace( float3 normal )
  function optixGetTransformListSize (line 831) | unsigned int optixGetTransformListSize()
  function OptixTraversableHandle (line 838) | OptixTraversableHandle optixGetTransformListHandle( unsigned int index )
  function OptixTransformType (line 845) | OptixTransformType optixGetTransformTypeFromHandle( OptixTraversableHand...
  function OptixStaticTransform (line 852) | const OptixStaticTransform* optixGetStaticTransformFromHandle( OptixTrav...
  function OptixSRTMotionTransform (line 859) | const OptixSRTMotionTransform* optixGetSRTMotionTransformFromHandle( Opt...
  function OptixMatrixMotionTransform (line 866) | const OptixMatrixMotionTransform* optixGetMatrixMotionTransformFromHandl...
  function optixGetInstanceIdFromHandle (line 873) | unsigned int optixGetInstanceIdFromHandle( OptixTraversableHandle handle )
  function OptixTraversableHandle (line 880) | OptixTraversableHandle optixGetInstanceChildFromHandle( OptixTraversable...
  function float4 (line 887) | const float4* optixGetInstanceTransformFromHandle( OptixTraversableHandl...
  function float4 (line 894) | const float4* optixGetInstanceInverseTransformFromHandle( OptixTraversab...
  function optixReportIntersection (line 901) | bool optixReportIntersection( float hitT, unsigned int hitKind )
  function optixReportIntersection (line 913) | bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned...
  function optixReportIntersection (line 925) | bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned...
  function optixReportIntersection (line 937) | bool optixReportIntersection( float hitT, unsigned int hitKind, unsigned...
  function optixReportIntersection (line 949) | bool optixReportIntersection( float        hitT,
  function optixReportIntersection (line 966) | bool optixReportIntersection( float        hitT,
  function optixReportIntersection (line 984) | bool optixReportIntersection( float        hitT,
  function optixReportIntersection (line 1003) | bool optixReportIntersection( float        hitT,
  function optixReportIntersection (line 1023) | bool optixReportIntersection( float        hitT,
  function optixGetAttribute_0 (line 1049) | unsigned int optixGetAttribute_0()
  function optixGetAttribute_1 (line 1054) | unsigned int optixGetAttribute_1()
  function optixGetAttribute_2 (line 1059) | unsigned int optixGetAttribute_2()
  function optixGetAttribute_3 (line 1064) | unsigned int optixGetAttribute_3()
  function optixGetAttribute_4 (line 1069) | unsigned int optixGetAttribute_4()
  function optixGetAttribute_5 (line 1074) | unsigned int optixGetAttribute_5()
  function optixGetAttribute_6 (line 1079) | unsigned int optixGetAttribute_6()
  function optixGetAttribute_7 (line 1084) | unsigned int optixGetAttribute_7()
  function optixTerminateRay (line 1091) | void optixTerminateRay()
  function optixIgnoreIntersection (line 1096) | void optixIgnoreIntersection()
  function optixGetPrimitiveIndex (line 1101) | unsigned int optixGetPrimitiveIndex()
  function optixGetSbtGASIndex (line 1108) | unsigned int optixGetSbtGASIndex()
  function optixGetInstanceId (line 1115) | unsigned int optixGetInstanceId()
  function optixGetInstanceIndex (line 1122) | unsigned int optixGetInstanceIndex()
  function optixGetHitKind (line 1129) | unsigned int optixGetHitKind()
  function OptixPrimitiveType (line 1136) | OptixPrimitiveType optixGetPrimitiveType(unsigned int hitKind)
  function optixIsBackFaceHit (line 1143) | bool optixIsBackFaceHit( unsigned int hitKind )
  function optixIsFrontFaceHit (line 1150) | bool optixIsFrontFaceHit( unsigned int hitKind )
  function OptixPrimitiveType (line 1156) | OptixPrimitiveType optixGetPrimitiveType()
  function optixIsBackFaceHit (line 1161) | bool optixIsBackFaceHit()
  function optixIsFrontFaceHit (line 1166) | bool optixIsFrontFaceHit()
  function optixIsTriangleHit (line 1171) | bool optixIsTriangleHit()
  function optixIsTriangleFrontFaceHit (line 1176) | bool optixIsTriangleFrontFaceHit()
  function optixIsTriangleBackFaceHit (line 1181) | bool optixIsTriangleBackFaceHit()
  function optixGetCurveParameter (line 1186) | float optixGetCurveParameter()
  function float2 (line 1191) | float2 optixGetTriangleBarycentrics()
  function uint3 (line 1198) | uint3 optixGetLaunchIndex()
  function uint3 (line 1207) | uint3 optixGetLaunchDimensions()
  function CUdeviceptr (line 1216) | CUdeviceptr optixGetSbtDataPointer()
  function optixThrowException (line 1223) | void optixThrowException( int exceptionCode )
  function optixThrowException (line 1232) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
  function optixThrowException (line 1241) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
  function optixThrowException (line 1250) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
  function optixThrowException (line 1259) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
  function optixThrowException (line 1268) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
  function optixThrowException (line 1277) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
  function optixThrowException (line 1286) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
  function optixThrowException (line 1295) | void optixThrowException( int exceptionCode, unsigned int exceptionDetai...
  function optixGetExceptionCode (line 1304) | int optixGetExceptionCode()
  function optixGetExceptionDetail_0 (line 1316) | unsigned int optixGetExceptionDetail_0()
  function optixGetExceptionDetail_1 (line 1321) | unsigned int optixGetExceptionDetail_1()
  function optixGetExceptionDetail_2 (line 1326) | unsigned int optixGetExceptionDetail_2()
  function optixGetExceptionDetail_3 (line 1331) | unsigned int optixGetExceptionDetail_3()
  function optixGetExceptionDetail_4 (line 1336) | unsigned int optixGetExceptionDetail_4()
  function optixGetExceptionDetail_5 (line 1341) | unsigned int optixGetExceptionDetail_5()
  function optixGetExceptionDetail_6 (line 1346) | unsigned int optixGetExceptionDetail_6()
  function optixGetExceptionDetail_7 (line 1351) | unsigned int optixGetExceptionDetail_7()
  function OptixTraversableHandle (line 1358) | OptixTraversableHandle optixGetExceptionInvalidTraversable()
  function optixGetExceptionInvalidSbtOffset (line 1365) | int optixGetExceptionInvalidSbtOffset()
  function OptixInvalidRayExceptionDetails (line 1372) | OptixInvalidRayExceptionDetails optixGetExceptionInvalidRay()
  function OptixParameterMismatchExceptionDetails (line 1388) | OptixParameterMismatchExceptionDetails optixGetExceptionParameterMismatch()
  function ReturnT (line 1411) | ReturnT optixDirectCall( unsigned int sbtIndex, ArgTypes... args )
  function ReturnT (line 1421) | ReturnT optixContinuationCall( unsigned int sbtIndex, ArgTypes... args )
  function uint4 (line 1431) | uint4 optixTexFootprint2D( unsigned long long tex, unsigned int texInfo,...
  function uint4 (line 1447) | uint4 optixTexFootprint2DGrad( unsigned long long tex,
  function uint4 (line 1474) | uint4

FILE: render/optixutils/include/internal/optix_7_device_impl_exception.h
  function namespace (line 40) | namespace optix_impl {

FILE: render/optixutils/include/internal/optix_7_device_impl_transformations.h
  function namespace (line 36) | namespace optix_impl {

FILE: render/optixutils/include/optix_7_types.h
  type CUdeviceptr (line 50) | typedef unsigned long long CUdeviceptr;
  type CUdeviceptr (line 53) | typedef unsigned int CUdeviceptr;
  type OptixDeviceContext_t (line 57) | struct OptixDeviceContext_t
  type OptixModule_t (line 60) | struct OptixModule_t
  type OptixProgramGroup_t (line 63) | struct OptixProgramGroup_t
  type OptixPipeline_t (line 66) | struct OptixPipeline_t
  type OptixDenoiser_t (line 69) | struct OptixDenoiser_t
  type OptixTraversableHandle (line 72) | typedef unsigned long long OptixTraversableHandle;
  type OptixVisibilityMask (line 75) | typedef unsigned int OptixVisibilityMask;
  type OptixResult (line 112) | typedef enum OptixResult
  type OptixDeviceProperty (line 156) | typedef enum OptixDeviceProperty
  type OptixDeviceContextValidationMode (line 225) | typedef enum OptixDeviceContextValidationMode
  type OptixDeviceContextOptions (line 234) | typedef struct OptixDeviceContextOptions
  type OptixGeometryFlags (line 249) | typedef enum OptixGeometryFlags
  type OptixHitKind (line 269) | typedef enum OptixHitKind
  type OptixIndicesFormat (line 278) | typedef enum OptixIndicesFormat
  type OptixVertexFormat (line 289) | typedef enum OptixVertexFormat
  type OptixTransformFormat (line 301) | typedef enum OptixTransformFormat
  type OptixBuildInputTriangleArray (line 310) | typedef struct OptixBuildInputTriangleArray
  type OptixPrimitiveType (line 381) | typedef enum OptixPrimitiveType
  type OptixPrimitiveTypeFlags (line 398) | typedef enum OptixPrimitiveTypeFlags
  type OptixBuildInputCurveArray (line 429) | typedef struct OptixBuildInputCurveArray
  type OptixAabb (line 480) | typedef struct OptixAabb
  type OptixBuildInputCustomPrimitiveArray (line 493) | typedef struct OptixBuildInputCustomPrimitiveArray
  type OptixBuildInputInstanceArray (line 538) | typedef struct OptixBuildInputInstanceArray
  type OptixBuildInputType (line 556) | typedef enum OptixBuildInputType
  type OptixBuildInput (line 575) | typedef struct OptixBuildInput
  type OptixInstanceFlags (line 606) | typedef enum OptixInstanceFlags
  type OptixInstance (line 638) | typedef struct OptixInstance
  type OptixBuildFlags (line 668) | typedef enum OptixBuildFlags
  type OptixBuildOperation (line 706) | typedef enum OptixBuildOperation
  type OptixMotionFlags (line 717) | typedef enum OptixMotionFlags
  type OptixMotionOptions (line 728) | typedef struct OptixMotionOptions
  type OptixAccelBuildOptions (line 747) | typedef struct OptixAccelBuildOptions
  type OptixAccelBufferSizes (line 767) | typedef struct OptixAccelBufferSizes
  type OptixAccelPropertyType (line 787) | typedef enum OptixAccelPropertyType
  type OptixAccelEmitDesc (line 799) | typedef struct OptixAccelEmitDesc
  type OptixAccelRelocationInfo (line 811) | typedef struct OptixAccelRelocationInfo
  type OptixStaticTransform (line 822) | typedef struct OptixStaticTransform
  type OptixMatrixMotionTransform (line 862) | typedef struct OptixMatrixMotionTransform
  type OptixSRTData (line 907) | typedef struct OptixSRTData
  type OptixSRTMotionTransform (line 944) | typedef struct OptixSRTMotionTransform
  type OptixTraversableType (line 967) | typedef enum OptixTraversableType
  type OptixPixelFormat (line 980) | typedef enum OptixPixelFormat
  type OptixImage2D (line 995) | typedef struct OptixImage2D
  type OptixDenoiserModelKind (line 1015) | typedef enum OptixDenoiserModelKind
  type OptixDenoiserOptions (line 1034) | typedef struct OptixDenoiserOptions
  type OptixDenoiserGuideLayer (line 1046) | typedef struct OptixDenoiserGuideLayer
  type OptixDenoiserLayer (line 1061) | typedef struct OptixDenoiserLayer
  type OptixDenoiserParams (line 1078) | typedef struct OptixDenoiserParams
  type OptixDenoiserSizes (line 1104) | typedef struct OptixDenoiserSizes
  type OptixRayFlags (line 1116) | typedef enum OptixRayFlags
  type OptixTransformType (line 1172) | typedef enum OptixTransformType
  type OptixTraversableGraphFlags (line 1183) | typedef enum OptixTraversableGraphFlags
  type OptixCompileOptimizationLevel (line 1204) | typedef enum OptixCompileOptimizationLevel
  type OptixCompileDebugLevel (line 1221) | typedef enum OptixCompileDebugLevel
  type OptixModuleCompileBoundValueEntry (line 1268) | typedef struct OptixModuleCompileBoundValueEntry {
  type OptixModuleCompileOptions (line 1280) | typedef struct OptixModuleCompileOptions
  type OptixProgramGroupKind (line 1302) | typedef enum OptixProgramGroupKind
  type OptixProgramGroupFlags (line 1326) | typedef enum OptixProgramGroupFlags
  type OptixProgramGroupSingleModule (line 1338) | typedef struct OptixProgramGroupSingleModule
  type OptixProgramGroupHitgroup (line 1351) | typedef struct OptixProgramGroupHitgroup
  type OptixProgramGroupCallables (line 1372) | typedef struct OptixProgramGroupCallables
  type OptixProgramGroupDesc (line 1385) | typedef struct OptixProgramGroupDesc
  type OptixProgramGroupOptions (line 1411) | typedef struct OptixProgramGroupOptions
  type OptixExceptionCodes (line 1418) | typedef enum OptixExceptionCodes
  type OptixExceptionFlags (line 1521) | typedef enum OptixExceptionFlags
  type OptixPipelineCompileOptions (line 1545) | typedef struct OptixPipelineCompileOptions
  type OptixPipelineLinkOptions (line 1581) | typedef struct OptixPipelineLinkOptions
  type OptixShaderBindingTable (line 1594) | typedef struct OptixShaderBindingTable
  type OptixStackSizes (line 1634) | typedef struct OptixStackSizes
  type OptixQueryFunctionTableOptions (line 1654) | typedef enum OptixQueryFunctionTableOptions
  type OptixResult (line 1662) | typedef OptixResult( OptixQueryFunctionTable_t )( int          abiId,
  type OptixBuiltinISOptions (line 1673) | typedef struct OptixBuiltinISOptions
  type OptixInvalidRayExceptionDetails (line 1685) | typedef struct OptixInvalidRayExceptionDetails
  type OptixParameterMismatchExceptionDetails (line 1700) | typedef struct OptixParameterMismatchExceptionDetails

FILE: render/optixutils/include/optix_denoiser_tiling.h
  type OptixUtilDenoiserImageTile (line 54) | struct OptixUtilDenoiserImageTile
  function optixUtilGetPixelStride (line 73) | inline unsigned int optixUtilGetPixelStride( const OptixImage2D& image )
  function OptixResult (line 118) | inline OptixResult optixUtilDenoiserSplitImage(
  function OptixResult (line 206) | inline OptixResult optixUtilDenoiserInvokeTiled(

FILE: render/optixutils/include/optix_function_table.h
  type OptixFunctionTable (line 55) | typedef struct OptixFunctionTable

FILE: render/optixutils/include/optix_stack_size.h
  function OptixResult (line 52) | inline OptixResult optixUtilAccumulateStackSizes( OptixProgramGroup prog...
  function OptixResult (line 86) | inline OptixResult optixUtilComputeStackSizes( const OptixStackSizes* st...
  function OptixResult (line 151) | inline OptixResult optixUtilComputeStackSizesDCSplit( const OptixStackSi...
  function OptixResult (line 212) | inline OptixResult optixUtilComputeStackSizesCssCCTree( const OptixStack...
  function OptixResult (line 263) | inline OptixResult optixUtilComputeStackSizesSimplePathTracer( OptixProg...

FILE: render/optixutils/include/optix_stubs.h
  function OptixResult (line 188) | inline OptixResult optixInitWithHandle( void** handlePtr )
  function OptixResult (line 224) | inline OptixResult optixInit( void )
  function OptixResult (line 235) | inline OptixResult optixUninitWithHandle( void* handle )
  function OptixResult (line 318) | inline OptixResult optixDeviceContextCreate( CUcontext fromContext, cons...
  function OptixResult (line 323) | inline OptixResult optixDeviceContextDestroy( OptixDeviceContext context )
  function OptixResult (line 328) | inline OptixResult optixDeviceContextGetProperty( OptixDeviceContext con...
  function OptixResult (line 333) | inline OptixResult optixDeviceContextSetLogCallback( OptixDeviceContext ...
  function OptixResult (line 341) | inline OptixResult optixDeviceContextSetCacheEnabled( OptixDeviceContext...
  function OptixResult (line 346) | inline OptixResult optixDeviceContextSetCacheLocation( OptixDeviceContex...
  function OptixResult (line 351) | inline OptixResult optixDeviceContextSetCacheDatabaseSizes( OptixDeviceC...
  function OptixResult (line 356) | inline OptixResult optixDeviceContextGetCacheEnabled( OptixDeviceContext...
  function OptixResult (line 361) | inline OptixResult optixDeviceContextGetCacheLocation( OptixDeviceContex...
  function OptixResult (line 366) | inline OptixResult optixDeviceContextGetCacheDatabaseSizes( OptixDeviceC...
  function OptixResult (line 371) | inline OptixResult optixModuleCreateFromPTX( OptixDeviceContext         ...
  function OptixResult (line 384) | inline OptixResult optixModuleDestroy( OptixModule module )
  function OptixResult (line 389) | inline OptixResult optixBuiltinISModuleGet( OptixDeviceContext          ...
  function OptixResult (line 399) | inline OptixResult optixProgramGroupCreate( OptixDeviceContext          ...
  function OptixResult (line 411) | inline OptixResult optixProgramGroupDestroy( OptixProgramGroup programGr...
  function OptixResult (line 416) | inline OptixResult optixProgramGroupGetStackSize( OptixProgramGroup prog...
  function OptixResult (line 421) | inline OptixResult optixPipelineCreate( OptixDeviceContext              ...
  function OptixResult (line 434) | inline OptixResult optixPipelineDestroy( OptixPipeline pipeline )
  function OptixResult (line 439) | inline OptixResult optixPipelineSetStackSize( OptixPipeline pipeline,
  function OptixResult (line 449) | inline OptixResult optixAccelComputeMemoryUsage( OptixDeviceContext     ...
  function OptixResult (line 458) | inline OptixResult optixAccelBuild( OptixDeviceContext            context,
  function OptixResult (line 477) | inline OptixResult optixAccelGetRelocationInfo( OptixDeviceContext conte...
  function OptixResult (line 483) | inline OptixResult optixAccelCheckRelocationCompatibility( OptixDeviceCo...
  function OptixResult (line 488) | inline OptixResult optixAccelRelocate( OptixDeviceContext              c...
  function OptixResult (line 501) | inline OptixResult optixAccelCompact( OptixDeviceContext      context,
  function OptixResult (line 511) | inline OptixResult optixConvertPointerToTraversableHandle( OptixDeviceCo...
  function OptixResult (line 519) | inline OptixResult optixSbtRecordPackHeader( OptixProgramGroup programGr...
  function OptixResult (line 524) | inline OptixResult optixLaunch( OptixPipeline                  pipeline,
  function OptixResult (line 536) | inline OptixResult optixDenoiserCreate( OptixDeviceContext context, Opti...
  function OptixResult (line 541) | inline OptixResult optixDenoiserCreateWithUserModel( OptixDeviceContext ...
  function OptixResult (line 546) | inline OptixResult optixDenoiserDestroy( OptixDenoiser handle )
  function OptixResult (line 551) | inline OptixResult optixDenoiserComputeMemoryResources( const OptixDenoi...
  function OptixResult (line 559) | inline OptixResult optixDenoiserSetup( OptixDenoiser denoiser,
  function OptixResult (line 572) | inline OptixResult optixDenoiserInvoke( OptixDenoiser                   ...
  function OptixResult (line 590) | inline OptixResult optixDenoiserComputeIntensity( OptixDenoiser       ha...
  function OptixResult (line 600) | inline OptixResult optixDenoiserComputeAverageColor( OptixDenoiser      ...

FILE: render/optixutils/ops.py
  function find_cl_path (line 25) | def find_cl_path():
  class _optix_env_shade_func (line 81) | class _optix_env_shade_func(torch.autograd.Function):
    method forward (line 85) | def forward(ctx, optix_ctx, mask, ro, gb_pos, gb_normal, gb_view_pos, ...
    method backward (line 101) | def backward(ctx, diff_grad, spec_grad):
  class _bilateral_denoiser_func (line 110) | class _bilateral_denoiser_func(torch.autograd.Function):
    method forward (line 112) | def forward(ctx, col, nrm, zdz, sigma):
    method backward (line 119) | def backward(ctx, out_grad):
  class OptiXContext (line 128) | class OptiXContext:
    method __init__ (line 129) | def __init__(self):
  function optix_build_bvh (line 133) | def optix_build_bvh(optix_ctx, verts, tris, rebuild):
  function optix_env_shade (line 141) | def optix_env_shade(optix_ctx, mask, ro, gb_pos, gb_normal, gb_view_pos,...
  function bilateral_denoiser (line 145) | def bilateral_denoiser(col, nrm, zdz, sigma):

FILE: render/optixutils/tests/filter_test.py
  function length (line 22) | def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
  function safe_normalize (line 25) | def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
  function dot (line 28) | def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  class BilateralDenoiser (line 31) | class BilateralDenoiser(torch.nn.Module):
    method __init__ (line 32) | def __init__(self, sigma=1.0):
    method set_sigma (line 36) | def set_sigma(self, sigma):
    method forward (line 41) | def forward(self, input):
  function relative_loss (line 76) | def relative_loss(name, ref, cuda):
  function test_filter (line 84) | def test_filter():

FILE: render/regularizer.py
  function luma (line 16) | def luma(x):
  function value (line 18) | def value(x):
  function chroma_loss (line 21) | def chroma_loss(kd, color_ref, lambda_chroma):
  function shading_loss (line 28) | def shading_loss(diffuse_light, specular_light, color_ref, lambda_diffus...
  function material_smoothness_grad (line 46) | def material_smoothness_grad(kd_grad, ks_grad, nrm_grad, lambda_kd=0.25,...
  function image_grad (line 56) | def image_grad(buf, std=0.01):
  function avg_edge_length (line 68) | def avg_edge_length(v_pos, t_pos_idx):
  function laplace_regularizer_const (line 77) | def laplace_regularizer_const(v_pos, t_pos_idx):
  function normal_consistency (line 101) | def normal_consistency(v_pos, t_pos_idx):

FILE: render/render.py
  function interpolate (line 25) | def interpolate(attr, rast, attr_idx, rast_db=None):
  function shade (line 31) | def shade(
  function render_layer (line 199) | def render_layer(
  function render_mesh (line 325) | def render_mesh(
  function render_uv (line 449) | def render_uv(ctx, mesh, resolution, mlp_texture):

FILE: render/renderutils/bsdf.py
  function _dot (line 19) | def _dot(x, y):
  function _reflect (line 22) | def _reflect(x, n):
  function _safe_normalize (line 25) | def _safe_normalize(x):
  function _bend_normal (line 28) | def _bend_normal(view_vec, smooth_nrm, geom_nrm, two_sided_shading):
  function _perturb_normal (line 38) | def _perturb_normal(perturbed_nrm, smooth_nrm, smooth_tng, opengl):
  function bsdf_prepare_shading_normal (line 46) | def bsdf_prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm...
  function bsdf_lambert (line 57) | def bsdf_lambert(nrm, wi):
  function bsdf_frostbite (line 64) | def bsdf_frostbite(nrm, wi, wo, linearRoughness):
  function bsdf_phong (line 85) | def bsdf_phong(nrm, wo, wi, N):
  function bsdf_fresnel_shlick (line 96) | def bsdf_fresnel_shlick(f0, f90, cosTheta):
  function bsdf_ndf_ggx (line 100) | def bsdf_ndf_ggx(alphaSqr, cosTheta):
  function bsdf_lambda_ggx (line 105) | def bsdf_lambda_ggx(alphaSqr, cosTheta):
  function bsdf_masking_smith_ggx_correlated (line 112) | def bsdf_masking_smith_ggx_correlated(alphaSqr, cosThetaI, cosThetaO):
  function bsdf_pbr_specular (line 117) | def bsdf_pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08):
  function bsdf_pbr (line 136) | def bsdf_pbr(kd, arm, pos, nrm, view_pos, light_pos, min_roughness, BSDF):

FILE: render/renderutils/c_src/bsdf.h
  type LambertKernelParams (line 16) | struct LambertKernelParams
  type FrostbiteDiffuseKernelParams (line 24) | struct FrostbiteDiffuseKernelParams
  type FresnelShlickKernelParams (line 34) | struct FresnelShlickKernelParams
  type NdfGGXParams (line 43) | struct NdfGGXParams
  type MaskingSmithParams (line 51) | struct MaskingSmithParams
  type PbrSpecular (line 60) | struct PbrSpecular
  type PbrBSDF (line 72) | struct PbrBSDF

FILE: render/renderutils/c_src/common.cpp
  function dim3 (line 18) | dim3 getLaunchBlockSize(int maxWidth, int maxHeight, dim3 dims)
  function dim3 (line 56) | dim3 getWarpSize(dim3 blockSize)
  function dim3 (line 65) | dim3 getLaunchGridSize(dim3 blockSize, dim3 dims)

FILE: render/renderutils/c_src/common.h
  function dim3 (line 29) | static inline dim3 getWarpSize(dim3 blockSize)
  function __device__ (line 38) | __device__ static inline float clamp(float val, float mn, float mx) { re...

FILE: render/renderutils/c_src/cubemap.h
  type DiffuseCubemapKernelParams (line 16) | struct DiffuseCubemapKernelParams
  type SpecularCubemapKernelParams (line 23) | struct SpecularCubemapKernelParams
  type SpecularBoundsKernelParams (line 33) | struct SpecularBoundsKernelParams

FILE: render/renderutils/c_src/loss.h
  type TonemapperType (line 16) | enum TonemapperType
  type LossType (line 22) | enum LossType
  type LossKernelParams (line 30) | struct LossKernelParams

FILE: render/renderutils/c_src/mesh.h
  type XfmKernelParams (line 16) | struct XfmKernelParams

FILE: render/renderutils/c_src/normal.h
  type PrepareShadingNormalKernelParams (line 16) | struct PrepareShadingNormalKernelParams

FILE: render/renderutils/c_src/tensor.h
  type Tensor (line 20) | struct Tensor
  function store (line 65) | inline void store(unsigned int x, unsigned int y, unsigned int z, float ...
  function __device__ (line 70) | __device__ inline void store(unsigned int x, unsigned int y, unsigned in...
  function __device__ (line 79) | __device__ inline void store_grad(unsigned int x, unsigned int y, unsign...
  function __device__ (line 84) | __device__ inline void store_grad(unsigned int x, unsigned int y, unsign...

FILE: render/renderutils/c_src/torch_bindings.cpp
  function update_grid (line 100) | void update_grid(dim3 &gridSize, torch::Tensor x)
  function update_grid (line 108) | void update_grid(dim3& gridSize, torch::Tensor x, Ts&&... vs)
  function Tensor (line 116) | Tensor make_cuda_tensor(torch::Tensor val)
  function Tensor (line 130) | Tensor make_cuda_tensor(torch::Tensor val, dim3 outDims, torch::Tensor* ...
  function prepare_shading_normal_fwd (line 161) | torch::Tensor prepare_shading_normal_fwd(torch::Tensor pos, torch::Tenso...
  function prepare_shading_normal_bwd (line 203) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, t...
  function lambert_fwd (line 237) | torch::Tensor lambert_fwd(torch::Tensor nrm, torch::Tensor wi, bool fp16)
  function lambert_bwd (line 268) | std::tuple<torch::Tensor, torch::Tensor> lambert_bwd(torch::Tensor nrm, ...
  function frostbite_fwd (line 295) | torch::Tensor frostbite_fwd(torch::Tensor nrm, torch::Tensor wi, torch::...
  function frostbite_bwd (line 330) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> f...
  function fresnel_shlick_fwd (line 359) | torch::Tensor fresnel_shlick_fwd(torch::Tensor f0, torch::Tensor f90, to...
  function fresnel_shlick_bwd (line 392) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> fresnel_shlick_b...
  function ndf_ggx_fwd (line 420) | torch::Tensor ndf_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTheta...
  function ndf_ggx_bwd (line 451) | std::tuple<torch::Tensor, torch::Tensor> ndf_ggx_bwd(torch::Tensor alpha...
  function lambda_ggx_fwd (line 478) | torch::Tensor lambda_ggx_fwd(torch::Tensor alphaSqr, torch::Tensor cosTh...
  function lambda_ggx_bwd (line 509) | std::tuple<torch::Tensor, torch::Tensor> lambda_ggx_bwd(torch::Tensor al...
  function masking_smith_fwd (line 536) | torch::Tensor masking_smith_fwd(torch::Tensor alphaSqr, torch::Tensor co...
  function masking_smith_bwd (line 569) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> masking_smith_bw...
  function pbr_specular_fwd (line 597) | torch::Tensor pbr_specular_fwd(torch::Tensor col, torch::Tensor nrm, tor...
  function pbr_specular_bwd (line 635) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, t...
  function pbr_bsdf_fwd (line 666) | torch::Tensor pbr_bsdf_fwd(torch::Tensor kd, torch::Tensor arm, torch::T...
  function pbr_bsdf_bwd (line 707) | std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, t...
  function diffuse_cubemap_fwd (line 740) | torch::Tensor diffuse_cubemap_fwd(torch::Tensor cubemap)
  function diffuse_cubemap_bwd (line 769) | torch::Tensor diffuse_cubemap_bwd(torch::Tensor cubemap, torch::Tensor g...
  function specular_bounds (line 799) | torch::Tensor specular_bounds(int resolution, float costheta_cutoff)
  function specular_cubemap_fwd (line 826) | torch::Tensor specular_cubemap_fwd(torch::Tensor cubemap, torch::Tensor ...
  function specular_cubemap_bwd (line 859) | torch::Tensor specular_cubemap_bwd(torch::Tensor cubemap, torch::Tensor ...
  function LossType (line 895) | LossType strToLoss(std::string str)
  function image_loss_fwd (line 907) | torch::Tensor image_loss_fwd(torch::Tensor img, torch::Tensor target, st...
  function image_loss_bwd (line 941) | std::tuple<torch::Tensor, torch::Tensor> image_loss_bwd(torch::Tensor im...
  function xfm_fwd (line 971) | torch::Tensor xfm_fwd(torch::Tensor points, torch::Tensor matrix, bool i...
  function xfm_bwd (line 1006) | torch::Tensor xfm_bwd(torch::Tensor points, torch::Tensor matrix, torch:...
  function PYBIND11_MODULE (line 1034) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: render/renderutils/c_src/vec3f.h
  type vec3f (line 14) | struct vec3f
  function __device__ (line 38) | __device__ static inline float sum(vec3f a)
  function __device__ (line 43) | __device__ static inline vec3f cross(vec3f a, vec3f b)
  function __device__ (line 52) | __device__ static inline void bwdCross(vec3f a, vec3f b, vec3f &d_a, vec...
  function __device__ (line 63) | __device__ static inline float dot(vec3f a, vec3f b)
  function __device__ (line 68) | __device__ static inline void bwdDot(vec3f a, vec3f b, vec3f& d_a, vec3f...
  function __device__ (line 74) | __device__ static inline vec3f reflect(vec3f x, vec3f n)
  function __device__ (line 79) | __device__ static inline void bwdReflect(vec3f x, vec3f n, vec3f& d_x, v...
  function __device__ (line 90) | __device__ static inline vec3f safeNormalize(vec3f v)
  function __device__ (line 96) | __device__ static inline void bwdSafeNormalize(const vec3f v, vec3f& d_v...

FILE: render/renderutils/c_src/vec4f.h
  type vec4f (line 14) | struct vec4f

FILE: render/renderutils/loss.py
  function _tonemap_srgb (line 16) | def _tonemap_srgb(f, exposure=5):
  function _SMAPE (line 20) | def _SMAPE(img, target, eps=0.01):
  function _RELMSE (line 25) | def _RELMSE(img, target, eps=0.1):
  function image_loss_fn (line 30) | def image_loss_fn(img, target, loss, tonemapper):

FILE: render/renderutils/ops.py
  function _get_plugin (line 23) | def _get_plugin():
  class _fresnel_shlick_func (line 92) | class _fresnel_shlick_func(torch.autograd.Function):
    method forward (line 94) | def forward(ctx, f0, f90, cosTheta):
    method backward (line 100) | def backward(ctx, dout):
  function _fresnel_shlick (line 104) | def _fresnel_shlick(f0, f90, cosTheta, use_python=False):
  class _ndf_ggx_func (line 115) | class _ndf_ggx_func(torch.autograd.Function):
    method forward (line 117) | def forward(ctx, alphaSqr, cosTheta):
    method backward (line 123) | def backward(ctx, dout):
  function _ndf_ggx (line 127) | def _ndf_ggx(alphaSqr, cosTheta, use_python=False):
  class _lambda_ggx_func (line 137) | class _lambda_ggx_func(torch.autograd.Function):
    method forward (line 139) | def forward(ctx, alphaSqr, cosTheta):
    method backward (line 145) | def backward(ctx, dout):
  function _lambda_ggx (line 149) | def _lambda_ggx(alphaSqr, cosTheta, use_python=False):
  class _masking_smith_func (line 159) | class _masking_smith_func(torch.autograd.Function):
    method forward (line 161) | def forward(ctx, alphaSqr, cosThetaI, cosThetaO):
    method backward (line 167) | def backward(ctx, dout):
  function _masking_smith (line 171) | def _masking_smith(alphaSqr, cosThetaI, cosThetaO, use_python=False):
  class _prepare_shading_normal_func (line 184) | class _prepare_shading_normal_func(torch.autograd.Function):
    method forward (line 186) | def forward(ctx, pos, view_pos, perturbed_nrm, smooth_nrm, smooth_tng,...
    method backward (line 193) | def backward(ctx, dout):
  function prepare_shading_normal (line 197) | def prepare_shading_normal(pos, view_pos, perturbed_nrm, smooth_nrm, smo...
  class _lambert_func (line 235) | class _lambert_func(torch.autograd.Function):
    method forward (line 237) | def forward(ctx, nrm, wi):
    method backward (line 243) | def backward(ctx, dout):
  function lambert (line 247) | def lambert(nrm, wi, use_python=False):
  class _frostbite_diffuse_func (line 269) | class _frostbite_diffuse_func(torch.autograd.Function):
    method forward (line 271) | def forward(ctx, nrm, wi, wo, linearRoughness):
    method backward (line 277) | def backward(ctx, dout):
  function frostbite_diffuse (line 281) | def frostbite_diffuse(nrm, wi, wo, linearRoughness, use_python=False):
  class _pbr_specular_func (line 305) | class _pbr_specular_func(torch.autograd.Function):
    method forward (line 307) | def forward(ctx, col, nrm, wo, wi, alpha, min_roughness):
    method backward (line 314) | def backward(ctx, dout):
  function pbr_specular (line 318) | def pbr_specular(col, nrm, wo, wi, alpha, min_roughness=0.08, use_python...
  class _pbr_bsdf_func (line 344) | class _pbr_bsdf_func(torch.autograd.Function):
    method forward (line 346) | def forward(ctx, kd, arm, pos, nrm, view_pos, light_pos, min_roughness...
    method backward (line 354) | def backward(ctx, dout):
  function pbr_bsdf (line 358) | def pbr_bsdf(kd, arm, pos, nrm, view_pos, light_pos, min_roughness=0.08,...
  class _diffuse_cubemap_func (line 394) | class _diffuse_cubemap_func(torch.autograd.Function):
    method forward (line 396) | def forward(ctx, cubemap):
    method backward (line 402) | def backward(ctx, dout):
  function diffuse_cubemap (line 407) | def diffuse_cubemap(cubemap, use_python=False):
  class _specular_cubemap (line 416) | class _specular_cubemap(torch.autograd.Function):
    method forward (line 418) | def forward(ctx, cubemap, roughness, costheta_cutoff, bounds):
    method backward (line 425) | def backward(ctx, dout):
  function __ndfBounds (line 431) | def __ndfBounds(res, roughness, cutoff):
  function specular_cubemap (line 449) | def specular_cubemap(cubemap, roughness, cutoff=0.99, use_python=False):
  class _image_loss_func (line 466) | class _image_loss_func(torch.autograd.Function):
    method forward (line 468) | def forward(ctx, img, target, loss, tonemapper):
    method backward (line 475) | def backward(ctx, dout):
  function image_loss (line 479) | def image_loss(img, target, loss='l1', tonemapper='none', use_python=Fal...
  class _xfm_func (line 506) | class _xfm_func(torch.autograd.Function):
    method forward (line 508) | def forward(ctx, points, matrix, isPoints):
    method backward (line 514) | def backward(ctx, dout):
  function xfm_points (line 518) | def xfm_points(points, matrix, use_python=False):
  function xfm_vectors (line 536) | def xfm_vectors(vectors, matrix, use_python=False):

FILE: render/renderutils/tests/test_bsdf.py
  function relative_loss (line 20) | def relative_loss(name, ref, cuda):
  function test_normal (line 25) | def test_normal():
  function test_schlick (line 59) | def test_schlick():
  function test_ndf_ggx (line 85) | def test_ndf_ggx():
  function test_lambda_ggx (line 109) | def test_lambda_ggx():
  function test_masking_smith (line 132) | def test_masking_smith():
  function test_lambert (line 157) | def test_lambert():
  function test_frostbite (line 179) | def test_frostbite():
  function test_pbr_specular (line 207) | def test_pbr_specular():
  function test_pbr_bsdf (line 244) | def test_pbr_bsdf(bsdf):

FILE: render/renderutils/tests/test_loss.py
  function tonemap_srgb (line 20) | def tonemap_srgb(f):
  function l1 (line 23) | def l1(output, target):
  function relative_loss (line 30) | def relative_loss(name, ref, cuda):
  function test_loss (line 35) | def test_loss(loss, tonemapper):

FILE: render/renderutils/tests/test_mesh.py
  function tonemap_srgb (line 23) | def tonemap_srgb(f):
  function l1 (line 26) | def l1(output, target):
  function relative_loss (line 33) | def relative_loss(name, ref, cuda):
  function test_xfm_points (line 38) | def test_xfm_points():
  function test_xfm_vectors (line 58) | def test_xfm_vectors():

FILE: render/renderutils/tests/test_perf.py
  function test_bsdf (line 19) | def test_bsdf(BATCH, RES, ITR):

FILE: render/texture.py
  class texture2d_mip (line 20) | class texture2d_mip(torch.autograd.Function):
    method forward (line 22) | def forward(ctx, texture):
    method backward (line 26) | def backward(ctx, dout):
  class Texture2D (line 38) | class Texture2D:
    method __init__ (line 41) | def __init__(self, init, min_max=None):
    method sample (line 57) | def sample(self, texc, texc_deriv, filter_mode='linear-mipmap-linear'):
    method getRes (line 70) | def getRes(self):
    method getChannels (line 73) | def getChannels(self):
    method getMips (line 76) | def getMips(self):
    method parameters (line 82) | def parameters(self):
    method clamp_ (line 86) | def clamp_(self):
    method normalize_ (line 93) | def normalize_(self):
  function create_trainable (line 103) | def create_trainable(init, res=None, auto_mipmaps=True, min_max=None):
  function srgb_to_rgb (line 137) | def srgb_to_rgb(texture):
  function rgb_to_srgb (line 140) | def rgb_to_srgb(texture):
  function _load_mip2D (line 147) | def _load_mip2D(fn, lambda_fn=None, channels=None):
  function load_texture2D (line 155) | def load_texture2D(fn, lambda_fn=None, channels=None):
  function _save_mip2D (line 165) | def _save_mip2D(fn, mip, mipidx, lambda_fn):
  function save_texture2D (line 177) | def save_texture2D(fn, tex, lambda_fn=None):

FILE: render/util.py
  function dot (line 19) | def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
  function reflect (line 22) | def reflect(x: torch.Tensor, n: torch.Tensor) -> torch.Tensor:
  function length (line 25) | def length(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
  function safe_normalize (line 28) | def safe_normalize(x: torch.Tensor, eps: float =1e-20) -> torch.Tensor:
  function to_hvec (line 31) | def to_hvec(x: torch.Tensor, w: float) -> torch.Tensor:
  function ycocg2rgb (line 34) | def ycocg2rgb(ycocg):
  function hsv2rgb (line 41) | def hsv2rgb(image): # Based on https://kornia.readthedocs.io/en/latest/_...
  function pixel_grid (line 61) | def pixel_grid(width, height, center_x = 0.5, center_y = 0.5):
  function dilate (line 70) | def dilate(x, x_avg, mask, N):
  function _rgb_to_srgb (line 94) | def _rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
  function rgb_to_srgb (line 97) | def rgb_to_srgb(f: torch.Tensor) -> torch.Tensor:
  function _srgb_to_rgb (line 103) | def _srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
  function srgb_to_rgb (line 106) | def srgb_to_rgb(f: torch.Tensor) -> torch.Tensor:
  function reinhard (line 112) | def reinhard(f: torch.Tensor) -> torch.Tensor:
  function mse_to_psnr (line 122) | def mse_to_psnr(mse):
  function psnr_to_mse (line 126) | def psnr_to_mse(psnr):
  function get_miplevels (line 134) | def get_miplevels(texture: np.ndarray) -> float:
  function tex_2d (line 138) | def tex_2d(tex_map : torch.Tensor, coords : torch.Tensor, filter='neares...
  function cube_to_dir (line 149) | def cube_to_dir(s, x, y):
  function latlong_to_cubemap (line 158) | def latlong_to_cubemap(latlong_map, res):
  function cubemap_to_latlong (line 173) | def cubemap_to_latlong(cubemap, res):
  function scale_img_hwc (line 192) | def scale_img_hwc(x : torch.Tensor, size, mag='bilinear', min='area') ->...
  function scale_img_nhwc (line 195) | def scale_img_nhwc(x  : torch.Tensor, size, mag='bilinear', min='area') ...
  function avg_pool_nhwc (line 207) | def avg_pool_nhwc(x  : torch.Tensor, size) -> torch.Tensor:
  function segment_sum (line 216) | def segment_sum(data: torch.Tensor, segment_ids: torch.Tensor) -> torch....
  function fovx_to_fovy (line 235) | def fovx_to_fovy(fovx, aspect):
  function focal_length_to_fovy (line 238) | def focal_length_to_fovy(focal_length, sensor_height):
  function perspective (line 242) | def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None):
  function perspective_offcenter (line 250) | def perspective_offcenter(fovy, fraction, rx, ry, aspect=1.0, n=0.1, f=1...
  function translate (line 274) | def translate(x, y, z, device=None):
  function rotate_x (line 280) | def rotate_x(a, device=None):
  function rotate_y (line 287) | def rotate_y(a, device=None):
  function rotate_z (line 294) | def rotate_z(a, device=None):
  function scale (line 301) | def scale(s, device=None):
  function lookAt (line 307) | def lookAt(eye, at, up):
  function random_rotation_translation (line 324) | def random_rotation_translation(t, device=None):
  function random_rotation (line 335) | def random_rotation(device=None):
  function lines_focal (line 350) | def lines_focal(o, d):
  function cosine_sample (line 361) | def cosine_sample(N, size=None):
  function bilinear_downsample (line 396) | def bilinear_downsample(x : torch.tensor) -> torch.Tensor:
  function bilinear_downsample (line 406) | def bilinear_downsample(x : torch.tensor, spp) -> torch.Tensor:
  function init_glfw (line 422) | def init_glfw():
  function display_image (line 440) | def display_image(image, title=None):
  function save_image (line 483) | def save_image(fn, x : np.ndarray) -> np.ndarray:
  function save_image_raw (line 492) | def save_image_raw(fn, x : np.ndarray):
  function load_image_raw (line 499) | def load_image_raw(fn) -> np.ndarray:
  function load_image (line 502) | def load_image(fn) -> np.ndarray:
  function time_to_text (line 511) | def time_to_text(x):
  function checkerboard (line 521) | def checkerboard(res, checker_size) -> np.ndarray:

FILE: train_gflexicubes_deepfashion.py
  function createLoss (line 52) | def createLoss(FLAGS):
  function prepare_batch (line 71) | def prepare_batch(target, bg_type='black'):
  function xatlas_uvmap (line 101) | def xatlas_uvmap(glctx, geometry, mat, FLAGS):
  function initial_guess_material (line 155) | def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
  function initial_guess_material_knownkskd (line 172) | def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
  function validate_itr (line 190) | def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, deno...
  function validate (line 237) | def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_d...
  function optimize_mesh (line 288) | def optimize_mesh(

FILE: train_gflexicubes_polycam.py
  function createLoss (line 52) | def createLoss(FLAGS):
  function prepare_batch (line 71) | def prepare_batch(target, bg_type='black'):
  function xatlas_uvmap (line 101) | def xatlas_uvmap(glctx, geometry, mat, FLAGS):
  function initial_guess_material (line 155) | def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
  function initial_guess_material_knownkskd (line 172) | def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
  function validate_itr (line 190) | def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, deno...
  function validate (line 209) | def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_d...
  function optimize_mesh (line 260) | def optimize_mesh(

FILE: train_gshelltet_deepfashion.py
  function createLoss (line 52) | def createLoss(FLAGS):
  function prepare_batch (line 71) | def prepare_batch(target, bg_type='black'):
  function xatlas_uvmap (line 101) | def xatlas_uvmap(glctx, geometry, mat, FLAGS):
  function initial_guess_material (line 155) | def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
  function initial_guess_material_knownkskd (line 172) | def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
  function validate_itr (line 190) | def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, deno...
  function validate (line 227) | def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_d...
  function optimize_mesh (line 278) | def optimize_mesh(

FILE: train_gshelltet_polycam.py
  function createLoss (line 52) | def createLoss(FLAGS):
  function prepare_batch (line 71) | def prepare_batch(target, bg_type='black'):
  function xatlas_uvmap (line 101) | def xatlas_uvmap(glctx, geometry, mat, FLAGS):
  function initial_guess_material (line 155) | def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
  function initial_guess_material_knownkskd (line 172) | def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
  function validate_itr (line 190) | def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, deno...
  function validate (line 209) | def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_d...
  function optimize_mesh (line 260) | def optimize_mesh(

FILE: train_gshelltet_synthetic.py
  function createLoss (line 51) | def createLoss(FLAGS):
  function prepare_batch (line 70) | def prepare_batch(target, bg_type='black'):
  function xatlas_uvmap (line 100) | def xatlas_uvmap(glctx, geometry, mat, FLAGS):
  function initial_guess_material (line 154) | def initial_guess_material(geometry, mlp, FLAGS, init_mat=None):
  function initial_guess_material_knownkskd (line 171) | def initial_guess_material_knownkskd(geometry, mlp, FLAGS, init_mat=None):
  function validate_itr (line 189) | def validate_itr(glctx, target, geometry, opt_material, lgt, FLAGS, deno...
  function validate (line 212) | def validate(glctx, geometry, opt_material, lgt, dataset_validate, out_d...
  function optimize_mesh (line 263) | def optimize_mesh(
Condensed preview — 124 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,318K chars).
[
  {
    "path": ".gitignore",
    "chars": 3091,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "GMeshDiffusion/diffusion_configs/config_lower_occgrid_normalized.py",
    "chars": 3733,
    "preview": "import ml_collections\nimport torch\nimport os\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    # data\n  "
  },
  {
    "path": "GMeshDiffusion/diffusion_configs/config_upper_occgrid_normalized.py",
    "chars": 3775,
    "preview": "import ml_collections\nimport torch\nimport os\n\n\ndef get_config():\n    config = ml_collections.ConfigDict()\n\n    # data\n  "
  },
  {
    "path": "GMeshDiffusion/lib/dataset/gshell_dataset.py",
    "chars": 741,
    "preview": "import torch\nimport numpy as np\nfrom torch.utils.data import Dataset\n\nclass GShellDataset(Dataset):\n    def __init__(sel"
  },
  {
    "path": "GMeshDiffusion/lib/dataset/gshell_dataset_aug.py",
    "chars": 1231,
    "preview": "import torch\nfrom torch.utils.data import Dataset\n\nclass GShellAugDataset(Dataset):\n    def __init__(self, FLAGS, extens"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/evaler.py",
    "chars": 10808,
    "preview": "import os\nimport sys\nimport numpy as np\nimport tqdm\n\nimport logging\nfrom . import losses\nfrom .models import utils as mu"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/likelihood.py",
    "chars": 4714,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/losses.py",
    "chars": 10283,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/__init__.py",
    "chars": 608,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/ema.py",
    "chars": 3672,
    "preview": "# Modified from https://raw.githubusercontent.com/fadel/pytorch_ema/master/torch_ema/ema.py\n\nfrom __future__ import divi"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/functional.py",
    "chars": 9332,
    "preview": "#################################################################################################\n# Copyright (c) 2023 A"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/layers.py",
    "chars": 11260,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/normalization.py",
    "chars": 7768,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/unet3d_occgrid.py",
    "chars": 8884,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/models/utils.py",
    "chars": 7257,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/sampling.py",
    "chars": 26124,
    "preview": "# coding=utf-8\n# Copyright 2020 The Google Research Authors.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Li"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/sde_lib.py",
    "chars": 9305,
    "preview": "\"\"\"Abstract SDE classes, Reverse SDE, and VE/VP SDEs.\"\"\"\nimport abc\nimport torch\nimport numpy as np\nimport torch.nn.func"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/trainer.py",
    "chars": 6286,
    "preview": "import os\nimport sys\nimport numpy as np\n\nimport logging\n# Keep the import below for registering all model definitions\nfr"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/trainer_ddp.py",
    "chars": 7278,
    "preview": "import os\nimport sys\nimport numpy as np\n\nimport logging\n# Keep the import below for registering all model definitions\nfr"
  },
  {
    "path": "GMeshDiffusion/lib/diffusion/utils.py",
    "chars": 1503,
    "preview": "import torch\nimport os\nimport logging\n\n\ndef restore_checkpoint(ckpt_dir, state, device, strict=False, rank=None):\n  if n"
  },
  {
    "path": "GMeshDiffusion/main_diffusion.py",
    "chars": 847,
    "preview": "\"\"\"Training and evaluation\"\"\"\n\nfrom absl import app\nfrom absl import flags\nfrom ml_collections.config_flags import confi"
  },
  {
    "path": "GMeshDiffusion/main_diffusion_ddp.py",
    "chars": 635,
    "preview": "\"\"\"Training and evaluation\"\"\"\n\nfrom absl import app\nfrom absl import flags\nfrom ml_collections.config_flags import confi"
  },
  {
    "path": "GMeshDiffusion/metadata/get_splits_lower.py",
    "chars": 974,
    "preview": "import os\nimport random\n\nrandom.seed(42)\n\nsplit_ratio = 0.9\ndata_root = 'PLACEHOLDER'\ngrid_root = os.path.join(data_root"
  },
  {
    "path": "GMeshDiffusion/metadata/get_splits_upper.py",
    "chars": 974,
    "preview": "import os\nimport random\n\nrandom.seed(42)\n\nsplit_ratio = 0.9\ndata_root = 'PLACEHOLDER'\ngrid_root = os.path.join(data_root"
  },
  {
    "path": "GMeshDiffusion/metadata/save_tet_info.py",
    "chars": 4023,
    "preview": "'''\n    Storing tet-grid related meta-info into a single file\n'''\n\nimport numpy as np\nimport torch\nimport os\nimport tqdm"
  },
  {
    "path": "GMeshDiffusion/metadata/tet_to_cubic_grid_dataset.py",
    "chars": 12533,
    "preview": "import numpy as np\nimport torch\nimport os\nimport tqdm\nimport argparse\n\ndef tet_to_grids(vertices, values_list, grid_size"
  },
  {
    "path": "GMeshDiffusion/scripts/run_eval_lower_occgrid_normalized.sh",
    "chars": 299,
    "preview": "python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_lower_occgrid_normalized.py \\\n--config.eval"
  },
  {
    "path": "GMeshDiffusion/scripts/run_eval_upper_occgrid_normalized.sh",
    "chars": 299,
    "preview": "python main_diffusion.py --mode uncond_gen --config diffusion_configs/config_upper_occgrid_normalized.py \\\n--config.eval"
  },
  {
    "path": "GMeshDiffusion/scripts/run_lower_occgrid_normalized_ddp.sh",
    "chars": 213,
    "preview": "torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_lower_occgri"
  },
  {
    "path": "GMeshDiffusion/scripts/run_upper_occgrid_normalized_ddp.sh",
    "chars": 214,
    "preview": "torchrun --nnodes=1 --nproc_per_node=8 main_diffusion_ddp.py --mode=train --config=diffusion_configs/config_upper_occgri"
  },
  {
    "path": "README.md",
    "chars": 7662,
    "preview": "<div align=\"center\">\n  <img src=\"assets/gshell_logo.png\" width=\"900\"/>\n</div>\n\n# Ghost on the Shell: An Expressive Repre"
  },
  {
    "path": "configs/deepfashion_mc.json",
    "chars": 655,
    "preview": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"te"
  },
  {
    "path": "configs/deepfashion_mc_256.json",
    "chars": 655,
    "preview": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"te"
  },
  {
    "path": "configs/deepfashion_mc_512.json",
    "chars": 679,
    "preview": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"te"
  },
  {
    "path": "configs/deepfashion_mc_80.json",
    "chars": 654,
    "preview": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"te"
  },
  {
    "path": "configs/nerf_chair.json",
    "chars": 541,
    "preview": "{\n    \"ref_mesh\": \"data/nerf_synthetic/chair\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n"
  },
  {
    "path": "configs/polycam_mc.json",
    "chars": 653,
    "preview": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"te"
  },
  {
    "path": "configs/polycam_mc_128.json",
    "chars": 653,
    "preview": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"te"
  },
  {
    "path": "configs/polycam_mc_16samples.json",
    "chars": 654,
    "preview": "{\n    \"ref_mesh\": \"data/spot/spot.obj\",\n    \"random_textures\": true,\n    \"iter\": 5000,\n    \"save_interval\": 100,\n    \"te"
  },
  {
    "path": "data/tets/generate_tets.py",
    "chars": 1775,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "dataset/__init__.py",
    "chars": 572,
    "preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
  },
  {
    "path": "dataset/dataset.py",
    "chars": 1866,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "dataset/dataset_deepfashion.py",
    "chars": 4627,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "dataset/dataset_deepfashion_testset.py",
    "chars": 4300,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "dataset/dataset_llff.py",
    "chars": 4839,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "dataset/dataset_mesh.py",
    "chars": 5842,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "dataset/dataset_nerf.py",
    "chars": 3623,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "dataset/dataset_nerf_colmap.py",
    "chars": 3738,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "denoiser/denoiser.py",
    "chars": 1210,
    "preview": "import os\n\nimport torch\nimport numpy as np\nimport math\n\nfrom render import util\nif \"TWOSIDED_TEXTURE\" not in os.environ "
  },
  {
    "path": "eval_gmeshdiffusion_generated_samples.py",
    "chars": 9619,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "geometry/embedding.py",
    "chars": 1200,
    "preview": "import torch\nfrom torch import nn\n\nclass Embedding(nn.Module):\n    def __init__(self, in_channels, N_freqs, logscale=Tru"
  },
  {
    "path": "geometry/flexicubes_table.py",
    "chars": 41249,
    "preview": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION & AFFILIATES and its"
  },
  {
    "path": "geometry/gshell_flexicubes.py",
    "chars": 40870,
    "preview": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION & AFFILIATES and its"
  },
  {
    "path": "geometry/gshell_flexicubes_geometry.py",
    "chars": 16375,
    "preview": "# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.\n#\n# NVIDIA CORPORATION & AFFILIATES and its"
  },
  {
    "path": "geometry/gshell_tets.py",
    "chars": 31153,
    "preview": "import numpy as np\nimport torch\n\nfrom render import util\n\n##############################################################"
  },
  {
    "path": "geometry/gshell_tets_geometry.py",
    "chars": 17765,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "geometry/mlp.py",
    "chars": 1372,
    "preview": "import torch\nimport torch.nn as nn\nimport numpy as np\n\nfrom .embedding import Embedding\n\nclass MLP(nn.Module):\n    def _"
  },
  {
    "path": "render/light.py",
    "chars": 4278,
    "preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
  },
  {
    "path": "render/material.py",
    "chars": 7200,
    "preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
  },
  {
    "path": "render/mesh.py",
    "chars": 12068,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/mlptexture.py",
    "chars": 4650,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/obj.py",
    "chars": 7794,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/optixutils/__init__.py",
    "chars": 601,
    "preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
  },
  {
    "path": "render/optixutils/c_src/accessor.h",
    "chars": 13669,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/c_src/bsdf.h",
    "chars": 11222,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/c_src/common.h",
    "chars": 5311,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/c_src/denoising.cu",
    "chars": 5103,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/c_src/denoising.h",
    "chars": 790,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/c_src/envsampling/kernel.cu",
    "chars": 20280,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/c_src/envsampling/params.h",
    "chars": 2054,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/c_src/math_utils.h",
    "chars": 12168,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/c_src/optix_wrapper.cpp",
    "chars": 13359,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/c_src/optix_wrapper.h",
    "chars": 1088,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/c_src/torch_bindings.cpp",
    "chars": 14563,
    "preview": "// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n//\n// NVIDIA CORPORATION and its licensors retain a"
  },
  {
    "path": "render/optixutils/include/internal/optix_7_device_impl.h",
    "chars": 73134,
    "preview": "/*\n* Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all in"
  },
  {
    "path": "render/optixutils/include/internal/optix_7_device_impl_exception.h",
    "chars": 15383,
    "preview": "/*\n* Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all in"
  },
  {
    "path": "render/optixutils/include/internal/optix_7_device_impl_transformations.h",
    "chars": 17987,
    "preview": "/*\n* Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all in"
  },
  {
    "path": "render/optixutils/include/optix.h",
    "chars": 1716,
    "preview": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain al"
  },
  {
    "path": "render/optixutils/include/optix_7_device.h",
    "chars": 60451,
    "preview": "/*\n* Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n*\n* NVIDIA Corporation and its licensors retain all in"
  },
  {
    "path": "render/optixutils/include/optix_7_host.h",
    "chars": 44487,
    "preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all"
  },
  {
    "path": "render/optixutils/include/optix_7_types.h",
    "chars": 70493,
    "preview": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain al"
  },
  {
    "path": "render/optixutils/include/optix_denoiser_tiling.h",
    "chars": 13614,
    "preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * Redistribution and use in source and binary for"
  },
  {
    "path": "render/optixutils/include/optix_device.h",
    "chars": 2129,
    "preview": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain al"
  },
  {
    "path": "render/optixutils/include/optix_function_table.h",
    "chars": 16836,
    "preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all"
  },
  {
    "path": "render/optixutils/include/optix_function_table_definition.h",
    "chars": 1827,
    "preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all"
  },
  {
    "path": "render/optixutils/include/optix_host.h",
    "chars": 1661,
    "preview": "\n/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain al"
  },
  {
    "path": "render/optixutils/include/optix_stack_size.h",
    "chars": 17447,
    "preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * Redistribution and use in source and binary for"
  },
  {
    "path": "render/optixutils/include/optix_stubs.h",
    "chars": 28668,
    "preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * Redistribution and use in source and binary for"
  },
  {
    "path": "render/optixutils/include/optix_types.h",
    "chars": 1777,
    "preview": "/*\n * Copyright (c) 2021 NVIDIA Corporation.  All rights reserved.\n *\n * NVIDIA Corporation and its licensors retain all"
  },
  {
    "path": "render/optixutils/ops.py",
    "chars": 7155,
    "preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
  },
  {
    "path": "render/optixutils/tests/filter_test.py",
    "chars": 4260,
    "preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
  },
  {
    "path": "render/regularizer.py",
    "chars": 5832,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/render.py",
    "chars": 21371,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n#\n# NVIDIA CORPORATION, its affiliates a"
  },
  {
    "path": "render/renderutils/__init__.py",
    "chars": 942,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/renderutils/bsdf.py",
    "chars": 6266,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/renderutils/c_src/bsdf.cu",
    "chars": 26451,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/bsdf.h",
    "chars": 1564,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/common.cpp",
    "chars": 2285,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/common.h",
    "chars": 1218,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/cubemap.cu",
    "chars": 13015,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/cubemap.h",
    "chars": 907,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/loss.cu",
    "chars": 7355,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/loss.h",
    "chars": 896,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/mesh.cu",
    "chars": 3973,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/mesh.h",
    "chars": 696,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/normal.cu",
    "chars": 7463,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/normal.h",
    "chars": 786,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/tensor.h",
    "chars": 4408,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/torch_bindings.cpp",
    "chars": 41998,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/vec3f.h",
    "chars": 4430,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/c_src/vec4f.h",
    "chars": 831,
    "preview": "/*\n * Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n *\n * NVIDIA CORPORATION, its affili"
  },
  {
    "path": "render/renderutils/loss.py",
    "chars": 1675,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/renderutils/ops.py",
    "chars": 22087,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/renderutils/tests/test_bsdf.py",
    "chars": 13752,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/renderutils/tests/test_loss.py",
    "chars": 2225,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/renderutils/tests/test_mesh.py",
    "chars": 3433,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/renderutils/tests/test_perf.py",
    "chars": 2342,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "render/texture.py",
    "chars": 7853,
    "preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
  },
  {
    "path": "render/util.py",
    "chars": 22370,
    "preview": "# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all "
  },
  {
    "path": "train_gflexicubes_deepfashion.py",
    "chars": 35437,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "train_gflexicubes_polycam.py",
    "chars": 32260,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "train_gshelltet_deepfashion.py",
    "chars": 34461,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "train_gshelltet_polycam.py",
    "chars": 32617,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  },
  {
    "path": "train_gshelltet_synthetic.py",
    "chars": 33225,
    "preview": "# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n#\n# NVIDIA CORPORATION, its affiliates "
  }
]

About this extraction

This page contains the full source code of the lzzcd001/GShell GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 124 files (1.2 MB), approximately 348.0k tokens, and a symbol index with 975 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!