[
  {
    "path": ".gitignore",
    "content": "*.ipynb\n*.ply\n.ipynb_checkpoints\ntest*.*\n*.stl\nvis.py\n*.obj\n__pycache__\n*_dataset\noutput/\nbackup.py\nshapes/\nsamples/"
  },
  {
    "path": "README.md",
    "content": "# Geometry Distributions\n\n### [Project Page](https://1zb.github.io/GeomDist/) | [Paper (arXiv)](https://arxiv.org/abs/2411.16076)\n\n### :bullettrain_front: Training\n\n```\ntorchrun --nproc_per_node=4 main.py --blr 5e-7 --output_dir output/loong --log_dir output/loong --data_path shapes/loong.obj\n```\n\n### :balloon: Inference\n\n```\npython infer.py --pth output/loong/checkpoint-999.pth --target Gaussian --num-steps 64 --output samples/loong --N 10000000\n```\n\n### :floppy_disk: Datasets\nhttps://huggingface.co/datasets/Zbalpha/shapes\n\n### :briefcase: Checkpoints\nhttps://huggingface.co/Zbalpha/geom_dist_ckpt\n\n## :e-mail: Contact\n\nContact [Biao Zhang](mailto:biao.zhang@kaust.edu.sa) ([@1zb](https://github.com/1zb)) if you have any further questions. This repository is for academic research use only.\n\n## :blue_book: Citation\n\narxiv\n```bibtex\n@article{zhang2024geometry,\n  title={Geometry Distributions},\n  author={Zhang, Biao and Ren, Jing and Wonka, Peter},\n  journal={arXiv preprint arXiv:2411.16076},\n  year={2024}\n}\n```\n\nICCV\n```\n@InProceedings{Zhang_2025_ICCV,\n    author    = {Zhang, Biao and Ren, Jing and Wonka, Peter},\n    title     = {Geometry Distributions},\n    booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)},\n    month     = {October},\n    year      = {2025},\n    pages     = {1495-1505}\n}\n```\n"
  },
  {
    "path": "engine.py",
    "content": "# --------------------------------------------------------\n# References:\n# MAE: https://github.com/facebookresearch/mae\n# DeiT: https://github.com/facebookresearch/deit\n# BEiT: https://github.com/microsoft/unilm/tree/master/beit\n# --------------------------------------------------------\n\nimport math\nimport sys\nfrom typing import Iterable\n\nimport torch\nimport torch.nn.functional as F\n\nimport numpy as np\n\nimport util.misc as misc\nimport util.lr_sched as lr_sched\n\nfrom torch.autograd import Variable\nfrom math import exp\n\nfrom einops import rearrange, repeat\n\nimport trimesh\n\nfrom PIL import Image\n\ndef train_one_epoch(model: torch.nn.Module,\n                    data_loader, optimizer: torch.optim.Optimizer,\n                    criterion,\n                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,\n                    log_writer=None, args=None):\n    model.train(True)\n    metric_logger = misc.MetricLogger(delimiter=\"  \")\n    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))\n    header = 'Epoch: [{}]'.format(epoch)\n    print_freq = 20\n    \n    accum_iter = args.accum_iter\n\n    optimizer.zero_grad()\n\n    if log_writer is not None:\n        print('log_dir: {}'.format(log_writer.log_dir))\n    \n    print(data_loader)\n\n    noise = None\n\n    if isinstance(data_loader, dict):\n        obj_file = data_loader['obj_file']\n        batch_size = data_loader['batch_size']\n\n        if obj_file is not None:\n            if obj_file.endswith('.obj'):\n                mesh = trimesh.load(obj_file)\n                if data_loader['texture_path'] is not None:\n                    img = Image.open(data_loader['texture_path'])\n                    material = trimesh.visual.texture.SimpleMaterial(image=img)\n                    assert mesh.visual.uv is not None\n                    texture = trimesh.visual.TextureVisuals(mesh.visual.uv, image=img, material=material)\n                    mesh = trimesh.Trimesh(vertices=mesh.vertices, faces=mesh.faces, visual=texture, process=False)\n\n                    samples, _, colors = trimesh.sample.sample_surface(mesh,  2048*64*4*64, sample_color=True)\n                    colors = colors[:, :3] # remove alpha\n                    colors = (colors.astype(np.float32) / 255.0 - 0.5)  / np.sqrt(1/12) # [-1, 1]\n                    samples = np.concatenate([samples, colors], axis=1)\n                else:\n                    samples, _ = trimesh.sample.sample_surface(mesh,  2048*64*4*64)\n            else:\n                samples = trimesh.load(obj_file).vertices\n\n        else:\n            if data_loader['primitive'] == 'sphere':\n                n = torch.randn(2048*64*4*64, 3)\n                n = torch.nn.functional.normalize(n, dim=1)\n                samples = n / np.sqrt(1/3)\n                samples = samples.numpy()\n            elif data_loader['primitive'] == 'plane':\n                samples = torch.rand(2048*64*4*64, 3) - 0.5\n                samples[:, 2] = 0\n                samples = (samples - 0) / np.sqrt(2/9*2*0.5**3)\n                samples = samples.numpy()\n            elif data_loader['primitive'] == 'volume':\n                samples = (torch.rand(2048*64*4*64, 3) - 0.5) / np.sqrt(1/12) \n                samples = samples.numpy()\n            elif data_loader['primitive'] == 'gaussian':\n                samples = np.random.randn(2048*64*4*64, 3).astype(np.float32)\n            else:\n                raise NotImplementedError\n\n        if data_loader['noise_mesh'] is not None:\n            noise, _ = trimesh.sample.sample_surface(trimesh.load(data_loader['noise_mesh']),  2048*64*4*64)\n        else:\n            noise = None\n\n        samples = samples.astype(np.float32)# - 0.12\n        data_loader = range(data_loader['epoch_size'])\n\n    for data_iter_step, batch in enumerate(metric_logger.log_every(data_loader, print_freq, header)):\n\n        # we use a per iteration (instead of per epoch) lr scheduler\n        if data_iter_step % accum_iter == 0:\n            lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)\n\n        if isinstance(batch, int):\n            ind = np.random.default_rng().choice(samples.shape[0], batch_size, replace=True)\n            xyz = samples[ind]\n            xyz = torch.from_numpy(xyz).float().to(device, non_blocking=True)\n        else:\n            xyz = batch.to(device, non_blocking=True)\n\n        with torch.cuda.amp.autocast(enabled=False):\n            if noise is not None:\n                ind = np.random.default_rng().choice(noise.shape[0], batch_size, replace=True)\n                init_noise = noise[ind]\n                init_noise = torch.from_numpy(init_noise).float().to(device, non_blocking=True)\n            else:\n                init_noise = None\n            loss = criterion(model, xyz, init_noise=init_noise)\n            \n        loss_value = loss.item()\n\n        if not math.isfinite(loss_value):\n            print(\"Loss is {}, stopping training\".format(loss_value))\n            sys.exit(1)\n\n        loss /= accum_iter\n        loss_scaler(loss, optimizer, clip_grad=max_norm,\n                    parameters=model.parameters(), create_graph=False,\n                    update_grad=(data_iter_step + 1) % accum_iter == 0)\n        if (data_iter_step + 1) % accum_iter == 0:\n            optimizer.zero_grad()\n\n        torch.cuda.synchronize()\n\n        metric_logger.update(loss=loss_value)\n\n        min_lr = 10.\n        max_lr = 0.\n        for group in optimizer.param_groups:\n            min_lr = min(min_lr, group[\"lr\"])\n            max_lr = max(max_lr, group[\"lr\"])\n\n        metric_logger.update(lr=max_lr)\n\n        loss_value_reduce = misc.all_reduce_mean(loss_value)\n        if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:\n            \"\"\" We use epoch_1000x as the x-axis in tensorboard.\n            This calibrates different curves when batch size changes.\n            \"\"\"\n            epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)\n            log_writer.add_scalar('loss', loss_value_reduce, epoch_1000x)\n            log_writer.add_scalar('lr', max_lr, epoch_1000x)\n\n    # gather the stats from all processes\n    metric_logger.synchronize_between_processes()\n    print(\"Averaged stats:\", metric_logger)\n    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}\n"
  },
  {
    "path": "eval.py",
    "content": "\nimport trimesh\nfrom scipy.spatial import cKDTree as KDTree\nimport numpy as np\n\nimport argparse\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--ply', required=True, type=str)\nparser.add_argument('--reference', required=True, type=str)\nparser.add_argument('--scale', required=True, type=str)\nargs = parser.parse_args()\n\nscale = np.load(args.scale)\n\nprediction = trimesh.load(args.ply).vertices * scale\nreference = trimesh.load(args.reference).vertices * scale\n\ntree = KDTree(prediction)\ndist, _ = tree.query(reference)\nd1 = dist\ngt_to_gen_chamfer = np.mean(dist)\ngt_to_gen_chamfer_sq = np.mean(np.square(dist))\n\ntree = KDTree(reference)\ndist, _ = tree.query(prediction)\nd2 = dist\ngen_to_gt_chamfer = np.mean(dist)\ngen_to_gt_chamfer_sq = np.mean(np.square(dist))\n\ncd = gt_to_gen_chamfer + gen_to_gt_chamfer\nprint(cd)"
  },
  {
    "path": "infer.py",
    "content": "import argparse \nfrom pathlib import Path\nimport os\n\nimport torch\n\nimport trimesh\n\nfrom models import EDMPrecond\n\ntorch.manual_seed(0)\n\nimport numpy as np\nnp.random.seed(0)\n\nimport random\nrandom.seed(0)\n\n# The flag below controls whether to allow TF32 on matmul. This flag defaults to False\n# in PyTorch 1.12 and later.\ntorch.backends.cuda.matmul.allow_tf32 = True\n\n# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.\ntorch.backends.cudnn.allow_tf32 = True\n\nparser = argparse.ArgumentParser('Inference', add_help=False)\nparser.add_argument('--pth', required=True, type=str)\nparser.add_argument('--texture', action='store_true')\nparser.add_argument('--target', default='Gaussian', type=str)\nparser.add_argument('--N', default=1000000, type=int)\nparser.add_argument('--num-steps', default=64, type=int)\nparser.add_argument('--noise_mesh', default=None, type=str)\nparser.add_argument('--output', required=True, type=str)\nparser.add_argument('--intermediate', action='store_true')\nparser.add_argument('--depth', default=6, type=int)\nparser.set_defaults(texture=False)\nparser.set_defaults(intermediate=False)\n\nargs = parser.parse_args()\n\nPath(args.output).mkdir(parents=True, exist_ok=True)\n\nif args.texture:\n    model = EDMPrecond(channels=6, depth=args.depth).cuda()\nelse:\n    model = EDMPrecond(depth=args.depth).cuda()\n\nmodel.load_state_dict(torch.load(args.pth, map_location='cpu')['model'], strict=True)\n\nif args.target == 'Gaussian':\n    noise = torch.randn(args.N, 3).cuda()\nelif args.target == 'Uniform':\n    noise = (torch.rand(args.N, 3).cuda() - 0.5) / np.sqrt(1/12)\nelif args.target == 'Sphere':\n    n = torch.randn(args.N, 3).cuda()\n    n = torch.nn.functional.normalize(n, dim=1)\n    noise = n / np.sqrt(1/3)\nelif args.target == 'Mesh':\n    assert args.noise_mesh is not None\n    noise, _ = trimesh.sample.sample_surface(trimesh.load(args.noise_mesh), args.N)\n    noise = torch.from_numpy(noise).float().cuda()\nelse:\n    raise NotImplementedError\n\nif args.texture:\n    color = (torch.rand(args.N, 3).cuda() - 0.5) / np.sqrt(1/12)\n    noise = torch.cat([noise, color], dim=1)\n\nsample, intermediate_steps = model.sample(batch_seeds=noise, num_steps=args.num_steps)\n\nif args.texture:\n    sample = sample.detach().cpu().numpy()\n    vertices, colors = sample[:, :3], sample[:, 3:]\n    colors = (colors * np.sqrt(1/12) + 0.5) * 255.0\n    colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel\n    trimesh.PointCloud(vertices, colors).export(os.path.join(args.output, 'sample.ply'))\n\n    if args.intermediate:\n        for i, s in enumerate(intermediate_steps):\n            vertices, colors = s[:, :3], s[:, 3:]\n            colors = (colors * np.sqrt(1/12) + 0.5) * 255.0\n            colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel\n\n            trimesh.PointCloud(vertices, colors).export(os.path.join(args.output, 'sample-{:03d}.ply'.format(i)))\n\nelse:\n    trimesh.PointCloud(sample.detach().cpu().numpy()).export(os.path.join(args.output, 'sample.ply'))\n\n    if args.intermediate:\n        for i, s in enumerate(intermediate_steps):\n            trimesh.PointCloud(s).export(os.path.join(args.output, 'sample-{:03d}.ply'.format(i)))\n"
  },
  {
    "path": "inverese.py",
    "content": "import argparse \n\nimport torch\n\nimport trimesh\n\nfrom models import EDMPrecond\n\ntorch.manual_seed(0)\n\nimport numpy as np\nnp.random.seed(0)\n\nimport random\nrandom.seed(0)\n\n# The flag below controls whether to allow TF32 on matmul. This flag defaults to False\n# in PyTorch 1.12 and later.\ntorch.backends.cuda.matmul.allow_tf32 = True\n\n# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.\ntorch.backends.cudnn.allow_tf32 = True\n\nparser = argparse.ArgumentParser('Inference', add_help=False)\nparser.add_argument('--pth', default='output/lamp_cube/checkpoint-0.pth', type=str)\nparser.add_argument('--texture', action='store_true')\nparser.add_argument('--N', default=1000000, type=int)\nparser.add_argument('--num-steps', default=64, type=int)\nparser.add_argument('--noise_mesh', default=None, type=str)\nparser.add_argument('--data_path', default='shapes/Jellyfish_lamp_part_A__B_normalized.obj', type=str)\nparser.set_defaults(texture=False)\n\nargs = parser.parse_args()\n\nif args.texture:\n    model = EDMPrecond(channels=6).cuda()\nelse:\n    model = EDMPrecond().cuda()\n\nmesh = trimesh.load(args.data_path)\nsamples, _ = trimesh.sample.sample_surface(mesh,  args.N)\nsamples = samples.astype(np.float32)\nsamples = torch.from_numpy(samples).float().cuda()\n\n    \nmodel.load_state_dict(torch.load(args.pth, map_location='cpu')['model'], strict=True)\n\nsample, intermediate_steps = model.inverse(samples=samples, num_steps=args.num_steps)\n\nif args.texture:\n    sample = sample.detach().cpu().numpy()\n    vertices, colors = sample[:, :3], sample[:, 3:]\n    colors = (colors * np.sqrt(1/12) + 0.5) * 255.0\n    colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel\n    trimesh.PointCloud(vertices, colors).export('sample.ply')\n\n    for i, s in enumerate(intermediate_steps):\n        vertices, colors = s[:, :3], s[:, 3:]\n        colors = (colors * np.sqrt(1/12) + 0.5) * 255.0\n        colors = np.concatenate([colors, np.ones_like(colors[:, 0:1]) * 255.0], axis=1).astype(np.uint8) # alpha channel\n\n        trimesh.PointCloud(vertices, colors).export('sample-{:03d}.ply'.format(i))\n\nelse:\n    trimesh.PointCloud(sample.detach().cpu().numpy()).export('sample.ply')\n\n    for i, s in enumerate(intermediate_steps):\n        trimesh.PointCloud(s).export('sample-{:03d}.ply'.format(i))\n"
  },
  {
    "path": "main.py",
    "content": "import argparse\nimport datetime\nimport json\nimport numpy as np\nimport os\nimport time\nfrom pathlib import Path\n\nimport torch\nimport torch.backends.cudnn as cudnn\nfrom torch.utils.tensorboard import SummaryWriter\n\ntorch.set_num_threads(8)\nimport util.lr_decay as lrd\nimport util.misc as misc\nfrom util.misc import NativeScalerWithGradNormCount as NativeScaler\n\nimport models as models\nfrom models import EDMLoss\n\nfrom engine import train_one_epoch\n\nfrom points import Points\n\n\ndef get_args_parser():\n    parser = argparse.ArgumentParser('Train', add_help=False)\n    parser.add_argument('--batch_size', default=2048*64*2, type=int,\n                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')\n    parser.add_argument('--epochs', default=1000, type=int)\n    parser.add_argument('--accum_iter', default=1, type=int,\n                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')\n    \n\n    # Model parameters\n    parser.add_argument('--model', default='EDMPrecond', type=str, metavar='MODEL',\n                        help='Name of model to train')\n    parser.add_argument('--depth', default=6, type=int, metavar='MODEL')\n\n    # Optimizer parameters\n    parser.add_argument('--clip_grad', type=float, default=None, metavar='NORM',\n                        help='Clip gradient norm (default: None, no clipping)')\n    parser.add_argument('--weight_decay', type=float, default=0.05,\n                        help='weight decay (default: 0.05)')\n\n    parser.add_argument('--lr', type=float, default=None, metavar='LR',\n                        help='learning rate (absolute lr)')\n    parser.add_argument('--blr', type=float, default=5e-7, metavar='LR',\n                        help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')\n    parser.add_argument('--layer_decay', type=float, default=0.75,\n                        help='layer-wise lr decay from ELECTRA/BEiT')\n\n    parser.add_argument('--min_lr', type=float, default=5e-7, metavar='LR',\n                        help='lower lr bound for cyclic schedulers that hit 0')\n\n    parser.add_argument('--warmup_epochs', type=int, default=1, metavar='N',\n                        help='epochs to warmup LR')\n\n    # Dataset parameters\n    parser.add_argument('--target', default='Gaussian', type=str, )\n    parser.add_argument('--data_path', default='shapes/Jellyfish_lamp_part_A__B_normalized.obj', type=str,\n                        help='dataset path')\n\n    parser.add_argument('--texture_path', default=None, type=str,\n                        help='dataset path')\n\n    parser.add_argument('--noise_mesh', default=None, type=str,\n                        help='dataset path')\n     \n    parser.add_argument('--output_dir', default='./output/',\n                        help='path where to save, empty for no saving')\n    parser.add_argument('--log_dir', default='./output/',\n                        help='path where to tensorboard log')\n    parser.add_argument('--device', default='cuda',\n                        help='device to use for training / testing')\n    parser.add_argument('--seed', default=0, type=int)\n    parser.add_argument('--resume', default='',\n                        help='resume from checkpoint')\n\n    parser.add_argument('--start_epoch', default=0, type=int, metavar='N',\n                        help='start epoch')\n    parser.add_argument('--eval', action='store_true',\n                        help='Perform evaluation only')\n    parser.add_argument('--dist_eval', action='store_true', default=False,\n                        help='Enabling distributed evaluation (recommended during training for faster monitor')\n    parser.add_argument('--num_workers', default=32, type=int)\n    parser.add_argument('--pin_mem', action='store_true',\n                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')\n    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')\n    parser.set_defaults(pin_mem=True)\n\n    # distributed training parameters\n    parser.add_argument('--world_size', default=1, type=int,\n                        help='number of distributed processes')\n    parser.add_argument('--local_rank', default=-1, type=int)\n    parser.add_argument('--dist_on_itp', action='store_true')\n    parser.add_argument('--dist_url', default='env://',\n                        help='url used to set up distributed training')\n\n    return parser\n\ndef main(args):\n\n    # os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n\n    misc.init_distributed_mode(args)\n\n    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))\n    print(\"{}\".format(args).replace(', ', ',\\n'))\n\n    device = torch.device(args.device)\n\n    # fix the seed for reproducibility\n    seed = args.seed + misc.get_rank()\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n\n    cudnn.benchmark = True\n    cudnn.deterministic=True\n\n    # The flag below controls whether to allow TF32 on matmul. This flag defaults to False\n    # in PyTorch 1.12 and later.\n    torch.backends.cuda.matmul.allow_tf32 = True\n\n    # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.\n    torch.backends.cudnn.allow_tf32 = True\n\n    if True:\n        num_tasks = misc.get_world_size()\n        global_rank = misc.get_rank()\n\n\n    neural_rendering_resolution = 128\n    if args.data_path.endswith('.obj') or args.data_path.endswith('.ply'):\n        data_loader_train = {\n            'obj_file': args.data_path,\n            'batch_size': args.batch_size,\n            'epoch_size': 512,\n            'texture_path': args.texture_path,\n        }\n        if args.noise_mesh is not None:\n            data_loader_train['noise_mesh'] = args.noise_mesh\n        else:\n            data_loader_train['noise_mesh'] = None\n    elif 'sphere' in args.data_path or 'plane' in args.data_path or 'volume' in args.data_path:\n        data_loader_train = {\n            'obj_file': None,\n            'primitive': args.data_path,\n            'batch_size': args.batch_size,\n            'epoch_size': 512,\n            'texture_path': args.texture_path,\n        }\n        if args.noise_mesh is not None:\n            data_loader_train['noise_mesh'] = args.noise_mesh\n        else:\n            data_loader_train['noise_mesh'] = None\n    else:\n        raise NotImplementedError\n    print(data_loader_train)\n\n    if global_rank == 0 and args.log_dir is not None and not args.eval:\n        os.makedirs(args.log_dir, exist_ok=True)\n        log_writer = SummaryWriter(log_dir=args.log_dir)\n    else:\n        log_writer = None\n\n\n    criterion = EDMLoss(dist=args.target)\n    \n    model = models.__dict__[args.model](channels=3 if args.texture_path is None else 6, depth=args.depth)\n    model.to(device)\n\n    model_without_ddp = model\n    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)\n\n    print(\"Model = %s\" % str(model_without_ddp))\n    print('number of params (M): %.2f' % (n_parameters / 1.e6))\n\n    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()\n    \n    if args.lr is None:  # only base_lr is specified\n        args.lr = args.blr * eff_batch_size / 128\n\n    print(\"base lr: %.2e\" % (args.lr * 128 / eff_batch_size))\n    print(\"actual lr: %.2e\" % args.lr)\n\n    print(\"accumulate grad iterations: %d\" % args.accum_iter)\n    print(\"effective batch size: %d\" % eff_batch_size)\n\n    if args.distributed:\n        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)\n        model_without_ddp = model.module\n\n    optimizer = torch.optim.AdamW(model_without_ddp.parameters(), lr=args.lr)\n    loss_scaler = NativeScaler()\n\n    misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)\n\n    print(f\"Start training for {args.epochs} epochs\")\n    start_time = time.time()\n    max_iou = 0.0\n    for epoch in range(args.start_epoch, args.epochs):\n        # if args.distributed and args.data_path.endswith('.ply'):\n        #     data_loader_train.sampler.set_epoch(epoch)\n\n        train_stats = train_one_epoch(\n            model, data_loader_train,\n            optimizer, criterion, device, epoch, loss_scaler,\n            args.clip_grad,\n            log_writer=log_writer,\n            args=args\n        )\n        if args.output_dir and (epoch % 5 == 0 or epoch + 1 == args.epochs):\n            misc.save_model(\n                args=args, model=model, model_without_ddp=model_without_ddp, optimizer=optimizer,\n                loss_scaler=loss_scaler, epoch=epoch)\n\n        if epoch % 1 == 0 or epoch + 1 == args.epochs:\n\n            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},\n                            # **{f'test_{k}': v for k, v in test_stats.items()},\n                            'epoch': epoch,\n                            'n_parameters': n_parameters}\n        else:\n            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},\n                            'epoch': epoch,\n                            'n_parameters': n_parameters}\n\n        if args.output_dir and misc.is_main_process():\n            if log_writer is not None:\n                log_writer.flush()\n            with open(os.path.join(args.output_dir, \"log.txt\"), mode=\"a\", encoding=\"utf-8\") as f:\n                f.write(json.dumps(log_stats) + \"\\n\")\n\n\n            \n    total_time = time.time() - start_time\n    total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n    print('Training time {}'.format(total_time_str))\n\nif __name__ == '__main__':\n    args = get_args_parser()\n    args = args.parse_args()\n    if args.output_dir:\n        Path(args.output_dir).mkdir(parents=True, exist_ok=True)\n    main(args)\n"
  },
  {
    "path": "models.py",
    "content": "import torch\nimport torch.nn as nn\n\nimport math\n\nimport numpy as np\n\nimport torch.nn.functional\nimport trimesh\n\n\ndef modulate(x, shift, scale):\n    return x * (1 + scale) + shift\n\nclass TimestepEmbedder(nn.Module):\n    \"\"\"\n    Embeds scalar timesteps into vector representations.\n    \"\"\"\n    def __init__(self, hidden_size, frequency_embedding_size=256):\n        super().__init__()\n        self.mlp = nn.Sequential(\n            nn.Linear(frequency_embedding_size, hidden_size, bias=True),\n            nn.SiLU(),\n            nn.Linear(hidden_size, hidden_size, bias=True),\n        )\n        self.frequency_embedding_size = frequency_embedding_size\n\n    @staticmethod\n    def timestep_embedding(t, dim, max_period=10000):\n        \"\"\"\n        Create sinusoidal timestep embeddings.\n        :param t: a 1-D Tensor of N indices, one per batch element.\n                          These may be fractional.\n        :param dim: the dimension of the output.\n        :param max_period: controls the minimum frequency of the embeddings.\n        :return: an (N, D) Tensor of positional embeddings.\n        \"\"\"\n        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py\n        half = dim // 2\n        freqs = torch.exp(\n            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half\n        ).to(device=t.device)\n        args = t[:, None].float() * freqs[None]\n        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n        if dim % 2:\n            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)\n        return embedding\n\n    def forward(self, t):\n        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)\n        t_emb = self.mlp(t_freq)\n        return t_emb\n\nclass MPFourier(torch.nn.Module):\n    def __init__(self, num_channels, bandwidth=1):\n        super().__init__()\n        self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels) * bandwidth)\n        self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels))\n\n    def forward(self, x):\n        y = x.to(torch.float32)\n        y = y.ger(self.freqs.to(torch.float32))\n        y = y + self.phases.to(torch.float32)\n        y = y.cos() * np.sqrt(2)\n        return y.to(x.dtype)\n    \n\n\ndef normalize(x, dim=None, eps=1e-4):\n    if dim is None:\n        dim = list(range(1, x.ndim))\n    norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)\n    norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))\n    return x / norm.to(x.dtype)\n\ndef mp_silu(x):\n    return torch.nn.functional.silu(x) / 0.596\n\ndef mp_sum(a, b, t=0.5):\n    # print(a.mean(), a.std(), b.mean(), b.std())\n    return a.lerp(b, t) / np.sqrt((1 - t) ** 2 + t ** 2)\n\nclass MPConv(torch.nn.Module):\n    def __init__(self, in_channels, out_channels, kernel):\n        super().__init__()\n        self.out_channels = out_channels\n        self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))\n\n    def forward(self, x, gain=1):\n        w = self.weight.to(torch.float32)\n        if self.training:\n            with torch.no_grad():\n                self.weight.copy_(normalize(w)) # forced weight normalization\n        w = normalize(w) # traditional weight normalization\n        w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling\n        w = w.to(x.dtype)\n        if w.ndim == 2:\n            return x @ w.t()\n        assert w.ndim == 4\n        return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))\n\nclass PointEmbed(nn.Module):\n    def __init__(self, hidden_dim=48, dim=128, other_dim=0):\n        super().__init__()\n\n        assert hidden_dim % 6 == 0\n\n        self.embedding_dim = hidden_dim\n        e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi\n        e = torch.stack([\n            torch.cat([e, torch.zeros(self.embedding_dim // 6),\n                        torch.zeros(self.embedding_dim // 6)]),\n            torch.cat([torch.zeros(self.embedding_dim // 6), e,\n                        torch.zeros(self.embedding_dim // 6)]),\n            torch.cat([torch.zeros(self.embedding_dim // 6),\n                        torch.zeros(self.embedding_dim // 6), e]),\n        ])\n        self.register_buffer('basis', e)  # 3 x 16\n\n        # self.mlp = nn.Linear(self.embedding_dim+3, dim)/\n        self.mlp = MPConv(self.embedding_dim+3+other_dim, dim, kernel=[])\n\n    @staticmethod\n    def embed(input, basis):\n        # print(input.shape, basis.shape)\n        projections = torch.einsum('nd,de->ne', input, basis)\n        embeddings = torch.cat([projections.sin(), projections.cos()], dim=1)\n        return embeddings\n    \n    def forward(self, input):\n        # input: N x 3\n        if input.shape[1] != 3:\n            input, others = input[:, :3], input[:, 3:]\n        else:\n            others = None\n        \n        if others is None:\n            embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=1)) # N x C\n        else:\n            embed = self.mlp(torch.cat([self.embed(input, self.basis), input, others], dim=1))\n        return embed\n\n\nclass Network(nn.Module):\n    def __init__(\n        self,\n        channels = 3,\n        hidden_size = 256,\n        depth = 6,\n    ):\n        super().__init__()\n\n        self.emb_fourier = MPFourier(hidden_size)\n        self.emb_noise = MPConv(hidden_size, hidden_size, kernel=[])\n\n        self.x_embedder = PointEmbed(dim=hidden_size, other_dim=channels-3)\n\n        self.gains = nn.ParameterList([\n            torch.nn.Parameter(torch.zeros([])) for _ in range(depth)\n        ])\n        ##\n        self.layers = nn.ModuleList([\n            nn.ModuleList([\n                MPConv(hidden_size, hidden_size, []),\n                MPConv(hidden_size, hidden_size, []),\n                MPConv(hidden_size, 1 * hidden_size, []),\n            ]) for _ in range(depth)\n        ])\n\n\n        self.final_emb_gain = torch.nn.Parameter(torch.zeros([]))\n        self.final_out_gain = torch.nn.Parameter(torch.zeros([]))\n        self.final_layer = nn.ModuleList([\n            MPConv(hidden_size, hidden_size, []),\n            MPConv(hidden_size, channels, []),\n            MPConv(hidden_size, hidden_size, []),\n        ])\n\n        self.res_balance = 0.3\n\n\n    def forward(self, x, t):\n        x = self.x_embedder(x)\n\n        if t.shape[0] == 1:\n            t = t.repeat(x.shape[0])\n\n        t = mp_silu(self.emb_noise(self.emb_fourier(t)))\n\n        for (x_proj_pre, x_proj_post, emb_linear), emb_gain in zip(self.layers, self.gains):\n\n            c = emb_linear(t, gain=emb_gain) + 1\n\n            x = normalize(x)\n            y = x_proj_pre(mp_silu(x))\n            y = mp_silu(y * c.to(y.dtype))\n            y = x_proj_post(y)\n            x = mp_sum(x, y, t=self.res_balance)\n\n        x_proj_pre, x_proj_post, emb_linear = self.final_layer\n        c = emb_linear(t, gain=self.final_emb_gain) + 1\n        y = x_proj_pre(mp_silu(normalize(x)))\n        y = mp_silu(y * c.to(y.dtype))\n        out = x_proj_post(y, gain=self.final_out_gain)\n    \n        return out\n\nclass EDMPrecond(torch.nn.Module):\n    def __init__(self,\n        channels = 3, \n        use_fp16 = False,\n        sigma_min = 0,\n        sigma_max = float('inf'),\n        sigma_data  = 1,\n        depth = 6,\n    ):\n        super().__init__()\n\n        self.use_fp16 = use_fp16\n        self.sigma_min = sigma_min\n        self.sigma_max = sigma_max\n\n        self.sigma_data = sigma_data\n        self.model = Network(channels=channels, hidden_size=512, depth=depth)\n\n    def forward(self, x, sigma, force_fp32=False, **model_kwargs):\n\n        x = x.to(torch.float32)\n        sigma = sigma.to(torch.float32).reshape(-1, 1)\n        dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32\n\n        c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)\n        c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()\n        c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()\n        c_noise = sigma.log() / 4\n    \n        F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)\n        assert F_x.dtype == dtype\n        D_x = c_skip * x + c_out * F_x.to(torch.float32)\n\n        return D_x\n\n    def round_sigma(self, sigma):\n        return torch.as_tensor(sigma)\n\n    @torch.no_grad()\n    def sample(self, cond=None, batch_seeds=None, channels=3, num_steps=18):\n\n        device = batch_seeds.device\n        batch_size = batch_seeds.shape[0]\n\n        rnd = None\n        points = batch_seeds\n\n        latents = points.float().to(device)\n\n        points = edm_sampler(self, latents, cond, num_steps=num_steps)\n        return points\n\n    @torch.no_grad()\n    def inverse(self, cond=None, samples=None, channels=3, num_steps=18):\n        return inverse_edm_sampler(self, samples, cond, num_steps=num_steps)\n\n\nclass StackedRandomGenerator:\n    def __init__(self, device, seeds):\n        super().__init__()\n        self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]\n\n    def randn(self, size, **kwargs):\n        assert size[0] == len(self.generators)\n        return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])\n\n    def randn_like(self, input):\n        return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)\n\n    def randint(self, *args, size, **kwargs):\n        assert size[0] == len(self.generators)\n        return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])\n\ndef edm_sampler(\n    net, latents, class_labels=None, randn_like=torch.randn_like,\n    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,\n    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,\n):  \n    # disable S_churn\n    assert S_churn==0\n\n    # Adjust noise levels based on what's supported by the network.\n    sigma_min = max(sigma_min, net.sigma_min)\n    sigma_max = min(sigma_max, net.sigma_max)\n\n    # Time step discretization.\n    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)\n    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho\n    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0\n\n    # Main sampling loop.\n    x_next = latents.to(torch.float64) * t_steps[0]\n    outputs = []\n    outputs.append((x_next / t_steps[0]).detach().cpu().numpy())\n    print(t_steps[0])\n    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1\n        print(t_cur, t_next)\n        x_cur = x_next\n\n        # Increase noise temporarily.\n        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0\n        t_hat = net.round_sigma(t_cur + gamma * t_cur)\n        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)\n        # x_hat = x_cur\n        t_hat = t_cur\n\n        # Euler step.\n        denoised = net(x_hat, t_hat, class_labels).to(torch.float64)\n        d_cur = (x_hat - denoised) / t_hat\n        x_next = x_hat + (t_next - t_hat) * d_cur\n\n        # Apply 2nd order correction.\n        if i < num_steps - 1:\n            denoised = net(x_next, t_next, class_labels).to(torch.float64)\n            d_prime = (x_next - denoised) / t_next\n            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)\n        outputs.append((x_next / (1+t_next**2).sqrt()).detach().cpu().numpy())\n    return x_next, outputs\n\ndef inverse_edm_sampler(\n    net, latents, class_labels=None, randn_like=torch.randn_like,\n    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,\n    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,\n):  \n    # disable S_churn\n    assert S_churn==0\n\n    # Adjust noise levels based on what's supported by the network.\n    sigma_min = max(sigma_min, net.sigma_min)\n    sigma_max = min(sigma_max, net.sigma_max)\n\n    # Time step discretization.\n    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)\n    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho\n    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])+1e-8]) # t_N = 0\n    t_steps = torch.flip(t_steps, [0])#[1:]\n\n    # Main sampling loop.\n    x_next = latents.to(torch.float64)# * t_steps[0]\n\n    # outputs = []\n    outputs = None\n    # outputs.append((x_next / t_steps[0]).detach().cpu().numpy())\n\n    print(t_steps[0])\n    print(x_next.mean(), x_next.std())\n    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1\n        # print('steps', t_cur, t_next)\n        x_cur = x_next\n        # print('cur', (x_cur / t_cur).mean(), (x_cur / t_cur).std())\n\n        # Increase noise temporarily.\n        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0\n        t_hat = net.round_sigma(t_cur + gamma * t_cur)\n        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)\n        x_hat = x_cur\n        t_hat = t_cur\n\n        # Euler step.\n        denoised = net(x_hat, t_hat, class_labels).to(torch.float64)\n        d_cur = (x_hat - denoised) / t_hat\n        x_next = x_hat + (t_next - t_hat) * d_cur\n\n        # Apply 2nd order correction.\n        if i < num_steps - 1:\n            denoised = net(x_next, t_next, class_labels).to(torch.float64)\n            d_prime = (x_next - denoised) / t_next\n            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)\n\n        print('next', (x_next / (1+t_next**2).sqrt()).mean(), (x_next / (1+t_next**2).sqrt()).std())\n\n        # outputs.append((x_next / (1+t_next**2).sqrt()).detach().cpu().numpy())\n    x_next = x_next / (1+t_next**2).sqrt()\n    return x_next, outputs\n\nclass EDMLoss:\n    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=1, dist='Gaussian'):\n        self.P_mean = P_mean\n        self.P_std = P_std\n        self.sigma_data = sigma_data\n\n        self.dist = dist\n\n    def __call__(self, net, inputs, labels=None, augment_pipe=None, init_noise=None):\n        rnd_normal = torch.randn([inputs.shape[0],], device=inputs.device)\n\n        sigma = (rnd_normal * self.P_std + self.P_mean).exp()\n        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2\n        y, augment_labels = augment_pipe(inputs) if augment_pipe is not None else (inputs, None)\n\n        if self.dist == 'Gaussian':\n            n = torch.randn_like(y[:, :3]) * sigma[:, None]\n            if y.shape[1] != 3:\n                c = (torch.rand_like(y[:, 3:]) - 0.5) / np.sqrt(1/12) * sigma[:, None]\n                n = torch.cat([n, c], dim=1)\n        elif self.dist == 'Uniform':\n            n = (torch.rand_like(y) - 0.5) / np.sqrt(1/12) * sigma[:, None]\n        elif self.dist == 'Sphere':\n            n = torch.randn_like(y[:, :3])\n            n = torch.nn.functional.normalize(n, dim=1)\n            n /= np.sqrt(1/3)\n            n = n * sigma[:, None]\n\n        elif self.dist == \"Mesh\":\n            assert init_noise is not None\n            n = init_noise * sigma[:, None]\n        else:\n            raise NotImplementedError\n\n        D_yn = net(y + n, sigma)\n\n        loss = weight[:, None] * ((D_yn - y) ** 2)\n        return loss.mean()"
  },
  {
    "path": "normalize.py",
    "content": "import argparse \n\nimport trimesh\n\nimport math\n\nimport glob\n\nimport numpy as np\n\nparser = argparse.ArgumentParser('Inference', add_help=False)\nparser.add_argument('--path', required=True, type=str)\nparser.add_argument('--output', required=True, type=str)\n\nargs = parser.parse_args()\n\nmodel = trimesh.load(args.path, process=False)\n\ndef normalize_meshes(mesh):\n    mesh.vertices -= (mesh.vertices.max(axis=0) + mesh.vertices.min(axis=0)) / 2\n\n    scale = (1 / np.abs(mesh.vertices).max()) * 0.99\n\n    mesh.vertices *= scale\n\n    points, _ = trimesh.sample.sample_surface(mesh, 10000000)\n\n    mesh.vertices -= points.mean()\n    mesh.vertices /= points.std()\n\n    return mesh\n\nmodel = normalize_meshes(model)\n\n# angle = math.pi / 2\n# direction = [1, 0, 0]\n# center = [0, 0, 0]\n\n# rot_matrix = trimesh.transformations.rotation_matrix(angle, direction, center)\n\n# model.apply_transform(rot_matrix)\n\nmodel.export(args.output)"
  },
  {
    "path": "points.py",
    "content": "import trimesh\n\nimport numpy as np\nimport os\n\nimport torch\nfrom torch.utils import data\n\nclass Points(data.Dataset):\n    def __init__(self, ply_path):\n        points = trimesh.load(ply_path).vertices\n        # self.points = np.array(points)\n        # if os.path.exists('test.npy'):\n        #     points = np.load('test.npy')\n        # else:\n        #     points, _ = trimesh.sample.sample_surface(trimesh.load(ply_path), 50000000*5)\n        #     np.save('test.npy', points)\n        self.points = torch.from_numpy(points)# - 0.12\n        print(self.points.std(), self.points.mean())\n\n    def __len__(self):\n        return self.points.shape[0]# * 16\n\n    def __getitem__(self, idx):\n        # idx = idx % self.points.shape[0]\n        return self.points[idx]"
  },
  {
    "path": "util/lr_decay.py",
    "content": "# --------------------------------------------------------\n# References:\n# MAE: https://github.com/facebookresearch/mae\n# DeiT: https://github.com/facebookresearch/deit\n# BEiT: https://github.com/microsoft/unilm/tree/master/beit\n# --------------------------------------------------------\n\nimport json\n\n\ndef param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):\n    \"\"\"\n    Parameter groups for layer-wise lr decay\n    Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58\n    \"\"\"\n    param_group_names = {}\n    param_groups = {}\n\n    num_layers = len(model.blocks) + 1\n\n    layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))\n\n    for n, p in model.named_parameters():\n        if not p.requires_grad:\n            continue\n\n        # no decay: all 1D parameters and model specific ones\n        if p.ndim == 1 or n in no_weight_decay_list:\n            g_decay = \"no_decay\"\n            this_decay = 0.\n        else:\n            g_decay = \"decay\"\n            this_decay = weight_decay\n            \n        layer_id = get_layer_id_for_vit(n, num_layers)\n        group_name = \"layer_%d_%s\" % (layer_id, g_decay)\n\n        if group_name not in param_group_names:\n            this_scale = layer_scales[layer_id]\n\n            param_group_names[group_name] = {\n                \"lr_scale\": this_scale,\n                \"weight_decay\": this_decay,\n                \"params\": [],\n            }\n            param_groups[group_name] = {\n                \"lr_scale\": this_scale,\n                \"weight_decay\": this_decay,\n                \"params\": [],\n            }\n\n        param_group_names[group_name][\"params\"].append(n)\n        param_groups[group_name][\"params\"].append(p)\n\n    # print(\"parameter groups: \\n%s\" % json.dumps(param_group_names, indent=2))\n\n    return list(param_groups.values())\n\n\ndef get_layer_id_for_vit(name, num_layers):\n    \"\"\"\n    Assign a parameter with its layer id\n    Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33\n    \"\"\"\n    if name in ['cls_token', 'pos_embed']:\n        return 0\n    elif name.startswith('patch_embed'):\n        return 0\n    elif name.startswith('blocks'):\n        return int(name.split('.')[1]) + 1\n    else:\n        return num_layers"
  },
  {
    "path": "util/lr_sched.py",
    "content": "# Copyright (c) Meta Platforms, Inc. and affiliates.\n# All rights reserved.\n\n# This source code is licensed under the license found in the\n# LICENSE file in the root directory of this source tree.\n\nimport math\n\ndef adjust_learning_rate(optimizer, epoch, args):\n    \"\"\"Decay the learning rate with half-cycle cosine after warmup\"\"\"\n    if epoch < args.warmup_epochs:\n        lr = args.lr * epoch / args.warmup_epochs \n    else:\n        lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \\\n            (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))\n    for param_group in optimizer.param_groups:\n        if \"lr_scale\" in param_group:\n            param_group[\"lr\"] = lr * param_group[\"lr_scale\"]\n        else:\n            param_group[\"lr\"] = lr\n    return lr"
  },
  {
    "path": "util/misc.py",
    "content": "# --------------------------------------------------------\n# References:\n# MAE: https://github.com/facebookresearch/mae\n# DeiT: https://github.com/facebookresearch/deit\n# BEiT: https://github.com/microsoft/unilm/tree/master/beit\n# --------------------------------------------------------\n\nimport builtins\nimport datetime\nimport os\nimport time\nfrom collections import defaultdict, deque\nfrom pathlib import Path\n\nimport torch\nimport torch.distributed as dist\n\nif torch.__version__[0] == '2':\n    from torch import inf\nelse:\n    from torch._six import inf\n\n\nclass SmoothedValue(object):\n    \"\"\"Track a series of values and provide access to smoothed values over a\n    window or the global series average.\n    \"\"\"\n\n    def __init__(self, window_size=20, fmt=None):\n        if fmt is None:\n            fmt = \"{median:.5f} ({global_avg:.5f})\"\n        self.deque = deque(maxlen=window_size)\n        self.total = 0.0\n        self.count = 0\n        self.fmt = fmt\n\n    def update(self, value, n=1):\n        self.deque.append(value)\n        self.count += n\n        self.total += value * n\n\n    def synchronize_between_processes(self):\n        \"\"\"\n        Warning: does not synchronize the deque!\n        \"\"\"\n        if not is_dist_avail_and_initialized():\n            return\n        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')\n        dist.barrier()\n        dist.all_reduce(t)\n        t = t.tolist()\n        self.count = int(t[0])\n        self.total = t[1]\n\n    @property\n    def median(self):\n        d = torch.tensor(list(self.deque))\n        return d.median().item()\n\n    @property\n    def avg(self):\n        d = torch.tensor(list(self.deque), dtype=torch.float32)\n        return d.mean().item()\n\n    @property\n    def global_avg(self):\n        return self.total / self.count\n\n    @property\n    def max(self):\n        return max(self.deque)\n\n    @property\n    def value(self):\n        return self.deque[-1]\n\n    def __str__(self):\n        return self.fmt.format(\n            median=self.median,\n            avg=self.avg,\n            global_avg=self.global_avg,\n            max=self.max,\n            value=self.value)\n\n\nclass MetricLogger(object):\n    def __init__(self, delimiter=\"\\t\"):\n        self.meters = defaultdict(SmoothedValue)\n        self.delimiter = delimiter\n\n    def update(self, **kwargs):\n        for k, v in kwargs.items():\n            if v is None:\n                continue\n            if isinstance(v, torch.Tensor):\n                v = v.item()\n            assert isinstance(v, (float, int))\n            self.meters[k].update(v)\n\n    def __getattr__(self, attr):\n        if attr in self.meters:\n            return self.meters[attr]\n        if attr in self.__dict__:\n            return self.__dict__[attr]\n        raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n            type(self).__name__, attr))\n\n    def __str__(self):\n        loss_str = []\n        for name, meter in self.meters.items():\n            loss_str.append(\n                \"{}: {}\".format(name, str(meter))\n            )\n        return self.delimiter.join(loss_str)\n\n    def synchronize_between_processes(self):\n        for meter in self.meters.values():\n            meter.synchronize_between_processes()\n\n    def add_meter(self, name, meter):\n        self.meters[name] = meter\n\n    def log_every(self, iterable, print_freq, header=None):\n        i = 0\n        if not header:\n            header = ''\n        start_time = time.time()\n        end = time.time()\n        iter_time = SmoothedValue(fmt='{avg:.4f}')\n        data_time = SmoothedValue(fmt='{avg:.4f}')\n        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'\n        log_msg = [\n            header,\n            '[{0' + space_fmt + '}/{1}]',\n            'eta: {eta}',\n            '{meters}',\n            'time: {time}',\n            'data: {data}'\n        ]\n        if torch.cuda.is_available():\n            log_msg.append('max mem: {memory:.0f}')\n        log_msg = self.delimiter.join(log_msg)\n        MB = 1024.0 * 1024.0\n        for obj in iterable:\n            data_time.update(time.time() - end)\n            yield obj\n            iter_time.update(time.time() - end)\n            if i % print_freq == 0 or i == len(iterable) - 1:\n                eta_seconds = iter_time.global_avg * (len(iterable) - i)\n                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))\n                if torch.cuda.is_available():\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time),\n                        memory=torch.cuda.max_memory_allocated() / MB))\n                else:\n                    print(log_msg.format(\n                        i, len(iterable), eta=eta_string,\n                        meters=str(self),\n                        time=str(iter_time), data=str(data_time)))\n            i += 1\n            end = time.time()\n        total_time = time.time() - start_time\n        total_time_str = str(datetime.timedelta(seconds=int(total_time)))\n        print('{} Total time: {} ({:.4f} s / it)'.format(\n            header, total_time_str, total_time / len(iterable)))\n\n\ndef setup_for_distributed(is_master):\n    \"\"\"\n    This function disables printing when not in master process\n    \"\"\"\n    builtin_print = builtins.print\n\n    def print(*args, **kwargs):\n        force = kwargs.pop('force', False)\n        force = force or (get_world_size() > 8)\n        if is_master:# or force:\n            now = datetime.datetime.now().time()\n            builtin_print('[{}] '.format(now), end='')  # print with time stamp\n            builtin_print(*args, **kwargs)\n\n    builtins.print = print\n\n\ndef is_dist_avail_and_initialized():\n    if not dist.is_available():\n        return False\n    if not dist.is_initialized():\n        return False\n    return True\n\n\ndef get_world_size():\n    if not is_dist_avail_and_initialized():\n        return 1\n    return dist.get_world_size()\n\n\ndef get_rank():\n    if not is_dist_avail_and_initialized():\n        return 0\n    return dist.get_rank()\n\n\ndef is_main_process():\n    return get_rank() == 0\n\n\ndef save_on_master(*args, **kwargs):\n    if is_main_process():\n        torch.save(*args, **kwargs)\n\n\ndef init_distributed_mode(args):\n    if args.dist_on_itp:\n        args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])\n        args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])\n        args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])\n        args.dist_url = \"tcp://%s:%s\" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT'])\n        os.environ['LOCAL_RANK'] = str(args.gpu)\n        os.environ['RANK'] = str(args.rank)\n        os.environ['WORLD_SIZE'] = str(args.world_size)\n        # [\"RANK\", \"WORLD_SIZE\", \"MASTER_ADDR\", \"MASTER_PORT\", \"LOCAL_RANK\"]\n    elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:\n        args.rank = int(os.environ[\"RANK\"])\n        args.world_size = int(os.environ['WORLD_SIZE'])\n        args.gpu = int(os.environ['LOCAL_RANK'])\n    elif 'SLURM_PROCID' in os.environ:\n        args.rank = int(os.environ['SLURM_PROCID'])\n        args.gpu = args.rank % torch.cuda.device_count()\n    else:\n        print('Not using distributed mode')\n        setup_for_distributed(is_master=True)  # hack\n        args.distributed = False\n        return\n\n    args.distributed = True\n\n    torch.cuda.set_device(args.gpu)\n    args.dist_backend = 'nccl'\n    print('| distributed init (rank {}): {}, gpu {}'.format(\n        args.rank, args.dist_url, args.gpu), flush=True)\n    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,\n                                         world_size=args.world_size, rank=args.rank)\n    torch.distributed.barrier()\n    setup_for_distributed(args.rank == 0)\n\n\nclass NativeScalerWithGradNormCount:\n    state_dict_key = \"amp_scaler\"\n\n    def __init__(self):\n        self._scaler = torch.cuda.amp.GradScaler()\n\n    def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):\n        self._scaler.scale(loss).backward(create_graph=create_graph)\n        if update_grad:\n            if clip_grad is not None:\n                assert parameters is not None\n                self._scaler.unscale_(optimizer)  # unscale the gradients of optimizer's assigned params in-place\n                norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)\n            else:\n                self._scaler.unscale_(optimizer)\n                norm = get_grad_norm_(parameters)\n            self._scaler.step(optimizer)\n            self._scaler.update()\n        else:\n            norm = None\n        return norm\n\n    def state_dict(self):\n        return self._scaler.state_dict()\n\n    def load_state_dict(self, state_dict):\n        self._scaler.load_state_dict(state_dict)\n\n\ndef get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:\n    if isinstance(parameters, torch.Tensor):\n        parameters = [parameters]\n    parameters = [p for p in parameters if p.grad is not None]\n    norm_type = float(norm_type)\n    if len(parameters) == 0:\n        return torch.tensor(0.)\n    device = parameters[0].grad.device\n    if norm_type == inf:\n        total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)\n    else:\n        total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)\n    return total_norm\n\n\ndef save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler):\n    output_dir = Path(args.output_dir)\n    epoch_name = str(epoch)\n    if loss_scaler is not None:\n        checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]\n        for checkpoint_path in checkpoint_paths:\n            to_save = {\n                'model': model_without_ddp.state_dict(),\n                'optimizer': optimizer.state_dict(),\n                'epoch': epoch,\n                'scaler': loss_scaler.state_dict(),\n                'args': args,\n            }\n\n            save_on_master(to_save, checkpoint_path)\n    else:\n        client_state = {'epoch': epoch}\n        model.save_checkpoint(save_dir=args.output_dir, tag=\"checkpoint-%s\" % epoch_name, client_state=client_state)\n\n\ndef load_model(args, model_without_ddp, optimizer, loss_scaler):\n    if args.resume:\n        if args.resume.startswith('https'):\n            checkpoint = torch.hub.load_state_dict_from_url(\n                args.resume, map_location='cpu', check_hash=True)\n        else:\n            checkpoint = torch.load(args.resume, map_location='cpu')\n        model_without_ddp.load_state_dict(checkpoint['model'])\n        print(\"Resume checkpoint %s\" % args.resume)\n        if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):\n            optimizer.load_state_dict(checkpoint['optimizer'])\n            args.start_epoch = checkpoint['epoch'] + 1\n            if 'scaler' in checkpoint:\n                loss_scaler.load_state_dict(checkpoint['scaler'])\n            print(\"With optim & sched!\")\n\n\ndef all_reduce_mean(x):\n    world_size = get_world_size()\n    if world_size > 1:\n        x_reduce = torch.tensor(x).cuda()\n        dist.all_reduce(x_reduce)\n        x_reduce /= world_size\n        return x_reduce.item()\n    else:\n        return x\n"
  }
]