Repository: 1zb/GeomDist
Branch: master
Commit: bf593d307676
Files: 13
Total size: 54.2 KB
Directory structure:
gitextract_qju22int/
├── .gitignore
├── README.md
├── engine.py
├── eval.py
├── infer.py
├── inverese.py
├── main.py
├── models.py
├── normalize.py
├── points.py
└── util/
├── lr_decay.py
├── lr_sched.py
└── misc.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.ipynb
*.ply
.ipynb_checkpoints
test*.*
*.stl
vis.py
*.obj
__pycache__
*_dataset
output/
backup.py
shapes/
samples/
================================================
FILE: README.md
================================================
# Geometry Distributions
### [Project Page](https://1zb.github.io/GeomDist/) | [Paper (arXiv)](https://arxiv.org/abs/2411.16076)
### :bullettrain_front: Training
```
torchrun --nproc_per_node=4 main.py --blr 5e-7 --output_dir output/loong --log_dir output/loong --data_path shapes/loong.obj
```
### :balloon: Inference
```
python infer.py --pth output/loong/checkpoint-999.pth --target Gaussian --num-steps 64 --output samples/loong --N 10000000
```
### :floppy_disk: Datasets
https://huggingface.co/datasets/Zbalpha/shapes
### :briefcase: Checkpoints
https://huggingface.co/Zbalpha/geom_dist_ckpt
## :e-mail: Contact
Contact [Biao Zhang](mailto:biao.zhang@kaust.edu.sa) ([@1zb](https://github.com/1zb)) if you have any further questions. This repository is for academic research use only.
## :blue_book: Citation
arxiv
```bibtex
@article{zhang2024geometry,
title={Geometry Distributions},
author={Zhang, Biao and Ren, Jing and Wonka, Peter},
journal={arXiv preprint arXiv:2411.16076},
year={2024}
}
```
ICCV
```
@InProceedings{Zhang_2025_ICCV,
author = {Zhang, Biao and Ren, Jing and Wonka, Peter},
title = {Geometry Distributions},
booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},
month = {October},
year = {2025},
pages = {1495-1505}
}
```
================================================
FILE: engine.py
================================================
# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import math
import sys
from typing import Iterable
import torch
import torch.nn.functional as F
import numpy as np
import util.misc as misc
import util.lr_sched as lr_sched
from torch.autograd import Variable
from math import exp
from einops import rearrange, repeat
import trimesh
from PIL import Image
def train_one_epoch(model: torch.nn.Module,
data_loader, optimizer: torch.optim.Optimizer,
criterion,
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
log_writer=None, args=None):
model.train(True)
metric_logger = misc.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 20
accum_iter = args.accum_iter
optimizer.zero_grad()
if log_writer is not None:
print('log_dir: {}'.format(log_writer.log_dir))
print(data_loader)
noise = None
if isinstance(data_loader, dict):
obj_file = data_loader['obj_file']
batch_size = data_loader['batch_size']
if obj_file is not None:
if obj_file.endswith('.obj'):
mesh = trimesh.load(obj_file)
if data_loader['texture_path'] is not None:
img = Image.open(data_loader['texture_path'])
material = trimesh.visual.texture.SimpleMaterial(image=img)
assert mesh.visual.uv is not None
texture = trimesh.visual.TextureVisuals(mesh.visual.uv, image=img, material=material)
mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, visual=texture, process=False)
samples, _, colors = trimesh.sample.sample_surface(mesh, 2048*64*4*64, sample_color=True)
colors = colors[:, :3] # remove alpha
colors = (colors.astype(np.float32) / 255.0 - 0.5) / np.sqrt(1/12) # [-1, 1]
samples = np.concatenate([samples, colors], axis=1)
else:
samples, _ = trimesh.sample.sample_surface(mesh, 2048*64*4*64)
else:
samples = trimesh.load(obj_file).vertices
else:
if data_loader['primitive'] == 'sphere':
n = torch.randn(2048*64*4*64, 3)
n = torch.nn.functional.normalize(n, dim=1)
samples = n / np.sqrt(1/3)
samples = samples.numpy()
elif data_loader['primitive'] == 'plane':
samples = torch.rand(2048*64*4*64, 3) - 0.5
samples[:, 2] = 0
samples = (samples - 0) / np.sqrt(2/9*2*0.5**3)
samples = samples.numpy()
elif data_loader['primitive'] == 'volume':
samples = (torch.rand(2048*64*4*64, 3) - 0.5) / np.sqrt(1/12)
samples = samples.numpy()
elif data_loader['primitive'] == 'gaussian':
samples = np.random.randn(2048*64*4*64, 3).astype(np.float32)
else:
raise NotImplementedError
if data_loader['noise_mesh'] is not None:
noise, _ = trimesh.sample.sample_surface(trimesh.load(data_loader['noise_mesh']), 2048*64*4*64)
else:
noise = None
samples = samples.astype(np.float32)# - 0.12
data_loader = range(data_loader['epoch_size'])
for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
# we use a per iteration (instead of per epoch) lr scheduler
if data_iter_step % accum_iter == 0:
lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
if isinstance(batch, int):
ind = np.random.default_rng().choice(samples.shape[0], batch_size, replace=True)
xyz = samples[ind]
xyz = torch.from_numpy(xyz).float().to(device, non_blocking=True)
else:
xyz = batch.to(device, non_blocking=True)
with torch.cuda.amp.autocast(enabled=False):
if noise is not None:
ind = np.random.default_rng().choice(noise.shape[0], batch_size, replace=True)
init_noise = noise[ind]
init_noise = torch.from_numpy(init_noise).float().to(device, non_blocking=True)
else:
init_noise = None
loss = criterion(model, xyz, init_noise=init_noise)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
loss /= accum_iter
loss_scaler(loss, optimizer, clip_grad=max_norm,
parameters=model.parameters(), create_graph=False,
update_grad=(data_iter_step + 1) % accum_iter == 0)
if (data_iter_step + 1) % accum_iter == 0:
optimizer.zero_grad()
torch.cuda.synchronize()
metric_logger.update(loss=loss_value)
min_lr = 10.
max_lr = 0.
for group in optimizer.param_groups:
min_lr = min(min_lr, group["lr"])
max_lr = max(max_lr, group["lr"])
metric_logger.update(lr=max_lr)
loss_value_reduce = misc.all_reduce_mean(loss_value)
if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
""" We use epoch_1000x as the x-axis in tensorboard.
This calibrates different curves when batch size changes.
"""
epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)
log_writer.add_scalar('lr', max_lr, epoch_1000x)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
================================================
FILE: eval.py
================================================
import trimesh
from scipy.spatial import cKDTree as KDTree
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--ply', required=True, type=str)
parser.add_argument('--reference', required=True, type=str)
parser.add_argument('--scale', required=True, type=str)
args = parser.parse_args()
scale = np.load(args.scale)
prediction = trimesh.load(args.ply).vertices * scale
reference = trimesh.load(args.reference).vertices * scale
tree = KDTree(prediction)
dist, _ = tree.query(reference)
d1 = dist
gt_to_gen_chamfer = np.mean(dist)
gt_to_gen_chamfer_sq = np.mean(np.square(dist))
tree = KDTree(reference)
dist, _ = tree.query(prediction)
d2 = dist
gen_to_gt_chamfer = np.mean(dist)
gen_to_gt_chamfer_sq = np.mean(np.square(dist))
cd = gt_to_gen_chamfer + gen_to_gt_chamfer
print(cd)
================================================
FILE: infer.py
================================================
import argparse
from pathlib import Path
import os
import torch
import trimesh
from models import EDMPrecond
torch.manual_seed(0)
import numpy as np
np.random.seed(0)
import random
random.seed(0)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
parser = argparse.ArgumentParser('Inference', add_help=False)
parser.add_argument('--pth', required=True, type=str)
parser.add_argument('--texture', action='store_true')
parser.add_argument('--target', default='Gaussian', type=str)
parser.add_argument('--N', default=1000000, type=int)
parser.add_argument('--num-steps', default=64, type=int)
parser.add_argument('--noise_mesh', default=None, type=str)
parser.add_argument('--output', required=True, type=str)
parser.add_argument('--intermediate', action='store_true')
parser.add_argument('--depth', default=6, type=int)
parser.set_defaults(texture=False)
parser.set_defaults(intermediate=False)
args = parser.parse_args()
Path(args.output).mkdir(parents=True, exist_ok=True)
if args.texture:
model = EDMPrecond(channels=6, depth=args.depth).cuda()
else:
model = EDMPrecond(depth=args.depth).cuda()
model.load_state_dict(torch.load(args.pth, map_location='cpu')['model'], strict=True)
if args.target == 'Gaussian':
noise = torch.randn(args.N, 3).cuda()
elif args.target == 'Uniform':
noise = (torch.rand(args.N, 3).cuda() - 0.5) / np.sqrt(1/12)
elif args.target == 'Sphere':
n = torch.randn(args.N, 3).cuda()
n = torch.nn.functional.normalize(n, dim=1)
noise = n / np.sqrt(1/3)
elif args.target == 'Mesh':
assert args.noise_mesh is not None
noise, _ = trimesh.sample.sample_surface(trimesh.load(args.noise_mesh), args.N)
noise = torch.from_numpy(noise).float().cuda()
else:
raise NotImplementedError
if args.texture:
color = (torch.rand(args.N, 3).cuda() - 0.5) / np.sqrt(1/12)
noise = torch.cat([noise, color], dim=1)
sample, intermediate_steps = model.sample(batch_seeds=noise, num_steps=args.num_steps)
if args.texture:
sample = sample.detach().cpu().numpy()
vertices, colors = sample[:, :3], sample[:, 3:]
colors = (colors * np.sqrt(1/12) + 0.5) * 255.0
colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel
trimesh.PointCloud(vertices, colors).export(os.path.join(args.output, 'sample.ply'))
if args.intermediate:
for i, s in enumerate(intermediate_steps):
vertices, colors = s[:, :3], s[:, 3:]
colors = (colors * np.sqrt(1/12) + 0.5) * 255.0
colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel
trimesh.PointCloud(vertices, colors).export(os.path.join(args.output, 'sample-{:03d}.ply'.format(i)))
else:
trimesh.PointCloud(sample.detach().cpu().numpy()).export(os.path.join(args.output, 'sample.ply'))
if args.intermediate:
for i, s in enumerate(intermediate_steps):
trimesh.PointCloud(s).export(os.path.join(args.output, 'sample-{:03d}.ply'.format(i)))
================================================
FILE: inverese.py
================================================
import argparse
import torch
import trimesh
from models import EDMPrecond
torch.manual_seed(0)
import numpy as np
np.random.seed(0)
import random
random.seed(0)
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
parser = argparse.ArgumentParser('Inference', add_help=False)
parser.add_argument('--pth', default='output/lamp_cube/checkpoint-0.pth', type=str)
parser.add_argument('--texture', action='store_true')
parser.add_argument('--N', default=1000000, type=int)
parser.add_argument('--num-steps', default=64, type=int)
parser.add_argument('--noise_mesh', default=None, type=str)
parser.add_argument('--data_path', default='shapes/Jellyfish_lamp_part_A__B_normalized.obj', type=str)
parser.set_defaults(texture=False)
args = parser.parse_args()
if args.texture:
model = EDMPrecond(channels=6).cuda()
else:
model = EDMPrecond().cuda()
mesh = trimesh.load(args.data_path)
samples, _ = trimesh.sample.sample_surface(mesh, args.N)
samples = samples.astype(np.float32)
samples = torch.from_numpy(samples).float().cuda()
model.load_state_dict(torch.load(args.pth, map_location='cpu')['model'], strict=True)
sample, intermediate_steps = model.inverse(samples=samples, num_steps=args.num_steps)
if args.texture:
sample = sample.detach().cpu().numpy()
vertices, colors = sample[:, :3], sample[:, 3:]
colors = (colors * np.sqrt(1/12) + 0.5) * 255.0
colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel
trimesh.PointCloud(vertices, colors).export('sample.ply')
for i, s in enumerate(intermediate_steps):
vertices, colors = s[:, :3], s[:, 3:]
colors = (colors * np.sqrt(1/12) + 0.5) * 255.0
colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel
trimesh.PointCloud(vertices, colors).export('sample-{:03d}.ply'.format(i))
else:
trimesh.PointCloud(sample.detach().cpu().numpy()).export('sample.ply')
for i, s in enumerate(intermediate_steps):
trimesh.PointCloud(s).export('sample-{:03d}.ply'.format(i))
================================================
FILE: main.py
================================================
import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import torch
import torch.backends.cudnn as cudnn
from torch.utils.tensorboard import SummaryWriter
torch.set_num_threads(8)
import util.lr_decay as lrd
import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
import models as models
from models import EDMLoss
from engine import train_one_epoch
from points import Points
def get_args_parser():
parser = argparse.ArgumentParser('Train', add_help=False)
parser.add_argument('--batch_size', default=2048*64*2, type=int,
help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
parser.add_argument('--epochs', default=1000, type=int)
parser.add_argument('--accum_iter', default=1, type=int,
help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
# Model parameters
parser.add_argument('--model', default='EDMPrecond', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--depth', default=6, type=int, metavar='MODEL')
# Optimizer parameters
parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--weight_decay', type=float, default=0.05,
help='weight decay (default: 0.05)')
parser.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate (absolute lr)')
parser.add_argument('--blr', type=float, default=5e-7, metavar='LR',
help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
parser.add_argument('--layer_decay', type=float, default=0.75,
help='layer-wise lr decay from ELECTRA/BEiT')
parser.add_argument('--min_lr', type=float, default=5e-7, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0')
parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N',
help='epochs to warmup LR')
# Dataset parameters
parser.add_argument('--target', default='Gaussian', type=str, )
parser.add_argument('--data_path', default='shapes/Jellyfish_lamp_part_A__B_normalized.obj', type=str,
help='dataset path')
parser.add_argument('--texture_path', default=None, type=str,
help='dataset path')
parser.add_argument('--noise_mesh', default=None, type=str,
help='dataset path')
parser.add_argument('--output_dir', default='./output/',
help='path where to save, empty for no saving')
parser.add_argument('--log_dir', default='./output/',
help='path where to tensorboard log')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='',
help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist_eval', action='store_true', default=False,
help='Enabling distributed evaluation (recommended during training for faster monitor')
parser.add_argument('--num_workers', default=32, type=int)
parser.add_argument('--pin_mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
parser.set_defaults(pin_mem=True)
# distributed training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--local_rank', default=-1, type=int)
parser.add_argument('--dist_on_itp', action='store_true')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
return parser
def main(args):
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
misc.init_distributed_mode(args)
print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
print("{}".format(args).replace(', ', ',\n'))
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
cudnn.benchmark = True
cudnn.deterministic=True
# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True
if True:
num_tasks = misc.get_world_size()
global_rank = misc.get_rank()
neural_rendering_resolution = 128
if args.data_path.endswith('.obj') or args.data_path.endswith('.ply'):
data_loader_train = {
'obj_file': args.data_path,
'batch_size': args.batch_size,
'epoch_size': 512,
'texture_path': args.texture_path,
}
if args.noise_mesh is not None:
data_loader_train['noise_mesh'] = args.noise_mesh
else:
data_loader_train['noise_mesh'] = None
elif 'sphere' in args.data_path or 'plane' in args.data_path or 'volume' in args.data_path:
data_loader_train = {
'obj_file': None,
'primitive': args.data_path,
'batch_size': args.batch_size,
'epoch_size': 512,
'texture_path': args.texture_path,
}
if args.noise_mesh is not None:
data_loader_train['noise_mesh'] = args.noise_mesh
else:
data_loader_train['noise_mesh'] = None
else:
raise NotImplementedError
print(data_loader_train)
if global_rank == 0 and args.log_dir is not None and not args.eval:
os.makedirs(args.log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=args.log_dir)
else:
log_writer = None
criterion = EDMLoss(dist=args.target)
model = models.__dict__[args.model](channels=3 if args.texture_path is None else 6, depth=args.depth)
model.to(device)
model_without_ddp = model
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Model = %s" % str(model_without_ddp))
print('number of params (M): %.2f' % (n_parameters / 1.e6))
eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
if args.lr is None: # only base_lr is specified
args.lr = args.blr * eff_batch_size / 128
print("base lr: %.2e" % (args.lr * 128 / eff_batch_size))
print("actual lr: %.2e" % args.lr)
print("accumulate grad iterations: %d" % args.accum_iter)
print("effective batch size: %d" % eff_batch_size)
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
model_without_ddp = model.module
optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr)
loss_scaler = NativeScaler()
misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_iou = 0.0
for epoch in range(args.start_epoch, args.epochs):
# if args.distributed and args.data_path.endswith('.ply'):
# data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, data_loader_train,
optimizer, criterion, device, epoch, loss_scaler,
args.clip_grad,
log_writer=log_writer,
args=args
)
if args.output_dir and (epoch % 5 == 0 or epoch + 1 == args.epochs):
misc.save_model(
args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,
loss_scaler=loss_scaler, epoch=epoch)
if epoch % 1 == 0 or epoch + 1 == args.epochs:
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
# **{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
else:
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if args.output_dir and misc.is_main_process():
if log_writer is not None:
log_writer.flush()
with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if __name__ == '__main__':
args = get_args_parser()
args = args.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
main(args)
================================================
FILE: models.py
================================================
import torch
import torch.nn as nn
import math
import numpy as np
import torch.nn.functional
import trimesh
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class MPFourier(torch.nn.Module):
def __init__(self, num_channels, bandwidth=1):
super().__init__()
self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)
self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))
def forward(self, x):
y = x.to(torch.float32)
y = y.ger(self.freqs.to(torch.float32))
y = y + self.phases.to(torch.float32)
y = y.cos() * np.sqrt(2)
return y.to(x.dtype)
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
def mp_silu(x):
return torch.nn.functional.silu(x) / 0.596
def mp_sum(a, b, t=0.5):
# print(a.mean(), a.std(), b.mean(), b.std())
return a.lerp(b, t) / np.sqrt((1 - t) ** 2 + t ** 2)
class MPConv(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel):
super().__init__()
self.out_channels = out_channels
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
def forward(self, x, gain=1):
w = self.weight.to(torch.float32)
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(w)) # forced weight normalization
w = normalize(w) # traditional weight normalization
w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
w = w.to(x.dtype)
if w.ndim == 2:
return x @ w.t()
assert w.ndim == 4
return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))
class PointEmbed(nn.Module):
def __init__(self, hidden_dim=48, dim=128, other_dim=0):
super().__init__()
assert hidden_dim % 6 == 0
self.embedding_dim = hidden_dim
e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
e = torch.stack([
torch.cat([e, torch.zeros(self.embedding_dim // 6),
torch.zeros(self.embedding_dim // 6)]),
torch.cat([torch.zeros(self.embedding_dim // 6), e,
torch.zeros(self.embedding_dim // 6)]),
torch.cat([torch.zeros(self.embedding_dim // 6),
torch.zeros(self.embedding_dim // 6), e]),
])
self.register_buffer('basis', e) # 3 x 16
# self.mlp = nn.Linear(self.embedding_dim+3, dim)/
self.mlp = MPConv(self.embedding_dim+3+other_dim, dim, kernel=[])
@staticmethod
def embed(input, basis):
# print(input.shape, basis.shape)
projections = torch.einsum('nd,de->ne', input, basis)
embeddings = torch.cat([projections.sin(), projections.cos()], dim=1)
return embeddings
def forward(self, input):
# input: N x 3
if input.shape[1] != 3:
input, others = input[:, :3], input[:, 3:]
else:
others = None
if others is None:
embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=1)) # N x C
else:
embed = self.mlp(torch.cat([self.embed(input, self.basis), input, others], dim=1))
return embed
class Network(nn.Module):
def __init__(
self,
channels = 3,
hidden_size = 256,
depth = 6,
):
super().__init__()
self.emb_fourier = MPFourier(hidden_size)
self.emb_noise = MPConv(hidden_size, hidden_size, kernel=[])
self.x_embedder = PointEmbed(dim=hidden_size, other_dim=channels-3)
self.gains = nn.ParameterList([
torch.nn.Parameter(torch.zeros([])) for _ in range(depth)
])
##
self.layers = nn.ModuleList([
nn.ModuleList([
MPConv(hidden_size, hidden_size, []),
MPConv(hidden_size, hidden_size, []),
MPConv(hidden_size, 1 * hidden_size, []),
]) for _ in range(depth)
])
self.final_emb_gain = torch.nn.Parameter(torch.zeros([]))
self.final_out_gain = torch.nn.Parameter(torch.zeros([]))
self.final_layer = nn.ModuleList([
MPConv(hidden_size, hidden_size, []),
MPConv(hidden_size, channels, []),
MPConv(hidden_size, hidden_size, []),
])
self.res_balance = 0.3
def forward(self, x, t):
x = self.x_embedder(x)
if t.shape[0] == 1:
t = t.repeat(x.shape[0])
t = mp_silu(self.emb_noise(self.emb_fourier(t)))
for (x_proj_pre, x_proj_post, emb_linear), emb_gain in zip(self.layers, self.gains):
c = emb_linear(t, gain=emb_gain) + 1
x = normalize(x)
y = x_proj_pre(mp_silu(x))
y = mp_silu(y * c.to(y.dtype))
y = x_proj_post(y)
x = mp_sum(x, y, t=self.res_balance)
x_proj_pre, x_proj_post, emb_linear = self.final_layer
c = emb_linear(t, gain=self.final_emb_gain) + 1
y = x_proj_pre(mp_silu(normalize(x)))
y = mp_silu(y * c.to(y.dtype))
out = x_proj_post(y, gain=self.final_out_gain)
return out
class EDMPrecond(torch.nn.Module):
def __init__(self,
channels = 3,
use_fp16 = False,
sigma_min = 0,
sigma_max = float('inf'),
sigma_data = 1,
depth = 6,
):
super().__init__()
self.use_fp16 = use_fp16
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
self.model = Network(channels=channels, hidden_size=512, depth=depth)
def forward(self, x, sigma, force_fp32=False, **model_kwargs):
x = x.to(torch.float32)
sigma = sigma.to(torch.float32).reshape(-1, 1)
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
c_noise = sigma.log() / 4
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
def round_sigma(self, sigma):
return torch.as_tensor(sigma)
@torch.no_grad()
def sample(self, cond=None, batch_seeds=None, channels=3, num_steps=18):
device = batch_seeds.device
batch_size = batch_seeds.shape[0]
rnd = None
points = batch_seeds
latents = points.float().to(device)
points = edm_sampler(self, latents, cond, num_steps=num_steps)
return points
@torch.no_grad()
def inverse(self, cond=None, samples=None, channels=3, num_steps=18):
return inverse_edm_sampler(self, samples, cond, num_steps=num_steps)
class StackedRandomGenerator:
def __init__(self, device, seeds):
super().__init__()
self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
def randn(self, size, **kwargs):
assert size[0] == len(self.generators)
return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])
def randn_like(self, input):
return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)
def randint(self, *args, size, **kwargs):
assert size[0] == len(self.generators)
return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])
def edm_sampler(
net, latents, class_labels=None, randn_like=torch.randn_like,
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
# disable S_churn
assert S_churn==0
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
# Time step discretization.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
x_next = latents.to(torch.float64) * t_steps[0]
outputs = []
outputs.append((x_next / t_steps[0]).detach().cpu().numpy())
print(t_steps[0])
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
print(t_cur, t_next)
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
t_hat = net.round_sigma(t_cur + gamma * t_cur)
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
# x_hat = x_cur
t_hat = t_cur
# Euler step.
denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
denoised = net(x_next, t_next, class_labels).to(torch.float64)
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
outputs.append((x_next / (1+t_next**2).sqrt()).detach().cpu().numpy())
return x_next, outputs
def inverse_edm_sampler(
net, latents, class_labels=None, randn_like=torch.randn_like,
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
# disable S_churn
assert S_churn==0
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
# Time step discretization.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])+1e-8]) # t_N = 0
t_steps = torch.flip(t_steps, [0])#[1:]
# Main sampling loop.
x_next = latents.to(torch.float64)# * t_steps[0]
# outputs = []
outputs = None
# outputs.append((x_next / t_steps[0]).detach().cpu().numpy())
print(t_steps[0])
print(x_next.mean(), x_next.std())
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
# print('steps', t_cur, t_next)
x_cur = x_next
# print('cur', (x_cur / t_cur).mean(), (x_cur / t_cur).std())
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
t_hat = net.round_sigma(t_cur + gamma * t_cur)
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
x_hat = x_cur
t_hat = t_cur
# Euler step.
denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
denoised = net(x_next, t_next, class_labels).to(torch.float64)
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
print('next', (x_next / (1+t_next**2).sqrt()).mean(), (x_next / (1+t_next**2).sqrt()).std())
# outputs.append((x_next / (1+t_next**2).sqrt()).detach().cpu().numpy())
x_next = x_next / (1+t_next**2).sqrt()
return x_next, outputs
class EDMLoss:
def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=1, dist='Gaussian'):
self.P_mean = P_mean
self.P_std = P_std
self.sigma_data = sigma_data
self.dist = dist
def __call__(self, net, inputs, labels=None, augment_pipe=None, init_noise=None):
rnd_normal = torch.randn([inputs.shape[0],], device=inputs.device)
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
y, augment_labels = augment_pipe(inputs) if augment_pipe is not None else (inputs, None)
if self.dist == 'Gaussian':
n = torch.randn_like(y[:, :3]) * sigma[:, None]
if y.shape[1] != 3:
c = (torch.rand_like(y[:, 3:]) - 0.5) / np.sqrt(1/12) * sigma[:, None]
n = torch.cat([n, c], dim=1)
elif self.dist == 'Uniform':
n = (torch.rand_like(y) - 0.5) / np.sqrt(1/12) * sigma[:, None]
elif self.dist == 'Sphere':
n = torch.randn_like(y[:, :3])
n = torch.nn.functional.normalize(n, dim=1)
n /= np.sqrt(1/3)
n = n * sigma[:, None]
elif self.dist == "Mesh":
assert init_noise is not None
n = init_noise * sigma[:, None]
else:
raise NotImplementedError
D_yn = net(y + n, sigma)
loss = weight[:, None] * ((D_yn - y) ** 2)
return loss.mean()
================================================
FILE: normalize.py
================================================
import argparse
import trimesh
import math
import glob
import numpy as np
parser = argparse.ArgumentParser('Inference', add_help=False)
parser.add_argument('--path', required=True, type=str)
parser.add_argument('--output', required=True, type=str)
args = parser.parse_args()
model = trimesh.load(args.path, process=False)
def normalize_meshes(mesh):
mesh.vertices -= (mesh.vertices.max(axis=0) + mesh.vertices.min(axis=0)) / 2
scale = (1 / np.abs(mesh.vertices).max()) * 0.99
mesh.vertices *= scale
points, _ = trimesh.sample.sample_surface(mesh, 10000000)
mesh.vertices -= points.mean()
mesh.vertices /= points.std()
return mesh
model = normalize_meshes(model)
# angle = math.pi / 2
# direction = [1, 0, 0]
# center = [0, 0, 0]
# rot_matrix = trimesh.transformations.rotation_matrix(angle, direction, center)
# model.apply_transform(rot_matrix)
model.export(args.output)
================================================
FILE: points.py
================================================
import trimesh
import numpy as np
import os
import torch
from torch.utils import data
class Points(data.Dataset):
def __init__(self, ply_path):
points = trimesh.load(ply_path).vertices
# self.points = np.array(points)
# if os.path.exists('test.npy'):
# points = np.load('test.npy')
# else:
# points, _ = trimesh.sample.sample_surface(trimesh.load(ply_path), 50000000*5)
# np.save('test.npy', points)
self.points = torch.from_numpy(points)# - 0.12
print(self.points.std(), self.points.mean())
def __len__(self):
return self.points.shape[0]# * 16
def __getitem__(self, idx):
# idx = idx % self.points.shape[0]
return self.points[idx]
================================================
FILE: util/lr_decay.py
================================================
# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import json
def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
"""
Parameter groups for layer-wise lr decay
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
"""
param_group_names = {}
param_groups = {}
num_layers = len(model.blocks) + 1
layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
for n, p in model.named_parameters():
if not p.requires_grad:
continue
# no decay: all 1D parameters and model specific ones
if p.ndim == 1 or n in no_weight_decay_list:
g_decay = "no_decay"
this_decay = 0.
else:
g_decay = "decay"
this_decay = weight_decay
layer_id = get_layer_id_for_vit(n, num_layers)
group_name = "layer_%d_%s" % (layer_id, g_decay)
if group_name not in param_group_names:
this_scale = layer_scales[layer_id]
param_group_names[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_groups[group_name] = {
"lr_scale": this_scale,
"weight_decay": this_decay,
"params": [],
}
param_group_names[group_name]["params"].append(n)
param_groups[group_name]["params"].append(p)
# print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
return list(param_groups.values())
def get_layer_id_for_vit(name, num_layers):
"""
Assign a parameter with its layer id
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
"""
if name in ['cls_token', 'pos_embed']:
return 0
elif name.startswith('patch_embed'):
return 0
elif name.startswith('blocks'):
return int(name.split('.')[1]) + 1
else:
return num_layers
================================================
FILE: util/lr_sched.py
================================================
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
def adjust_learning_rate(optimizer, epoch, args):
"""Decay the learning rate with half-cycle cosine after warmup"""
if epoch < args.warmup_epochs:
lr = args.lr * epoch / args.warmup_epochs
else:
lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
(1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
for param_group in optimizer.param_groups:
if "lr_scale" in param_group:
param_group["lr"] = lr * param_group["lr_scale"]
else:
param_group["lr"] = lr
return lr
================================================
FILE: util/misc.py
================================================
# --------------------------------------------------------
# References:
# MAE: https://github.com/facebookresearch/mae
# DeiT: https://github.com/facebookresearch/deit
# BEiT: https://github.com/microsoft/unilm/tree/master/beit
# --------------------------------------------------------
import builtins
import datetime
import os
import time
from collections import defaultdict, deque
from pathlib import Path
import torch
import torch.distributed as dist
if torch.__version__[0] == '2':
from torch import inf
else:
from torch._six import inf
class SmoothedValue(object):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.5f} ({global_avg:.5f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
if not is_dist_avail_and_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=self.max,
value=self.value)
class MetricLogger(object):
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if v is None:
continue
if isinstance(v, torch.Tensor):
v = v.item()
assert isinstance(v, (float, int))
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, attr))
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(
"{}: {}".format(name, str(meter))
)
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
if torch.cuda.is_available():
log_msg.append('max mem: {memory:.0f}')
log_msg = self.delimiter.join(log_msg)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB))
else:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('{} Total time: {} ({:.4f} s / it)'.format(
header, total_time_str, total_time / len(iterable)))
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
builtin_print = builtins.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
force = force or (get_world_size() > 8)
if is_master:# or force:
now = datetime.datetime.now().time()
builtin_print('[{}] '.format(now), end='') # print with time stamp
builtin_print(*args, **kwargs)
builtins.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if args.dist_on_itp:
args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])
os.environ['LOCAL_RANK'] = str(args.gpu)
os.environ['RANK'] = str(args.rank)
os.environ['WORLD_SIZE'] = str(args.world_size)
# ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
elif 'SLURM_PROCID' in os.environ:
args.rank = int(os.environ['SLURM_PROCID'])
args.gpu = args.rank % torch.cuda.device_count()
else:
print('Not using distributed mode')
setup_for_distributed(is_master=True) # hack
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
class NativeScalerWithGradNormCount:
state_dict_key = "amp_scaler"
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
if update_grad:
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
else:
self._scaler.unscale_(optimizer)
norm = get_grad_norm_(parameters)
self._scaler.step(optimizer)
self._scaler.update()
else:
norm = None
return norm
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = [p for p in parameters if p.grad is not None]
norm_type = float(norm_type)
if len(parameters) == 0:
return torch.tensor(0.)
device = parameters[0].grad.device
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
else:
total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
return total_norm
def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):
output_dir = Path(args.output_dir)
epoch_name = str(epoch)
if loss_scaler is not None:
checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
for checkpoint_path in checkpoint_paths:
to_save = {
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'scaler': loss_scaler.state_dict(),
'args': args,
}
save_on_master(to_save, checkpoint_path)
else:
client_state = {'epoch': epoch}
model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
def load_model(args, model_without_ddp, optimizer, loss_scaler):
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.resume, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
print("Resume checkpoint %s" % args.resume)
if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
optimizer.load_state_dict(checkpoint['optimizer'])
args.start_epoch = checkpoint['epoch'] + 1
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
print("With optim & sched!")
def all_reduce_mean(x):
world_size = get_world_size()
if world_size > 1:
x_reduce = torch.tensor(x).cuda()
dist.all_reduce(x_reduce)
x_reduce /= world_size
return x_reduce.item()
else:
return x
gitextract_qju22int/
├── .gitignore
├── README.md
├── engine.py
├── eval.py
├── infer.py
├── inverese.py
├── main.py
├── models.py
├── normalize.py
├── points.py
└── util/
├── lr_decay.py
├── lr_sched.py
└── misc.py
SYMBOL INDEX (82 symbols across 8 files)
FILE: engine.py
function train_one_epoch (line 29) | def train_one_epoch(model: torch.nn.Module,
FILE: main.py
function get_args_parser (line 26) | def get_args_parser():
function main (line 102) | def main(args):
FILE: models.py
function modulate (line 12) | def modulate(x, shift, scale):
class TimestepEmbedder (line 15) | class TimestepEmbedder(nn.Module):
method __init__ (line 19) | def __init__(self, hidden_size, frequency_embedding_size=256):
method timestep_embedding (line 29) | def timestep_embedding(t, dim, max_period=10000):
method forward (line 49) | def forward(self, t):
class MPFourier (line 54) | class MPFourier(torch.nn.Module):
method __init__ (line 55) | def __init__(self, num_channels, bandwidth=1):
method forward (line 60) | def forward(self, x):
function normalize (line 69) | def normalize(x, dim=None, eps=1e-4):
function mp_silu (line 76) | def mp_silu(x):
function mp_sum (line 79) | def mp_sum(a, b, t=0.5):
class MPConv (line 83) | class MPConv(torch.nn.Module):
method __init__ (line 84) | def __init__(self, in_channels, out_channels, kernel):
method forward (line 89) | def forward(self, x, gain=1):
class PointEmbed (line 102) | class PointEmbed(nn.Module):
method __init__ (line 103) | def __init__(self, hidden_dim=48, dim=128, other_dim=0):
method embed (line 124) | def embed(input, basis):
method forward (line 130) | def forward(self, input):
class Network (line 144) | class Network(nn.Module):
method __init__ (line 145) | def __init__(
method forward (line 182) | def forward(self, x, t):
class EDMPrecond (line 208) | class EDMPrecond(torch.nn.Module):
method __init__ (line 209) | def __init__(self,
method forward (line 226) | def forward(self, x, sigma, force_fp32=False, **model_kwargs):
method round_sigma (line 243) | def round_sigma(self, sigma):
method sample (line 247) | def sample(self, cond=None, batch_seeds=None, channels=3, num_steps=18):
method inverse (line 261) | def inverse(self, cond=None, samples=None, channels=3, num_steps=18):
class StackedRandomGenerator (line 265) | class StackedRandomGenerator:
method __init__ (line 266) | def __init__(self, device, seeds):
method randn (line 270) | def randn(self, size, **kwargs):
method randn_like (line 274) | def randn_like(self, input):
method randint (line 277) | def randint(self, *args, size, **kwargs):
function edm_sampler (line 281) | def edm_sampler(
function inverse_edm_sampler (line 327) | def inverse_edm_sampler(
class EDMLoss (line 383) | class EDMLoss:
method __init__ (line 384) | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=1, dist='Gaussia...
method __call__ (line 391) | def __call__(self, net, inputs, labels=None, augment_pipe=None, init_n...
FILE: normalize.py
function normalize_meshes (line 19) | def normalize_meshes(mesh):
FILE: points.py
class Points (line 9) | class Points(data.Dataset):
method __init__ (line 10) | def __init__(self, ply_path):
method __len__ (line 21) | def __len__(self):
method __getitem__ (line 24) | def __getitem__(self, idx):
FILE: util/lr_decay.py
function param_groups_lrd (line 11) | def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], ...
function get_layer_id_for_vit (line 60) | def get_layer_id_for_vit(name, num_layers):
FILE: util/lr_sched.py
function adjust_learning_rate (line 9) | def adjust_learning_rate(optimizer, epoch, args):
FILE: util/misc.py
class SmoothedValue (line 24) | class SmoothedValue(object):
method __init__ (line 29) | def __init__(self, window_size=20, fmt=None):
method update (line 37) | def update(self, value, n=1):
method synchronize_between_processes (line 42) | def synchronize_between_processes(self):
method median (line 56) | def median(self):
method avg (line 61) | def avg(self):
method global_avg (line 66) | def global_avg(self):
method max (line 70) | def max(self):
method value (line 74) | def value(self):
method __str__ (line 77) | def __str__(self):
class MetricLogger (line 86) | class MetricLogger(object):
method __init__ (line 87) | def __init__(self, delimiter="\t"):
method update (line 91) | def update(self, **kwargs):
method __getattr__ (line 100) | def __getattr__(self, attr):
method __str__ (line 108) | def __str__(self):
method synchronize_between_processes (line 116) | def synchronize_between_processes(self):
method add_meter (line 120) | def add_meter(self, name, meter):
method log_every (line 123) | def log_every(self, iterable, print_freq, header=None):
function setup_for_distributed (line 170) | def setup_for_distributed(is_master):
function is_dist_avail_and_initialized (line 187) | def is_dist_avail_and_initialized():
function get_world_size (line 195) | def get_world_size():
function get_rank (line 201) | def get_rank():
function is_main_process (line 207) | def is_main_process():
function save_on_master (line 211) | def save_on_master(*args, **kwargs):
function init_distributed_mode (line 216) | def init_distributed_mode(args):
class NativeScalerWithGradNormCount (line 251) | class NativeScalerWithGradNormCount:
method __init__ (line 254) | def __init__(self):
method __call__ (line 257) | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, c...
method state_dict (line 273) | def state_dict(self):
method load_state_dict (line 276) | def load_state_dict(self, state_dict):
function get_grad_norm_ (line 280) | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
function save_model (line 295) | def save_model(args, epoch, model, model_without_ddp, optimizer, loss_sc...
function load_model (line 315) | def load_model(args, model_without_ddp, optimizer, loss_scaler):
function all_reduce_mean (line 332) | def all_reduce_mean(x):
Condensed preview — 13 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (58K chars).
[
{
"path": ".gitignore",
"chars": 116,
"preview": "*.ipynb\n*.ply\n.ipynb_checkpoints\ntest*.*\n*.stl\nvis.py\n*.obj\n__pycache__\n*_dataset\noutput/\nbackup.py\nshapes/\nsamples/"
},
{
"path": "README.md",
"chars": 1352,
"preview": "# Geometry Distributions\n\n### [Project Page](https://1zb.github.io/GeomDist/) | [Paper (arXiv)](https://arxiv.org/abs/24"
},
{
"path": "engine.py",
"chars": 6335,
"preview": "# --------------------------------------------------------\n# References:\n# MAE: https://github.com/facebookresearch/mae\n"
},
{
"path": "eval.py",
"chars": 825,
"preview": "\nimport trimesh\nfrom scipy.spatial import cKDTree as KDTree\nimport numpy as np\n\nimport argparse\n\nparser = argparse.Argum"
},
{
"path": "infer.py",
"chars": 3292,
"preview": "import argparse \nfrom pathlib import Path\nimport os\n\nimport torch\n\nimport trimesh\n\nfrom models import EDMPrecond\n\ntorch."
},
{
"path": "inverese.py",
"chars": 2351,
"preview": "import argparse \n\nimport torch\n\nimport trimesh\n\nfrom models import EDMPrecond\n\ntorch.manual_seed(0)\n\nimport numpy as np\n"
},
{
"path": "main.py",
"chars": 9770,
"preview": "import argparse\nimport datetime\nimport json\nimport numpy as np\nimport os\nimport time\nfrom pathlib import Path\n\nimport to"
},
{
"path": "models.py",
"chars": 15335,
"preview": "import torch\nimport torch.nn as nn\n\nimport math\n\nimport numpy as np\n\nimport torch.nn.functional\nimport trimesh\n\n\ndef mod"
},
{
"path": "normalize.py",
"chars": 918,
"preview": "import argparse \n\nimport trimesh\n\nimport math\n\nimport glob\n\nimport numpy as np\n\nparser = argparse.ArgumentParser('Infere"
},
{
"path": "points.py",
"chars": 756,
"preview": "import trimesh\n\nimport numpy as np\nimport os\n\nimport torch\nfrom torch.utils import data\n\nclass Points(data.Dataset):\n "
},
{
"path": "util/lr_decay.py",
"chars": 2307,
"preview": "# --------------------------------------------------------\n# References:\n# MAE: https://github.com/facebookresearch/mae\n"
},
{
"path": "util/lr_sched.py",
"chars": 801,
"preview": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the li"
},
{
"path": "util/misc.py",
"chars": 11362,
"preview": "# --------------------------------------------------------\n# References:\n# MAE: https://github.com/facebookresearch/mae\n"
}
]
About this extraction
This page contains the full source code of the 1zb/GeomDist GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 13 files (54.2 KB), approximately 13.9k tokens, and a symbol index with 82 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.