Full Code of guanyingc/SDPS-Net for AI

master b84ca3aaa5cd cached
41 files
100.9 KB
29.1k tokens
179 symbols
1 requests
Download .txt
Repository: guanyingc/SDPS-Net
Branch: master
Commit: b84ca3aaa5cd
Files: 41
Total size: 100.9 KB

Directory structure:
gitextract_g3mq9mkn/

├── .gitignore
├── LICENSE.txt
├── README.md
├── data/
│   └── .gitignore
├── datasets/
│   ├── UPS_Custom_Dataset.py
│   ├── UPS_DiLiGenT_main.py
│   ├── UPS_Synth_Dataset.py
│   ├── __init__.py
│   ├── custom_data_loader.py
│   ├── pms_transforms.py
│   └── util.py
├── eval/
│   ├── run_stage1.py
│   └── run_stage2.py
├── main_stage1.py
├── main_stage2.py
├── models/
│   ├── LCNet.py
│   ├── NENet.py
│   ├── __init__.py
│   ├── custom_model.py
│   ├── model_utils.py
│   └── solver_utils.py
├── options/
│   ├── __init__.py
│   ├── base_opts.py
│   ├── run_model_opts.py
│   ├── stage1_opts.py
│   └── stage2_opts.py
├── scripts/
│   ├── DiLiGenT_objects.txt
│   ├── cropDiLiGenTData.py
│   ├── download_pretrained_models.sh
│   ├── download_synthetic_datasets.sh
│   └── prepare_diligent_dataset.sh
├── test_stage1.py
├── test_stage2.py
├── train_stage1.py
├── train_stage2.py
└── utils/
    ├── __init__.py
    ├── eval_utils.py
    ├── logger.py
    ├── recorders.py
    ├── time_utils.py
    └── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.jpg
*.png
tags
*.pyc
*.pyo


================================================
FILE: LICENSE.txt
================================================
MIT License

Copyright (c) 2018 Guanying Chen

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# SDPS-Net
**[SDPS-Net: Self-calibrating Deep Photometric Stereo Networks, CVPR 2019 (Oral)](http://guanyingc.github.io/SDPS-Net/)**.
<br>
[Guanying Chen](https://guanyingc.github.io), [Kai Han](http://www.hankai.org/), [Boxin Shi](http://alumni.media.mit.edu/~shiboxin/), [Yasuyuki Matsushita](http://www-infobiz.ist.osaka-u.ac.jp/en/member/matsushita/), [Kwan-Yee K. Wong](http://i.cs.hku.hk/~kykwong/)
<br>

This paper addresses the problem of learning based _uncalibrated_ photometric stereo for non-Lambertian surface.
<br>
<p align="center">
    <img src='data/images/buddha.gif' height="250" >
    <img src='data/images/GT.png' height="250" >
</p>

### _Changelog_
- July 28, 2019: We have already updated the code to support Python 3.7 + PyTorch 1.10. To run the previous version (Python 2.7 + PyTorch 0.40), please checkout to `python2.7` branch first (by `git checkout python2.7`).

## Dependencies
SDPS-Net is implemented in [PyTorch](https://pytorch.org/) and tested with Ubuntu (14.04 and 16.04), please install PyTorch first following the official instruction. 

- Python 3.7 
- PyTorch (version = 1.10)
- torchvision
- CUDA-9.0 # Skip this if you only have CPUs in your computer
- numpy
- scipy
- scikit-image 

You are highly recommended to use Anaconda and create a new environment to run this code.
```shell
# Create a new python3.7 environment named py3.7
conda create -n py3.7 python=3.7

# Activate the created environment
source activate py3.7

# Example commands for installing the dependencies 
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch
conda install -c anaconda scipy 
conda install -c anaconda scikit-image 

# Download this code
git clone https://github.com/guanyingc/SDPS-Net.git
cd SDPS-Net
```
## Overview
We provide:
- Trained models
    - LCNet for lighting calibration from input images
    - NENet for normal estimation from input images and estimated lightings.
- Code to test on DiLiGenT main dataset
- Full code to train a new model, including codes for debugging, visualization and logging.

## Testing
### Download the trained models
```
sh scripts/download_pretrained_models.sh
```
If the above command is not working, please manually download the trained models from BaiduYun ([LCNet and NENet](https://pan.baidu.com/s/10huOyPkfDSkDUK23_j4y1w?pwd=i5ha)) and put them in `./data/models/`.

### Test SDPS-Net on the DiLiGenT main dataset
```shell
# Prepare the DiLiGenT main dataset
sh scripts/prepare_diligent_dataset.sh
# This command will first download and unzip the DiLiGenT dataset, and then centered crop 
# the original images based on the object mask with a margin size of 15 pixels.

# Test SDPS-Net on DiLiGenT main dataset using all of the 96 image
CUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar
# Please check the outputs in data/models/

# If you only have CPUs, please add the argument "--cuda" to disable the usage of GPU
python eval/run_stage2.py --cuda --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar
```

### Test SDPS-Net on your own dataset
You have two options to test our method on your dataset. In the first option, you have to implement a customized Dataset class to load your data, which should not be difficult. Please refer to `datasets/UPS_DiLiGenT_main.py` for an example that loads the DiLiGenT main dataset.

If you don't want to implement your own Dataset class, you may try our `datasets/UPS_Custom_Dataset.py`. However, you have to first arrange your dataset in the same format as the `data/ToyPSDataset/`. Then you can call the following commands.
```shell
CUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar --benchmark UPS_Custom_Dataset --bm_dir /path/to/your/dataset

# To test SDPS-Net on the ToyPSDataset, simply run
CUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar --benchmark UPS_Custom_Dataset --bm_dir data/ToyPSDataset/
# Please check the outputs in data/models/
```
You may find input arguments in `run_model_opts.py` (particularly `--have_l_dirs`, `--have_l_ints`, and `--have_gt_n`) useful when testing your own dataset.

## Training
We adopted the publicly available synthetic [PS Blobby and Sculpture datasets](https://github.com/guanyingc/PS-FCN) for training.
To train a new SDPS-Net model, please follow the following steps:

### Download the training data
```shell
# The total size of the zipped synthetic datasets is 4.7+19=23.7 GB 
# and it takes some times to download and unzip the datasets.
sh scripts/download_synthetic_datasets.sh
```
If the above command is not working, please manually download the training datasets from BaiduYun ([PS Sculpture Dataset and PS Blobby Dataset](https://pan.baidu.com/s/1WUVu9ibIBh4wM1shTXBuNw?pwd=snyc) and put them in `./data/datasets/`.

### First stage: train Lighting Calibration Network (LCNet)
```shell
# Train LCNet on synthetic datasets using 32 input images
CUDA_VISIBLE_DEVICES=0 python main_stage1.py --in_img_num 32
# Please refer to options/base_opt.py and options/stage1_opt.py for more options

# You can find checkpoints and results in data/logdir/
# It takes about 20 hours to train LCNet on a single Titan X Pascal GPU.
```
### Second stage: train Normal Estimation Network (NENet)
```shell
# Train NENet on synthetic datasets using 32 input images
CUDA_VISIBLE_DEVICES=0 python main_stage2.py --in_img_num 32 --retrain data/logdir/path/to/checkpointDirOfLCNet/checkpoint20.pth.tar
# Please refer to options/base_opt.py and options/stage2_opt.py for more options

# You can find checkpoints and results in data/logdir/
# It takes about 26 hours to train NENet on a single Titan X Pascal GPU.
```

## FAQ

#### Q1: How to test SDPS-Net on other datasets?
- You can implement a customized Dataset class to load your data. You may also use the provided `datasets/UPS_Custom_Dataset.py` Dataset class to load your data. However, you have to first arrange your dataset in the same format as the `data/ToyPSDataset/`. Precomputed results on DiLiGenT main dataset, Gourd\&Apple dataset, Light Stage Dataset and Synthetic Test dataset are available upon request.

#### Q2: What should I do if I have problem in running your code?
- Please create an issue if you encounter errors when trying to run the code. Please also feel free to submit a bug report.

#### Q3: Could I run your code only using CPUs?
- The good news is that you can simply append `--cuda` in your command to disable the usage of GPU. The running time for the testing on DiLiGenT benchmark using CPUs is still bearable (should be less than 20 minutes). However, it is EXTREMELY SLOW for training! 

## Citation
If you find this code or the provided models useful in your research, please consider cite: 
```
@inproceedings{chen2019SDPS_Net,
  title={SDPS-Net: Self-calibrating Deep Photometric Stereo Networks},
  author={Chen, Guanying and Han, Kai and Shi, Boxin and Matsushita, Yasuyuki and Wong, Kwan-Yee~K.},
  booktitle={CVPR},
  year={2019}
}
```


================================================
FILE: data/.gitignore
================================================
*
!.gitignore


================================================
FILE: datasets/UPS_Custom_Dataset.py
================================================
from __future__ import division
import os
import numpy as np
import scipy.io as sio
from imageio import imread

import torch
import torch.utils.data as data

from datasets import pms_transforms
from . import util
np.random.seed(0)

class UPS_Custom_Dataset(data.Dataset):
    def __init__(self, args, split='train'):
        self.root   = os.path.join(args.bm_dir)
        self.split  = split
        self.args   = args
        self.objs   = util.readList(os.path.join(self.root, 'objects.txt'), sort=False)
        args.log.printWrite('[%s Data] \t%d objs. Root: %s' % (split, len(self.objs), self.root))

    def _getMask(self, obj):
        mask = imread(os.path.join(self.root, obj, 'mask.png'))
        if mask.ndim > 2: mask = mask[:,:,0]
        mask = mask.reshape(mask.shape[0], mask.shape[1], 1)
        return mask / 255.0

    def __getitem__(self, index):
        obj   = self.objs[index]
        names  = util.readList(os.path.join(self.root, obj, 'names.txt'))
        img_list   = [os.path.join(self.root, obj, names[i]) for i in range(len(names))]

        if self.args.have_l_dirs:
            dirs = np.genfromtxt(os.path.join(self.root, obj, 'light_directions.txt'))
        else:
            dirs = np.zeros((len(names), 3))
            dirs[:,2] = 1
        
        if self.args.have_l_ints:
            ints = np.genfromtxt(os.path.join(self.root, obj, 'light_intensities.txt'))
        else:
            ints = np.ones((len(names), 3))

        imgs = []
        for idx, img_name in enumerate(img_list):
            img = imread(img_name).astype(np.float32) / 255.0
            imgs.append(img)
        img = np.concatenate(imgs, 2)
        h, w, c = img.shape

        if self.args.have_gt_n:
            normal_path = os.path.join(self.root, obj, 'normal.png')
            normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1
        else:
            normal = np.zeros((h, w, 3))

        mask = self._getMask(obj)
        img  = img * mask.repeat(img.shape[2], 2)

        item = {'normal': normal, 'img': img, 'mask': mask}

        downsample = 4 
        for k in item.keys():
            item[k] = pms_transforms.imgSizeToFactorOfK(item[k], downsample)

        for k in item.keys(): 
            item[k] = pms_transforms.arrayToTensor(item[k])

        item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()
        item['ints'] = torch.from_numpy(ints).view(-1, 1, 1).float()
        item['obj'] = obj
        item['path'] = os.path.join(self.root, obj)
        return item

    def __len__(self):
        return len(self.objs)


================================================
FILE: datasets/UPS_DiLiGenT_main.py
================================================
from __future__ import division
import os
import numpy as np
import scipy.io as sio
#from scipy.ndimage import imread
from imageio import imread

import torch
import torch.utils.data as data

from datasets import pms_transforms
from . import util
np.random.seed(0)

class UPS_DiLiGenT_main(data.Dataset):
    def __init__(self, args, split='train'):
        self.root  = os.path.join(args.bm_dir)
        self.split = split
        self.args  = args
        self.objs  = util.readList(os.path.join(self.root, 'objects.txt'), sort=False)
        self.names = util.readList(os.path.join(self.root, 'names.txt'),   sort=False)
        self.l_dir = util.light_source_directions()
        args.log.printWrite('[%s Data] \t%d objs %d lights. Root: %s' % 
                (split, len(self.objs), len(self.names), self.root))
        self.ints = {}
        ints_name = 'light_intensities.txt'
        print('Files for intensity: %s' % (ints_name))
        for obj in self.objs:
            self.ints[obj] = np.genfromtxt(os.path.join(self.root, obj, ints_name))

    def _getMask(self, obj):
        mask = imread(os.path.join(self.root, obj, 'mask.png'))
        if mask.ndim > 2: mask = mask[:,:,0]
        mask = mask.reshape(mask.shape[0], mask.shape[1], 1)
        return mask / 255.0

    def __getitem__(self, index):
        np.random.seed(index)
        obj = self.objs[index]
        select_idx = range(len(self.names))

        img_list = [os.path.join(self.root, obj, self.names[i]) for i in select_idx]
        ints = [np.diag(1 / self.ints[obj][i]) for i in select_idx]
        dirs = self.l_dir[select_idx]

        normal_path = os.path.join(self.root, obj, 'Normal_gt.mat')
        normal = sio.loadmat(normal_path)['Normal_gt']

        imgs = []
        for idx, img_name in enumerate(img_list):
            img = imread(img_name).astype(np.float32) / 255.0
            if self.args.in_light and not self.args.int_aug:
                img = np.dot(img, ints[idx])
            imgs.append(img)
        img = np.concatenate(imgs, 2)

        mask = self._getMask(obj)
        if self.args.test_resc:
            img, normal = pms_transforms.rescale(img, normal, [self.args.test_h, self.args.test_w])
            mask = pms_transforms.rescaleSingle(mask, [self.args.test_h, self.args.test_w])

        img = img * mask.repeat(img.shape[2], 2)

        norm = np.sqrt((normal * normal).sum(2, keepdims=True))
        normal = normal / (norm + 1e-10)

        item = {'normal': normal, 'img': img, 'mask': mask}

        downsample = 4 
        for k in item.keys():
            item[k] = pms_transforms.imgSizeToFactorOfK(item[k], downsample)

        for k in item.keys(): 
            item[k] = pms_transforms.arrayToTensor(item[k])

        item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()
        item['ints'] = torch.from_numpy(self.ints[obj][select_idx]).view(-1, 1, 1).float()

        item['obj'] = obj
        item['path'] = os.path.join(self.root, obj)
        return item

    def __len__(self):
        return len(self.objs)


================================================
FILE: datasets/UPS_Synth_Dataset.py
================================================
from __future__ import division
import os
import numpy as np
#from scipy.ndimage import imread
from imageio import imread

import torch
import torch.utils.data as data

from datasets import pms_transforms
from . import util
np.random.seed(0)

class UPS_Synth_Dataset(data.Dataset):
    def __init__(self, args, root, split='train'):
        self.root  = os.path.join(root)
        self.split = split
        self.args  = args
        self.shape_list = util.readList(os.path.join(self.root, split + args.l_suffix))

    def _getInputPath(self, index):
        shape, mtrl = self.shape_list[index].split('/')
        normal_path = os.path.join(self.root, 'Images', shape, shape + '_normal.png')
        img_dir     = os.path.join(self.root, 'Images', self.shape_list[index])
        img_list    = util.readList(os.path.join(img_dir, '%s_%s.txt' % (shape, mtrl)))

        data = np.genfromtxt(img_list, dtype='str', delimiter=' ')
        select_idx = np.random.permutation(data.shape[0])[:self.args.in_img_num]
        idxs = ['%03d' % (idx) for idx in select_idx]
        data = data[select_idx, :]
        imgs = [os.path.join(img_dir, img) for img in data[:, 0]]
        dirs = data[:, 1:4].astype(np.float32)
        return normal_path, imgs, dirs

    def __getitem__(self, index):
        normal_path, img_list, dirs = self._getInputPath(index)
        normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1
        imgs   =  []
        for i in img_list:
            img = imread(i).astype(np.float32) / 255.0
            imgs.append(img)
        img = np.concatenate(imgs, 2)

        h, w, c = img.shape
        crop_h, crop_w = self.args.crop_h, self.args.crop_w
        if self.args.rescale and not (crop_h == h):
            sc_h = np.random.randint(crop_h, h) if self.args.rand_sc else self.args.scale_h
            sc_w = np.random.randint(crop_w, w) if self.args.rand_sc else self.args.scale_w
            img, normal = pms_transforms.rescale(img, normal, [sc_h, sc_w])

        if self.args.crop:
            img, normal = pms_transforms.randomCrop(img, normal, [crop_h, crop_w])

        if self.args.color_aug:
            img = img * np.random.uniform(1, self.args.color_ratio)

        if self.args.int_aug:
            ints = pms_transforms.getIntensity(len(imgs))
            img  = np.dot(img, np.diag(ints.reshape(-1)))
        else:
            ints = np.ones(c)

        if self.args.noise_aug:
            img = pms_transforms.randomNoiseAug(img, self.args.noise)

        mask   = pms_transforms.normalToMask(normal)
        normal = normal * mask.repeat(3, 2) 
        norm   = np.sqrt((normal * normal).sum(2, keepdims=True))
        normal = normal / (norm + 1e-10) # Rescale normal to unit length

        item = {'normal': normal, 'img': img, 'mask': mask}
        for k in item.keys(): 
            item[k] = pms_transforms.arrayToTensor(item[k])

        item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()
        item['ints'] = torch.from_numpy(ints).view(-1, 1, 1).float()
        return item

    def __len__(self):
        return len(self.shape_list)


================================================
FILE: datasets/__init__.py
================================================
#from .PMS_dataset_v1 import PMS_dataset
#from .PMS_dataset_v2 import PMS_data_v2
#from .DiLiGenT import DiLiGenT
#__all__ = ('PMS_dataset', 'DiLiGenT', 'PMS_data_v2')


================================================
FILE: datasets/custom_data_loader.py
================================================
import torch.utils.data

def customDataloader(args):
    args.log.printWrite("=> fetching img pairs in %s" % (args.data_dir))
    datasets = __import__('datasets.' + args.dataset)
    dataset_file = getattr(datasets, args.dataset)
    train_set = getattr(dataset_file, args.dataset)(args, args.data_dir, 'train')
    val_set   = getattr(dataset_file, args.dataset)(args, args.data_dir, 'val')

    if args.concat_data:
        args.log.printWrite('****** Using cocnat data ******')
        args.log.printWrite("=> fetching img pairs in '{}'".format(args.data_dir2))
        train_set2 = getattr(dataset_file, args.dataset)(args, args.data_dir2, 'train')
        val_set2   = getattr(dataset_file, args.dataset)(args, args.data_dir2, 'val')
        train_set  = torch.utils.data.ConcatDataset([train_set, train_set2])
        val_set    = torch.utils.data.ConcatDataset([val_set,   val_set2])

    args.log.printWrite('Found Data:\t %d Train and %d Val' % (len(train_set), len(val_set)))
    args.log.printWrite('\t Train Batch: %d, Val Batch: %d' % (args.batch, args.val_batch))

    train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch,
        num_workers=args.workers, pin_memory=args.cuda, shuffle=True)
    test_loader  = torch.utils.data.DataLoader(val_set , batch_size=args.val_batch,
        num_workers=args.workers, pin_memory=args.cuda, shuffle=False)
    return train_loader, test_loader

def benchmarkLoader(args):
    args.log.printWrite("=> fetching img pairs in 'data/%s'" % (args.benchmark))
    datasets = __import__('datasets.' + args.benchmark)
    dataset_file = getattr(datasets, args.benchmark)
    test_set = getattr(dataset_file, args.benchmark)(args, 'test')

    args.log.printWrite('Found Benchmark Data: %d samples' % (len(test_set)))
    args.log.printWrite('\t Test Batch %d' % (args.test_batch))

    test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.test_batch,
        num_workers=args.workers, pin_memory=args.cuda, shuffle=False)
    return test_loader


================================================
FILE: datasets/pms_transforms.py
================================================
import torch
import random
import numpy as np
from skimage.transform import resize
random.seed(0)
np.random.seed(0)

def arrayToTensor(array):
    if array is None:
        return array
    array = np.transpose(array, (2, 0, 1))
    tensor = torch.from_numpy(array)
    return tensor.float()

def normalToMask(normal, thres=1e-2):
    """
    Due to the numerical precision of uint8, [0, 0, 0] will save as [127, 127, 127] in gt normal,
    When we load the data and rescale normal by N / 255 * 2 - 1, [127, 127, 127] becomes 
    [-0.003927, -0.003927, -0.003927]
    """
    mask = (np.square(normal).sum(2, keepdims=True) > thres).astype(np.float32)
    return mask

def imgSizeToFactorOfK(img, k):
    if img.shape[0] % k == 0 and img.shape[1] % k == 0:
        return img
    pad_h, pad_w = k - img.shape[0] % k, k - img.shape[1] % k
    img = np.pad(img, ((0, pad_h), (0, pad_w), (0,0)), 
            'constant', constant_values=((0,0),(0,0),(0,0)))
    return img

def randomCrop(inputs, target, size):
    h, w, _ = inputs.shape
    c_h, c_w = size
    if h == c_h and w == c_w:
        return inputs, target
    x1 = random.randint(0, w - c_w)
    y1 = random.randint(0, h - c_h)
    inputs = inputs[y1: y1 + c_h, x1: x1 + c_w]
    target = target[y1: y1 + c_h, x1: x1 + c_w]
    return inputs, target

def centerCrop(inputs, size):
    h, w, _ = inputs.shape
    c_h, c_w = size
    if h != c_h or w != c_w:
        x1 = int(w / 2 - c_w / 2)
        y1 = int(h / 2 - c_h / 2)
        inputs = inputs[y1: y1 + c_h, x1: x1 + c_w]
    return inputs

def rescale(inputs, target, size):
    in_h, in_w, _ = inputs.shape
    h, w = size
    if h != in_h or w != in_w:
        inputs = resize(inputs, size, order=1, mode='reflect')
        target = resize(target, size, order=1, mode='reflect')
    return inputs, target

def rescaleSingle(inputs, size, order=1):
    in_h, in_w, _ = inputs.shape
    h, w = size
    if h != in_h or w != in_w:
        inputs = resize(inputs, size, order=order, mode='reflect')
    return inputs

def randomNoiseAug(inputs, noise_level=0.05):
    noise = np.random.random(inputs.shape)
    noise = (noise - 0.5) * noise_level
    inputs += noise
    return inputs

def getIntensity(num):
    intensity = np.random.random((num, 1)) * 1.8 + 0.2
    color = np.ones((1, 3)) # Uniform color
    intens = (intensity.repeat(3, 1) * color)
    return intens


================================================
FILE: datasets/util.py
================================================
import numpy as np
import re

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    '''
    return [ atoi(c) for c in re.split('(\d+)', text) ]

def readList(list_path,ignore_head=False, sort=True):
    lists = []
    with open(list_path) as f:
        lists = f.read().splitlines()
    if ignore_head:
        lists = lists[1:]
    if sort:
        lists.sort(key=natural_keys)
    return lists

def light_source_directions():
    """
    Below matrix is from DiLiGenT.
    :return: light source direction matrix. [light_num, 3]
    :rtype: np.ndarray
    """
    L = np.array([[-0.06059872, -0.44839055, 0.8917812],
                  [-0.05939919, -0.33739538, 0.93948714],
                  [-0.05710194, -0.21230722, 0.97553319],
                  [-0.05360061, -0.07800089, 0.99551134],
                  [-0.04919816, 0.05869781, 0.99706274],
                  [-0.04399823, 0.19019233, 0.98076044],
                  [-0.03839991, 0.31049925, 0.9497977],
                  [-0.03280081, 0.41611025, 0.90872238],
                  [-0.18449839, -0.43989616, 0.87889232],
                  [-0.18870114, -0.32950199, 0.92510557],
                  [-0.1901994, -0.20549935, 0.95999698],
                  [-0.18849605, -0.07269848, 0.97937948],
                  [-0.18329657, 0.06229884, 0.98108166],
                  [-0.17500445, 0.19220488, 0.96562453],
                  [-0.16449474, 0.31129005, 0.93597008],
                  [-0.15270716, 0.4160195, 0.89644202],
                  [-0.30139786, -0.42509698, 0.85349393],
                  [-0.31020115, -0.31660118, 0.89640333],
                  [-0.31489186, -0.19549495, 0.92877599],
                  [-0.31450962, -0.06640203, 0.94692897],
                  [-0.30880699, 0.06470146, 0.94892147],
                  [-0.2981084, 0.19100538, 0.93522635],
                  [-0.28359251, 0.30729189, 0.90837601],
                  [-0.26670649, 0.41020998, 0.87212122],
                  [-0.40709586, -0.40559588, 0.81839168],
                  [-0.41919869, -0.29999906, 0.85689732],
                  [-0.42618633, -0.18329412, 0.88587159],
                  [-0.42691512, -0.05950211, 0.90233197],
                  [-0.42090385, 0.0659006, 0.90470827],
                  [-0.40860354, 0.18720162, 0.89330773],
                  [-0.39141794, 0.29941372, 0.87013988],
                  [-0.3707838, 0.39958255, 0.83836338],
                  [-0.499596, -0.38319693, 0.77689378],
                  [-0.51360334, -0.28130183, 0.81060526],
                  [-0.52190667, -0.16990217, 0.83591069],
                  [-0.52326874, -0.05249686, 0.85054918],
                  [-0.51720021, 0.06620003, 0.85330035],
                  [-0.50428312, 0.18139393, 0.84427174],
                  [-0.48561334, 0.28870793, 0.82512267],
                  [-0.46289771, 0.38549809, 0.79819605],
                  [-0.57853599, -0.35932235, 0.73224555],
                  [-0.59329349, -0.26189713, 0.76119165],
                  [-0.60202327, -0.15630604, 0.78303027],
                  [-0.6037003, -0.04570002, 0.7959004],
                  [-0.59781529, 0.06590169, 0.79892043],
                  [-0.58486953, 0.17439091, 0.79215873],
                  [-0.56588359, 0.27639198, 0.77677747],
                  [-0.54241965, 0.36921337, 0.75462733],
                  [0.05220076, -0.43870637, 0.89711304],
                  [0.05199786, -0.33138635, 0.9420612],
                  [0.05109826, -0.20999284, 0.97636672],
                  [0.04919919, -0.07869871, 0.99568366],
                  [0.04640163, 0.05630197, 0.99733494],
                  [0.04279892, 0.18779527, 0.98127529],
                  [0.03870043, 0.30950341, 0.95011048],
                  [0.03440055, 0.41730662, 0.90811441],
                  [0.17290651, -0.43181626, 0.88523333],
                  [0.17839998, -0.32509996, 0.92869988],
                  [0.18160174, -0.20480196, 0.96180921],
                  [0.18200745, -0.07490306, 0.98044012],
                  [0.17919505, 0.05849838, 0.98207285],
                  [0.17329685, 0.18839658, 0.96668244],
                  [0.1649036, 0.30880674, 0.93672045],
                  [0.1549931, 0.41578148, 0.89616009],
                  [0.28720483, -0.41910705, 0.8613145],
                  [0.29740177, -0.31410186, 0.90160535],
                  [0.30420604, -0.1965039, 0.9321185],
                  [0.30640529, -0.07010121, 0.94931639],
                  [0.30361153, 0.05950226, 0.95093613],
                  [0.29588748, 0.18589214, 0.93696036],
                  [0.28409783, 0.30349768, 0.90949304],
                  [0.26939905, 0.40849857, 0.87209694],
                  [0.39120402, -0.40190413, 0.8279085],
                  [0.40481085, -0.29960803, 0.86392315],
                  [0.41411685, -0.18590756, 0.89103626],
                  [0.41769724, -0.06449957, 0.906294],
                  [0.41498764, 0.05959822, 0.90787296],
                  [0.40607977, 0.18089099, 0.89575537],
                  [0.39179226, 0.29439419, 0.87168279],
                  [0.37379609, 0.39649585, 0.83849122],
                  [0.48278794, -0.38169046, 0.78818031],
                  [0.49848546, -0.28279175, 0.8194761],
                  [0.50918069, -0.1740934, 0.84286803],
                  [0.51360856, -0.05870098, 0.85601427],
                  [0.51097962, 0.05899765, 0.8575658],
                  [0.50151639, 0.17420569, 0.84742769],
                  [0.48600297, 0.28260173, 0.82700506],
                  [0.46600106, 0.38110087, 0.79850181],
                  [0.56150442, -0.35990283, 0.74510586],
                  [0.57807114, -0.26498677, 0.77176147],
                  [0.58933134, -0.1617086, 0.7915421],
                  [0.59407609, -0.05289787, 0.80266769],
                  [0.59157958, 0.057798, 0.80417224],
                  [0.58198189, 0.16649482, 0.79597523],
                  [0.56620006, 0.26940003, 0.77900008],
                  [0.54551481, 0.36380988, 0.7550205]], dtype=float)
    return L


================================================
FILE: eval/run_stage1.py
================================================
import torch, sys
sys.path.append('.')

from datasets import custom_data_loader
from options  import run_model_opts
from models   import custom_model
from utils    import logger, recorders

import test_stage1 as test_utils

args = run_model_opts.RunModelOpts().parse()
log  = logger.Logger(args)

def main(args):
    test_loader = custom_data_loader.benchmarkLoader(args)
    model    = custom_model.buildModel(args)
    recorder = recorders.Records(args.log_dir)
    test_utils.test(args, 'test', test_loader, model, log, 1, recorder)
    log.plotCurves(recorder, 'test')

if __name__ == '__main__':
    torch.manual_seed(args.seed)
    main(args)


================================================
FILE: eval/run_stage2.py
================================================
import torch, sys
sys.path.append('.')

from datasets import custom_data_loader
from options  import run_model_opts
from models   import custom_model
from utils    import logger, recorders

import test_stage2 as test_utils

args = run_model_opts.RunModelOpts().parse()
args.stage2    = True
args.test_resc = False
log  = logger.Logger(args)

def main(args):
    test_loader = custom_data_loader.benchmarkLoader(args)
    model = custom_model.buildModel(args)
    model_s2 = custom_model.buildModelStage2(args)
    models = [model, model_s2]

    recorder = recorders.Records(args.log_dir)
    test_utils.test(args, 'test', test_loader, models, log, 1, recorder)
    log.plotCurves(recorder, 'test')

if __name__ == '__main__':
    torch.manual_seed(args.seed)
    main(args)


================================================
FILE: main_stage1.py
================================================
import torch
from options  import stage1_opts
from utils    import logger, recorders
from datasets import custom_data_loader
from models   import custom_model, solver_utils, model_utils

import train_stage1 as train_utils
import test_stage1  as test_utils

args = stage1_opts.TrainOpts().parse()
log  = logger.Logger(args)

def main(args):
    model = custom_model.buildModel(args)
    optimizer, scheduler, records = solver_utils.configOptimizer(args, model)
    criterion = solver_utils.Stage1ClsCrit(args)
    recorder  = recorders.Records(args.log_dir, records)

    train_loader, val_loader = custom_data_loader.customDataloader(args)

    for epoch in range(args.start_epoch, args.epochs+1):
        scheduler.step()
        recorder.insertRecord('train', 'lr', epoch, scheduler.get_lr()[0])

        train_utils.train(args, train_loader, model, criterion, optimizer, log, epoch, recorder)
        if epoch % args.save_intv == 0: 
            model_utils.saveCheckpoint(args.cp_dir, epoch, model, optimizer, recorder.records, args)
        log.plotCurves(recorder, 'train')

        if epoch % args.val_intv == 0:
            test_utils.test(args, 'val', val_loader, model, log, epoch, recorder)
            log.plotCurves(recorder, 'val')

if __name__ == '__main__':
    torch.manual_seed(args.seed)
    main(args)


================================================
FILE: main_stage2.py
================================================
import torch
from options  import stage2_opts
from utils    import logger, recorders
from datasets import custom_data_loader
from models   import custom_model, solver_utils, model_utils

import train_stage2 as train_utils
import test_stage2 as test_utils

args = stage2_opts.TrainOpts().parse()
log  = logger.Logger(args)

def main(args):
    model = custom_model.buildModel(args)
    model_s2 = custom_model.buildModelStage2(args)
    models = [model, model_s2]

    optimizer, scheduler, records = solver_utils.configOptimizer(args, model_s2)
    optimizers = [optimizer, -1]
    criterion = solver_utils.Stage2Crit(args)
    recorder  = recorders.Records(args.log_dir, records)

    train_loader, val_loader = custom_data_loader.customDataloader(args)

    for epoch in range(args.start_epoch, args.epochs+1):
        scheduler.step()

        recorder.insertRecord('train', 'lr', epoch, scheduler.get_lr()[0])

        train_utils.train(args, train_loader, models, criterion, optimizers, log, epoch, recorder)
        if epoch % args.save_intv == 0: 
            model_utils.saveCheckpoint(args.cp_dir, epoch, model_s2, optimizer, recorder.records, args)
        log.plotCurves(recorder, 'train')

        if epoch % args.val_intv == 0:
            test_utils.test(args, 'val', val_loader, models, log, epoch, recorder)
            log.plotCurves(recorder, 'val')

if __name__ == '__main__':
    torch.manual_seed(args.seed)
    main(args)


================================================
FILE: models/LCNet.py
================================================
import torch
import torch.nn as nn
from torch.nn.init import kaiming_normal_
from . import model_utils
from utils import eval_utils

# Classification
class FeatExtractor(nn.Module):
    def __init__(self, batchNorm, c_in, c_out=256):
        super(FeatExtractor, self).__init__()
        self.conv1 = model_utils.conv(batchNorm, c_in, 64,    k=3, stride=2, pad=1)
        self.conv2 = model_utils.conv(batchNorm, 64,   128,   k=3, stride=2, pad=1)
        self.conv3 = model_utils.conv(batchNorm, 128,  128,   k=3, stride=1, pad=1)
        self.conv4 = model_utils.conv(batchNorm, 128,  128,   k=3, stride=2, pad=1)
        self.conv5 = model_utils.conv(batchNorm, 128,  128,   k=3, stride=1, pad=1)
        self.conv6 = model_utils.conv(batchNorm, 128,  256,   k=3, stride=2, pad=1)
        self.conv7 = model_utils.conv(batchNorm, 256,  256,   k=3, stride=1, pad=1)

    def forward(self, inputs):
        out = self.conv1(inputs)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        out = self.conv7(out)
        return out

class Classifier(nn.Module):
    def __init__(self, batchNorm, c_in, other):
        super(Classifier, self).__init__()
        self.conv1 = model_utils.conv(batchNorm, 512,  256, k=3, stride=1, pad=1)
        self.conv2 = model_utils.conv(batchNorm, 256,  256, k=3, stride=2, pad=1)
        self.conv3 = model_utils.conv(batchNorm, 256,  256, k=3, stride=2, pad=1)
        self.conv4 = model_utils.conv(batchNorm, 256,  256, k=3, stride=2, pad=1)
        self.other = other
        
        self.dir_x_est = nn.Sequential(
                    model_utils.conv(batchNorm, 256, 64,  k=1, stride=1, pad=0),
                    model_utils.outputConv(64, other['dirs_cls'], k=1, stride=1, pad=0))

        self.dir_y_est = nn.Sequential(
                    model_utils.conv(batchNorm, 256, 64,  k=1, stride=1, pad=0),
                    model_utils.outputConv(64, other['dirs_cls'], k=1, stride=1, pad=0))

        self.int_est = nn.Sequential(
                    model_utils.conv(batchNorm, 256, 64,  k=1, stride=1, pad=0),
                    model_utils.outputConv(64, other['ints_cls'], k=1, stride=1, pad=0))

    def forward(self, inputs):
        out = self.conv1(inputs)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        outputs = {}
        if self.other['s1_est_d']:
            outputs['dir_x'] = self.dir_x_est(out)
            outputs['dir_y'] = self.dir_y_est(out)
        if self.other['s1_est_i']:
            outputs['ints'] = self.int_est(out)
        return outputs

class LCNet(nn.Module):
    def __init__(self, fuse_type='max', batchNorm=False, c_in=3, other={}):
        super(LCNet, self).__init__()
        self.featExtractor = FeatExtractor(batchNorm, c_in, 128)
        self.classifier = Classifier(batchNorm, 256, other)
        self.c_in      = c_in
        self.fuse_type = fuse_type
        self.other     = other

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                kaiming_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def prepareInputs(self, x):
        n, c, h, w = x[0].shape
        t_h, t_w = self.other['test_h'], self.other['test_w']
        if (h == t_h and w == t_w):
            imgs = x[0] 
        else:
            print('Rescaling images: from %dX%d to %dX%d' % (h, w, t_h, t_w))
            imgs = torch.nn.functional.upsample(x[0], size=(t_h, t_w), mode='bilinear')

        inputs = list(torch.split(imgs, 3, 1))
        idx = 1
        if self.other['in_light']:
            light = torch.split(x[idx], 3, 1)
            for i in range(len(inputs)):
                inputs[i] = torch.cat([inputs[i], light[i]], 1)
            idx += 1
        if self.other['in_mask']:
            mask = x[idx]
            if mask.shape[2] != inputs[0].shape[2] or mask.shape[3] != inputs[0].shape[3]:
                mask = torch.nn.functional.upsample(mask, size=(t_h, t_w), mode='bilinear')
            for i in range(len(inputs)):
                inputs[i] = torch.cat([inputs[i], mask], 1)
            idx += 1
        return inputs

    def fuseFeatures(self, feats, fuse_type):
        if fuse_type == 'mean':
            feat_fused = torch.stack(feats, 1).mean(1)
        elif fuse_type == 'max':
            feat_fused, _ = torch.stack(feats, 1).max(1)
        return feat_fused

    def convertMidDirs(self, pred):
        _, x_idx = pred['dirs_x'].data.max(1)
        _, y_idx = pred['dirs_y'].data.max(1)
        dirs = eval_utils.SphericalClassToDirs(x_idx, y_idx, self.other['dirs_cls'])
        return dirs

    def convertMidIntens(self, pred, img_num):
        _, idx = pred['ints'].data.max(1)
        ints = eval_utils.ClassToLightInts(idx, self.other['ints_cls'])
        ints = ints.view(-1, 1).repeat(1, 3)
        ints = torch.cat(torch.split(ints, ints.shape[0] // img_num, 0), 1)
        return ints

    def forward(self, x):
        inputs = self.prepareInputs(x)
        feats = []
        for i in range(len(inputs)):
            out_feat = self.featExtractor(inputs[i])
            shape    = out_feat.data.shape
            feats.append(out_feat)
        feat_fused = self.fuseFeatures(feats, self.fuse_type)

        l_dirs_x, l_dirs_y, l_ints = [], [], []
        for i in range(len(inputs)):
            net_input = torch.cat([feats[i], feat_fused], 1)
            outputs = self.classifier(net_input)
            if self.other['s1_est_d']:
                l_dirs_x.append(outputs['dir_x'])
                l_dirs_y.append(outputs['dir_y'])
            if self.other['s1_est_i']:
                l_ints.append(outputs['ints'])

        pred = {}
        if self.other['s1_est_d']:
            pred['dirs_x'] = torch.cat(l_dirs_x, 0).squeeze()
            pred['dirs_y'] = torch.cat(l_dirs_y, 0).squeeze()
            pred['dirs']   = self.convertMidDirs(pred)
        if self.other['s1_est_i']:
            pred['ints'] = torch.cat(l_ints, 0).squeeze()
            if pred['ints'].ndimension() == 1:
                pred['ints'] = pred['ints'].view(1, -1)
            pred['intens'] = self.convertMidIntens(pred, len(inputs))
        return pred


================================================
FILE: models/NENet.py
================================================
import torch
import torch.nn as nn
from torch.nn.init import kaiming_normal_
from . import model_utils

class FeatExtractor(nn.Module):
    def __init__(self, batchNorm=False, c_in=3, other={}):
        super(FeatExtractor, self).__init__()
        self.other = other
        self.conv1 = model_utils.conv(batchNorm, c_in, 64,  k=3, stride=1, pad=1)
        self.conv2 = model_utils.conv(batchNorm, 64,   128, k=3, stride=2, pad=1)
        self.conv3 = model_utils.conv(batchNorm, 128,  128, k=3, stride=1, pad=1)
        self.conv4 = model_utils.conv(batchNorm, 128,  256, k=3, stride=2, pad=1)
        self.conv5 = model_utils.conv(batchNorm, 256,  256, k=3, stride=1, pad=1)
        self.conv6 = model_utils.deconv(256, 128)
        self.conv7 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        out_feat = self.conv7(out)
        n, c, h, w = out_feat.data.shape
        out_feat   = out_feat.view(-1)
        return out_feat, [n, c, h, w]

class Regressor(nn.Module):
    def __init__(self, batchNorm=False, other={}): 
        super(Regressor, self).__init__()
        self.other   = other
        self.deconv1 = model_utils.conv(batchNorm, 128, 128,  k=3, stride=1, pad=1)
        self.deconv2 = model_utils.conv(batchNorm, 128, 128,  k=3, stride=1, pad=1)
        self.deconv3 = model_utils.deconv(128, 64)
        self.est_normal = self._make_output(64, 3, k=3, stride=1, pad=1)
        self.other   = other

    def _make_output(self, cin, cout, k=3, stride=1, pad=1):
        return nn.Sequential(
               nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False))

    def forward(self, x, shape):
        x      = x.view(shape[0], shape[1], shape[2], shape[3])
        out    = self.deconv1(x)
        out    = self.deconv2(out)
        out    = self.deconv3(out)
        normal = self.est_normal(out)
        normal = torch.nn.functional.normalize(normal, 2, 1)
        return normal

class NENet(nn.Module):
    def __init__(self, fuse_type='max', batchNorm=False, c_in=3, other={}):
        super(NENet, self).__init__()
        self.extractor = FeatExtractor(batchNorm, c_in, other)
        self.regressor = Regressor(batchNorm, other)
        self.c_in      = c_in
        self.fuse_type = fuse_type
        self.other = other

        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                kaiming_normal_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def prepareInputs(self, x):
        imgs = torch.split(x[0], 3, 1)
        idx = 1
        if self.other['in_light']: idx += 1
        if self.other['in_mask']:  idx += 1
        dirs = torch.split(x[idx]['dirs'], x[0].shape[0], 0)
        ints = torch.split(x[idx]['intens'], 3, 1)
        
        s2_inputs = []
        for i in range(len(imgs)):
            n, c, h, w = imgs[i].shape
            l_dir = dirs[i] if dirs[i].dim() == 4 else dirs[i].view(n, -1, 1, 1)
            l_int = torch.diag(1.0 / (ints[i].contiguous().view(-1)+1e-8))
            img   = imgs[i].contiguous().view(n * c, h * w)
            img   = torch.mm(l_int, img).view(n, c, h, w)
            img_light = torch.cat([img, l_dir.expand_as(img)], 1)
            s2_inputs.append(img_light)
        return s2_inputs

    def forward(self, x):
        inputs = self.prepareInputs(x)
        feats = torch.Tensor()
        for i in range(len(inputs)):
            feat, shape = self.extractor(inputs[i])
            if i == 0:
                feats = feat
            else:
                if self.fuse_type == 'mean':
                    feats = torch.stack([feats, feat], 1).sum(1)
                elif self.fuse_type == 'max':
                    feats, _ = torch.stack([feats, feat], 1).max(1)
        if self.fuse_type == 'mean':
            feats = feats / len(img_split)
        feat_fused = feats
        normal = self.regressor(feat_fused, shape)
        pred = {}
        pred['n'] = normal
        return pred


================================================
FILE: models/__init__.py
================================================


================================================
FILE: models/custom_model.py
================================================
from . import model_utils
import torch

def buildModel(args):
    print('Creating Model %s' % (args.model))
    in_c = model_utils.getInputChanel(args)
    other = {
            'img_num':  args.in_img_num, 
            'test_h':   args.test_h,   'test_w':   args.test_w,
            'in_mask':  args.in_mask,  'in_light': args.in_light, 
            'dirs_cls': args.dirs_cls, 'ints_cls': args.ints_cls,
            's1_est_d': args.s1_est_d, 's1_est_i': args.s1_est_i, 's1_est_n': args.s1_est_n, 
            }
    models = __import__('models.' + args.model)
    model_file = getattr(models, args.model)
    model = getattr(model_file, args.model)(args.fuse_type, args.use_BN, in_c, other)

    if args.cuda: model = model.cuda()

    if args.retrain: 
        args.log.printWrite("=> using pre-trained model '{}'".format(args.retrain))
        model_utils.loadCheckpoint(args.retrain, model, cuda=args.cuda)

    if args.resume:
        args.log.printWrite("=> Resume loading checkpoint '{}'".format(args.resume))
        model_utils.loadCheckpoint(args.resume, model, cuda=args.cuda)
    print(model)
    args.log.printWrite("=> Model Parameters: %d" % (model_utils.get_n_params(model)))
    return model

def buildModelStage2(args):
    print('Creating Stage2 Model %s' % (args.model_s2))
    in_c = 6 if args.s2_in_light else 3
    other = {
            'img_num':  args.in_img_num,
            'in_mask':  args.in_mask,  'in_light': args.in_light, 
            'dirs_cls': args.dirs_cls, 'ints_cls': args.ints_cls,
            }
    models = __import__('models.' + args.model_s2)
    model_file = getattr(models, args.model_s2)
    model = getattr(model_file, args.model_s2)(args.fuse_type, args.use_BN, in_c, other)

    if args.cuda: model = model.cuda()

    if args.retrain_s2: 
        args.log.printWrite("=> using pre-trained model_s2 '{}'".format(args.retrain_s2))
        model_utils.loadCheckpoint(args.retrain_s2, model, cuda=args.cuda)

    print(model)
    args.log.printWrite("=> Stage2 Model Parameters: %d" % (model_utils.get_n_params(model)))
    return model


================================================
FILE: models/model_utils.py
================================================
import os
import torch
import torch.nn as nn

def getInput(args, data):
    input_list = [data['img']]
    if args.in_light: input_list.append(data['dirs'])
    if args.in_mask:  input_list.append(data['m'])
    return input_list

def parseData(args, sample, timer=None, split='train'):
    img, normal, mask = sample['img'], sample['normal'], sample['mask']
    ints = sample['ints']
    if args.in_light:
        dirs = sample['dirs'].expand_as(img)
    else: # predict lighting, prepare ground truth
        n, c, h, w = sample['dirs'].shape
        dirs_split = torch.split(sample['dirs'].view(n, c), 3, 1)
        dirs = torch.cat(dirs_split, 0)
    if timer: timer.updateTime('ToCPU')
    if args.cuda:
        img, normal, mask = img.cuda(), normal.cuda(), mask.cuda()
        dirs, ints = dirs.cuda(), ints.cuda()
        if timer: timer.updateTime('ToGPU')
    data = {'img': img, 'n': normal, 'm': mask, 'dirs': dirs, 'ints': ints}
    return data 

def getInputChanel(args):
    args.log.printWrite('[Network Input] Color image as input')
    c_in = 3
    if args.in_light:
        args.log.printWrite('[Network Input] Adding Light direction as input')
        c_in += 3
    if args.in_mask:
        args.log.printWrite('[Network Input] Adding Mask as input')
        c_in += 1
    args.log.printWrite('[Network Input] Input channel: {}'.format(c_in))
    return c_in

def get_n_params(model):
    pp = 0
    for p in list(model.parameters()):
        nn = 1
        for s in list(p.size()):
            nn = nn * s
        pp += nn
    return pp

def loadCheckpoint(path, model, cuda=True):
    if cuda:
        checkpoint = torch.load(path)
    else:
        checkpoint = torch.load(path, map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])

def saveCheckpoint(save_path, epoch=-1, model=None, optimizer=None, records=None, args=None):
    state   = {'state_dict': model.state_dict(), 'model': args.model}
    records = {'epoch': epoch, 'optimizer':optimizer.state_dict(), 'records': records} # 'args': args}
    torch.save(state,   os.path.join(save_path, 'checkp_{}.pth.tar'.format(epoch)))
    torch.save(records, os.path.join(save_path, 'checkp_{}_rec.pth.tar'.format(epoch)))

def conv_ReLU(batchNorm, cin, cout, k=3, stride=1, pad=-1):
    pad = pad if pad >= 0 else (k - 1) // 2
    if batchNorm:
        print('=> convolutional layer with bachnorm')
        return nn.Sequential(
                nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False),
                nn.BatchNorm2d(cout),
                nn.ReLU(inplace=True)
                )
    else:
        return nn.Sequential(
                nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True),
                nn.ReLU(inplace=True)
                )

def conv(batchNorm, cin, cout, k=3, stride=1, pad=-1):
    pad = pad if pad >= 0 else (k - 1) // 2
    if batchNorm:
        print('=> convolutional layer with bachnorm')
        return nn.Sequential(
                nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False),
                nn.BatchNorm2d(cout),
                nn.LeakyReLU(0.1, inplace=True)
                )
    else:
        return nn.Sequential(
                nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True),
                nn.LeakyReLU(0.1, inplace=True)
                )

def outputConv(cin, cout, k=3, stride=1, pad=1):
    return nn.Sequential(
            nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True))

def deconv(cin, cout):
    return nn.Sequential(
            nn.ConvTranspose2d(cin, cout, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.1, inplace=True)
            )

def upconv(cin, cout):
    return nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear'),
            nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False),
            nn.LeakyReLU(0.1, inplace=True)
            )


================================================
FILE: models/solver_utils.py
================================================
import torch
import os
from utils import eval_utils

class Stage1ClsCrit(object): # First Stage, Light classification criterion
    def __init__(self, args):
        print('==> Using Stage1ClsCrit for lighting classification')
        self.s1_est_d = args.s1_est_d
        self.s1_est_i = args.s1_est_i
        self.ints_cls, self.dirs_cls = args.ints_cls, args.dirs_cls
        self.setupLightCrit(args)

    def setupLightCrit(self, args):
        args.log.printWrite('=> Using light criterion')
        if self.s1_est_d:
            self.dir_w = args.dir_w
            self.dirs_x_crit = torch.nn.CrossEntropyLoss()
            self.dirs_y_crit = torch.nn.CrossEntropyLoss()
            if args.cuda: 
                self.dirs_x_crit = self.dirs_x_crit.cuda()
                self.dirs_y_crit = self.dirs_y_crit.cuda()
        if self.s1_est_i:
            self.ints_w = args.ints_w
            self.ints_crit = torch.nn.CrossEntropyLoss()
            if args.cuda: self.ints_crit = self.ints_crit.cuda()

    def forward(self, output, target):
        self.loss = 0
        out_loss = {}
        if self.s1_est_d:
            est_dir_x, est_dir_y = output['dirs_x'], output['dirs_y']
            gt_dir_x, gt_dir_y = eval_utils.SphericalDirsToClass(target['dirs'], self.dirs_cls)
        
            dirs_x_loss = self.dirs_x_crit(est_dir_x, gt_dir_x)
            dirs_y_loss = self.dirs_y_crit(est_dir_y, gt_dir_y)

            out_loss['D_x_loss'] = dirs_x_loss.item()
            out_loss['D_y_loss'] = dirs_y_loss.item()
            self.loss += self.dir_w * (dirs_x_loss + dirs_y_loss)

        if self.s1_est_i:
            est_intens = output['ints']
            gt_ints = target['ints'].squeeze()[:, 0: target['ints'].shape[1]:3]
            gt_ints = torch.cat(torch.split(gt_ints, 1, 1), 0)
            gt_intens = eval_utils.LightIntsToClass(gt_ints, self.ints_cls)
            ints_loss = self.ints_crit(est_intens, gt_intens)
            out_loss['I_loss'] = ints_loss.item()
            self.loss += self.ints_w * ints_loss
        return out_loss
     
    def backward(self):
        self.loss.backward()

class Stage2Crit(object): # Second stage
    def __init__(self, args):
        self.s2_est_n = args.s2_est_n 
        self.s2_est_d = args.s2_est_d
        self.s2_est_i = args.s2_est_i
        self.setupLightCrit(args)
        if self.s2_est_n:
            self.setupNormalCrit(args)

    def setupLightCrit(self, args):
        args.log.printWrite('=> Using light criterion')
        if self.s2_est_d:
            self.dir_w = args.dir_w
            self.dirs_crit = torch.nn.CosineEmbeddingLoss()
            if args.cuda: self.dirs_crit = self.dirs_crit.cuda()
        if self.s2_est_i:
            self.ints_w = args.ints_w
            self.ints_crit = torch.nn.MSELoss()
            if args.cuda: self.ints_crit = self.ints_crit.cuda()

    def setupNormalCrit(self, args):
        args.log.printWrite('=> Using {} for criterion normal'.format(args.normal_loss))
        self.normal_w = args.normal_w
        if args.normal_loss == 'mse':
            self.n_crit = torch.nn.MSELoss()
        elif args.normal_loss == 'cos':
            self.n_crit = torch.nn.CosineEmbeddingLoss()
        else:
            raise Exception("=> Unknown Criterion '{}'".format(args.normal_loss))
        if args.cuda:
            self.n_crit = self.n_crit.cuda()

    def forward(self, output, target):
        self.loss = 0
        out_loss = {}

        if self.s2_est_d:
            d_est, d_tar = output['dirs'], target['dirs']
            d_num = d_tar.nelement() // d_tar.shape[1]
            if not hasattr(self, 'l_flag') or d_num != self.l_flag.nelement():
                self.l_flag = d_tar.data.new().resize_(d_num).fill_(1)
            dirs_loss = self.dirs_crit(d_est.squeeze(), d_tar, self.l_flag)
            self.loss += self.dir_w * dirs_loss
            out_loss['D_loss'] = dirs_loss.item()

        if self.s2_est_i:
            i_est, i_tar = output['ints'], target['ints']
            ints_loss  = self.ints_crit(i_est, i_tar.squeeze())
            self.loss += self.ints_w * ints_loss
            out_loss['I_loss'] = ints_loss.item()

        if self.s2_est_n:
            n_est, n_tar = output['n'], target['n']
            n_num = n_tar.nelement() // n_tar.shape[1]
            if not hasattr(self, 'n_flag') or n_num != self.n_flag.nelement():
                self.n_flag = n_tar.data.new().resize_(n_num).fill_(1)
            self.out_reshape = n_est.permute(0, 2, 3, 1).contiguous().view(-1, 3)
            self.gt_reshape  = n_tar.permute(0, 2, 3, 1).contiguous().view(-1, 3)
            normal_loss      = self.n_crit(self.out_reshape, self.gt_reshape, self.n_flag)
            self.loss += self.normal_w * normal_loss 
            out_loss['N_loss'] = normal_loss.item()
        return out_loss

    def backward(self):
        self.loss.backward()

def getOptimizer(args, params):
    args.log.printWrite('=> Using %s solver for optimization' % (args.solver))
    if args.solver == 'adam':
        optimizer = torch.optim.Adam(params, args.init_lr, betas=(args.beta_1, args.beta_2))
    elif args.solver == 'sgd':
        optimizer = torch.optim.SGD(params, args.init_lr, momentum=args.momentum)
    else:
        raise Exception("=> Unknown Optimizer %s" % (args.solver))
    return optimizer

def getLrScheduler(args, optimizer):
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
            milestones=args.milestones, gamma=args.lr_decay, last_epoch=args.start_epoch-2)
    return scheduler

def loadRecords(path, model, optimizer):
    records = None
    if os.path.isfile(path):
        records = torch.load(path[:-8] + '_rec' + path[-8:])
        optimizer.load_state_dict(records['optimizer'])
        start_epoch = records['epoch'] + 1
        records = records['records']
        print("=> loaded Records")
    else:
        raise Exception("=> no checkpoint found at '{}'".format(path))
    return records, start_epoch

def configOptimizer(args, model):
    records = None
    optimizer = getOptimizer(args, model.parameters())
    if args.resume:
        args.log.printWrite("=> Resume loading checkpoint '{}'".format(args.resume))
        records, start_epoch = loadRecords(args.resume, model, optimizer)
        args.start_epoch = start_epoch
    scheduler = getLrScheduler(args, optimizer)
    return optimizer, scheduler, records


================================================
FILE: options/__init__.py
================================================


================================================
FILE: options/base_opts.py
================================================
import argparse
import os
import torch

class BaseOpts(object):
    def __init__(self):
        self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)

    def initialize(self):
        #### Trainining Dataset ####
        self.parser.add_argument('--dataset',     default='UPS_Synth_Dataset')
        self.parser.add_argument('--data_dir',    default='data/datasets/PS_Blobby_Dataset')
        self.parser.add_argument('--data_dir2',   default='data/datasets/PS_Sculpture_Dataset')
        self.parser.add_argument('--concat_data', default=True, action='store_false')
        self.parser.add_argument('--l_suffix',    default='_mtrl.txt')

        #### Training Data and Preprocessing Arguments ####
        self.parser.add_argument('--rescale',     default=True,  action='store_false')
        self.parser.add_argument('--rand_sc',     default=True,  action='store_false')
        self.parser.add_argument('--scale_h',     default=128,   type=int)
        self.parser.add_argument('--scale_w',     default=128,   type=int)
        self.parser.add_argument('--crop',        default=True,  action='store_false')
        self.parser.add_argument('--crop_h',      default=128,   type=int)
        self.parser.add_argument('--crop_w',      default=128,   type=int)
        self.parser.add_argument('--test_h',      default=128,   type=int)
        self.parser.add_argument('--test_w',      default=128,   type=int)
        self.parser.add_argument('--test_resc',   default=True,  action='store_false')
        self.parser.add_argument('--int_aug',     default=True,  action='store_false')
        self.parser.add_argument('--noise_aug',   default=True,  action='store_false')
        self.parser.add_argument('--noise',       default=0.05,  type=float)
        self.parser.add_argument('--color_aug',   default=True,  action='store_false')
        self.parser.add_argument('--color_ratio', default=3,     type=float)
        self.parser.add_argument('--normalize',   default=False, action='store_true')

        #### Device Arguments ####
        self.parser.add_argument('--cuda',        default=True,  action='store_false')
        self.parser.add_argument('--multi_gpu',   default=False, action='store_true')
        self.parser.add_argument('--time_sync',   default=False, action='store_true')
        self.parser.add_argument('--workers',     default=8,     type=int)
        self.parser.add_argument('--seed',        default=0,     type=int)

        #### Stage 1 Model Arguments ####
        self.parser.add_argument('--dirs_cls',    default=36,    type=int)
        self.parser.add_argument('--ints_cls',    default=20,    type=int)
        self.parser.add_argument('--dir_int',     default=False, action='store_true')
        self.parser.add_argument('--model',       default='LCNet')
        self.parser.add_argument('--fuse_type',   default='max')
        self.parser.add_argument('--in_img_num',  default=32,    type=int)
        self.parser.add_argument('--s1_est_n',    default=False, action='store_true')
        self.parser.add_argument('--s1_est_d',    default=True,  action='store_false')
        self.parser.add_argument('--s1_est_i',    default=True,  action='store_false')
        self.parser.add_argument('--in_light',    default=False, action='store_true')
        self.parser.add_argument('--in_mask',     default=True,  action='store_false')
        self.parser.add_argument('--use_BN',      default=False, action='store_true')
        self.parser.add_argument('--resume',      default=None)
        self.parser.add_argument('--retrain',     default=None)
        self.parser.add_argument('--save_intv',   default=1,     type=int)

        #### Stage 2 Model Arguments ####
        self.parser.add_argument('--stage2',      default=False, action='store_true')
        self.parser.add_argument('--model_s2',    default='NENet')
        self.parser.add_argument('--retrain_s2',  default=None)
        self.parser.add_argument('--s2_est_n',    default=True,  action='store_false')
        self.parser.add_argument('--s2_est_i',    default=False, action='store_true')
        self.parser.add_argument('--s2_est_d',    default=False, action='store_true')
        self.parser.add_argument('--s2_in_light', default=True,  action='store_false')

        #### Displaying Arguments ####
        self.parser.add_argument('--train_disp',    default=20,  type=int)
        self.parser.add_argument('--train_save',    default=200, type=int)
        self.parser.add_argument('--val_intv',      default=1,   type=int)
        self.parser.add_argument('--val_disp',      default=1,   type=int)
        self.parser.add_argument('--val_save',      default=1,   type=int)
        self.parser.add_argument('--max_train_iter',default=-1,  type=int)
        self.parser.add_argument('--max_val_iter',  default=-1,  type=int)
        self.parser.add_argument('--max_test_iter', default=-1,  type=int)
        self.parser.add_argument('--train_save_n',  default=4,   type=int)
        self.parser.add_argument('--test_save_n',   default=4,   type=int)

        #### Log Arguments ####
        self.parser.add_argument('--save_root',  default='data/logdir/')
        self.parser.add_argument('--item',       default='CVPR2019')
        self.parser.add_argument('--suffix',     default=None)
        self.parser.add_argument('--debug',      default=False, action='store_true')
        self.parser.add_argument('--make_dir',   default=True,  action='store_false')
        self.parser.add_argument('--save_split', default=False, action='store_true')

    def setDefault(self):
        if self.args.debug:
            self.args.train_disp = 1
            self.args.train_save = 1
            self.args.max_train_iter = 4 
            self.args.max_val_iter = 4
            self.args.max_test_iter = 4
            self.args.test_intv = 1
    def collectInfo(self):
        self.args.str_keys  = [
                'model', 'fuse_type', 'solver'
                ]
        self.args.val_keys  = [
                'batch', 'scale_h', 'crop_h', 'init_lr', 'normal_w', 
                'dir_w', 'ints_w', 'in_img_num', 'dirs_cls', 'ints_cls'
                ]
        self.args.bool_keys = [
                'use_BN', 'in_light', 'in_mask', 's1_est_n', 's1_est_d', 's1_est_i', 
                'color_aug', 'int_aug', 'concat_data', 'retrain', 'resume', 'stage2', 
                ] 

    def parse(self):
        self.args = self.parser.parse_args()
        return self.args


================================================
FILE: options/run_model_opts.py
================================================
from .base_opts import BaseOpts
class RunModelOpts(BaseOpts):
    def __init__(self):
        super(RunModelOpts, self).__init__()
        self.initialize()

    def initialize(self):
        BaseOpts.initialize(self)
        #### Testing Dataset Arguments #### 
        self.parser.add_argument('--run_model',  default=True, action='store_false')
        self.parser.add_argument('--benchmark',  default='UPS_DiLiGenT_main')
        self.parser.add_argument('--bm_dir',     default='data/datasets/DiLiGenT/pmsData_crop')
        self.parser.add_argument('--epochs',     default=1,   type=int)
        self.parser.add_argument('--test_batch', default=1,   type=int)
        self.parser.add_argument('--test_disp',  default=1,   type=int)
        self.parser.add_argument('--test_save',  default=1,   type=int)
        
        ### For UPS_Custom_Datast.py to test you own dataset
        self.parser.add_argument('--have_l_dirs', default=False, action='store_true', help='Have light directions?')
        self.parser.add_argument('--have_l_ints', default=False, action='store_true', help='Have light intensities?')
        self.parser.add_argument('--have_gt_n',   default=False, action='store_true', help='Have GT surface normals?')

    def collectInfo(self):
        self.args.str_keys  = ['model', 'model_s2', 'benchmark', 'fuse_type']
        self.args.val_keys  = ['in_img_num', 'test_h', 'test_w']
        self.args.bool_keys = ['int_aug', 'test_resc']

    def setDefault(self):
        self.collectInfo()

    def parse(self):
        BaseOpts.parse(self)
        self.setDefault()
        return self.args


================================================
FILE: options/stage1_opts.py
================================================
from .base_opts import BaseOpts
class TrainOpts(BaseOpts):
    def __init__(self):
        super(TrainOpts, self).__init__()
        self.initialize()

    def initialize(self):
        BaseOpts.initialize(self)
        #### Training Arguments ####
        self.parser.add_argument('--solver',      default='adam', help='adam|sgd')
        self.parser.add_argument('--milestones',  default=[5, 10, 15, 20, 25], nargs='+', type=int)
        self.parser.add_argument('--start_epoch', default=1,      type=int)
        self.parser.add_argument('--epochs',      default=20,     type=int)
        self.parser.add_argument('--batch',       default=32,     type=int)
        self.parser.add_argument('--val_batch',   default=8,      type=int)
        self.parser.add_argument('--init_lr',     default=0.0005, type=float)
        self.parser.add_argument('--lr_decay',    default=0.5,    type=float)
        self.parser.add_argument('--beta_1',      default=0.9,    type=float, help='adam')
        self.parser.add_argument('--beta_2',      default=0.999,  type=float, help='adam')
        self.parser.add_argument('--momentum',    default=0.9,    type=float, help='sgd')
        self.parser.add_argument('--w_decay',     default=4e-4,   type=float)

        #### Loss Arguments ####
        self.parser.add_argument('--normal_loss', default='cos',  help='cos|mse')
        self.parser.add_argument('--normal_w',    default=1,      type=float)
        self.parser.add_argument('--dir_loss',    default='cos',  help='cos|mse')
        self.parser.add_argument('--dir_w',       default=1,      type=float)
        self.parser.add_argument('--ints_loss',   default='mse',  help='l1|mse')
        self.parser.add_argument('--ints_w',      default=1,      type=float)

    def collectInfo(self): 
        BaseOpts.collectInfo(self)
        self.args.str_keys  += [
                'dir_loss',
                ]
        self.args.val_keys  += [
                ]
        self.args.bool_keys += [
                ] 

    def setDefault(self):
        BaseOpts.setDefault(self)
        if self.args.test_h != self.args.crop_h:
            self.args.test_h, self.args.test_w = self.args.crop_h, self.args.crop_w
        self.collectInfo()

    def parse(self):
        BaseOpts.parse(self)
        self.setDefault()
        return self.args


================================================
FILE: options/stage2_opts.py
================================================
from .base_opts import BaseOpts
class TrainOpts(BaseOpts):
    def __init__(self):
        super(TrainOpts, self).__init__()
        self.initialize()

    def initialize(self):
        BaseOpts.initialize(self)
        #### Training Arguments ####
        self.parser.add_argument('--solver',      default='adam', help='adam|sgd')
        self.parser.add_argument('--milestones',  default=[2, 4, 6, 8, 10], nargs='+', type=int)
        self.parser.add_argument('--start_epoch', default=1,      type=int)
        self.parser.add_argument('--epochs',      default=10,     type=int)
        self.parser.add_argument('--batch',       default=16,     type=int)
        self.parser.add_argument('--val_batch',   default=8,      type=int)
        self.parser.add_argument('--init_lr',     default=0.0005, type=float)
        self.parser.add_argument('--lr_decay',    default=0.5,    type=float)
        self.parser.add_argument('--beta_1',      default=0.9,    type=float, help='adam')
        self.parser.add_argument('--beta_2',      default=0.999,  type=float, help='adam')
        self.parser.add_argument('--momentum',    default=0.9,    type=float, help='sgd')
        self.parser.add_argument('--w_decay',     default=4e-4,   type=float)

        #### Loss Arguments ####
        self.parser.add_argument('--normal_loss', default='cos',  help='cos|mse')
        self.parser.add_argument('--normal_w',    default=1,      type=float)
        self.parser.add_argument('--dir_loss',    default='mse',  help='cos|mse')
        self.parser.add_argument('--dir_w',       default=1,      type=float)
        self.parser.add_argument('--ints_loss',   default='mse',  help='mse')
        self.parser.add_argument('--ints_w',      default=1,      type=float)

    def collectInfo(self): 
        BaseOpts.collectInfo(self)
        self.args.str_keys += [
                'model_s2'
                ]
        self.args.val_keys  += [
                ]
        self.args.bool_keys += [
                ]

    def setDefault(self):
        BaseOpts.setDefault(self)
        self.args.stage2    = True
        self.args.test_resc = False
        self.collectInfo()

    def parse(self):
        BaseOpts.parse(self)
        self.setDefault()
        return self.args


================================================
FILE: scripts/DiLiGenT_objects.txt
================================================
ballPNG
catPNG
pot1PNG
bearPNG
pot2PNG
buddhaPNG
gobletPNG
readingPNG
cowPNG
harvestPNG


================================================
FILE: scripts/cropDiLiGenTData.py
================================================
import os, argparse, sys, shutil, glob
import numpy as np
from imageio import imread, imsave
import scipy.io as sio

root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')
sys.path.append(root_path)
from utils import utils

parser = argparse.ArgumentParser()
parser.add_argument('--input_dir',  default='data/datasets/DiLiGenT/pmsData')
parser.add_argument('--obj_list',   default='objects.txt')
parser.add_argument('--suffix',     default='crop')
parser.add_argument('--file_ext',   default='.png')
parser.add_argument('--normal_name',default='Normal_gt.png')
parser.add_argument('--mask_name',  default='mask.png')
parser.add_argument('--n_key',      default='Normal_gt')
parser.add_argument('--pad',        default=15, type=int)
args = parser.parse_args()

def getSaveDir():
    dirName  = os.path.dirname(args.input_dir)
    save_dir = '%s_%s' % (args.input_dir, args.suffix) 
    utils.makeFile(save_dir)
    print('Output dir: %s\n' % save_dir)
    return save_dir

def getBBoxCompact(mask):
    index = np.where(mask != 0)
    t, b, l , r = index[0].min(), index[0].max(), index[1].min(), index[1].max()
    h, w = b - t + 2 * args.pad, r - l + 2 * args.pad
    t = max(0, t - args.pad)
    b = t + h 
    l = max(0, l - args.pad)
    r = l + w 
    if h % 4 != 0: 
        pad = 4 - h % 4
        b += pad; h += pad
    if w % 4 != 0: 
        pad = 4 - w % 4
        r += pad; w += pad
    return l, r, t, b, h, w

def loadMaskNormal(d):
    mask   = imread(os.path.join(args.input_dir, d, args.mask_name))
    try:
        normal = imread(os.path.join(args.input_dir, d, args.normal_name))
    except IOError:
        normal = imread(os.path.join(args.input_dir, d, 'normal_gt.png'))
    n_mat  = sio.loadmat(os.path.join(args.input_dir, d, 'Normal_gt.mat'))[args.n_key]
    h, w, c = normal.shape
    print('Processing Objects: %s' % d, mask.shape)
    if mask.ndim < 3:
        mask = mask.reshape(h, w, 1).repeat(3, 2)
    return mask, normal, n_mat

def copyTXT(d):
    txt = glob.glob(os.path.join(args.input_dir, '*.txt'))
    for t in txt:
        name = os.path.basename(t)
        shutil.copy(t, os.path.join(args.save_dir, name))

    txt = glob.glob(os.path.join(args.input_dir, d, '*.txt'))
    for t in txt:
        name = os.path.basename(t)
        shutil.copy(t, os.path.join(args.save_dir, d, name))

if __name__ == '__main__':
    print('Input dir: %s\n' % args.input_dir)
    args.save_dir = getSaveDir()

    dir_list  = utils.readList(os.path.join(args.input_dir, args.obj_list))
    name_list = utils.readList(os.path.join(args.input_dir, 'names.txt'))
    max_h, max_w = 0, 0
    crop_list = open(os.path.join(args.save_dir, 'crop.txt'), 'w')
    for d in dir_list:
        utils.makeFile(os.path.join(args.save_dir, d))
        mask, normal, n_mat = loadMaskNormal(d)
        l, r, t, b, h, w = getBBoxCompact(mask[:,:,0] / 255)
        crop_list.write('%d %d %d %d %d %d\n' % (mask.shape[0], mask.shape[1], l, r, t, b))

        max_h = h if h > max_h else max_h
        max_w = w if w > max_w else max_w
        print('\t BBox L %d R %d T %d B %d, H:%d W:%d, Padded: %d %d' % 
                (l, r, t, b, h, w, r - l, b - t)) 
        imsave(os.path.join(args.save_dir, d, args.mask_name), mask[t:b, l:r, :])
        imsave(os.path.join(args.save_dir, d, args.normal_name), normal[t:b, l:r, :])
        sio.savemat(os.path.join(args.save_dir, d, 'Normal_gt.mat'), 
                {args.n_key: n_mat[t:b, l:r, :]} ,do_compression=True)
        copyTXT(d)
        intens = np.genfromtxt(os.path.join(args.input_dir, d, 'light_intensities.txt'))
        for idx, name in enumerate(name_list):
            #print('Process img %d/%d' % (idx+1, len(name_list)))
            img = imread(os.path.join(args.input_dir, d, name))
            img = img[t:b, l:r, :]
            imsave(os.path.join(args.save_dir, d, name), img.astype(np.uint8))
    print('Max H %d, Max %d' % (max_h, max_w))
    crop_list.close()


================================================
FILE: scripts/download_pretrained_models.sh
================================================
path="data/models/"
mkdir -p $path
cd $path

# Download pre-trained model
for model in "LCNet_CVPR2019.pth.tar" "NENet_CVPR2019.pth.tar"; do
    wget http://www.visionlab.cs.hku.hk/data/SDPS-Net/models/${model}
done

# Back to root directory
cd ../


================================================
FILE: scripts/download_synthetic_datasets.sh
================================================
mkdir -p data/datasets
cd data/datasets

# Download Synthetic dataset
for dataset in "PS_Sculpture_Dataset.tgz" "PS_Blobby_Dataset.tgz"; do
    echo "Downloading $dataset"
    wget http://www.visionlab.cs.hku.hk/data/PS-FCN/datasets/$dataset
    tar -xvf $dataset
    rm $dataset
done

# Back to root directory
cd ../../



================================================
FILE: scripts/prepare_diligent_dataset.sh
================================================
mkdir -p data/datasets
cd data/datasets

## Download real testing dataset
url="https://www.dropbox.com/s/hdnbh526tyvv68i/DiLiGenT.zip?dl=0"
name="DiLiGenT"

wget $url -O ${name}.zip
unzip ${name}.zip
rm ${name}.zip

cd ${name}/pmsData/
cp ballPNG/filenames.txt names.txt

# Back to root directory
cd ../../../../
cp scripts/DiLiGenT_objects.txt data/datasets/${name}/pmsData/objects.txt
python scripts/cropDiLiGenTData.py


================================================
FILE: test_stage1.py
================================================
import os
import torch
from models import model_utils
from utils import eval_utils, time_utils 
import numpy as np

def get_itervals(args, split):
    if split not in ['train', 'val', 'test']:
        split = 'test'
    args_var = vars(args)
    disp_intv = args_var[split+'_disp']
    save_intv = args_var[split+'_save']
    stop_iters = args_var['max_'+split+'_iter']
    return disp_intv, save_intv, stop_iters

def test(args, split, loader, model, log, epoch, recorder):
    model.eval()
    log.printWrite('---- Start %s Epoch %d: %d batches ----' % (split, epoch, len(loader)))
    timer = time_utils.Timer(args.time_sync);

    disp_intv, save_intv, stop_iters = get_itervals(args, split)
    res = []
    with torch.no_grad():
        for i, sample in enumerate(loader):
            data = model_utils.parseData(args, sample, timer, split)
            input = model_utils.getInput(args, data)

            pred = model(input); timer.updateTime('Forward')

            recoder, iter_res, error = prepareRes(args, data, pred, recorder, log, split)

            res.append(iter_res)
            iters = i + 1
            if iters % disp_intv == 0:
                opt = {'split':split, 'epoch':epoch, 'iters':iters, 'batch':len(loader), 
                        'timer':timer, 'recorder': recorder}
                log.printItersSummary(opt)

            if iters % save_intv == 0:
                results, nrow = prepareSave(args, data, pred)
                log.saveImgResults(results, split, epoch, iters, nrow=nrow, error=error)
                log.plotCurves(recorder, split, epoch=epoch, intv=disp_intv)

            if stop_iters > 0 and iters >= stop_iters: break
    res = np.vstack([np.array(res), np.array(res).mean(0)])
    save_name = '%s_res.txt' % (args.suffix)
    np.savetxt(os.path.join(args.log_dir, split, save_name), res, fmt='%.2f')
    opt = {'split': split, 'epoch': epoch, 'recorder': recorder}
    log.printEpochSummary(opt)

def prepareRes(args, data, pred, recorder, log, split):
    data_batch = args.val_batch if split == 'val' else args.test_batch
    iter_res = []
    error = ''
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred['dirs'].data, data_batch)
        recorder.updateIter(split, l_acc.keys(), l_acc.values())
        iter_res.append(l_acc['l_err_mean'])
        error += 'D_%.3f-' % (l_acc['l_err_mean']) 
    if args.s1_est_i:
        int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred['intens'].data, data_batch)
        recorder.updateIter(split, int_acc.keys(), int_acc.values())
        iter_res.append(int_acc['ints_ratio'])
        error += 'I_%.3f-' % (int_acc['ints_ratio'])

    if args.s1_est_n:
        acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data)
        recorder.updateIter(split, acc.keys(), acc.values())
        iter_res.append(acc['n_err_mean'])
        error += 'N_%.3f-' % (acc['n_err_mean'])
        data['error_map'] = error_map['angular_map']

    return recorder, iter_res, error


def prepareSave(args, data, pred):
    results = [data['img'].data, data['m'].data, (data['n'].data+1)/2]
    if args.s1_est_n:
        pred_n = (pred['n'].data + 1) / 2
        masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data)
        res_n = [pred_n, masked_pred, data['error_map']]
        results += res_n

    nrow = data['img'].shape[0]
    return results, nrow


================================================
FILE: test_stage2.py
================================================
import os
import torch
from models import model_utils
from utils import eval_utils, time_utils 
import numpy as np

def get_itervals(args, split):
    if split not in ['train', 'val', 'test']:
        split = 'test'
    args_var = vars(args)
    disp_intv = args_var[split+'_disp']
    save_intv = args_var[split+'_save']
    stop_iters = args_var['max_'+split+'_iter']
    return disp_intv, save_intv, stop_iters

def test(args, split, loader, models, log, epoch, recorder):
    models[0].eval()
    models[1].eval()
    log.printWrite('---- Start %s Epoch %d: %d batches ----' % (split, epoch, len(loader)))
    timer = time_utils.Timer(args.time_sync);

    disp_intv, save_intv, stop_iters = get_itervals(args, split)
    res = []
    with torch.no_grad():
        for i, sample in enumerate(loader):
            data = model_utils.parseData(args, sample, timer, split)
            input = model_utils.getInput(args, data)

            pred_c = models[0](input); timer.updateTime('Forward')
            input.append(pred_c)
            pred = models[1](input); timer.updateTime('Forward')

            recoder, iter_res, error = prepareRes(args, data, pred_c, pred, recorder, log, split)

            res.append(iter_res)
            iters = i + 1
            if iters % disp_intv == 0:
                opt = {'split':split, 'epoch':epoch, 'iters':iters, 'batch':len(loader), 
                        'timer':timer, 'recorder': recorder}
                log.printItersSummary(opt)

            if iters % save_intv == 0:
                results, nrow = prepareSave(args, data, pred_c, pred)
                log.saveImgResults(results, split, epoch, iters, nrow=nrow, error='')
                log.plotCurves(recorder, split, epoch=epoch, intv=disp_intv)

            if stop_iters > 0 and iters >= stop_iters: break
    res = np.vstack([np.array(res), np.array(res).mean(0)])
    save_name = '%s_res.txt' % (args.suffix)
    np.savetxt(os.path.join(args.log_dir, split, save_name), res, fmt='%.2f')
    if res.ndim > 1:
        for i in range(res.shape[1]):
            save_name = '%s_%d_res.txt' % (args.suffix, i)
            np.savetxt(os.path.join(args.log_dir, split, save_name), res[:,i], fmt='%.3f')

    opt = {'split': split, 'epoch': epoch, 'recorder': recorder}
    log.printEpochSummary(opt)

def prepareRes(args, data, pred_c, pred, recorder, log, split):
    data_batch = args.val_batch if split == 'val' else args.test_batch
    iter_res = []
    error = ''
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred_c['dirs'].data, data_batch)
        recorder.updateIter(split, l_acc.keys(), l_acc.values())
        iter_res.append(l_acc['l_err_mean'])
        error += 'D_%.3f-' % (l_acc['l_err_mean']) 
    if args.s1_est_i:
        int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred_c['intens'].data, data_batch)
        recorder.updateIter(split, int_acc.keys(), int_acc.values())
        iter_res.append(int_acc['ints_ratio'])
        error += 'I_%.3f-' % (int_acc['ints_ratio'])

    if args.s2_est_n:
        acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data)
        recorder.updateIter(split, acc.keys(), acc.values())
        iter_res.append(acc['n_err_mean'])
        error += 'N_%.3f-' % (acc['n_err_mean'])
        data['error_map'] = error_map['angular_map']
    return recorder, iter_res, error

def prepareSave(args, data, pred_c, pred):
    results = [data['img'].data, data['m'].data, (data['n'].data+1) / 2]
    if args.s2_est_n:
        pred_n = (pred['n'].data + 1) / 2
        masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data)
        res_n = [masked_pred, data['error_map']]
        results += res_n

    nrow = data['img'].shape[0]
    return results, nrow


================================================
FILE: train_stage1.py
================================================
from models import model_utils
from utils  import eval_utils, time_utils

def train(args, loader, model, criterion, optimizer, log, epoch, recorder):
    model.train()
    log.printWrite('---- Start Training Epoch %d: %d batches ----' % (epoch, len(loader)))
    timer = time_utils.Timer(args.time_sync);

    for i, sample in enumerate(loader):
        data = model_utils.parseData(args, sample, timer, 'train')
        input = model_utils.getInput(args, data)

        pred = model(input); timer.updateTime('Forward')

        optimizer.zero_grad()
        loss = criterion.forward(pred, data); 
        timer.updateTime('Crit');
        criterion.backward(); timer.updateTime('Backward')

        recorder.updateIter('train', loss.keys(), loss.values())

        optimizer.step(); timer.updateTime('Solver')

        iters = i + 1
        if iters % args.train_disp == 0:
            opt = {'split':'train', 'epoch':epoch, 'iters':iters, 'batch':len(loader), 
                    'timer':timer, 'recorder': recorder}
            log.printItersSummary(opt)

        if iters % args.train_save == 0:
            results, recorder, nrow = prepareSave(args, data, pred, recorder, log) 
            log.saveImgResults(results, 'train', epoch, iters, nrow=nrow)
            log.plotCurves(recorder, 'train', epoch=epoch, intv=args.train_disp)

        if args.max_train_iter > 0 and iters >= args.max_train_iter: break
    opt = {'split': 'train', 'epoch': epoch, 'recorder': recorder}
    log.printEpochSummary(opt)

def prepareSave(args, data, pred, recorder, log):
    results = [data['img'].data, data['m'].data, (data['n'].data+1)/2]
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred['dirs'].data, args.batch)
        recorder.updateIter('train', l_acc.keys(), l_acc.values())

    if args.s1_est_i:
        int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred['intens'].data, args.batch)
        recorder.updateIter('train', int_acc.keys(), int_acc.values())

    if args.s1_est_n:
        acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data)
        pred_n = (pred['n'].data + 1) / 2
        masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data)
        res_n = [pred_n, masked_pred, error_map['angular_map']]
        results += res_n
        recorder.updateIter('train', acc.keys(), acc.values())
    nrow = data['img'].shape[0] if data['img'].shape[0] <= 32 else 32
    return results, recorder, nrow


================================================
FILE: train_stage2.py
================================================
import torch
from models import model_utils
from utils  import eval_utils, time_utils

def train(args, loader, models, criterion, optimizers, log, epoch, recorder):
    models[1].train()
    models[0].eval()
    optimizer, optimizer_c = optimizers
    log.printWrite('---- Start Training Epoch %d: %d batches ----' % (epoch, len(loader)))
    timer = time_utils.Timer(args.time_sync);

    for i, sample in enumerate(loader):
        data = model_utils.parseData(args, sample, timer, 'train')
        input = model_utils.getInput(args, data)
        with torch.no_grad():
            pred_c = models[0](input); 
        input.append(pred_c)
        pred = models[1](input); timer.updateTime('Forward')

        optimizer.zero_grad()

        loss = criterion.forward(pred, data); 
        timer.updateTime('Crit');
        criterion.backward(); timer.updateTime('Backward')

        recorder.updateIter('train', loss.keys(), loss.values())

        optimizer.step(); timer.updateTime('Solver')

        iters = i + 1
        if iters % args.train_disp == 0:
            opt = {'split':'train', 'epoch':epoch, 'iters':iters, 'batch':len(loader), 
                    'timer':timer, 'recorder': recorder}
            log.printItersSummary(opt)

        if iters % args.train_save == 0:
            results, recorder, nrow = prepareSave(args, data, pred_c, pred, recorder, log) 
            log.saveImgResults(results, 'train', epoch, iters, nrow=nrow)
            log.plotCurves(recorder, 'train', epoch=epoch, intv=args.train_disp)

        if args.max_train_iter > 0 and iters >= args.max_train_iter: break
    opt = {'split': 'train', 'epoch': epoch, 'recorder': recorder}
    log.printEpochSummary(opt)

def prepareSave(args, data, pred_c, pred, recorder, log):
    input_var, mask_var = data['img'], data['m']
    results = [input_var.data, mask_var.data, (data['n'].data+1)/2]
    if args.s1_est_d:
        l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred_c['dirs'].data, args.batch)
        recorder.updateIter('train', l_acc.keys(), l_acc.values())
    if args.s1_est_i:
        int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred_c['intens'].data, args.batch)
        recorder.updateIter('train', int_acc.keys(), int_acc.values())

    if args.s2_est_n:
        acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, mask_var.data)
        pred_n = (pred['n'].data + 1) / 2
        masked_pred = pred_n * mask_var.data.expand_as(pred['n'].data)
        res_n = [masked_pred, error_map['angular_map']]
        results += res_n
        recorder.updateIter('train', acc.keys(), acc.values())

    nrow = input_var.shape[0] if input_var.shape[0] <= 32 else 32
    return results, recorder, nrow


================================================
FILE: utils/__init__.py
================================================


================================================
FILE: utils/eval_utils.py
================================================
import torch
import math
import numpy as np
from matplotlib import cm

def colorMap(diff):
    thres = 90
    diff_norm = np.clip(diff, 0, thres) / thres
    diff_cm = torch.from_numpy(cm.jet(diff_norm.numpy()))[:,:,:, :3]
    return diff_cm.permute(0,3,1,2).clone().float()

def calDirsAcc(gt_l, pred_l, data_batch=1):
    n, c = gt_l.shape
    pred_l = pred_l.view(n, c)
    dot_product = (gt_l * pred_l).sum(1).clamp(-1, 1)
    
    angular_err = torch.acos(dot_product) * 180.0 / math.pi
    l_err_mean  = angular_err.mean()
    return {'l_err_mean': l_err_mean.item()}, angular_err.squeeze()

def calIntsAcc(gt_i, pred_i, data_batch=1):
    n, c, h, w = gt_i.shape
    pred_i  = pred_i.view(n, c, h, w)
    ref_int = gt_i[:, :3].repeat(1, gt_i.shape[1] // 3, 1, 1)
    gt_i  = gt_i / ref_int
    scale = torch.gels(gt_i.view(-1, 1), pred_i.view(-1, 1))
    ints_ratio = (gt_i - scale[0][0] * pred_i).abs() / (gt_i + 1e-8)
    ints_error = torch.stack(ints_ratio.split(3, 1), 1).mean(2)
    return {'ints_ratio': ints_ratio.mean().item()}, ints_error.squeeze()
    
def calNormalAcc(gt_n, pred_n, mask=None):
    """Tensor Dim: NxCxHxW"""
    dot_product = (gt_n * pred_n).sum(1).clamp(-1,1)
    error_map   = torch.acos(dot_product) # [-pi, pi]
    angular_map = error_map * 180.0 / math.pi
    angular_map = angular_map * mask.narrow(1, 0, 1).squeeze(1)

    valid = mask.narrow(1, 0, 1).sum()
    ang_valid  = angular_map[mask.narrow(1, 0, 1).squeeze(1).byte()]
    n_err_mean = ang_valid.sum() / valid
    n_err_med  = ang_valid.median()
    n_acc_11   = (ang_valid < 11.25).sum().float() / valid
    n_acc_30   = (ang_valid < 30).sum().float() / valid
    n_acc_45   = (ang_valid < 45).sum().float() / valid

    angular_map = colorMap(angular_map.cpu().squeeze(1))
    value = {'n_err_mean': n_err_mean.item(), 
            'n_acc_11': n_acc_11.item(), 'n_acc_30': n_acc_30.item(), 'n_acc_45': n_acc_45.item()}
    angular_error_map = {'angular_map': angular_map}
    return value, angular_error_map

def SphericalDirsToClass(dirs, cls_num):
    theta = torch.atan(dirs[:,0] / (dirs[:,2] + 1e-8)) 
    denom = torch.sqrt(dirs[:,0] * dirs[:,0] + dirs[:,2] * dirs[:,2])
    phi = torch.atan(dirs[:,1] / (denom + 1e-8))
    theta = theta / np.pi * 180
    phi   = phi / np.pi * 180
    azimuth = ((theta + 90.0) / 180 * cls_num).clamp(0, cls_num-1).long()
    elevate = ((phi   + 90.0) / 180 * cls_num).clamp(0, cls_num-1).long()
    return azimuth, elevate

def SphericalClassToDirs(x_cls, y_cls, cls_num):
    theta = (x_cls.float() + 0.5) / cls_num * 180 - 90
    phi   = (y_cls.float() + 0.5) / cls_num * 180 - 90
    neg_x = theta < 0
    neg_y = phi < 0
    theta = theta.clamp(-90, 90) / 180.0 * np.pi
    phi   = phi.clamp(-90, 90) / 180.0 * np.pi

    tan2_phi   = pow(torch.tan(phi), 2)
    tan2_theta = pow(torch.tan(theta), 2)
    y = torch.sqrt(tan2_phi / (1 + tan2_phi))
    y[neg_y] = y[neg_y] * -1
    #y = torch.sin(phi)
    z = torch.sqrt((1 - y * y) / (1 + tan2_theta))
    x = z * torch.tan(theta)
    dirs = torch.stack([x,y,z], 1)
    dirs = dirs / dirs.norm(p=2, dim=1, keepdim=True)
    return dirs
    
def LightIntsToClass(ints, cls_num):
    ints = (ints - 0.2) / 1.8
    ints = (ints * cls_num).clamp(0, cls_num-1).long()
    return ints.view(-1)

def ClassToLightInts(cls, cls_num):
    ints = (cls.float() + 0.5) / cls_num * 1.8 + 0.2
    ints = ints.clamp(0.2 , 2.0)
    return ints


================================================
FILE: utils/logger.py
================================================
import datetime, time, os
import numpy as np
import torch
import torchvision.utils as vutils
import scipy.io as sio
from . import utils

import matplotlib; matplotlib.use('agg')
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.font_manager import FontProperties
fontP = FontProperties()
fontP.set_size('small')
plt.rcParams["figure.figsize"] = (5,8)

class Logger(object):
    def __init__(self, args):
        self.times = {'init': time.time()}
        if args.make_dir:
            self._setupDirs(args)
        self.args = args
        args.log  = self
        self.printArgs()

    def printArgs(self):
        strs = '------------ Options -------------\n'
        strs += '{}'.format(utils.dictToString(vars(self.args)))
        strs += '-------------- End ----------------\n'
        self.printWrite(strs)

    def _addArguments(self, args):
        info = ''
        if hasattr(args, 'run_model') and args.run_model:
            info += '_run_model,%s' % os.path.basename(args.retrain).split('.')[0]
        arg_var  = vars(args)
        for k in args.str_keys:  
            info = '{0},{1}'.format(info, arg_var[k])
        for k in args.val_keys:  
            var_key = k[:2] + '_' + k[-1]
            info = '{0},{1}-{2}'.format(info, var_key, arg_var[k])
        for k in args.bool_keys: 
            info = '{0},{1}'.format(info, k) if arg_var[k] else info 
        return info

    def _setupDirs(self, args):
        date_now = datetime.datetime.now()
        dir_name = '%d-%d' % (date_now.month, date_now.day)
        dir_name += (',%s' % args.suffix) if args.suffix else ''
        dir_name += self._addArguments(args) 
        dir_name += ',DEBUG' if args.debug else ''

        self._checkPath(args, dir_name)
        file_dir = os.path.join(args.log_dir, '%s,%s' % (dir_name, date_now.strftime('%H:%M:%S')))
        self.log_fie = open(file_dir, 'w')
        return 

    def _checkPath(self, args, dir_name):
        if hasattr(args, 'run_model') and args.run_model:
            log_root = os.path.join(os.path.dirname(args.retrain), dir_name)
            args.log_dir = log_root
            sub_dirs = ['test']
        else:
            if args.resume and os.path.isfile(args.resume):
                log_root = os.path.join(os.path.dirname(os.path.dirname(args.resume)), dir_name)
            else:
                if args.debug:
                    dir_name = 'DEBUG/' + dir_name
                log_root = os.path.join(args.save_root, args.dataset, args.item, dir_name)
            args.log_dir = os.path.join(log_root, 'logdir')
            args.cp_dir  = os.path.join(log_root, 'checkpointdir')
            utils.makeFiles([args.log_dir, args.cp_dir])
            sub_dirs = ['train', 'val'] 
        for sub_dir in sub_dirs:
            utils.makeFiles([os.path.join(args.log_dir, sub_dir, 'Images')])

    def printWrite(self, strs):
        print('%s' % strs)
        if self.args.make_dir:
            self.log_fie.write('%s\n' % strs)
            self.log_fie.flush()

    def getTimeInfo(self, epoch, iters, batch):
        time_elapsed = (time.time() - self.times['init']) / 3600.0
        total_iters  = (self.args.epochs - self.args.start_epoch + 1) * batch
        cur_iters    = (epoch - self.args.start_epoch) * batch + iters
        time_total   = time_elapsed * (float(total_iters) / cur_iters)
        return time_elapsed, time_total

    def printItersSummary(self, opt):
        epoch, iters, batch = opt['epoch'], opt['iters'], opt['batch']
        strs = ' | {}'.format(str.upper(opt['split']))
        strs += ' Iter [{}/{}] Epoch [{}/{}]'.format(iters, batch, epoch, self.args.epochs)
        if opt['split'] == 'train': 
            time_elapsed, time_total = self.getTimeInfo(epoch, iters, batch) # Buggy for test
            strs += ' Clock [{:.2f}h/{:.2f}h]'.format(time_elapsed, time_total)
            strs += ' LR [{}]'.format(opt['recorder'].records[opt['split']]['lr'][epoch][0])
        self.printWrite(strs)
        if 'timer' in opt.keys(): 
            self.printWrite(opt['timer'].timeToString())
        if 'recorder' in opt.keys(): 
            self.printWrite(opt['recorder'].iterRecToString(opt['split'], epoch))

    def printEpochSummary(self, opt):
        split = opt['split']
        epoch = opt['epoch']
        self.printWrite('---------- {} Epoch {} Summary -----------'.format(str.upper(split), epoch))
        self.printWrite(opt['recorder'].epochRecToString(split, epoch))

    def convertToSameSize(self, t_list):
        shape = (t_list[0].shape[0], 3, t_list[0].shape[2], t_list[0].shape[3])
        for i, tensor in enumerate(t_list):
            n, c, h, w = tensor.shape
            if tensor.shape[1] != shape[1]: # check channel
                t_list[i] = tensor.expand((n, shape[1], h, w))
            if h == shape[2] and w == shape[3]:
                continue
            t_list[i] = torch.nn.functional.upsample(tensor, [shape[2], shape[3]], mode='bilinear')
        return t_list

    def getSaveDir(self, split, epoch):
        save_dir = os.path.join(self.args.log_dir, split, 'Images')
        run_model = hasattr(self.args, 'run_model') and self.args.run_model
        if not run_model and epoch > 0:
            save_dir = os.path.join(save_dir, str(epoch))
        utils.makeFile(save_dir)
        return save_dir

    def splitMulitChannel(self, t_list, max_save_n = 8):
        new_list = []
        for tensor in t_list:
            if tensor.shape[1] > 3:
                num = 3
                new_list += torch.split(tensor, num, 1)[:max_save_n]
            else:
                new_list.append(tensor)
        return new_list

    def saveSplit(self, res, save_prefix):
        n, c, h, w = res.shape
        for i in range(n):
            vutils.save_image(res[i], save_prefix + '_%d.png' % (i))

    def saveImgResults(self, results, split, epoch, iters, nrow, error=''):
        max_save_n = self.args.test_save_n if split == 'test' else self.args.train_save_n
        res = [img.cpu() for img in results]
        res = self.splitMulitChannel(res, max_save_n)
        res = torch.cat(self.convertToSameSize(res))
        save_dir = self.getSaveDir(split, epoch)
        save_prefix = os.path.join(save_dir, '%d_%d' % (epoch, iters))
        save_prefix += ('_%s' % error) if error != '' else ''
        if self.args.save_split: 
            self.saveSplit(res, save_prefix)
        else:
            vutils.save_image(res, save_prefix + '_out.png', nrow=nrow)

    def plotCurves(self, recorder, split='train', epoch=-1, intv=1):
        dict_of_array = recorder.recordToDictOfArray(split, epoch, intv)
        save_dir = os.path.join(self.args.log_dir, split)
        if epoch < 0:
            save_dir = self.args.log_dir
            save_name = '%s_Summary.png' % (split)
        else:
            save_name = '%s_epoch_%d.png' % (split, epoch)

        classes = ['loss', 'acc', 'err', 'lr', 'ratio']
        classes = utils.checkIfInList(classes, dict_of_array.keys())
        if len(classes) == 0: return

        for idx, c in enumerate(classes):
            plt.subplot(len(classes), 1, idx+1)
            plt.grid()
            legends = []
            for k in dict_of_array.keys():
                if (c in k.lower()) and not k.endswith('_x'):
                    plt.plot(dict_of_array[k+'_x'], dict_of_array[k])
                    legends.append(k)
            if len(legends) != 0:
                plt.legend(legends, bbox_to_anchor=(0.5, 1.05), loc='upper center', 
                            ncol=len(legends), prop=fontP)
                plt.title(c)
                if epoch < 0: plt.xlabel('Epoch') 
                else: plt.xlabel('Iters')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, save_name))
        plt.clf()


================================================
FILE: utils/recorders.py
================================================
from collections import OrderedDict
import numpy as np

class Records(object):
    """
    Records->Train,Val->Loss,Accuracy->Epoch1,2,3->[v1,v2]
    IterRecords->Train,Val->Loss, Accuracy,->[v1,v2]
    """
    def __init__(self, log_dir, records=None):
        if records == None:
            self.records = OrderedDict()
        else:
            self.records = records
        self.iter_rec = OrderedDict()
        self.log_dir  = log_dir
        self.classes = ['loss', 'acc', 'err', 'ratio']

    def resetIter(self):
        self.iter_rec.clear()

    def checkDict(self, a_dict, key, sub_type='dict'):
        if key not in a_dict.keys():
            if sub_type == 'dict':
                a_dict[key] = OrderedDict()
            if sub_type == 'list':
                a_dict[key] = []

    def updateIter(self, split, keys, values):
        self.checkDict(self.iter_rec, split, 'dict')
        for k, v in zip(keys, values):
            self.checkDict(self.iter_rec[split], k, 'list')
            self.iter_rec[split][k].append(v)

    def saveIterRecord(self, epoch, reset=True):
        for s in self.iter_rec.keys(): # s stands for split
            self.checkDict(self.records, s, 'dict')
            for k in self.iter_rec[s].keys():
                self.checkDict(self.records[s], k, 'dict')
                self.checkDict(self.records[s][k], epoch, 'list')
                self.records[s][k][epoch].append(np.mean(self.iter_rec[s][k]))
        if reset: 
            self.resetIter()

    def insertRecord(self, split, key, epoch, value):
        self.checkDict(self.records, split, 'dict')
        self.checkDict(self.records[split], key, 'dict')
        self.checkDict(self.records[split][key], epoch, 'list')
        self.records[split][key][epoch].append(value)

    def iterRecToString(self, split, epoch):
        rec_strs = ''
        for c in self.classes:
            strs = ''
            for k in self.iter_rec[split].keys():
                if (c in k.lower()):
                    strs += '{}: {:.3f}| '.format(k, np.mean(self.iter_rec[split][k]))
            if strs != '':
                rec_strs += '\t [{}] {}\n'.format(c.upper(), strs)
        self.saveIterRecord(epoch)
        return rec_strs

    def epochRecToString(self, split, epoch):
        rec_strs = ''
        for c in self.classes:
            strs = ''
            for k in self.records[split].keys():
                if (c in k.lower()) and (epoch in self.records[split][k].keys()):
                    strs += '{}: {:.3f}| '.format(k, np.mean(self.records[split][k][epoch]))
            if strs != '':
                rec_strs += '\t [{}] {}\n'.format(c.upper(), strs)
        return rec_strs

    def recordToDictOfArray(self, splits, epoch=-1, intv=1):
        if len(self.records) == 0: return {}
        if type(splits) == str: splits = [splits]

        dict_of_array = OrderedDict()
        for split in splits:
            for k in self.records[split].keys():
                y_array, x_array = [], []
                if epoch < 0:
                    for ep in self.records[split][k].keys():
                        y_array.append(np.mean(self.records[split][k][ep]))
                        x_array.append(ep)
                else:
                    if epoch in self.records[split][k].keys():
                        y_array = np.array(self.records[split][k][epoch])
                        x_array = np.linspace(intv, intv*len(y_array), len(y_array))
                dict_of_array[split[0] + split[-1] + '_' + k]      = y_array
                dict_of_array[split[0] + split[-1] + '_' + k+'_x'] = x_array
        return dict_of_array


================================================
FILE: utils/time_utils.py
================================================
import time
import torch
from collections import OrderedDict

class Timer(object):
    def __init__(self, cuda_sync=False):
        self.timer = OrderedDict()
        self.cuda_sync = cuda_sync
        self.startTimer()

    def startTimer(self):
        self.iter_start = time.time()
        self.disp_start = time.time()

    def resetTimer(self):
        self.iter_start = time.time()
        self.disp_start = time.time()
        for key in self.timer.keys(): self.timer[key].reset()

    def updateTime(self, key):
        if key not in self.timer.keys(): self.timer[key] = AverageMeter()
        if self.cuda_sync: torch.cuda.synchronize()
        self.timer[key].update(time.time() - self.iter_start)
        self.iter_start = time.time()

    def timeToString(self, reset=True):
        strs = '\t [Time %.3fs] ' % (time.time() - self.disp_start)
        for key in self.timer.keys():
            if self.timer[key].sum < 1e-4: continue
            strs += '%s: %.3fs| ' % (key, self.timer[key].sum)
        self.resetTimer()
        return strs

class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __repr__(self):
        return '%.3f' % (self.avg)


================================================
FILE: utils/utils.py
================================================
import os 
import numpy as np
from imageio import imread, imsave

def makeFile(f):
    if not os.path.exists(f):
        os.makedirs(f)
    #else:  raise Exception('Rendered image directory %s is already existed!!!' % directory)

def makeFiles(f_list):
    for f in f_list:
        makeFile(f)

def emptyFile(name):
    with open(name, 'w') as f:
        f.write(' ')

def dictToString(dicts, start='\t', end='\n'):
    strs = '' 
    for k, v in sorted(dicts.items()):
        strs += '%s%s: %s%s' % (start, str(k), str(v), end) 
    return strs

def checkIfInList(list1, list2):
    contains = []
    for l1 in list1:
        for l2 in list2:
            if l1 in l2.lower():
                contains.append(l1)
                break
    return contains

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    '''
    alist.sort(key=natural_keys) sorts in human order
    http://nedbatchelder.com/blog/200712/human_sorting.html
    (See Toothy's implementation in the comments)
    '''
    return [ atoi(c) for c in re.split('(\d+)', text) ]

def readList(list_path,ignore_head=False, sort=False):
    lists = []
    with open(list_path) as f:
        lists = f.read().splitlines()
    if ignore_head:
        lists = lists[1:]
    if sort:
        lists.sort(key=natural_keys)
    return lists
Download .txt
gitextract_g3mq9mkn/

├── .gitignore
├── LICENSE.txt
├── README.md
├── data/
│   └── .gitignore
├── datasets/
│   ├── UPS_Custom_Dataset.py
│   ├── UPS_DiLiGenT_main.py
│   ├── UPS_Synth_Dataset.py
│   ├── __init__.py
│   ├── custom_data_loader.py
│   ├── pms_transforms.py
│   └── util.py
├── eval/
│   ├── run_stage1.py
│   └── run_stage2.py
├── main_stage1.py
├── main_stage2.py
├── models/
│   ├── LCNet.py
│   ├── NENet.py
│   ├── __init__.py
│   ├── custom_model.py
│   ├── model_utils.py
│   └── solver_utils.py
├── options/
│   ├── __init__.py
│   ├── base_opts.py
│   ├── run_model_opts.py
│   ├── stage1_opts.py
│   └── stage2_opts.py
├── scripts/
│   ├── DiLiGenT_objects.txt
│   ├── cropDiLiGenTData.py
│   ├── download_pretrained_models.sh
│   ├── download_synthetic_datasets.sh
│   └── prepare_diligent_dataset.sh
├── test_stage1.py
├── test_stage2.py
├── train_stage1.py
├── train_stage2.py
└── utils/
    ├── __init__.py
    ├── eval_utils.py
    ├── logger.py
    ├── recorders.py
    ├── time_utils.py
    └── utils.py
Download .txt
SYMBOL INDEX (179 symbols across 29 files)

FILE: datasets/UPS_Custom_Dataset.py
  class UPS_Custom_Dataset (line 14) | class UPS_Custom_Dataset(data.Dataset):
    method __init__ (line 15) | def __init__(self, args, split='train'):
    method _getMask (line 22) | def _getMask(self, obj):
    method __getitem__ (line 28) | def __getitem__(self, index):
    method __len__ (line 75) | def __len__(self):

FILE: datasets/UPS_DiLiGenT_main.py
  class UPS_DiLiGenT_main (line 15) | class UPS_DiLiGenT_main(data.Dataset):
    method __init__ (line 16) | def __init__(self, args, split='train'):
    method _getMask (line 31) | def _getMask(self, obj):
    method __getitem__ (line 37) | def __getitem__(self, index):
    method __len__ (line 83) | def __len__(self):

FILE: datasets/UPS_Synth_Dataset.py
  class UPS_Synth_Dataset (line 14) | class UPS_Synth_Dataset(data.Dataset):
    method __init__ (line 15) | def __init__(self, args, root, split='train'):
    method _getInputPath (line 21) | def _getInputPath(self, index):
    method __getitem__ (line 35) | def __getitem__(self, index):
    method __len__ (line 79) | def __len__(self):

FILE: datasets/custom_data_loader.py
  function customDataloader (line 3) | def customDataloader(args):
  function benchmarkLoader (line 27) | def benchmarkLoader(args):

FILE: datasets/pms_transforms.py
  function arrayToTensor (line 8) | def arrayToTensor(array):
  function normalToMask (line 15) | def normalToMask(normal, thres=1e-2):
  function imgSizeToFactorOfK (line 24) | def imgSizeToFactorOfK(img, k):
  function randomCrop (line 32) | def randomCrop(inputs, target, size):
  function centerCrop (line 43) | def centerCrop(inputs, size):
  function rescale (line 52) | def rescale(inputs, target, size):
  function rescaleSingle (line 60) | def rescaleSingle(inputs, size, order=1):
  function randomNoiseAug (line 67) | def randomNoiseAug(inputs, noise_level=0.05):
  function getIntensity (line 73) | def getIntensity(num):

FILE: datasets/util.py
  function atoi (line 4) | def atoi(text):
  function natural_keys (line 7) | def natural_keys(text):
  function readList (line 15) | def readList(list_path,ignore_head=False, sort=True):
  function light_source_directions (line 25) | def light_source_directions():

FILE: eval/run_stage1.py
  function main (line 14) | def main(args):

FILE: eval/run_stage2.py
  function main (line 16) | def main(args):

FILE: main_stage1.py
  function main (line 13) | def main(args):

FILE: main_stage2.py
  function main (line 13) | def main(args):

FILE: models/LCNet.py
  class FeatExtractor (line 8) | class FeatExtractor(nn.Module):
    method __init__ (line 9) | def __init__(self, batchNorm, c_in, c_out=256):
    method forward (line 19) | def forward(self, inputs):
  class Classifier (line 29) | class Classifier(nn.Module):
    method __init__ (line 30) | def __init__(self, batchNorm, c_in, other):
    method forward (line 50) | def forward(self, inputs):
  class LCNet (line 63) | class LCNet(nn.Module):
    method __init__ (line 64) | def __init__(self, fuse_type='max', batchNorm=False, c_in=3, other={}):
    method prepareInputs (line 81) | def prepareInputs(self, x):
    method fuseFeatures (line 106) | def fuseFeatures(self, feats, fuse_type):
    method convertMidDirs (line 113) | def convertMidDirs(self, pred):
    method convertMidIntens (line 119) | def convertMidIntens(self, pred, img_num):
    method forward (line 126) | def forward(self, x):

FILE: models/NENet.py
  class FeatExtractor (line 6) | class FeatExtractor(nn.Module):
    method __init__ (line 7) | def __init__(self, batchNorm=False, c_in=3, other={}):
    method forward (line 18) | def forward(self, x):
  class Regressor (line 30) | class Regressor(nn.Module):
    method __init__ (line 31) | def __init__(self, batchNorm=False, other={}):
    method _make_output (line 40) | def _make_output(self, cin, cout, k=3, stride=1, pad=1):
    method forward (line 44) | def forward(self, x, shape):
  class NENet (line 53) | class NENet(nn.Module):
    method __init__ (line 54) | def __init__(self, fuse_type='max', batchNorm=False, c_in=3, other={}):
    method prepareInputs (line 71) | def prepareInputs(self, x):
    method forward (line 90) | def forward(self, x):

FILE: models/custom_model.py
  function buildModel (line 4) | def buildModel(args):
  function buildModelStage2 (line 31) | def buildModelStage2(args):

FILE: models/model_utils.py
  function getInput (line 5) | def getInput(args, data):
  function parseData (line 11) | def parseData(args, sample, timer=None, split='train'):
  function getInputChanel (line 28) | def getInputChanel(args):
  function get_n_params (line 40) | def get_n_params(model):
  function loadCheckpoint (line 49) | def loadCheckpoint(path, model, cuda=True):
  function saveCheckpoint (line 56) | def saveCheckpoint(save_path, epoch=-1, model=None, optimizer=None, reco...
  function conv_ReLU (line 62) | def conv_ReLU(batchNorm, cin, cout, k=3, stride=1, pad=-1):
  function conv (line 77) | def conv(batchNorm, cin, cout, k=3, stride=1, pad=-1):
  function outputConv (line 92) | def outputConv(cin, cout, k=3, stride=1, pad=1):
  function deconv (line 96) | def deconv(cin, cout):
  function upconv (line 102) | def upconv(cin, cout):

FILE: models/solver_utils.py
  class Stage1ClsCrit (line 5) | class Stage1ClsCrit(object): # First Stage, Light classification criterion
    method __init__ (line 6) | def __init__(self, args):
    method setupLightCrit (line 13) | def setupLightCrit(self, args):
    method forward (line 27) | def forward(self, output, target):
    method backward (line 51) | def backward(self):
  class Stage2Crit (line 54) | class Stage2Crit(object): # Second stage
    method __init__ (line 55) | def __init__(self, args):
    method setupLightCrit (line 63) | def setupLightCrit(self, args):
    method setupNormalCrit (line 74) | def setupNormalCrit(self, args):
    method forward (line 86) | def forward(self, output, target):
    method backward (line 117) | def backward(self):
  function getOptimizer (line 120) | def getOptimizer(args, params):
  function getLrScheduler (line 130) | def getLrScheduler(args, optimizer):
  function loadRecords (line 135) | def loadRecords(path, model, optimizer):
  function configOptimizer (line 147) | def configOptimizer(args, model):

FILE: options/base_opts.py
  class BaseOpts (line 5) | class BaseOpts(object):
    method __init__ (line 6) | def __init__(self):
    method initialize (line 9) | def initialize(self):
    method setDefault (line 88) | def setDefault(self):
    method collectInfo (line 96) | def collectInfo(self):
    method parse (line 109) | def parse(self):

FILE: options/run_model_opts.py
  class RunModelOpts (line 2) | class RunModelOpts(BaseOpts):
    method __init__ (line 3) | def __init__(self):
    method initialize (line 7) | def initialize(self):
    method collectInfo (line 23) | def collectInfo(self):
    method setDefault (line 28) | def setDefault(self):
    method parse (line 31) | def parse(self):

FILE: options/stage1_opts.py
  class TrainOpts (line 2) | class TrainOpts(BaseOpts):
    method __init__ (line 3) | def __init__(self):
    method initialize (line 7) | def initialize(self):
    method collectInfo (line 31) | def collectInfo(self):
    method setDefault (line 41) | def setDefault(self):
    method parse (line 47) | def parse(self):

FILE: options/stage2_opts.py
  class TrainOpts (line 2) | class TrainOpts(BaseOpts):
    method __init__ (line 3) | def __init__(self):
    method initialize (line 7) | def initialize(self):
    method collectInfo (line 31) | def collectInfo(self):
    method setDefault (line 41) | def setDefault(self):
    method parse (line 47) | def parse(self):

FILE: scripts/cropDiLiGenTData.py
  function getSaveDir (line 21) | def getSaveDir():
  function getBBoxCompact (line 28) | def getBBoxCompact(mask):
  function loadMaskNormal (line 44) | def loadMaskNormal(d):
  function copyTXT (line 57) | def copyTXT(d):

FILE: test_stage1.py
  function get_itervals (line 7) | def get_itervals(args, split):
  function test (line 16) | def test(args, split, loader, model, log, epoch, recorder):
  function prepareRes (line 51) | def prepareRes(args, data, pred, recorder, log, split):
  function prepareSave (line 76) | def prepareSave(args, data, pred):

FILE: test_stage2.py
  function get_itervals (line 7) | def get_itervals(args, split):
  function test (line 16) | def test(args, split, loader, models, log, epoch, recorder):
  function prepareRes (line 59) | def prepareRes(args, data, pred_c, pred, recorder, log, split):
  function prepareSave (line 82) | def prepareSave(args, data, pred_c, pred):

FILE: train_stage1.py
  function train (line 4) | def train(args, loader, model, criterion, optimizer, log, epoch, recorder):
  function prepareSave (line 39) | def prepareSave(args, data, pred, recorder, log):

FILE: train_stage2.py
  function train (line 5) | def train(args, loader, models, criterion, optimizers, log, epoch, recor...
  function prepareSave (line 45) | def prepareSave(args, data, pred_c, pred, recorder, log):

FILE: utils/eval_utils.py
  function colorMap (line 6) | def colorMap(diff):
  function calDirsAcc (line 12) | def calDirsAcc(gt_l, pred_l, data_batch=1):
  function calIntsAcc (line 21) | def calIntsAcc(gt_i, pred_i, data_batch=1):
  function calNormalAcc (line 31) | def calNormalAcc(gt_n, pred_n, mask=None):
  function SphericalDirsToClass (line 52) | def SphericalDirsToClass(dirs, cls_num):
  function SphericalClassToDirs (line 62) | def SphericalClassToDirs(x_cls, y_cls, cls_num):
  function LightIntsToClass (line 81) | def LightIntsToClass(ints, cls_num):
  function ClassToLightInts (line 86) | def ClassToLightInts(cls, cls_num):

FILE: utils/logger.py
  class Logger (line 16) | class Logger(object):
    method __init__ (line 17) | def __init__(self, args):
    method printArgs (line 25) | def printArgs(self):
    method _addArguments (line 31) | def _addArguments(self, args):
    method _setupDirs (line 45) | def _setupDirs(self, args):
    method _checkPath (line 57) | def _checkPath(self, args, dir_name):
    method printWrite (line 76) | def printWrite(self, strs):
    method getTimeInfo (line 82) | def getTimeInfo(self, epoch, iters, batch):
    method printItersSummary (line 89) | def printItersSummary(self, opt):
    method printEpochSummary (line 103) | def printEpochSummary(self, opt):
    method convertToSameSize (line 109) | def convertToSameSize(self, t_list):
    method getSaveDir (line 120) | def getSaveDir(self, split, epoch):
    method splitMulitChannel (line 128) | def splitMulitChannel(self, t_list, max_save_n = 8):
    method saveSplit (line 138) | def saveSplit(self, res, save_prefix):
    method saveImgResults (line 143) | def saveImgResults(self, results, split, epoch, iters, nrow, error=''):
    method plotCurves (line 156) | def plotCurves(self, recorder, split='train', epoch=-1, intv=1):

FILE: utils/recorders.py
  class Records (line 4) | class Records(object):
    method __init__ (line 9) | def __init__(self, log_dir, records=None):
    method resetIter (line 18) | def resetIter(self):
    method checkDict (line 21) | def checkDict(self, a_dict, key, sub_type='dict'):
    method updateIter (line 28) | def updateIter(self, split, keys, values):
    method saveIterRecord (line 34) | def saveIterRecord(self, epoch, reset=True):
    method insertRecord (line 44) | def insertRecord(self, split, key, epoch, value):
    method iterRecToString (line 50) | def iterRecToString(self, split, epoch):
    method epochRecToString (line 62) | def epochRecToString(self, split, epoch):
    method recordToDictOfArray (line 73) | def recordToDictOfArray(self, splits, epoch=-1, intv=1):

FILE: utils/time_utils.py
  class Timer (line 5) | class Timer(object):
    method __init__ (line 6) | def __init__(self, cuda_sync=False):
    method startTimer (line 11) | def startTimer(self):
    method resetTimer (line 15) | def resetTimer(self):
    method updateTime (line 20) | def updateTime(self, key):
    method timeToString (line 26) | def timeToString(self, reset=True):
  class AverageMeter (line 34) | class AverageMeter(object):
    method __init__ (line 35) | def __init__(self):
    method reset (line 38) | def reset(self):
    method update (line 43) | def update(self, val, n=1):
    method __repr__ (line 48) | def __repr__(self):

FILE: utils/utils.py
  function makeFile (line 5) | def makeFile(f):
  function makeFiles (line 10) | def makeFiles(f_list):
  function emptyFile (line 14) | def emptyFile(name):
  function dictToString (line 18) | def dictToString(dicts, start='\t', end='\n'):
  function checkIfInList (line 24) | def checkIfInList(list1, list2):
  function atoi (line 33) | def atoi(text):
  function natural_keys (line 36) | def natural_keys(text):
  function readList (line 44) | def readList(list_path,ignore_head=False, sort=False):
Condensed preview — 41 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (108K chars).
[
  {
    "path": ".gitignore",
    "chars": 29,
    "preview": "*.jpg\n*.png\ntags\n*.pyc\n*.pyo\n"
  },
  {
    "path": "LICENSE.txt",
    "chars": 1070,
    "preview": "MIT License\n\nCopyright (c) 2018 Guanying Chen\n\nPermission is hereby granted, free of charge, to any person obtaining a c"
  },
  {
    "path": "README.md",
    "chars": 7210,
    "preview": "# SDPS-Net\n**[SDPS-Net: Self-calibrating Deep Photometric Stereo Networks, CVPR 2019 (Oral)](http://guanyingc.github.io/"
  },
  {
    "path": "data/.gitignore",
    "chars": 14,
    "preview": "*\n!.gitignore\n"
  },
  {
    "path": "datasets/UPS_Custom_Dataset.py",
    "chars": 2581,
    "preview": "from __future__ import division\nimport os\nimport numpy as np\nimport scipy.io as sio\nfrom imageio import imread\n\nimport t"
  },
  {
    "path": "datasets/UPS_DiLiGenT_main.py",
    "chars": 3056,
    "preview": "from __future__ import division\nimport os\nimport numpy as np\nimport scipy.io as sio\n#from scipy.ndimage import imread\nfr"
  },
  {
    "path": "datasets/UPS_Synth_Dataset.py",
    "chars": 3108,
    "preview": "from __future__ import division\nimport os\nimport numpy as np\n#from scipy.ndimage import imread\nfrom imageio import imrea"
  },
  {
    "path": "datasets/__init__.py",
    "chars": 168,
    "preview": "#from .PMS_dataset_v1 import PMS_dataset\n#from .PMS_dataset_v2 import PMS_data_v2\n#from .DiLiGenT import DiLiGenT\n#__all"
  },
  {
    "path": "datasets/custom_data_loader.py",
    "chars": 2030,
    "preview": "import torch.utils.data\n\ndef customDataloader(args):\n    args.log.printWrite(\"=> fetching img pairs in %s\" % (args.data_"
  },
  {
    "path": "datasets/pms_transforms.py",
    "chars": 2387,
    "preview": "import torch\nimport random\nimport numpy as np\nfrom skimage.transform import resize\nrandom.seed(0)\nnp.random.seed(0)\n\ndef"
  },
  {
    "path": "datasets/util.py",
    "chars": 6244,
    "preview": "import numpy as np\nimport re\n\ndef atoi(text):\n    return int(text) if text.isdigit() else text\n\ndef natural_keys(text):\n"
  },
  {
    "path": "eval/run_stage1.py",
    "chars": 649,
    "preview": "import torch, sys\nsys.path.append('.')\n\nfrom datasets import custom_data_loader\nfrom options  import run_model_opts\nfrom"
  },
  {
    "path": "eval/run_stage2.py",
    "chars": 775,
    "preview": "import torch, sys\nsys.path.append('.')\n\nfrom datasets import custom_data_loader\nfrom options  import run_model_opts\nfrom"
  },
  {
    "path": "main_stage1.py",
    "chars": 1322,
    "preview": "import torch\nfrom options  import stage1_opts\nfrom utils    import logger, recorders\nfrom datasets import custom_data_lo"
  },
  {
    "path": "main_stage2.py",
    "chars": 1444,
    "preview": "import torch\nfrom options  import stage2_opts\nfrom utils    import logger, recorders\nfrom datasets import custom_data_lo"
  },
  {
    "path": "models/LCNet.py",
    "chars": 6455,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn.init import kaiming_normal_\nfrom . import model_utils\nfrom utils import"
  },
  {
    "path": "models/NENet.py",
    "chars": 4356,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn.init import kaiming_normal_\nfrom . import model_utils\n\nclass FeatExtrac"
  },
  {
    "path": "models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/custom_model.py",
    "chars": 2084,
    "preview": "from . import model_utils\nimport torch\n\ndef buildModel(args):\n    print('Creating Model %s' % (args.model))\n    in_c = m"
  },
  {
    "path": "models/model_utils.py",
    "chars": 4027,
    "preview": "import os\nimport torch\nimport torch.nn as nn\n\ndef getInput(args, data):\n    input_list = [data['img']]\n    if args.in_li"
  },
  {
    "path": "models/solver_utils.py",
    "chars": 6410,
    "preview": "import torch\nimport os\nfrom utils import eval_utils\n\nclass Stage1ClsCrit(object): # First Stage, Light classification cr"
  },
  {
    "path": "options/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "options/base_opts.py",
    "chars": 6500,
    "preview": "import argparse\nimport os\nimport torch\n\nclass BaseOpts(object):\n    def __init__(self):\n        self.parser = argparse.A"
  },
  {
    "path": "options/run_model_opts.py",
    "chars": 1616,
    "preview": "from .base_opts import BaseOpts\nclass RunModelOpts(BaseOpts):\n    def __init__(self):\n        super(RunModelOpts, self)."
  },
  {
    "path": "options/stage1_opts.py",
    "chars": 2324,
    "preview": "from .base_opts import BaseOpts\nclass TrainOpts(BaseOpts):\n    def __init__(self):\n        super(TrainOpts, self).__init"
  },
  {
    "path": "options/stage2_opts.py",
    "chars": 2253,
    "preview": "from .base_opts import BaseOpts\nclass TrainOpts(BaseOpts):\n    def __init__(self):\n        super(TrainOpts, self).__init"
  },
  {
    "path": "scripts/DiLiGenT_objects.txt",
    "chars": 88,
    "preview": "ballPNG\ncatPNG\npot1PNG\nbearPNG\npot2PNG\nbuddhaPNG\ngobletPNG\nreadingPNG\ncowPNG\nharvestPNG\n"
  },
  {
    "path": "scripts/cropDiLiGenTData.py",
    "chars": 3962,
    "preview": "import os, argparse, sys, shutil, glob\nimport numpy as np\nfrom imageio import imread, imsave\nimport scipy.io as sio\n\nroo"
  },
  {
    "path": "scripts/download_pretrained_models.sh",
    "chars": 249,
    "preview": "path=\"data/models/\"\nmkdir -p $path\ncd $path\n\n# Download pre-trained model\nfor model in \"LCNet_CVPR2019.pth.tar\" \"NENet_C"
  },
  {
    "path": "scripts/download_synthetic_datasets.sh",
    "chars": 322,
    "preview": "mkdir -p data/datasets\ncd data/datasets\n\n# Download Synthetic dataset\nfor dataset in \"PS_Sculpture_Dataset.tgz\" \"PS_Blob"
  },
  {
    "path": "scripts/prepare_diligent_dataset.sh",
    "chars": 422,
    "preview": "mkdir -p data/datasets\ncd data/datasets\n\n## Download real testing dataset\nurl=\"https://www.dropbox.com/s/hdnbh526tyvv68i"
  },
  {
    "path": "test_stage1.py",
    "chars": 3454,
    "preview": "import os\nimport torch\nfrom models import model_utils\nfrom utils import eval_utils, time_utils \nimport numpy as np\n\ndef "
  },
  {
    "path": "test_stage2.py",
    "chars": 3819,
    "preview": "import os\nimport torch\nfrom models import model_utils\nfrom utils import eval_utils, time_utils \nimport numpy as np\n\ndef "
  },
  {
    "path": "train_stage1.py",
    "chars": 2524,
    "preview": "from models import model_utils\nfrom utils  import eval_utils, time_utils\n\ndef train(args, loader, model, criterion, opti"
  },
  {
    "path": "train_stage2.py",
    "chars": 2759,
    "preview": "import torch\nfrom models import model_utils\nfrom utils  import eval_utils, time_utils\n\ndef train(args, loader, models, c"
  },
  {
    "path": "utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "utils/eval_utils.py",
    "chars": 3421,
    "preview": "import torch\nimport math\nimport numpy as np\nfrom matplotlib import cm\n\ndef colorMap(diff):\n    thres = 90\n    diff_norm "
  },
  {
    "path": "utils/logger.py",
    "chars": 7800,
    "preview": "import datetime, time, os\nimport numpy as np\nimport torch\nimport torchvision.utils as vutils\nimport scipy.io as sio\nfrom"
  },
  {
    "path": "utils/recorders.py",
    "chars": 3646,
    "preview": "from collections import OrderedDict\nimport numpy as np\n\nclass Records(object):\n    \"\"\"\n    Records->Train,Val->Loss,Accu"
  },
  {
    "path": "utils/time_utils.py",
    "chars": 1401,
    "preview": "import time\nimport torch\nfrom collections import OrderedDict\n\nclass Timer(object):\n    def __init__(self, cuda_sync=Fals"
  },
  {
    "path": "utils/utils.py",
    "chars": 1336,
    "preview": "import os \nimport numpy as np\nfrom imageio import imread, imsave\n\ndef makeFile(f):\n    if not os.path.exists(f):\n       "
  }
]

About this extraction

This page contains the full source code of the guanyingc/SDPS-Net GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 41 files (100.9 KB), approximately 29.1k tokens, and a symbol index with 179 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!