Repository: guanyingc/SDPS-Net
Branch: master
Commit: b84ca3aaa5cd
Files: 41
Total size: 100.9 KB
Directory structure:
gitextract_g3mq9mkn/
├── .gitignore
├── LICENSE.txt
├── README.md
├── data/
│ └── .gitignore
├── datasets/
│ ├── UPS_Custom_Dataset.py
│ ├── UPS_DiLiGenT_main.py
│ ├── UPS_Synth_Dataset.py
│ ├── __init__.py
│ ├── custom_data_loader.py
│ ├── pms_transforms.py
│ └── util.py
├── eval/
│ ├── run_stage1.py
│ └── run_stage2.py
├── main_stage1.py
├── main_stage2.py
├── models/
│ ├── LCNet.py
│ ├── NENet.py
│ ├── __init__.py
│ ├── custom_model.py
│ ├── model_utils.py
│ └── solver_utils.py
├── options/
│ ├── __init__.py
│ ├── base_opts.py
│ ├── run_model_opts.py
│ ├── stage1_opts.py
│ └── stage2_opts.py
├── scripts/
│ ├── DiLiGenT_objects.txt
│ ├── cropDiLiGenTData.py
│ ├── download_pretrained_models.sh
│ ├── download_synthetic_datasets.sh
│ └── prepare_diligent_dataset.sh
├── test_stage1.py
├── test_stage2.py
├── train_stage1.py
├── train_stage2.py
└── utils/
├── __init__.py
├── eval_utils.py
├── logger.py
├── recorders.py
├── time_utils.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.jpg
*.png
tags
*.pyc
*.pyo
================================================
FILE: LICENSE.txt
================================================
MIT License
Copyright (c) 2018 Guanying Chen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# SDPS-Net
**[SDPS-Net: Self-calibrating Deep Photometric Stereo Networks, CVPR 2019 (Oral)](http://guanyingc.github.io/SDPS-Net/)**.
[Guanying Chen](https://guanyingc.github.io), [Kai Han](http://www.hankai.org/), [Boxin Shi](http://alumni.media.mit.edu/~shiboxin/), [Yasuyuki Matsushita](http://www-infobiz.ist.osaka-u.ac.jp/en/member/matsushita/), [Kwan-Yee K. Wong](http://i.cs.hku.hk/~kykwong/)
This paper addresses the problem of learning based _uncalibrated_ photometric stereo for non-Lambertian surface.
### _Changelog_
- July 28, 2019: We have already updated the code to support Python 3.7 + PyTorch 1.10. To run the previous version (Python 2.7 + PyTorch 0.40), please checkout to `python2.7` branch first (by `git checkout python2.7`).
## Dependencies
SDPS-Net is implemented in [PyTorch](https://pytorch.org/) and tested with Ubuntu (14.04 and 16.04), please install PyTorch first following the official instruction.
- Python 3.7
- PyTorch (version = 1.10)
- torchvision
- CUDA-9.0 # Skip this if you only have CPUs in your computer
- numpy
- scipy
- scikit-image
You are highly recommended to use Anaconda and create a new environment to run this code.
```shell
# Create a new python3.7 environment named py3.7
conda create -n py3.7 python=3.7
# Activate the created environment
source activate py3.7
# Example commands for installing the dependencies
conda install pytorch torchvision cudatoolkit=9.0 -c pytorch
conda install -c anaconda scipy
conda install -c anaconda scikit-image
# Download this code
git clone https://github.com/guanyingc/SDPS-Net.git
cd SDPS-Net
```
## Overview
We provide:
- Trained models
- LCNet for lighting calibration from input images
- NENet for normal estimation from input images and estimated lightings.
- Code to test on DiLiGenT main dataset
- Full code to train a new model, including codes for debugging, visualization and logging.
## Testing
### Download the trained models
```
sh scripts/download_pretrained_models.sh
```
If the above command is not working, please manually download the trained models from BaiduYun ([LCNet and NENet](https://pan.baidu.com/s/10huOyPkfDSkDUK23_j4y1w?pwd=i5ha)) and put them in `./data/models/`.
### Test SDPS-Net on the DiLiGenT main dataset
```shell
# Prepare the DiLiGenT main dataset
sh scripts/prepare_diligent_dataset.sh
# This command will first download and unzip the DiLiGenT dataset, and then centered crop
# the original images based on the object mask with a margin size of 15 pixels.
# Test SDPS-Net on DiLiGenT main dataset using all of the 96 image
CUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar
# Please check the outputs in data/models/
# If you only have CPUs, please add the argument "--cuda" to disable the usage of GPU
python eval/run_stage2.py --cuda --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar
```
### Test SDPS-Net on your own dataset
You have two options to test our method on your dataset. In the first option, you have to implement a customized Dataset class to load your data, which should not be difficult. Please refer to `datasets/UPS_DiLiGenT_main.py` for an example that loads the DiLiGenT main dataset.
If you don't want to implement your own Dataset class, you may try our `datasets/UPS_Custom_Dataset.py`. However, you have to first arrange your dataset in the same format as the `data/ToyPSDataset/`. Then you can call the following commands.
```shell
CUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar --benchmark UPS_Custom_Dataset --bm_dir /path/to/your/dataset
# To test SDPS-Net on the ToyPSDataset, simply run
CUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar --benchmark UPS_Custom_Dataset --bm_dir data/ToyPSDataset/
# Please check the outputs in data/models/
```
You may find input arguments in `run_model_opts.py` (particularly `--have_l_dirs`, `--have_l_ints`, and `--have_gt_n`) useful when testing your own dataset.
## Training
We adopted the publicly available synthetic [PS Blobby and Sculpture datasets](https://github.com/guanyingc/PS-FCN) for training.
To train a new SDPS-Net model, please follow the following steps:
### Download the training data
```shell
# The total size of the zipped synthetic datasets is 4.7+19=23.7 GB
# and it takes some times to download and unzip the datasets.
sh scripts/download_synthetic_datasets.sh
```
If the above command is not working, please manually download the training datasets from BaiduYun ([PS Sculpture Dataset and PS Blobby Dataset](https://pan.baidu.com/s/1WUVu9ibIBh4wM1shTXBuNw?pwd=snyc) and put them in `./data/datasets/`.
### First stage: train Lighting Calibration Network (LCNet)
```shell
# Train LCNet on synthetic datasets using 32 input images
CUDA_VISIBLE_DEVICES=0 python main_stage1.py --in_img_num 32
# Please refer to options/base_opt.py and options/stage1_opt.py for more options
# You can find checkpoints and results in data/logdir/
# It takes about 20 hours to train LCNet on a single Titan X Pascal GPU.
```
### Second stage: train Normal Estimation Network (NENet)
```shell
# Train NENet on synthetic datasets using 32 input images
CUDA_VISIBLE_DEVICES=0 python main_stage2.py --in_img_num 32 --retrain data/logdir/path/to/checkpointDirOfLCNet/checkpoint20.pth.tar
# Please refer to options/base_opt.py and options/stage2_opt.py for more options
# You can find checkpoints and results in data/logdir/
# It takes about 26 hours to train NENet on a single Titan X Pascal GPU.
```
## FAQ
#### Q1: How to test SDPS-Net on other datasets?
- You can implement a customized Dataset class to load your data. You may also use the provided `datasets/UPS_Custom_Dataset.py` Dataset class to load your data. However, you have to first arrange your dataset in the same format as the `data/ToyPSDataset/`. Precomputed results on DiLiGenT main dataset, Gourd\&Apple dataset, Light Stage Dataset and Synthetic Test dataset are available upon request.
#### Q2: What should I do if I have problem in running your code?
- Please create an issue if you encounter errors when trying to run the code. Please also feel free to submit a bug report.
#### Q3: Could I run your code only using CPUs?
- The good news is that you can simply append `--cuda` in your command to disable the usage of GPU. The running time for the testing on DiLiGenT benchmark using CPUs is still bearable (should be less than 20 minutes). However, it is EXTREMELY SLOW for training!
## Citation
If you find this code or the provided models useful in your research, please consider cite:
```
@inproceedings{chen2019SDPS_Net,
title={SDPS-Net: Self-calibrating Deep Photometric Stereo Networks},
author={Chen, Guanying and Han, Kai and Shi, Boxin and Matsushita, Yasuyuki and Wong, Kwan-Yee~K.},
booktitle={CVPR},
year={2019}
}
```
================================================
FILE: data/.gitignore
================================================
*
!.gitignore
================================================
FILE: datasets/UPS_Custom_Dataset.py
================================================
from __future__ import division
import os
import numpy as np
import scipy.io as sio
from imageio import imread
import torch
import torch.utils.data as data
from datasets import pms_transforms
from . import util
np.random.seed(0)
class UPS_Custom_Dataset(data.Dataset):
def __init__(self, args, split='train'):
self.root = os.path.join(args.bm_dir)
self.split = split
self.args = args
self.objs = util.readList(os.path.join(self.root, 'objects.txt'), sort=False)
args.log.printWrite('[%s Data] \t%d objs. Root: %s' % (split, len(self.objs), self.root))
def _getMask(self, obj):
mask = imread(os.path.join(self.root, obj, 'mask.png'))
if mask.ndim > 2: mask = mask[:,:,0]
mask = mask.reshape(mask.shape[0], mask.shape[1], 1)
return mask / 255.0
def __getitem__(self, index):
obj = self.objs[index]
names = util.readList(os.path.join(self.root, obj, 'names.txt'))
img_list = [os.path.join(self.root, obj, names[i]) for i in range(len(names))]
if self.args.have_l_dirs:
dirs = np.genfromtxt(os.path.join(self.root, obj, 'light_directions.txt'))
else:
dirs = np.zeros((len(names), 3))
dirs[:,2] = 1
if self.args.have_l_ints:
ints = np.genfromtxt(os.path.join(self.root, obj, 'light_intensities.txt'))
else:
ints = np.ones((len(names), 3))
imgs = []
for idx, img_name in enumerate(img_list):
img = imread(img_name).astype(np.float32) / 255.0
imgs.append(img)
img = np.concatenate(imgs, 2)
h, w, c = img.shape
if self.args.have_gt_n:
normal_path = os.path.join(self.root, obj, 'normal.png')
normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1
else:
normal = np.zeros((h, w, 3))
mask = self._getMask(obj)
img = img * mask.repeat(img.shape[2], 2)
item = {'normal': normal, 'img': img, 'mask': mask}
downsample = 4
for k in item.keys():
item[k] = pms_transforms.imgSizeToFactorOfK(item[k], downsample)
for k in item.keys():
item[k] = pms_transforms.arrayToTensor(item[k])
item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()
item['ints'] = torch.from_numpy(ints).view(-1, 1, 1).float()
item['obj'] = obj
item['path'] = os.path.join(self.root, obj)
return item
def __len__(self):
return len(self.objs)
================================================
FILE: datasets/UPS_DiLiGenT_main.py
================================================
from __future__ import division
import os
import numpy as np
import scipy.io as sio
#from scipy.ndimage import imread
from imageio import imread
import torch
import torch.utils.data as data
from datasets import pms_transforms
from . import util
np.random.seed(0)
class UPS_DiLiGenT_main(data.Dataset):
def __init__(self, args, split='train'):
self.root = os.path.join(args.bm_dir)
self.split = split
self.args = args
self.objs = util.readList(os.path.join(self.root, 'objects.txt'), sort=False)
self.names = util.readList(os.path.join(self.root, 'names.txt'), sort=False)
self.l_dir = util.light_source_directions()
args.log.printWrite('[%s Data] \t%d objs %d lights. Root: %s' %
(split, len(self.objs), len(self.names), self.root))
self.ints = {}
ints_name = 'light_intensities.txt'
print('Files for intensity: %s' % (ints_name))
for obj in self.objs:
self.ints[obj] = np.genfromtxt(os.path.join(self.root, obj, ints_name))
def _getMask(self, obj):
mask = imread(os.path.join(self.root, obj, 'mask.png'))
if mask.ndim > 2: mask = mask[:,:,0]
mask = mask.reshape(mask.shape[0], mask.shape[1], 1)
return mask / 255.0
def __getitem__(self, index):
np.random.seed(index)
obj = self.objs[index]
select_idx = range(len(self.names))
img_list = [os.path.join(self.root, obj, self.names[i]) for i in select_idx]
ints = [np.diag(1 / self.ints[obj][i]) for i in select_idx]
dirs = self.l_dir[select_idx]
normal_path = os.path.join(self.root, obj, 'Normal_gt.mat')
normal = sio.loadmat(normal_path)['Normal_gt']
imgs = []
for idx, img_name in enumerate(img_list):
img = imread(img_name).astype(np.float32) / 255.0
if self.args.in_light and not self.args.int_aug:
img = np.dot(img, ints[idx])
imgs.append(img)
img = np.concatenate(imgs, 2)
mask = self._getMask(obj)
if self.args.test_resc:
img, normal = pms_transforms.rescale(img, normal, [self.args.test_h, self.args.test_w])
mask = pms_transforms.rescaleSingle(mask, [self.args.test_h, self.args.test_w])
img = img * mask.repeat(img.shape[2], 2)
norm = np.sqrt((normal * normal).sum(2, keepdims=True))
normal = normal / (norm + 1e-10)
item = {'normal': normal, 'img': img, 'mask': mask}
downsample = 4
for k in item.keys():
item[k] = pms_transforms.imgSizeToFactorOfK(item[k], downsample)
for k in item.keys():
item[k] = pms_transforms.arrayToTensor(item[k])
item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()
item['ints'] = torch.from_numpy(self.ints[obj][select_idx]).view(-1, 1, 1).float()
item['obj'] = obj
item['path'] = os.path.join(self.root, obj)
return item
def __len__(self):
return len(self.objs)
================================================
FILE: datasets/UPS_Synth_Dataset.py
================================================
from __future__ import division
import os
import numpy as np
#from scipy.ndimage import imread
from imageio import imread
import torch
import torch.utils.data as data
from datasets import pms_transforms
from . import util
np.random.seed(0)
class UPS_Synth_Dataset(data.Dataset):
def __init__(self, args, root, split='train'):
self.root = os.path.join(root)
self.split = split
self.args = args
self.shape_list = util.readList(os.path.join(self.root, split + args.l_suffix))
def _getInputPath(self, index):
shape, mtrl = self.shape_list[index].split('/')
normal_path = os.path.join(self.root, 'Images', shape, shape + '_normal.png')
img_dir = os.path.join(self.root, 'Images', self.shape_list[index])
img_list = util.readList(os.path.join(img_dir, '%s_%s.txt' % (shape, mtrl)))
data = np.genfromtxt(img_list, dtype='str', delimiter=' ')
select_idx = np.random.permutation(data.shape[0])[:self.args.in_img_num]
idxs = ['%03d' % (idx) for idx in select_idx]
data = data[select_idx, :]
imgs = [os.path.join(img_dir, img) for img in data[:, 0]]
dirs = data[:, 1:4].astype(np.float32)
return normal_path, imgs, dirs
def __getitem__(self, index):
normal_path, img_list, dirs = self._getInputPath(index)
normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1
imgs = []
for i in img_list:
img = imread(i).astype(np.float32) / 255.0
imgs.append(img)
img = np.concatenate(imgs, 2)
h, w, c = img.shape
crop_h, crop_w = self.args.crop_h, self.args.crop_w
if self.args.rescale and not (crop_h == h):
sc_h = np.random.randint(crop_h, h) if self.args.rand_sc else self.args.scale_h
sc_w = np.random.randint(crop_w, w) if self.args.rand_sc else self.args.scale_w
img, normal = pms_transforms.rescale(img, normal, [sc_h, sc_w])
if self.args.crop:
img, normal = pms_transforms.randomCrop(img, normal, [crop_h, crop_w])
if self.args.color_aug:
img = img * np.random.uniform(1, self.args.color_ratio)
if self.args.int_aug:
ints = pms_transforms.getIntensity(len(imgs))
img = np.dot(img, np.diag(ints.reshape(-1)))
else:
ints = np.ones(c)
if self.args.noise_aug:
img = pms_transforms.randomNoiseAug(img, self.args.noise)
mask = pms_transforms.normalToMask(normal)
normal = normal * mask.repeat(3, 2)
norm = np.sqrt((normal * normal).sum(2, keepdims=True))
normal = normal / (norm + 1e-10) # Rescale normal to unit length
item = {'normal': normal, 'img': img, 'mask': mask}
for k in item.keys():
item[k] = pms_transforms.arrayToTensor(item[k])
item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float()
item['ints'] = torch.from_numpy(ints).view(-1, 1, 1).float()
return item
def __len__(self):
return len(self.shape_list)
================================================
FILE: datasets/__init__.py
================================================
#from .PMS_dataset_v1 import PMS_dataset
#from .PMS_dataset_v2 import PMS_data_v2
#from .DiLiGenT import DiLiGenT
#__all__ = ('PMS_dataset', 'DiLiGenT', 'PMS_data_v2')
================================================
FILE: datasets/custom_data_loader.py
================================================
import torch.utils.data
def customDataloader(args):
args.log.printWrite("=> fetching img pairs in %s" % (args.data_dir))
datasets = __import__('datasets.' + args.dataset)
dataset_file = getattr(datasets, args.dataset)
train_set = getattr(dataset_file, args.dataset)(args, args.data_dir, 'train')
val_set = getattr(dataset_file, args.dataset)(args, args.data_dir, 'val')
if args.concat_data:
args.log.printWrite('****** Using cocnat data ******')
args.log.printWrite("=> fetching img pairs in '{}'".format(args.data_dir2))
train_set2 = getattr(dataset_file, args.dataset)(args, args.data_dir2, 'train')
val_set2 = getattr(dataset_file, args.dataset)(args, args.data_dir2, 'val')
train_set = torch.utils.data.ConcatDataset([train_set, train_set2])
val_set = torch.utils.data.ConcatDataset([val_set, val_set2])
args.log.printWrite('Found Data:\t %d Train and %d Val' % (len(train_set), len(val_set)))
args.log.printWrite('\t Train Batch: %d, Val Batch: %d' % (args.batch, args.val_batch))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch,
num_workers=args.workers, pin_memory=args.cuda, shuffle=True)
test_loader = torch.utils.data.DataLoader(val_set , batch_size=args.val_batch,
num_workers=args.workers, pin_memory=args.cuda, shuffle=False)
return train_loader, test_loader
def benchmarkLoader(args):
args.log.printWrite("=> fetching img pairs in 'data/%s'" % (args.benchmark))
datasets = __import__('datasets.' + args.benchmark)
dataset_file = getattr(datasets, args.benchmark)
test_set = getattr(dataset_file, args.benchmark)(args, 'test')
args.log.printWrite('Found Benchmark Data: %d samples' % (len(test_set)))
args.log.printWrite('\t Test Batch %d' % (args.test_batch))
test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.test_batch,
num_workers=args.workers, pin_memory=args.cuda, shuffle=False)
return test_loader
================================================
FILE: datasets/pms_transforms.py
================================================
import torch
import random
import numpy as np
from skimage.transform import resize
random.seed(0)
np.random.seed(0)
def arrayToTensor(array):
if array is None:
return array
array = np.transpose(array, (2, 0, 1))
tensor = torch.from_numpy(array)
return tensor.float()
def normalToMask(normal, thres=1e-2):
"""
Due to the numerical precision of uint8, [0, 0, 0] will save as [127, 127, 127] in gt normal,
When we load the data and rescale normal by N / 255 * 2 - 1, [127, 127, 127] becomes
[-0.003927, -0.003927, -0.003927]
"""
mask = (np.square(normal).sum(2, keepdims=True) > thres).astype(np.float32)
return mask
def imgSizeToFactorOfK(img, k):
if img.shape[0] % k == 0 and img.shape[1] % k == 0:
return img
pad_h, pad_w = k - img.shape[0] % k, k - img.shape[1] % k
img = np.pad(img, ((0, pad_h), (0, pad_w), (0,0)),
'constant', constant_values=((0,0),(0,0),(0,0)))
return img
def randomCrop(inputs, target, size):
h, w, _ = inputs.shape
c_h, c_w = size
if h == c_h and w == c_w:
return inputs, target
x1 = random.randint(0, w - c_w)
y1 = random.randint(0, h - c_h)
inputs = inputs[y1: y1 + c_h, x1: x1 + c_w]
target = target[y1: y1 + c_h, x1: x1 + c_w]
return inputs, target
def centerCrop(inputs, size):
h, w, _ = inputs.shape
c_h, c_w = size
if h != c_h or w != c_w:
x1 = int(w / 2 - c_w / 2)
y1 = int(h / 2 - c_h / 2)
inputs = inputs[y1: y1 + c_h, x1: x1 + c_w]
return inputs
def rescale(inputs, target, size):
in_h, in_w, _ = inputs.shape
h, w = size
if h != in_h or w != in_w:
inputs = resize(inputs, size, order=1, mode='reflect')
target = resize(target, size, order=1, mode='reflect')
return inputs, target
def rescaleSingle(inputs, size, order=1):
in_h, in_w, _ = inputs.shape
h, w = size
if h != in_h or w != in_w:
inputs = resize(inputs, size, order=order, mode='reflect')
return inputs
def randomNoiseAug(inputs, noise_level=0.05):
noise = np.random.random(inputs.shape)
noise = (noise - 0.5) * noise_level
inputs += noise
return inputs
def getIntensity(num):
intensity = np.random.random((num, 1)) * 1.8 + 0.2
color = np.ones((1, 3)) # Uniform color
intens = (intensity.repeat(3, 1) * color)
return intens
================================================
FILE: datasets/util.py
================================================
import numpy as np
import re
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [ atoi(c) for c in re.split('(\d+)', text) ]
def readList(list_path,ignore_head=False, sort=True):
lists = []
with open(list_path) as f:
lists = f.read().splitlines()
if ignore_head:
lists = lists[1:]
if sort:
lists.sort(key=natural_keys)
return lists
def light_source_directions():
"""
Below matrix is from DiLiGenT.
:return: light source direction matrix. [light_num, 3]
:rtype: np.ndarray
"""
L = np.array([[-0.06059872, -0.44839055, 0.8917812],
[-0.05939919, -0.33739538, 0.93948714],
[-0.05710194, -0.21230722, 0.97553319],
[-0.05360061, -0.07800089, 0.99551134],
[-0.04919816, 0.05869781, 0.99706274],
[-0.04399823, 0.19019233, 0.98076044],
[-0.03839991, 0.31049925, 0.9497977],
[-0.03280081, 0.41611025, 0.90872238],
[-0.18449839, -0.43989616, 0.87889232],
[-0.18870114, -0.32950199, 0.92510557],
[-0.1901994, -0.20549935, 0.95999698],
[-0.18849605, -0.07269848, 0.97937948],
[-0.18329657, 0.06229884, 0.98108166],
[-0.17500445, 0.19220488, 0.96562453],
[-0.16449474, 0.31129005, 0.93597008],
[-0.15270716, 0.4160195, 0.89644202],
[-0.30139786, -0.42509698, 0.85349393],
[-0.31020115, -0.31660118, 0.89640333],
[-0.31489186, -0.19549495, 0.92877599],
[-0.31450962, -0.06640203, 0.94692897],
[-0.30880699, 0.06470146, 0.94892147],
[-0.2981084, 0.19100538, 0.93522635],
[-0.28359251, 0.30729189, 0.90837601],
[-0.26670649, 0.41020998, 0.87212122],
[-0.40709586, -0.40559588, 0.81839168],
[-0.41919869, -0.29999906, 0.85689732],
[-0.42618633, -0.18329412, 0.88587159],
[-0.42691512, -0.05950211, 0.90233197],
[-0.42090385, 0.0659006, 0.90470827],
[-0.40860354, 0.18720162, 0.89330773],
[-0.39141794, 0.29941372, 0.87013988],
[-0.3707838, 0.39958255, 0.83836338],
[-0.499596, -0.38319693, 0.77689378],
[-0.51360334, -0.28130183, 0.81060526],
[-0.52190667, -0.16990217, 0.83591069],
[-0.52326874, -0.05249686, 0.85054918],
[-0.51720021, 0.06620003, 0.85330035],
[-0.50428312, 0.18139393, 0.84427174],
[-0.48561334, 0.28870793, 0.82512267],
[-0.46289771, 0.38549809, 0.79819605],
[-0.57853599, -0.35932235, 0.73224555],
[-0.59329349, -0.26189713, 0.76119165],
[-0.60202327, -0.15630604, 0.78303027],
[-0.6037003, -0.04570002, 0.7959004],
[-0.59781529, 0.06590169, 0.79892043],
[-0.58486953, 0.17439091, 0.79215873],
[-0.56588359, 0.27639198, 0.77677747],
[-0.54241965, 0.36921337, 0.75462733],
[0.05220076, -0.43870637, 0.89711304],
[0.05199786, -0.33138635, 0.9420612],
[0.05109826, -0.20999284, 0.97636672],
[0.04919919, -0.07869871, 0.99568366],
[0.04640163, 0.05630197, 0.99733494],
[0.04279892, 0.18779527, 0.98127529],
[0.03870043, 0.30950341, 0.95011048],
[0.03440055, 0.41730662, 0.90811441],
[0.17290651, -0.43181626, 0.88523333],
[0.17839998, -0.32509996, 0.92869988],
[0.18160174, -0.20480196, 0.96180921],
[0.18200745, -0.07490306, 0.98044012],
[0.17919505, 0.05849838, 0.98207285],
[0.17329685, 0.18839658, 0.96668244],
[0.1649036, 0.30880674, 0.93672045],
[0.1549931, 0.41578148, 0.89616009],
[0.28720483, -0.41910705, 0.8613145],
[0.29740177, -0.31410186, 0.90160535],
[0.30420604, -0.1965039, 0.9321185],
[0.30640529, -0.07010121, 0.94931639],
[0.30361153, 0.05950226, 0.95093613],
[0.29588748, 0.18589214, 0.93696036],
[0.28409783, 0.30349768, 0.90949304],
[0.26939905, 0.40849857, 0.87209694],
[0.39120402, -0.40190413, 0.8279085],
[0.40481085, -0.29960803, 0.86392315],
[0.41411685, -0.18590756, 0.89103626],
[0.41769724, -0.06449957, 0.906294],
[0.41498764, 0.05959822, 0.90787296],
[0.40607977, 0.18089099, 0.89575537],
[0.39179226, 0.29439419, 0.87168279],
[0.37379609, 0.39649585, 0.83849122],
[0.48278794, -0.38169046, 0.78818031],
[0.49848546, -0.28279175, 0.8194761],
[0.50918069, -0.1740934, 0.84286803],
[0.51360856, -0.05870098, 0.85601427],
[0.51097962, 0.05899765, 0.8575658],
[0.50151639, 0.17420569, 0.84742769],
[0.48600297, 0.28260173, 0.82700506],
[0.46600106, 0.38110087, 0.79850181],
[0.56150442, -0.35990283, 0.74510586],
[0.57807114, -0.26498677, 0.77176147],
[0.58933134, -0.1617086, 0.7915421],
[0.59407609, -0.05289787, 0.80266769],
[0.59157958, 0.057798, 0.80417224],
[0.58198189, 0.16649482, 0.79597523],
[0.56620006, 0.26940003, 0.77900008],
[0.54551481, 0.36380988, 0.7550205]], dtype=float)
return L
================================================
FILE: eval/run_stage1.py
================================================
import torch, sys
sys.path.append('.')
from datasets import custom_data_loader
from options import run_model_opts
from models import custom_model
from utils import logger, recorders
import test_stage1 as test_utils
args = run_model_opts.RunModelOpts().parse()
log = logger.Logger(args)
def main(args):
test_loader = custom_data_loader.benchmarkLoader(args)
model = custom_model.buildModel(args)
recorder = recorders.Records(args.log_dir)
test_utils.test(args, 'test', test_loader, model, log, 1, recorder)
log.plotCurves(recorder, 'test')
if __name__ == '__main__':
torch.manual_seed(args.seed)
main(args)
================================================
FILE: eval/run_stage2.py
================================================
import torch, sys
sys.path.append('.')
from datasets import custom_data_loader
from options import run_model_opts
from models import custom_model
from utils import logger, recorders
import test_stage2 as test_utils
args = run_model_opts.RunModelOpts().parse()
args.stage2 = True
args.test_resc = False
log = logger.Logger(args)
def main(args):
test_loader = custom_data_loader.benchmarkLoader(args)
model = custom_model.buildModel(args)
model_s2 = custom_model.buildModelStage2(args)
models = [model, model_s2]
recorder = recorders.Records(args.log_dir)
test_utils.test(args, 'test', test_loader, models, log, 1, recorder)
log.plotCurves(recorder, 'test')
if __name__ == '__main__':
torch.manual_seed(args.seed)
main(args)
================================================
FILE: main_stage1.py
================================================
import torch
from options import stage1_opts
from utils import logger, recorders
from datasets import custom_data_loader
from models import custom_model, solver_utils, model_utils
import train_stage1 as train_utils
import test_stage1 as test_utils
args = stage1_opts.TrainOpts().parse()
log = logger.Logger(args)
def main(args):
model = custom_model.buildModel(args)
optimizer, scheduler, records = solver_utils.configOptimizer(args, model)
criterion = solver_utils.Stage1ClsCrit(args)
recorder = recorders.Records(args.log_dir, records)
train_loader, val_loader = custom_data_loader.customDataloader(args)
for epoch in range(args.start_epoch, args.epochs+1):
scheduler.step()
recorder.insertRecord('train', 'lr', epoch, scheduler.get_lr()[0])
train_utils.train(args, train_loader, model, criterion, optimizer, log, epoch, recorder)
if epoch % args.save_intv == 0:
model_utils.saveCheckpoint(args.cp_dir, epoch, model, optimizer, recorder.records, args)
log.plotCurves(recorder, 'train')
if epoch % args.val_intv == 0:
test_utils.test(args, 'val', val_loader, model, log, epoch, recorder)
log.plotCurves(recorder, 'val')
if __name__ == '__main__':
torch.manual_seed(args.seed)
main(args)
================================================
FILE: main_stage2.py
================================================
import torch
from options import stage2_opts
from utils import logger, recorders
from datasets import custom_data_loader
from models import custom_model, solver_utils, model_utils
import train_stage2 as train_utils
import test_stage2 as test_utils
args = stage2_opts.TrainOpts().parse()
log = logger.Logger(args)
def main(args):
model = custom_model.buildModel(args)
model_s2 = custom_model.buildModelStage2(args)
models = [model, model_s2]
optimizer, scheduler, records = solver_utils.configOptimizer(args, model_s2)
optimizers = [optimizer, -1]
criterion = solver_utils.Stage2Crit(args)
recorder = recorders.Records(args.log_dir, records)
train_loader, val_loader = custom_data_loader.customDataloader(args)
for epoch in range(args.start_epoch, args.epochs+1):
scheduler.step()
recorder.insertRecord('train', 'lr', epoch, scheduler.get_lr()[0])
train_utils.train(args, train_loader, models, criterion, optimizers, log, epoch, recorder)
if epoch % args.save_intv == 0:
model_utils.saveCheckpoint(args.cp_dir, epoch, model_s2, optimizer, recorder.records, args)
log.plotCurves(recorder, 'train')
if epoch % args.val_intv == 0:
test_utils.test(args, 'val', val_loader, models, log, epoch, recorder)
log.plotCurves(recorder, 'val')
if __name__ == '__main__':
torch.manual_seed(args.seed)
main(args)
================================================
FILE: models/LCNet.py
================================================
import torch
import torch.nn as nn
from torch.nn.init import kaiming_normal_
from . import model_utils
from utils import eval_utils
# Classification
class FeatExtractor(nn.Module):
def __init__(self, batchNorm, c_in, c_out=256):
super(FeatExtractor, self).__init__()
self.conv1 = model_utils.conv(batchNorm, c_in, 64, k=3, stride=2, pad=1)
self.conv2 = model_utils.conv(batchNorm, 64, 128, k=3, stride=2, pad=1)
self.conv3 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1)
self.conv4 = model_utils.conv(batchNorm, 128, 128, k=3, stride=2, pad=1)
self.conv5 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1)
self.conv6 = model_utils.conv(batchNorm, 128, 256, k=3, stride=2, pad=1)
self.conv7 = model_utils.conv(batchNorm, 256, 256, k=3, stride=1, pad=1)
def forward(self, inputs):
out = self.conv1(inputs)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
out = self.conv5(out)
out = self.conv6(out)
out = self.conv7(out)
return out
class Classifier(nn.Module):
def __init__(self, batchNorm, c_in, other):
super(Classifier, self).__init__()
self.conv1 = model_utils.conv(batchNorm, 512, 256, k=3, stride=1, pad=1)
self.conv2 = model_utils.conv(batchNorm, 256, 256, k=3, stride=2, pad=1)
self.conv3 = model_utils.conv(batchNorm, 256, 256, k=3, stride=2, pad=1)
self.conv4 = model_utils.conv(batchNorm, 256, 256, k=3, stride=2, pad=1)
self.other = other
self.dir_x_est = nn.Sequential(
model_utils.conv(batchNorm, 256, 64, k=1, stride=1, pad=0),
model_utils.outputConv(64, other['dirs_cls'], k=1, stride=1, pad=0))
self.dir_y_est = nn.Sequential(
model_utils.conv(batchNorm, 256, 64, k=1, stride=1, pad=0),
model_utils.outputConv(64, other['dirs_cls'], k=1, stride=1, pad=0))
self.int_est = nn.Sequential(
model_utils.conv(batchNorm, 256, 64, k=1, stride=1, pad=0),
model_utils.outputConv(64, other['ints_cls'], k=1, stride=1, pad=0))
def forward(self, inputs):
out = self.conv1(inputs)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
outputs = {}
if self.other['s1_est_d']:
outputs['dir_x'] = self.dir_x_est(out)
outputs['dir_y'] = self.dir_y_est(out)
if self.other['s1_est_i']:
outputs['ints'] = self.int_est(out)
return outputs
class LCNet(nn.Module):
def __init__(self, fuse_type='max', batchNorm=False, c_in=3, other={}):
super(LCNet, self).__init__()
self.featExtractor = FeatExtractor(batchNorm, c_in, 128)
self.classifier = Classifier(batchNorm, 256, other)
self.c_in = c_in
self.fuse_type = fuse_type
self.other = other
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def prepareInputs(self, x):
n, c, h, w = x[0].shape
t_h, t_w = self.other['test_h'], self.other['test_w']
if (h == t_h and w == t_w):
imgs = x[0]
else:
print('Rescaling images: from %dX%d to %dX%d' % (h, w, t_h, t_w))
imgs = torch.nn.functional.upsample(x[0], size=(t_h, t_w), mode='bilinear')
inputs = list(torch.split(imgs, 3, 1))
idx = 1
if self.other['in_light']:
light = torch.split(x[idx], 3, 1)
for i in range(len(inputs)):
inputs[i] = torch.cat([inputs[i], light[i]], 1)
idx += 1
if self.other['in_mask']:
mask = x[idx]
if mask.shape[2] != inputs[0].shape[2] or mask.shape[3] != inputs[0].shape[3]:
mask = torch.nn.functional.upsample(mask, size=(t_h, t_w), mode='bilinear')
for i in range(len(inputs)):
inputs[i] = torch.cat([inputs[i], mask], 1)
idx += 1
return inputs
def fuseFeatures(self, feats, fuse_type):
if fuse_type == 'mean':
feat_fused = torch.stack(feats, 1).mean(1)
elif fuse_type == 'max':
feat_fused, _ = torch.stack(feats, 1).max(1)
return feat_fused
def convertMidDirs(self, pred):
_, x_idx = pred['dirs_x'].data.max(1)
_, y_idx = pred['dirs_y'].data.max(1)
dirs = eval_utils.SphericalClassToDirs(x_idx, y_idx, self.other['dirs_cls'])
return dirs
def convertMidIntens(self, pred, img_num):
_, idx = pred['ints'].data.max(1)
ints = eval_utils.ClassToLightInts(idx, self.other['ints_cls'])
ints = ints.view(-1, 1).repeat(1, 3)
ints = torch.cat(torch.split(ints, ints.shape[0] // img_num, 0), 1)
return ints
def forward(self, x):
inputs = self.prepareInputs(x)
feats = []
for i in range(len(inputs)):
out_feat = self.featExtractor(inputs[i])
shape = out_feat.data.shape
feats.append(out_feat)
feat_fused = self.fuseFeatures(feats, self.fuse_type)
l_dirs_x, l_dirs_y, l_ints = [], [], []
for i in range(len(inputs)):
net_input = torch.cat([feats[i], feat_fused], 1)
outputs = self.classifier(net_input)
if self.other['s1_est_d']:
l_dirs_x.append(outputs['dir_x'])
l_dirs_y.append(outputs['dir_y'])
if self.other['s1_est_i']:
l_ints.append(outputs['ints'])
pred = {}
if self.other['s1_est_d']:
pred['dirs_x'] = torch.cat(l_dirs_x, 0).squeeze()
pred['dirs_y'] = torch.cat(l_dirs_y, 0).squeeze()
pred['dirs'] = self.convertMidDirs(pred)
if self.other['s1_est_i']:
pred['ints'] = torch.cat(l_ints, 0).squeeze()
if pred['ints'].ndimension() == 1:
pred['ints'] = pred['ints'].view(1, -1)
pred['intens'] = self.convertMidIntens(pred, len(inputs))
return pred
================================================
FILE: models/NENet.py
================================================
import torch
import torch.nn as nn
from torch.nn.init import kaiming_normal_
from . import model_utils
class FeatExtractor(nn.Module):
def __init__(self, batchNorm=False, c_in=3, other={}):
super(FeatExtractor, self).__init__()
self.other = other
self.conv1 = model_utils.conv(batchNorm, c_in, 64, k=3, stride=1, pad=1)
self.conv2 = model_utils.conv(batchNorm, 64, 128, k=3, stride=2, pad=1)
self.conv3 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1)
self.conv4 = model_utils.conv(batchNorm, 128, 256, k=3, stride=2, pad=1)
self.conv5 = model_utils.conv(batchNorm, 256, 256, k=3, stride=1, pad=1)
self.conv6 = model_utils.deconv(256, 128)
self.conv7 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
out = self.conv5(out)
out = self.conv6(out)
out_feat = self.conv7(out)
n, c, h, w = out_feat.data.shape
out_feat = out_feat.view(-1)
return out_feat, [n, c, h, w]
class Regressor(nn.Module):
def __init__(self, batchNorm=False, other={}):
super(Regressor, self).__init__()
self.other = other
self.deconv1 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1)
self.deconv2 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1)
self.deconv3 = model_utils.deconv(128, 64)
self.est_normal = self._make_output(64, 3, k=3, stride=1, pad=1)
self.other = other
def _make_output(self, cin, cout, k=3, stride=1, pad=1):
return nn.Sequential(
nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False))
def forward(self, x, shape):
x = x.view(shape[0], shape[1], shape[2], shape[3])
out = self.deconv1(x)
out = self.deconv2(out)
out = self.deconv3(out)
normal = self.est_normal(out)
normal = torch.nn.functional.normalize(normal, 2, 1)
return normal
class NENet(nn.Module):
def __init__(self, fuse_type='max', batchNorm=False, c_in=3, other={}):
super(NENet, self).__init__()
self.extractor = FeatExtractor(batchNorm, c_in, other)
self.regressor = Regressor(batchNorm, other)
self.c_in = c_in
self.fuse_type = fuse_type
self.other = other
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
kaiming_normal_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def prepareInputs(self, x):
imgs = torch.split(x[0], 3, 1)
idx = 1
if self.other['in_light']: idx += 1
if self.other['in_mask']: idx += 1
dirs = torch.split(x[idx]['dirs'], x[0].shape[0], 0)
ints = torch.split(x[idx]['intens'], 3, 1)
s2_inputs = []
for i in range(len(imgs)):
n, c, h, w = imgs[i].shape
l_dir = dirs[i] if dirs[i].dim() == 4 else dirs[i].view(n, -1, 1, 1)
l_int = torch.diag(1.0 / (ints[i].contiguous().view(-1)+1e-8))
img = imgs[i].contiguous().view(n * c, h * w)
img = torch.mm(l_int, img).view(n, c, h, w)
img_light = torch.cat([img, l_dir.expand_as(img)], 1)
s2_inputs.append(img_light)
return s2_inputs
def forward(self, x):
inputs = self.prepareInputs(x)
feats = torch.Tensor()
for i in range(len(inputs)):
feat, shape = self.extractor(inputs[i])
if i == 0:
feats = feat
else:
if self.fuse_type == 'mean':
feats = torch.stack([feats, feat], 1).sum(1)
elif self.fuse_type == 'max':
feats, _ = torch.stack([feats, feat], 1).max(1)
if self.fuse_type == 'mean':
feats = feats / len(img_split)
feat_fused = feats
normal = self.regressor(feat_fused, shape)
pred = {}
pred['n'] = normal
return pred
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/custom_model.py
================================================
from . import model_utils
import torch
def buildModel(args):
print('Creating Model %s' % (args.model))
in_c = model_utils.getInputChanel(args)
other = {
'img_num': args.in_img_num,
'test_h': args.test_h, 'test_w': args.test_w,
'in_mask': args.in_mask, 'in_light': args.in_light,
'dirs_cls': args.dirs_cls, 'ints_cls': args.ints_cls,
's1_est_d': args.s1_est_d, 's1_est_i': args.s1_est_i, 's1_est_n': args.s1_est_n,
}
models = __import__('models.' + args.model)
model_file = getattr(models, args.model)
model = getattr(model_file, args.model)(args.fuse_type, args.use_BN, in_c, other)
if args.cuda: model = model.cuda()
if args.retrain:
args.log.printWrite("=> using pre-trained model '{}'".format(args.retrain))
model_utils.loadCheckpoint(args.retrain, model, cuda=args.cuda)
if args.resume:
args.log.printWrite("=> Resume loading checkpoint '{}'".format(args.resume))
model_utils.loadCheckpoint(args.resume, model, cuda=args.cuda)
print(model)
args.log.printWrite("=> Model Parameters: %d" % (model_utils.get_n_params(model)))
return model
def buildModelStage2(args):
print('Creating Stage2 Model %s' % (args.model_s2))
in_c = 6 if args.s2_in_light else 3
other = {
'img_num': args.in_img_num,
'in_mask': args.in_mask, 'in_light': args.in_light,
'dirs_cls': args.dirs_cls, 'ints_cls': args.ints_cls,
}
models = __import__('models.' + args.model_s2)
model_file = getattr(models, args.model_s2)
model = getattr(model_file, args.model_s2)(args.fuse_type, args.use_BN, in_c, other)
if args.cuda: model = model.cuda()
if args.retrain_s2:
args.log.printWrite("=> using pre-trained model_s2 '{}'".format(args.retrain_s2))
model_utils.loadCheckpoint(args.retrain_s2, model, cuda=args.cuda)
print(model)
args.log.printWrite("=> Stage2 Model Parameters: %d" % (model_utils.get_n_params(model)))
return model
================================================
FILE: models/model_utils.py
================================================
import os
import torch
import torch.nn as nn
def getInput(args, data):
input_list = [data['img']]
if args.in_light: input_list.append(data['dirs'])
if args.in_mask: input_list.append(data['m'])
return input_list
def parseData(args, sample, timer=None, split='train'):
img, normal, mask = sample['img'], sample['normal'], sample['mask']
ints = sample['ints']
if args.in_light:
dirs = sample['dirs'].expand_as(img)
else: # predict lighting, prepare ground truth
n, c, h, w = sample['dirs'].shape
dirs_split = torch.split(sample['dirs'].view(n, c), 3, 1)
dirs = torch.cat(dirs_split, 0)
if timer: timer.updateTime('ToCPU')
if args.cuda:
img, normal, mask = img.cuda(), normal.cuda(), mask.cuda()
dirs, ints = dirs.cuda(), ints.cuda()
if timer: timer.updateTime('ToGPU')
data = {'img': img, 'n': normal, 'm': mask, 'dirs': dirs, 'ints': ints}
return data
def getInputChanel(args):
args.log.printWrite('[Network Input] Color image as input')
c_in = 3
if args.in_light:
args.log.printWrite('[Network Input] Adding Light direction as input')
c_in += 3
if args.in_mask:
args.log.printWrite('[Network Input] Adding Mask as input')
c_in += 1
args.log.printWrite('[Network Input] Input channel: {}'.format(c_in))
return c_in
def get_n_params(model):
pp = 0
for p in list(model.parameters()):
nn = 1
for s in list(p.size()):
nn = nn * s
pp += nn
return pp
def loadCheckpoint(path, model, cuda=True):
if cuda:
checkpoint = torch.load(path)
else:
checkpoint = torch.load(path, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint['state_dict'])
def saveCheckpoint(save_path, epoch=-1, model=None, optimizer=None, records=None, args=None):
state = {'state_dict': model.state_dict(), 'model': args.model}
records = {'epoch': epoch, 'optimizer':optimizer.state_dict(), 'records': records} # 'args': args}
torch.save(state, os.path.join(save_path, 'checkp_{}.pth.tar'.format(epoch)))
torch.save(records, os.path.join(save_path, 'checkp_{}_rec.pth.tar'.format(epoch)))
def conv_ReLU(batchNorm, cin, cout, k=3, stride=1, pad=-1):
pad = pad if pad >= 0 else (k - 1) // 2
if batchNorm:
print('=> convolutional layer with bachnorm')
return nn.Sequential(
nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False),
nn.BatchNorm2d(cout),
nn.ReLU(inplace=True)
)
else:
return nn.Sequential(
nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True),
nn.ReLU(inplace=True)
)
def conv(batchNorm, cin, cout, k=3, stride=1, pad=-1):
pad = pad if pad >= 0 else (k - 1) // 2
if batchNorm:
print('=> convolutional layer with bachnorm')
return nn.Sequential(
nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False),
nn.BatchNorm2d(cout),
nn.LeakyReLU(0.1, inplace=True)
)
else:
return nn.Sequential(
nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True),
nn.LeakyReLU(0.1, inplace=True)
)
def outputConv(cin, cout, k=3, stride=1, pad=1):
return nn.Sequential(
nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True))
def deconv(cin, cout):
return nn.Sequential(
nn.ConvTranspose2d(cin, cout, kernel_size=4, stride=2, padding=1, bias=False),
nn.LeakyReLU(0.1, inplace=True)
)
def upconv(cin, cout):
return nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear'),
nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False),
nn.LeakyReLU(0.1, inplace=True)
)
================================================
FILE: models/solver_utils.py
================================================
import torch
import os
from utils import eval_utils
class Stage1ClsCrit(object): # First Stage, Light classification criterion
def __init__(self, args):
print('==> Using Stage1ClsCrit for lighting classification')
self.s1_est_d = args.s1_est_d
self.s1_est_i = args.s1_est_i
self.ints_cls, self.dirs_cls = args.ints_cls, args.dirs_cls
self.setupLightCrit(args)
def setupLightCrit(self, args):
args.log.printWrite('=> Using light criterion')
if self.s1_est_d:
self.dir_w = args.dir_w
self.dirs_x_crit = torch.nn.CrossEntropyLoss()
self.dirs_y_crit = torch.nn.CrossEntropyLoss()
if args.cuda:
self.dirs_x_crit = self.dirs_x_crit.cuda()
self.dirs_y_crit = self.dirs_y_crit.cuda()
if self.s1_est_i:
self.ints_w = args.ints_w
self.ints_crit = torch.nn.CrossEntropyLoss()
if args.cuda: self.ints_crit = self.ints_crit.cuda()
def forward(self, output, target):
self.loss = 0
out_loss = {}
if self.s1_est_d:
est_dir_x, est_dir_y = output['dirs_x'], output['dirs_y']
gt_dir_x, gt_dir_y = eval_utils.SphericalDirsToClass(target['dirs'], self.dirs_cls)
dirs_x_loss = self.dirs_x_crit(est_dir_x, gt_dir_x)
dirs_y_loss = self.dirs_y_crit(est_dir_y, gt_dir_y)
out_loss['D_x_loss'] = dirs_x_loss.item()
out_loss['D_y_loss'] = dirs_y_loss.item()
self.loss += self.dir_w * (dirs_x_loss + dirs_y_loss)
if self.s1_est_i:
est_intens = output['ints']
gt_ints = target['ints'].squeeze()[:, 0: target['ints'].shape[1]:3]
gt_ints = torch.cat(torch.split(gt_ints, 1, 1), 0)
gt_intens = eval_utils.LightIntsToClass(gt_ints, self.ints_cls)
ints_loss = self.ints_crit(est_intens, gt_intens)
out_loss['I_loss'] = ints_loss.item()
self.loss += self.ints_w * ints_loss
return out_loss
def backward(self):
self.loss.backward()
class Stage2Crit(object): # Second stage
def __init__(self, args):
self.s2_est_n = args.s2_est_n
self.s2_est_d = args.s2_est_d
self.s2_est_i = args.s2_est_i
self.setupLightCrit(args)
if self.s2_est_n:
self.setupNormalCrit(args)
def setupLightCrit(self, args):
args.log.printWrite('=> Using light criterion')
if self.s2_est_d:
self.dir_w = args.dir_w
self.dirs_crit = torch.nn.CosineEmbeddingLoss()
if args.cuda: self.dirs_crit = self.dirs_crit.cuda()
if self.s2_est_i:
self.ints_w = args.ints_w
self.ints_crit = torch.nn.MSELoss()
if args.cuda: self.ints_crit = self.ints_crit.cuda()
def setupNormalCrit(self, args):
args.log.printWrite('=> Using {} for criterion normal'.format(args.normal_loss))
self.normal_w = args.normal_w
if args.normal_loss == 'mse':
self.n_crit = torch.nn.MSELoss()
elif args.normal_loss == 'cos':
self.n_crit = torch.nn.CosineEmbeddingLoss()
else:
raise Exception("=> Unknown Criterion '{}'".format(args.normal_loss))
if args.cuda:
self.n_crit = self.n_crit.cuda()
def forward(self, output, target):
self.loss = 0
out_loss = {}
if self.s2_est_d:
d_est, d_tar = output['dirs'], target['dirs']
d_num = d_tar.nelement() // d_tar.shape[1]
if not hasattr(self, 'l_flag') or d_num != self.l_flag.nelement():
self.l_flag = d_tar.data.new().resize_(d_num).fill_(1)
dirs_loss = self.dirs_crit(d_est.squeeze(), d_tar, self.l_flag)
self.loss += self.dir_w * dirs_loss
out_loss['D_loss'] = dirs_loss.item()
if self.s2_est_i:
i_est, i_tar = output['ints'], target['ints']
ints_loss = self.ints_crit(i_est, i_tar.squeeze())
self.loss += self.ints_w * ints_loss
out_loss['I_loss'] = ints_loss.item()
if self.s2_est_n:
n_est, n_tar = output['n'], target['n']
n_num = n_tar.nelement() // n_tar.shape[1]
if not hasattr(self, 'n_flag') or n_num != self.n_flag.nelement():
self.n_flag = n_tar.data.new().resize_(n_num).fill_(1)
self.out_reshape = n_est.permute(0, 2, 3, 1).contiguous().view(-1, 3)
self.gt_reshape = n_tar.permute(0, 2, 3, 1).contiguous().view(-1, 3)
normal_loss = self.n_crit(self.out_reshape, self.gt_reshape, self.n_flag)
self.loss += self.normal_w * normal_loss
out_loss['N_loss'] = normal_loss.item()
return out_loss
def backward(self):
self.loss.backward()
def getOptimizer(args, params):
args.log.printWrite('=> Using %s solver for optimization' % (args.solver))
if args.solver == 'adam':
optimizer = torch.optim.Adam(params, args.init_lr, betas=(args.beta_1, args.beta_2))
elif args.solver == 'sgd':
optimizer = torch.optim.SGD(params, args.init_lr, momentum=args.momentum)
else:
raise Exception("=> Unknown Optimizer %s" % (args.solver))
return optimizer
def getLrScheduler(args, optimizer):
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=args.milestones, gamma=args.lr_decay, last_epoch=args.start_epoch-2)
return scheduler
def loadRecords(path, model, optimizer):
records = None
if os.path.isfile(path):
records = torch.load(path[:-8] + '_rec' + path[-8:])
optimizer.load_state_dict(records['optimizer'])
start_epoch = records['epoch'] + 1
records = records['records']
print("=> loaded Records")
else:
raise Exception("=> no checkpoint found at '{}'".format(path))
return records, start_epoch
def configOptimizer(args, model):
records = None
optimizer = getOptimizer(args, model.parameters())
if args.resume:
args.log.printWrite("=> Resume loading checkpoint '{}'".format(args.resume))
records, start_epoch = loadRecords(args.resume, model, optimizer)
args.start_epoch = start_epoch
scheduler = getLrScheduler(args, optimizer)
return optimizer, scheduler, records
================================================
FILE: options/__init__.py
================================================
================================================
FILE: options/base_opts.py
================================================
import argparse
import os
import torch
class BaseOpts(object):
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
def initialize(self):
#### Trainining Dataset ####
self.parser.add_argument('--dataset', default='UPS_Synth_Dataset')
self.parser.add_argument('--data_dir', default='data/datasets/PS_Blobby_Dataset')
self.parser.add_argument('--data_dir2', default='data/datasets/PS_Sculpture_Dataset')
self.parser.add_argument('--concat_data', default=True, action='store_false')
self.parser.add_argument('--l_suffix', default='_mtrl.txt')
#### Training Data and Preprocessing Arguments ####
self.parser.add_argument('--rescale', default=True, action='store_false')
self.parser.add_argument('--rand_sc', default=True, action='store_false')
self.parser.add_argument('--scale_h', default=128, type=int)
self.parser.add_argument('--scale_w', default=128, type=int)
self.parser.add_argument('--crop', default=True, action='store_false')
self.parser.add_argument('--crop_h', default=128, type=int)
self.parser.add_argument('--crop_w', default=128, type=int)
self.parser.add_argument('--test_h', default=128, type=int)
self.parser.add_argument('--test_w', default=128, type=int)
self.parser.add_argument('--test_resc', default=True, action='store_false')
self.parser.add_argument('--int_aug', default=True, action='store_false')
self.parser.add_argument('--noise_aug', default=True, action='store_false')
self.parser.add_argument('--noise', default=0.05, type=float)
self.parser.add_argument('--color_aug', default=True, action='store_false')
self.parser.add_argument('--color_ratio', default=3, type=float)
self.parser.add_argument('--normalize', default=False, action='store_true')
#### Device Arguments ####
self.parser.add_argument('--cuda', default=True, action='store_false')
self.parser.add_argument('--multi_gpu', default=False, action='store_true')
self.parser.add_argument('--time_sync', default=False, action='store_true')
self.parser.add_argument('--workers', default=8, type=int)
self.parser.add_argument('--seed', default=0, type=int)
#### Stage 1 Model Arguments ####
self.parser.add_argument('--dirs_cls', default=36, type=int)
self.parser.add_argument('--ints_cls', default=20, type=int)
self.parser.add_argument('--dir_int', default=False, action='store_true')
self.parser.add_argument('--model', default='LCNet')
self.parser.add_argument('--fuse_type', default='max')
self.parser.add_argument('--in_img_num', default=32, type=int)
self.parser.add_argument('--s1_est_n', default=False, action='store_true')
self.parser.add_argument('--s1_est_d', default=True, action='store_false')
self.parser.add_argument('--s1_est_i', default=True, action='store_false')
self.parser.add_argument('--in_light', default=False, action='store_true')
self.parser.add_argument('--in_mask', default=True, action='store_false')
self.parser.add_argument('--use_BN', default=False, action='store_true')
self.parser.add_argument('--resume', default=None)
self.parser.add_argument('--retrain', default=None)
self.parser.add_argument('--save_intv', default=1, type=int)
#### Stage 2 Model Arguments ####
self.parser.add_argument('--stage2', default=False, action='store_true')
self.parser.add_argument('--model_s2', default='NENet')
self.parser.add_argument('--retrain_s2', default=None)
self.parser.add_argument('--s2_est_n', default=True, action='store_false')
self.parser.add_argument('--s2_est_i', default=False, action='store_true')
self.parser.add_argument('--s2_est_d', default=False, action='store_true')
self.parser.add_argument('--s2_in_light', default=True, action='store_false')
#### Displaying Arguments ####
self.parser.add_argument('--train_disp', default=20, type=int)
self.parser.add_argument('--train_save', default=200, type=int)
self.parser.add_argument('--val_intv', default=1, type=int)
self.parser.add_argument('--val_disp', default=1, type=int)
self.parser.add_argument('--val_save', default=1, type=int)
self.parser.add_argument('--max_train_iter',default=-1, type=int)
self.parser.add_argument('--max_val_iter', default=-1, type=int)
self.parser.add_argument('--max_test_iter', default=-1, type=int)
self.parser.add_argument('--train_save_n', default=4, type=int)
self.parser.add_argument('--test_save_n', default=4, type=int)
#### Log Arguments ####
self.parser.add_argument('--save_root', default='data/logdir/')
self.parser.add_argument('--item', default='CVPR2019')
self.parser.add_argument('--suffix', default=None)
self.parser.add_argument('--debug', default=False, action='store_true')
self.parser.add_argument('--make_dir', default=True, action='store_false')
self.parser.add_argument('--save_split', default=False, action='store_true')
def setDefault(self):
if self.args.debug:
self.args.train_disp = 1
self.args.train_save = 1
self.args.max_train_iter = 4
self.args.max_val_iter = 4
self.args.max_test_iter = 4
self.args.test_intv = 1
def collectInfo(self):
self.args.str_keys = [
'model', 'fuse_type', 'solver'
]
self.args.val_keys = [
'batch', 'scale_h', 'crop_h', 'init_lr', 'normal_w',
'dir_w', 'ints_w', 'in_img_num', 'dirs_cls', 'ints_cls'
]
self.args.bool_keys = [
'use_BN', 'in_light', 'in_mask', 's1_est_n', 's1_est_d', 's1_est_i',
'color_aug', 'int_aug', 'concat_data', 'retrain', 'resume', 'stage2',
]
def parse(self):
self.args = self.parser.parse_args()
return self.args
================================================
FILE: options/run_model_opts.py
================================================
from .base_opts import BaseOpts
class RunModelOpts(BaseOpts):
def __init__(self):
super(RunModelOpts, self).__init__()
self.initialize()
def initialize(self):
BaseOpts.initialize(self)
#### Testing Dataset Arguments ####
self.parser.add_argument('--run_model', default=True, action='store_false')
self.parser.add_argument('--benchmark', default='UPS_DiLiGenT_main')
self.parser.add_argument('--bm_dir', default='data/datasets/DiLiGenT/pmsData_crop')
self.parser.add_argument('--epochs', default=1, type=int)
self.parser.add_argument('--test_batch', default=1, type=int)
self.parser.add_argument('--test_disp', default=1, type=int)
self.parser.add_argument('--test_save', default=1, type=int)
### For UPS_Custom_Datast.py to test you own dataset
self.parser.add_argument('--have_l_dirs', default=False, action='store_true', help='Have light directions?')
self.parser.add_argument('--have_l_ints', default=False, action='store_true', help='Have light intensities?')
self.parser.add_argument('--have_gt_n', default=False, action='store_true', help='Have GT surface normals?')
def collectInfo(self):
self.args.str_keys = ['model', 'model_s2', 'benchmark', 'fuse_type']
self.args.val_keys = ['in_img_num', 'test_h', 'test_w']
self.args.bool_keys = ['int_aug', 'test_resc']
def setDefault(self):
self.collectInfo()
def parse(self):
BaseOpts.parse(self)
self.setDefault()
return self.args
================================================
FILE: options/stage1_opts.py
================================================
from .base_opts import BaseOpts
class TrainOpts(BaseOpts):
def __init__(self):
super(TrainOpts, self).__init__()
self.initialize()
def initialize(self):
BaseOpts.initialize(self)
#### Training Arguments ####
self.parser.add_argument('--solver', default='adam', help='adam|sgd')
self.parser.add_argument('--milestones', default=[5, 10, 15, 20, 25], nargs='+', type=int)
self.parser.add_argument('--start_epoch', default=1, type=int)
self.parser.add_argument('--epochs', default=20, type=int)
self.parser.add_argument('--batch', default=32, type=int)
self.parser.add_argument('--val_batch', default=8, type=int)
self.parser.add_argument('--init_lr', default=0.0005, type=float)
self.parser.add_argument('--lr_decay', default=0.5, type=float)
self.parser.add_argument('--beta_1', default=0.9, type=float, help='adam')
self.parser.add_argument('--beta_2', default=0.999, type=float, help='adam')
self.parser.add_argument('--momentum', default=0.9, type=float, help='sgd')
self.parser.add_argument('--w_decay', default=4e-4, type=float)
#### Loss Arguments ####
self.parser.add_argument('--normal_loss', default='cos', help='cos|mse')
self.parser.add_argument('--normal_w', default=1, type=float)
self.parser.add_argument('--dir_loss', default='cos', help='cos|mse')
self.parser.add_argument('--dir_w', default=1, type=float)
self.parser.add_argument('--ints_loss', default='mse', help='l1|mse')
self.parser.add_argument('--ints_w', default=1, type=float)
def collectInfo(self):
BaseOpts.collectInfo(self)
self.args.str_keys += [
'dir_loss',
]
self.args.val_keys += [
]
self.args.bool_keys += [
]
def setDefault(self):
BaseOpts.setDefault(self)
if self.args.test_h != self.args.crop_h:
self.args.test_h, self.args.test_w = self.args.crop_h, self.args.crop_w
self.collectInfo()
def parse(self):
BaseOpts.parse(self)
self.setDefault()
return self.args
================================================
FILE: options/stage2_opts.py
================================================
from .base_opts import BaseOpts
class TrainOpts(BaseOpts):
def __init__(self):
super(TrainOpts, self).__init__()
self.initialize()
def initialize(self):
BaseOpts.initialize(self)
#### Training Arguments ####
self.parser.add_argument('--solver', default='adam', help='adam|sgd')
self.parser.add_argument('--milestones', default=[2, 4, 6, 8, 10], nargs='+', type=int)
self.parser.add_argument('--start_epoch', default=1, type=int)
self.parser.add_argument('--epochs', default=10, type=int)
self.parser.add_argument('--batch', default=16, type=int)
self.parser.add_argument('--val_batch', default=8, type=int)
self.parser.add_argument('--init_lr', default=0.0005, type=float)
self.parser.add_argument('--lr_decay', default=0.5, type=float)
self.parser.add_argument('--beta_1', default=0.9, type=float, help='adam')
self.parser.add_argument('--beta_2', default=0.999, type=float, help='adam')
self.parser.add_argument('--momentum', default=0.9, type=float, help='sgd')
self.parser.add_argument('--w_decay', default=4e-4, type=float)
#### Loss Arguments ####
self.parser.add_argument('--normal_loss', default='cos', help='cos|mse')
self.parser.add_argument('--normal_w', default=1, type=float)
self.parser.add_argument('--dir_loss', default='mse', help='cos|mse')
self.parser.add_argument('--dir_w', default=1, type=float)
self.parser.add_argument('--ints_loss', default='mse', help='mse')
self.parser.add_argument('--ints_w', default=1, type=float)
def collectInfo(self):
BaseOpts.collectInfo(self)
self.args.str_keys += [
'model_s2'
]
self.args.val_keys += [
]
self.args.bool_keys += [
]
def setDefault(self):
BaseOpts.setDefault(self)
self.args.stage2 = True
self.args.test_resc = False
self.collectInfo()
def parse(self):
BaseOpts.parse(self)
self.setDefault()
return self.args
================================================
FILE: scripts/DiLiGenT_objects.txt
================================================
ballPNG
catPNG
pot1PNG
bearPNG
pot2PNG
buddhaPNG
gobletPNG
readingPNG
cowPNG
harvestPNG
================================================
FILE: scripts/cropDiLiGenTData.py
================================================
import os, argparse, sys, shutil, glob
import numpy as np
from imageio import imread, imsave
import scipy.io as sio
root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../')
sys.path.append(root_path)
from utils import utils
parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', default='data/datasets/DiLiGenT/pmsData')
parser.add_argument('--obj_list', default='objects.txt')
parser.add_argument('--suffix', default='crop')
parser.add_argument('--file_ext', default='.png')
parser.add_argument('--normal_name',default='Normal_gt.png')
parser.add_argument('--mask_name', default='mask.png')
parser.add_argument('--n_key', default='Normal_gt')
parser.add_argument('--pad', default=15, type=int)
args = parser.parse_args()
def getSaveDir():
dirName = os.path.dirname(args.input_dir)
save_dir = '%s_%s' % (args.input_dir, args.suffix)
utils.makeFile(save_dir)
print('Output dir: %s\n' % save_dir)
return save_dir
def getBBoxCompact(mask):
index = np.where(mask != 0)
t, b, l , r = index[0].min(), index[0].max(), index[1].min(), index[1].max()
h, w = b - t + 2 * args.pad, r - l + 2 * args.pad
t = max(0, t - args.pad)
b = t + h
l = max(0, l - args.pad)
r = l + w
if h % 4 != 0:
pad = 4 - h % 4
b += pad; h += pad
if w % 4 != 0:
pad = 4 - w % 4
r += pad; w += pad
return l, r, t, b, h, w
def loadMaskNormal(d):
mask = imread(os.path.join(args.input_dir, d, args.mask_name))
try:
normal = imread(os.path.join(args.input_dir, d, args.normal_name))
except IOError:
normal = imread(os.path.join(args.input_dir, d, 'normal_gt.png'))
n_mat = sio.loadmat(os.path.join(args.input_dir, d, 'Normal_gt.mat'))[args.n_key]
h, w, c = normal.shape
print('Processing Objects: %s' % d, mask.shape)
if mask.ndim < 3:
mask = mask.reshape(h, w, 1).repeat(3, 2)
return mask, normal, n_mat
def copyTXT(d):
txt = glob.glob(os.path.join(args.input_dir, '*.txt'))
for t in txt:
name = os.path.basename(t)
shutil.copy(t, os.path.join(args.save_dir, name))
txt = glob.glob(os.path.join(args.input_dir, d, '*.txt'))
for t in txt:
name = os.path.basename(t)
shutil.copy(t, os.path.join(args.save_dir, d, name))
if __name__ == '__main__':
print('Input dir: %s\n' % args.input_dir)
args.save_dir = getSaveDir()
dir_list = utils.readList(os.path.join(args.input_dir, args.obj_list))
name_list = utils.readList(os.path.join(args.input_dir, 'names.txt'))
max_h, max_w = 0, 0
crop_list = open(os.path.join(args.save_dir, 'crop.txt'), 'w')
for d in dir_list:
utils.makeFile(os.path.join(args.save_dir, d))
mask, normal, n_mat = loadMaskNormal(d)
l, r, t, b, h, w = getBBoxCompact(mask[:,:,0] / 255)
crop_list.write('%d %d %d %d %d %d\n' % (mask.shape[0], mask.shape[1], l, r, t, b))
max_h = h if h > max_h else max_h
max_w = w if w > max_w else max_w
print('\t BBox L %d R %d T %d B %d, H:%d W:%d, Padded: %d %d' %
(l, r, t, b, h, w, r - l, b - t))
imsave(os.path.join(args.save_dir, d, args.mask_name), mask[t:b, l:r, :])
imsave(os.path.join(args.save_dir, d, args.normal_name), normal[t:b, l:r, :])
sio.savemat(os.path.join(args.save_dir, d, 'Normal_gt.mat'),
{args.n_key: n_mat[t:b, l:r, :]} ,do_compression=True)
copyTXT(d)
intens = np.genfromtxt(os.path.join(args.input_dir, d, 'light_intensities.txt'))
for idx, name in enumerate(name_list):
#print('Process img %d/%d' % (idx+1, len(name_list)))
img = imread(os.path.join(args.input_dir, d, name))
img = img[t:b, l:r, :]
imsave(os.path.join(args.save_dir, d, name), img.astype(np.uint8))
print('Max H %d, Max %d' % (max_h, max_w))
crop_list.close()
================================================
FILE: scripts/download_pretrained_models.sh
================================================
path="data/models/"
mkdir -p $path
cd $path
# Download pre-trained model
for model in "LCNet_CVPR2019.pth.tar" "NENet_CVPR2019.pth.tar"; do
wget http://www.visionlab.cs.hku.hk/data/SDPS-Net/models/${model}
done
# Back to root directory
cd ../
================================================
FILE: scripts/download_synthetic_datasets.sh
================================================
mkdir -p data/datasets
cd data/datasets
# Download Synthetic dataset
for dataset in "PS_Sculpture_Dataset.tgz" "PS_Blobby_Dataset.tgz"; do
echo "Downloading $dataset"
wget http://www.visionlab.cs.hku.hk/data/PS-FCN/datasets/$dataset
tar -xvf $dataset
rm $dataset
done
# Back to root directory
cd ../../
================================================
FILE: scripts/prepare_diligent_dataset.sh
================================================
mkdir -p data/datasets
cd data/datasets
## Download real testing dataset
url="https://www.dropbox.com/s/hdnbh526tyvv68i/DiLiGenT.zip?dl=0"
name="DiLiGenT"
wget $url -O ${name}.zip
unzip ${name}.zip
rm ${name}.zip
cd ${name}/pmsData/
cp ballPNG/filenames.txt names.txt
# Back to root directory
cd ../../../../
cp scripts/DiLiGenT_objects.txt data/datasets/${name}/pmsData/objects.txt
python scripts/cropDiLiGenTData.py
================================================
FILE: test_stage1.py
================================================
import os
import torch
from models import model_utils
from utils import eval_utils, time_utils
import numpy as np
def get_itervals(args, split):
if split not in ['train', 'val', 'test']:
split = 'test'
args_var = vars(args)
disp_intv = args_var[split+'_disp']
save_intv = args_var[split+'_save']
stop_iters = args_var['max_'+split+'_iter']
return disp_intv, save_intv, stop_iters
def test(args, split, loader, model, log, epoch, recorder):
model.eval()
log.printWrite('---- Start %s Epoch %d: %d batches ----' % (split, epoch, len(loader)))
timer = time_utils.Timer(args.time_sync);
disp_intv, save_intv, stop_iters = get_itervals(args, split)
res = []
with torch.no_grad():
for i, sample in enumerate(loader):
data = model_utils.parseData(args, sample, timer, split)
input = model_utils.getInput(args, data)
pred = model(input); timer.updateTime('Forward')
recoder, iter_res, error = prepareRes(args, data, pred, recorder, log, split)
res.append(iter_res)
iters = i + 1
if iters % disp_intv == 0:
opt = {'split':split, 'epoch':epoch, 'iters':iters, 'batch':len(loader),
'timer':timer, 'recorder': recorder}
log.printItersSummary(opt)
if iters % save_intv == 0:
results, nrow = prepareSave(args, data, pred)
log.saveImgResults(results, split, epoch, iters, nrow=nrow, error=error)
log.plotCurves(recorder, split, epoch=epoch, intv=disp_intv)
if stop_iters > 0 and iters >= stop_iters: break
res = np.vstack([np.array(res), np.array(res).mean(0)])
save_name = '%s_res.txt' % (args.suffix)
np.savetxt(os.path.join(args.log_dir, split, save_name), res, fmt='%.2f')
opt = {'split': split, 'epoch': epoch, 'recorder': recorder}
log.printEpochSummary(opt)
def prepareRes(args, data, pred, recorder, log, split):
data_batch = args.val_batch if split == 'val' else args.test_batch
iter_res = []
error = ''
if args.s1_est_d:
l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred['dirs'].data, data_batch)
recorder.updateIter(split, l_acc.keys(), l_acc.values())
iter_res.append(l_acc['l_err_mean'])
error += 'D_%.3f-' % (l_acc['l_err_mean'])
if args.s1_est_i:
int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred['intens'].data, data_batch)
recorder.updateIter(split, int_acc.keys(), int_acc.values())
iter_res.append(int_acc['ints_ratio'])
error += 'I_%.3f-' % (int_acc['ints_ratio'])
if args.s1_est_n:
acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data)
recorder.updateIter(split, acc.keys(), acc.values())
iter_res.append(acc['n_err_mean'])
error += 'N_%.3f-' % (acc['n_err_mean'])
data['error_map'] = error_map['angular_map']
return recorder, iter_res, error
def prepareSave(args, data, pred):
results = [data['img'].data, data['m'].data, (data['n'].data+1)/2]
if args.s1_est_n:
pred_n = (pred['n'].data + 1) / 2
masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data)
res_n = [pred_n, masked_pred, data['error_map']]
results += res_n
nrow = data['img'].shape[0]
return results, nrow
================================================
FILE: test_stage2.py
================================================
import os
import torch
from models import model_utils
from utils import eval_utils, time_utils
import numpy as np
def get_itervals(args, split):
if split not in ['train', 'val', 'test']:
split = 'test'
args_var = vars(args)
disp_intv = args_var[split+'_disp']
save_intv = args_var[split+'_save']
stop_iters = args_var['max_'+split+'_iter']
return disp_intv, save_intv, stop_iters
def test(args, split, loader, models, log, epoch, recorder):
models[0].eval()
models[1].eval()
log.printWrite('---- Start %s Epoch %d: %d batches ----' % (split, epoch, len(loader)))
timer = time_utils.Timer(args.time_sync);
disp_intv, save_intv, stop_iters = get_itervals(args, split)
res = []
with torch.no_grad():
for i, sample in enumerate(loader):
data = model_utils.parseData(args, sample, timer, split)
input = model_utils.getInput(args, data)
pred_c = models[0](input); timer.updateTime('Forward')
input.append(pred_c)
pred = models[1](input); timer.updateTime('Forward')
recoder, iter_res, error = prepareRes(args, data, pred_c, pred, recorder, log, split)
res.append(iter_res)
iters = i + 1
if iters % disp_intv == 0:
opt = {'split':split, 'epoch':epoch, 'iters':iters, 'batch':len(loader),
'timer':timer, 'recorder': recorder}
log.printItersSummary(opt)
if iters % save_intv == 0:
results, nrow = prepareSave(args, data, pred_c, pred)
log.saveImgResults(results, split, epoch, iters, nrow=nrow, error='')
log.plotCurves(recorder, split, epoch=epoch, intv=disp_intv)
if stop_iters > 0 and iters >= stop_iters: break
res = np.vstack([np.array(res), np.array(res).mean(0)])
save_name = '%s_res.txt' % (args.suffix)
np.savetxt(os.path.join(args.log_dir, split, save_name), res, fmt='%.2f')
if res.ndim > 1:
for i in range(res.shape[1]):
save_name = '%s_%d_res.txt' % (args.suffix, i)
np.savetxt(os.path.join(args.log_dir, split, save_name), res[:,i], fmt='%.3f')
opt = {'split': split, 'epoch': epoch, 'recorder': recorder}
log.printEpochSummary(opt)
def prepareRes(args, data, pred_c, pred, recorder, log, split):
data_batch = args.val_batch if split == 'val' else args.test_batch
iter_res = []
error = ''
if args.s1_est_d:
l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred_c['dirs'].data, data_batch)
recorder.updateIter(split, l_acc.keys(), l_acc.values())
iter_res.append(l_acc['l_err_mean'])
error += 'D_%.3f-' % (l_acc['l_err_mean'])
if args.s1_est_i:
int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred_c['intens'].data, data_batch)
recorder.updateIter(split, int_acc.keys(), int_acc.values())
iter_res.append(int_acc['ints_ratio'])
error += 'I_%.3f-' % (int_acc['ints_ratio'])
if args.s2_est_n:
acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data)
recorder.updateIter(split, acc.keys(), acc.values())
iter_res.append(acc['n_err_mean'])
error += 'N_%.3f-' % (acc['n_err_mean'])
data['error_map'] = error_map['angular_map']
return recorder, iter_res, error
def prepareSave(args, data, pred_c, pred):
results = [data['img'].data, data['m'].data, (data['n'].data+1) / 2]
if args.s2_est_n:
pred_n = (pred['n'].data + 1) / 2
masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data)
res_n = [masked_pred, data['error_map']]
results += res_n
nrow = data['img'].shape[0]
return results, nrow
================================================
FILE: train_stage1.py
================================================
from models import model_utils
from utils import eval_utils, time_utils
def train(args, loader, model, criterion, optimizer, log, epoch, recorder):
model.train()
log.printWrite('---- Start Training Epoch %d: %d batches ----' % (epoch, len(loader)))
timer = time_utils.Timer(args.time_sync);
for i, sample in enumerate(loader):
data = model_utils.parseData(args, sample, timer, 'train')
input = model_utils.getInput(args, data)
pred = model(input); timer.updateTime('Forward')
optimizer.zero_grad()
loss = criterion.forward(pred, data);
timer.updateTime('Crit');
criterion.backward(); timer.updateTime('Backward')
recorder.updateIter('train', loss.keys(), loss.values())
optimizer.step(); timer.updateTime('Solver')
iters = i + 1
if iters % args.train_disp == 0:
opt = {'split':'train', 'epoch':epoch, 'iters':iters, 'batch':len(loader),
'timer':timer, 'recorder': recorder}
log.printItersSummary(opt)
if iters % args.train_save == 0:
results, recorder, nrow = prepareSave(args, data, pred, recorder, log)
log.saveImgResults(results, 'train', epoch, iters, nrow=nrow)
log.plotCurves(recorder, 'train', epoch=epoch, intv=args.train_disp)
if args.max_train_iter > 0 and iters >= args.max_train_iter: break
opt = {'split': 'train', 'epoch': epoch, 'recorder': recorder}
log.printEpochSummary(opt)
def prepareSave(args, data, pred, recorder, log):
results = [data['img'].data, data['m'].data, (data['n'].data+1)/2]
if args.s1_est_d:
l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred['dirs'].data, args.batch)
recorder.updateIter('train', l_acc.keys(), l_acc.values())
if args.s1_est_i:
int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred['intens'].data, args.batch)
recorder.updateIter('train', int_acc.keys(), int_acc.values())
if args.s1_est_n:
acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data)
pred_n = (pred['n'].data + 1) / 2
masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data)
res_n = [pred_n, masked_pred, error_map['angular_map']]
results += res_n
recorder.updateIter('train', acc.keys(), acc.values())
nrow = data['img'].shape[0] if data['img'].shape[0] <= 32 else 32
return results, recorder, nrow
================================================
FILE: train_stage2.py
================================================
import torch
from models import model_utils
from utils import eval_utils, time_utils
def train(args, loader, models, criterion, optimizers, log, epoch, recorder):
models[1].train()
models[0].eval()
optimizer, optimizer_c = optimizers
log.printWrite('---- Start Training Epoch %d: %d batches ----' % (epoch, len(loader)))
timer = time_utils.Timer(args.time_sync);
for i, sample in enumerate(loader):
data = model_utils.parseData(args, sample, timer, 'train')
input = model_utils.getInput(args, data)
with torch.no_grad():
pred_c = models[0](input);
input.append(pred_c)
pred = models[1](input); timer.updateTime('Forward')
optimizer.zero_grad()
loss = criterion.forward(pred, data);
timer.updateTime('Crit');
criterion.backward(); timer.updateTime('Backward')
recorder.updateIter('train', loss.keys(), loss.values())
optimizer.step(); timer.updateTime('Solver')
iters = i + 1
if iters % args.train_disp == 0:
opt = {'split':'train', 'epoch':epoch, 'iters':iters, 'batch':len(loader),
'timer':timer, 'recorder': recorder}
log.printItersSummary(opt)
if iters % args.train_save == 0:
results, recorder, nrow = prepareSave(args, data, pred_c, pred, recorder, log)
log.saveImgResults(results, 'train', epoch, iters, nrow=nrow)
log.plotCurves(recorder, 'train', epoch=epoch, intv=args.train_disp)
if args.max_train_iter > 0 and iters >= args.max_train_iter: break
opt = {'split': 'train', 'epoch': epoch, 'recorder': recorder}
log.printEpochSummary(opt)
def prepareSave(args, data, pred_c, pred, recorder, log):
input_var, mask_var = data['img'], data['m']
results = [input_var.data, mask_var.data, (data['n'].data+1)/2]
if args.s1_est_d:
l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred_c['dirs'].data, args.batch)
recorder.updateIter('train', l_acc.keys(), l_acc.values())
if args.s1_est_i:
int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred_c['intens'].data, args.batch)
recorder.updateIter('train', int_acc.keys(), int_acc.values())
if args.s2_est_n:
acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, mask_var.data)
pred_n = (pred['n'].data + 1) / 2
masked_pred = pred_n * mask_var.data.expand_as(pred['n'].data)
res_n = [masked_pred, error_map['angular_map']]
results += res_n
recorder.updateIter('train', acc.keys(), acc.values())
nrow = input_var.shape[0] if input_var.shape[0] <= 32 else 32
return results, recorder, nrow
================================================
FILE: utils/__init__.py
================================================
================================================
FILE: utils/eval_utils.py
================================================
import torch
import math
import numpy as np
from matplotlib import cm
def colorMap(diff):
thres = 90
diff_norm = np.clip(diff, 0, thres) / thres
diff_cm = torch.from_numpy(cm.jet(diff_norm.numpy()))[:,:,:, :3]
return diff_cm.permute(0,3,1,2).clone().float()
def calDirsAcc(gt_l, pred_l, data_batch=1):
n, c = gt_l.shape
pred_l = pred_l.view(n, c)
dot_product = (gt_l * pred_l).sum(1).clamp(-1, 1)
angular_err = torch.acos(dot_product) * 180.0 / math.pi
l_err_mean = angular_err.mean()
return {'l_err_mean': l_err_mean.item()}, angular_err.squeeze()
def calIntsAcc(gt_i, pred_i, data_batch=1):
n, c, h, w = gt_i.shape
pred_i = pred_i.view(n, c, h, w)
ref_int = gt_i[:, :3].repeat(1, gt_i.shape[1] // 3, 1, 1)
gt_i = gt_i / ref_int
scale = torch.gels(gt_i.view(-1, 1), pred_i.view(-1, 1))
ints_ratio = (gt_i - scale[0][0] * pred_i).abs() / (gt_i + 1e-8)
ints_error = torch.stack(ints_ratio.split(3, 1), 1).mean(2)
return {'ints_ratio': ints_ratio.mean().item()}, ints_error.squeeze()
def calNormalAcc(gt_n, pred_n, mask=None):
"""Tensor Dim: NxCxHxW"""
dot_product = (gt_n * pred_n).sum(1).clamp(-1,1)
error_map = torch.acos(dot_product) # [-pi, pi]
angular_map = error_map * 180.0 / math.pi
angular_map = angular_map * mask.narrow(1, 0, 1).squeeze(1)
valid = mask.narrow(1, 0, 1).sum()
ang_valid = angular_map[mask.narrow(1, 0, 1).squeeze(1).byte()]
n_err_mean = ang_valid.sum() / valid
n_err_med = ang_valid.median()
n_acc_11 = (ang_valid < 11.25).sum().float() / valid
n_acc_30 = (ang_valid < 30).sum().float() / valid
n_acc_45 = (ang_valid < 45).sum().float() / valid
angular_map = colorMap(angular_map.cpu().squeeze(1))
value = {'n_err_mean': n_err_mean.item(),
'n_acc_11': n_acc_11.item(), 'n_acc_30': n_acc_30.item(), 'n_acc_45': n_acc_45.item()}
angular_error_map = {'angular_map': angular_map}
return value, angular_error_map
def SphericalDirsToClass(dirs, cls_num):
theta = torch.atan(dirs[:,0] / (dirs[:,2] + 1e-8))
denom = torch.sqrt(dirs[:,0] * dirs[:,0] + dirs[:,2] * dirs[:,2])
phi = torch.atan(dirs[:,1] / (denom + 1e-8))
theta = theta / np.pi * 180
phi = phi / np.pi * 180
azimuth = ((theta + 90.0) / 180 * cls_num).clamp(0, cls_num-1).long()
elevate = ((phi + 90.0) / 180 * cls_num).clamp(0, cls_num-1).long()
return azimuth, elevate
def SphericalClassToDirs(x_cls, y_cls, cls_num):
theta = (x_cls.float() + 0.5) / cls_num * 180 - 90
phi = (y_cls.float() + 0.5) / cls_num * 180 - 90
neg_x = theta < 0
neg_y = phi < 0
theta = theta.clamp(-90, 90) / 180.0 * np.pi
phi = phi.clamp(-90, 90) / 180.0 * np.pi
tan2_phi = pow(torch.tan(phi), 2)
tan2_theta = pow(torch.tan(theta), 2)
y = torch.sqrt(tan2_phi / (1 + tan2_phi))
y[neg_y] = y[neg_y] * -1
#y = torch.sin(phi)
z = torch.sqrt((1 - y * y) / (1 + tan2_theta))
x = z * torch.tan(theta)
dirs = torch.stack([x,y,z], 1)
dirs = dirs / dirs.norm(p=2, dim=1, keepdim=True)
return dirs
def LightIntsToClass(ints, cls_num):
ints = (ints - 0.2) / 1.8
ints = (ints * cls_num).clamp(0, cls_num-1).long()
return ints.view(-1)
def ClassToLightInts(cls, cls_num):
ints = (cls.float() + 0.5) / cls_num * 1.8 + 0.2
ints = ints.clamp(0.2 , 2.0)
return ints
================================================
FILE: utils/logger.py
================================================
import datetime, time, os
import numpy as np
import torch
import torchvision.utils as vutils
import scipy.io as sio
from . import utils
import matplotlib; matplotlib.use('agg')
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.font_manager import FontProperties
fontP = FontProperties()
fontP.set_size('small')
plt.rcParams["figure.figsize"] = (5,8)
class Logger(object):
def __init__(self, args):
self.times = {'init': time.time()}
if args.make_dir:
self._setupDirs(args)
self.args = args
args.log = self
self.printArgs()
def printArgs(self):
strs = '------------ Options -------------\n'
strs += '{}'.format(utils.dictToString(vars(self.args)))
strs += '-------------- End ----------------\n'
self.printWrite(strs)
def _addArguments(self, args):
info = ''
if hasattr(args, 'run_model') and args.run_model:
info += '_run_model,%s' % os.path.basename(args.retrain).split('.')[0]
arg_var = vars(args)
for k in args.str_keys:
info = '{0},{1}'.format(info, arg_var[k])
for k in args.val_keys:
var_key = k[:2] + '_' + k[-1]
info = '{0},{1}-{2}'.format(info, var_key, arg_var[k])
for k in args.bool_keys:
info = '{0},{1}'.format(info, k) if arg_var[k] else info
return info
def _setupDirs(self, args):
date_now = datetime.datetime.now()
dir_name = '%d-%d' % (date_now.month, date_now.day)
dir_name += (',%s' % args.suffix) if args.suffix else ''
dir_name += self._addArguments(args)
dir_name += ',DEBUG' if args.debug else ''
self._checkPath(args, dir_name)
file_dir = os.path.join(args.log_dir, '%s,%s' % (dir_name, date_now.strftime('%H:%M:%S')))
self.log_fie = open(file_dir, 'w')
return
def _checkPath(self, args, dir_name):
if hasattr(args, 'run_model') and args.run_model:
log_root = os.path.join(os.path.dirname(args.retrain), dir_name)
args.log_dir = log_root
sub_dirs = ['test']
else:
if args.resume and os.path.isfile(args.resume):
log_root = os.path.join(os.path.dirname(os.path.dirname(args.resume)), dir_name)
else:
if args.debug:
dir_name = 'DEBUG/' + dir_name
log_root = os.path.join(args.save_root, args.dataset, args.item, dir_name)
args.log_dir = os.path.join(log_root, 'logdir')
args.cp_dir = os.path.join(log_root, 'checkpointdir')
utils.makeFiles([args.log_dir, args.cp_dir])
sub_dirs = ['train', 'val']
for sub_dir in sub_dirs:
utils.makeFiles([os.path.join(args.log_dir, sub_dir, 'Images')])
def printWrite(self, strs):
print('%s' % strs)
if self.args.make_dir:
self.log_fie.write('%s\n' % strs)
self.log_fie.flush()
def getTimeInfo(self, epoch, iters, batch):
time_elapsed = (time.time() - self.times['init']) / 3600.0
total_iters = (self.args.epochs - self.args.start_epoch + 1) * batch
cur_iters = (epoch - self.args.start_epoch) * batch + iters
time_total = time_elapsed * (float(total_iters) / cur_iters)
return time_elapsed, time_total
def printItersSummary(self, opt):
epoch, iters, batch = opt['epoch'], opt['iters'], opt['batch']
strs = ' | {}'.format(str.upper(opt['split']))
strs += ' Iter [{}/{}] Epoch [{}/{}]'.format(iters, batch, epoch, self.args.epochs)
if opt['split'] == 'train':
time_elapsed, time_total = self.getTimeInfo(epoch, iters, batch) # Buggy for test
strs += ' Clock [{:.2f}h/{:.2f}h]'.format(time_elapsed, time_total)
strs += ' LR [{}]'.format(opt['recorder'].records[opt['split']]['lr'][epoch][0])
self.printWrite(strs)
if 'timer' in opt.keys():
self.printWrite(opt['timer'].timeToString())
if 'recorder' in opt.keys():
self.printWrite(opt['recorder'].iterRecToString(opt['split'], epoch))
def printEpochSummary(self, opt):
split = opt['split']
epoch = opt['epoch']
self.printWrite('---------- {} Epoch {} Summary -----------'.format(str.upper(split), epoch))
self.printWrite(opt['recorder'].epochRecToString(split, epoch))
def convertToSameSize(self, t_list):
shape = (t_list[0].shape[0], 3, t_list[0].shape[2], t_list[0].shape[3])
for i, tensor in enumerate(t_list):
n, c, h, w = tensor.shape
if tensor.shape[1] != shape[1]: # check channel
t_list[i] = tensor.expand((n, shape[1], h, w))
if h == shape[2] and w == shape[3]:
continue
t_list[i] = torch.nn.functional.upsample(tensor, [shape[2], shape[3]], mode='bilinear')
return t_list
def getSaveDir(self, split, epoch):
save_dir = os.path.join(self.args.log_dir, split, 'Images')
run_model = hasattr(self.args, 'run_model') and self.args.run_model
if not run_model and epoch > 0:
save_dir = os.path.join(save_dir, str(epoch))
utils.makeFile(save_dir)
return save_dir
def splitMulitChannel(self, t_list, max_save_n = 8):
new_list = []
for tensor in t_list:
if tensor.shape[1] > 3:
num = 3
new_list += torch.split(tensor, num, 1)[:max_save_n]
else:
new_list.append(tensor)
return new_list
def saveSplit(self, res, save_prefix):
n, c, h, w = res.shape
for i in range(n):
vutils.save_image(res[i], save_prefix + '_%d.png' % (i))
def saveImgResults(self, results, split, epoch, iters, nrow, error=''):
max_save_n = self.args.test_save_n if split == 'test' else self.args.train_save_n
res = [img.cpu() for img in results]
res = self.splitMulitChannel(res, max_save_n)
res = torch.cat(self.convertToSameSize(res))
save_dir = self.getSaveDir(split, epoch)
save_prefix = os.path.join(save_dir, '%d_%d' % (epoch, iters))
save_prefix += ('_%s' % error) if error != '' else ''
if self.args.save_split:
self.saveSplit(res, save_prefix)
else:
vutils.save_image(res, save_prefix + '_out.png', nrow=nrow)
def plotCurves(self, recorder, split='train', epoch=-1, intv=1):
dict_of_array = recorder.recordToDictOfArray(split, epoch, intv)
save_dir = os.path.join(self.args.log_dir, split)
if epoch < 0:
save_dir = self.args.log_dir
save_name = '%s_Summary.png' % (split)
else:
save_name = '%s_epoch_%d.png' % (split, epoch)
classes = ['loss', 'acc', 'err', 'lr', 'ratio']
classes = utils.checkIfInList(classes, dict_of_array.keys())
if len(classes) == 0: return
for idx, c in enumerate(classes):
plt.subplot(len(classes), 1, idx+1)
plt.grid()
legends = []
for k in dict_of_array.keys():
if (c in k.lower()) and not k.endswith('_x'):
plt.plot(dict_of_array[k+'_x'], dict_of_array[k])
legends.append(k)
if len(legends) != 0:
plt.legend(legends, bbox_to_anchor=(0.5, 1.05), loc='upper center',
ncol=len(legends), prop=fontP)
plt.title(c)
if epoch < 0: plt.xlabel('Epoch')
else: plt.xlabel('Iters')
plt.tight_layout()
plt.savefig(os.path.join(save_dir, save_name))
plt.clf()
================================================
FILE: utils/recorders.py
================================================
from collections import OrderedDict
import numpy as np
class Records(object):
"""
Records->Train,Val->Loss,Accuracy->Epoch1,2,3->[v1,v2]
IterRecords->Train,Val->Loss, Accuracy,->[v1,v2]
"""
def __init__(self, log_dir, records=None):
if records == None:
self.records = OrderedDict()
else:
self.records = records
self.iter_rec = OrderedDict()
self.log_dir = log_dir
self.classes = ['loss', 'acc', 'err', 'ratio']
def resetIter(self):
self.iter_rec.clear()
def checkDict(self, a_dict, key, sub_type='dict'):
if key not in a_dict.keys():
if sub_type == 'dict':
a_dict[key] = OrderedDict()
if sub_type == 'list':
a_dict[key] = []
def updateIter(self, split, keys, values):
self.checkDict(self.iter_rec, split, 'dict')
for k, v in zip(keys, values):
self.checkDict(self.iter_rec[split], k, 'list')
self.iter_rec[split][k].append(v)
def saveIterRecord(self, epoch, reset=True):
for s in self.iter_rec.keys(): # s stands for split
self.checkDict(self.records, s, 'dict')
for k in self.iter_rec[s].keys():
self.checkDict(self.records[s], k, 'dict')
self.checkDict(self.records[s][k], epoch, 'list')
self.records[s][k][epoch].append(np.mean(self.iter_rec[s][k]))
if reset:
self.resetIter()
def insertRecord(self, split, key, epoch, value):
self.checkDict(self.records, split, 'dict')
self.checkDict(self.records[split], key, 'dict')
self.checkDict(self.records[split][key], epoch, 'list')
self.records[split][key][epoch].append(value)
def iterRecToString(self, split, epoch):
rec_strs = ''
for c in self.classes:
strs = ''
for k in self.iter_rec[split].keys():
if (c in k.lower()):
strs += '{}: {:.3f}| '.format(k, np.mean(self.iter_rec[split][k]))
if strs != '':
rec_strs += '\t [{}] {}\n'.format(c.upper(), strs)
self.saveIterRecord(epoch)
return rec_strs
def epochRecToString(self, split, epoch):
rec_strs = ''
for c in self.classes:
strs = ''
for k in self.records[split].keys():
if (c in k.lower()) and (epoch in self.records[split][k].keys()):
strs += '{}: {:.3f}| '.format(k, np.mean(self.records[split][k][epoch]))
if strs != '':
rec_strs += '\t [{}] {}\n'.format(c.upper(), strs)
return rec_strs
def recordToDictOfArray(self, splits, epoch=-1, intv=1):
if len(self.records) == 0: return {}
if type(splits) == str: splits = [splits]
dict_of_array = OrderedDict()
for split in splits:
for k in self.records[split].keys():
y_array, x_array = [], []
if epoch < 0:
for ep in self.records[split][k].keys():
y_array.append(np.mean(self.records[split][k][ep]))
x_array.append(ep)
else:
if epoch in self.records[split][k].keys():
y_array = np.array(self.records[split][k][epoch])
x_array = np.linspace(intv, intv*len(y_array), len(y_array))
dict_of_array[split[0] + split[-1] + '_' + k] = y_array
dict_of_array[split[0] + split[-1] + '_' + k+'_x'] = x_array
return dict_of_array
================================================
FILE: utils/time_utils.py
================================================
import time
import torch
from collections import OrderedDict
class Timer(object):
def __init__(self, cuda_sync=False):
self.timer = OrderedDict()
self.cuda_sync = cuda_sync
self.startTimer()
def startTimer(self):
self.iter_start = time.time()
self.disp_start = time.time()
def resetTimer(self):
self.iter_start = time.time()
self.disp_start = time.time()
for key in self.timer.keys(): self.timer[key].reset()
def updateTime(self, key):
if key not in self.timer.keys(): self.timer[key] = AverageMeter()
if self.cuda_sync: torch.cuda.synchronize()
self.timer[key].update(time.time() - self.iter_start)
self.iter_start = time.time()
def timeToString(self, reset=True):
strs = '\t [Time %.3fs] ' % (time.time() - self.disp_start)
for key in self.timer.keys():
if self.timer[key].sum < 1e-4: continue
strs += '%s: %.3fs| ' % (key, self.timer[key].sum)
self.resetTimer()
return strs
class AverageMeter(object):
def __init__(self):
self.reset()
def reset(self):
self.sum = 0
self.count = 0
self.avg = 0
def update(self, val, n=1):
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __repr__(self):
return '%.3f' % (self.avg)
================================================
FILE: utils/utils.py
================================================
import os
import numpy as np
from imageio import imread, imsave
def makeFile(f):
if not os.path.exists(f):
os.makedirs(f)
#else: raise Exception('Rendered image directory %s is already existed!!!' % directory)
def makeFiles(f_list):
for f in f_list:
makeFile(f)
def emptyFile(name):
with open(name, 'w') as f:
f.write(' ')
def dictToString(dicts, start='\t', end='\n'):
strs = ''
for k, v in sorted(dicts.items()):
strs += '%s%s: %s%s' % (start, str(k), str(v), end)
return strs
def checkIfInList(list1, list2):
contains = []
for l1 in list1:
for l2 in list2:
if l1 in l2.lower():
contains.append(l1)
break
return contains
def atoi(text):
return int(text) if text.isdigit() else text
def natural_keys(text):
'''
alist.sort(key=natural_keys) sorts in human order
http://nedbatchelder.com/blog/200712/human_sorting.html
(See Toothy's implementation in the comments)
'''
return [ atoi(c) for c in re.split('(\d+)', text) ]
def readList(list_path,ignore_head=False, sort=False):
lists = []
with open(list_path) as f:
lists = f.read().splitlines()
if ignore_head:
lists = lists[1:]
if sort:
lists.sort(key=natural_keys)
return lists