[
  {
    "path": ".gitignore",
    "content": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packaging\n.Python\nenv/\nbuild/\ndevelop-eggs/\ndist/\ndownloads/\neggs/\n.eggs/\nlib/\nlib64/\nparts/\nsdist/\nvar/\nwheels/\n*.egg-info/\n.installed.cfg\n*.egg\n\n# PyInstaller\n#  Usually these files are written by a python script from a template\n#  before PyInstaller builds the exe, so as to inject date/other infos into it.\n*.manifest\n*.spec\n\n# Installer logs\npip-log.txt\npip-delete-this-directory.txt\n\n# Unit test / coverage reports\nhtmlcov/\n.tox/\n.coverage\n.coverage.*\n.cache\nnosetests.xml\ncoverage.xml\n*.cover\n.hypothesis/\n\n# Translations\n*.mo\n*.pot\n\n# Django stuff:\n*.log\nlocal_settings.py\n\n# Flask stuff:\ninstance/\n.webassets-cache\n\n# Scrapy stuff:\n.scrapy\n\n# Sphinx documentation\ndocs/_build/\n\n# PyBuilder\ntarget/\n\n# Jupyter Notebook\n.ipynb_checkpoints\n\n# pyenv\n.python-version\n\n# celery beat schedule file\ncelerybeat-schedule\n\n# SageMath parsed files\n*.sage.py\n\n# dotenv\n.env\n\n# virtualenv\n.venv\nvenv/\nENV/\n\n# Spyder project settings\n.spyderproject\n.spyproject\n\n# Rope project settings\n.ropeproject\n\n# mkdocs documentation\n/site\n\n# mypy\n.mypy_cache/\n\n# macOS\n.DS_Store\n\n# Datasets\ndatasets/\n\n# Output\noutput/"
  },
  {
    "path": ".style.yapf",
    "content": "[style]\nbased_on_style = pep8\ncolumn_limit = 119\nspaces_before_comment = 4\nsplit_before_logical_operator = True\nuse_tabs = False"
  },
  {
    "path": ".yapfignore",
    "content": "config.py\nmodels/decoder.py\nmodels/encoder.py\nmodels/merger.py\nmodels/refiner.py\n"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2018 Haozhe Xie\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# Pix2Vox\n\n[![Quality Gate Status](https://sonarcloud.io/api/project_badges/measure?project=hzxie_Pix2Vox&metric=alert_status)](https://sonarcloud.io/summary/new_code?id=hzxie_Pix2Vox)\n[![codefactor badge](https://www.codefactor.io/repository/github/hzxie/Pix2Vox/badge)](https://www.codefactor.io/repository/github/hzxie/Pix2Vox)\n\nThis repository contains the source code for the paper [Pix2Vox: Context-aware 3D Reconstruction from Single and Multi-view Images](https://arxiv.org/abs/1901.11153). The follow-up work [Pix2Vox++: Multi-scale Context-aware 3D Object Reconstruction from Single and Multiple Images](https://arxiv.org/abs/2006.12250) has been published in *International Journal of Computer Vision (IJCV)*.\n\n![Overview](https://www.infinitescript.com/projects/Pix2Vox/Pix2Vox-Overview.jpg)\n\n## Cite this work\n\n```\n@inproceedings{xie2019pix2vox,\n  title={Pix2Vox: Context-aware 3D Reconstruction from Single and Multi-view Images},\n  author={Xie, Haozhe and \n          Yao, Hongxun and \n          Sun, Xiaoshuai and \n          Zhou, Shangchen and \n          Zhang, Shengping},\n  booktitle={ICCV},\n  year={2019}\n}\n```\n\n## Datasets\n\nWe use the [ShapeNet](https://www.shapenet.org/) and [Pix3D](http://pix3d.csail.mit.edu/) datasets in our experiments, which are available below:\n\n- ShapeNet rendering images: http://cvgl.stanford.edu/data2/ShapeNetRendering.tgz\n- ShapeNet voxelized models: http://cvgl.stanford.edu/data2/ShapeNetVox32.tgz\n- Pix3D images & voxelized models: http://pix3d.csail.mit.edu/data/pix3d.zip\n\n## Pretrained Models\n\nThe pretrained models on ShapeNet are available as follows:\n\n- [Pix2Vox-A](https://gateway.infinitescript.com/?fileName=Pix2Vox-A-ShapeNet.pth) (457.0 MB)\n- [Pix2Vox-F](https://gateway.infinitescript.com/?fileName=Pix2Vox-F-ShapeNet.pth) (29.8 MB)\n\n## Prerequisites\n\n#### Clone the Code Repository\n\n```\ngit clone https://github.com/hzxie/Pix2Vox.git\n```\n\n#### Install Python Denpendencies\n\n```\ncd Pix2Vox\npip install -r requirements.txt\n```\n\n#### Update Settings in `config.py`\n\nYou need to update the file path of the datasets:\n\n```\n__C.DATASETS.SHAPENET.RENDERING_PATH        = '/path/to/Datasets/ShapeNet/ShapeNetRendering/%s/%s/rendering/%02d.png'\n__C.DATASETS.SHAPENET.VOXEL_PATH            = '/path/to/Datasets/ShapeNet/ShapeNetVox32/%s/%s/model.binvox'\n__C.DATASETS.PASCAL3D.ANNOTATION_PATH       = '/path/to/Datasets/PASCAL3D/Annotations/%s_imagenet/%s.mat'\n__C.DATASETS.PASCAL3D.RENDERING_PATH        = '/path/to/Datasets/PASCAL3D/Images/%s_imagenet/%s.JPEG'\n__C.DATASETS.PASCAL3D.VOXEL_PATH            = '/path/to/Datasets/PASCAL3D/CAD/%s/%02d.binvox'\n__C.DATASETS.PIX3D.ANNOTATION_PATH          = '/path/to/Datasets/Pix3D/pix3d.json'\n__C.DATASETS.PIX3D.RENDERING_PATH           = '/path/to/Datasets/Pix3D/img/%s/%s.%s'\n__C.DATASETS.PIX3D.VOXEL_PATH               = '/path/to/Datasets/Pix3D/model/%s/%s/%s.binvox'\n```\n\n## Get Started\n\nTo train Pix2Vox, you can simply use the following command:\n\n```\npython3 runner.py\n```\n\nTo test Pix2Vox, you can use the following command:\n\n```\npython3 runner.py --test --weights=/path/to/pretrained/model.pth\n```\n\nIf you want to train/test Pix2Vox-F, you need to checkout to `Pix2Vox-F` branch first.\n\n```\ngit checkout -b Pix2Vox-F origin/Pix2Vox-F\n```\n\n## License\n\nThis project is open sourced under MIT license.\n"
  },
  {
    "path": "config.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nfrom easydict import EasyDict as edict\n\n__C                                         = edict()\ncfg                                         = __C\n\n#\n# Dataset Config\n#\n__C.DATASETS                                = edict()\n__C.DATASETS.SHAPENET                       = edict()\n__C.DATASETS.SHAPENET.TAXONOMY_FILE_PATH    = './datasets/ShapeNet.json'\n# __C.DATASETS.SHAPENET.TAXONOMY_FILE_PATH  = './datasets/PascalShapeNet.json'\n__C.DATASETS.SHAPENET.RENDERING_PATH        = '/home/hzxie/Datasets/ShapeNet/ShapeNetRendering/%s/%s/rendering/%02d.png'\n# __C.DATASETS.SHAPENET.RENDERING_PATH      = '/home/hzxie/Datasets/ShapeNet/PascalShapeNetRendering/%s/%s/render_%04d.jpg'\n__C.DATASETS.SHAPENET.VOXEL_PATH            = '/home/hzxie/Datasets/ShapeNet/ShapeNetVox32/%s/%s/model.binvox'\n__C.DATASETS.PASCAL3D                       = edict()\n__C.DATASETS.PASCAL3D.TAXONOMY_FILE_PATH    = './datasets/Pascal3D.json'\n__C.DATASETS.PASCAL3D.ANNOTATION_PATH       = '/home/hzxie/Datasets/PASCAL3D/Annotations/%s_imagenet/%s.mat'\n__C.DATASETS.PASCAL3D.RENDERING_PATH        = '/home/hzxie/Datasets/PASCAL3D/Images/%s_imagenet/%s.JPEG'\n__C.DATASETS.PASCAL3D.VOXEL_PATH            = '/home/hzxie/Datasets/PASCAL3D/CAD/%s/%02d.binvox'\n__C.DATASETS.PIX3D                          = edict()\n__C.DATASETS.PIX3D.TAXONOMY_FILE_PATH       = './datasets/Pix3D.json'\n__C.DATASETS.PIX3D.ANNOTATION_PATH          = '/home/hzxie/Datasets/Pix3D/pix3d.json'\n__C.DATASETS.PIX3D.RENDERING_PATH           = '/home/hzxie/Datasets/Pix3D/img/%s/%s.%s'\n__C.DATASETS.PIX3D.VOXEL_PATH               = '/home/hzxie/Datasets/Pix3D/model/%s/%s/%s.binvox'\n\n#\n# Dataset\n#\n__C.DATASET                                 = edict()\n__C.DATASET.MEAN                            = [0.5, 0.5, 0.5]\n__C.DATASET.STD                             = [0.5, 0.5, 0.5]\n__C.DATASET.TRAIN_DATASET                   = 'ShapeNet'\n__C.DATASET.TEST_DATASET                    = 'ShapeNet'\n# __C.DATASET.TEST_DATASET                  = 'Pascal3D'\n# __C.DATASET.TEST_DATASET                  = 'Pix3D'\n\n#\n# Common\n#\n__C.CONST                                   = edict()\n__C.CONST.DEVICE                            = '0'\n__C.CONST.RNG_SEED                          = 0\n__C.CONST.IMG_W                             = 224       # Image width for input\n__C.CONST.IMG_H                             = 224       # Image height for input\n__C.CONST.N_VOX                             = 32\n__C.CONST.BATCH_SIZE                        = 64\n__C.CONST.N_VIEWS_RENDERING                 = 1         # Dummy property for Pascal 3D\n__C.CONST.CROP_IMG_W                        = 128       # Dummy property for Pascal 3D\n__C.CONST.CROP_IMG_H                        = 128       # Dummy property for Pascal 3D\n\n#\n# Directories\n#\n__C.DIR                                     = edict()\n__C.DIR.OUT_PATH                            = './output'\n__C.DIR.RANDOM_BG_PATH                      = '/home/hzxie/Datasets/SUN2012/JPEGImages'\n\n#\n# Network\n#\n__C.NETWORK                                 = edict()\n__C.NETWORK.LEAKY_VALUE                     = .2\n__C.NETWORK.TCONV_USE_BIAS                  = False\n__C.NETWORK.USE_REFINER                     = True\n__C.NETWORK.USE_MERGER                      = True\n\n#\n# Training\n#\n__C.TRAIN                                   = edict()\n__C.TRAIN.RESUME_TRAIN                      = False\n__C.TRAIN.NUM_WORKER                        = 4             # number of data workers\n__C.TRAIN.NUM_EPOCHES                       = 250\n__C.TRAIN.BRIGHTNESS                        = .4\n__C.TRAIN.CONTRAST                          = .4\n__C.TRAIN.SATURATION                        = .4\n__C.TRAIN.NOISE_STD                         = .1\n__C.TRAIN.RANDOM_BG_COLOR_RANGE             = [[225, 255], [225, 255], [225, 255]]\n__C.TRAIN.POLICY                            = 'adam'        # available options: sgd, adam\n__C.TRAIN.EPOCH_START_USE_REFINER           = 0\n__C.TRAIN.EPOCH_START_USE_MERGER            = 0\n__C.TRAIN.ENCODER_LEARNING_RATE             = 1e-3\n__C.TRAIN.DECODER_LEARNING_RATE             = 1e-3\n__C.TRAIN.REFINER_LEARNING_RATE             = 1e-3\n__C.TRAIN.MERGER_LEARNING_RATE              = 1e-4\n__C.TRAIN.ENCODER_LR_MILESTONES             = [150]\n__C.TRAIN.DECODER_LR_MILESTONES             = [150]\n__C.TRAIN.REFINER_LR_MILESTONES             = [150]\n__C.TRAIN.MERGER_LR_MILESTONES              = [150]\n__C.TRAIN.BETAS                             = (.9, .999)\n__C.TRAIN.MOMENTUM                          = .9\n__C.TRAIN.GAMMA                             = .5\n__C.TRAIN.SAVE_FREQ                         = 10            # weights will be overwritten every save_freq epoch\n__C.TRAIN.UPDATE_N_VIEWS_RENDERING          = False\n\n#\n# Testing options\n#\n__C.TEST                                    = edict()\n__C.TEST.RANDOM_BG_COLOR_RANGE              = [[240, 240], [240, 240], [240, 240]]\n__C.TEST.VOXEL_THRESH                       = [.2, .3, .4, .5]\n"
  },
  {
    "path": "core/__init__.py",
    "content": ""
  },
  {
    "path": "core/test.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nimport json\nimport numpy as np\nimport os\nimport torch\nimport torch.backends.cudnn\nimport torch.utils.data\n\nimport utils.binvox_visualization\nimport utils.data_loaders\nimport utils.data_transforms\nimport utils.network_utils\n\nfrom datetime import datetime as dt\n\nfrom models.encoder import Encoder\nfrom models.decoder import Decoder\nfrom models.refiner import Refiner\nfrom models.merger import Merger\n\n\ndef test_net(cfg,\n             epoch_idx=-1,\n             output_dir=None,\n             test_data_loader=None,\n             test_writer=None,\n             encoder=None,\n             decoder=None,\n             refiner=None,\n             merger=None):\n    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use\n    torch.backends.cudnn.benchmark = True\n\n    # Load taxonomies of dataset\n    taxonomies = []\n    with open(cfg.DATASETS[cfg.DATASET.TEST_DATASET.upper()].TAXONOMY_FILE_PATH, encoding='utf-8') as file:\n        taxonomies = json.loads(file.read())\n    taxonomies = {t['taxonomy_id']: t for t in taxonomies}\n\n    # Set up data loader\n    if test_data_loader is None:\n        # Set up data augmentation\n        IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W\n        CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W\n        test_transforms = utils.data_transforms.Compose([\n            utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),\n            utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),\n            utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),\n            utils.data_transforms.ToTensor(),\n        ])\n\n        dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg)\n        test_data_loader = torch.utils.data.DataLoader(dataset=dataset_loader.get_dataset(\n            utils.data_loaders.DatasetType.TEST, cfg.CONST.N_VIEWS_RENDERING, test_transforms),\n                                                       batch_size=1,\n                                                       num_workers=1,\n                                                       pin_memory=True,\n                                                       shuffle=False)\n\n    # Set up networks\n    if decoder is None or encoder is None:\n        encoder = Encoder(cfg)\n        decoder = Decoder(cfg)\n        refiner = Refiner(cfg)\n        merger = Merger(cfg)\n\n        if torch.cuda.is_available():\n            encoder = torch.nn.DataParallel(encoder).cuda()\n            decoder = torch.nn.DataParallel(decoder).cuda()\n            refiner = torch.nn.DataParallel(refiner).cuda()\n            merger = torch.nn.DataParallel(merger).cuda()\n\n        print('[INFO] %s Loading weights from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))\n        checkpoint = torch.load(cfg.CONST.WEIGHTS)\n        epoch_idx = checkpoint['epoch_idx']\n        encoder.load_state_dict(checkpoint['encoder_state_dict'])\n        decoder.load_state_dict(checkpoint['decoder_state_dict'])\n\n        if cfg.NETWORK.USE_REFINER:\n            refiner.load_state_dict(checkpoint['refiner_state_dict'])\n        if cfg.NETWORK.USE_MERGER:\n            merger.load_state_dict(checkpoint['merger_state_dict'])\n\n    # Set up loss functions\n    bce_loss = torch.nn.BCELoss()\n\n    # Testing loop\n    n_samples = len(test_data_loader)\n    test_iou = dict()\n    encoder_losses = utils.network_utils.AverageMeter()\n    refiner_losses = utils.network_utils.AverageMeter()\n\n    # Switch models to evaluation mode\n    encoder.eval()\n    decoder.eval()\n    refiner.eval()\n    merger.eval()\n\n    for sample_idx, (taxonomy_id, sample_name, rendering_images, ground_truth_volume) in enumerate(test_data_loader):\n        taxonomy_id = taxonomy_id[0] if isinstance(taxonomy_id[0], str) else taxonomy_id[0].item()\n        sample_name = sample_name[0]\n\n        with torch.no_grad():\n            # Get data from data loader\n            rendering_images = utils.network_utils.var_or_cuda(rendering_images)\n            ground_truth_volume = utils.network_utils.var_or_cuda(ground_truth_volume)\n\n            # Test the encoder, decoder, refiner and merger\n            image_features = encoder(rendering_images)\n            raw_features, generated_volume = decoder(image_features)\n\n            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:\n                generated_volume = merger(raw_features, generated_volume)\n            else:\n                generated_volume = torch.mean(generated_volume, dim=1)\n            encoder_loss = bce_loss(generated_volume, ground_truth_volume) * 10\n\n            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:\n                generated_volume = refiner(generated_volume)\n                refiner_loss = bce_loss(generated_volume, ground_truth_volume) * 10\n            else:\n                refiner_loss = encoder_loss\n\n            # Append loss and accuracy to average metrics\n            encoder_losses.update(encoder_loss.item())\n            refiner_losses.update(refiner_loss.item())\n\n            # IoU per sample\n            sample_iou = []\n            for th in cfg.TEST.VOXEL_THRESH:\n                _volume = torch.ge(generated_volume, th).float()\n                intersection = torch.sum(_volume.mul(ground_truth_volume)).float()\n                union = torch.sum(torch.ge(_volume.add(ground_truth_volume), 1)).float()\n                sample_iou.append((intersection / union).item())\n\n            # IoU per taxonomy\n            if taxonomy_id not in test_iou:\n                test_iou[taxonomy_id] = {'n_samples': 0, 'iou': []}\n            test_iou[taxonomy_id]['n_samples'] += 1\n            test_iou[taxonomy_id]['iou'].append(sample_iou)\n\n            # Append generated volumes to TensorBoard\n            if output_dir and sample_idx < 3:\n                img_dir = output_dir % 'images'\n                # Volume Visualization\n                gv = generated_volume.cpu().numpy()\n                rendering_views = utils.binvox_visualization.get_volume_views(gv, os.path.join(img_dir, 'test'),\n                                                                              epoch_idx)\n                test_writer.add_image('Test Sample#%02d/Volume Reconstructed' % sample_idx, rendering_views, epoch_idx)\n                gtv = ground_truth_volume.cpu().numpy()\n                rendering_views = utils.binvox_visualization.get_volume_views(gtv, os.path.join(img_dir, 'test'),\n                                                                              epoch_idx)\n                test_writer.add_image('Test Sample#%02d/Volume GroundTruth' % sample_idx, rendering_views, epoch_idx)\n\n            # Print sample loss and IoU\n            print('[INFO] %s Test[%d/%d] Taxonomy = %s Sample = %s EDLoss = %.4f RLoss = %.4f IoU = %s' %\n                  (dt.now(), sample_idx + 1, n_samples, taxonomy_id, sample_name, encoder_loss.item(),\n                   refiner_loss.item(), ['%.4f' % si for si in sample_iou]))\n\n    # Output testing results\n    mean_iou = []\n    for taxonomy_id in test_iou:\n        test_iou[taxonomy_id]['iou'] = np.mean(test_iou[taxonomy_id]['iou'], axis=0)\n        mean_iou.append(test_iou[taxonomy_id]['iou'] * test_iou[taxonomy_id]['n_samples'])\n    mean_iou = np.sum(mean_iou, axis=0) / n_samples\n\n    # Print header\n    print('============================ TEST RESULTS ============================')\n    print('Taxonomy', end='\\t')\n    print('#Sample', end='\\t')\n    print('Baseline', end='\\t')\n    for th in cfg.TEST.VOXEL_THRESH:\n        print('t=%.2f' % th, end='\\t')\n    print()\n    # Print body\n    for taxonomy_id in test_iou:\n        print('%s' % taxonomies[taxonomy_id]['taxonomy_name'].ljust(8), end='\\t')\n        print('%d' % test_iou[taxonomy_id]['n_samples'], end='\\t')\n        if 'baseline' in taxonomies[taxonomy_id]:\n            print('%.4f' % taxonomies[taxonomy_id]['baseline']['%d-view' % cfg.CONST.N_VIEWS_RENDERING], end='\\t\\t')\n        else:\n            print('N/a', end='\\t\\t')\n\n        for ti in test_iou[taxonomy_id]['iou']:\n            print('%.4f' % ti, end='\\t')\n        print()\n    # Print mean IoU for each threshold\n    print('Overall ', end='\\t\\t\\t\\t')\n    for mi in mean_iou:\n        print('%.4f' % mi, end='\\t')\n    print('\\n')\n\n    # Add testing results to TensorBoard\n    max_iou = np.max(mean_iou)\n    if test_writer is not None:\n        test_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx)\n        test_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx)\n        test_writer.add_scalar('Refiner/IoU', max_iou, epoch_idx)\n\n    return max_iou\n"
  },
  {
    "path": "core/train.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nimport os\nimport random\nimport torch\nimport torch.backends.cudnn\nimport torch.utils.data\n\nimport utils.binvox_visualization\nimport utils.data_loaders\nimport utils.data_transforms\nimport utils.network_utils\n\nfrom datetime import datetime as dt\nfrom tensorboardX import SummaryWriter\nfrom time import time\n\nfrom core.test import test_net\nfrom models.encoder import Encoder\nfrom models.decoder import Decoder\nfrom models.refiner import Refiner\nfrom models.merger import Merger\n\n\ndef train_net(cfg):\n    # Enable the inbuilt cudnn auto-tuner to find the best algorithm to use\n    torch.backends.cudnn.benchmark = True\n\n    # Set up data augmentation\n    IMG_SIZE = cfg.CONST.IMG_H, cfg.CONST.IMG_W\n    CROP_SIZE = cfg.CONST.CROP_IMG_H, cfg.CONST.CROP_IMG_W\n    train_transforms = utils.data_transforms.Compose([\n        utils.data_transforms.RandomCrop(IMG_SIZE, CROP_SIZE),\n        utils.data_transforms.RandomBackground(cfg.TRAIN.RANDOM_BG_COLOR_RANGE),\n        utils.data_transforms.ColorJitter(cfg.TRAIN.BRIGHTNESS, cfg.TRAIN.CONTRAST, cfg.TRAIN.SATURATION),\n        utils.data_transforms.RandomNoise(cfg.TRAIN.NOISE_STD),\n        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),\n        utils.data_transforms.RandomFlip(),\n        utils.data_transforms.RandomPermuteRGB(),\n        utils.data_transforms.ToTensor(),\n    ])\n    val_transforms = utils.data_transforms.Compose([\n        utils.data_transforms.CenterCrop(IMG_SIZE, CROP_SIZE),\n        utils.data_transforms.RandomBackground(cfg.TEST.RANDOM_BG_COLOR_RANGE),\n        utils.data_transforms.Normalize(mean=cfg.DATASET.MEAN, std=cfg.DATASET.STD),\n        utils.data_transforms.ToTensor(),\n    ])\n\n    # Set up data loader\n    train_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TRAIN_DATASET](cfg)\n    val_dataset_loader = utils.data_loaders.DATASET_LOADER_MAPPING[cfg.DATASET.TEST_DATASET](cfg)\n    train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset_loader.get_dataset(\n        utils.data_loaders.DatasetType.TRAIN, cfg.CONST.N_VIEWS_RENDERING, train_transforms),\n                                                    batch_size=cfg.CONST.BATCH_SIZE,\n                                                    num_workers=cfg.TRAIN.NUM_WORKER,\n                                                    pin_memory=True,\n                                                    shuffle=True,\n                                                    drop_last=True)\n    val_data_loader = torch.utils.data.DataLoader(dataset=val_dataset_loader.get_dataset(\n        utils.data_loaders.DatasetType.VAL, cfg.CONST.N_VIEWS_RENDERING, val_transforms),\n                                                  batch_size=1,\n                                                  num_workers=1,\n                                                  pin_memory=True,\n                                                  shuffle=False)\n\n    # Set up networks\n    encoder = Encoder(cfg)\n    decoder = Decoder(cfg)\n    refiner = Refiner(cfg)\n    merger = Merger(cfg)\n    print('[DEBUG] %s Parameters in Encoder: %d.' % (dt.now(), utils.network_utils.count_parameters(encoder)))\n    print('[DEBUG] %s Parameters in Decoder: %d.' % (dt.now(), utils.network_utils.count_parameters(decoder)))\n    print('[DEBUG] %s Parameters in Refiner: %d.' % (dt.now(), utils.network_utils.count_parameters(refiner)))\n    print('[DEBUG] %s Parameters in Merger: %d.' % (dt.now(), utils.network_utils.count_parameters(merger)))\n\n    # Initialize weights of networks\n    encoder.apply(utils.network_utils.init_weights)\n    decoder.apply(utils.network_utils.init_weights)\n    refiner.apply(utils.network_utils.init_weights)\n    merger.apply(utils.network_utils.init_weights)\n\n    # Set up solver\n    if cfg.TRAIN.POLICY == 'adam':\n        encoder_solver = torch.optim.Adam(filter(lambda p: p.requires_grad, encoder.parameters()),\n                                          lr=cfg.TRAIN.ENCODER_LEARNING_RATE,\n                                          betas=cfg.TRAIN.BETAS)\n        decoder_solver = torch.optim.Adam(decoder.parameters(),\n                                          lr=cfg.TRAIN.DECODER_LEARNING_RATE,\n                                          betas=cfg.TRAIN.BETAS)\n        refiner_solver = torch.optim.Adam(refiner.parameters(),\n                                          lr=cfg.TRAIN.REFINER_LEARNING_RATE,\n                                          betas=cfg.TRAIN.BETAS)\n        merger_solver = torch.optim.Adam(merger.parameters(), lr=cfg.TRAIN.MERGER_LEARNING_RATE, betas=cfg.TRAIN.BETAS)\n    elif cfg.TRAIN.POLICY == 'sgd':\n        encoder_solver = torch.optim.SGD(filter(lambda p: p.requires_grad, encoder.parameters()),\n                                         lr=cfg.TRAIN.ENCODER_LEARNING_RATE,\n                                         momentum=cfg.TRAIN.MOMENTUM)\n        decoder_solver = torch.optim.SGD(decoder.parameters(),\n                                         lr=cfg.TRAIN.DECODER_LEARNING_RATE,\n                                         momentum=cfg.TRAIN.MOMENTUM)\n        refiner_solver = torch.optim.SGD(refiner.parameters(),\n                                         lr=cfg.TRAIN.REFINER_LEARNING_RATE,\n                                         momentum=cfg.TRAIN.MOMENTUM)\n        merger_solver = torch.optim.SGD(merger.parameters(),\n                                        lr=cfg.TRAIN.MERGER_LEARNING_RATE,\n                                        momentum=cfg.TRAIN.MOMENTUM)\n    else:\n        raise Exception('[FATAL] %s Unknown optimizer %s.' % (dt.now(), cfg.TRAIN.POLICY))\n\n    # Set up learning rate scheduler to decay learning rates dynamically\n    encoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(encoder_solver,\n                                                                milestones=cfg.TRAIN.ENCODER_LR_MILESTONES,\n                                                                gamma=cfg.TRAIN.GAMMA)\n    decoder_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(decoder_solver,\n                                                                milestones=cfg.TRAIN.DECODER_LR_MILESTONES,\n                                                                gamma=cfg.TRAIN.GAMMA)\n    refiner_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(refiner_solver,\n                                                                milestones=cfg.TRAIN.REFINER_LR_MILESTONES,\n                                                                gamma=cfg.TRAIN.GAMMA)\n    merger_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(merger_solver,\n                                                               milestones=cfg.TRAIN.MERGER_LR_MILESTONES,\n                                                               gamma=cfg.TRAIN.GAMMA)\n\n    if torch.cuda.is_available():\n        encoder = torch.nn.DataParallel(encoder).cuda()\n        decoder = torch.nn.DataParallel(decoder).cuda()\n        refiner = torch.nn.DataParallel(refiner).cuda()\n        merger = torch.nn.DataParallel(merger).cuda()\n\n    # Set up loss functions\n    bce_loss = torch.nn.BCELoss()\n\n    # Load pretrained model if exists\n    init_epoch = 0\n    best_iou = -1\n    best_epoch = -1\n    if 'WEIGHTS' in cfg.CONST and cfg.TRAIN.RESUME_TRAIN:\n        print('[INFO] %s Recovering from %s ...' % (dt.now(), cfg.CONST.WEIGHTS))\n        checkpoint = torch.load(cfg.CONST.WEIGHTS)\n        init_epoch = checkpoint['epoch_idx']\n        best_iou = checkpoint['best_iou']\n        best_epoch = checkpoint['best_epoch']\n\n        encoder.load_state_dict(checkpoint['encoder_state_dict'])\n        decoder.load_state_dict(checkpoint['decoder_state_dict'])\n        if cfg.NETWORK.USE_REFINER:\n            refiner.load_state_dict(checkpoint['refiner_state_dict'])\n        if cfg.NETWORK.USE_MERGER:\n            merger.load_state_dict(checkpoint['merger_state_dict'])\n\n        print('[INFO] %s Recover complete. Current epoch #%d, Best IoU = %.4f at epoch #%d.' %\n              (dt.now(), init_epoch, best_iou, best_epoch))\n\n    # Summary writer for TensorBoard\n    output_dir = os.path.join(cfg.DIR.OUT_PATH, '%s', dt.now().isoformat())\n    log_dir = output_dir % 'logs'\n    ckpt_dir = output_dir % 'checkpoints'\n    train_writer = SummaryWriter(os.path.join(log_dir, 'train'))\n    val_writer = SummaryWriter(os.path.join(log_dir, 'test'))\n\n    # Training loop\n    for epoch_idx in range(init_epoch, cfg.TRAIN.NUM_EPOCHES):\n        # Tick / tock\n        epoch_start_time = time()\n\n        # Batch average meterics\n        batch_time = utils.network_utils.AverageMeter()\n        data_time = utils.network_utils.AverageMeter()\n        encoder_losses = utils.network_utils.AverageMeter()\n        refiner_losses = utils.network_utils.AverageMeter()\n\n        # switch models to training mode\n        encoder.train()\n        decoder.train()\n        merger.train()\n        refiner.train()\n\n        batch_end_time = time()\n        n_batches = len(train_data_loader)\n        for batch_idx, (taxonomy_names, sample_names, rendering_images,\n                        ground_truth_volumes) in enumerate(train_data_loader):\n            # Measure data time\n            data_time.update(time() - batch_end_time)\n\n            # Get data from data loader\n            rendering_images = utils.network_utils.var_or_cuda(rendering_images)\n            ground_truth_volumes = utils.network_utils.var_or_cuda(ground_truth_volumes)\n\n            # Train the encoder, decoder, refiner, and merger\n            image_features = encoder(rendering_images)\n            raw_features, generated_volumes = decoder(image_features)\n\n            if cfg.NETWORK.USE_MERGER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_MERGER:\n                generated_volumes = merger(raw_features, generated_volumes)\n            else:\n                generated_volumes = torch.mean(generated_volumes, dim=1)\n            encoder_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10\n\n            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:\n                generated_volumes = refiner(generated_volumes)\n                refiner_loss = bce_loss(generated_volumes, ground_truth_volumes) * 10\n            else:\n                refiner_loss = encoder_loss\n\n            # Gradient decent\n            encoder.zero_grad()\n            decoder.zero_grad()\n            refiner.zero_grad()\n            merger.zero_grad()\n\n            if cfg.NETWORK.USE_REFINER and epoch_idx >= cfg.TRAIN.EPOCH_START_USE_REFINER:\n                encoder_loss.backward(retain_graph=True)\n                refiner_loss.backward()\n            else:\n                encoder_loss.backward()\n\n            encoder_solver.step()\n            decoder_solver.step()\n            refiner_solver.step()\n            merger_solver.step()\n\n            # Append loss to average metrics\n            encoder_losses.update(encoder_loss.item())\n            refiner_losses.update(refiner_loss.item())\n            # Append loss to TensorBoard\n            n_itr = epoch_idx * n_batches + batch_idx\n            train_writer.add_scalar('EncoderDecoder/BatchLoss', encoder_loss.item(), n_itr)\n            train_writer.add_scalar('Refiner/BatchLoss', refiner_loss.item(), n_itr)\n\n            # Tick / tock\n            batch_time.update(time() - batch_end_time)\n            batch_end_time = time()\n            print(\n                '[INFO] %s [Epoch %d/%d][Batch %d/%d] BatchTime = %.3f (s) DataTime = %.3f (s) EDLoss = %.4f RLoss = %.4f'\n                % (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, batch_idx + 1, n_batches, batch_time.val,\n                   data_time.val, encoder_loss.item(), refiner_loss.item()))\n\n        # Append epoch loss to TensorBoard\n        train_writer.add_scalar('EncoderDecoder/EpochLoss', encoder_losses.avg, epoch_idx + 1)\n        train_writer.add_scalar('Refiner/EpochLoss', refiner_losses.avg, epoch_idx + 1)\n\n        # Adjust learning rate\n        encoder_lr_scheduler.step()\n        decoder_lr_scheduler.step()\n        refiner_lr_scheduler.step()\n        merger_lr_scheduler.step()\n\n        # Tick / tock\n        epoch_end_time = time()\n        print('[INFO] %s Epoch [%d/%d] EpochTime = %.3f (s) EDLoss = %.4f RLoss = %.4f' %\n              (dt.now(), epoch_idx + 1, cfg.TRAIN.NUM_EPOCHES, epoch_end_time - epoch_start_time, encoder_losses.avg,\n               refiner_losses.avg))\n\n        # Update Rendering Views\n        if cfg.TRAIN.UPDATE_N_VIEWS_RENDERING:\n            n_views_rendering = random.randint(1, cfg.CONST.N_VIEWS_RENDERING)\n            train_data_loader.dataset.set_n_views_rendering(n_views_rendering)\n            print('[INFO] %s Epoch [%d/%d] Update #RenderingViews to %d' %\n                  (dt.now(), epoch_idx + 2, cfg.TRAIN.NUM_EPOCHES, n_views_rendering))\n\n        # Validate the training models\n        iou = test_net(cfg, epoch_idx + 1, output_dir, val_data_loader, val_writer, encoder, decoder, refiner, merger)\n\n        # Save weights to file\n        if (epoch_idx + 1) % cfg.TRAIN.SAVE_FREQ == 0:\n            if not os.path.exists(ckpt_dir):\n                os.makedirs(ckpt_dir)\n\n            utils.network_utils.save_checkpoints(cfg, os.path.join(ckpt_dir, 'ckpt-epoch-%04d.pth' % (epoch_idx + 1)),\n                                                 epoch_idx + 1, encoder, encoder_solver, decoder, decoder_solver,\n                                                 refiner, refiner_solver, merger, merger_solver, best_iou, best_epoch)\n        if iou > best_iou:\n            if not os.path.exists(ckpt_dir):\n                os.makedirs(ckpt_dir)\n\n            best_iou = iou\n            best_epoch = epoch_idx + 1\n            utils.network_utils.save_checkpoints(cfg, os.path.join(ckpt_dir, 'best-ckpt.pth'), epoch_idx + 1, encoder,\n                                                 encoder_solver, decoder, decoder_solver, refiner, refiner_solver,\n                                                 merger, merger_solver, best_iou, best_epoch)\n\n    # Close SummaryWriter for TensorBoard\n    train_writer.close()\n    val_writer.close()\n"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/decoder.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nimport torch\n\n\nclass Decoder(torch.nn.Module):\n    def __init__(self, cfg):\n        super(Decoder, self).__init__()\n        self.cfg = cfg\n\n        # Layer Definition\n        self.layer1 = torch.nn.Sequential(\n            torch.nn.ConvTranspose3d(2048, 512, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),\n            torch.nn.BatchNorm3d(512),\n            torch.nn.ReLU()\n        )\n        self.layer2 = torch.nn.Sequential(\n            torch.nn.ConvTranspose3d(512, 128, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),\n            torch.nn.BatchNorm3d(128),\n            torch.nn.ReLU()\n        )\n        self.layer3 = torch.nn.Sequential(\n            torch.nn.ConvTranspose3d(128, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),\n            torch.nn.BatchNorm3d(32),\n            torch.nn.ReLU()\n        )\n        self.layer4 = torch.nn.Sequential(\n            torch.nn.ConvTranspose3d(32, 8, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),\n            torch.nn.BatchNorm3d(8),\n            torch.nn.ReLU()\n        )\n        self.layer5 = torch.nn.Sequential(\n            torch.nn.ConvTranspose3d(8, 1, kernel_size=1, bias=cfg.NETWORK.TCONV_USE_BIAS),\n            torch.nn.Sigmoid()\n        )\n\n    def forward(self, image_features):\n        image_features = image_features.permute(1, 0, 2, 3, 4).contiguous()\n        image_features = torch.split(image_features, 1, dim=0)\n        gen_volumes = []\n        raw_features = []\n\n        for features in image_features:\n            gen_volume = features.view(-1, 2048, 2, 2, 2)\n            # print(gen_volume.size())   # torch.Size([batch_size, 2048, 2, 2, 2])\n            gen_volume = self.layer1(gen_volume)\n            # print(gen_volume.size())   # torch.Size([batch_size, 512, 4, 4, 4])\n            gen_volume = self.layer2(gen_volume)\n            # print(gen_volume.size())   # torch.Size([batch_size, 128, 8, 8, 8])\n            gen_volume = self.layer3(gen_volume)\n            # print(gen_volume.size())   # torch.Size([batch_size, 32, 16, 16, 16])\n            gen_volume = self.layer4(gen_volume)\n            raw_feature = gen_volume\n            # print(gen_volume.size())   # torch.Size([batch_size, 8, 32, 32, 32])\n            gen_volume = self.layer5(gen_volume)\n            # print(gen_volume.size())   # torch.Size([batch_size, 1, 32, 32, 32])\n            raw_feature = torch.cat((raw_feature, gen_volume), dim=1)\n            # print(raw_feature.size())  # torch.Size([batch_size, 9, 32, 32, 32])\n\n            gen_volumes.append(torch.squeeze(gen_volume, dim=1))\n            raw_features.append(raw_feature)\n\n        gen_volumes = torch.stack(gen_volumes).permute(1, 0, 2, 3, 4).contiguous()\n        raw_features = torch.stack(raw_features).permute(1, 0, 2, 3, 4, 5).contiguous()\n        # print(gen_volumes.size())      # torch.Size([batch_size, n_views, 32, 32, 32])\n        # print(raw_features.size())     # torch.Size([batch_size, n_views, 9, 32, 32, 32])\n        return raw_features, gen_volumes\n"
  },
  {
    "path": "models/encoder.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n#\n# References:\n# - https://github.com/shawnxu1318/MVCNN-Multi-View-Convolutional-Neural-Networks/blob/master/mvcnn.py\n\nimport torch\nimport torchvision.models\n\n\nclass Encoder(torch.nn.Module):\n    def __init__(self, cfg):\n        super(Encoder, self).__init__()\n        self.cfg = cfg\n\n        # Layer Definition\n        vgg16_bn = torchvision.models.vgg16_bn(pretrained=True)\n        self.vgg = torch.nn.Sequential(*list(vgg16_bn.features.children()))[:27]\n        self.layer1 = torch.nn.Sequential(\n            torch.nn.Conv2d(512, 512, kernel_size=3),\n            torch.nn.BatchNorm2d(512),\n            torch.nn.ELU(),\n        )\n        self.layer2 = torch.nn.Sequential(\n            torch.nn.Conv2d(512, 512, kernel_size=3),\n            torch.nn.BatchNorm2d(512),\n            torch.nn.ELU(),\n            torch.nn.MaxPool2d(kernel_size=3)\n        )\n        self.layer3 = torch.nn.Sequential(\n            torch.nn.Conv2d(512, 256, kernel_size=1),\n            torch.nn.BatchNorm2d(256),\n            torch.nn.ELU()\n        )\n\n        # Don't update params in VGG16\n        for param in vgg16_bn.parameters():\n            param.requires_grad = False\n\n    def forward(self, rendering_images):\n        # print(rendering_images.size())  # torch.Size([batch_size, n_views, img_c, img_h, img_w])\n        rendering_images = rendering_images.permute(1, 0, 2, 3, 4).contiguous()\n        rendering_images = torch.split(rendering_images, 1, dim=0)\n        image_features = []\n\n        for img in rendering_images:\n            features = self.vgg(img.squeeze(dim=0))\n            # print(features.size())    # torch.Size([batch_size, 512, 28, 28])\n            features = self.layer1(features)\n            # print(features.size())    # torch.Size([batch_size, 512, 26, 26])\n            features = self.layer2(features)\n            # print(features.size())    # torch.Size([batch_size, 512, 24, 24])\n            features = self.layer3(features)\n            # print(features.size())    # torch.Size([batch_size, 256, 8, 8])\n            image_features.append(features)\n\n        image_features = torch.stack(image_features).permute(1, 0, 2, 3, 4).contiguous()\n        # print(image_features.size())  # torch.Size([batch_size, n_views, 256, 8, 8])\n        return image_features\n"
  },
  {
    "path": "models/merger.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nimport torch\n\n\nclass Merger(torch.nn.Module):\n    def __init__(self, cfg):\n        super(Merger, self).__init__()\n        self.cfg = cfg\n\n        # Layer Definition\n        self.layer1 = torch.nn.Sequential(\n            torch.nn.Conv3d(9, 16, kernel_size=3, padding=1),\n            torch.nn.BatchNorm3d(16),\n            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)\n        )\n        self.layer2 = torch.nn.Sequential(\n            torch.nn.Conv3d(16, 8, kernel_size=3, padding=1),\n            torch.nn.BatchNorm3d(8),\n            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)\n        )\n        self.layer3 = torch.nn.Sequential(\n            torch.nn.Conv3d(8, 4, kernel_size=3, padding=1),\n            torch.nn.BatchNorm3d(4),\n            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)\n        )\n        self.layer4 = torch.nn.Sequential(\n            torch.nn.Conv3d(4, 2, kernel_size=3, padding=1),\n            torch.nn.BatchNorm3d(2),\n            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)\n        )\n        self.layer5 = torch.nn.Sequential(\n            torch.nn.Conv3d(2, 1, kernel_size=3, padding=1),\n            torch.nn.BatchNorm3d(1),\n            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE)\n        )\n\n    def forward(self, raw_features, coarse_volumes):\n        n_views_rendering = coarse_volumes.size(1)\n        raw_features = torch.split(raw_features, 1, dim=1)\n        volume_weights = []\n\n        for i in range(n_views_rendering):\n            raw_feature = torch.squeeze(raw_features[i], dim=1)\n            # print(raw_feature.size())       # torch.Size([batch_size, 9, 32, 32, 32])\n\n            volume_weight = self.layer1(raw_feature)\n            # print(volume_weight.size())     # torch.Size([batch_size, 16, 32, 32, 32])\n            volume_weight = self.layer2(volume_weight)\n            # print(volume_weight.size())     # torch.Size([batch_size, 8, 32, 32, 32])\n            volume_weight = self.layer3(volume_weight)\n            # print(volume_weight.size())     # torch.Size([batch_size, 4, 32, 32, 32])\n            volume_weight = self.layer4(volume_weight)\n            # print(volume_weight.size())     # torch.Size([batch_size, 2, 32, 32, 32])\n            volume_weight = self.layer5(volume_weight)\n            # print(volume_weight.size())     # torch.Size([batch_size, 1, 32, 32, 32])\n\n            volume_weight = torch.squeeze(volume_weight, dim=1)\n            # print(volume_weight.size())     # torch.Size([batch_size, 32, 32, 32])\n            volume_weights.append(volume_weight)\n\n        volume_weights = torch.stack(volume_weights).permute(1, 0, 2, 3, 4).contiguous()\n        volume_weights = torch.softmax(volume_weights, dim=1)\n        # print(volume_weights.size())        # torch.Size([batch_size, n_views, 32, 32, 32])\n        # print(coarse_volumes.size())        # torch.Size([batch_size, n_views, 32, 32, 32])\n        coarse_volumes = coarse_volumes * volume_weights\n        coarse_volumes = torch.sum(coarse_volumes, dim=1)\n\n        return torch.clamp(coarse_volumes, min=0, max=1)\n"
  },
  {
    "path": "models/refiner.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nimport torch\n\n\nclass Refiner(torch.nn.Module):\n    def __init__(self, cfg):\n        super(Refiner, self).__init__()\n        self.cfg = cfg\n\n        # Layer Definition\n        self.layer1 = torch.nn.Sequential(\n            torch.nn.Conv3d(1, 32, kernel_size=4, padding=2),\n            torch.nn.BatchNorm3d(32),\n            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),\n            torch.nn.MaxPool3d(kernel_size=2)\n        )\n        self.layer2 = torch.nn.Sequential(\n            torch.nn.Conv3d(32, 64, kernel_size=4, padding=2),\n            torch.nn.BatchNorm3d(64),\n            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),\n            torch.nn.MaxPool3d(kernel_size=2)\n        )\n        self.layer3 = torch.nn.Sequential(\n            torch.nn.Conv3d(64, 128, kernel_size=4, padding=2),\n            torch.nn.BatchNorm3d(128),\n            torch.nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE),\n            torch.nn.MaxPool3d(kernel_size=2)\n        )\n        self.layer4 = torch.nn.Sequential(\n            torch.nn.Linear(8192, 2048),\n            torch.nn.ReLU()\n        )\n        self.layer5 = torch.nn.Sequential(\n            torch.nn.Linear(2048, 8192),\n            torch.nn.ReLU()\n        )\n        self.layer6 = torch.nn.Sequential(\n            torch.nn.ConvTranspose3d(128, 64, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),\n            torch.nn.BatchNorm3d(64),\n            torch.nn.ReLU()\n        )\n        self.layer7 = torch.nn.Sequential(\n            torch.nn.ConvTranspose3d(64, 32, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),\n            torch.nn.BatchNorm3d(32),\n            torch.nn.ReLU()\n        )\n        self.layer8 = torch.nn.Sequential(\n            torch.nn.ConvTranspose3d(32, 1, kernel_size=4, stride=2, bias=cfg.NETWORK.TCONV_USE_BIAS, padding=1),\n            torch.nn.Sigmoid()\n        )\n\n    def forward(self, coarse_volumes):\n        volumes_32_l = coarse_volumes.view((-1, 1, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX))\n        # print(volumes_32_l.size())       # torch.Size([batch_size, 1, 32, 32, 32])\n        volumes_16_l = self.layer1(volumes_32_l)\n        # print(volumes_16_l.size())       # torch.Size([batch_size, 32, 16, 16, 16])\n        volumes_8_l = self.layer2(volumes_16_l)\n        # print(volumes_8_l.size())        # torch.Size([batch_size, 64, 8, 8, 8])\n        volumes_4_l = self.layer3(volumes_8_l)\n        # print(volumes_4_l.size())        # torch.Size([batch_size, 128, 4, 4, 4])\n        flatten_features = self.layer4(volumes_4_l.view(-1, 8192))\n        # print(flatten_features.size())   # torch.Size([batch_size, 2048])\n        flatten_features = self.layer5(flatten_features)\n        # print(flatten_features.size())   # torch.Size([batch_size, 8192])\n        volumes_4_r = volumes_4_l + flatten_features.view(-1, 128, 4, 4, 4)\n        # print(volumes_4_r.size())        # torch.Size([batch_size, 128, 4, 4, 4])\n        volumes_8_r = volumes_8_l + self.layer6(volumes_4_r)\n        # print(volumes_8_r.size())        # torch.Size([batch_size, 64, 8, 8, 8])\n        volumes_16_r = volumes_16_l + self.layer7(volumes_8_r)\n        # print(volumes_16_r.size())       # torch.Size([batch_size, 32, 16, 16, 16])\n        volumes_32_r = (volumes_32_l + self.layer8(volumes_16_r)) * 0.5\n        # print(volumes_32_r.size())       # torch.Size([batch_size, 1, 32, 32, 32])\n\n        return volumes_32_r.view((-1, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX, self.cfg.CONST.N_VOX))\n"
  },
  {
    "path": "requirements.txt",
    "content": "argparse\neasydict\nmatplotlib\nnumpy\nopencv-python\nscipy\ntorchvision\ntensorboardX\n"
  },
  {
    "path": "runner.py",
    "content": "#!/usr/bin/python3\n# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nimport logging\nimport matplotlib\nimport multiprocessing as mp\nimport numpy as np\nimport os\nimport sys\n# Fix problem: no $DISPLAY environment variable\nmatplotlib.use('Agg')\n\nfrom argparse import ArgumentParser\nfrom datetime import datetime as dt\nfrom pprint import pprint\n\nfrom config import cfg\nfrom core.train import train_net\nfrom core.test import test_net\n\n\ndef get_args_from_command_line():\n    parser = ArgumentParser(description='Parser of Runner of Pix2Vox')\n    parser.add_argument('--gpu',\n                        dest='gpu_id',\n                        help='GPU device id to use [cuda0]',\n                        default=cfg.CONST.DEVICE,\n                        type=str)\n    parser.add_argument('--rand', dest='randomize', help='Randomize (do not use a fixed seed)', action='store_true')\n    parser.add_argument('--test', dest='test', help='Test neural networks', action='store_true')\n    parser.add_argument('--batch-size',\n                        dest='batch_size',\n                        help='name of the net',\n                        default=cfg.CONST.BATCH_SIZE,\n                        type=int)\n    parser.add_argument('--epoch', dest='epoch', help='number of epoches', default=cfg.TRAIN.NUM_EPOCHES, type=int)\n    parser.add_argument('--weights', dest='weights', help='Initialize network from the weights file', default=None)\n    parser.add_argument('--out', dest='out_path', help='Set output path', default=cfg.DIR.OUT_PATH)\n    args = parser.parse_args()\n    return args\n\n\ndef main():\n    # Get args from command line\n    args = get_args_from_command_line()\n\n    if args.gpu_id is not None:\n        cfg.CONST.DEVICE = args.gpu_id\n    if not args.randomize:\n        np.random.seed(cfg.CONST.RNG_SEED)\n    if args.batch_size is not None:\n        cfg.CONST.BATCH_SIZE = args.batch_size\n    if args.epoch is not None:\n        cfg.TRAIN.NUM_EPOCHES = args.epoch\n    if args.out_path is not None:\n        cfg.DIR.OUT_PATH = args.out_path\n    if args.weights is not None:\n        cfg.CONST.WEIGHTS = args.weights\n        if not args.test:\n            cfg.TRAIN.RESUME_TRAIN = True\n\n    # Print config\n    print('Use config:')\n    pprint(cfg)\n\n    # Set GPU to use\n    if type(cfg.CONST.DEVICE) == str:\n        os.environ[\"CUDA_VISIBLE_DEVICES\"] = cfg.CONST.DEVICE\n\n    # Start train/test process\n    if not args.test:\n        train_net(cfg)\n    else:\n        if 'WEIGHTS' in cfg.CONST and os.path.exists(cfg.CONST.WEIGHTS):\n            test_net(cfg)\n        else:\n            print('[FATAL] %s Please specify the file path of checkpoint.' % (dt.now()))\n            sys.exit(2)\n\n\nif __name__ == '__main__':\n    # Check python version\n    if sys.version_info < (3, 0):\n        raise Exception(\"Please follow the installation instruction on 'https://github.com/hzxie/Pix2Vox'\")\n\n    # Setup logger\n    mp.log_to_stderr()\n    logger = mp.get_logger()\n    logger.setLevel(logging.INFO)\n\n    main()\n"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/binvox_converter.py",
    "content": "#!/usr/bin/python3\n# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n#\n# This script is used to convert OFF format to binvox.\n# Please make sure that you have `binvox` installed.\n# You can get it in http://www.patrickmin.com/binvox/\n\nimport numpy as np\nimport os\nimport subprocess\nimport sys\n\nfrom datetime import datetime as dt\nfrom glob import glob\n\nimport binvox_rw\n\n\ndef main():\n    if not len(sys.argv) == 2:\n        print('python binvox_converter.py input_file_folder')\n        sys.exit(1)\n\n    input_file_folder = sys.argv[1]\n    if not os.path.exists(input_file_folder) or not os.path.isdir(input_file_folder):\n        print('[ERROR] Input folder not exists!')\n        sys.exit(2)\n\n    N_VOX = 32\n    MESH_EXTENSION = '*.off'\n\n    folder_path = os.path.join(input_file_folder, MESH_EXTENSION)\n    mesh_files = glob(folder_path)\n\n    for m_file in mesh_files:\n        file_path = os.path.join(input_file_folder, m_file)\n        file_name, ext = os.path.splitext(m_file)\n        binvox_file_path = os.path.join(input_file_folder, '%s.binvox' % file_name)\n\n        if os.path.exists(binvox_file_path):\n            print('[WARN] %s File: %s exists. It will be overwritten.' % (dt.now(), binvox_file_path))\n            os.remove(binvox_file_path)\n\n        print('[INFO] %s Processing file: %s' % (dt.now(), file_path))\n        rc = subprocess.call(['binvox', '-d', str(N_VOX), '-e', '-cb', '-rotx', '-rotx', '-rotx', '-rotz', m_file])\n        if not rc == 0:\n            print('[WARN] %s Failed to convert file: %s' % (dt.now(), m_file))\n            continue\n\n        with open(binvox_file_path, 'rb') as file:\n            v = binvox_rw.read_as_3d_array(file)\n\n        v.data = np.transpose(v.data, (2, 0, 1))\n        with open(binvox_file_path, 'wb') as file:\n            binvox_rw.write(v, file)\n\n\nif __name__ == '__main__':\n    return_code = subprocess.call(['which', 'binvox'], stdout=subprocess.PIPE)\n    if return_code == 0:\n        main()\n    else:\n        print('[FATAL] %s Please make sure you have binvox installed.' % dt.now())\n"
  },
  {
    "path": "utils/binvox_rw.py",
    "content": "#  Copyright (C) 2012 Daniel Maturana\n#  This file is part of binvox-rw-py.\n#\n#  binvox-rw-py is free software: you can redistribute it and/or modify\n#  it under the terms of the GNU General Public License as published by\n#  the Free Software Foundation, either version 3 of the License, or\n#  (at your option) any later version.\n#\n#  binvox-rw-py is distributed in the hope that it will be useful,\n#  but WITHOUT ANY WARRANTY; without even the implied warranty of\n#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the\n#  GNU General Public License for more details.\n#\n#  You should have received a copy of the GNU General Public License\n#  along with binvox-rw-py. If not, see <http://www.gnu.org/licenses/>.\n#\n\"\"\"\nBinvox to Numpy and back.\n\n\n>>> import numpy as np\n>>> import binvox_rw\n>>> with open('chair.binvox', 'rb') as f:\n...     m1 = binvox_rw.read_as_3d_array(f)\n...\n>>> m1.dims\n[32, 32, 32]\n>>> m1.scale\n41.133000000000003\n>>> m1.translate\n[0.0, 0.0, 0.0]\n>>> with open('chair_out.binvox', 'wb') as f:\n...     m1.write(f)\n...\n>>> with open('chair_out.binvox', 'rb') as f:\n...     m2 = binvox_rw.read_as_3d_array(f)\n...\n>>> m1.dims == m2.dims\nTrue\n>>> m1.scale == m2.scale\nTrue\n>>> m1.translate == m2.translate\nTrue\n>>> np.all(m1.data == m2.data)\nTrue\n\n>>> with open('chair.binvox', 'rb') as f:\n...     md = binvox_rw.read_as_3d_array(f)\n...\n>>> with open('chair.binvox', 'rb') as f:\n...     ms = binvox_rw.read_as_coord_array(f)\n...\n>>> data_ds = binvox_rw.dense_to_sparse(md.data)\n>>> data_sd = binvox_rw.sparse_to_dense(ms.data, 32)\n>>> np.all(data_sd == md.data)\nTrue\n>>> # the ordering of elements returned by numpy.nonzero changes with axis\n>>> # ordering, so to compare for equality we first lexically sort the voxels.\n>>> np.all(ms.data[:, np.lexsort(ms.data)] == data_ds[:, np.lexsort(data_ds)])\nTrue\n\"\"\"\n\nimport numpy as np\n\n\nclass Voxels(object):\n    \"\"\" Holds a binvox model.\n    data is either a three-dimensional numpy boolean array (dense representation)\n    or a two-dimensional numpy float array (coordinate representation).\n\n    dims, translate and scale are the model metadata.\n\n    dims are the voxel dimensions, e.g. [32, 32, 32] for a 32x32x32 model.\n\n    scale and translate relate the voxels to the original model coordinates.\n\n    To translate voxel coordinates i, j, k to original coordinates x, y, z:\n\n    x_n = (i+.5)/dims[0]\n    y_n = (j+.5)/dims[1]\n    z_n = (k+.5)/dims[2]\n    x = scale*x_n + translate[0]\n    y = scale*y_n + translate[1]\n    z = scale*z_n + translate[2]\n\n    \"\"\"\n    def __init__(self, data, dims, translate, scale, axis_order):\n        self.data = data\n        self.dims = dims\n        self.translate = translate\n        self.scale = scale\n        assert (axis_order in ('xzy', 'xyz'))\n        self.axis_order = axis_order\n\n    def clone(self):\n        data = self.data.copy()\n        dims = self.dims[:]\n        translate = self.translate[:]\n        return Voxels(data, dims, translate, self.scale, self.axis_order)\n\n    def write(self, fp):\n        write(self, fp)\n\n\ndef read_header(fp):\n    \"\"\" Read binvox header. Mostly meant for internal use.\n    \"\"\"\n    line = fp.readline().strip()\n    if not line.startswith(b'#binvox'):\n        raise IOError('[ERROR] Not a binvox file')\n    dims = list(map(int, fp.readline().strip().split(b' ')[1:]))\n    translate = list(map(float, fp.readline().strip().split(b' ')[1:]))\n    scale = list(map(float, fp.readline().strip().split(b' ')[1:]))[0]\n    fp.readline()\n    return dims, translate, scale\n\n\ndef read_as_3d_array(fp, fix_coords=True):\n    \"\"\" Read binary binvox format as array.\n\n    Returns the model with accompanying metadata.\n\n    Voxels are stored in a three-dimensional numpy array, which is simple and\n    direct, but may use a lot of memory for large models. (Storage requirements\n    are 8*(d^3) bytes, where d is the dimensions of the binvox model. Numpy\n    boolean arrays use a byte per element).\n\n    Doesn't do any checks on input except for the '#binvox' line.\n    \"\"\"\n    dims, translate, scale = read_header(fp)\n    raw_data = np.frombuffer(fp.read(), dtype=np.uint8)\n    # if just using reshape() on the raw data:\n    # indexing the array as array[i,j,k], the indices map into the\n    # coords as:\n    # i -> x\n    # j -> z\n    # k -> y\n    # if fix_coords is true, then data is rearranged so that\n    # mapping is\n    # i -> x\n    # j -> y\n    # k -> z\n    values, counts = raw_data[::2], raw_data[1::2]\n    data = np.repeat(values, counts).astype(np.int32)\n    data = data.reshape(dims)\n    if fix_coords:\n        # xzy to xyz TODO the right thing\n        data = np.transpose(data, (0, 2, 1))\n        axis_order = 'xyz'\n    else:\n        axis_order = 'xzy'\n    return Voxels(data, dims, translate, scale, axis_order)\n\n\ndef read_as_coord_array(fp, fix_coords=True):\n    \"\"\" Read binary binvox format as coordinates.\n\n    Returns binvox model with voxels in a \"coordinate\" representation, i.e.  an\n    3 x N array where N is the number of nonzero voxels. Each column\n    corresponds to a nonzero voxel and the 3 rows are the (x, z, y) coordinates\n    of the voxel.  (The odd ordering is due to the way binvox format lays out\n    data).  Note that coordinates refer to the binvox voxels, without any\n    scaling or translation.\n\n    Use this to save memory if your model is very sparse (mostly empty).\n\n    Doesn't do any checks on input except for the '#binvox' line.\n    \"\"\"\n    dims, translate, scale = read_header(fp)\n    raw_data = np.frombuffer(fp.read(), dtype=np.uint8)\n\n    values, counts = raw_data[::2], raw_data[1::2]\n\n    # sz = np.prod(dims)\n    # index, end_index = 0, 0\n    end_indices = np.cumsum(counts)\n    indices = np.concatenate(([0], end_indices[:-1])).astype(end_indices.dtype)\n\n    values = values.astype(np.bool)\n    indices = indices[values]\n    end_indices = end_indices[values]\n\n    nz_voxels = []\n    for index, end_index in zip(indices, end_indices):\n        nz_voxels.extend(range(index, end_index))\n    nz_voxels = np.array(nz_voxels)\n    # TODO are these dims correct?\n    # according to docs,\n    # index = x * wxh + z * width + y; // wxh = width * height = d * d\n\n    x = nz_voxels / (dims[0] * dims[1])\n    zwpy = nz_voxels % (dims[0] * dims[1])    # z*w + y\n    z = zwpy / dims[0]\n    y = zwpy % dims[0]\n    if fix_coords:\n        data = np.vstack((x, y, z))\n        axis_order = 'xyz'\n    else:\n        data = np.vstack((x, z, y))\n        axis_order = 'xzy'\n\n    #return Voxels(data, dims, translate, scale, axis_order)\n    return Voxels(np.ascontiguousarray(data), dims, translate, scale, axis_order)\n\n\ndef dense_to_sparse(voxel_data, dtype=np.int):\n    \"\"\" From dense representation to sparse (coordinate) representation.\n    No coordinate reordering.\n    \"\"\"\n    if voxel_data.ndim != 3:\n        raise ValueError('[ERROR] voxel_data is wrong shape; should be 3D array.')\n    return np.asarray(np.nonzero(voxel_data), dtype)\n\n\ndef sparse_to_dense(voxel_data, dims, dtype=np.bool):\n    if voxel_data.ndim != 2 or voxel_data.shape[0] != 3:\n        raise ValueError('[ERROR] voxel_data is wrong shape; should be 3xN array.')\n    if np.isscalar(dims):\n        dims = [dims] * 3\n    dims = np.atleast_2d(dims).T\n    # truncate to integers\n    xyz = voxel_data.astype(np.int)\n    # discard voxels that fall outside dims\n    valid_ix = ~np.any((xyz < 0) | (xyz >= dims), 0)\n    xyz = xyz[:, valid_ix]\n    out = np.zeros(dims.flatten(), dtype=dtype)\n    out[tuple(xyz)] = True\n    return out\n\n\n#def get_linear_index(x, y, z, dims):\n#\"\"\" Assuming xzy order. (y increasing fastest.\n#TODO ensure this is right when dims are not all same\n#\"\"\"\n#return x*(dims[1]*dims[2]) + z*dims[1] + y\n\n\ndef write(voxel_model, fp):\n    \"\"\" Write binary binvox format.\n\n    Note that when saving a model in sparse (coordinate) format, it is first\n    converted to dense format.\n\n    Doesn't check if the model is 'sane'.\n\n    \"\"\"\n    if voxel_model.data.ndim == 2:\n        # TODO avoid conversion to dense\n        dense_voxel_data = sparse_to_dense(voxel_model.data, voxel_model.dims).astype(int)\n    else:\n        dense_voxel_data = voxel_model.data.astype(int)\n\n    file_header = [\n        '#binvox 1\\n',\n        'dim %s\\n' % ' '.join(map(str, voxel_model.dims)),\n        'translate %s\\n' % ' '.join(map(str, voxel_model.translate)),\n        'scale %s\\n' % str(voxel_model.scale), 'data\\n'\n    ]\n\n    for fh in file_header:\n        fp.write(fh.encode('latin-1'))\n\n    if voxel_model.axis_order not in ('xzy', 'xyz'):\n        raise ValueError('[ERROR] Unsupported voxel model axis order')\n\n    if voxel_model.axis_order == 'xzy':\n        voxels_flat = dense_voxel_data.flatten()\n    elif voxel_model.axis_order == 'xyz':\n        voxels_flat = np.transpose(dense_voxel_data, (0, 2, 1)).flatten()\n\n    # keep a sort of state machine for writing run length encoding\n    state = voxels_flat[0]\n    ctr = 0\n    for c in voxels_flat:\n        if c == state:\n            ctr += 1\n            # if ctr hits max, dump\n            if ctr == 255:\n                fp.write(chr(state).encode('latin-1'))\n                fp.write(chr(ctr).encode('latin-1'))\n                ctr = 0\n        else:\n            # if switch state, dump\n            fp.write(chr(state).encode('latin-1'))\n            fp.write(chr(ctr).encode('latin-1'))\n            state = c\n            ctr = 1\n    # flush out remainders\n    if ctr > 0:\n        fp.write(chr(state).encode('latin-1'))\n        fp.write(chr(ctr).encode('latin-1'))\n\n\nif __name__ == '__main__':\n    import doctest\n    doctest.testmod()\n"
  },
  {
    "path": "utils/binvox_visualization.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nimport cv2\nimport matplotlib.pyplot as plt\nimport os\n\nfrom mpl_toolkits.mplot3d import Axes3D\n\n\ndef get_volume_views(volume, save_dir, n_itr):\n    if not os.path.exists(save_dir):\n        os.makedirs(save_dir)\n\n    volume = volume.squeeze().__ge__(0.5)\n    fig = plt.figure()\n    ax = fig.gca(projection=Axes3D.name)\n    ax.set_aspect('equal')\n    ax.voxels(volume, edgecolor=\"k\")\n\n    save_path = os.path.join(save_dir, 'voxels-%06d.png' % n_itr)\n    plt.savefig(save_path, bbox_inches='tight')\n    plt.close()\n    return cv2.imread(save_path)\n"
  },
  {
    "path": "utils/data_loaders.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nimport cv2\nimport json\nimport numpy as np\nimport os\nimport random\nimport scipy.io\nimport scipy.ndimage\nimport sys\nimport torch.utils.data.dataset\n\nfrom datetime import datetime as dt\nfrom enum import Enum, unique\n\nimport utils.binvox_rw\n\n\n@unique\nclass DatasetType(Enum):\n    TRAIN = 0\n    TEST = 1\n    VAL = 2\n\n\n# //////////////////////////////// = End of DatasetType Class Definition = ///////////////////////////////// #\n\n\nclass ShapeNetDataset(torch.utils.data.dataset.Dataset):\n    \"\"\"ShapeNetDataset class used for PyTorch DataLoader\"\"\"\n    def __init__(self, dataset_type, file_list, n_views_rendering, transforms=None):\n        self.dataset_type = dataset_type\n        self.file_list = file_list\n        self.transforms = transforms\n        self.n_views_rendering = n_views_rendering\n\n    def __len__(self):\n        return len(self.file_list)\n\n    def __getitem__(self, idx):\n        taxonomy_name, sample_name, rendering_images, volume = self.get_datum(idx)\n\n        if self.transforms:\n            rendering_images = self.transforms(rendering_images)\n\n        return taxonomy_name, sample_name, rendering_images, volume\n\n    def set_n_views_rendering(self, n_views_rendering):\n        self.n_views_rendering = n_views_rendering\n\n    def get_datum(self, idx):\n        taxonomy_name = self.file_list[idx]['taxonomy_name']\n        sample_name = self.file_list[idx]['sample_name']\n        rendering_image_paths = self.file_list[idx]['rendering_images']\n        volume_path = self.file_list[idx]['volume']\n\n        # Get data of rendering images\n        if self.dataset_type == DatasetType.TRAIN:\n            selected_rendering_image_paths = [\n                rendering_image_paths[i]\n                for i in random.sample(range(len(rendering_image_paths)), self.n_views_rendering)\n            ]\n        else:\n            selected_rendering_image_paths = [rendering_image_paths[i] for i in range(self.n_views_rendering)]\n\n        rendering_images = []\n        for image_path in selected_rendering_image_paths:\n            rendering_image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.\n            if len(rendering_image.shape) < 3:\n                print('[FATAL] %s It seems that there is something wrong with the image file %s' %\n                      (dt.now(), image_path))\n                sys.exit(2)\n\n            rendering_images.append(rendering_image)\n\n        # Get data of volume\n        _, suffix = os.path.splitext(volume_path)\n\n        if suffix == '.mat':\n            volume = scipy.io.loadmat(volume_path)\n            volume = volume['Volume'].astype(np.float32)\n        elif suffix == '.binvox':\n            with open(volume_path, 'rb') as f:\n                volume = utils.binvox_rw.read_as_3d_array(f)\n                volume = volume.data.astype(np.float32)\n\n        return taxonomy_name, sample_name, np.asarray(rendering_images), volume\n\n\n# //////////////////////////////// = End of ShapeNetDataset Class Definition = ///////////////////////////////// #\n\n\nclass ShapeNetDataLoader:\n    def __init__(self, cfg):\n        self.dataset_taxonomy = None\n        self.rendering_image_path_template = cfg.DATASETS.SHAPENET.RENDERING_PATH\n        self.volume_path_template = cfg.DATASETS.SHAPENET.VOXEL_PATH\n\n        # Load all taxonomies of the dataset\n        with open(cfg.DATASETS.SHAPENET.TAXONOMY_FILE_PATH, encoding='utf-8') as file:\n            self.dataset_taxonomy = json.loads(file.read())\n\n    def get_dataset(self, dataset_type, n_views_rendering, transforms=None):\n        files = []\n\n        # Load data for each category\n        for taxonomy in self.dataset_taxonomy:\n            taxonomy_folder_name = taxonomy['taxonomy_id']\n            print('[INFO] %s Collecting files of Taxonomy[ID=%s, Name=%s]' %\n                  (dt.now(), taxonomy['taxonomy_id'], taxonomy['taxonomy_name']))\n            samples = []\n            if dataset_type == DatasetType.TRAIN:\n                samples = taxonomy['train']\n            elif dataset_type == DatasetType.TEST:\n                samples = taxonomy['test']\n            elif dataset_type == DatasetType.VAL:\n                samples = taxonomy['val']\n\n            files.extend(self.get_files_of_taxonomy(taxonomy_folder_name, samples))\n\n        print('[INFO] %s Complete collecting files of the dataset. Total files: %d.' % (dt.now(), len(files)))\n        return ShapeNetDataset(dataset_type, files, n_views_rendering, transforms)\n\n    def get_files_of_taxonomy(self, taxonomy_folder_name, samples):\n        files_of_taxonomy = []\n\n        for sample_idx, sample_name in enumerate(samples):\n            # Get file path of volumes\n            volume_file_path = self.volume_path_template % (taxonomy_folder_name, sample_name)\n            if not os.path.exists(volume_file_path):\n                print('[WARN] %s Ignore sample %s/%s since volume file not exists.' %\n                      (dt.now(), taxonomy_folder_name, sample_name))\n                continue\n\n            # Get file list of rendering images\n            img_file_path = self.rendering_image_path_template % (taxonomy_folder_name, sample_name, 0)\n            img_folder = os.path.dirname(img_file_path)\n            total_views = len(os.listdir(img_folder))\n            rendering_image_indexes = range(total_views)\n            rendering_images_file_path = []\n            for image_idx in rendering_image_indexes:\n                img_file_path = self.rendering_image_path_template % (taxonomy_folder_name, sample_name, image_idx)\n                if not os.path.exists(img_file_path):\n                    continue\n\n                rendering_images_file_path.append(img_file_path)\n\n            if len(rendering_images_file_path) == 0:\n                print('[WARN] %s Ignore sample %s/%s since image files not exists.' %\n                      (dt.now(), taxonomy_folder_name, sample_name))\n                continue\n\n            # Append to the list of rendering images\n            files_of_taxonomy.append({\n                'taxonomy_name': taxonomy_folder_name,\n                'sample_name': sample_name,\n                'rendering_images': rendering_images_file_path,\n                'volume': volume_file_path,\n            })\n\n            # Report the progress of reading dataset\n            # if sample_idx % 500 == 499 or sample_idx == n_samples - 1:\n            #     print('[INFO] %s Collecting %d of %d' % (dt.now(), sample_idx + 1, n_samples))\n\n        return files_of_taxonomy\n\n\n# /////////////////////////////// = End of ShapeNetDataLoader Class Definition = /////////////////////////////// #\n\n\nclass Pascal3dDataset(torch.utils.data.dataset.Dataset):\n    \"\"\"Pascal3D class used for PyTorch DataLoader\"\"\"\n    def __init__(self, file_list, transforms=None):\n        self.file_list = file_list\n        self.transforms = transforms\n\n    def __len__(self):\n        return len(self.file_list)\n\n    def __getitem__(self, idx):\n        taxonomy_name, sample_name, rendering_images, volume, bounding_box = self.get_datum(idx)\n\n        if self.transforms:\n            rendering_images = self.transforms(rendering_images, bounding_box)\n\n        return taxonomy_name, sample_name, rendering_images, volume\n\n    def get_datum(self, idx):\n        taxonomy_name = self.file_list[idx]['taxonomy_name']\n        sample_name = self.file_list[idx]['sample_name']\n        rendering_image_path = self.file_list[idx]['rendering_image']\n        bounding_box = self.file_list[idx]['bounding_box']\n        volume_path = self.file_list[idx]['volume']\n\n        # Get data of rendering images\n        rendering_image = cv2.imread(rendering_image_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.\n\n        if len(rendering_image.shape) < 3:\n            print('[WARN] %s It seems the image file %s is grayscale.' % (dt.now(), rendering_image_path))\n            rendering_image = np.stack((rendering_image, ) * 3, -1)\n\n        # Get data of volume\n        with open(volume_path, 'rb') as f:\n            volume = utils.binvox_rw.read_as_3d_array(f)\n            volume = volume.data.astype(np.float32)\n\n        return taxonomy_name, sample_name, np.asarray([rendering_image]), volume, bounding_box\n\n\n# //////////////////////////////// = End of Pascal3dDataset Class Definition = ///////////////////////////////// #\n\n\nclass Pascal3dDataLoader:\n    def __init__(self, cfg):\n        self.dataset_taxonomy = None\n        self.volume_path_template = cfg.DATASETS.PASCAL3D.VOXEL_PATH\n        self.annotation_path_template = cfg.DATASETS.PASCAL3D.ANNOTATION_PATH\n        self.rendering_image_path_template = cfg.DATASETS.PASCAL3D.RENDERING_PATH\n\n        # Load all taxonomies of the dataset\n        with open(cfg.DATASETS.PASCAL3D.TAXONOMY_FILE_PATH, encoding='utf-8') as file:\n            self.dataset_taxonomy = json.loads(file.read())\n\n    def get_dataset(self, dataset_type, n_views_rendering, transforms=None):\n        files = []\n\n        # Load data for each category\n        for taxonomy in self.dataset_taxonomy:\n            taxonomy_name = taxonomy['taxonomy_name']\n            print('[INFO] %s Collecting files of Taxonomy[Name=%s]' % (dt.now(), taxonomy_name))\n\n            samples = []\n            if dataset_type == DatasetType.TRAIN:\n                samples = taxonomy['train']\n            elif dataset_type == DatasetType.TEST:\n                samples = taxonomy['test']\n            elif dataset_type == DatasetType.VAL:\n                samples = taxonomy['test']\n\n            files.extend(self.get_files_of_taxonomy(taxonomy_name, samples))\n\n        print('[INFO] %s Complete collecting files of the dataset. Total files: %d.' % (dt.now(), len(files)))\n        return Pascal3dDataset(files, transforms)\n\n    def get_files_of_taxonomy(self, taxonomy_name, samples):\n        files_of_taxonomy = []\n\n        for sample_idx, sample_name in enumerate(samples):\n            # Get file list of rendering images\n            rendering_image_file_path = self.rendering_image_path_template % (taxonomy_name, sample_name)\n            # if not os.path.exists(rendering_image_file_path):\n            #     continue\n\n            # Get image annotations\n            annotations_file_path = self.annotation_path_template % (taxonomy_name, sample_name)\n            annotations_mat = scipy.io.loadmat(annotations_file_path, squeeze_me=True, struct_as_record=False)\n            img_width, img_height, _ = annotations_mat['record'].imgsize\n            annotations = annotations_mat['record'].objects\n\n            cad_index = -1\n            bbox = None\n            if (type(annotations) == np.ndarray):\n                max_bbox_aera = -1\n\n                for i in range(len(annotations)):\n                    _cad_index = annotations[i].cad_index\n                    _bbox = annotations[i].__dict__['bbox']\n\n                    bbox_xmin = _bbox[0]\n                    bbox_ymin = _bbox[1]\n                    bbox_xmax = _bbox[2]\n                    bbox_ymax = _bbox[3]\n                    _bbox_area = (bbox_xmax - bbox_xmin) * (bbox_ymax - bbox_ymin)\n\n                    if _bbox_area > max_bbox_aera:\n                        bbox = _bbox\n                        cad_index = _cad_index\n                        max_bbox_aera = _bbox_area\n            else:\n                cad_index = annotations.cad_index\n                bbox = annotations.bbox\n\n            # Convert the coordinates of bounding boxes to percentages\n            bbox = [bbox[0] / img_width, bbox[1] / img_height, bbox[2] / img_width, bbox[3] / img_height]\n            # Get file path of volumes\n            volume_file_path = self.volume_path_template % (taxonomy_name, cad_index)\n            if not os.path.exists(volume_file_path):\n                print('[WARN] %s Ignore sample %s/%s since volume file not exists.' %\n                      (dt.now(), taxonomy_name, sample_name))\n                continue\n\n            # Append to the list of rendering images\n            files_of_taxonomy.append({\n                'taxonomy_name': taxonomy_name,\n                'sample_name': sample_name,\n                'rendering_image': rendering_image_file_path,\n                'bounding_box': bbox,\n                'volume': volume_file_path,\n            })\n\n        return files_of_taxonomy\n\n\n# /////////////////////////////// = End of Pascal3dDataLoader Class Definition = /////////////////////////////// #\n\n\nclass Pix3dDataset(torch.utils.data.dataset.Dataset):\n    \"\"\"Pix3D class used for PyTorch DataLoader\"\"\"\n    def __init__(self, file_list, transforms=None):\n        self.file_list = file_list\n        self.transforms = transforms\n\n    def __len__(self):\n        return len(self.file_list)\n\n    def __getitem__(self, idx):\n        taxonomy_name, sample_name, rendering_images, volume, bounding_box = self.get_datum(idx)\n\n        if self.transforms:\n            rendering_images = self.transforms(rendering_images, bounding_box)\n\n        return taxonomy_name, sample_name, rendering_images, volume\n\n    def get_datum(self, idx):\n        taxonomy_name = self.file_list[idx]['taxonomy_name']\n        sample_name = self.file_list[idx]['sample_name']\n        rendering_image_path = self.file_list[idx]['rendering_image']\n        bounding_box = self.file_list[idx]['bounding_box']\n        volume_path = self.file_list[idx]['volume']\n\n        # Get data of rendering images\n        rendering_image = cv2.imread(rendering_image_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.\n\n        if len(rendering_image.shape) < 3:\n            print('[WARN] %s It seems the image file %s is grayscale.' % (dt.now(), rendering_image_path))\n            rendering_image = np.stack((rendering_image, ) * 3, -1)\n\n        # Get data of volume\n        with open(volume_path, 'rb') as f:\n            volume = utils.binvox_rw.read_as_3d_array(f)\n            volume = volume.data.astype(np.float32)\n\n        return taxonomy_name, sample_name, np.asarray([rendering_image]), volume, bounding_box\n\n\n# //////////////////////////////// = End of Pascal3dDataset Class Definition = ///////////////////////////////// #\n\n\nclass Pix3dDataLoader:\n    def __init__(self, cfg):\n        self.dataset_taxonomy = None\n        self.annotations = dict()\n        self.volume_path_template = cfg.DATASETS.PIX3D.VOXEL_PATH\n        self.rendering_image_path_template = cfg.DATASETS.PIX3D.RENDERING_PATH\n\n        # Load all taxonomies of the dataset\n        with open(cfg.DATASETS.PIX3D.TAXONOMY_FILE_PATH, encoding='utf-8') as file:\n            self.dataset_taxonomy = json.loads(file.read())\n\n        # Load all annotations of the dataset\n        _annotations = None\n        with open(cfg.DATASETS.PIX3D.ANNOTATION_PATH, encoding='utf-8') as file:\n            _annotations = json.loads(file.read())\n\n        for anno in _annotations:\n            filename, _ = os.path.splitext(anno['img'])\n            anno_key = filename[4:]\n            self.annotations[anno_key] = anno\n\n    def get_dataset(self, dataset_type, n_views_rendering, transforms=None):\n        files = []\n\n        # Load data for each category\n        for taxonomy in self.dataset_taxonomy:\n            taxonomy_name = taxonomy['taxonomy_name']\n            print('[INFO] %s Collecting files of Taxonomy[Name=%s]' % (dt.now(), taxonomy_name))\n\n            samples = []\n            if dataset_type == DatasetType.TRAIN:\n                samples = taxonomy['train']\n            elif dataset_type == DatasetType.TEST:\n                samples = taxonomy['test']\n            elif dataset_type == DatasetType.VAL:\n                samples = taxonomy['test']\n\n            files.extend(self.get_files_of_taxonomy(taxonomy_name, samples))\n\n        print('[INFO] %s Complete collecting files of the dataset. Total files: %d.' % (dt.now(), len(files)))\n        return Pix3dDataset(files, transforms)\n\n    def get_files_of_taxonomy(self, taxonomy_name, samples):\n        files_of_taxonomy = []\n\n        for sample_idx, sample_name in enumerate(samples):\n            # Get image annotations\n            anno_key = '%s/%s' % (taxonomy_name, sample_name)\n            annotations = self.annotations[anno_key]\n\n            # Get file list of rendering images\n            _, img_file_suffix = os.path.splitext(annotations['img'])\n            rendering_image_file_path = self.rendering_image_path_template % (taxonomy_name, sample_name,\n                                                                              img_file_suffix[1:])\n\n            # Get the bounding box of the image\n            img_width, img_height = annotations['img_size']\n            bbox = [\n                annotations['bbox'][0] / img_width,\n                annotations['bbox'][1] / img_height,\n                annotations['bbox'][2] / img_width,\n                annotations['bbox'][3] / img_height\n            ]  # yapf: disable\n            model_name_parts = annotations['voxel'].split('/')\n            model_name = model_name_parts[2]\n            volume_file_name = model_name_parts[3][:-4].replace('voxel', 'model')\n\n            # Get file path of volumes\n            volume_file_path = self.volume_path_template % (taxonomy_name, model_name, volume_file_name)\n            if not os.path.exists(volume_file_path):\n                print('[WARN] %s Ignore sample %s/%s since volume file not exists.' %\n                      (dt.now(), taxonomy_name, sample_name))\n                continue\n\n            # Append to the list of rendering images\n            files_of_taxonomy.append({\n                'taxonomy_name': taxonomy_name,\n                'sample_name': sample_name,\n                'rendering_image': rendering_image_file_path,\n                'bounding_box': bbox,\n                'volume': volume_file_path,\n            })\n\n        return files_of_taxonomy\n\n\n# /////////////////////////////// = End of Pascal3dDataLoader Class Definition = /////////////////////////////// #\n\nDATASET_LOADER_MAPPING = {\n    'ShapeNet': ShapeNetDataLoader,\n    'Pascal3D': Pascal3dDataLoader,\n    'Pix3D': Pix3dDataLoader\n}  # yapf: disable\n"
  },
  {
    "path": "utils/data_transforms.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n# References:\n# - https://github.com/xiumingzhang/GenRe-ShapeHD\n\nimport cv2\n# import matplotlib.pyplot as plt\n# import matplotlib.patches as patches\nimport numpy as np\nimport os\nimport random\nimport torch\n\n\nclass Compose(object):\n    \"\"\" Composes several transforms together.\n    For example:\n    >>> transforms.Compose([\n    >>>     transforms.RandomBackground(),\n    >>>     transforms.CenterCrop(127, 127, 3),\n    >>>  ])\n    \"\"\"\n    def __init__(self, transforms):\n        self.transforms = transforms\n\n    def __call__(self, rendering_images, bounding_box=None):\n        for t in self.transforms:\n            if t.__class__.__name__ == 'RandomCrop' or t.__class__.__name__ == 'CenterCrop':\n                rendering_images = t(rendering_images, bounding_box)\n            else:\n                rendering_images = t(rendering_images)\n\n        return rendering_images\n\n\nclass ToTensor(object):\n    \"\"\"\n    Convert a PIL Image or numpy.ndarray to tensor.\n    Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].\n    \"\"\"\n    def __call__(self, rendering_images):\n        assert (isinstance(rendering_images, np.ndarray))\n        array = np.transpose(rendering_images, (0, 3, 1, 2))\n        # handle numpy array\n        tensor = torch.from_numpy(array)\n\n        # put it from HWC to CHW format\n        return tensor.float()\n\n\nclass Normalize(object):\n    def __init__(self, mean, std):\n        self.mean = mean\n        self.std = std\n\n    def __call__(self, rendering_images):\n        assert (isinstance(rendering_images, np.ndarray))\n        rendering_images -= self.mean\n        rendering_images /= self.std\n\n        return rendering_images\n\n\nclass RandomPermuteRGB(object):\n    def __call__(self, rendering_images):\n        assert (isinstance(rendering_images, np.ndarray))\n\n        random_permutation = np.random.permutation(3)\n        for img_idx, img in enumerate(rendering_images):\n            rendering_images[img_idx] = img[..., random_permutation]\n\n        return rendering_images\n\n\nclass CenterCrop(object):\n    def __init__(self, img_size, crop_size):\n        \"\"\"Set the height and weight before and after cropping\"\"\"\n        self.img_size_h = img_size[0]\n        self.img_size_w = img_size[1]\n        self.crop_size_h = crop_size[0]\n        self.crop_size_w = crop_size[1]\n\n    def __call__(self, rendering_images, bounding_box=None):\n        if len(rendering_images) == 0:\n            return rendering_images\n\n        crop_size_c = rendering_images[0].shape[2]\n        processed_images = np.empty(shape=(0, self.img_size_h, self.img_size_w, crop_size_c))\n        for img_idx, img in enumerate(rendering_images):\n            img_height, img_width, _ = img.shape\n\n            if bounding_box is not None:\n                bounding_box = [\n                    bounding_box[0] * img_width,\n                    bounding_box[1] * img_height,\n                    bounding_box[2] * img_width,\n                    bounding_box[3] * img_height\n                ]  # yapf: disable\n\n                # Calculate the size of bounding boxes\n                bbox_width = bounding_box[2] - bounding_box[0]\n                bbox_height = bounding_box[3] - bounding_box[1]\n                bbox_x_mid = (bounding_box[2] + bounding_box[0]) * .5\n                bbox_y_mid = (bounding_box[3] + bounding_box[1]) * .5\n\n                # Make the crop area as a square\n                square_object_size = max(bbox_width, bbox_height)\n                x_left = int(bbox_x_mid - square_object_size * .5)\n                x_right = int(bbox_x_mid + square_object_size * .5)\n                y_top = int(bbox_y_mid - square_object_size * .5)\n                y_bottom = int(bbox_y_mid + square_object_size * .5)\n\n                # If the crop position is out of the image, fix it with padding\n                pad_x_left = 0\n                if x_left < 0:\n                    pad_x_left = -x_left\n                    x_left = 0\n                pad_x_right = 0\n                if x_right >= img_width:\n                    pad_x_right = x_right - img_width + 1\n                    x_right = img_width - 1\n                pad_y_top = 0\n                if y_top < 0:\n                    pad_y_top = -y_top\n                    y_top = 0\n                pad_y_bottom = 0\n                if y_bottom >= img_height:\n                    pad_y_bottom = y_bottom - img_height + 1\n                    y_bottom = img_height - 1\n\n                # Padding the image and resize the image\n                processed_image = np.pad(img[y_top:y_bottom + 1, x_left:x_right + 1],\n                                         ((pad_y_top, pad_y_bottom), (pad_x_left, pad_x_right), (0, 0)),\n                                         mode='edge')\n                processed_image = cv2.resize(processed_image, (self.img_size_w, self.img_size_h))\n            else:\n                if img_height > self.crop_size_h and img_width > self.crop_size_w:\n                    x_left = int(img_width - self.crop_size_w) // 2\n                    x_right = int(x_left + self.crop_size_w)\n                    y_top = int(img_height - self.crop_size_h) // 2\n                    y_bottom = int(y_top + self.crop_size_h)\n                else:\n                    x_left = 0\n                    x_right = img_width\n                    y_top = 0\n                    y_bottom = img_height\n\n                processed_image = cv2.resize(img[y_top:y_bottom, x_left:x_right], (self.img_size_w, self.img_size_h))\n\n            processed_images = np.append(processed_images, [processed_image], axis=0)\n            # Debug\n            # fig = plt.figure()\n            # ax1 = fig.add_subplot(1, 2, 1)\n            # ax1.imshow(img)\n            # if not bounding_box is None:\n            #     rect = patches.Rectangle((bounding_box[0], bounding_box[1]),\n            #                              bbox_width,\n            #                              bbox_height,\n            #                              linewidth=1,\n            #                              edgecolor='r',\n            #                              facecolor='none')\n            #     ax1.add_patch(rect)\n            # ax2 = fig.add_subplot(1, 2, 2)\n            # ax2.imshow(processed_image)\n            # plt.show()\n        return processed_images\n\n\nclass RandomCrop(object):\n    def __init__(self, img_size, crop_size):\n        \"\"\"Set the height and weight before and after cropping\"\"\"\n        self.img_size_h = img_size[0]\n        self.img_size_w = img_size[1]\n        self.crop_size_h = crop_size[0]\n        self.crop_size_w = crop_size[1]\n\n    def __call__(self, rendering_images, bounding_box=None):\n        if len(rendering_images) == 0:\n            return rendering_images\n\n        crop_size_c = rendering_images[0].shape[2]\n        processed_images = np.empty(shape=(0, self.img_size_h, self.img_size_w, crop_size_c))\n        for img_idx, img in enumerate(rendering_images):\n            img_height, img_width, _ = img.shape\n\n            if bounding_box is not None:\n                bounding_box = [\n                    bounding_box[0] * img_width,\n                    bounding_box[1] * img_height,\n                    bounding_box[2] * img_width,\n                    bounding_box[3] * img_height\n                ]  # yapf: disable\n\n                # Calculate the size of bounding boxes\n                bbox_width = bounding_box[2] - bounding_box[0]\n                bbox_height = bounding_box[3] - bounding_box[1]\n                bbox_x_mid = (bounding_box[2] + bounding_box[0]) * .5\n                bbox_y_mid = (bounding_box[3] + bounding_box[1]) * .5\n\n                # Make the crop area as a square\n                square_object_size = max(bbox_width, bbox_height)\n                square_object_size = square_object_size * random.uniform(0.8, 1.2)\n\n                x_left = int(bbox_x_mid - square_object_size * random.uniform(.4, .6))\n                x_right = int(bbox_x_mid + square_object_size * random.uniform(.4, .6))\n                y_top = int(bbox_y_mid - square_object_size * random.uniform(.4, .6))\n                y_bottom = int(bbox_y_mid + square_object_size * random.uniform(.4, .6))\n\n                # If the crop position is out of the image, fix it with padding\n                pad_x_left = 0\n                if x_left < 0:\n                    pad_x_left = -x_left\n                    x_left = 0\n                pad_x_right = 0\n                if x_right >= img_width:\n                    pad_x_right = x_right - img_width + 1\n                    x_right = img_width - 1\n                pad_y_top = 0\n                if y_top < 0:\n                    pad_y_top = -y_top\n                    y_top = 0\n                pad_y_bottom = 0\n                if y_bottom >= img_height:\n                    pad_y_bottom = y_bottom - img_height + 1\n                    y_bottom = img_height - 1\n\n                # Padding the image and resize the image\n                processed_image = np.pad(img[y_top:y_bottom + 1, x_left:x_right + 1],\n                                         ((pad_y_top, pad_y_bottom), (pad_x_left, pad_x_right), (0, 0)),\n                                         mode='edge')\n                processed_image = cv2.resize(processed_image, (self.img_size_w, self.img_size_h))\n            else:\n                if img_height > self.crop_size_h and img_width > self.crop_size_w:\n                    x_left = int(img_width - self.crop_size_w) // 2\n                    x_right = int(x_left + self.crop_size_w)\n                    y_top = int(img_height - self.crop_size_h) // 2\n                    y_bottom = int(y_top + self.crop_size_h)\n                else:\n                    x_left = 0\n                    x_right = img_width\n                    y_top = 0\n                    y_bottom = img_height\n\n                processed_image = cv2.resize(img[y_top:y_bottom, x_left:x_right], (self.img_size_w, self.img_size_h))\n\n            processed_images = np.append(processed_images, [processed_image], axis=0)\n\n        return processed_images\n\n\nclass RandomFlip(object):\n    def __call__(self, rendering_images):\n        assert (isinstance(rendering_images, np.ndarray))\n\n        for img_idx, img in enumerate(rendering_images):\n            if random.randint(0, 1):\n                rendering_images[img_idx] = np.fliplr(img)\n\n        return rendering_images\n\n\nclass ColorJitter(object):\n    def __init__(self, brightness, contrast, saturation):\n        self.brightness = brightness\n        self.contrast = contrast\n        self.saturation = saturation\n\n    def __call__(self, rendering_images):\n        if len(rendering_images) == 0:\n            return rendering_images\n\n        # Allocate new space for storing processed images\n        img_height, img_width, img_channels = rendering_images[0].shape\n        processed_images = np.empty(shape=(0, img_height, img_width, img_channels))\n\n        # Randomize the value of changing brightness, contrast, and saturation\n        brightness = 1 + np.random.uniform(low=-self.brightness, high=self.brightness)\n        contrast = 1 + np.random.uniform(low=-self.contrast, high=self.contrast)\n        saturation = 1 + np.random.uniform(low=-self.saturation, high=self.saturation)\n\n        # Randomize the order of changing brightness, contrast, and saturation\n        attr_names = ['brightness', 'contrast', 'saturation']\n        attr_values = [brightness, contrast, saturation]    # The value of changing attrs\n        attr_indexes = np.array(range(len(attr_names)))    # The order of changing attrs\n        np.random.shuffle(attr_indexes)\n\n        for img_idx, img in enumerate(rendering_images):\n            processed_image = img\n            for idx in attr_indexes:\n                processed_image = self._adjust_image_attr(processed_image, attr_names[idx], attr_values[idx])\n\n            processed_images = np.append(processed_images, [processed_image], axis=0)\n            # print('ColorJitter', np.mean(ori_img), np.mean(processed_image))\n            # fig = plt.figure(figsize=(8, 4))\n            # ax1 = fig.add_subplot(1, 2, 1)\n            # ax1.imshow(ori_img)\n            # ax2 = fig.add_subplot(1, 2, 2)\n            # ax2.imshow(processed_image)\n            # plt.show()\n        return processed_images\n\n    def _adjust_image_attr(self, img, attr_name, attr_value):\n        \"\"\"\n        Adjust or randomize the specified attribute of the image\n\n        Args:\n            img: Image in BGR format\n                Numpy array of shape (h, w, 3)\n            attr_name: Image attribute to adjust or randomize\n                       'brightness', 'saturation', or 'contrast'\n            attr_value: the alpha for blending is randomly drawn from [1 - d, 1 + d]\n\n        Returns:\n            Output image in BGR format\n            Numpy array of the same shape as input\n        \"\"\"\n        gs = self._bgr_to_gray(img)\n\n        if attr_name == 'contrast':\n            img = self._alpha_blend(img, np.mean(gs[:, :, 0]), attr_value)\n        elif attr_name == 'saturation':\n            img = self._alpha_blend(img, gs, attr_value)\n        elif attr_name == 'brightness':\n            img = self._alpha_blend(img, 0, attr_value)\n        else:\n            raise NotImplementedError(attr_name)\n        return img\n\n    def _bgr_to_gray(self, bgr):\n        \"\"\"\n        Convert a RGB image to a grayscale image\n            Differences from cv2.cvtColor():\n                1. Input image can be float\n                2. Output image has three repeated channels, other than a single channel\n\n        Args:\n            bgr: Image in BGR format\n                 Numpy array of shape (h, w, 3)\n\n        Returns:\n            gs: Grayscale image\n                Numpy array of the same shape as input; the three channels are the same\n        \"\"\"\n        ch = 0.114 * bgr[:, :, 0] + 0.587 * bgr[:, :, 1] + 0.299 * bgr[:, :, 2]\n        gs = np.dstack((ch, ch, ch))\n        return gs\n\n    def _alpha_blend(self, im1, im2, alpha):\n        \"\"\"\n        Alpha blending of two images or one image and a scalar\n\n        Args:\n            im1, im2: Image or scalar\n                Numpy array and a scalar or two numpy arrays of the same shape\n            alpha: Weight of im1\n                Float ranging usually from 0 to 1\n\n        Returns:\n            im_blend: Blended image -- alpha * im1 + (1 - alpha) * im2\n                Numpy array of the same shape as input image\n        \"\"\"\n        im_blend = alpha * im1 + (1 - alpha) * im2\n        return im_blend\n\n\nclass RandomNoise(object):\n    def __init__(self,\n                 noise_std,\n                 eigvals=(0.2175, 0.0188, 0.0045),\n                 eigvecs=((-0.5675, 0.7192, 0.4009), (-0.5808, -0.0045, -0.8140), (-0.5836, -0.6948, 0.4203))):\n        self.noise_std = noise_std\n        self.eigvals = np.array(eigvals)\n        self.eigvecs = np.array(eigvecs)\n\n    def __call__(self, rendering_images):\n        alpha = np.random.normal(loc=0, scale=self.noise_std, size=3)\n        noise_rgb = \\\n            np.sum(\n                np.multiply(\n                    np.multiply(\n                        self.eigvecs,\n                        np.tile(alpha, (3, 1))\n                    ),\n                    np.tile(self.eigvals, (3, 1))\n                ),\n                axis=1\n            )\n\n        # Allocate new space for storing processed images\n        img_height, img_width, img_channels = rendering_images[0].shape\n        assert (img_channels == 3), \"Please use RandomBackground to normalize image channels\"\n        processed_images = np.empty(shape=(0, img_height, img_width, img_channels))\n\n        for img_idx, img in enumerate(rendering_images):\n            processed_image = img[:, :, ::-1]    # BGR -> RGB\n            for i in range(img_channels):\n                processed_image[:, :, i] += noise_rgb[i]\n\n            processed_image = processed_image[:, :, ::-1]    # RGB -> BGR\n            processed_images = np.append(processed_images, [processed_image], axis=0)\n            # from copy import deepcopy\n            # ori_img = deepcopy(img)\n            # print(noise_rgb, np.mean(processed_image), np.mean(ori_img))\n            # print('RandomNoise', np.mean(ori_img), np.mean(processed_image))\n            # fig = plt.figure(figsize=(8, 4))\n            # ax1 = fig.add_subplot(1, 2, 1)\n            # ax1.imshow(ori_img)\n            # ax2 = fig.add_subplot(1, 2, 2)\n            # ax2.imshow(processed_image)\n            # plt.show()\n        return processed_images\n\n\nclass RandomBackground(object):\n    def __init__(self, random_bg_color_range, random_bg_folder_path=None):\n        self.random_bg_color_range = random_bg_color_range\n        self.random_bg_files = []\n        if random_bg_folder_path is not None:\n            self.random_bg_files = os.listdir(random_bg_folder_path)\n            self.random_bg_files = [os.path.join(random_bg_folder_path, rbf) for rbf in self.random_bg_files]\n\n    def __call__(self, rendering_images):\n        if len(rendering_images) == 0:\n            return rendering_images\n\n        img_height, img_width, img_channels = rendering_images[0].shape\n        # If the image has the alpha channel, add the background\n        if not img_channels == 4:\n            return rendering_images\n\n        # Generate random background\n        r, g, b = np.array([\n            np.random.randint(self.random_bg_color_range[i][0], self.random_bg_color_range[i][1] + 1) for i in range(3)\n        ]) / 255.\n\n        random_bg = None\n        if len(self.random_bg_files) > 0:\n            random_bg_file_path = random.choice(self.random_bg_files)\n            random_bg = cv2.imread(random_bg_file_path).astype(np.float32) / 255.\n\n        # Apply random background\n        processed_images = np.empty(shape=(0, img_height, img_width, img_channels - 1))\n        for img_idx, img in enumerate(rendering_images):\n            alpha = (np.expand_dims(img[:, :, 3], axis=2) == 0).astype(np.float32)\n            img = img[:, :, :3]\n            bg_color = random_bg if random.randint(0, 1) and random_bg is not None else np.array([[[r, g, b]]])\n            img = alpha * bg_color + (1 - alpha) * img\n\n            processed_images = np.append(processed_images, [img], axis=0)\n\n        return processed_images\n"
  },
  {
    "path": "utils/dataset_analyzer.py",
    "content": "#!/usr/bin/python3\n# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nimport numpy as np\nimport os\nimport scipy.ndimage\nimport sys\n\nfrom datetime import datetime as dt\nfrom fnmatch import fnmatch\nfrom queue import Queue\n\n\ndef main():\n    if not len(sys.argv) == 2:\n        print('python dataset_analyzer.py input_file_folder')\n        sys.exit(1)\n\n    input_file_folder = sys.argv[1]\n    if not os.path.exists(input_file_folder) or not os.path.isdir(input_file_folder):\n        print('[ERROR] Input folder not exists!')\n        sys.exit(2)\n\n    FILE_NAME_PATTERN = '*.JPEG'\n    folders_to_explore = Queue()\n    folders_to_explore.put(input_file_folder)\n\n    total_files = 0\n    mean = np.asarray([0., 0., 0.])\n    std = np.asarray([0., 0., 0.])\n    while not folders_to_explore.empty():\n        current_folder = folders_to_explore.get()\n\n        if not os.path.exists(current_folder) or not os.path.isdir(current_folder):\n            print('[WARN] %s Ignore folder: %s' % (dt.now(), current_folder))\n            continue\n\n        print('[INFO] %s Listing files in folder: %s' % (dt.now(), current_folder))\n        n_folders = 0\n        n_files = 0\n        files = os.listdir(current_folder)\n        for file_name in files:\n            file_path = os.path.join(current_folder, file_name)\n            if os.path.isdir(file_path):\n                n_folders += 1\n                folders_to_explore.put(file_path)\n            elif os.path.isfile(file_path) and fnmatch(file_name, FILE_NAME_PATTERN):\n                n_files += 1\n                total_files += 1\n\n                img = scipy.ndimage.imread(file_path)\n                img_mean = np.mean(img, axis=(0, 1))\n                img_std = np.var(img, axis=(0, 1))\n                mean += img_mean\n                std += img_std\n        # print('[INFO] %s %d folders found, %d files found.' % (dt.now(), n_folders, n_files))\n    print('[INFO] %s Mean = %s, Std = %s' % (dt.now(), mean / total_files, np.sqrt(std) / total_files))\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "utils/network_utils.py",
    "content": "# -*- coding: utf-8 -*-\n#\n# Developed by Haozhe Xie <cshzxie@gmail.com>\n\nimport torch\n\nfrom datetime import datetime as dt\n\n\ndef var_or_cuda(x):\n    if torch.cuda.is_available():\n        x = x.cuda(non_blocking=True)\n\n    return x\n\n\ndef init_weights(m):\n    if type(m) == torch.nn.Conv2d or type(m) == torch.nn.Conv3d or type(m) == torch.nn.ConvTranspose3d:\n        torch.nn.init.kaiming_normal_(m.weight)\n        if m.bias is not None:\n            torch.nn.init.constant_(m.bias, 0)\n    elif type(m) == torch.nn.BatchNorm2d or type(m) == torch.nn.BatchNorm3d:\n        torch.nn.init.constant_(m.weight, 1)\n        torch.nn.init.constant_(m.bias, 0)\n    elif type(m) == torch.nn.Linear:\n        torch.nn.init.normal_(m.weight, 0, 0.01)\n        torch.nn.init.constant_(m.bias, 0)\n\n\ndef save_checkpoints(cfg, file_path, epoch_idx, encoder, encoder_solver, decoder, decoder_solver, refiner,\n                     refiner_solver, merger, merger_solver, best_iou, best_epoch):\n    print('[INFO] %s Saving checkpoint to %s ...' % (dt.now(), file_path))\n    checkpoint = {\n        'epoch_idx': epoch_idx,\n        'best_iou': best_iou,\n        'best_epoch': best_epoch,\n        'encoder_state_dict': encoder.state_dict(),\n        'encoder_solver_state_dict': encoder_solver.state_dict(),\n        'decoder_state_dict': decoder.state_dict(),\n        'decoder_solver_state_dict': decoder_solver.state_dict()\n    }\n\n    if cfg.NETWORK.USE_REFINER:\n        checkpoint['refiner_state_dict'] = refiner.state_dict()\n        checkpoint['refiner_solver_state_dict'] = refiner_solver.state_dict()\n    if cfg.NETWORK.USE_MERGER:\n        checkpoint['merger_state_dict'] = merger.state_dict()\n        checkpoint['merger_solver_state_dict'] = merger_solver.state_dict()\n\n    torch.save(checkpoint, file_path)\n\n\ndef count_parameters(model):\n    return sum(p.numel() for p in model.parameters())\n\n\nclass AverageMeter(object):\n    \"\"\"Computes and stores the average and current value\"\"\"\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.val = 0\n        self.avg = 0\n        self.sum = 0\n        self.count = 0\n\n    def update(self, val, n=1):\n        self.val = val\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n"
  }
]