Repository: 1zb/GeomDist Branch: master Commit: bf593d307676 Files: 13 Total size: 54.2 KB Directory structure: gitextract_qju22int/ ├── .gitignore ├── README.md ├── engine.py ├── eval.py ├── infer.py ├── inverese.py ├── main.py ├── models.py ├── normalize.py ├── points.py └── util/ ├── lr_decay.py ├── lr_sched.py └── misc.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.ipynb *.ply .ipynb_checkpoints test*.* *.stl vis.py *.obj __pycache__ *_dataset output/ backup.py shapes/ samples/ ================================================ FILE: README.md ================================================ # Geometry Distributions ### [Project Page](https://1zb.github.io/GeomDist/) | [Paper (arXiv)](https://arxiv.org/abs/2411.16076) ### :bullettrain_front: Training ``` torchrun --nproc_per_node=4 main.py --blr 5e-7 --output_dir output/loong --log_dir output/loong --data_path shapes/loong.obj ``` ### :balloon: Inference ``` python infer.py --pth output/loong/checkpoint-999.pth --target Gaussian --num-steps 64 --output samples/loong --N 10000000 ``` ### :floppy_disk: Datasets https://huggingface.co/datasets/Zbalpha/shapes ### :briefcase: Checkpoints https://huggingface.co/Zbalpha/geom_dist_ckpt ## :e-mail: Contact Contact [Biao Zhang](mailto:biao.zhang@kaust.edu.sa) ([@1zb](https://github.com/1zb)) if you have any further questions. This repository is for academic research use only. ## :blue_book: Citation arxiv ```bibtex @article{zhang2024geometry, title={Geometry Distributions}, author={Zhang, Biao and Ren, Jing and Wonka, Peter}, journal={arXiv preprint arXiv:2411.16076}, year={2024} } ``` ICCV ``` @InProceedings{Zhang_2025_ICCV, author = {Zhang, Biao and Ren, Jing and Wonka, Peter}, title = {Geometry Distributions}, booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, month = {October}, year = {2025}, pages = {1495-1505} } ``` ================================================ FILE: engine.py ================================================ # -------------------------------------------------------- # References: # MAE: https://github.com/facebookresearch/mae # DeiT: https://github.com/facebookresearch/deit # BEiT: https://github.com/microsoft/unilm/tree/master/beit # -------------------------------------------------------- import math import sys from typing import Iterable import torch import torch.nn.functional as F import numpy as np import util.misc as misc import util.lr_sched as lr_sched from torch.autograd import Variable from math import exp from einops import rearrange, repeat import trimesh from PIL import Image def train_one_epoch(model: torch.nn.Module, data_loader, optimizer: torch.optim.Optimizer, criterion, device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, log_writer=None, args=None): model.train(True) metric_logger = misc.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = 20 accum_iter = args.accum_iter optimizer.zero_grad() if log_writer is not None: print('log_dir: {}'.format(log_writer.log_dir)) print(data_loader) noise = None if isinstance(data_loader, dict): obj_file = data_loader['obj_file'] batch_size = data_loader['batch_size'] if obj_file is not None: if obj_file.endswith('.obj'): mesh = trimesh.load(obj_file) if data_loader['texture_path'] is not None: img = Image.open(data_loader['texture_path']) material = trimesh.visual.texture.SimpleMaterial(image=img) assert mesh.visual.uv is not None texture = trimesh.visual.TextureVisuals(mesh.visual.uv, image=img, material=material) mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, visual=texture, process=False) samples, _, colors = trimesh.sample.sample_surface(mesh, 2048*64*4*64, sample_color=True) colors = colors[:, :3] # remove alpha colors = (colors.astype(np.float32) / 255.0 - 0.5) / np.sqrt(1/12) # [-1, 1] samples = np.concatenate([samples, colors], axis=1) else: samples, _ = trimesh.sample.sample_surface(mesh, 2048*64*4*64) else: samples = trimesh.load(obj_file).vertices else: if data_loader['primitive'] == 'sphere': n = torch.randn(2048*64*4*64, 3) n = torch.nn.functional.normalize(n, dim=1) samples = n / np.sqrt(1/3) samples = samples.numpy() elif data_loader['primitive'] == 'plane': samples = torch.rand(2048*64*4*64, 3) - 0.5 samples[:, 2] = 0 samples = (samples - 0) / np.sqrt(2/9*2*0.5**3) samples = samples.numpy() elif data_loader['primitive'] == 'volume': samples = (torch.rand(2048*64*4*64, 3) - 0.5) / np.sqrt(1/12) samples = samples.numpy() elif data_loader['primitive'] == 'gaussian': samples = np.random.randn(2048*64*4*64, 3).astype(np.float32) else: raise NotImplementedError if data_loader['noise_mesh'] is not None: noise, _ = trimesh.sample.sample_surface(trimesh.load(data_loader['noise_mesh']), 2048*64*4*64) else: noise = None samples = samples.astype(np.float32)# - 0.12 data_loader = range(data_loader['epoch_size']) for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)): # we use a per iteration (instead of per epoch) lr scheduler if data_iter_step % accum_iter == 0: lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args) if isinstance(batch, int): ind = np.random.default_rng().choice(samples.shape[0], batch_size, replace=True) xyz = samples[ind] xyz = torch.from_numpy(xyz).float().to(device, non_blocking=True) else: xyz = batch.to(device, non_blocking=True) with torch.cuda.amp.autocast(enabled=False): if noise is not None: ind = np.random.default_rng().choice(noise.shape[0], batch_size, replace=True) init_noise = noise[ind] init_noise = torch.from_numpy(init_noise).float().to(device, non_blocking=True) else: init_noise = None loss = criterion(model, xyz, init_noise=init_noise) loss_value = loss.item() if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) sys.exit(1) loss /= accum_iter loss_scaler(loss, optimizer, clip_grad=max_norm, parameters=model.parameters(), create_graph=False, update_grad=(data_iter_step + 1) % accum_iter == 0) if (data_iter_step + 1) % accum_iter == 0: optimizer.zero_grad() torch.cuda.synchronize() metric_logger.update(loss=loss_value) min_lr = 10. max_lr = 0. for group in optimizer.param_groups: min_lr = min(min_lr, group["lr"]) max_lr = max(max_lr, group["lr"]) metric_logger.update(lr=max_lr) loss_value_reduce = misc.all_reduce_mean(loss_value) if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: """ We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes. """ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x) log_writer.add_scalar('lr', max_lr, epoch_1000x) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} ================================================ FILE: eval.py ================================================ import trimesh from scipy.spatial import cKDTree as KDTree import numpy as np import argparse parser = argparse.ArgumentParser() parser.add_argument('--ply', required=True, type=str) parser.add_argument('--reference', required=True, type=str) parser.add_argument('--scale', required=True, type=str) args = parser.parse_args() scale = np.load(args.scale) prediction = trimesh.load(args.ply).vertices * scale reference = trimesh.load(args.reference).vertices * scale tree = KDTree(prediction) dist, _ = tree.query(reference) d1 = dist gt_to_gen_chamfer = np.mean(dist) gt_to_gen_chamfer_sq = np.mean(np.square(dist)) tree = KDTree(reference) dist, _ = tree.query(prediction) d2 = dist gen_to_gt_chamfer = np.mean(dist) gen_to_gt_chamfer_sq = np.mean(np.square(dist)) cd = gt_to_gen_chamfer + gen_to_gt_chamfer print(cd) ================================================ FILE: infer.py ================================================ import argparse from pathlib import Path import os import torch import trimesh from models import EDMPrecond torch.manual_seed(0) import numpy as np np.random.seed(0) import random random.seed(0) # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True parser = argparse.ArgumentParser('Inference', add_help=False) parser.add_argument('--pth', required=True, type=str) parser.add_argument('--texture', action='store_true') parser.add_argument('--target', default='Gaussian', type=str) parser.add_argument('--N', default=1000000, type=int) parser.add_argument('--num-steps', default=64, type=int) parser.add_argument('--noise_mesh', default=None, type=str) parser.add_argument('--output', required=True, type=str) parser.add_argument('--intermediate', action='store_true') parser.add_argument('--depth', default=6, type=int) parser.set_defaults(texture=False) parser.set_defaults(intermediate=False) args = parser.parse_args() Path(args.output).mkdir(parents=True, exist_ok=True) if args.texture: model = EDMPrecond(channels=6, depth=args.depth).cuda() else: model = EDMPrecond(depth=args.depth).cuda() model.load_state_dict(torch.load(args.pth, map_location='cpu')['model'], strict=True) if args.target == 'Gaussian': noise = torch.randn(args.N, 3).cuda() elif args.target == 'Uniform': noise = (torch.rand(args.N, 3).cuda() - 0.5) / np.sqrt(1/12) elif args.target == 'Sphere': n = torch.randn(args.N, 3).cuda() n = torch.nn.functional.normalize(n, dim=1) noise = n / np.sqrt(1/3) elif args.target == 'Mesh': assert args.noise_mesh is not None noise, _ = trimesh.sample.sample_surface(trimesh.load(args.noise_mesh), args.N) noise = torch.from_numpy(noise).float().cuda() else: raise NotImplementedError if args.texture: color = (torch.rand(args.N, 3).cuda() - 0.5) / np.sqrt(1/12) noise = torch.cat([noise, color], dim=1) sample, intermediate_steps = model.sample(batch_seeds=noise, num_steps=args.num_steps) if args.texture: sample = sample.detach().cpu().numpy() vertices, colors = sample[:, :3], sample[:, 3:] colors = (colors * np.sqrt(1/12) + 0.5) * 255.0 colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel trimesh.PointCloud(vertices, colors).export(os.path.join(args.output, 'sample.ply')) if args.intermediate: for i, s in enumerate(intermediate_steps): vertices, colors = s[:, :3], s[:, 3:] colors = (colors * np.sqrt(1/12) + 0.5) * 255.0 colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel trimesh.PointCloud(vertices, colors).export(os.path.join(args.output, 'sample-{:03d}.ply'.format(i))) else: trimesh.PointCloud(sample.detach().cpu().numpy()).export(os.path.join(args.output, 'sample.ply')) if args.intermediate: for i, s in enumerate(intermediate_steps): trimesh.PointCloud(s).export(os.path.join(args.output, 'sample-{:03d}.ply'.format(i))) ================================================ FILE: inverese.py ================================================ import argparse import torch import trimesh from models import EDMPrecond torch.manual_seed(0) import numpy as np np.random.seed(0) import random random.seed(0) # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True parser = argparse.ArgumentParser('Inference', add_help=False) parser.add_argument('--pth', default='output/lamp_cube/checkpoint-0.pth', type=str) parser.add_argument('--texture', action='store_true') parser.add_argument('--N', default=1000000, type=int) parser.add_argument('--num-steps', default=64, type=int) parser.add_argument('--noise_mesh', default=None, type=str) parser.add_argument('--data_path', default='shapes/Jellyfish_lamp_part_A__B_normalized.obj', type=str) parser.set_defaults(texture=False) args = parser.parse_args() if args.texture: model = EDMPrecond(channels=6).cuda() else: model = EDMPrecond().cuda() mesh = trimesh.load(args.data_path) samples, _ = trimesh.sample.sample_surface(mesh, args.N) samples = samples.astype(np.float32) samples = torch.from_numpy(samples).float().cuda() model.load_state_dict(torch.load(args.pth, map_location='cpu')['model'], strict=True) sample, intermediate_steps = model.inverse(samples=samples, num_steps=args.num_steps) if args.texture: sample = sample.detach().cpu().numpy() vertices, colors = sample[:, :3], sample[:, 3:] colors = (colors * np.sqrt(1/12) + 0.5) * 255.0 colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel trimesh.PointCloud(vertices, colors).export('sample.ply') for i, s in enumerate(intermediate_steps): vertices, colors = s[:, :3], s[:, 3:] colors = (colors * np.sqrt(1/12) + 0.5) * 255.0 colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel trimesh.PointCloud(vertices, colors).export('sample-{:03d}.ply'.format(i)) else: trimesh.PointCloud(sample.detach().cpu().numpy()).export('sample.ply') for i, s in enumerate(intermediate_steps): trimesh.PointCloud(s).export('sample-{:03d}.ply'.format(i)) ================================================ FILE: main.py ================================================ import argparse import datetime import json import numpy as np import os import time from pathlib import Path import torch import torch.backends.cudnn as cudnn from torch.utils.tensorboard import SummaryWriter torch.set_num_threads(8) import util.lr_decay as lrd import util.misc as misc from util.misc import NativeScalerWithGradNormCount as NativeScaler import models as models from models import EDMLoss from engine import train_one_epoch from points import Points def get_args_parser(): parser = argparse.ArgumentParser('Train', add_help=False) parser.add_argument('--batch_size', default=2048*64*2, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus') parser.add_argument('--epochs', default=1000, type=int) parser.add_argument('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)') # Model parameters parser.add_argument('--model', default='EDMPrecond', type=str, metavar='MODEL', help='Name of model to train') parser.add_argument('--depth', default=6, type=int, metavar='MODEL') # Optimizer parameters parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') parser.add_argument('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)') parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)') parser.add_argument('--blr', type=float, default=5e-7, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256') parser.add_argument('--layer_decay', type=float, default=0.75, help='layer-wise lr decay from ELECTRA/BEiT') parser.add_argument('--min_lr', type=float, default=5e-7, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0') parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N', help='epochs to warmup LR') # Dataset parameters parser.add_argument('--target', default='Gaussian', type=str, ) parser.add_argument('--data_path', default='shapes/Jellyfish_lamp_part_A__B_normalized.obj', type=str, help='dataset path') parser.add_argument('--texture_path', default=None, type=str, help='dataset path') parser.add_argument('--noise_mesh', default=None, type=str, help='dataset path') parser.add_argument('--output_dir', default='./output/', help='path where to save, empty for no saving') parser.add_argument('--log_dir', default='./output/', help='path where to tensorboard log') parser.add_argument('--device', default='cuda', help='device to use for training / testing') parser.add_argument('--seed', default=0, type=int) parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') parser.add_argument('--eval', action='store_true', help='Perform evaluation only') parser.add_argument('--dist_eval', action='store_true', default=False, help='Enabling distributed evaluation (recommended during training for faster monitor') parser.add_argument('--num_workers', default=32, type=int) parser.add_argument('--pin_mem', action='store_true', help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem') parser.set_defaults(pin_mem=True) # distributed training parameters parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') parser.add_argument('--local_rank', default=-1, type=int) parser.add_argument('--dist_on_itp', action='store_true') parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') return parser def main(args): # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' misc.init_distributed_mode(args) print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__)))) print("{}".format(args).replace(', ', ',\n')) device = torch.device(args.device) # fix the seed for reproducibility seed = args.seed + misc.get_rank() torch.manual_seed(seed) np.random.seed(seed) cudnn.benchmark = True cudnn.deterministic=True # The flag below controls whether to allow TF32 on matmul. This flag defaults to False # in PyTorch 1.12 and later. torch.backends.cuda.matmul.allow_tf32 = True # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. torch.backends.cudnn.allow_tf32 = True if True: num_tasks = misc.get_world_size() global_rank = misc.get_rank() neural_rendering_resolution = 128 if args.data_path.endswith('.obj') or args.data_path.endswith('.ply'): data_loader_train = { 'obj_file': args.data_path, 'batch_size': args.batch_size, 'epoch_size': 512, 'texture_path': args.texture_path, } if args.noise_mesh is not None: data_loader_train['noise_mesh'] = args.noise_mesh else: data_loader_train['noise_mesh'] = None elif 'sphere' in args.data_path or 'plane' in args.data_path or 'volume' in args.data_path: data_loader_train = { 'obj_file': None, 'primitive': args.data_path, 'batch_size': args.batch_size, 'epoch_size': 512, 'texture_path': args.texture_path, } if args.noise_mesh is not None: data_loader_train['noise_mesh'] = args.noise_mesh else: data_loader_train['noise_mesh'] = None else: raise NotImplementedError print(data_loader_train) if global_rank == 0 and args.log_dir is not None and not args.eval: os.makedirs(args.log_dir, exist_ok=True) log_writer = SummaryWriter(log_dir=args.log_dir) else: log_writer = None criterion = EDMLoss(dist=args.target) model = models.__dict__[args.model](channels=3 if args.texture_path is None else 6, depth=args.depth) model.to(device) model_without_ddp = model n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) print("Model = %s" % str(model_without_ddp)) print('number of params (M): %.2f' % (n_parameters / 1.e6)) eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() if args.lr is None: # only base_lr is specified args.lr = args.blr * eff_batch_size / 128 print("base lr: %.2e" % (args.lr * 128 / eff_batch_size)) print("actual lr: %.2e" % args.lr) print("accumulate grad iterations: %d" % args.accum_iter) print("effective batch size: %d" % eff_batch_size) if args.distributed: model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) model_without_ddp = model.module optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr) loss_scaler = NativeScaler() misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler) print(f"Start training for {args.epochs} epochs") start_time = time.time() max_iou = 0.0 for epoch in range(args.start_epoch, args.epochs): # if args.distributed and args.data_path.endswith('.ply'): # data_loader_train.sampler.set_epoch(epoch) train_stats = train_one_epoch( model, data_loader_train, optimizer, criterion, device, epoch, loss_scaler, args.clip_grad, log_writer=log_writer, args=args ) if args.output_dir and (epoch % 5 == 0 or epoch + 1 == args.epochs): misc.save_model( args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch) if epoch % 1 == 0 or epoch + 1 == args.epochs: log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, # **{f'test_{k}': v for k, v in test_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} else: log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 'epoch': epoch, 'n_parameters': n_parameters} if args.output_dir and misc.is_main_process(): if log_writer is not None: log_writer.flush() with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f: f.write(json.dumps(log_stats) + "\n") total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('Training time {}'.format(total_time_str)) if __name__ == '__main__': args = get_args_parser() args = args.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) main(args) ================================================ FILE: models.py ================================================ import torch import torch.nn as nn import math import numpy as np import torch.nn.functional import trimesh def modulate(x, shift, scale): return x * (1 + scale) + shift class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class MPFourier(torch.nn.Module): def __init__(self, num_channels, bandwidth=1): super().__init__() self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth) self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels)) def forward(self, x): y = x.to(torch.float32) y = y.ger(self.freqs.to(torch.float32)) y = y + self.phases.to(torch.float32) y = y.cos() * np.sqrt(2) return y.to(x.dtype) def normalize(x, dim=None, eps=1e-4): if dim is None: dim = list(range(1, x.ndim)) norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) return x / norm.to(x.dtype) def mp_silu(x): return torch.nn.functional.silu(x) / 0.596 def mp_sum(a, b, t=0.5): # print(a.mean(), a.std(), b.mean(), b.std()) return a.lerp(b, t) / np.sqrt((1 - t) ** 2 + t ** 2) class MPConv(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel): super().__init__() self.out_channels = out_channels self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) def forward(self, x, gain=1): w = self.weight.to(torch.float32) if self.training: with torch.no_grad(): self.weight.copy_(normalize(w)) # forced weight normalization w = normalize(w) # traditional weight normalization w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling w = w.to(x.dtype) if w.ndim == 2: return x @ w.t() assert w.ndim == 4 return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,)) class PointEmbed(nn.Module): def __init__(self, hidden_dim=48, dim=128, other_dim=0): super().__init__() assert hidden_dim % 6 == 0 self.embedding_dim = hidden_dim e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi e = torch.stack([ torch.cat([e, torch.zeros(self.embedding_dim // 6), torch.zeros(self.embedding_dim // 6)]), torch.cat([torch.zeros(self.embedding_dim // 6), e, torch.zeros(self.embedding_dim // 6)]), torch.cat([torch.zeros(self.embedding_dim // 6), torch.zeros(self.embedding_dim // 6), e]), ]) self.register_buffer('basis', e) # 3 x 16 # self.mlp = nn.Linear(self.embedding_dim+3, dim)/ self.mlp = MPConv(self.embedding_dim+3+other_dim, dim, kernel=[]) @staticmethod def embed(input, basis): # print(input.shape, basis.shape) projections = torch.einsum('nd,de->ne', input, basis) embeddings = torch.cat([projections.sin(), projections.cos()], dim=1) return embeddings def forward(self, input): # input: N x 3 if input.shape[1] != 3: input, others = input[:, :3], input[:, 3:] else: others = None if others is None: embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=1)) # N x C else: embed = self.mlp(torch.cat([self.embed(input, self.basis), input, others], dim=1)) return embed class Network(nn.Module): def __init__( self, channels = 3, hidden_size = 256, depth = 6, ): super().__init__() self.emb_fourier = MPFourier(hidden_size) self.emb_noise = MPConv(hidden_size, hidden_size, kernel=[]) self.x_embedder = PointEmbed(dim=hidden_size, other_dim=channels-3) self.gains = nn.ParameterList([ torch.nn.Parameter(torch.zeros([])) for _ in range(depth) ]) ## self.layers = nn.ModuleList([ nn.ModuleList([ MPConv(hidden_size, hidden_size, []), MPConv(hidden_size, hidden_size, []), MPConv(hidden_size, 1 * hidden_size, []), ]) for _ in range(depth) ]) self.final_emb_gain = torch.nn.Parameter(torch.zeros([])) self.final_out_gain = torch.nn.Parameter(torch.zeros([])) self.final_layer = nn.ModuleList([ MPConv(hidden_size, hidden_size, []), MPConv(hidden_size, channels, []), MPConv(hidden_size, hidden_size, []), ]) self.res_balance = 0.3 def forward(self, x, t): x = self.x_embedder(x) if t.shape[0] == 1: t = t.repeat(x.shape[0]) t = mp_silu(self.emb_noise(self.emb_fourier(t))) for (x_proj_pre, x_proj_post, emb_linear), emb_gain in zip(self.layers, self.gains): c = emb_linear(t, gain=emb_gain) + 1 x = normalize(x) y = x_proj_pre(mp_silu(x)) y = mp_silu(y * c.to(y.dtype)) y = x_proj_post(y) x = mp_sum(x, y, t=self.res_balance) x_proj_pre, x_proj_post, emb_linear = self.final_layer c = emb_linear(t, gain=self.final_emb_gain) + 1 y = x_proj_pre(mp_silu(normalize(x))) y = mp_silu(y * c.to(y.dtype)) out = x_proj_post(y, gain=self.final_out_gain) return out class EDMPrecond(torch.nn.Module): def __init__(self, channels = 3, use_fp16 = False, sigma_min = 0, sigma_max = float('inf'), sigma_data = 1, depth = 6, ): super().__init__() self.use_fp16 = use_fp16 self.sigma_min = sigma_min self.sigma_max = sigma_max self.sigma_data = sigma_data self.model = Network(channels=channels, hidden_size=512, depth=depth) def forward(self, x, sigma, force_fp32=False, **model_kwargs): x = x.to(torch.float32) sigma = sigma.to(torch.float32).reshape(-1, 1) dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() c_noise = sigma.log() / 4 F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs) assert F_x.dtype == dtype D_x = c_skip * x + c_out * F_x.to(torch.float32) return D_x def round_sigma(self, sigma): return torch.as_tensor(sigma) @torch.no_grad() def sample(self, cond=None, batch_seeds=None, channels=3, num_steps=18): device = batch_seeds.device batch_size = batch_seeds.shape[0] rnd = None points = batch_seeds latents = points.float().to(device) points = edm_sampler(self, latents, cond, num_steps=num_steps) return points @torch.no_grad() def inverse(self, cond=None, samples=None, channels=3, num_steps=18): return inverse_edm_sampler(self, samples, cond, num_steps=num_steps) class StackedRandomGenerator: def __init__(self, device, seeds): super().__init__() self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds] def randn(self, size, **kwargs): assert size[0] == len(self.generators) return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]) def randn_like(self, input): return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device) def randint(self, *args, size, **kwargs): assert size[0] == len(self.generators) return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators]) def edm_sampler( net, latents, class_labels=None, randn_like=torch.randn_like, num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ): # disable S_churn assert S_churn==0 # Adjust noise levels based on what's supported by the network. sigma_min = max(sigma_min, net.sigma_min) sigma_max = min(sigma_max, net.sigma_max) # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 # Main sampling loop. x_next = latents.to(torch.float64) * t_steps[0] outputs = [] outputs.append((x_next / t_steps[0]).detach().cpu().numpy()) print(t_steps[0]) for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 print(t_cur, t_next) x_cur = x_next # Increase noise temporarily. gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 t_hat = net.round_sigma(t_cur + gamma * t_cur) x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) # x_hat = x_cur t_hat = t_cur # Euler step. denoised = net(x_hat, t_hat, class_labels).to(torch.float64) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: denoised = net(x_next, t_next, class_labels).to(torch.float64) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) outputs.append((x_next / (1+t_next**2).sqrt()).detach().cpu().numpy()) return x_next, outputs def inverse_edm_sampler( net, latents, class_labels=None, randn_like=torch.randn_like, num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, ): # disable S_churn assert S_churn==0 # Adjust noise levels based on what's supported by the network. sigma_min = max(sigma_min, net.sigma_min) sigma_max = min(sigma_max, net.sigma_max) # Time step discretization. step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device) t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])+1e-8]) # t_N = 0 t_steps = torch.flip(t_steps, [0])#[1:] # Main sampling loop. x_next = latents.to(torch.float64)# * t_steps[0] # outputs = [] outputs = None # outputs.append((x_next / t_steps[0]).detach().cpu().numpy()) print(t_steps[0]) print(x_next.mean(), x_next.std()) for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 # print('steps', t_cur, t_next) x_cur = x_next # print('cur', (x_cur / t_cur).mean(), (x_cur / t_cur).std()) # Increase noise temporarily. gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 t_hat = net.round_sigma(t_cur + gamma * t_cur) x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur) x_hat = x_cur t_hat = t_cur # Euler step. denoised = net(x_hat, t_hat, class_labels).to(torch.float64) d_cur = (x_hat - denoised) / t_hat x_next = x_hat + (t_next - t_hat) * d_cur # Apply 2nd order correction. if i < num_steps - 1: denoised = net(x_next, t_next, class_labels).to(torch.float64) d_prime = (x_next - denoised) / t_next x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) print('next', (x_next / (1+t_next**2).sqrt()).mean(), (x_next / (1+t_next**2).sqrt()).std()) # outputs.append((x_next / (1+t_next**2).sqrt()).detach().cpu().numpy()) x_next = x_next / (1+t_next**2).sqrt() return x_next, outputs class EDMLoss: def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=1, dist='Gaussian'): self.P_mean = P_mean self.P_std = P_std self.sigma_data = sigma_data self.dist = dist def __call__(self, net, inputs, labels=None, augment_pipe=None, init_noise=None): rnd_normal = torch.randn([inputs.shape[0],], device=inputs.device) sigma = (rnd_normal * self.P_std + self.P_mean).exp() weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 y, augment_labels = augment_pipe(inputs) if augment_pipe is not None else (inputs, None) if self.dist == 'Gaussian': n = torch.randn_like(y[:, :3]) * sigma[:, None] if y.shape[1] != 3: c = (torch.rand_like(y[:, 3:]) - 0.5) / np.sqrt(1/12) * sigma[:, None] n = torch.cat([n, c], dim=1) elif self.dist == 'Uniform': n = (torch.rand_like(y) - 0.5) / np.sqrt(1/12) * sigma[:, None] elif self.dist == 'Sphere': n = torch.randn_like(y[:, :3]) n = torch.nn.functional.normalize(n, dim=1) n /= np.sqrt(1/3) n = n * sigma[:, None] elif self.dist == "Mesh": assert init_noise is not None n = init_noise * sigma[:, None] else: raise NotImplementedError D_yn = net(y + n, sigma) loss = weight[:, None] * ((D_yn - y) ** 2) return loss.mean() ================================================ FILE: normalize.py ================================================ import argparse import trimesh import math import glob import numpy as np parser = argparse.ArgumentParser('Inference', add_help=False) parser.add_argument('--path', required=True, type=str) parser.add_argument('--output', required=True, type=str) args = parser.parse_args() model = trimesh.load(args.path, process=False) def normalize_meshes(mesh): mesh.vertices -= (mesh.vertices.max(axis=0) + mesh.vertices.min(axis=0)) / 2 scale = (1 / np.abs(mesh.vertices).max()) * 0.99 mesh.vertices *= scale points, _ = trimesh.sample.sample_surface(mesh, 10000000) mesh.vertices -= points.mean() mesh.vertices /= points.std() return mesh model = normalize_meshes(model) # angle = math.pi / 2 # direction = [1, 0, 0] # center = [0, 0, 0] # rot_matrix = trimesh.transformations.rotation_matrix(angle, direction, center) # model.apply_transform(rot_matrix) model.export(args.output) ================================================ FILE: points.py ================================================ import trimesh import numpy as np import os import torch from torch.utils import data class Points(data.Dataset): def __init__(self, ply_path): points = trimesh.load(ply_path).vertices # self.points = np.array(points) # if os.path.exists('test.npy'): # points = np.load('test.npy') # else: # points, _ = trimesh.sample.sample_surface(trimesh.load(ply_path), 50000000*5) # np.save('test.npy', points) self.points = torch.from_numpy(points)# - 0.12 print(self.points.std(), self.points.mean()) def __len__(self): return self.points.shape[0]# * 16 def __getitem__(self, idx): # idx = idx % self.points.shape[0] return self.points[idx] ================================================ FILE: util/lr_decay.py ================================================ # -------------------------------------------------------- # References: # MAE: https://github.com/facebookresearch/mae # DeiT: https://github.com/facebookresearch/deit # BEiT: https://github.com/microsoft/unilm/tree/master/beit # -------------------------------------------------------- import json def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75): """ Parameter groups for layer-wise lr decay Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 """ param_group_names = {} param_groups = {} num_layers = len(model.blocks) + 1 layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) for n, p in model.named_parameters(): if not p.requires_grad: continue # no decay: all 1D parameters and model specific ones if p.ndim == 1 or n in no_weight_decay_list: g_decay = "no_decay" this_decay = 0. else: g_decay = "decay" this_decay = weight_decay layer_id = get_layer_id_for_vit(n, num_layers) group_name = "layer_%d_%s" % (layer_id, g_decay) if group_name not in param_group_names: this_scale = layer_scales[layer_id] param_group_names[group_name] = { "lr_scale": this_scale, "weight_decay": this_decay, "params": [], } param_groups[group_name] = { "lr_scale": this_scale, "weight_decay": this_decay, "params": [], } param_group_names[group_name]["params"].append(n) param_groups[group_name]["params"].append(p) # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) return list(param_groups.values()) def get_layer_id_for_vit(name, num_layers): """ Assign a parameter with its layer id Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 """ if name in ['cls_token', 'pos_embed']: return 0 elif name.startswith('patch_embed'): return 0 elif name.startswith('blocks'): return int(name.split('.')[1]) + 1 else: return num_layers ================================================ FILE: util/lr_sched.py ================================================ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import math def adjust_learning_rate(optimizer, epoch, args): """Decay the learning rate with half-cycle cosine after warmup""" if epoch < args.warmup_epochs: lr = args.lr * epoch / args.warmup_epochs else: lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs))) for param_group in optimizer.param_groups: if "lr_scale" in param_group: param_group["lr"] = lr * param_group["lr_scale"] else: param_group["lr"] = lr return lr ================================================ FILE: util/misc.py ================================================ # -------------------------------------------------------- # References: # MAE: https://github.com/facebookresearch/mae # DeiT: https://github.com/facebookresearch/deit # BEiT: https://github.com/microsoft/unilm/tree/master/beit # -------------------------------------------------------- import builtins import datetime import os import time from collections import defaultdict, deque from pathlib import Path import torch import torch.distributed as dist if torch.__version__[0] == '2': from torch import inf else: from torch._six import inf class SmoothedValue(object): """Track a series of values and provide access to smoothed values over a window or the global series average. """ def __init__(self, window_size=20, fmt=None): if fmt is None: fmt = "{median:.5f} ({global_avg:.5f})" self.deque = deque(maxlen=window_size) self.total = 0.0 self.count = 0 self.fmt = fmt def update(self, value, n=1): self.deque.append(value) self.count += n self.total += value * n def synchronize_between_processes(self): """ Warning: does not synchronize the deque! """ if not is_dist_avail_and_initialized(): return t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') dist.barrier() dist.all_reduce(t) t = t.tolist() self.count = int(t[0]) self.total = t[1] @property def median(self): d = torch.tensor(list(self.deque)) return d.median().item() @property def avg(self): d = torch.tensor(list(self.deque), dtype=torch.float32) return d.mean().item() @property def global_avg(self): return self.total / self.count @property def max(self): return max(self.deque) @property def value(self): return self.deque[-1] def __str__(self): return self.fmt.format( median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value) class MetricLogger(object): def __init__(self, delimiter="\t"): self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter def update(self, **kwargs): for k, v in kwargs.items(): if v is None: continue if isinstance(v, torch.Tensor): v = v.item() assert isinstance(v, (float, int)) self.meters[k].update(v) def __getattr__(self, attr): if attr in self.meters: return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] raise AttributeError("'{}' object has no attribute '{}'".format( type(self).__name__, attr)) def __str__(self): loss_str = [] for name, meter in self.meters.items(): loss_str.append( "{}: {}".format(name, str(meter)) ) return self.delimiter.join(loss_str) def synchronize_between_processes(self): for meter in self.meters.values(): meter.synchronize_between_processes() def add_meter(self, name, meter): self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): i = 0 if not header: header = '' start_time = time.time() end = time.time() iter_time = SmoothedValue(fmt='{avg:.4f}') data_time = SmoothedValue(fmt='{avg:.4f}') space_fmt = ':' + str(len(str(len(iterable)))) + 'd' log_msg = [ header, '[{0' + space_fmt + '}/{1}]', 'eta: {eta}', '{meters}', 'time: {time}', 'data: {data}' ] if torch.cuda.is_available(): log_msg.append('max mem: {memory:.0f}') log_msg = self.delimiter.join(log_msg) MB = 1024.0 * 1024.0 for obj in iterable: data_time.update(time.time() - end) yield obj iter_time.update(time.time() - end) if i % print_freq == 0 or i == len(iterable) - 1: eta_seconds = iter_time.global_avg * (len(iterable) - i) eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) if torch.cuda.is_available(): print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time), memory=torch.cuda.max_memory_allocated() / MB)) else: print(log_msg.format( i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time))) i += 1 end = time.time() total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print('{} Total time: {} ({:.4f} s / it)'.format( header, total_time_str, total_time / len(iterable))) def setup_for_distributed(is_master): """ This function disables printing when not in master process """ builtin_print = builtins.print def print(*args, **kwargs): force = kwargs.pop('force', False) force = force or (get_world_size() > 8) if is_master:# or force: now = datetime.datetime.now().time() builtin_print('[{}] '.format(now), end='') # print with time stamp builtin_print(*args, **kwargs) builtins.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_master(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def init_distributed_mode(args): if args.dist_on_itp: args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) os.environ['LOCAL_RANK'] = str(args.gpu) os.environ['RANK'] = str(args.rank) os.environ['WORLD_SIZE'] = str(args.world_size) # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) elif 'SLURM_PROCID' in os.environ: args.rank = int(os.environ['SLURM_PROCID']) args.gpu = args.rank % torch.cuda.device_count() else: print('Not using distributed mode') setup_for_distributed(is_master=True) # hack args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}, gpu {}'.format( args.rank, args.dist_url, args.gpu), flush=True) torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) torch.distributed.barrier() setup_for_distributed(args.rank == 0) class NativeScalerWithGradNormCount: state_dict_key = "amp_scaler" def __init__(self): self._scaler = torch.cuda.amp.GradScaler() def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): self._scaler.scale(loss).backward(create_graph=create_graph) if update_grad: if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) else: self._scaler.unscale_(optimizer) norm = get_grad_norm_(parameters) self._scaler.step(optimizer) self._scaler.update() else: norm = None return norm def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict) def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] norm_type = float(norm_type) if len(parameters) == 0: return torch.tensor(0.) device = parameters[0].grad.device if norm_type == inf: total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) else: total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) return total_norm def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler): output_dir = Path(args.output_dir) epoch_name = str(epoch) if loss_scaler is not None: checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)] for checkpoint_path in checkpoint_paths: to_save = { 'model': model_without_ddp.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch, 'scaler': loss_scaler.state_dict(), 'args': args, } save_on_master(to_save, checkpoint_path) else: client_state = {'epoch': epoch} model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state) def load_model(args, model_without_ddp, optimizer, loss_scaler): if args.resume: if args.resume.startswith('https'): checkpoint = torch.hub.load_state_dict_from_url( args.resume, map_location='cpu', check_hash=True) else: checkpoint = torch.load(args.resume, map_location='cpu') model_without_ddp.load_state_dict(checkpoint['model']) print("Resume checkpoint %s" % args.resume) if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval): optimizer.load_state_dict(checkpoint['optimizer']) args.start_epoch = checkpoint['epoch'] + 1 if 'scaler' in checkpoint: loss_scaler.load_state_dict(checkpoint['scaler']) print("With optim & sched!") def all_reduce_mean(x): world_size = get_world_size() if world_size > 1: x_reduce = torch.tensor(x).cuda() dist.all_reduce(x_reduce) x_reduce /= world_size return x_reduce.item() else: return x