[
  {
    "path": ".gitignore",
    "content": "*.jpg\n*.png\ntags\n*.pyc\n*.pyo\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "MIT License\n\nCopyright (c) 2018 Guanying Chen\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# SDPS-Net\n**[SDPS-Net: Self-calibrating Deep Photometric Stereo Networks, CVPR 2019 (Oral)](http://guanyingc.github.io/SDPS-Net/)**.\n<br>\n[Guanying Chen](https://guanyingc.github.io), [Kai Han](http://www.hankai.org/), [Boxin Shi](http://alumni.media.mit.edu/~shiboxin/), [Yasuyuki Matsushita](http://www-infobiz.ist.osaka-u.ac.jp/en/member/matsushita/), [Kwan-Yee K. Wong](http://i.cs.hku.hk/~kykwong/)\n<br>\n\nThis paper addresses the problem of learning based _uncalibrated_ photometric stereo for non-Lambertian surface.\n<br>\n<p align=\"center\">\n    <img src='data/images/buddha.gif' height=\"250\" >\n    <img src='data/images/GT.png' height=\"250\" >\n</p>\n\n### _Changelog_\n- July 28, 2019: We have already updated the code to support Python 3.7 + PyTorch 1.10. To run the previous version (Python 2.7 + PyTorch 0.40), please checkout to `python2.7` branch first (by `git checkout python2.7`).\n\n## Dependencies\nSDPS-Net is implemented in [PyTorch](https://pytorch.org/) and tested with Ubuntu (14.04 and 16.04), please install PyTorch first following the official instruction. \n\n- Python 3.7 \n- PyTorch (version = 1.10)\n- torchvision\n- CUDA-9.0 # Skip this if you only have CPUs in your computer\n- numpy\n- scipy\n- scikit-image \n\nYou are highly recommended to use Anaconda and create a new environment to run this code.\n```shell\n# Create a new python3.7 environment named py3.7\nconda create -n py3.7 python=3.7\n\n# Activate the created environment\nsource activate py3.7\n\n# Example commands for installing the dependencies \nconda install pytorch torchvision cudatoolkit=9.0 -c pytorch\nconda install -c anaconda scipy \nconda install -c anaconda scikit-image \n\n# Download this code\ngit clone https://github.com/guanyingc/SDPS-Net.git\ncd SDPS-Net\n```\n## Overview\nWe provide:\n- Trained models\n    - LCNet for lighting calibration from input images\n    - NENet for normal estimation from input images and estimated lightings.\n- Code to test on DiLiGenT main dataset\n- Full code to train a new model, including codes for debugging, visualization and logging.\n\n## Testing\n### Download the trained models\n```\nsh scripts/download_pretrained_models.sh\n```\nIf the above command is not working, please manually download the trained models from BaiduYun ([LCNet and NENet](https://pan.baidu.com/s/10huOyPkfDSkDUK23_j4y1w?pwd=i5ha)) and put them in `./data/models/`.\n\n### Test SDPS-Net on the DiLiGenT main dataset\n```shell\n# Prepare the DiLiGenT main dataset\nsh scripts/prepare_diligent_dataset.sh\n# This command will first download and unzip the DiLiGenT dataset, and then centered crop \n# the original images based on the object mask with a margin size of 15 pixels.\n\n# Test SDPS-Net on DiLiGenT main dataset using all of the 96 image\nCUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar\n# Please check the outputs in data/models/\n\n# If you only have CPUs, please add the argument \"--cuda\" to disable the usage of GPU\npython eval/run_stage2.py --cuda --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar\n```\n\n### Test SDPS-Net on your own dataset\nYou have two options to test our method on your dataset. In the first option, you have to implement a customized Dataset class to load your data, which should not be difficult. Please refer to `datasets/UPS_DiLiGenT_main.py` for an example that loads the DiLiGenT main dataset.\n\nIf you don't want to implement your own Dataset class, you may try our `datasets/UPS_Custom_Dataset.py`. However, you have to first arrange your dataset in the same format as the `data/ToyPSDataset/`. Then you can call the following commands.\n```shell\nCUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar --benchmark UPS_Custom_Dataset --bm_dir /path/to/your/dataset\n\n# To test SDPS-Net on the ToyPSDataset, simply run\nCUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar --benchmark UPS_Custom_Dataset --bm_dir data/ToyPSDataset/\n# Please check the outputs in data/models/\n```\nYou may find input arguments in `run_model_opts.py` (particularly `--have_l_dirs`, `--have_l_ints`, and `--have_gt_n`) useful when testing your own dataset.\n\n## Training\nWe adopted the publicly available synthetic [PS Blobby and Sculpture datasets](https://github.com/guanyingc/PS-FCN) for training.\nTo train a new SDPS-Net model, please follow the following steps:\n\n### Download the training data\n```shell\n# The total size of the zipped synthetic datasets is 4.7+19=23.7 GB \n# and it takes some times to download and unzip the datasets.\nsh scripts/download_synthetic_datasets.sh\n```\nIf the above command is not working, please manually download the training datasets from BaiduYun ([PS Sculpture Dataset and PS Blobby Dataset](https://pan.baidu.com/s/1WUVu9ibIBh4wM1shTXBuNw?pwd=snyc) and put them in `./data/datasets/`.\n\n### First stage: train Lighting Calibration Network (LCNet)\n```shell\n# Train LCNet on synthetic datasets using 32 input images\nCUDA_VISIBLE_DEVICES=0 python main_stage1.py --in_img_num 32\n# Please refer to options/base_opt.py and options/stage1_opt.py for more options\n\n# You can find checkpoints and results in data/logdir/\n# It takes about 20 hours to train LCNet on a single Titan X Pascal GPU.\n```\n### Second stage: train Normal Estimation Network (NENet)\n```shell\n# Train NENet on synthetic datasets using 32 input images\nCUDA_VISIBLE_DEVICES=0 python main_stage2.py --in_img_num 32 --retrain data/logdir/path/to/checkpointDirOfLCNet/checkpoint20.pth.tar\n# Please refer to options/base_opt.py and options/stage2_opt.py for more options\n\n# You can find checkpoints and results in data/logdir/\n# It takes about 26 hours to train NENet on a single Titan X Pascal GPU.\n```\n\n## FAQ\n\n#### Q1: How to test SDPS-Net on other datasets?\n- You can implement a customized Dataset class to load your data. You may also use the provided `datasets/UPS_Custom_Dataset.py` Dataset class to load your data. However, you have to first arrange your dataset in the same format as the `data/ToyPSDataset/`. Precomputed results on DiLiGenT main dataset, Gourd\\&Apple dataset, Light Stage Dataset and Synthetic Test dataset are available upon request.\n\n#### Q2: What should I do if I have problem in running your code?\n- Please create an issue if you encounter errors when trying to run the code. Please also feel free to submit a bug report.\n\n#### Q3: Could I run your code only using CPUs?\n- The good news is that you can simply append `--cuda` in your command to disable the usage of GPU. The running time for the testing on DiLiGenT benchmark using CPUs is still bearable (should be less than 20 minutes). However, it is EXTREMELY SLOW for training! \n\n## Citation\nIf you find this code or the provided models useful in your research, please consider cite: \n```\n@inproceedings{chen2019SDPS_Net,\n  title={SDPS-Net: Self-calibrating Deep Photometric Stereo Networks},\n  author={Chen, Guanying and Han, Kai and Shi, Boxin and Matsushita, Yasuyuki and Wong, Kwan-Yee~K.},\n  booktitle={CVPR},\n  year={2019}\n}\n```\n"
  },
  {
    "path": "data/.gitignore",
    "content": "*\n!.gitignore\n"
  },
  {
    "path": "datasets/UPS_Custom_Dataset.py",
    "content": "from __future__ import division\nimport os\nimport numpy as np\nimport scipy.io as sio\nfrom imageio import imread\n\nimport torch\nimport torch.utils.data as data\n\nfrom datasets import pms_transforms\nfrom . import util\nnp.random.seed(0)\n\nclass UPS_Custom_Dataset(data.Dataset):\n    def __init__(self, args, split='train'):\n        self.root   = os.path.join(args.bm_dir)\n        self.split  = split\n        self.args   = args\n        self.objs   = util.readList(os.path.join(self.root, 'objects.txt'), sort=False)\n        args.log.printWrite('[%s Data] \\t%d objs. Root: %s' % (split, len(self.objs), self.root))\n\n    def _getMask(self, obj):\n        mask = imread(os.path.join(self.root, obj, 'mask.png'))\n        if mask.ndim > 2: mask = mask[:,:,0]\n        mask = mask.reshape(mask.shape[0], mask.shape[1], 1)\n        return mask / 255.0\n\n    def __getitem__(self, index):\n        obj   = self.objs[index]\n        names  = util.readList(os.path.join(self.root, obj, 'names.txt'))\n        img_list   = [os.path.join(self.root, obj, names[i]) for i in range(len(names))]\n\n        if self.args.have_l_dirs:\n            dirs = np.genfromtxt(os.path.join(self.root, obj, 'light_directions.txt'))\n        else:\n            dirs = np.zeros((len(names), 3))\n            dirs[:,2] = 1\n        \n        if self.args.have_l_ints:\n            ints = np.genfromtxt(os.path.join(self.root, obj, 'light_intensities.txt'))\n        else:\n            ints = np.ones((len(names), 3))\n\n        imgs = []\n        for idx, img_name in enumerate(img_list):\n            img = imread(img_name).astype(np.float32) / 255.0\n            imgs.append(img)\n        img = np.concatenate(imgs, 2)\n        h, w, c = img.shape\n\n        if self.args.have_gt_n:\n            normal_path = os.path.join(self.root, obj, 'normal.png')\n            normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1\n        else:\n            normal = np.zeros((h, w, 3))\n\n        mask = self._getMask(obj)\n        img  = img * mask.repeat(img.shape[2], 2)\n\n        item = {'normal': normal, 'img': img, 'mask': mask}\n\n        downsample = 4 \n        for k in item.keys():\n            item[k] = pms_transforms.imgSizeToFactorOfK(item[k], downsample)\n\n        for k in item.keys(): \n            item[k] = pms_transforms.arrayToTensor(item[k])\n\n        item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()\n        item['ints'] = torch.from_numpy(ints).view(-1, 1, 1).float()\n        item['obj'] = obj\n        item['path'] = os.path.join(self.root, obj)\n        return item\n\n    def __len__(self):\n        return len(self.objs)\n"
  },
  {
    "path": "datasets/UPS_DiLiGenT_main.py",
    "content": "from __future__ import division\nimport os\nimport numpy as np\nimport scipy.io as sio\n#from scipy.ndimage import imread\nfrom imageio import imread\n\nimport torch\nimport torch.utils.data as data\n\nfrom datasets import pms_transforms\nfrom . import util\nnp.random.seed(0)\n\nclass UPS_DiLiGenT_main(data.Dataset):\n    def __init__(self, args, split='train'):\n        self.root  = os.path.join(args.bm_dir)\n        self.split = split\n        self.args  = args\n        self.objs  = util.readList(os.path.join(self.root, 'objects.txt'), sort=False)\n        self.names = util.readList(os.path.join(self.root, 'names.txt'),   sort=False)\n        self.l_dir = util.light_source_directions()\n        args.log.printWrite('[%s Data] \\t%d objs %d lights. Root: %s' % \n                (split, len(self.objs), len(self.names), self.root))\n        self.ints = {}\n        ints_name = 'light_intensities.txt'\n        print('Files for intensity: %s' % (ints_name))\n        for obj in self.objs:\n            self.ints[obj] = np.genfromtxt(os.path.join(self.root, obj, ints_name))\n\n    def _getMask(self, obj):\n        mask = imread(os.path.join(self.root, obj, 'mask.png'))\n        if mask.ndim > 2: mask = mask[:,:,0]\n        mask = mask.reshape(mask.shape[0], mask.shape[1], 1)\n        return mask / 255.0\n\n    def __getitem__(self, index):\n        np.random.seed(index)\n        obj = self.objs[index]\n        select_idx = range(len(self.names))\n\n        img_list = [os.path.join(self.root, obj, self.names[i]) for i in select_idx]\n        ints = [np.diag(1 / self.ints[obj][i]) for i in select_idx]\n        dirs = self.l_dir[select_idx]\n\n        normal_path = os.path.join(self.root, obj, 'Normal_gt.mat')\n        normal = sio.loadmat(normal_path)['Normal_gt']\n\n        imgs = []\n        for idx, img_name in enumerate(img_list):\n            img = imread(img_name).astype(np.float32) / 255.0\n            if self.args.in_light and not self.args.int_aug:\n                img = np.dot(img, ints[idx])\n            imgs.append(img)\n        img = np.concatenate(imgs, 2)\n\n        mask = self._getMask(obj)\n        if self.args.test_resc:\n            img, normal = pms_transforms.rescale(img, normal, [self.args.test_h, self.args.test_w])\n            mask = pms_transforms.rescaleSingle(mask, [self.args.test_h, self.args.test_w])\n\n        img = img * mask.repeat(img.shape[2], 2)\n\n        norm = np.sqrt((normal * normal).sum(2, keepdims=True))\n        normal = normal / (norm + 1e-10)\n\n        item = {'normal': normal, 'img': img, 'mask': mask}\n\n        downsample = 4 \n        for k in item.keys():\n            item[k] = pms_transforms.imgSizeToFactorOfK(item[k], downsample)\n\n        for k in item.keys(): \n            item[k] = pms_transforms.arrayToTensor(item[k])\n\n        item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()\n        item['ints'] = torch.from_numpy(self.ints[obj][select_idx]).view(-1, 1, 1).float()\n\n        item['obj'] = obj\n        item['path'] = os.path.join(self.root, obj)\n        return item\n\n    def __len__(self):\n        return len(self.objs)\n"
  },
  {
    "path": "datasets/UPS_Synth_Dataset.py",
    "content": "from __future__ import division\nimport os\nimport numpy as np\n#from scipy.ndimage import imread\nfrom imageio import imread\n\nimport torch\nimport torch.utils.data as data\n\nfrom datasets import pms_transforms\nfrom . import util\nnp.random.seed(0)\n\nclass UPS_Synth_Dataset(data.Dataset):\n    def __init__(self, args, root, split='train'):\n        self.root  = os.path.join(root)\n        self.split = split\n        self.args  = args\n        self.shape_list = util.readList(os.path.join(self.root, split + args.l_suffix))\n\n    def _getInputPath(self, index):\n        shape, mtrl = self.shape_list[index].split('/')\n        normal_path = os.path.join(self.root, 'Images', shape, shape + '_normal.png')\n        img_dir     = os.path.join(self.root, 'Images', self.shape_list[index])\n        img_list    = util.readList(os.path.join(img_dir, '%s_%s.txt' % (shape, mtrl)))\n\n        data = np.genfromtxt(img_list, dtype='str', delimiter=' ')\n        select_idx = np.random.permutation(data.shape[0])[:self.args.in_img_num]\n        idxs = ['%03d' % (idx) for idx in select_idx]\n        data = data[select_idx, :]\n        imgs = [os.path.join(img_dir, img) for img in data[:, 0]]\n        dirs = data[:, 1:4].astype(np.float32)\n        return normal_path, imgs, dirs\n\n    def __getitem__(self, index):\n        normal_path, img_list, dirs = self._getInputPath(index)\n        normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1\n        imgs   =  []\n        for i in img_list:\n            img = imread(i).astype(np.float32) / 255.0\n            imgs.append(img)\n        img = np.concatenate(imgs, 2)\n\n        h, w, c = img.shape\n        crop_h, crop_w = self.args.crop_h, self.args.crop_w\n        if self.args.rescale and not (crop_h == h):\n            sc_h = np.random.randint(crop_h, h) if self.args.rand_sc else self.args.scale_h\n            sc_w = np.random.randint(crop_w, w) if self.args.rand_sc else self.args.scale_w\n            img, normal = pms_transforms.rescale(img, normal, [sc_h, sc_w])\n\n        if self.args.crop:\n            img, normal = pms_transforms.randomCrop(img, normal, [crop_h, crop_w])\n\n        if self.args.color_aug:\n            img = img * np.random.uniform(1, self.args.color_ratio)\n\n        if self.args.int_aug:\n            ints = pms_transforms.getIntensity(len(imgs))\n            img  = np.dot(img, np.diag(ints.reshape(-1)))\n        else:\n            ints = np.ones(c)\n\n        if self.args.noise_aug:\n            img = pms_transforms.randomNoiseAug(img, self.args.noise)\n\n        mask   = pms_transforms.normalToMask(normal)\n        normal = normal * mask.repeat(3, 2) \n        norm   = np.sqrt((normal * normal).sum(2, keepdims=True))\n        normal = normal / (norm + 1e-10) # Rescale normal to unit length\n\n        item = {'normal': normal, 'img': img, 'mask': mask}\n        for k in item.keys(): \n            item[k] = pms_transforms.arrayToTensor(item[k])\n\n        item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()\n        item['ints'] = torch.from_numpy(ints).view(-1, 1, 1).float()\n        return item\n\n    def __len__(self):\n        return len(self.shape_list)\n"
  },
  {
    "path": "datasets/__init__.py",
    "content": "#from .PMS_dataset_v1 import PMS_dataset\n#from .PMS_dataset_v2 import PMS_data_v2\n#from .DiLiGenT import DiLiGenT\n#__all__ = ('PMS_dataset', 'DiLiGenT', 'PMS_data_v2')\n"
  },
  {
    "path": "datasets/custom_data_loader.py",
    "content": "import torch.utils.data\n\ndef customDataloader(args):\n    args.log.printWrite(\"=> fetching img pairs in %s\" % (args.data_dir))\n    datasets = __import__('datasets.' + args.dataset)\n    dataset_file = getattr(datasets, args.dataset)\n    train_set = getattr(dataset_file, args.dataset)(args, args.data_dir, 'train')\n    val_set   = getattr(dataset_file, args.dataset)(args, args.data_dir, 'val')\n\n    if args.concat_data:\n        args.log.printWrite('****** Using cocnat data ******')\n        args.log.printWrite(\"=> fetching img pairs in '{}'\".format(args.data_dir2))\n        train_set2 = getattr(dataset_file, args.dataset)(args, args.data_dir2, 'train')\n        val_set2   = getattr(dataset_file, args.dataset)(args, args.data_dir2, 'val')\n        train_set  = torch.utils.data.ConcatDataset([train_set, train_set2])\n        val_set    = torch.utils.data.ConcatDataset([val_set,   val_set2])\n\n    args.log.printWrite('Found Data:\\t %d Train and %d Val' % (len(train_set), len(val_set)))\n    args.log.printWrite('\\t Train Batch: %d, Val Batch: %d' % (args.batch, args.val_batch))\n\n    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch,\n        num_workers=args.workers, pin_memory=args.cuda, shuffle=True)\n    test_loader  = torch.utils.data.DataLoader(val_set , batch_size=args.val_batch,\n        num_workers=args.workers, pin_memory=args.cuda, shuffle=False)\n    return train_loader, test_loader\n\ndef benchmarkLoader(args):\n    args.log.printWrite(\"=> fetching img pairs in 'data/%s'\" % (args.benchmark))\n    datasets = __import__('datasets.' + args.benchmark)\n    dataset_file = getattr(datasets, args.benchmark)\n    test_set = getattr(dataset_file, args.benchmark)(args, 'test')\n\n    args.log.printWrite('Found Benchmark Data: %d samples' % (len(test_set)))\n    args.log.printWrite('\\t Test Batch %d' % (args.test_batch))\n\n    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.test_batch,\n        num_workers=args.workers, pin_memory=args.cuda, shuffle=False)\n    return test_loader\n"
  },
  {
    "path": "datasets/pms_transforms.py",
    "content": "import torch\nimport random\nimport numpy as np\nfrom skimage.transform import resize\nrandom.seed(0)\nnp.random.seed(0)\n\ndef arrayToTensor(array):\n    if array is None:\n        return array\n    array = np.transpose(array, (2, 0, 1))\n    tensor = torch.from_numpy(array)\n    return tensor.float()\n\ndef normalToMask(normal, thres=1e-2):\n    \"\"\"\n    Due to the numerical precision of uint8, [0, 0, 0] will save as [127, 127, 127] in gt normal,\n    When we load the data and rescale normal by N / 255 * 2 - 1, [127, 127, 127] becomes \n    [-0.003927, -0.003927, -0.003927]\n    \"\"\"\n    mask = (np.square(normal).sum(2, keepdims=True) > thres).astype(np.float32)\n    return mask\n\ndef imgSizeToFactorOfK(img, k):\n    if img.shape[0] % k == 0 and img.shape[1] % k == 0:\n        return img\n    pad_h, pad_w = k - img.shape[0] % k, k - img.shape[1] % k\n    img = np.pad(img, ((0, pad_h), (0, pad_w), (0,0)), \n            'constant', constant_values=((0,0),(0,0),(0,0)))\n    return img\n\ndef randomCrop(inputs, target, size):\n    h, w, _ = inputs.shape\n    c_h, c_w = size\n    if h == c_h and w == c_w:\n        return inputs, target\n    x1 = random.randint(0, w - c_w)\n    y1 = random.randint(0, h - c_h)\n    inputs = inputs[y1: y1 + c_h, x1: x1 + c_w]\n    target = target[y1: y1 + c_h, x1: x1 + c_w]\n    return inputs, target\n\ndef centerCrop(inputs, size):\n    h, w, _ = inputs.shape\n    c_h, c_w = size\n    if h != c_h or w != c_w:\n        x1 = int(w / 2 - c_w / 2)\n        y1 = int(h / 2 - c_h / 2)\n        inputs = inputs[y1: y1 + c_h, x1: x1 + c_w]\n    return inputs\n\ndef rescale(inputs, target, size):\n    in_h, in_w, _ = inputs.shape\n    h, w = size\n    if h != in_h or w != in_w:\n        inputs = resize(inputs, size, order=1, mode='reflect')\n        target = resize(target, size, order=1, mode='reflect')\n    return inputs, target\n\ndef rescaleSingle(inputs, size, order=1):\n    in_h, in_w, _ = inputs.shape\n    h, w = size\n    if h != in_h or w != in_w:\n        inputs = resize(inputs, size, order=order, mode='reflect')\n    return inputs\n\ndef randomNoiseAug(inputs, noise_level=0.05):\n    noise = np.random.random(inputs.shape)\n    noise = (noise - 0.5) * noise_level\n    inputs += noise\n    return inputs\n\ndef getIntensity(num):\n    intensity = np.random.random((num, 1)) * 1.8 + 0.2\n    color = np.ones((1, 3)) # Uniform color\n    intens = (intensity.repeat(3, 1) * color)\n    return intens\n"
  },
  {
    "path": "datasets/util.py",
    "content": "import numpy as np\nimport re\n\ndef atoi(text):\n    return int(text) if text.isdigit() else text\n\ndef natural_keys(text):\n    '''\n    alist.sort(key=natural_keys) sorts in human order\n    http://nedbatchelder.com/blog/200712/human_sorting.html\n    (See Toothy's implementation in the comments)\n    '''\n    return [ atoi(c) for c in re.split('(\\d+)', text) ]\n\ndef readList(list_path,ignore_head=False, sort=True):\n    lists = []\n    with open(list_path) as f:\n        lists = f.read().splitlines()\n    if ignore_head:\n        lists = lists[1:]\n    if sort:\n        lists.sort(key=natural_keys)\n    return lists\n\ndef light_source_directions():\n    \"\"\"\n    Below matrix is from DiLiGenT.\n    :return: light source direction matrix. [light_num, 3]\n    :rtype: np.ndarray\n    \"\"\"\n    L = np.array([[-0.06059872, -0.44839055, 0.8917812],\n                  [-0.05939919, -0.33739538, 0.93948714],\n                  [-0.05710194, -0.21230722, 0.97553319],\n                  [-0.05360061, -0.07800089, 0.99551134],\n                  [-0.04919816, 0.05869781, 0.99706274],\n                  [-0.04399823, 0.19019233, 0.98076044],\n                  [-0.03839991, 0.31049925, 0.9497977],\n                  [-0.03280081, 0.41611025, 0.90872238],\n                  [-0.18449839, -0.43989616, 0.87889232],\n                  [-0.18870114, -0.32950199, 0.92510557],\n                  [-0.1901994, -0.20549935, 0.95999698],\n                  [-0.18849605, -0.07269848, 0.97937948],\n                  [-0.18329657, 0.06229884, 0.98108166],\n                  [-0.17500445, 0.19220488, 0.96562453],\n                  [-0.16449474, 0.31129005, 0.93597008],\n                  [-0.15270716, 0.4160195, 0.89644202],\n                  [-0.30139786, -0.42509698, 0.85349393],\n                  [-0.31020115, -0.31660118, 0.89640333],\n                  [-0.31489186, -0.19549495, 0.92877599],\n                  [-0.31450962, -0.06640203, 0.94692897],\n                  [-0.30880699, 0.06470146, 0.94892147],\n                  [-0.2981084, 0.19100538, 0.93522635],\n                  [-0.28359251, 0.30729189, 0.90837601],\n                  [-0.26670649, 0.41020998, 0.87212122],\n                  [-0.40709586, -0.40559588, 0.81839168],\n                  [-0.41919869, -0.29999906, 0.85689732],\n                  [-0.42618633, -0.18329412, 0.88587159],\n                  [-0.42691512, -0.05950211, 0.90233197],\n                  [-0.42090385, 0.0659006, 0.90470827],\n                  [-0.40860354, 0.18720162, 0.89330773],\n                  [-0.39141794, 0.29941372, 0.87013988],\n                  [-0.3707838, 0.39958255, 0.83836338],\n                  [-0.499596, -0.38319693, 0.77689378],\n                  [-0.51360334, -0.28130183, 0.81060526],\n                  [-0.52190667, -0.16990217, 0.83591069],\n                  [-0.52326874, -0.05249686, 0.85054918],\n                  [-0.51720021, 0.06620003, 0.85330035],\n                  [-0.50428312, 0.18139393, 0.84427174],\n                  [-0.48561334, 0.28870793, 0.82512267],\n                  [-0.46289771, 0.38549809, 0.79819605],\n                  [-0.57853599, -0.35932235, 0.73224555],\n                  [-0.59329349, -0.26189713, 0.76119165],\n                  [-0.60202327, -0.15630604, 0.78303027],\n                  [-0.6037003, -0.04570002, 0.7959004],\n                  [-0.59781529, 0.06590169, 0.79892043],\n                  [-0.58486953, 0.17439091, 0.79215873],\n                  [-0.56588359, 0.27639198, 0.77677747],\n                  [-0.54241965, 0.36921337, 0.75462733],\n                  [0.05220076, -0.43870637, 0.89711304],\n                  [0.05199786, -0.33138635, 0.9420612],\n                  [0.05109826, -0.20999284, 0.97636672],\n                  [0.04919919, -0.07869871, 0.99568366],\n                  [0.04640163, 0.05630197, 0.99733494],\n                  [0.04279892, 0.18779527, 0.98127529],\n                  [0.03870043, 0.30950341, 0.95011048],\n                  [0.03440055, 0.41730662, 0.90811441],\n                  [0.17290651, -0.43181626, 0.88523333],\n                  [0.17839998, -0.32509996, 0.92869988],\n                  [0.18160174, -0.20480196, 0.96180921],\n                  [0.18200745, -0.07490306, 0.98044012],\n                  [0.17919505, 0.05849838, 0.98207285],\n                  [0.17329685, 0.18839658, 0.96668244],\n                  [0.1649036, 0.30880674, 0.93672045],\n                  [0.1549931, 0.41578148, 0.89616009],\n                  [0.28720483, -0.41910705, 0.8613145],\n                  [0.29740177, -0.31410186, 0.90160535],\n                  [0.30420604, -0.1965039, 0.9321185],\n                  [0.30640529, -0.07010121, 0.94931639],\n                  [0.30361153, 0.05950226, 0.95093613],\n                  [0.29588748, 0.18589214, 0.93696036],\n                  [0.28409783, 0.30349768, 0.90949304],\n                  [0.26939905, 0.40849857, 0.87209694],\n                  [0.39120402, -0.40190413, 0.8279085],\n                  [0.40481085, -0.29960803, 0.86392315],\n                  [0.41411685, -0.18590756, 0.89103626],\n                  [0.41769724, -0.06449957, 0.906294],\n                  [0.41498764, 0.05959822, 0.90787296],\n                  [0.40607977, 0.18089099, 0.89575537],\n                  [0.39179226, 0.29439419, 0.87168279],\n                  [0.37379609, 0.39649585, 0.83849122],\n                  [0.48278794, -0.38169046, 0.78818031],\n                  [0.49848546, -0.28279175, 0.8194761],\n                  [0.50918069, -0.1740934, 0.84286803],\n                  [0.51360856, -0.05870098, 0.85601427],\n                  [0.51097962, 0.05899765, 0.8575658],\n                  [0.50151639, 0.17420569, 0.84742769],\n                  [0.48600297, 0.28260173, 0.82700506],\n                  [0.46600106, 0.38110087, 0.79850181],\n                  [0.56150442, -0.35990283, 0.74510586],\n                  [0.57807114, -0.26498677, 0.77176147],\n                  [0.58933134, -0.1617086, 0.7915421],\n                  [0.59407609, -0.05289787, 0.80266769],\n                  [0.59157958, 0.057798, 0.80417224],\n                  [0.58198189, 0.16649482, 0.79597523],\n                  [0.56620006, 0.26940003, 0.77900008],\n                  [0.54551481, 0.36380988, 0.7550205]], dtype=float)\n    return L\n"
  },
  {
    "path": "eval/run_stage1.py",
    "content": "import torch, sys\nsys.path.append('.')\n\nfrom datasets import custom_data_loader\nfrom options  import run_model_opts\nfrom models   import custom_model\nfrom utils    import logger, recorders\n\nimport test_stage1 as test_utils\n\nargs = run_model_opts.RunModelOpts().parse()\nlog  = logger.Logger(args)\n\ndef main(args):\n    test_loader = custom_data_loader.benchmarkLoader(args)\n    model    = custom_model.buildModel(args)\n    recorder = recorders.Records(args.log_dir)\n    test_utils.test(args, 'test', test_loader, model, log, 1, recorder)\n    log.plotCurves(recorder, 'test')\n\nif __name__ == '__main__':\n    torch.manual_seed(args.seed)\n    main(args)\n"
  },
  {
    "path": "eval/run_stage2.py",
    "content": "import torch, sys\nsys.path.append('.')\n\nfrom datasets import custom_data_loader\nfrom options  import run_model_opts\nfrom models   import custom_model\nfrom utils    import logger, recorders\n\nimport test_stage2 as test_utils\n\nargs = run_model_opts.RunModelOpts().parse()\nargs.stage2    = True\nargs.test_resc = False\nlog  = logger.Logger(args)\n\ndef main(args):\n    test_loader = custom_data_loader.benchmarkLoader(args)\n    model = custom_model.buildModel(args)\n    model_s2 = custom_model.buildModelStage2(args)\n    models = [model, model_s2]\n\n    recorder = recorders.Records(args.log_dir)\n    test_utils.test(args, 'test', test_loader, models, log, 1, recorder)\n    log.plotCurves(recorder, 'test')\n\nif __name__ == '__main__':\n    torch.manual_seed(args.seed)\n    main(args)\n"
  },
  {
    "path": "main_stage1.py",
    "content": "import torch\nfrom options  import stage1_opts\nfrom utils    import logger, recorders\nfrom datasets import custom_data_loader\nfrom models   import custom_model, solver_utils, model_utils\n\nimport train_stage1 as train_utils\nimport test_stage1  as test_utils\n\nargs = stage1_opts.TrainOpts().parse()\nlog  = logger.Logger(args)\n\ndef main(args):\n    model = custom_model.buildModel(args)\n    optimizer, scheduler, records = solver_utils.configOptimizer(args, model)\n    criterion = solver_utils.Stage1ClsCrit(args)\n    recorder  = recorders.Records(args.log_dir, records)\n\n    train_loader, val_loader = custom_data_loader.customDataloader(args)\n\n    for epoch in range(args.start_epoch, args.epochs+1):\n        scheduler.step()\n        recorder.insertRecord('train', 'lr', epoch, scheduler.get_lr()[0])\n\n        train_utils.train(args, train_loader, model, criterion, optimizer, log, epoch, recorder)\n        if epoch % args.save_intv == 0: \n            model_utils.saveCheckpoint(args.cp_dir, epoch, model, optimizer, recorder.records, args)\n        log.plotCurves(recorder, 'train')\n\n        if epoch % args.val_intv == 0:\n            test_utils.test(args, 'val', val_loader, model, log, epoch, recorder)\n            log.plotCurves(recorder, 'val')\n\nif __name__ == '__main__':\n    torch.manual_seed(args.seed)\n    main(args)\n"
  },
  {
    "path": "main_stage2.py",
    "content": "import torch\nfrom options  import stage2_opts\nfrom utils    import logger, recorders\nfrom datasets import custom_data_loader\nfrom models   import custom_model, solver_utils, model_utils\n\nimport train_stage2 as train_utils\nimport test_stage2 as test_utils\n\nargs = stage2_opts.TrainOpts().parse()\nlog  = logger.Logger(args)\n\ndef main(args):\n    model = custom_model.buildModel(args)\n    model_s2 = custom_model.buildModelStage2(args)\n    models = [model, model_s2]\n\n    optimizer, scheduler, records = solver_utils.configOptimizer(args, model_s2)\n    optimizers = [optimizer, -1]\n    criterion = solver_utils.Stage2Crit(args)\n    recorder  = recorders.Records(args.log_dir, records)\n\n    train_loader, val_loader = custom_data_loader.customDataloader(args)\n\n    for epoch in range(args.start_epoch, args.epochs+1):\n        scheduler.step()\n\n        recorder.insertRecord('train', 'lr', epoch, scheduler.get_lr()[0])\n\n        train_utils.train(args, train_loader, models, criterion, optimizers, log, epoch, recorder)\n        if epoch % args.save_intv == 0: \n            model_utils.saveCheckpoint(args.cp_dir, epoch, model_s2, optimizer, recorder.records, args)\n        log.plotCurves(recorder, 'train')\n\n        if epoch % args.val_intv == 0:\n            test_utils.test(args, 'val', val_loader, models, log, epoch, recorder)\n            log.plotCurves(recorder, 'val')\n\nif __name__ == '__main__':\n    torch.manual_seed(args.seed)\n    main(args)\n"
  },
  {
    "path": "models/LCNet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn.init import kaiming_normal_\nfrom . import model_utils\nfrom utils import eval_utils\n\n# Classification\nclass FeatExtractor(nn.Module):\n    def __init__(self, batchNorm, c_in, c_out=256):\n        super(FeatExtractor, self).__init__()\n        self.conv1 = model_utils.conv(batchNorm, c_in, 64,    k=3, stride=2, pad=1)\n        self.conv2 = model_utils.conv(batchNorm, 64,   128,   k=3, stride=2, pad=1)\n        self.conv3 = model_utils.conv(batchNorm, 128,  128,   k=3, stride=1, pad=1)\n        self.conv4 = model_utils.conv(batchNorm, 128,  128,   k=3, stride=2, pad=1)\n        self.conv5 = model_utils.conv(batchNorm, 128,  128,   k=3, stride=1, pad=1)\n        self.conv6 = model_utils.conv(batchNorm, 128,  256,   k=3, stride=2, pad=1)\n        self.conv7 = model_utils.conv(batchNorm, 256,  256,   k=3, stride=1, pad=1)\n\n    def forward(self, inputs):\n        out = self.conv1(inputs)\n        out = self.conv2(out)\n        out = self.conv3(out)\n        out = self.conv4(out)\n        out = self.conv5(out)\n        out = self.conv6(out)\n        out = self.conv7(out)\n        return out\n\nclass Classifier(nn.Module):\n    def __init__(self, batchNorm, c_in, other):\n        super(Classifier, self).__init__()\n        self.conv1 = model_utils.conv(batchNorm, 512,  256, k=3, stride=1, pad=1)\n        self.conv2 = model_utils.conv(batchNorm, 256,  256, k=3, stride=2, pad=1)\n        self.conv3 = model_utils.conv(batchNorm, 256,  256, k=3, stride=2, pad=1)\n        self.conv4 = model_utils.conv(batchNorm, 256,  256, k=3, stride=2, pad=1)\n        self.other = other\n        \n        self.dir_x_est = nn.Sequential(\n                    model_utils.conv(batchNorm, 256, 64,  k=1, stride=1, pad=0),\n                    model_utils.outputConv(64, other['dirs_cls'], k=1, stride=1, pad=0))\n\n        self.dir_y_est = nn.Sequential(\n                    model_utils.conv(batchNorm, 256, 64,  k=1, stride=1, pad=0),\n                    model_utils.outputConv(64, other['dirs_cls'], k=1, stride=1, pad=0))\n\n        self.int_est = nn.Sequential(\n                    model_utils.conv(batchNorm, 256, 64,  k=1, stride=1, pad=0),\n                    model_utils.outputConv(64, other['ints_cls'], k=1, stride=1, pad=0))\n\n    def forward(self, inputs):\n        out = self.conv1(inputs)\n        out = self.conv2(out)\n        out = self.conv3(out)\n        out = self.conv4(out)\n        outputs = {}\n        if self.other['s1_est_d']:\n            outputs['dir_x'] = self.dir_x_est(out)\n            outputs['dir_y'] = self.dir_y_est(out)\n        if self.other['s1_est_i']:\n            outputs['ints'] = self.int_est(out)\n        return outputs\n\nclass LCNet(nn.Module):\n    def __init__(self, fuse_type='max', batchNorm=False, c_in=3, other={}):\n        super(LCNet, self).__init__()\n        self.featExtractor = FeatExtractor(batchNorm, c_in, 128)\n        self.classifier = Classifier(batchNorm, 256, other)\n        self.c_in      = c_in\n        self.fuse_type = fuse_type\n        self.other     = other\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):\n                kaiming_normal_(m.weight.data)\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def prepareInputs(self, x):\n        n, c, h, w = x[0].shape\n        t_h, t_w = self.other['test_h'], self.other['test_w']\n        if (h == t_h and w == t_w):\n            imgs = x[0] \n        else:\n            print('Rescaling images: from %dX%d to %dX%d' % (h, w, t_h, t_w))\n            imgs = torch.nn.functional.upsample(x[0], size=(t_h, t_w), mode='bilinear')\n\n        inputs = list(torch.split(imgs, 3, 1))\n        idx = 1\n        if self.other['in_light']:\n            light = torch.split(x[idx], 3, 1)\n            for i in range(len(inputs)):\n                inputs[i] = torch.cat([inputs[i], light[i]], 1)\n            idx += 1\n        if self.other['in_mask']:\n            mask = x[idx]\n            if mask.shape[2] != inputs[0].shape[2] or mask.shape[3] != inputs[0].shape[3]:\n                mask = torch.nn.functional.upsample(mask, size=(t_h, t_w), mode='bilinear')\n            for i in range(len(inputs)):\n                inputs[i] = torch.cat([inputs[i], mask], 1)\n            idx += 1\n        return inputs\n\n    def fuseFeatures(self, feats, fuse_type):\n        if fuse_type == 'mean':\n            feat_fused = torch.stack(feats, 1).mean(1)\n        elif fuse_type == 'max':\n            feat_fused, _ = torch.stack(feats, 1).max(1)\n        return feat_fused\n\n    def convertMidDirs(self, pred):\n        _, x_idx = pred['dirs_x'].data.max(1)\n        _, y_idx = pred['dirs_y'].data.max(1)\n        dirs = eval_utils.SphericalClassToDirs(x_idx, y_idx, self.other['dirs_cls'])\n        return dirs\n\n    def convertMidIntens(self, pred, img_num):\n        _, idx = pred['ints'].data.max(1)\n        ints = eval_utils.ClassToLightInts(idx, self.other['ints_cls'])\n        ints = ints.view(-1, 1).repeat(1, 3)\n        ints = torch.cat(torch.split(ints, ints.shape[0] // img_num, 0), 1)\n        return ints\n\n    def forward(self, x):\n        inputs = self.prepareInputs(x)\n        feats = []\n        for i in range(len(inputs)):\n            out_feat = self.featExtractor(inputs[i])\n            shape    = out_feat.data.shape\n            feats.append(out_feat)\n        feat_fused = self.fuseFeatures(feats, self.fuse_type)\n\n        l_dirs_x, l_dirs_y, l_ints = [], [], []\n        for i in range(len(inputs)):\n            net_input = torch.cat([feats[i], feat_fused], 1)\n            outputs = self.classifier(net_input)\n            if self.other['s1_est_d']:\n                l_dirs_x.append(outputs['dir_x'])\n                l_dirs_y.append(outputs['dir_y'])\n            if self.other['s1_est_i']:\n                l_ints.append(outputs['ints'])\n\n        pred = {}\n        if self.other['s1_est_d']:\n            pred['dirs_x'] = torch.cat(l_dirs_x, 0).squeeze()\n            pred['dirs_y'] = torch.cat(l_dirs_y, 0).squeeze()\n            pred['dirs']   = self.convertMidDirs(pred)\n        if self.other['s1_est_i']:\n            pred['ints'] = torch.cat(l_ints, 0).squeeze()\n            if pred['ints'].ndimension() == 1:\n                pred['ints'] = pred['ints'].view(1, -1)\n            pred['intens'] = self.convertMidIntens(pred, len(inputs))\n        return pred\n"
  },
  {
    "path": "models/NENet.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn.init import kaiming_normal_\nfrom . import model_utils\n\nclass FeatExtractor(nn.Module):\n    def __init__(self, batchNorm=False, c_in=3, other={}):\n        super(FeatExtractor, self).__init__()\n        self.other = other\n        self.conv1 = model_utils.conv(batchNorm, c_in, 64,  k=3, stride=1, pad=1)\n        self.conv2 = model_utils.conv(batchNorm, 64,   128, k=3, stride=2, pad=1)\n        self.conv3 = model_utils.conv(batchNorm, 128,  128, k=3, stride=1, pad=1)\n        self.conv4 = model_utils.conv(batchNorm, 128,  256, k=3, stride=2, pad=1)\n        self.conv5 = model_utils.conv(batchNorm, 256,  256, k=3, stride=1, pad=1)\n        self.conv6 = model_utils.deconv(256, 128)\n        self.conv7 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1)\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.conv2(out)\n        out = self.conv3(out)\n        out = self.conv4(out)\n        out = self.conv5(out)\n        out = self.conv6(out)\n        out_feat = self.conv7(out)\n        n, c, h, w = out_feat.data.shape\n        out_feat   = out_feat.view(-1)\n        return out_feat, [n, c, h, w]\n\nclass Regressor(nn.Module):\n    def __init__(self, batchNorm=False, other={}): \n        super(Regressor, self).__init__()\n        self.other   = other\n        self.deconv1 = model_utils.conv(batchNorm, 128, 128,  k=3, stride=1, pad=1)\n        self.deconv2 = model_utils.conv(batchNorm, 128, 128,  k=3, stride=1, pad=1)\n        self.deconv3 = model_utils.deconv(128, 64)\n        self.est_normal = self._make_output(64, 3, k=3, stride=1, pad=1)\n        self.other   = other\n\n    def _make_output(self, cin, cout, k=3, stride=1, pad=1):\n        return nn.Sequential(\n               nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False))\n\n    def forward(self, x, shape):\n        x      = x.view(shape[0], shape[1], shape[2], shape[3])\n        out    = self.deconv1(x)\n        out    = self.deconv2(out)\n        out    = self.deconv3(out)\n        normal = self.est_normal(out)\n        normal = torch.nn.functional.normalize(normal, 2, 1)\n        return normal\n\nclass NENet(nn.Module):\n    def __init__(self, fuse_type='max', batchNorm=False, c_in=3, other={}):\n        super(NENet, self).__init__()\n        self.extractor = FeatExtractor(batchNorm, c_in, other)\n        self.regressor = Regressor(batchNorm, other)\n        self.c_in      = c_in\n        self.fuse_type = fuse_type\n        self.other = other\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):\n                kaiming_normal_(m.weight.data)\n                if m.bias is not None:\n                    m.bias.data.zero_()\n            elif isinstance(m, nn.BatchNorm2d):\n                m.weight.data.fill_(1)\n                m.bias.data.zero_()\n\n    def prepareInputs(self, x):\n        imgs = torch.split(x[0], 3, 1)\n        idx = 1\n        if self.other['in_light']: idx += 1\n        if self.other['in_mask']:  idx += 1\n        dirs = torch.split(x[idx]['dirs'], x[0].shape[0], 0)\n        ints = torch.split(x[idx]['intens'], 3, 1)\n        \n        s2_inputs = []\n        for i in range(len(imgs)):\n            n, c, h, w = imgs[i].shape\n            l_dir = dirs[i] if dirs[i].dim() == 4 else dirs[i].view(n, -1, 1, 1)\n            l_int = torch.diag(1.0 / (ints[i].contiguous().view(-1)+1e-8))\n            img   = imgs[i].contiguous().view(n * c, h * w)\n            img   = torch.mm(l_int, img).view(n, c, h, w)\n            img_light = torch.cat([img, l_dir.expand_as(img)], 1)\n            s2_inputs.append(img_light)\n        return s2_inputs\n\n    def forward(self, x):\n        inputs = self.prepareInputs(x)\n        feats = torch.Tensor()\n        for i in range(len(inputs)):\n            feat, shape = self.extractor(inputs[i])\n            if i == 0:\n                feats = feat\n            else:\n                if self.fuse_type == 'mean':\n                    feats = torch.stack([feats, feat], 1).sum(1)\n                elif self.fuse_type == 'max':\n                    feats, _ = torch.stack([feats, feat], 1).max(1)\n        if self.fuse_type == 'mean':\n            feats = feats / len(img_split)\n        feat_fused = feats\n        normal = self.regressor(feat_fused, shape)\n        pred = {}\n        pred['n'] = normal\n        return pred\n"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/custom_model.py",
    "content": "from . import model_utils\nimport torch\n\ndef buildModel(args):\n    print('Creating Model %s' % (args.model))\n    in_c = model_utils.getInputChanel(args)\n    other = {\n            'img_num':  args.in_img_num, \n            'test_h':   args.test_h,   'test_w':   args.test_w,\n            'in_mask':  args.in_mask,  'in_light': args.in_light, \n            'dirs_cls': args.dirs_cls, 'ints_cls': args.ints_cls,\n            's1_est_d': args.s1_est_d, 's1_est_i': args.s1_est_i, 's1_est_n': args.s1_est_n, \n            }\n    models = __import__('models.' + args.model)\n    model_file = getattr(models, args.model)\n    model = getattr(model_file, args.model)(args.fuse_type, args.use_BN, in_c, other)\n\n    if args.cuda: model = model.cuda()\n\n    if args.retrain: \n        args.log.printWrite(\"=> using pre-trained model '{}'\".format(args.retrain))\n        model_utils.loadCheckpoint(args.retrain, model, cuda=args.cuda)\n\n    if args.resume:\n        args.log.printWrite(\"=> Resume loading checkpoint '{}'\".format(args.resume))\n        model_utils.loadCheckpoint(args.resume, model, cuda=args.cuda)\n    print(model)\n    args.log.printWrite(\"=> Model Parameters: %d\" % (model_utils.get_n_params(model)))\n    return model\n\ndef buildModelStage2(args):\n    print('Creating Stage2 Model %s' % (args.model_s2))\n    in_c = 6 if args.s2_in_light else 3\n    other = {\n            'img_num':  args.in_img_num,\n            'in_mask':  args.in_mask,  'in_light': args.in_light, \n            'dirs_cls': args.dirs_cls, 'ints_cls': args.ints_cls,\n            }\n    models = __import__('models.' + args.model_s2)\n    model_file = getattr(models, args.model_s2)\n    model = getattr(model_file, args.model_s2)(args.fuse_type, args.use_BN, in_c, other)\n\n    if args.cuda: model = model.cuda()\n\n    if args.retrain_s2: \n        args.log.printWrite(\"=> using pre-trained model_s2 '{}'\".format(args.retrain_s2))\n        model_utils.loadCheckpoint(args.retrain_s2, model, cuda=args.cuda)\n\n    print(model)\n    args.log.printWrite(\"=> Stage2 Model Parameters: %d\" % (model_utils.get_n_params(model)))\n    return model\n"
  },
  {
    "path": "models/model_utils.py",
    "content": "import os\nimport torch\nimport torch.nn as nn\n\ndef getInput(args, data):\n    input_list = [data['img']]\n    if args.in_light: input_list.append(data['dirs'])\n    if args.in_mask:  input_list.append(data['m'])\n    return input_list\n\ndef parseData(args, sample, timer=None, split='train'):\n    img, normal, mask = sample['img'], sample['normal'], sample['mask']\n    ints = sample['ints']\n    if args.in_light:\n        dirs = sample['dirs'].expand_as(img)\n    else: # predict lighting, prepare ground truth\n        n, c, h, w = sample['dirs'].shape\n        dirs_split = torch.split(sample['dirs'].view(n, c), 3, 1)\n        dirs = torch.cat(dirs_split, 0)\n    if timer: timer.updateTime('ToCPU')\n    if args.cuda:\n        img, normal, mask = img.cuda(), normal.cuda(), mask.cuda()\n        dirs, ints = dirs.cuda(), ints.cuda()\n        if timer: timer.updateTime('ToGPU')\n    data = {'img': img, 'n': normal, 'm': mask, 'dirs': dirs, 'ints': ints}\n    return data \n\ndef getInputChanel(args):\n    args.log.printWrite('[Network Input] Color image as input')\n    c_in = 3\n    if args.in_light:\n        args.log.printWrite('[Network Input] Adding Light direction as input')\n        c_in += 3\n    if args.in_mask:\n        args.log.printWrite('[Network Input] Adding Mask as input')\n        c_in += 1\n    args.log.printWrite('[Network Input] Input channel: {}'.format(c_in))\n    return c_in\n\ndef get_n_params(model):\n    pp = 0\n    for p in list(model.parameters()):\n        nn = 1\n        for s in list(p.size()):\n            nn = nn * s\n        pp += nn\n    return pp\n\ndef loadCheckpoint(path, model, cuda=True):\n    if cuda:\n        checkpoint = torch.load(path)\n    else:\n        checkpoint = torch.load(path, map_location=lambda storage, loc: storage)\n    model.load_state_dict(checkpoint['state_dict'])\n\ndef saveCheckpoint(save_path, epoch=-1, model=None, optimizer=None, records=None, args=None):\n    state   = {'state_dict': model.state_dict(), 'model': args.model}\n    records = {'epoch': epoch, 'optimizer':optimizer.state_dict(), 'records': records} # 'args': args}\n    torch.save(state,   os.path.join(save_path, 'checkp_{}.pth.tar'.format(epoch)))\n    torch.save(records, os.path.join(save_path, 'checkp_{}_rec.pth.tar'.format(epoch)))\n\ndef conv_ReLU(batchNorm, cin, cout, k=3, stride=1, pad=-1):\n    pad = pad if pad >= 0 else (k - 1) // 2\n    if batchNorm:\n        print('=> convolutional layer with bachnorm')\n        return nn.Sequential(\n                nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False),\n                nn.BatchNorm2d(cout),\n                nn.ReLU(inplace=True)\n                )\n    else:\n        return nn.Sequential(\n                nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True),\n                nn.ReLU(inplace=True)\n                )\n\ndef conv(batchNorm, cin, cout, k=3, stride=1, pad=-1):\n    pad = pad if pad >= 0 else (k - 1) // 2\n    if batchNorm:\n        print('=> convolutional layer with bachnorm')\n        return nn.Sequential(\n                nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False),\n                nn.BatchNorm2d(cout),\n                nn.LeakyReLU(0.1, inplace=True)\n                )\n    else:\n        return nn.Sequential(\n                nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True),\n                nn.LeakyReLU(0.1, inplace=True)\n                )\n\ndef outputConv(cin, cout, k=3, stride=1, pad=1):\n    return nn.Sequential(\n            nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True))\n\ndef deconv(cin, cout):\n    return nn.Sequential(\n            nn.ConvTranspose2d(cin, cout, kernel_size=4, stride=2, padding=1, bias=False),\n            nn.LeakyReLU(0.1, inplace=True)\n            )\n\ndef upconv(cin, cout):\n    return nn.Sequential(\n            nn.Upsample(scale_factor=2, mode='bilinear'),\n            nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False),\n            nn.LeakyReLU(0.1, inplace=True)\n            )\n"
  },
  {
    "path": "models/solver_utils.py",
    "content": "import torch\nimport os\nfrom utils import eval_utils\n\nclass Stage1ClsCrit(object): # First Stage, Light classification criterion\n    def __init__(self, args):\n        print('==> Using Stage1ClsCrit for lighting classification')\n        self.s1_est_d = args.s1_est_d\n        self.s1_est_i = args.s1_est_i\n        self.ints_cls, self.dirs_cls = args.ints_cls, args.dirs_cls\n        self.setupLightCrit(args)\n\n    def setupLightCrit(self, args):\n        args.log.printWrite('=> Using light criterion')\n        if self.s1_est_d:\n            self.dir_w = args.dir_w\n            self.dirs_x_crit = torch.nn.CrossEntropyLoss()\n            self.dirs_y_crit = torch.nn.CrossEntropyLoss()\n            if args.cuda: \n                self.dirs_x_crit = self.dirs_x_crit.cuda()\n                self.dirs_y_crit = self.dirs_y_crit.cuda()\n        if self.s1_est_i:\n            self.ints_w = args.ints_w\n            self.ints_crit = torch.nn.CrossEntropyLoss()\n            if args.cuda: self.ints_crit = self.ints_crit.cuda()\n\n    def forward(self, output, target):\n        self.loss = 0\n        out_loss = {}\n        if self.s1_est_d:\n            est_dir_x, est_dir_y = output['dirs_x'], output['dirs_y']\n            gt_dir_x, gt_dir_y = eval_utils.SphericalDirsToClass(target['dirs'], self.dirs_cls)\n        \n            dirs_x_loss = self.dirs_x_crit(est_dir_x, gt_dir_x)\n            dirs_y_loss = self.dirs_y_crit(est_dir_y, gt_dir_y)\n\n            out_loss['D_x_loss'] = dirs_x_loss.item()\n            out_loss['D_y_loss'] = dirs_y_loss.item()\n            self.loss += self.dir_w * (dirs_x_loss + dirs_y_loss)\n\n        if self.s1_est_i:\n            est_intens = output['ints']\n            gt_ints = target['ints'].squeeze()[:, 0: target['ints'].shape[1]:3]\n            gt_ints = torch.cat(torch.split(gt_ints, 1, 1), 0)\n            gt_intens = eval_utils.LightIntsToClass(gt_ints, self.ints_cls)\n            ints_loss = self.ints_crit(est_intens, gt_intens)\n            out_loss['I_loss'] = ints_loss.item()\n            self.loss += self.ints_w * ints_loss\n        return out_loss\n     \n    def backward(self):\n        self.loss.backward()\n\nclass Stage2Crit(object): # Second stage\n    def __init__(self, args):\n        self.s2_est_n = args.s2_est_n \n        self.s2_est_d = args.s2_est_d\n        self.s2_est_i = args.s2_est_i\n        self.setupLightCrit(args)\n        if self.s2_est_n:\n            self.setupNormalCrit(args)\n\n    def setupLightCrit(self, args):\n        args.log.printWrite('=> Using light criterion')\n        if self.s2_est_d:\n            self.dir_w = args.dir_w\n            self.dirs_crit = torch.nn.CosineEmbeddingLoss()\n            if args.cuda: self.dirs_crit = self.dirs_crit.cuda()\n        if self.s2_est_i:\n            self.ints_w = args.ints_w\n            self.ints_crit = torch.nn.MSELoss()\n            if args.cuda: self.ints_crit = self.ints_crit.cuda()\n\n    def setupNormalCrit(self, args):\n        args.log.printWrite('=> Using {} for criterion normal'.format(args.normal_loss))\n        self.normal_w = args.normal_w\n        if args.normal_loss == 'mse':\n            self.n_crit = torch.nn.MSELoss()\n        elif args.normal_loss == 'cos':\n            self.n_crit = torch.nn.CosineEmbeddingLoss()\n        else:\n            raise Exception(\"=> Unknown Criterion '{}'\".format(args.normal_loss))\n        if args.cuda:\n            self.n_crit = self.n_crit.cuda()\n\n    def forward(self, output, target):\n        self.loss = 0\n        out_loss = {}\n\n        if self.s2_est_d:\n            d_est, d_tar = output['dirs'], target['dirs']\n            d_num = d_tar.nelement() // d_tar.shape[1]\n            if not hasattr(self, 'l_flag') or d_num != self.l_flag.nelement():\n                self.l_flag = d_tar.data.new().resize_(d_num).fill_(1)\n            dirs_loss = self.dirs_crit(d_est.squeeze(), d_tar, self.l_flag)\n            self.loss += self.dir_w * dirs_loss\n            out_loss['D_loss'] = dirs_loss.item()\n\n        if self.s2_est_i:\n            i_est, i_tar = output['ints'], target['ints']\n            ints_loss  = self.ints_crit(i_est, i_tar.squeeze())\n            self.loss += self.ints_w * ints_loss\n            out_loss['I_loss'] = ints_loss.item()\n\n        if self.s2_est_n:\n            n_est, n_tar = output['n'], target['n']\n            n_num = n_tar.nelement() // n_tar.shape[1]\n            if not hasattr(self, 'n_flag') or n_num != self.n_flag.nelement():\n                self.n_flag = n_tar.data.new().resize_(n_num).fill_(1)\n            self.out_reshape = n_est.permute(0, 2, 3, 1).contiguous().view(-1, 3)\n            self.gt_reshape  = n_tar.permute(0, 2, 3, 1).contiguous().view(-1, 3)\n            normal_loss      = self.n_crit(self.out_reshape, self.gt_reshape, self.n_flag)\n            self.loss += self.normal_w * normal_loss \n            out_loss['N_loss'] = normal_loss.item()\n        return out_loss\n\n    def backward(self):\n        self.loss.backward()\n\ndef getOptimizer(args, params):\n    args.log.printWrite('=> Using %s solver for optimization' % (args.solver))\n    if args.solver == 'adam':\n        optimizer = torch.optim.Adam(params, args.init_lr, betas=(args.beta_1, args.beta_2))\n    elif args.solver == 'sgd':\n        optimizer = torch.optim.SGD(params, args.init_lr, momentum=args.momentum)\n    else:\n        raise Exception(\"=> Unknown Optimizer %s\" % (args.solver))\n    return optimizer\n\ndef getLrScheduler(args, optimizer):\n    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, \n            milestones=args.milestones, gamma=args.lr_decay, last_epoch=args.start_epoch-2)\n    return scheduler\n\ndef loadRecords(path, model, optimizer):\n    records = None\n    if os.path.isfile(path):\n        records = torch.load(path[:-8] + '_rec' + path[-8:])\n        optimizer.load_state_dict(records['optimizer'])\n        start_epoch = records['epoch'] + 1\n        records = records['records']\n        print(\"=> loaded Records\")\n    else:\n        raise Exception(\"=> no checkpoint found at '{}'\".format(path))\n    return records, start_epoch\n\ndef configOptimizer(args, model):\n    records = None\n    optimizer = getOptimizer(args, model.parameters())\n    if args.resume:\n        args.log.printWrite(\"=> Resume loading checkpoint '{}'\".format(args.resume))\n        records, start_epoch = loadRecords(args.resume, model, optimizer)\n        args.start_epoch = start_epoch\n    scheduler = getLrScheduler(args, optimizer)\n    return optimizer, scheduler, records\n"
  },
  {
    "path": "options/__init__.py",
    "content": ""
  },
  {
    "path": "options/base_opts.py",
    "content": "import argparse\nimport os\nimport torch\n\nclass BaseOpts(object):\n    def __init__(self):\n        self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n\n    def initialize(self):\n        #### Trainining Dataset ####\n        self.parser.add_argument('--dataset',     default='UPS_Synth_Dataset')\n        self.parser.add_argument('--data_dir',    default='data/datasets/PS_Blobby_Dataset')\n        self.parser.add_argument('--data_dir2',   default='data/datasets/PS_Sculpture_Dataset')\n        self.parser.add_argument('--concat_data', default=True, action='store_false')\n        self.parser.add_argument('--l_suffix',    default='_mtrl.txt')\n\n        #### Training Data and Preprocessing Arguments ####\n        self.parser.add_argument('--rescale',     default=True,  action='store_false')\n        self.parser.add_argument('--rand_sc',     default=True,  action='store_false')\n        self.parser.add_argument('--scale_h',     default=128,   type=int)\n        self.parser.add_argument('--scale_w',     default=128,   type=int)\n        self.parser.add_argument('--crop',        default=True,  action='store_false')\n        self.parser.add_argument('--crop_h',      default=128,   type=int)\n        self.parser.add_argument('--crop_w',      default=128,   type=int)\n        self.parser.add_argument('--test_h',      default=128,   type=int)\n        self.parser.add_argument('--test_w',      default=128,   type=int)\n        self.parser.add_argument('--test_resc',   default=True,  action='store_false')\n        self.parser.add_argument('--int_aug',     default=True,  action='store_false')\n        self.parser.add_argument('--noise_aug',   default=True,  action='store_false')\n        self.parser.add_argument('--noise',       default=0.05,  type=float)\n        self.parser.add_argument('--color_aug',   default=True,  action='store_false')\n        self.parser.add_argument('--color_ratio', default=3,     type=float)\n        self.parser.add_argument('--normalize',   default=False, action='store_true')\n\n        #### Device Arguments ####\n        self.parser.add_argument('--cuda',        default=True,  action='store_false')\n        self.parser.add_argument('--multi_gpu',   default=False, action='store_true')\n        self.parser.add_argument('--time_sync',   default=False, action='store_true')\n        self.parser.add_argument('--workers',     default=8,     type=int)\n        self.parser.add_argument('--seed',        default=0,     type=int)\n\n        #### Stage 1 Model Arguments ####\n        self.parser.add_argument('--dirs_cls',    default=36,    type=int)\n        self.parser.add_argument('--ints_cls',    default=20,    type=int)\n        self.parser.add_argument('--dir_int',     default=False, action='store_true')\n        self.parser.add_argument('--model',       default='LCNet')\n        self.parser.add_argument('--fuse_type',   default='max')\n        self.parser.add_argument('--in_img_num',  default=32,    type=int)\n        self.parser.add_argument('--s1_est_n',    default=False, action='store_true')\n        self.parser.add_argument('--s1_est_d',    default=True,  action='store_false')\n        self.parser.add_argument('--s1_est_i',    default=True,  action='store_false')\n        self.parser.add_argument('--in_light',    default=False, action='store_true')\n        self.parser.add_argument('--in_mask',     default=True,  action='store_false')\n        self.parser.add_argument('--use_BN',      default=False, action='store_true')\n        self.parser.add_argument('--resume',      default=None)\n        self.parser.add_argument('--retrain',     default=None)\n        self.parser.add_argument('--save_intv',   default=1,     type=int)\n\n        #### Stage 2 Model Arguments ####\n        self.parser.add_argument('--stage2',      default=False, action='store_true')\n        self.parser.add_argument('--model_s2',    default='NENet')\n        self.parser.add_argument('--retrain_s2',  default=None)\n        self.parser.add_argument('--s2_est_n',    default=True,  action='store_false')\n        self.parser.add_argument('--s2_est_i',    default=False, action='store_true')\n        self.parser.add_argument('--s2_est_d',    default=False, action='store_true')\n        self.parser.add_argument('--s2_in_light', default=True,  action='store_false')\n\n        #### Displaying Arguments ####\n        self.parser.add_argument('--train_disp',    default=20,  type=int)\n        self.parser.add_argument('--train_save',    default=200, type=int)\n        self.parser.add_argument('--val_intv',      default=1,   type=int)\n        self.parser.add_argument('--val_disp',      default=1,   type=int)\n        self.parser.add_argument('--val_save',      default=1,   type=int)\n        self.parser.add_argument('--max_train_iter',default=-1,  type=int)\n        self.parser.add_argument('--max_val_iter',  default=-1,  type=int)\n        self.parser.add_argument('--max_test_iter', default=-1,  type=int)\n        self.parser.add_argument('--train_save_n',  default=4,   type=int)\n        self.parser.add_argument('--test_save_n',   default=4,   type=int)\n\n        #### Log Arguments ####\n        self.parser.add_argument('--save_root',  default='data/logdir/')\n        self.parser.add_argument('--item',       default='CVPR2019')\n        self.parser.add_argument('--suffix',     default=None)\n        self.parser.add_argument('--debug',      default=False, action='store_true')\n        self.parser.add_argument('--make_dir',   default=True,  action='store_false')\n        self.parser.add_argument('--save_split', default=False, action='store_true')\n\n    def setDefault(self):\n        if self.args.debug:\n            self.args.train_disp = 1\n            self.args.train_save = 1\n            self.args.max_train_iter = 4 \n            self.args.max_val_iter = 4\n            self.args.max_test_iter = 4\n            self.args.test_intv = 1\n    def collectInfo(self):\n        self.args.str_keys  = [\n                'model', 'fuse_type', 'solver'\n                ]\n        self.args.val_keys  = [\n                'batch', 'scale_h', 'crop_h', 'init_lr', 'normal_w', \n                'dir_w', 'ints_w', 'in_img_num', 'dirs_cls', 'ints_cls'\n                ]\n        self.args.bool_keys = [\n                'use_BN', 'in_light', 'in_mask', 's1_est_n', 's1_est_d', 's1_est_i', \n                'color_aug', 'int_aug', 'concat_data', 'retrain', 'resume', 'stage2', \n                ] \n\n    def parse(self):\n        self.args = self.parser.parse_args()\n        return self.args\n"
  },
  {
    "path": "options/run_model_opts.py",
    "content": "from .base_opts import BaseOpts\nclass RunModelOpts(BaseOpts):\n    def __init__(self):\n        super(RunModelOpts, self).__init__()\n        self.initialize()\n\n    def initialize(self):\n        BaseOpts.initialize(self)\n        #### Testing Dataset Arguments #### \n        self.parser.add_argument('--run_model',  default=True, action='store_false')\n        self.parser.add_argument('--benchmark',  default='UPS_DiLiGenT_main')\n        self.parser.add_argument('--bm_dir',     default='data/datasets/DiLiGenT/pmsData_crop')\n        self.parser.add_argument('--epochs',     default=1,   type=int)\n        self.parser.add_argument('--test_batch', default=1,   type=int)\n        self.parser.add_argument('--test_disp',  default=1,   type=int)\n        self.parser.add_argument('--test_save',  default=1,   type=int)\n        \n        ### For UPS_Custom_Datast.py to test you own dataset\n        self.parser.add_argument('--have_l_dirs', default=False, action='store_true', help='Have light directions?')\n        self.parser.add_argument('--have_l_ints', default=False, action='store_true', help='Have light intensities?')\n        self.parser.add_argument('--have_gt_n',   default=False, action='store_true', help='Have GT surface normals?')\n\n    def collectInfo(self):\n        self.args.str_keys  = ['model', 'model_s2', 'benchmark', 'fuse_type']\n        self.args.val_keys  = ['in_img_num', 'test_h', 'test_w']\n        self.args.bool_keys = ['int_aug', 'test_resc']\n\n    def setDefault(self):\n        self.collectInfo()\n\n    def parse(self):\n        BaseOpts.parse(self)\n        self.setDefault()\n        return self.args\n"
  },
  {
    "path": "options/stage1_opts.py",
    "content": "from .base_opts import BaseOpts\nclass TrainOpts(BaseOpts):\n    def __init__(self):\n        super(TrainOpts, self).__init__()\n        self.initialize()\n\n    def initialize(self):\n        BaseOpts.initialize(self)\n        #### Training Arguments ####\n        self.parser.add_argument('--solver',      default='adam', help='adam|sgd')\n        self.parser.add_argument('--milestones',  default=[5, 10, 15, 20, 25], nargs='+', type=int)\n        self.parser.add_argument('--start_epoch', default=1,      type=int)\n        self.parser.add_argument('--epochs',      default=20,     type=int)\n        self.parser.add_argument('--batch',       default=32,     type=int)\n        self.parser.add_argument('--val_batch',   default=8,      type=int)\n        self.parser.add_argument('--init_lr',     default=0.0005, type=float)\n        self.parser.add_argument('--lr_decay',    default=0.5,    type=float)\n        self.parser.add_argument('--beta_1',      default=0.9,    type=float, help='adam')\n        self.parser.add_argument('--beta_2',      default=0.999,  type=float, help='adam')\n        self.parser.add_argument('--momentum',    default=0.9,    type=float, help='sgd')\n        self.parser.add_argument('--w_decay',     default=4e-4,   type=float)\n\n        #### Loss Arguments ####\n        self.parser.add_argument('--normal_loss', default='cos',  help='cos|mse')\n        self.parser.add_argument('--normal_w',    default=1,      type=float)\n        self.parser.add_argument('--dir_loss',    default='cos',  help='cos|mse')\n        self.parser.add_argument('--dir_w',       default=1,      type=float)\n        self.parser.add_argument('--ints_loss',   default='mse',  help='l1|mse')\n        self.parser.add_argument('--ints_w',      default=1,      type=float)\n\n    def collectInfo(self): \n        BaseOpts.collectInfo(self)\n        self.args.str_keys  += [\n                'dir_loss',\n                ]\n        self.args.val_keys  += [\n                ]\n        self.args.bool_keys += [\n                ] \n\n    def setDefault(self):\n        BaseOpts.setDefault(self)\n        if self.args.test_h != self.args.crop_h:\n            self.args.test_h, self.args.test_w = self.args.crop_h, self.args.crop_w\n        self.collectInfo()\n\n    def parse(self):\n        BaseOpts.parse(self)\n        self.setDefault()\n        return self.args\n"
  },
  {
    "path": "options/stage2_opts.py",
    "content": "from .base_opts import BaseOpts\nclass TrainOpts(BaseOpts):\n    def __init__(self):\n        super(TrainOpts, self).__init__()\n        self.initialize()\n\n    def initialize(self):\n        BaseOpts.initialize(self)\n        #### Training Arguments ####\n        self.parser.add_argument('--solver',      default='adam', help='adam|sgd')\n        self.parser.add_argument('--milestones',  default=[2, 4, 6, 8, 10], nargs='+', type=int)\n        self.parser.add_argument('--start_epoch', default=1,      type=int)\n        self.parser.add_argument('--epochs',      default=10,     type=int)\n        self.parser.add_argument('--batch',       default=16,     type=int)\n        self.parser.add_argument('--val_batch',   default=8,      type=int)\n        self.parser.add_argument('--init_lr',     default=0.0005, type=float)\n        self.parser.add_argument('--lr_decay',    default=0.5,    type=float)\n        self.parser.add_argument('--beta_1',      default=0.9,    type=float, help='adam')\n        self.parser.add_argument('--beta_2',      default=0.999,  type=float, help='adam')\n        self.parser.add_argument('--momentum',    default=0.9,    type=float, help='sgd')\n        self.parser.add_argument('--w_decay',     default=4e-4,   type=float)\n\n        #### Loss Arguments ####\n        self.parser.add_argument('--normal_loss', default='cos',  help='cos|mse')\n        self.parser.add_argument('--normal_w',    default=1,      type=float)\n        self.parser.add_argument('--dir_loss',    default='mse',  help='cos|mse')\n        self.parser.add_argument('--dir_w',       default=1,      type=float)\n        self.parser.add_argument('--ints_loss',   default='mse',  help='mse')\n        self.parser.add_argument('--ints_w',      default=1,      type=float)\n\n    def collectInfo(self): \n        BaseOpts.collectInfo(self)\n        self.args.str_keys += [\n                'model_s2'\n                ]\n        self.args.val_keys  += [\n                ]\n        self.args.bool_keys += [\n                ]\n\n    def setDefault(self):\n        BaseOpts.setDefault(self)\n        self.args.stage2    = True\n        self.args.test_resc = False\n        self.collectInfo()\n\n    def parse(self):\n        BaseOpts.parse(self)\n        self.setDefault()\n        return self.args\n"
  },
  {
    "path": "scripts/DiLiGenT_objects.txt",
    "content": "ballPNG\ncatPNG\npot1PNG\nbearPNG\npot2PNG\nbuddhaPNG\ngobletPNG\nreadingPNG\ncowPNG\nharvestPNG\n"
  },
  {
    "path": "scripts/cropDiLiGenTData.py",
    "content": "import os, argparse, sys, shutil, glob\nimport numpy as np\nfrom imageio import imread, imsave\nimport scipy.io as sio\n\nroot_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')\nsys.path.append(root_path)\nfrom utils import utils\n\nparser = argparse.ArgumentParser()\nparser.add_argument('--input_dir',  default='data/datasets/DiLiGenT/pmsData')\nparser.add_argument('--obj_list',   default='objects.txt')\nparser.add_argument('--suffix',     default='crop')\nparser.add_argument('--file_ext',   default='.png')\nparser.add_argument('--normal_name',default='Normal_gt.png')\nparser.add_argument('--mask_name',  default='mask.png')\nparser.add_argument('--n_key',      default='Normal_gt')\nparser.add_argument('--pad',        default=15, type=int)\nargs = parser.parse_args()\n\ndef getSaveDir():\n    dirName  = os.path.dirname(args.input_dir)\n    save_dir = '%s_%s' % (args.input_dir, args.suffix) \n    utils.makeFile(save_dir)\n    print('Output dir: %s\\n' % save_dir)\n    return save_dir\n\ndef getBBoxCompact(mask):\n    index = np.where(mask != 0)\n    t, b, l , r = index[0].min(), index[0].max(), index[1].min(), index[1].max()\n    h, w = b - t + 2 * args.pad, r - l + 2 * args.pad\n    t = max(0, t - args.pad)\n    b = t + h \n    l = max(0, l - args.pad)\n    r = l + w \n    if h % 4 != 0: \n        pad = 4 - h % 4\n        b += pad; h += pad\n    if w % 4 != 0: \n        pad = 4 - w % 4\n        r += pad; w += pad\n    return l, r, t, b, h, w\n\ndef loadMaskNormal(d):\n    mask   = imread(os.path.join(args.input_dir, d, args.mask_name))\n    try:\n        normal = imread(os.path.join(args.input_dir, d, args.normal_name))\n    except IOError:\n        normal = imread(os.path.join(args.input_dir, d, 'normal_gt.png'))\n    n_mat  = sio.loadmat(os.path.join(args.input_dir, d, 'Normal_gt.mat'))[args.n_key]\n    h, w, c = normal.shape\n    print('Processing Objects: %s' % d, mask.shape)\n    if mask.ndim < 3:\n        mask = mask.reshape(h, w, 1).repeat(3, 2)\n    return mask, normal, n_mat\n\ndef copyTXT(d):\n    txt = glob.glob(os.path.join(args.input_dir, '*.txt'))\n    for t in txt:\n        name = os.path.basename(t)\n        shutil.copy(t, os.path.join(args.save_dir, name))\n\n    txt = glob.glob(os.path.join(args.input_dir, d, '*.txt'))\n    for t in txt:\n        name = os.path.basename(t)\n        shutil.copy(t, os.path.join(args.save_dir, d, name))\n\nif __name__ == '__main__':\n    print('Input dir: %s\\n' % args.input_dir)\n    args.save_dir = getSaveDir()\n\n    dir_list  = utils.readList(os.path.join(args.input_dir, args.obj_list))\n    name_list = utils.readList(os.path.join(args.input_dir, 'names.txt'))\n    max_h, max_w = 0, 0\n    crop_list = open(os.path.join(args.save_dir, 'crop.txt'), 'w')\n    for d in dir_list:\n        utils.makeFile(os.path.join(args.save_dir, d))\n        mask, normal, n_mat = loadMaskNormal(d)\n        l, r, t, b, h, w = getBBoxCompact(mask[:,:,0] / 255)\n        crop_list.write('%d %d %d %d %d %d\\n' % (mask.shape[0], mask.shape[1], l, r, t, b))\n\n        max_h = h if h > max_h else max_h\n        max_w = w if w > max_w else max_w\n        print('\\t BBox L %d R %d T %d B %d, H:%d W:%d, Padded: %d %d' % \n                (l, r, t, b, h, w, r - l, b - t)) \n        imsave(os.path.join(args.save_dir, d, args.mask_name), mask[t:b, l:r, :])\n        imsave(os.path.join(args.save_dir, d, args.normal_name), normal[t:b, l:r, :])\n        sio.savemat(os.path.join(args.save_dir, d, 'Normal_gt.mat'), \n                {args.n_key: n_mat[t:b, l:r, :]} ,do_compression=True)\n        copyTXT(d)\n        intens = np.genfromtxt(os.path.join(args.input_dir, d, 'light_intensities.txt'))\n        for idx, name in enumerate(name_list):\n            #print('Process img %d/%d' % (idx+1, len(name_list)))\n            img = imread(os.path.join(args.input_dir, d, name))\n            img = img[t:b, l:r, :]\n            imsave(os.path.join(args.save_dir, d, name), img.astype(np.uint8))\n    print('Max H %d, Max %d' % (max_h, max_w))\n    crop_list.close()\n"
  },
  {
    "path": "scripts/download_pretrained_models.sh",
    "content": "path=\"data/models/\"\nmkdir -p $path\ncd $path\n\n# Download pre-trained model\nfor model in \"LCNet_CVPR2019.pth.tar\" \"NENet_CVPR2019.pth.tar\"; do\n    wget http://www.visionlab.cs.hku.hk/data/SDPS-Net/models/${model}\ndone\n\n# Back to root directory\ncd ../\n"
  },
  {
    "path": "scripts/download_synthetic_datasets.sh",
    "content": "mkdir -p data/datasets\ncd data/datasets\n\n# Download Synthetic dataset\nfor dataset in \"PS_Sculpture_Dataset.tgz\" \"PS_Blobby_Dataset.tgz\"; do\n    echo \"Downloading $dataset\"\n    wget http://www.visionlab.cs.hku.hk/data/PS-FCN/datasets/$dataset\n    tar -xvf $dataset\n    rm $dataset\ndone\n\n# Back to root directory\ncd ../../\n\n"
  },
  {
    "path": "scripts/prepare_diligent_dataset.sh",
    "content": "mkdir -p data/datasets\ncd data/datasets\n\n## Download real testing dataset\nurl=\"https://www.dropbox.com/s/hdnbh526tyvv68i/DiLiGenT.zip?dl=0\"\nname=\"DiLiGenT\"\n\nwget $url -O ${name}.zip\nunzip ${name}.zip\nrm ${name}.zip\n\ncd ${name}/pmsData/\ncp ballPNG/filenames.txt names.txt\n\n# Back to root directory\ncd ../../../../\ncp scripts/DiLiGenT_objects.txt data/datasets/${name}/pmsData/objects.txt\npython scripts/cropDiLiGenTData.py\n"
  },
  {
    "path": "test_stage1.py",
    "content": "import os\nimport torch\nfrom models import model_utils\nfrom utils import eval_utils, time_utils \nimport numpy as np\n\ndef get_itervals(args, split):\n    if split not in ['train', 'val', 'test']:\n        split = 'test'\n    args_var = vars(args)\n    disp_intv = args_var[split+'_disp']\n    save_intv = args_var[split+'_save']\n    stop_iters = args_var['max_'+split+'_iter']\n    return disp_intv, save_intv, stop_iters\n\ndef test(args, split, loader, model, log, epoch, recorder):\n    model.eval()\n    log.printWrite('---- Start %s Epoch %d: %d batches ----' % (split, epoch, len(loader)))\n    timer = time_utils.Timer(args.time_sync);\n\n    disp_intv, save_intv, stop_iters = get_itervals(args, split)\n    res = []\n    with torch.no_grad():\n        for i, sample in enumerate(loader):\n            data = model_utils.parseData(args, sample, timer, split)\n            input = model_utils.getInput(args, data)\n\n            pred = model(input); timer.updateTime('Forward')\n\n            recoder, iter_res, error = prepareRes(args, data, pred, recorder, log, split)\n\n            res.append(iter_res)\n            iters = i + 1\n            if iters % disp_intv == 0:\n                opt = {'split':split, 'epoch':epoch, 'iters':iters, 'batch':len(loader), \n                        'timer':timer, 'recorder': recorder}\n                log.printItersSummary(opt)\n\n            if iters % save_intv == 0:\n                results, nrow = prepareSave(args, data, pred)\n                log.saveImgResults(results, split, epoch, iters, nrow=nrow, error=error)\n                log.plotCurves(recorder, split, epoch=epoch, intv=disp_intv)\n\n            if stop_iters > 0 and iters >= stop_iters: break\n    res = np.vstack([np.array(res), np.array(res).mean(0)])\n    save_name = '%s_res.txt' % (args.suffix)\n    np.savetxt(os.path.join(args.log_dir, split, save_name), res, fmt='%.2f')\n    opt = {'split': split, 'epoch': epoch, 'recorder': recorder}\n    log.printEpochSummary(opt)\n\ndef prepareRes(args, data, pred, recorder, log, split):\n    data_batch = args.val_batch if split == 'val' else args.test_batch\n    iter_res = []\n    error = ''\n    if args.s1_est_d:\n        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred['dirs'].data, data_batch)\n        recorder.updateIter(split, l_acc.keys(), l_acc.values())\n        iter_res.append(l_acc['l_err_mean'])\n        error += 'D_%.3f-' % (l_acc['l_err_mean']) \n    if args.s1_est_i:\n        int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred['intens'].data, data_batch)\n        recorder.updateIter(split, int_acc.keys(), int_acc.values())\n        iter_res.append(int_acc['ints_ratio'])\n        error += 'I_%.3f-' % (int_acc['ints_ratio'])\n\n    if args.s1_est_n:\n        acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data)\n        recorder.updateIter(split, acc.keys(), acc.values())\n        iter_res.append(acc['n_err_mean'])\n        error += 'N_%.3f-' % (acc['n_err_mean'])\n        data['error_map'] = error_map['angular_map']\n\n    return recorder, iter_res, error\n\n\ndef prepareSave(args, data, pred):\n    results = [data['img'].data, data['m'].data, (data['n'].data+1)/2]\n    if args.s1_est_n:\n        pred_n = (pred['n'].data + 1) / 2\n        masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data)\n        res_n = [pred_n, masked_pred, data['error_map']]\n        results += res_n\n\n    nrow = data['img'].shape[0]\n    return results, nrow\n"
  },
  {
    "path": "test_stage2.py",
    "content": "import os\nimport torch\nfrom models import model_utils\nfrom utils import eval_utils, time_utils \nimport numpy as np\n\ndef get_itervals(args, split):\n    if split not in ['train', 'val', 'test']:\n        split = 'test'\n    args_var = vars(args)\n    disp_intv = args_var[split+'_disp']\n    save_intv = args_var[split+'_save']\n    stop_iters = args_var['max_'+split+'_iter']\n    return disp_intv, save_intv, stop_iters\n\ndef test(args, split, loader, models, log, epoch, recorder):\n    models[0].eval()\n    models[1].eval()\n    log.printWrite('---- Start %s Epoch %d: %d batches ----' % (split, epoch, len(loader)))\n    timer = time_utils.Timer(args.time_sync);\n\n    disp_intv, save_intv, stop_iters = get_itervals(args, split)\n    res = []\n    with torch.no_grad():\n        for i, sample in enumerate(loader):\n            data = model_utils.parseData(args, sample, timer, split)\n            input = model_utils.getInput(args, data)\n\n            pred_c = models[0](input); timer.updateTime('Forward')\n            input.append(pred_c)\n            pred = models[1](input); timer.updateTime('Forward')\n\n            recoder, iter_res, error = prepareRes(args, data, pred_c, pred, recorder, log, split)\n\n            res.append(iter_res)\n            iters = i + 1\n            if iters % disp_intv == 0:\n                opt = {'split':split, 'epoch':epoch, 'iters':iters, 'batch':len(loader), \n                        'timer':timer, 'recorder': recorder}\n                log.printItersSummary(opt)\n\n            if iters % save_intv == 0:\n                results, nrow = prepareSave(args, data, pred_c, pred)\n                log.saveImgResults(results, split, epoch, iters, nrow=nrow, error='')\n                log.plotCurves(recorder, split, epoch=epoch, intv=disp_intv)\n\n            if stop_iters > 0 and iters >= stop_iters: break\n    res = np.vstack([np.array(res), np.array(res).mean(0)])\n    save_name = '%s_res.txt' % (args.suffix)\n    np.savetxt(os.path.join(args.log_dir, split, save_name), res, fmt='%.2f')\n    if res.ndim > 1:\n        for i in range(res.shape[1]):\n            save_name = '%s_%d_res.txt' % (args.suffix, i)\n            np.savetxt(os.path.join(args.log_dir, split, save_name), res[:,i], fmt='%.3f')\n\n    opt = {'split': split, 'epoch': epoch, 'recorder': recorder}\n    log.printEpochSummary(opt)\n\ndef prepareRes(args, data, pred_c, pred, recorder, log, split):\n    data_batch = args.val_batch if split == 'val' else args.test_batch\n    iter_res = []\n    error = ''\n    if args.s1_est_d:\n        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred_c['dirs'].data, data_batch)\n        recorder.updateIter(split, l_acc.keys(), l_acc.values())\n        iter_res.append(l_acc['l_err_mean'])\n        error += 'D_%.3f-' % (l_acc['l_err_mean']) \n    if args.s1_est_i:\n        int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred_c['intens'].data, data_batch)\n        recorder.updateIter(split, int_acc.keys(), int_acc.values())\n        iter_res.append(int_acc['ints_ratio'])\n        error += 'I_%.3f-' % (int_acc['ints_ratio'])\n\n    if args.s2_est_n:\n        acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data)\n        recorder.updateIter(split, acc.keys(), acc.values())\n        iter_res.append(acc['n_err_mean'])\n        error += 'N_%.3f-' % (acc['n_err_mean'])\n        data['error_map'] = error_map['angular_map']\n    return recorder, iter_res, error\n\ndef prepareSave(args, data, pred_c, pred):\n    results = [data['img'].data, data['m'].data, (data['n'].data+1) / 2]\n    if args.s2_est_n:\n        pred_n = (pred['n'].data + 1) / 2\n        masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data)\n        res_n = [masked_pred, data['error_map']]\n        results += res_n\n\n    nrow = data['img'].shape[0]\n    return results, nrow\n"
  },
  {
    "path": "train_stage1.py",
    "content": "from models import model_utils\nfrom utils  import eval_utils, time_utils\n\ndef train(args, loader, model, criterion, optimizer, log, epoch, recorder):\n    model.train()\n    log.printWrite('---- Start Training Epoch %d: %d batches ----' % (epoch, len(loader)))\n    timer = time_utils.Timer(args.time_sync);\n\n    for i, sample in enumerate(loader):\n        data = model_utils.parseData(args, sample, timer, 'train')\n        input = model_utils.getInput(args, data)\n\n        pred = model(input); timer.updateTime('Forward')\n\n        optimizer.zero_grad()\n        loss = criterion.forward(pred, data); \n        timer.updateTime('Crit');\n        criterion.backward(); timer.updateTime('Backward')\n\n        recorder.updateIter('train', loss.keys(), loss.values())\n\n        optimizer.step(); timer.updateTime('Solver')\n\n        iters = i + 1\n        if iters % args.train_disp == 0:\n            opt = {'split':'train', 'epoch':epoch, 'iters':iters, 'batch':len(loader), \n                    'timer':timer, 'recorder': recorder}\n            log.printItersSummary(opt)\n\n        if iters % args.train_save == 0:\n            results, recorder, nrow = prepareSave(args, data, pred, recorder, log) \n            log.saveImgResults(results, 'train', epoch, iters, nrow=nrow)\n            log.plotCurves(recorder, 'train', epoch=epoch, intv=args.train_disp)\n\n        if args.max_train_iter > 0 and iters >= args.max_train_iter: break\n    opt = {'split': 'train', 'epoch': epoch, 'recorder': recorder}\n    log.printEpochSummary(opt)\n\ndef prepareSave(args, data, pred, recorder, log):\n    results = [data['img'].data, data['m'].data, (data['n'].data+1)/2]\n    if args.s1_est_d:\n        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred['dirs'].data, args.batch)\n        recorder.updateIter('train', l_acc.keys(), l_acc.values())\n\n    if args.s1_est_i:\n        int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred['intens'].data, args.batch)\n        recorder.updateIter('train', int_acc.keys(), int_acc.values())\n\n    if args.s1_est_n:\n        acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data)\n        pred_n = (pred['n'].data + 1) / 2\n        masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data)\n        res_n = [pred_n, masked_pred, error_map['angular_map']]\n        results += res_n\n        recorder.updateIter('train', acc.keys(), acc.values())\n    nrow = data['img'].shape[0] if data['img'].shape[0] <= 32 else 32\n    return results, recorder, nrow\n"
  },
  {
    "path": "train_stage2.py",
    "content": "import torch\nfrom models import model_utils\nfrom utils  import eval_utils, time_utils\n\ndef train(args, loader, models, criterion, optimizers, log, epoch, recorder):\n    models[1].train()\n    models[0].eval()\n    optimizer, optimizer_c = optimizers\n    log.printWrite('---- Start Training Epoch %d: %d batches ----' % (epoch, len(loader)))\n    timer = time_utils.Timer(args.time_sync);\n\n    for i, sample in enumerate(loader):\n        data = model_utils.parseData(args, sample, timer, 'train')\n        input = model_utils.getInput(args, data)\n        with torch.no_grad():\n            pred_c = models[0](input); \n        input.append(pred_c)\n        pred = models[1](input); timer.updateTime('Forward')\n\n        optimizer.zero_grad()\n\n        loss = criterion.forward(pred, data); \n        timer.updateTime('Crit');\n        criterion.backward(); timer.updateTime('Backward')\n\n        recorder.updateIter('train', loss.keys(), loss.values())\n\n        optimizer.step(); timer.updateTime('Solver')\n\n        iters = i + 1\n        if iters % args.train_disp == 0:\n            opt = {'split':'train', 'epoch':epoch, 'iters':iters, 'batch':len(loader), \n                    'timer':timer, 'recorder': recorder}\n            log.printItersSummary(opt)\n\n        if iters % args.train_save == 0:\n            results, recorder, nrow = prepareSave(args, data, pred_c, pred, recorder, log) \n            log.saveImgResults(results, 'train', epoch, iters, nrow=nrow)\n            log.plotCurves(recorder, 'train', epoch=epoch, intv=args.train_disp)\n\n        if args.max_train_iter > 0 and iters >= args.max_train_iter: break\n    opt = {'split': 'train', 'epoch': epoch, 'recorder': recorder}\n    log.printEpochSummary(opt)\n\ndef prepareSave(args, data, pred_c, pred, recorder, log):\n    input_var, mask_var = data['img'], data['m']\n    results = [input_var.data, mask_var.data, (data['n'].data+1)/2]\n    if args.s1_est_d:\n        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred_c['dirs'].data, args.batch)\n        recorder.updateIter('train', l_acc.keys(), l_acc.values())\n    if args.s1_est_i:\n        int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred_c['intens'].data, args.batch)\n        recorder.updateIter('train', int_acc.keys(), int_acc.values())\n\n    if args.s2_est_n:\n        acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, mask_var.data)\n        pred_n = (pred['n'].data + 1) / 2\n        masked_pred = pred_n * mask_var.data.expand_as(pred['n'].data)\n        res_n = [masked_pred, error_map['angular_map']]\n        results += res_n\n        recorder.updateIter('train', acc.keys(), acc.values())\n\n    nrow = input_var.shape[0] if input_var.shape[0] <= 32 else 32\n    return results, recorder, nrow\n"
  },
  {
    "path": "utils/__init__.py",
    "content": ""
  },
  {
    "path": "utils/eval_utils.py",
    "content": "import torch\nimport math\nimport numpy as np\nfrom matplotlib import cm\n\ndef colorMap(diff):\n    thres = 90\n    diff_norm = np.clip(diff, 0, thres) / thres\n    diff_cm = torch.from_numpy(cm.jet(diff_norm.numpy()))[:,:,:, :3]\n    return diff_cm.permute(0,3,1,2).clone().float()\n\ndef calDirsAcc(gt_l, pred_l, data_batch=1):\n    n, c = gt_l.shape\n    pred_l = pred_l.view(n, c)\n    dot_product = (gt_l * pred_l).sum(1).clamp(-1, 1)\n    \n    angular_err = torch.acos(dot_product) * 180.0 / math.pi\n    l_err_mean  = angular_err.mean()\n    return {'l_err_mean': l_err_mean.item()}, angular_err.squeeze()\n\ndef calIntsAcc(gt_i, pred_i, data_batch=1):\n    n, c, h, w = gt_i.shape\n    pred_i  = pred_i.view(n, c, h, w)\n    ref_int = gt_i[:, :3].repeat(1, gt_i.shape[1] // 3, 1, 1)\n    gt_i  = gt_i / ref_int\n    scale = torch.gels(gt_i.view(-1, 1), pred_i.view(-1, 1))\n    ints_ratio = (gt_i - scale[0][0] * pred_i).abs() / (gt_i + 1e-8)\n    ints_error = torch.stack(ints_ratio.split(3, 1), 1).mean(2)\n    return {'ints_ratio': ints_ratio.mean().item()}, ints_error.squeeze()\n    \ndef calNormalAcc(gt_n, pred_n, mask=None):\n    \"\"\"Tensor Dim: NxCxHxW\"\"\"\n    dot_product = (gt_n * pred_n).sum(1).clamp(-1,1)\n    error_map   = torch.acos(dot_product) # [-pi, pi]\n    angular_map = error_map * 180.0 / math.pi\n    angular_map = angular_map * mask.narrow(1, 0, 1).squeeze(1)\n\n    valid = mask.narrow(1, 0, 1).sum()\n    ang_valid  = angular_map[mask.narrow(1, 0, 1).squeeze(1).byte()]\n    n_err_mean = ang_valid.sum() / valid\n    n_err_med  = ang_valid.median()\n    n_acc_11   = (ang_valid < 11.25).sum().float() / valid\n    n_acc_30   = (ang_valid < 30).sum().float() / valid\n    n_acc_45   = (ang_valid < 45).sum().float() / valid\n\n    angular_map = colorMap(angular_map.cpu().squeeze(1))\n    value = {'n_err_mean': n_err_mean.item(), \n            'n_acc_11': n_acc_11.item(), 'n_acc_30': n_acc_30.item(), 'n_acc_45': n_acc_45.item()}\n    angular_error_map = {'angular_map': angular_map}\n    return value, angular_error_map\n\ndef SphericalDirsToClass(dirs, cls_num):\n    theta = torch.atan(dirs[:,0] / (dirs[:,2] + 1e-8)) \n    denom = torch.sqrt(dirs[:,0] * dirs[:,0] + dirs[:,2] * dirs[:,2])\n    phi = torch.atan(dirs[:,1] / (denom + 1e-8))\n    theta = theta / np.pi * 180\n    phi   = phi / np.pi * 180\n    azimuth = ((theta + 90.0) / 180 * cls_num).clamp(0, cls_num-1).long()\n    elevate = ((phi   + 90.0) / 180 * cls_num).clamp(0, cls_num-1).long()\n    return azimuth, elevate\n\ndef SphericalClassToDirs(x_cls, y_cls, cls_num):\n    theta = (x_cls.float() + 0.5) / cls_num * 180 - 90\n    phi   = (y_cls.float() + 0.5) / cls_num * 180 - 90\n    neg_x = theta < 0\n    neg_y = phi < 0\n    theta = theta.clamp(-90, 90) / 180.0 * np.pi\n    phi   = phi.clamp(-90, 90) / 180.0 * np.pi\n\n    tan2_phi   = pow(torch.tan(phi), 2)\n    tan2_theta = pow(torch.tan(theta), 2)\n    y = torch.sqrt(tan2_phi / (1 + tan2_phi))\n    y[neg_y] = y[neg_y] * -1\n    #y = torch.sin(phi)\n    z = torch.sqrt((1 - y * y) / (1 + tan2_theta))\n    x = z * torch.tan(theta)\n    dirs = torch.stack([x,y,z], 1)\n    dirs = dirs / dirs.norm(p=2, dim=1, keepdim=True)\n    return dirs\n    \ndef LightIntsToClass(ints, cls_num):\n    ints = (ints - 0.2) / 1.8\n    ints = (ints * cls_num).clamp(0, cls_num-1).long()\n    return ints.view(-1)\n\ndef ClassToLightInts(cls, cls_num):\n    ints = (cls.float() + 0.5) / cls_num * 1.8 + 0.2\n    ints = ints.clamp(0.2 , 2.0)\n    return ints\n"
  },
  {
    "path": "utils/logger.py",
    "content": "import datetime, time, os\nimport numpy as np\nimport torch\nimport torchvision.utils as vutils\nimport scipy.io as sio\nfrom . import utils\n\nimport matplotlib; matplotlib.use('agg')\nimport matplotlib.pyplot as plt\nfrom matplotlib import cm\nfrom matplotlib.font_manager import FontProperties\nfontP = FontProperties()\nfontP.set_size('small')\nplt.rcParams[\"figure.figsize\"] = (5,8)\n\nclass Logger(object):\n    def __init__(self, args):\n        self.times = {'init': time.time()}\n        if args.make_dir:\n            self._setupDirs(args)\n        self.args = args\n        args.log  = self\n        self.printArgs()\n\n    def printArgs(self):\n        strs = '------------ Options -------------\\n'\n        strs += '{}'.format(utils.dictToString(vars(self.args)))\n        strs += '-------------- End ----------------\\n'\n        self.printWrite(strs)\n\n    def _addArguments(self, args):\n        info = ''\n        if hasattr(args, 'run_model') and args.run_model:\n            info += '_run_model,%s' % os.path.basename(args.retrain).split('.')[0]\n        arg_var  = vars(args)\n        for k in args.str_keys:  \n            info = '{0},{1}'.format(info, arg_var[k])\n        for k in args.val_keys:  \n            var_key = k[:2] + '_' + k[-1]\n            info = '{0},{1}-{2}'.format(info, var_key, arg_var[k])\n        for k in args.bool_keys: \n            info = '{0},{1}'.format(info, k) if arg_var[k] else info \n        return info\n\n    def _setupDirs(self, args):\n        date_now = datetime.datetime.now()\n        dir_name = '%d-%d' % (date_now.month, date_now.day)\n        dir_name += (',%s' % args.suffix) if args.suffix else ''\n        dir_name += self._addArguments(args) \n        dir_name += ',DEBUG' if args.debug else ''\n\n        self._checkPath(args, dir_name)\n        file_dir = os.path.join(args.log_dir, '%s,%s' % (dir_name, date_now.strftime('%H:%M:%S')))\n        self.log_fie = open(file_dir, 'w')\n        return \n\n    def _checkPath(self, args, dir_name):\n        if hasattr(args, 'run_model') and args.run_model:\n            log_root = os.path.join(os.path.dirname(args.retrain), dir_name)\n            args.log_dir = log_root\n            sub_dirs = ['test']\n        else:\n            if args.resume and os.path.isfile(args.resume):\n                log_root = os.path.join(os.path.dirname(os.path.dirname(args.resume)), dir_name)\n            else:\n                if args.debug:\n                    dir_name = 'DEBUG/' + dir_name\n                log_root = os.path.join(args.save_root, args.dataset, args.item, dir_name)\n            args.log_dir = os.path.join(log_root, 'logdir')\n            args.cp_dir  = os.path.join(log_root, 'checkpointdir')\n            utils.makeFiles([args.log_dir, args.cp_dir])\n            sub_dirs = ['train', 'val'] \n        for sub_dir in sub_dirs:\n            utils.makeFiles([os.path.join(args.log_dir, sub_dir, 'Images')])\n\n    def printWrite(self, strs):\n        print('%s' % strs)\n        if self.args.make_dir:\n            self.log_fie.write('%s\\n' % strs)\n            self.log_fie.flush()\n\n    def getTimeInfo(self, epoch, iters, batch):\n        time_elapsed = (time.time() - self.times['init']) / 3600.0\n        total_iters  = (self.args.epochs - self.args.start_epoch + 1) * batch\n        cur_iters    = (epoch - self.args.start_epoch) * batch + iters\n        time_total   = time_elapsed * (float(total_iters) / cur_iters)\n        return time_elapsed, time_total\n\n    def printItersSummary(self, opt):\n        epoch, iters, batch = opt['epoch'], opt['iters'], opt['batch']\n        strs = ' | {}'.format(str.upper(opt['split']))\n        strs += ' Iter [{}/{}] Epoch [{}/{}]'.format(iters, batch, epoch, self.args.epochs)\n        if opt['split'] == 'train': \n            time_elapsed, time_total = self.getTimeInfo(epoch, iters, batch) # Buggy for test\n            strs += ' Clock [{:.2f}h/{:.2f}h]'.format(time_elapsed, time_total)\n            strs += ' LR [{}]'.format(opt['recorder'].records[opt['split']]['lr'][epoch][0])\n        self.printWrite(strs)\n        if 'timer' in opt.keys(): \n            self.printWrite(opt['timer'].timeToString())\n        if 'recorder' in opt.keys(): \n            self.printWrite(opt['recorder'].iterRecToString(opt['split'], epoch))\n\n    def printEpochSummary(self, opt):\n        split = opt['split']\n        epoch = opt['epoch']\n        self.printWrite('---------- {} Epoch {} Summary -----------'.format(str.upper(split), epoch))\n        self.printWrite(opt['recorder'].epochRecToString(split, epoch))\n\n    def convertToSameSize(self, t_list):\n        shape = (t_list[0].shape[0], 3, t_list[0].shape[2], t_list[0].shape[3])\n        for i, tensor in enumerate(t_list):\n            n, c, h, w = tensor.shape\n            if tensor.shape[1] != shape[1]: # check channel\n                t_list[i] = tensor.expand((n, shape[1], h, w))\n            if h == shape[2] and w == shape[3]:\n                continue\n            t_list[i] = torch.nn.functional.upsample(tensor, [shape[2], shape[3]], mode='bilinear')\n        return t_list\n\n    def getSaveDir(self, split, epoch):\n        save_dir = os.path.join(self.args.log_dir, split, 'Images')\n        run_model = hasattr(self.args, 'run_model') and self.args.run_model\n        if not run_model and epoch > 0:\n            save_dir = os.path.join(save_dir, str(epoch))\n        utils.makeFile(save_dir)\n        return save_dir\n\n    def splitMulitChannel(self, t_list, max_save_n = 8):\n        new_list = []\n        for tensor in t_list:\n            if tensor.shape[1] > 3:\n                num = 3\n                new_list += torch.split(tensor, num, 1)[:max_save_n]\n            else:\n                new_list.append(tensor)\n        return new_list\n\n    def saveSplit(self, res, save_prefix):\n        n, c, h, w = res.shape\n        for i in range(n):\n            vutils.save_image(res[i], save_prefix + '_%d.png' % (i))\n\n    def saveImgResults(self, results, split, epoch, iters, nrow, error=''):\n        max_save_n = self.args.test_save_n if split == 'test' else self.args.train_save_n\n        res = [img.cpu() for img in results]\n        res = self.splitMulitChannel(res, max_save_n)\n        res = torch.cat(self.convertToSameSize(res))\n        save_dir = self.getSaveDir(split, epoch)\n        save_prefix = os.path.join(save_dir, '%d_%d' % (epoch, iters))\n        save_prefix += ('_%s' % error) if error != '' else ''\n        if self.args.save_split: \n            self.saveSplit(res, save_prefix)\n        else:\n            vutils.save_image(res, save_prefix + '_out.png', nrow=nrow)\n\n    def plotCurves(self, recorder, split='train', epoch=-1, intv=1):\n        dict_of_array = recorder.recordToDictOfArray(split, epoch, intv)\n        save_dir = os.path.join(self.args.log_dir, split)\n        if epoch < 0:\n            save_dir = self.args.log_dir\n            save_name = '%s_Summary.png' % (split)\n        else:\n            save_name = '%s_epoch_%d.png' % (split, epoch)\n\n        classes = ['loss', 'acc', 'err', 'lr', 'ratio']\n        classes = utils.checkIfInList(classes, dict_of_array.keys())\n        if len(classes) == 0: return\n\n        for idx, c in enumerate(classes):\n            plt.subplot(len(classes), 1, idx+1)\n            plt.grid()\n            legends = []\n            for k in dict_of_array.keys():\n                if (c in k.lower()) and not k.endswith('_x'):\n                    plt.plot(dict_of_array[k+'_x'], dict_of_array[k])\n                    legends.append(k)\n            if len(legends) != 0:\n                plt.legend(legends, bbox_to_anchor=(0.5, 1.05), loc='upper center', \n                            ncol=len(legends), prop=fontP)\n                plt.title(c)\n                if epoch < 0: plt.xlabel('Epoch') \n                else: plt.xlabel('Iters')\n        plt.tight_layout()\n        plt.savefig(os.path.join(save_dir, save_name))\n        plt.clf()\n"
  },
  {
    "path": "utils/recorders.py",
    "content": "from collections import OrderedDict\nimport numpy as np\n\nclass Records(object):\n    \"\"\"\n    Records->Train,Val->Loss,Accuracy->Epoch1,2,3->[v1,v2]\n    IterRecords->Train,Val->Loss, Accuracy,->[v1,v2]\n    \"\"\"\n    def __init__(self, log_dir, records=None):\n        if records == None:\n            self.records = OrderedDict()\n        else:\n            self.records = records\n        self.iter_rec = OrderedDict()\n        self.log_dir  = log_dir\n        self.classes = ['loss', 'acc', 'err', 'ratio']\n\n    def resetIter(self):\n        self.iter_rec.clear()\n\n    def checkDict(self, a_dict, key, sub_type='dict'):\n        if key not in a_dict.keys():\n            if sub_type == 'dict':\n                a_dict[key] = OrderedDict()\n            if sub_type == 'list':\n                a_dict[key] = []\n\n    def updateIter(self, split, keys, values):\n        self.checkDict(self.iter_rec, split, 'dict')\n        for k, v in zip(keys, values):\n            self.checkDict(self.iter_rec[split], k, 'list')\n            self.iter_rec[split][k].append(v)\n\n    def saveIterRecord(self, epoch, reset=True):\n        for s in self.iter_rec.keys(): # s stands for split\n            self.checkDict(self.records, s, 'dict')\n            for k in self.iter_rec[s].keys():\n                self.checkDict(self.records[s], k, 'dict')\n                self.checkDict(self.records[s][k], epoch, 'list')\n                self.records[s][k][epoch].append(np.mean(self.iter_rec[s][k]))\n        if reset: \n            self.resetIter()\n\n    def insertRecord(self, split, key, epoch, value):\n        self.checkDict(self.records, split, 'dict')\n        self.checkDict(self.records[split], key, 'dict')\n        self.checkDict(self.records[split][key], epoch, 'list')\n        self.records[split][key][epoch].append(value)\n\n    def iterRecToString(self, split, epoch):\n        rec_strs = ''\n        for c in self.classes:\n            strs = ''\n            for k in self.iter_rec[split].keys():\n                if (c in k.lower()):\n                    strs += '{}: {:.3f}| '.format(k, np.mean(self.iter_rec[split][k]))\n            if strs != '':\n                rec_strs += '\\t [{}] {}\\n'.format(c.upper(), strs)\n        self.saveIterRecord(epoch)\n        return rec_strs\n\n    def epochRecToString(self, split, epoch):\n        rec_strs = ''\n        for c in self.classes:\n            strs = ''\n            for k in self.records[split].keys():\n                if (c in k.lower()) and (epoch in self.records[split][k].keys()):\n                    strs += '{}: {:.3f}| '.format(k, np.mean(self.records[split][k][epoch]))\n            if strs != '':\n                rec_strs += '\\t [{}] {}\\n'.format(c.upper(), strs)\n        return rec_strs\n\n    def recordToDictOfArray(self, splits, epoch=-1, intv=1):\n        if len(self.records) == 0: return {}\n        if type(splits) == str: splits = [splits]\n\n        dict_of_array = OrderedDict()\n        for split in splits:\n            for k in self.records[split].keys():\n                y_array, x_array = [], []\n                if epoch < 0:\n                    for ep in self.records[split][k].keys():\n                        y_array.append(np.mean(self.records[split][k][ep]))\n                        x_array.append(ep)\n                else:\n                    if epoch in self.records[split][k].keys():\n                        y_array = np.array(self.records[split][k][epoch])\n                        x_array = np.linspace(intv, intv*len(y_array), len(y_array))\n                dict_of_array[split[0] + split[-1] + '_' + k]      = y_array\n                dict_of_array[split[0] + split[-1] + '_' + k+'_x'] = x_array\n        return dict_of_array\n"
  },
  {
    "path": "utils/time_utils.py",
    "content": "import time\nimport torch\nfrom collections import OrderedDict\n\nclass Timer(object):\n    def __init__(self, cuda_sync=False):\n        self.timer = OrderedDict()\n        self.cuda_sync = cuda_sync\n        self.startTimer()\n\n    def startTimer(self):\n        self.iter_start = time.time()\n        self.disp_start = time.time()\n\n    def resetTimer(self):\n        self.iter_start = time.time()\n        self.disp_start = time.time()\n        for key in self.timer.keys(): self.timer[key].reset()\n\n    def updateTime(self, key):\n        if key not in self.timer.keys(): self.timer[key] = AverageMeter()\n        if self.cuda_sync: torch.cuda.synchronize()\n        self.timer[key].update(time.time() - self.iter_start)\n        self.iter_start = time.time()\n\n    def timeToString(self, reset=True):\n        strs = '\\t [Time %.3fs] ' % (time.time() - self.disp_start)\n        for key in self.timer.keys():\n            if self.timer[key].sum < 1e-4: continue\n            strs += '%s: %.3fs| ' % (key, self.timer[key].sum)\n        self.resetTimer()\n        return strs\n\nclass AverageMeter(object):\n    def __init__(self):\n        self.reset()\n\n    def reset(self):\n        self.sum = 0\n        self.count = 0\n        self.avg = 0\n\n    def update(self, val, n=1):\n        self.sum += val * n\n        self.count += n\n        self.avg = self.sum / self.count\n\n    def __repr__(self):\n        return '%.3f' % (self.avg)\n"
  },
  {
    "path": "utils/utils.py",
    "content": "import os \nimport numpy as np\nfrom imageio import imread, imsave\n\ndef makeFile(f):\n    if not os.path.exists(f):\n        os.makedirs(f)\n    #else:  raise Exception('Rendered image directory %s is already existed!!!' % directory)\n\ndef makeFiles(f_list):\n    for f in f_list:\n        makeFile(f)\n\ndef emptyFile(name):\n    with open(name, 'w') as f:\n        f.write(' ')\n\ndef dictToString(dicts, start='\\t', end='\\n'):\n    strs = '' \n    for k, v in sorted(dicts.items()):\n        strs += '%s%s: %s%s' % (start, str(k), str(v), end) \n    return strs\n\ndef checkIfInList(list1, list2):\n    contains = []\n    for l1 in list1:\n        for l2 in list2:\n            if l1 in l2.lower():\n                contains.append(l1)\n                break\n    return contains\n\ndef atoi(text):\n    return int(text) if text.isdigit() else text\n\ndef natural_keys(text):\n    '''\n    alist.sort(key=natural_keys) sorts in human order\n    http://nedbatchelder.com/blog/200712/human_sorting.html\n    (See Toothy's implementation in the comments)\n    '''\n    return [ atoi(c) for c in re.split('(\\d+)', text) ]\n\ndef readList(list_path,ignore_head=False, sort=False):\n    lists = []\n    with open(list_path) as f:\n        lists = f.read().splitlines()\n    if ignore_head:\n        lists = lists[1:]\n    if sort:\n        lists.sort(key=natural_keys)\n    return lists\n"
  }
]