Repository: guanyingc/SDPS-Net Branch: master Commit: b84ca3aaa5cd Files: 41 Total size: 100.9 KB Directory structure: gitextract_g3mq9mkn/ ├── .gitignore ├── LICENSE.txt ├── README.md ├── data/ │ └── .gitignore ├── datasets/ │ ├── UPS_Custom_Dataset.py │ ├── UPS_DiLiGenT_main.py │ ├── UPS_Synth_Dataset.py │ ├── __init__.py │ ├── custom_data_loader.py │ ├── pms_transforms.py │ └── util.py ├── eval/ │ ├── run_stage1.py │ └── run_stage2.py ├── main_stage1.py ├── main_stage2.py ├── models/ │ ├── LCNet.py │ ├── NENet.py │ ├── __init__.py │ ├── custom_model.py │ ├── model_utils.py │ └── solver_utils.py ├── options/ │ ├── __init__.py │ ├── base_opts.py │ ├── run_model_opts.py │ ├── stage1_opts.py │ └── stage2_opts.py ├── scripts/ │ ├── DiLiGenT_objects.txt │ ├── cropDiLiGenTData.py │ ├── download_pretrained_models.sh │ ├── download_synthetic_datasets.sh │ └── prepare_diligent_dataset.sh ├── test_stage1.py ├── test_stage2.py ├── train_stage1.py ├── train_stage2.py └── utils/ ├── __init__.py ├── eval_utils.py ├── logger.py ├── recorders.py ├── time_utils.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.jpg *.png tags *.pyc *.pyo ================================================ FILE: LICENSE.txt ================================================ MIT License Copyright (c) 2018 Guanying Chen Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # SDPS-Net **[SDPS-Net: Self-calibrating Deep Photometric Stereo Networks, CVPR 2019 (Oral)](http://guanyingc.github.io/SDPS-Net/)**.
[Guanying Chen](https://guanyingc.github.io), [Kai Han](http://www.hankai.org/), [Boxin Shi](http://alumni.media.mit.edu/~shiboxin/), [Yasuyuki Matsushita](http://www-infobiz.ist.osaka-u.ac.jp/en/member/matsushita/), [Kwan-Yee K. Wong](http://i.cs.hku.hk/~kykwong/)
This paper addresses the problem of learning based _uncalibrated_ photometric stereo for non-Lambertian surface.

### _Changelog_ - July 28, 2019: We have already updated the code to support Python 3.7 + PyTorch 1.10. To run the previous version (Python 2.7 + PyTorch 0.40), please checkout to `python2.7` branch first (by `git checkout python2.7`). ## Dependencies SDPS-Net is implemented in [PyTorch](https://pytorch.org/) and tested with Ubuntu (14.04 and 16.04), please install PyTorch first following the official instruction. - Python 3.7 - PyTorch (version = 1.10) - torchvision - CUDA-9.0 # Skip this if you only have CPUs in your computer - numpy - scipy - scikit-image You are highly recommended to use Anaconda and create a new environment to run this code. ```shell # Create a new python3.7 environment named py3.7 conda create -n py3.7 python=3.7 # Activate the created environment source activate py3.7 # Example commands for installing the dependencies conda install pytorch torchvision cudatoolkit=9.0 -c pytorch conda install -c anaconda scipy conda install -c anaconda scikit-image # Download this code git clone https://github.com/guanyingc/SDPS-Net.git cd SDPS-Net ``` ## Overview We provide: - Trained models - LCNet for lighting calibration from input images - NENet for normal estimation from input images and estimated lightings. - Code to test on DiLiGenT main dataset - Full code to train a new model, including codes for debugging, visualization and logging. ## Testing ### Download the trained models ``` sh scripts/download_pretrained_models.sh ``` If the above command is not working, please manually download the trained models from BaiduYun ([LCNet and NENet](https://pan.baidu.com/s/10huOyPkfDSkDUK23_j4y1w?pwd=i5ha)) and put them in `./data/models/`. ### Test SDPS-Net on the DiLiGenT main dataset ```shell # Prepare the DiLiGenT main dataset sh scripts/prepare_diligent_dataset.sh # This command will first download and unzip the DiLiGenT dataset, and then centered crop # the original images based on the object mask with a margin size of 15 pixels. # Test SDPS-Net on DiLiGenT main dataset using all of the 96 image CUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar # Please check the outputs in data/models/ # If you only have CPUs, please add the argument "--cuda" to disable the usage of GPU python eval/run_stage2.py --cuda --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar ``` ### Test SDPS-Net on your own dataset You have two options to test our method on your dataset. In the first option, you have to implement a customized Dataset class to load your data, which should not be difficult. Please refer to `datasets/UPS_DiLiGenT_main.py` for an example that loads the DiLiGenT main dataset. If you don't want to implement your own Dataset class, you may try our `datasets/UPS_Custom_Dataset.py`. However, you have to first arrange your dataset in the same format as the `data/ToyPSDataset/`. Then you can call the following commands. ```shell CUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar --benchmark UPS_Custom_Dataset --bm_dir /path/to/your/dataset # To test SDPS-Net on the ToyPSDataset, simply run CUDA_VISIBLE_DEVICES=0 python eval/run_stage2.py --retrain data/models/LCNet_CVPR2019.pth.tar --retrain_s2 data/models/NENet_CVPR2019.pth.tar --benchmark UPS_Custom_Dataset --bm_dir data/ToyPSDataset/ # Please check the outputs in data/models/ ``` You may find input arguments in `run_model_opts.py` (particularly `--have_l_dirs`, `--have_l_ints`, and `--have_gt_n`) useful when testing your own dataset. ## Training We adopted the publicly available synthetic [PS Blobby and Sculpture datasets](https://github.com/guanyingc/PS-FCN) for training. To train a new SDPS-Net model, please follow the following steps: ### Download the training data ```shell # The total size of the zipped synthetic datasets is 4.7+19=23.7 GB # and it takes some times to download and unzip the datasets. sh scripts/download_synthetic_datasets.sh ``` If the above command is not working, please manually download the training datasets from BaiduYun ([PS Sculpture Dataset and PS Blobby Dataset](https://pan.baidu.com/s/1WUVu9ibIBh4wM1shTXBuNw?pwd=snyc) and put them in `./data/datasets/`. ### First stage: train Lighting Calibration Network (LCNet) ```shell # Train LCNet on synthetic datasets using 32 input images CUDA_VISIBLE_DEVICES=0 python main_stage1.py --in_img_num 32 # Please refer to options/base_opt.py and options/stage1_opt.py for more options # You can find checkpoints and results in data/logdir/ # It takes about 20 hours to train LCNet on a single Titan X Pascal GPU. ``` ### Second stage: train Normal Estimation Network (NENet) ```shell # Train NENet on synthetic datasets using 32 input images CUDA_VISIBLE_DEVICES=0 python main_stage2.py --in_img_num 32 --retrain data/logdir/path/to/checkpointDirOfLCNet/checkpoint20.pth.tar # Please refer to options/base_opt.py and options/stage2_opt.py for more options # You can find checkpoints and results in data/logdir/ # It takes about 26 hours to train NENet on a single Titan X Pascal GPU. ``` ## FAQ #### Q1: How to test SDPS-Net on other datasets? - You can implement a customized Dataset class to load your data. You may also use the provided `datasets/UPS_Custom_Dataset.py` Dataset class to load your data. However, you have to first arrange your dataset in the same format as the `data/ToyPSDataset/`. Precomputed results on DiLiGenT main dataset, Gourd\&Apple dataset, Light Stage Dataset and Synthetic Test dataset are available upon request. #### Q2: What should I do if I have problem in running your code? - Please create an issue if you encounter errors when trying to run the code. Please also feel free to submit a bug report. #### Q3: Could I run your code only using CPUs? - The good news is that you can simply append `--cuda` in your command to disable the usage of GPU. The running time for the testing on DiLiGenT benchmark using CPUs is still bearable (should be less than 20 minutes). However, it is EXTREMELY SLOW for training! ## Citation If you find this code or the provided models useful in your research, please consider cite: ``` @inproceedings{chen2019SDPS_Net, title={SDPS-Net: Self-calibrating Deep Photometric Stereo Networks}, author={Chen, Guanying and Han, Kai and Shi, Boxin and Matsushita, Yasuyuki and Wong, Kwan-Yee~K.}, booktitle={CVPR}, year={2019} } ``` ================================================ FILE: data/.gitignore ================================================ * !.gitignore ================================================ FILE: datasets/UPS_Custom_Dataset.py ================================================ from __future__ import division import os import numpy as np import scipy.io as sio from imageio import imread import torch import torch.utils.data as data from datasets import pms_transforms from . import util np.random.seed(0) class UPS_Custom_Dataset(data.Dataset): def __init__(self, args, split='train'): self.root = os.path.join(args.bm_dir) self.split = split self.args = args self.objs = util.readList(os.path.join(self.root, 'objects.txt'), sort=False) args.log.printWrite('[%s Data] \t%d objs. Root: %s' % (split, len(self.objs), self.root)) def _getMask(self, obj): mask = imread(os.path.join(self.root, obj, 'mask.png')) if mask.ndim > 2: mask = mask[:,:,0] mask = mask.reshape(mask.shape[0], mask.shape[1], 1) return mask / 255.0 def __getitem__(self, index): obj = self.objs[index] names = util.readList(os.path.join(self.root, obj, 'names.txt')) img_list = [os.path.join(self.root, obj, names[i]) for i in range(len(names))] if self.args.have_l_dirs: dirs = np.genfromtxt(os.path.join(self.root, obj, 'light_directions.txt')) else: dirs = np.zeros((len(names), 3)) dirs[:,2] = 1 if self.args.have_l_ints: ints = np.genfromtxt(os.path.join(self.root, obj, 'light_intensities.txt')) else: ints = np.ones((len(names), 3)) imgs = [] for idx, img_name in enumerate(img_list): img = imread(img_name).astype(np.float32) / 255.0 imgs.append(img) img = np.concatenate(imgs, 2) h, w, c = img.shape if self.args.have_gt_n: normal_path = os.path.join(self.root, obj, 'normal.png') normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1 else: normal = np.zeros((h, w, 3)) mask = self._getMask(obj) img = img * mask.repeat(img.shape[2], 2) item = {'normal': normal, 'img': img, 'mask': mask} downsample = 4 for k in item.keys(): item[k] = pms_transforms.imgSizeToFactorOfK(item[k], downsample) for k in item.keys(): item[k] = pms_transforms.arrayToTensor(item[k]) item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float() item['ints'] = torch.from_numpy(ints).view(-1, 1, 1).float() item['obj'] = obj item['path'] = os.path.join(self.root, obj) return item def __len__(self): return len(self.objs) ================================================ FILE: datasets/UPS_DiLiGenT_main.py ================================================ from __future__ import division import os import numpy as np import scipy.io as sio #from scipy.ndimage import imread from imageio import imread import torch import torch.utils.data as data from datasets import pms_transforms from . import util np.random.seed(0) class UPS_DiLiGenT_main(data.Dataset): def __init__(self, args, split='train'): self.root = os.path.join(args.bm_dir) self.split = split self.args = args self.objs = util.readList(os.path.join(self.root, 'objects.txt'), sort=False) self.names = util.readList(os.path.join(self.root, 'names.txt'), sort=False) self.l_dir = util.light_source_directions() args.log.printWrite('[%s Data] \t%d objs %d lights. Root: %s' % (split, len(self.objs), len(self.names), self.root)) self.ints = {} ints_name = 'light_intensities.txt' print('Files for intensity: %s' % (ints_name)) for obj in self.objs: self.ints[obj] = np.genfromtxt(os.path.join(self.root, obj, ints_name)) def _getMask(self, obj): mask = imread(os.path.join(self.root, obj, 'mask.png')) if mask.ndim > 2: mask = mask[:,:,0] mask = mask.reshape(mask.shape[0], mask.shape[1], 1) return mask / 255.0 def __getitem__(self, index): np.random.seed(index) obj = self.objs[index] select_idx = range(len(self.names)) img_list = [os.path.join(self.root, obj, self.names[i]) for i in select_idx] ints = [np.diag(1 / self.ints[obj][i]) for i in select_idx] dirs = self.l_dir[select_idx] normal_path = os.path.join(self.root, obj, 'Normal_gt.mat') normal = sio.loadmat(normal_path)['Normal_gt'] imgs = [] for idx, img_name in enumerate(img_list): img = imread(img_name).astype(np.float32) / 255.0 if self.args.in_light and not self.args.int_aug: img = np.dot(img, ints[idx]) imgs.append(img) img = np.concatenate(imgs, 2) mask = self._getMask(obj) if self.args.test_resc: img, normal = pms_transforms.rescale(img, normal, [self.args.test_h, self.args.test_w]) mask = pms_transforms.rescaleSingle(mask, [self.args.test_h, self.args.test_w]) img = img * mask.repeat(img.shape[2], 2) norm = np.sqrt((normal * normal).sum(2, keepdims=True)) normal = normal / (norm + 1e-10) item = {'normal': normal, 'img': img, 'mask': mask} downsample = 4 for k in item.keys(): item[k] = pms_transforms.imgSizeToFactorOfK(item[k], downsample) for k in item.keys(): item[k] = pms_transforms.arrayToTensor(item[k]) item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float() item['ints'] = torch.from_numpy(self.ints[obj][select_idx]).view(-1, 1, 1).float() item['obj'] = obj item['path'] = os.path.join(self.root, obj) return item def __len__(self): return len(self.objs) ================================================ FILE: datasets/UPS_Synth_Dataset.py ================================================ from __future__ import division import os import numpy as np #from scipy.ndimage import imread from imageio import imread import torch import torch.utils.data as data from datasets import pms_transforms from . import util np.random.seed(0) class UPS_Synth_Dataset(data.Dataset): def __init__(self, args, root, split='train'): self.root = os.path.join(root) self.split = split self.args = args self.shape_list = util.readList(os.path.join(self.root, split + args.l_suffix)) def _getInputPath(self, index): shape, mtrl = self.shape_list[index].split('/') normal_path = os.path.join(self.root, 'Images', shape, shape + '_normal.png') img_dir = os.path.join(self.root, 'Images', self.shape_list[index]) img_list = util.readList(os.path.join(img_dir, '%s_%s.txt' % (shape, mtrl))) data = np.genfromtxt(img_list, dtype='str', delimiter=' ') select_idx = np.random.permutation(data.shape[0])[:self.args.in_img_num] idxs = ['%03d' % (idx) for idx in select_idx] data = data[select_idx, :] imgs = [os.path.join(img_dir, img) for img in data[:, 0]] dirs = data[:, 1:4].astype(np.float32) return normal_path, imgs, dirs def __getitem__(self, index): normal_path, img_list, dirs = self._getInputPath(index) normal = imread(normal_path).astype(np.float32) / 255.0 * 2 - 1 imgs = [] for i in img_list: img = imread(i).astype(np.float32) / 255.0 imgs.append(img) img = np.concatenate(imgs, 2) h, w, c = img.shape crop_h, crop_w = self.args.crop_h, self.args.crop_w if self.args.rescale and not (crop_h == h): sc_h = np.random.randint(crop_h, h) if self.args.rand_sc else self.args.scale_h sc_w = np.random.randint(crop_w, w) if self.args.rand_sc else self.args.scale_w img, normal = pms_transforms.rescale(img, normal, [sc_h, sc_w]) if self.args.crop: img, normal = pms_transforms.randomCrop(img, normal, [crop_h, crop_w]) if self.args.color_aug: img = img * np.random.uniform(1, self.args.color_ratio) if self.args.int_aug: ints = pms_transforms.getIntensity(len(imgs)) img = np.dot(img, np.diag(ints.reshape(-1))) else: ints = np.ones(c) if self.args.noise_aug: img = pms_transforms.randomNoiseAug(img, self.args.noise) mask = pms_transforms.normalToMask(normal) normal = normal * mask.repeat(3, 2) norm = np.sqrt((normal * normal).sum(2, keepdims=True)) normal = normal / (norm + 1e-10) # Rescale normal to unit length item = {'normal': normal, 'img': img, 'mask': mask} for k in item.keys(): item[k] = pms_transforms.arrayToTensor(item[k]) item['dirs'] = torch.from_numpy(dirs).view(-1, 1, 1).float() item['ints'] = torch.from_numpy(ints).view(-1, 1, 1).float() return item def __len__(self): return len(self.shape_list) ================================================ FILE: datasets/__init__.py ================================================ #from .PMS_dataset_v1 import PMS_dataset #from .PMS_dataset_v2 import PMS_data_v2 #from .DiLiGenT import DiLiGenT #__all__ = ('PMS_dataset', 'DiLiGenT', 'PMS_data_v2') ================================================ FILE: datasets/custom_data_loader.py ================================================ import torch.utils.data def customDataloader(args): args.log.printWrite("=> fetching img pairs in %s" % (args.data_dir)) datasets = __import__('datasets.' + args.dataset) dataset_file = getattr(datasets, args.dataset) train_set = getattr(dataset_file, args.dataset)(args, args.data_dir, 'train') val_set = getattr(dataset_file, args.dataset)(args, args.data_dir, 'val') if args.concat_data: args.log.printWrite('****** Using cocnat data ******') args.log.printWrite("=> fetching img pairs in '{}'".format(args.data_dir2)) train_set2 = getattr(dataset_file, args.dataset)(args, args.data_dir2, 'train') val_set2 = getattr(dataset_file, args.dataset)(args, args.data_dir2, 'val') train_set = torch.utils.data.ConcatDataset([train_set, train_set2]) val_set = torch.utils.data.ConcatDataset([val_set, val_set2]) args.log.printWrite('Found Data:\t %d Train and %d Val' % (len(train_set), len(val_set))) args.log.printWrite('\t Train Batch: %d, Val Batch: %d' % (args.batch, args.val_batch)) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch, num_workers=args.workers, pin_memory=args.cuda, shuffle=True) test_loader = torch.utils.data.DataLoader(val_set , batch_size=args.val_batch, num_workers=args.workers, pin_memory=args.cuda, shuffle=False) return train_loader, test_loader def benchmarkLoader(args): args.log.printWrite("=> fetching img pairs in 'data/%s'" % (args.benchmark)) datasets = __import__('datasets.' + args.benchmark) dataset_file = getattr(datasets, args.benchmark) test_set = getattr(dataset_file, args.benchmark)(args, 'test') args.log.printWrite('Found Benchmark Data: %d samples' % (len(test_set))) args.log.printWrite('\t Test Batch %d' % (args.test_batch)) test_loader = torch.utils.data.DataLoader(test_set, batch_size=args.test_batch, num_workers=args.workers, pin_memory=args.cuda, shuffle=False) return test_loader ================================================ FILE: datasets/pms_transforms.py ================================================ import torch import random import numpy as np from skimage.transform import resize random.seed(0) np.random.seed(0) def arrayToTensor(array): if array is None: return array array = np.transpose(array, (2, 0, 1)) tensor = torch.from_numpy(array) return tensor.float() def normalToMask(normal, thres=1e-2): """ Due to the numerical precision of uint8, [0, 0, 0] will save as [127, 127, 127] in gt normal, When we load the data and rescale normal by N / 255 * 2 - 1, [127, 127, 127] becomes [-0.003927, -0.003927, -0.003927] """ mask = (np.square(normal).sum(2, keepdims=True) > thres).astype(np.float32) return mask def imgSizeToFactorOfK(img, k): if img.shape[0] % k == 0 and img.shape[1] % k == 0: return img pad_h, pad_w = k - img.shape[0] % k, k - img.shape[1] % k img = np.pad(img, ((0, pad_h), (0, pad_w), (0,0)), 'constant', constant_values=((0,0),(0,0),(0,0))) return img def randomCrop(inputs, target, size): h, w, _ = inputs.shape c_h, c_w = size if h == c_h and w == c_w: return inputs, target x1 = random.randint(0, w - c_w) y1 = random.randint(0, h - c_h) inputs = inputs[y1: y1 + c_h, x1: x1 + c_w] target = target[y1: y1 + c_h, x1: x1 + c_w] return inputs, target def centerCrop(inputs, size): h, w, _ = inputs.shape c_h, c_w = size if h != c_h or w != c_w: x1 = int(w / 2 - c_w / 2) y1 = int(h / 2 - c_h / 2) inputs = inputs[y1: y1 + c_h, x1: x1 + c_w] return inputs def rescale(inputs, target, size): in_h, in_w, _ = inputs.shape h, w = size if h != in_h or w != in_w: inputs = resize(inputs, size, order=1, mode='reflect') target = resize(target, size, order=1, mode='reflect') return inputs, target def rescaleSingle(inputs, size, order=1): in_h, in_w, _ = inputs.shape h, w = size if h != in_h or w != in_w: inputs = resize(inputs, size, order=order, mode='reflect') return inputs def randomNoiseAug(inputs, noise_level=0.05): noise = np.random.random(inputs.shape) noise = (noise - 0.5) * noise_level inputs += noise return inputs def getIntensity(num): intensity = np.random.random((num, 1)) * 1.8 + 0.2 color = np.ones((1, 3)) # Uniform color intens = (intensity.repeat(3, 1) * color) return intens ================================================ FILE: datasets/util.py ================================================ import numpy as np import re def atoi(text): return int(text) if text.isdigit() else text def natural_keys(text): ''' alist.sort(key=natural_keys) sorts in human order http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy's implementation in the comments) ''' return [ atoi(c) for c in re.split('(\d+)', text) ] def readList(list_path,ignore_head=False, sort=True): lists = [] with open(list_path) as f: lists = f.read().splitlines() if ignore_head: lists = lists[1:] if sort: lists.sort(key=natural_keys) return lists def light_source_directions(): """ Below matrix is from DiLiGenT. :return: light source direction matrix. [light_num, 3] :rtype: np.ndarray """ L = np.array([[-0.06059872, -0.44839055, 0.8917812], [-0.05939919, -0.33739538, 0.93948714], [-0.05710194, -0.21230722, 0.97553319], [-0.05360061, -0.07800089, 0.99551134], [-0.04919816, 0.05869781, 0.99706274], [-0.04399823, 0.19019233, 0.98076044], [-0.03839991, 0.31049925, 0.9497977], [-0.03280081, 0.41611025, 0.90872238], [-0.18449839, -0.43989616, 0.87889232], [-0.18870114, -0.32950199, 0.92510557], [-0.1901994, -0.20549935, 0.95999698], [-0.18849605, -0.07269848, 0.97937948], [-0.18329657, 0.06229884, 0.98108166], [-0.17500445, 0.19220488, 0.96562453], [-0.16449474, 0.31129005, 0.93597008], [-0.15270716, 0.4160195, 0.89644202], [-0.30139786, -0.42509698, 0.85349393], [-0.31020115, -0.31660118, 0.89640333], [-0.31489186, -0.19549495, 0.92877599], [-0.31450962, -0.06640203, 0.94692897], [-0.30880699, 0.06470146, 0.94892147], [-0.2981084, 0.19100538, 0.93522635], [-0.28359251, 0.30729189, 0.90837601], [-0.26670649, 0.41020998, 0.87212122], [-0.40709586, -0.40559588, 0.81839168], [-0.41919869, -0.29999906, 0.85689732], [-0.42618633, -0.18329412, 0.88587159], [-0.42691512, -0.05950211, 0.90233197], [-0.42090385, 0.0659006, 0.90470827], [-0.40860354, 0.18720162, 0.89330773], [-0.39141794, 0.29941372, 0.87013988], [-0.3707838, 0.39958255, 0.83836338], [-0.499596, -0.38319693, 0.77689378], [-0.51360334, -0.28130183, 0.81060526], [-0.52190667, -0.16990217, 0.83591069], [-0.52326874, -0.05249686, 0.85054918], [-0.51720021, 0.06620003, 0.85330035], [-0.50428312, 0.18139393, 0.84427174], [-0.48561334, 0.28870793, 0.82512267], [-0.46289771, 0.38549809, 0.79819605], [-0.57853599, -0.35932235, 0.73224555], [-0.59329349, -0.26189713, 0.76119165], [-0.60202327, -0.15630604, 0.78303027], [-0.6037003, -0.04570002, 0.7959004], [-0.59781529, 0.06590169, 0.79892043], [-0.58486953, 0.17439091, 0.79215873], [-0.56588359, 0.27639198, 0.77677747], [-0.54241965, 0.36921337, 0.75462733], [0.05220076, -0.43870637, 0.89711304], [0.05199786, -0.33138635, 0.9420612], [0.05109826, -0.20999284, 0.97636672], [0.04919919, -0.07869871, 0.99568366], [0.04640163, 0.05630197, 0.99733494], [0.04279892, 0.18779527, 0.98127529], [0.03870043, 0.30950341, 0.95011048], [0.03440055, 0.41730662, 0.90811441], [0.17290651, -0.43181626, 0.88523333], [0.17839998, -0.32509996, 0.92869988], [0.18160174, -0.20480196, 0.96180921], [0.18200745, -0.07490306, 0.98044012], [0.17919505, 0.05849838, 0.98207285], [0.17329685, 0.18839658, 0.96668244], [0.1649036, 0.30880674, 0.93672045], [0.1549931, 0.41578148, 0.89616009], [0.28720483, -0.41910705, 0.8613145], [0.29740177, -0.31410186, 0.90160535], [0.30420604, -0.1965039, 0.9321185], [0.30640529, -0.07010121, 0.94931639], [0.30361153, 0.05950226, 0.95093613], [0.29588748, 0.18589214, 0.93696036], [0.28409783, 0.30349768, 0.90949304], [0.26939905, 0.40849857, 0.87209694], [0.39120402, -0.40190413, 0.8279085], [0.40481085, -0.29960803, 0.86392315], [0.41411685, -0.18590756, 0.89103626], [0.41769724, -0.06449957, 0.906294], [0.41498764, 0.05959822, 0.90787296], [0.40607977, 0.18089099, 0.89575537], [0.39179226, 0.29439419, 0.87168279], [0.37379609, 0.39649585, 0.83849122], [0.48278794, -0.38169046, 0.78818031], [0.49848546, -0.28279175, 0.8194761], [0.50918069, -0.1740934, 0.84286803], [0.51360856, -0.05870098, 0.85601427], [0.51097962, 0.05899765, 0.8575658], [0.50151639, 0.17420569, 0.84742769], [0.48600297, 0.28260173, 0.82700506], [0.46600106, 0.38110087, 0.79850181], [0.56150442, -0.35990283, 0.74510586], [0.57807114, -0.26498677, 0.77176147], [0.58933134, -0.1617086, 0.7915421], [0.59407609, -0.05289787, 0.80266769], [0.59157958, 0.057798, 0.80417224], [0.58198189, 0.16649482, 0.79597523], [0.56620006, 0.26940003, 0.77900008], [0.54551481, 0.36380988, 0.7550205]], dtype=float) return L ================================================ FILE: eval/run_stage1.py ================================================ import torch, sys sys.path.append('.') from datasets import custom_data_loader from options import run_model_opts from models import custom_model from utils import logger, recorders import test_stage1 as test_utils args = run_model_opts.RunModelOpts().parse() log = logger.Logger(args) def main(args): test_loader = custom_data_loader.benchmarkLoader(args) model = custom_model.buildModel(args) recorder = recorders.Records(args.log_dir) test_utils.test(args, 'test', test_loader, model, log, 1, recorder) log.plotCurves(recorder, 'test') if __name__ == '__main__': torch.manual_seed(args.seed) main(args) ================================================ FILE: eval/run_stage2.py ================================================ import torch, sys sys.path.append('.') from datasets import custom_data_loader from options import run_model_opts from models import custom_model from utils import logger, recorders import test_stage2 as test_utils args = run_model_opts.RunModelOpts().parse() args.stage2 = True args.test_resc = False log = logger.Logger(args) def main(args): test_loader = custom_data_loader.benchmarkLoader(args) model = custom_model.buildModel(args) model_s2 = custom_model.buildModelStage2(args) models = [model, model_s2] recorder = recorders.Records(args.log_dir) test_utils.test(args, 'test', test_loader, models, log, 1, recorder) log.plotCurves(recorder, 'test') if __name__ == '__main__': torch.manual_seed(args.seed) main(args) ================================================ FILE: main_stage1.py ================================================ import torch from options import stage1_opts from utils import logger, recorders from datasets import custom_data_loader from models import custom_model, solver_utils, model_utils import train_stage1 as train_utils import test_stage1 as test_utils args = stage1_opts.TrainOpts().parse() log = logger.Logger(args) def main(args): model = custom_model.buildModel(args) optimizer, scheduler, records = solver_utils.configOptimizer(args, model) criterion = solver_utils.Stage1ClsCrit(args) recorder = recorders.Records(args.log_dir, records) train_loader, val_loader = custom_data_loader.customDataloader(args) for epoch in range(args.start_epoch, args.epochs+1): scheduler.step() recorder.insertRecord('train', 'lr', epoch, scheduler.get_lr()[0]) train_utils.train(args, train_loader, model, criterion, optimizer, log, epoch, recorder) if epoch % args.save_intv == 0: model_utils.saveCheckpoint(args.cp_dir, epoch, model, optimizer, recorder.records, args) log.plotCurves(recorder, 'train') if epoch % args.val_intv == 0: test_utils.test(args, 'val', val_loader, model, log, epoch, recorder) log.plotCurves(recorder, 'val') if __name__ == '__main__': torch.manual_seed(args.seed) main(args) ================================================ FILE: main_stage2.py ================================================ import torch from options import stage2_opts from utils import logger, recorders from datasets import custom_data_loader from models import custom_model, solver_utils, model_utils import train_stage2 as train_utils import test_stage2 as test_utils args = stage2_opts.TrainOpts().parse() log = logger.Logger(args) def main(args): model = custom_model.buildModel(args) model_s2 = custom_model.buildModelStage2(args) models = [model, model_s2] optimizer, scheduler, records = solver_utils.configOptimizer(args, model_s2) optimizers = [optimizer, -1] criterion = solver_utils.Stage2Crit(args) recorder = recorders.Records(args.log_dir, records) train_loader, val_loader = custom_data_loader.customDataloader(args) for epoch in range(args.start_epoch, args.epochs+1): scheduler.step() recorder.insertRecord('train', 'lr', epoch, scheduler.get_lr()[0]) train_utils.train(args, train_loader, models, criterion, optimizers, log, epoch, recorder) if epoch % args.save_intv == 0: model_utils.saveCheckpoint(args.cp_dir, epoch, model_s2, optimizer, recorder.records, args) log.plotCurves(recorder, 'train') if epoch % args.val_intv == 0: test_utils.test(args, 'val', val_loader, models, log, epoch, recorder) log.plotCurves(recorder, 'val') if __name__ == '__main__': torch.manual_seed(args.seed) main(args) ================================================ FILE: models/LCNet.py ================================================ import torch import torch.nn as nn from torch.nn.init import kaiming_normal_ from . import model_utils from utils import eval_utils # Classification class FeatExtractor(nn.Module): def __init__(self, batchNorm, c_in, c_out=256): super(FeatExtractor, self).__init__() self.conv1 = model_utils.conv(batchNorm, c_in, 64, k=3, stride=2, pad=1) self.conv2 = model_utils.conv(batchNorm, 64, 128, k=3, stride=2, pad=1) self.conv3 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1) self.conv4 = model_utils.conv(batchNorm, 128, 128, k=3, stride=2, pad=1) self.conv5 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1) self.conv6 = model_utils.conv(batchNorm, 128, 256, k=3, stride=2, pad=1) self.conv7 = model_utils.conv(batchNorm, 256, 256, k=3, stride=1, pad=1) def forward(self, inputs): out = self.conv1(inputs) out = self.conv2(out) out = self.conv3(out) out = self.conv4(out) out = self.conv5(out) out = self.conv6(out) out = self.conv7(out) return out class Classifier(nn.Module): def __init__(self, batchNorm, c_in, other): super(Classifier, self).__init__() self.conv1 = model_utils.conv(batchNorm, 512, 256, k=3, stride=1, pad=1) self.conv2 = model_utils.conv(batchNorm, 256, 256, k=3, stride=2, pad=1) self.conv3 = model_utils.conv(batchNorm, 256, 256, k=3, stride=2, pad=1) self.conv4 = model_utils.conv(batchNorm, 256, 256, k=3, stride=2, pad=1) self.other = other self.dir_x_est = nn.Sequential( model_utils.conv(batchNorm, 256, 64, k=1, stride=1, pad=0), model_utils.outputConv(64, other['dirs_cls'], k=1, stride=1, pad=0)) self.dir_y_est = nn.Sequential( model_utils.conv(batchNorm, 256, 64, k=1, stride=1, pad=0), model_utils.outputConv(64, other['dirs_cls'], k=1, stride=1, pad=0)) self.int_est = nn.Sequential( model_utils.conv(batchNorm, 256, 64, k=1, stride=1, pad=0), model_utils.outputConv(64, other['ints_cls'], k=1, stride=1, pad=0)) def forward(self, inputs): out = self.conv1(inputs) out = self.conv2(out) out = self.conv3(out) out = self.conv4(out) outputs = {} if self.other['s1_est_d']: outputs['dir_x'] = self.dir_x_est(out) outputs['dir_y'] = self.dir_y_est(out) if self.other['s1_est_i']: outputs['ints'] = self.int_est(out) return outputs class LCNet(nn.Module): def __init__(self, fuse_type='max', batchNorm=False, c_in=3, other={}): super(LCNet, self).__init__() self.featExtractor = FeatExtractor(batchNorm, c_in, 128) self.classifier = Classifier(batchNorm, 256, other) self.c_in = c_in self.fuse_type = fuse_type self.other = other for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def prepareInputs(self, x): n, c, h, w = x[0].shape t_h, t_w = self.other['test_h'], self.other['test_w'] if (h == t_h and w == t_w): imgs = x[0] else: print('Rescaling images: from %dX%d to %dX%d' % (h, w, t_h, t_w)) imgs = torch.nn.functional.upsample(x[0], size=(t_h, t_w), mode='bilinear') inputs = list(torch.split(imgs, 3, 1)) idx = 1 if self.other['in_light']: light = torch.split(x[idx], 3, 1) for i in range(len(inputs)): inputs[i] = torch.cat([inputs[i], light[i]], 1) idx += 1 if self.other['in_mask']: mask = x[idx] if mask.shape[2] != inputs[0].shape[2] or mask.shape[3] != inputs[0].shape[3]: mask = torch.nn.functional.upsample(mask, size=(t_h, t_w), mode='bilinear') for i in range(len(inputs)): inputs[i] = torch.cat([inputs[i], mask], 1) idx += 1 return inputs def fuseFeatures(self, feats, fuse_type): if fuse_type == 'mean': feat_fused = torch.stack(feats, 1).mean(1) elif fuse_type == 'max': feat_fused, _ = torch.stack(feats, 1).max(1) return feat_fused def convertMidDirs(self, pred): _, x_idx = pred['dirs_x'].data.max(1) _, y_idx = pred['dirs_y'].data.max(1) dirs = eval_utils.SphericalClassToDirs(x_idx, y_idx, self.other['dirs_cls']) return dirs def convertMidIntens(self, pred, img_num): _, idx = pred['ints'].data.max(1) ints = eval_utils.ClassToLightInts(idx, self.other['ints_cls']) ints = ints.view(-1, 1).repeat(1, 3) ints = torch.cat(torch.split(ints, ints.shape[0] // img_num, 0), 1) return ints def forward(self, x): inputs = self.prepareInputs(x) feats = [] for i in range(len(inputs)): out_feat = self.featExtractor(inputs[i]) shape = out_feat.data.shape feats.append(out_feat) feat_fused = self.fuseFeatures(feats, self.fuse_type) l_dirs_x, l_dirs_y, l_ints = [], [], [] for i in range(len(inputs)): net_input = torch.cat([feats[i], feat_fused], 1) outputs = self.classifier(net_input) if self.other['s1_est_d']: l_dirs_x.append(outputs['dir_x']) l_dirs_y.append(outputs['dir_y']) if self.other['s1_est_i']: l_ints.append(outputs['ints']) pred = {} if self.other['s1_est_d']: pred['dirs_x'] = torch.cat(l_dirs_x, 0).squeeze() pred['dirs_y'] = torch.cat(l_dirs_y, 0).squeeze() pred['dirs'] = self.convertMidDirs(pred) if self.other['s1_est_i']: pred['ints'] = torch.cat(l_ints, 0).squeeze() if pred['ints'].ndimension() == 1: pred['ints'] = pred['ints'].view(1, -1) pred['intens'] = self.convertMidIntens(pred, len(inputs)) return pred ================================================ FILE: models/NENet.py ================================================ import torch import torch.nn as nn from torch.nn.init import kaiming_normal_ from . import model_utils class FeatExtractor(nn.Module): def __init__(self, batchNorm=False, c_in=3, other={}): super(FeatExtractor, self).__init__() self.other = other self.conv1 = model_utils.conv(batchNorm, c_in, 64, k=3, stride=1, pad=1) self.conv2 = model_utils.conv(batchNorm, 64, 128, k=3, stride=2, pad=1) self.conv3 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1) self.conv4 = model_utils.conv(batchNorm, 128, 256, k=3, stride=2, pad=1) self.conv5 = model_utils.conv(batchNorm, 256, 256, k=3, stride=1, pad=1) self.conv6 = model_utils.deconv(256, 128) self.conv7 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1) def forward(self, x): out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) out = self.conv4(out) out = self.conv5(out) out = self.conv6(out) out_feat = self.conv7(out) n, c, h, w = out_feat.data.shape out_feat = out_feat.view(-1) return out_feat, [n, c, h, w] class Regressor(nn.Module): def __init__(self, batchNorm=False, other={}): super(Regressor, self).__init__() self.other = other self.deconv1 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1) self.deconv2 = model_utils.conv(batchNorm, 128, 128, k=3, stride=1, pad=1) self.deconv3 = model_utils.deconv(128, 64) self.est_normal = self._make_output(64, 3, k=3, stride=1, pad=1) self.other = other def _make_output(self, cin, cout, k=3, stride=1, pad=1): return nn.Sequential( nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False)) def forward(self, x, shape): x = x.view(shape[0], shape[1], shape[2], shape[3]) out = self.deconv1(x) out = self.deconv2(out) out = self.deconv3(out) normal = self.est_normal(out) normal = torch.nn.functional.normalize(normal, 2, 1) return normal class NENet(nn.Module): def __init__(self, fuse_type='max', batchNorm=False, c_in=3, other={}): super(NENet, self).__init__() self.extractor = FeatExtractor(batchNorm, c_in, other) self.regressor = Regressor(batchNorm, other) self.c_in = c_in self.fuse_type = fuse_type self.other = other for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): kaiming_normal_(m.weight.data) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def prepareInputs(self, x): imgs = torch.split(x[0], 3, 1) idx = 1 if self.other['in_light']: idx += 1 if self.other['in_mask']: idx += 1 dirs = torch.split(x[idx]['dirs'], x[0].shape[0], 0) ints = torch.split(x[idx]['intens'], 3, 1) s2_inputs = [] for i in range(len(imgs)): n, c, h, w = imgs[i].shape l_dir = dirs[i] if dirs[i].dim() == 4 else dirs[i].view(n, -1, 1, 1) l_int = torch.diag(1.0 / (ints[i].contiguous().view(-1)+1e-8)) img = imgs[i].contiguous().view(n * c, h * w) img = torch.mm(l_int, img).view(n, c, h, w) img_light = torch.cat([img, l_dir.expand_as(img)], 1) s2_inputs.append(img_light) return s2_inputs def forward(self, x): inputs = self.prepareInputs(x) feats = torch.Tensor() for i in range(len(inputs)): feat, shape = self.extractor(inputs[i]) if i == 0: feats = feat else: if self.fuse_type == 'mean': feats = torch.stack([feats, feat], 1).sum(1) elif self.fuse_type == 'max': feats, _ = torch.stack([feats, feat], 1).max(1) if self.fuse_type == 'mean': feats = feats / len(img_split) feat_fused = feats normal = self.regressor(feat_fused, shape) pred = {} pred['n'] = normal return pred ================================================ FILE: models/__init__.py ================================================ ================================================ FILE: models/custom_model.py ================================================ from . import model_utils import torch def buildModel(args): print('Creating Model %s' % (args.model)) in_c = model_utils.getInputChanel(args) other = { 'img_num': args.in_img_num, 'test_h': args.test_h, 'test_w': args.test_w, 'in_mask': args.in_mask, 'in_light': args.in_light, 'dirs_cls': args.dirs_cls, 'ints_cls': args.ints_cls, 's1_est_d': args.s1_est_d, 's1_est_i': args.s1_est_i, 's1_est_n': args.s1_est_n, } models = __import__('models.' + args.model) model_file = getattr(models, args.model) model = getattr(model_file, args.model)(args.fuse_type, args.use_BN, in_c, other) if args.cuda: model = model.cuda() if args.retrain: args.log.printWrite("=> using pre-trained model '{}'".format(args.retrain)) model_utils.loadCheckpoint(args.retrain, model, cuda=args.cuda) if args.resume: args.log.printWrite("=> Resume loading checkpoint '{}'".format(args.resume)) model_utils.loadCheckpoint(args.resume, model, cuda=args.cuda) print(model) args.log.printWrite("=> Model Parameters: %d" % (model_utils.get_n_params(model))) return model def buildModelStage2(args): print('Creating Stage2 Model %s' % (args.model_s2)) in_c = 6 if args.s2_in_light else 3 other = { 'img_num': args.in_img_num, 'in_mask': args.in_mask, 'in_light': args.in_light, 'dirs_cls': args.dirs_cls, 'ints_cls': args.ints_cls, } models = __import__('models.' + args.model_s2) model_file = getattr(models, args.model_s2) model = getattr(model_file, args.model_s2)(args.fuse_type, args.use_BN, in_c, other) if args.cuda: model = model.cuda() if args.retrain_s2: args.log.printWrite("=> using pre-trained model_s2 '{}'".format(args.retrain_s2)) model_utils.loadCheckpoint(args.retrain_s2, model, cuda=args.cuda) print(model) args.log.printWrite("=> Stage2 Model Parameters: %d" % (model_utils.get_n_params(model))) return model ================================================ FILE: models/model_utils.py ================================================ import os import torch import torch.nn as nn def getInput(args, data): input_list = [data['img']] if args.in_light: input_list.append(data['dirs']) if args.in_mask: input_list.append(data['m']) return input_list def parseData(args, sample, timer=None, split='train'): img, normal, mask = sample['img'], sample['normal'], sample['mask'] ints = sample['ints'] if args.in_light: dirs = sample['dirs'].expand_as(img) else: # predict lighting, prepare ground truth n, c, h, w = sample['dirs'].shape dirs_split = torch.split(sample['dirs'].view(n, c), 3, 1) dirs = torch.cat(dirs_split, 0) if timer: timer.updateTime('ToCPU') if args.cuda: img, normal, mask = img.cuda(), normal.cuda(), mask.cuda() dirs, ints = dirs.cuda(), ints.cuda() if timer: timer.updateTime('ToGPU') data = {'img': img, 'n': normal, 'm': mask, 'dirs': dirs, 'ints': ints} return data def getInputChanel(args): args.log.printWrite('[Network Input] Color image as input') c_in = 3 if args.in_light: args.log.printWrite('[Network Input] Adding Light direction as input') c_in += 3 if args.in_mask: args.log.printWrite('[Network Input] Adding Mask as input') c_in += 1 args.log.printWrite('[Network Input] Input channel: {}'.format(c_in)) return c_in def get_n_params(model): pp = 0 for p in list(model.parameters()): nn = 1 for s in list(p.size()): nn = nn * s pp += nn return pp def loadCheckpoint(path, model, cuda=True): if cuda: checkpoint = torch.load(path) else: checkpoint = torch.load(path, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) def saveCheckpoint(save_path, epoch=-1, model=None, optimizer=None, records=None, args=None): state = {'state_dict': model.state_dict(), 'model': args.model} records = {'epoch': epoch, 'optimizer':optimizer.state_dict(), 'records': records} # 'args': args} torch.save(state, os.path.join(save_path, 'checkp_{}.pth.tar'.format(epoch))) torch.save(records, os.path.join(save_path, 'checkp_{}_rec.pth.tar'.format(epoch))) def conv_ReLU(batchNorm, cin, cout, k=3, stride=1, pad=-1): pad = pad if pad >= 0 else (k - 1) // 2 if batchNorm: print('=> convolutional layer with bachnorm') return nn.Sequential( nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False), nn.BatchNorm2d(cout), nn.ReLU(inplace=True) ) else: return nn.Sequential( nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True), nn.ReLU(inplace=True) ) def conv(batchNorm, cin, cout, k=3, stride=1, pad=-1): pad = pad if pad >= 0 else (k - 1) // 2 if batchNorm: print('=> convolutional layer with bachnorm') return nn.Sequential( nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=False), nn.BatchNorm2d(cout), nn.LeakyReLU(0.1, inplace=True) ) else: return nn.Sequential( nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True), nn.LeakyReLU(0.1, inplace=True) ) def outputConv(cin, cout, k=3, stride=1, pad=1): return nn.Sequential( nn.Conv2d(cin, cout, kernel_size=k, stride=stride, padding=pad, bias=True)) def deconv(cin, cout): return nn.Sequential( nn.ConvTranspose2d(cin, cout, kernel_size=4, stride=2, padding=1, bias=False), nn.LeakyReLU(0.1, inplace=True) ) def upconv(cin, cout): return nn.Sequential( nn.Upsample(scale_factor=2, mode='bilinear'), nn.Conv2d(cin, cout, kernel_size=3, stride=1, padding=1, bias=False), nn.LeakyReLU(0.1, inplace=True) ) ================================================ FILE: models/solver_utils.py ================================================ import torch import os from utils import eval_utils class Stage1ClsCrit(object): # First Stage, Light classification criterion def __init__(self, args): print('==> Using Stage1ClsCrit for lighting classification') self.s1_est_d = args.s1_est_d self.s1_est_i = args.s1_est_i self.ints_cls, self.dirs_cls = args.ints_cls, args.dirs_cls self.setupLightCrit(args) def setupLightCrit(self, args): args.log.printWrite('=> Using light criterion') if self.s1_est_d: self.dir_w = args.dir_w self.dirs_x_crit = torch.nn.CrossEntropyLoss() self.dirs_y_crit = torch.nn.CrossEntropyLoss() if args.cuda: self.dirs_x_crit = self.dirs_x_crit.cuda() self.dirs_y_crit = self.dirs_y_crit.cuda() if self.s1_est_i: self.ints_w = args.ints_w self.ints_crit = torch.nn.CrossEntropyLoss() if args.cuda: self.ints_crit = self.ints_crit.cuda() def forward(self, output, target): self.loss = 0 out_loss = {} if self.s1_est_d: est_dir_x, est_dir_y = output['dirs_x'], output['dirs_y'] gt_dir_x, gt_dir_y = eval_utils.SphericalDirsToClass(target['dirs'], self.dirs_cls) dirs_x_loss = self.dirs_x_crit(est_dir_x, gt_dir_x) dirs_y_loss = self.dirs_y_crit(est_dir_y, gt_dir_y) out_loss['D_x_loss'] = dirs_x_loss.item() out_loss['D_y_loss'] = dirs_y_loss.item() self.loss += self.dir_w * (dirs_x_loss + dirs_y_loss) if self.s1_est_i: est_intens = output['ints'] gt_ints = target['ints'].squeeze()[:, 0: target['ints'].shape[1]:3] gt_ints = torch.cat(torch.split(gt_ints, 1, 1), 0) gt_intens = eval_utils.LightIntsToClass(gt_ints, self.ints_cls) ints_loss = self.ints_crit(est_intens, gt_intens) out_loss['I_loss'] = ints_loss.item() self.loss += self.ints_w * ints_loss return out_loss def backward(self): self.loss.backward() class Stage2Crit(object): # Second stage def __init__(self, args): self.s2_est_n = args.s2_est_n self.s2_est_d = args.s2_est_d self.s2_est_i = args.s2_est_i self.setupLightCrit(args) if self.s2_est_n: self.setupNormalCrit(args) def setupLightCrit(self, args): args.log.printWrite('=> Using light criterion') if self.s2_est_d: self.dir_w = args.dir_w self.dirs_crit = torch.nn.CosineEmbeddingLoss() if args.cuda: self.dirs_crit = self.dirs_crit.cuda() if self.s2_est_i: self.ints_w = args.ints_w self.ints_crit = torch.nn.MSELoss() if args.cuda: self.ints_crit = self.ints_crit.cuda() def setupNormalCrit(self, args): args.log.printWrite('=> Using {} for criterion normal'.format(args.normal_loss)) self.normal_w = args.normal_w if args.normal_loss == 'mse': self.n_crit = torch.nn.MSELoss() elif args.normal_loss == 'cos': self.n_crit = torch.nn.CosineEmbeddingLoss() else: raise Exception("=> Unknown Criterion '{}'".format(args.normal_loss)) if args.cuda: self.n_crit = self.n_crit.cuda() def forward(self, output, target): self.loss = 0 out_loss = {} if self.s2_est_d: d_est, d_tar = output['dirs'], target['dirs'] d_num = d_tar.nelement() // d_tar.shape[1] if not hasattr(self, 'l_flag') or d_num != self.l_flag.nelement(): self.l_flag = d_tar.data.new().resize_(d_num).fill_(1) dirs_loss = self.dirs_crit(d_est.squeeze(), d_tar, self.l_flag) self.loss += self.dir_w * dirs_loss out_loss['D_loss'] = dirs_loss.item() if self.s2_est_i: i_est, i_tar = output['ints'], target['ints'] ints_loss = self.ints_crit(i_est, i_tar.squeeze()) self.loss += self.ints_w * ints_loss out_loss['I_loss'] = ints_loss.item() if self.s2_est_n: n_est, n_tar = output['n'], target['n'] n_num = n_tar.nelement() // n_tar.shape[1] if not hasattr(self, 'n_flag') or n_num != self.n_flag.nelement(): self.n_flag = n_tar.data.new().resize_(n_num).fill_(1) self.out_reshape = n_est.permute(0, 2, 3, 1).contiguous().view(-1, 3) self.gt_reshape = n_tar.permute(0, 2, 3, 1).contiguous().view(-1, 3) normal_loss = self.n_crit(self.out_reshape, self.gt_reshape, self.n_flag) self.loss += self.normal_w * normal_loss out_loss['N_loss'] = normal_loss.item() return out_loss def backward(self): self.loss.backward() def getOptimizer(args, params): args.log.printWrite('=> Using %s solver for optimization' % (args.solver)) if args.solver == 'adam': optimizer = torch.optim.Adam(params, args.init_lr, betas=(args.beta_1, args.beta_2)) elif args.solver == 'sgd': optimizer = torch.optim.SGD(params, args.init_lr, momentum=args.momentum) else: raise Exception("=> Unknown Optimizer %s" % (args.solver)) return optimizer def getLrScheduler(args, optimizer): scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=args.lr_decay, last_epoch=args.start_epoch-2) return scheduler def loadRecords(path, model, optimizer): records = None if os.path.isfile(path): records = torch.load(path[:-8] + '_rec' + path[-8:]) optimizer.load_state_dict(records['optimizer']) start_epoch = records['epoch'] + 1 records = records['records'] print("=> loaded Records") else: raise Exception("=> no checkpoint found at '{}'".format(path)) return records, start_epoch def configOptimizer(args, model): records = None optimizer = getOptimizer(args, model.parameters()) if args.resume: args.log.printWrite("=> Resume loading checkpoint '{}'".format(args.resume)) records, start_epoch = loadRecords(args.resume, model, optimizer) args.start_epoch = start_epoch scheduler = getLrScheduler(args, optimizer) return optimizer, scheduler, records ================================================ FILE: options/__init__.py ================================================ ================================================ FILE: options/base_opts.py ================================================ import argparse import os import torch class BaseOpts(object): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) def initialize(self): #### Trainining Dataset #### self.parser.add_argument('--dataset', default='UPS_Synth_Dataset') self.parser.add_argument('--data_dir', default='data/datasets/PS_Blobby_Dataset') self.parser.add_argument('--data_dir2', default='data/datasets/PS_Sculpture_Dataset') self.parser.add_argument('--concat_data', default=True, action='store_false') self.parser.add_argument('--l_suffix', default='_mtrl.txt') #### Training Data and Preprocessing Arguments #### self.parser.add_argument('--rescale', default=True, action='store_false') self.parser.add_argument('--rand_sc', default=True, action='store_false') self.parser.add_argument('--scale_h', default=128, type=int) self.parser.add_argument('--scale_w', default=128, type=int) self.parser.add_argument('--crop', default=True, action='store_false') self.parser.add_argument('--crop_h', default=128, type=int) self.parser.add_argument('--crop_w', default=128, type=int) self.parser.add_argument('--test_h', default=128, type=int) self.parser.add_argument('--test_w', default=128, type=int) self.parser.add_argument('--test_resc', default=True, action='store_false') self.parser.add_argument('--int_aug', default=True, action='store_false') self.parser.add_argument('--noise_aug', default=True, action='store_false') self.parser.add_argument('--noise', default=0.05, type=float) self.parser.add_argument('--color_aug', default=True, action='store_false') self.parser.add_argument('--color_ratio', default=3, type=float) self.parser.add_argument('--normalize', default=False, action='store_true') #### Device Arguments #### self.parser.add_argument('--cuda', default=True, action='store_false') self.parser.add_argument('--multi_gpu', default=False, action='store_true') self.parser.add_argument('--time_sync', default=False, action='store_true') self.parser.add_argument('--workers', default=8, type=int) self.parser.add_argument('--seed', default=0, type=int) #### Stage 1 Model Arguments #### self.parser.add_argument('--dirs_cls', default=36, type=int) self.parser.add_argument('--ints_cls', default=20, type=int) self.parser.add_argument('--dir_int', default=False, action='store_true') self.parser.add_argument('--model', default='LCNet') self.parser.add_argument('--fuse_type', default='max') self.parser.add_argument('--in_img_num', default=32, type=int) self.parser.add_argument('--s1_est_n', default=False, action='store_true') self.parser.add_argument('--s1_est_d', default=True, action='store_false') self.parser.add_argument('--s1_est_i', default=True, action='store_false') self.parser.add_argument('--in_light', default=False, action='store_true') self.parser.add_argument('--in_mask', default=True, action='store_false') self.parser.add_argument('--use_BN', default=False, action='store_true') self.parser.add_argument('--resume', default=None) self.parser.add_argument('--retrain', default=None) self.parser.add_argument('--save_intv', default=1, type=int) #### Stage 2 Model Arguments #### self.parser.add_argument('--stage2', default=False, action='store_true') self.parser.add_argument('--model_s2', default='NENet') self.parser.add_argument('--retrain_s2', default=None) self.parser.add_argument('--s2_est_n', default=True, action='store_false') self.parser.add_argument('--s2_est_i', default=False, action='store_true') self.parser.add_argument('--s2_est_d', default=False, action='store_true') self.parser.add_argument('--s2_in_light', default=True, action='store_false') #### Displaying Arguments #### self.parser.add_argument('--train_disp', default=20, type=int) self.parser.add_argument('--train_save', default=200, type=int) self.parser.add_argument('--val_intv', default=1, type=int) self.parser.add_argument('--val_disp', default=1, type=int) self.parser.add_argument('--val_save', default=1, type=int) self.parser.add_argument('--max_train_iter',default=-1, type=int) self.parser.add_argument('--max_val_iter', default=-1, type=int) self.parser.add_argument('--max_test_iter', default=-1, type=int) self.parser.add_argument('--train_save_n', default=4, type=int) self.parser.add_argument('--test_save_n', default=4, type=int) #### Log Arguments #### self.parser.add_argument('--save_root', default='data/logdir/') self.parser.add_argument('--item', default='CVPR2019') self.parser.add_argument('--suffix', default=None) self.parser.add_argument('--debug', default=False, action='store_true') self.parser.add_argument('--make_dir', default=True, action='store_false') self.parser.add_argument('--save_split', default=False, action='store_true') def setDefault(self): if self.args.debug: self.args.train_disp = 1 self.args.train_save = 1 self.args.max_train_iter = 4 self.args.max_val_iter = 4 self.args.max_test_iter = 4 self.args.test_intv = 1 def collectInfo(self): self.args.str_keys = [ 'model', 'fuse_type', 'solver' ] self.args.val_keys = [ 'batch', 'scale_h', 'crop_h', 'init_lr', 'normal_w', 'dir_w', 'ints_w', 'in_img_num', 'dirs_cls', 'ints_cls' ] self.args.bool_keys = [ 'use_BN', 'in_light', 'in_mask', 's1_est_n', 's1_est_d', 's1_est_i', 'color_aug', 'int_aug', 'concat_data', 'retrain', 'resume', 'stage2', ] def parse(self): self.args = self.parser.parse_args() return self.args ================================================ FILE: options/run_model_opts.py ================================================ from .base_opts import BaseOpts class RunModelOpts(BaseOpts): def __init__(self): super(RunModelOpts, self).__init__() self.initialize() def initialize(self): BaseOpts.initialize(self) #### Testing Dataset Arguments #### self.parser.add_argument('--run_model', default=True, action='store_false') self.parser.add_argument('--benchmark', default='UPS_DiLiGenT_main') self.parser.add_argument('--bm_dir', default='data/datasets/DiLiGenT/pmsData_crop') self.parser.add_argument('--epochs', default=1, type=int) self.parser.add_argument('--test_batch', default=1, type=int) self.parser.add_argument('--test_disp', default=1, type=int) self.parser.add_argument('--test_save', default=1, type=int) ### For UPS_Custom_Datast.py to test you own dataset self.parser.add_argument('--have_l_dirs', default=False, action='store_true', help='Have light directions?') self.parser.add_argument('--have_l_ints', default=False, action='store_true', help='Have light intensities?') self.parser.add_argument('--have_gt_n', default=False, action='store_true', help='Have GT surface normals?') def collectInfo(self): self.args.str_keys = ['model', 'model_s2', 'benchmark', 'fuse_type'] self.args.val_keys = ['in_img_num', 'test_h', 'test_w'] self.args.bool_keys = ['int_aug', 'test_resc'] def setDefault(self): self.collectInfo() def parse(self): BaseOpts.parse(self) self.setDefault() return self.args ================================================ FILE: options/stage1_opts.py ================================================ from .base_opts import BaseOpts class TrainOpts(BaseOpts): def __init__(self): super(TrainOpts, self).__init__() self.initialize() def initialize(self): BaseOpts.initialize(self) #### Training Arguments #### self.parser.add_argument('--solver', default='adam', help='adam|sgd') self.parser.add_argument('--milestones', default=[5, 10, 15, 20, 25], nargs='+', type=int) self.parser.add_argument('--start_epoch', default=1, type=int) self.parser.add_argument('--epochs', default=20, type=int) self.parser.add_argument('--batch', default=32, type=int) self.parser.add_argument('--val_batch', default=8, type=int) self.parser.add_argument('--init_lr', default=0.0005, type=float) self.parser.add_argument('--lr_decay', default=0.5, type=float) self.parser.add_argument('--beta_1', default=0.9, type=float, help='adam') self.parser.add_argument('--beta_2', default=0.999, type=float, help='adam') self.parser.add_argument('--momentum', default=0.9, type=float, help='sgd') self.parser.add_argument('--w_decay', default=4e-4, type=float) #### Loss Arguments #### self.parser.add_argument('--normal_loss', default='cos', help='cos|mse') self.parser.add_argument('--normal_w', default=1, type=float) self.parser.add_argument('--dir_loss', default='cos', help='cos|mse') self.parser.add_argument('--dir_w', default=1, type=float) self.parser.add_argument('--ints_loss', default='mse', help='l1|mse') self.parser.add_argument('--ints_w', default=1, type=float) def collectInfo(self): BaseOpts.collectInfo(self) self.args.str_keys += [ 'dir_loss', ] self.args.val_keys += [ ] self.args.bool_keys += [ ] def setDefault(self): BaseOpts.setDefault(self) if self.args.test_h != self.args.crop_h: self.args.test_h, self.args.test_w = self.args.crop_h, self.args.crop_w self.collectInfo() def parse(self): BaseOpts.parse(self) self.setDefault() return self.args ================================================ FILE: options/stage2_opts.py ================================================ from .base_opts import BaseOpts class TrainOpts(BaseOpts): def __init__(self): super(TrainOpts, self).__init__() self.initialize() def initialize(self): BaseOpts.initialize(self) #### Training Arguments #### self.parser.add_argument('--solver', default='adam', help='adam|sgd') self.parser.add_argument('--milestones', default=[2, 4, 6, 8, 10], nargs='+', type=int) self.parser.add_argument('--start_epoch', default=1, type=int) self.parser.add_argument('--epochs', default=10, type=int) self.parser.add_argument('--batch', default=16, type=int) self.parser.add_argument('--val_batch', default=8, type=int) self.parser.add_argument('--init_lr', default=0.0005, type=float) self.parser.add_argument('--lr_decay', default=0.5, type=float) self.parser.add_argument('--beta_1', default=0.9, type=float, help='adam') self.parser.add_argument('--beta_2', default=0.999, type=float, help='adam') self.parser.add_argument('--momentum', default=0.9, type=float, help='sgd') self.parser.add_argument('--w_decay', default=4e-4, type=float) #### Loss Arguments #### self.parser.add_argument('--normal_loss', default='cos', help='cos|mse') self.parser.add_argument('--normal_w', default=1, type=float) self.parser.add_argument('--dir_loss', default='mse', help='cos|mse') self.parser.add_argument('--dir_w', default=1, type=float) self.parser.add_argument('--ints_loss', default='mse', help='mse') self.parser.add_argument('--ints_w', default=1, type=float) def collectInfo(self): BaseOpts.collectInfo(self) self.args.str_keys += [ 'model_s2' ] self.args.val_keys += [ ] self.args.bool_keys += [ ] def setDefault(self): BaseOpts.setDefault(self) self.args.stage2 = True self.args.test_resc = False self.collectInfo() def parse(self): BaseOpts.parse(self) self.setDefault() return self.args ================================================ FILE: scripts/DiLiGenT_objects.txt ================================================ ballPNG catPNG pot1PNG bearPNG pot2PNG buddhaPNG gobletPNG readingPNG cowPNG harvestPNG ================================================ FILE: scripts/cropDiLiGenTData.py ================================================ import os, argparse, sys, shutil, glob import numpy as np from imageio import imread, imsave import scipy.io as sio root_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../') sys.path.append(root_path) from utils import utils parser = argparse.ArgumentParser() parser.add_argument('--input_dir', default='data/datasets/DiLiGenT/pmsData') parser.add_argument('--obj_list', default='objects.txt') parser.add_argument('--suffix', default='crop') parser.add_argument('--file_ext', default='.png') parser.add_argument('--normal_name',default='Normal_gt.png') parser.add_argument('--mask_name', default='mask.png') parser.add_argument('--n_key', default='Normal_gt') parser.add_argument('--pad', default=15, type=int) args = parser.parse_args() def getSaveDir(): dirName = os.path.dirname(args.input_dir) save_dir = '%s_%s' % (args.input_dir, args.suffix) utils.makeFile(save_dir) print('Output dir: %s\n' % save_dir) return save_dir def getBBoxCompact(mask): index = np.where(mask != 0) t, b, l , r = index[0].min(), index[0].max(), index[1].min(), index[1].max() h, w = b - t + 2 * args.pad, r - l + 2 * args.pad t = max(0, t - args.pad) b = t + h l = max(0, l - args.pad) r = l + w if h % 4 != 0: pad = 4 - h % 4 b += pad; h += pad if w % 4 != 0: pad = 4 - w % 4 r += pad; w += pad return l, r, t, b, h, w def loadMaskNormal(d): mask = imread(os.path.join(args.input_dir, d, args.mask_name)) try: normal = imread(os.path.join(args.input_dir, d, args.normal_name)) except IOError: normal = imread(os.path.join(args.input_dir, d, 'normal_gt.png')) n_mat = sio.loadmat(os.path.join(args.input_dir, d, 'Normal_gt.mat'))[args.n_key] h, w, c = normal.shape print('Processing Objects: %s' % d, mask.shape) if mask.ndim < 3: mask = mask.reshape(h, w, 1).repeat(3, 2) return mask, normal, n_mat def copyTXT(d): txt = glob.glob(os.path.join(args.input_dir, '*.txt')) for t in txt: name = os.path.basename(t) shutil.copy(t, os.path.join(args.save_dir, name)) txt = glob.glob(os.path.join(args.input_dir, d, '*.txt')) for t in txt: name = os.path.basename(t) shutil.copy(t, os.path.join(args.save_dir, d, name)) if __name__ == '__main__': print('Input dir: %s\n' % args.input_dir) args.save_dir = getSaveDir() dir_list = utils.readList(os.path.join(args.input_dir, args.obj_list)) name_list = utils.readList(os.path.join(args.input_dir, 'names.txt')) max_h, max_w = 0, 0 crop_list = open(os.path.join(args.save_dir, 'crop.txt'), 'w') for d in dir_list: utils.makeFile(os.path.join(args.save_dir, d)) mask, normal, n_mat = loadMaskNormal(d) l, r, t, b, h, w = getBBoxCompact(mask[:,:,0] / 255) crop_list.write('%d %d %d %d %d %d\n' % (mask.shape[0], mask.shape[1], l, r, t, b)) max_h = h if h > max_h else max_h max_w = w if w > max_w else max_w print('\t BBox L %d R %d T %d B %d, H:%d W:%d, Padded: %d %d' % (l, r, t, b, h, w, r - l, b - t)) imsave(os.path.join(args.save_dir, d, args.mask_name), mask[t:b, l:r, :]) imsave(os.path.join(args.save_dir, d, args.normal_name), normal[t:b, l:r, :]) sio.savemat(os.path.join(args.save_dir, d, 'Normal_gt.mat'), {args.n_key: n_mat[t:b, l:r, :]} ,do_compression=True) copyTXT(d) intens = np.genfromtxt(os.path.join(args.input_dir, d, 'light_intensities.txt')) for idx, name in enumerate(name_list): #print('Process img %d/%d' % (idx+1, len(name_list))) img = imread(os.path.join(args.input_dir, d, name)) img = img[t:b, l:r, :] imsave(os.path.join(args.save_dir, d, name), img.astype(np.uint8)) print('Max H %d, Max %d' % (max_h, max_w)) crop_list.close() ================================================ FILE: scripts/download_pretrained_models.sh ================================================ path="data/models/" mkdir -p $path cd $path # Download pre-trained model for model in "LCNet_CVPR2019.pth.tar" "NENet_CVPR2019.pth.tar"; do wget http://www.visionlab.cs.hku.hk/data/SDPS-Net/models/${model} done # Back to root directory cd ../ ================================================ FILE: scripts/download_synthetic_datasets.sh ================================================ mkdir -p data/datasets cd data/datasets # Download Synthetic dataset for dataset in "PS_Sculpture_Dataset.tgz" "PS_Blobby_Dataset.tgz"; do echo "Downloading $dataset" wget http://www.visionlab.cs.hku.hk/data/PS-FCN/datasets/$dataset tar -xvf $dataset rm $dataset done # Back to root directory cd ../../ ================================================ FILE: scripts/prepare_diligent_dataset.sh ================================================ mkdir -p data/datasets cd data/datasets ## Download real testing dataset url="https://www.dropbox.com/s/hdnbh526tyvv68i/DiLiGenT.zip?dl=0" name="DiLiGenT" wget $url -O ${name}.zip unzip ${name}.zip rm ${name}.zip cd ${name}/pmsData/ cp ballPNG/filenames.txt names.txt # Back to root directory cd ../../../../ cp scripts/DiLiGenT_objects.txt data/datasets/${name}/pmsData/objects.txt python scripts/cropDiLiGenTData.py ================================================ FILE: test_stage1.py ================================================ import os import torch from models import model_utils from utils import eval_utils, time_utils import numpy as np def get_itervals(args, split): if split not in ['train', 'val', 'test']: split = 'test' args_var = vars(args) disp_intv = args_var[split+'_disp'] save_intv = args_var[split+'_save'] stop_iters = args_var['max_'+split+'_iter'] return disp_intv, save_intv, stop_iters def test(args, split, loader, model, log, epoch, recorder): model.eval() log.printWrite('---- Start %s Epoch %d: %d batches ----' % (split, epoch, len(loader))) timer = time_utils.Timer(args.time_sync); disp_intv, save_intv, stop_iters = get_itervals(args, split) res = [] with torch.no_grad(): for i, sample in enumerate(loader): data = model_utils.parseData(args, sample, timer, split) input = model_utils.getInput(args, data) pred = model(input); timer.updateTime('Forward') recoder, iter_res, error = prepareRes(args, data, pred, recorder, log, split) res.append(iter_res) iters = i + 1 if iters % disp_intv == 0: opt = {'split':split, 'epoch':epoch, 'iters':iters, 'batch':len(loader), 'timer':timer, 'recorder': recorder} log.printItersSummary(opt) if iters % save_intv == 0: results, nrow = prepareSave(args, data, pred) log.saveImgResults(results, split, epoch, iters, nrow=nrow, error=error) log.plotCurves(recorder, split, epoch=epoch, intv=disp_intv) if stop_iters > 0 and iters >= stop_iters: break res = np.vstack([np.array(res), np.array(res).mean(0)]) save_name = '%s_res.txt' % (args.suffix) np.savetxt(os.path.join(args.log_dir, split, save_name), res, fmt='%.2f') opt = {'split': split, 'epoch': epoch, 'recorder': recorder} log.printEpochSummary(opt) def prepareRes(args, data, pred, recorder, log, split): data_batch = args.val_batch if split == 'val' else args.test_batch iter_res = [] error = '' if args.s1_est_d: l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred['dirs'].data, data_batch) recorder.updateIter(split, l_acc.keys(), l_acc.values()) iter_res.append(l_acc['l_err_mean']) error += 'D_%.3f-' % (l_acc['l_err_mean']) if args.s1_est_i: int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred['intens'].data, data_batch) recorder.updateIter(split, int_acc.keys(), int_acc.values()) iter_res.append(int_acc['ints_ratio']) error += 'I_%.3f-' % (int_acc['ints_ratio']) if args.s1_est_n: acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data) recorder.updateIter(split, acc.keys(), acc.values()) iter_res.append(acc['n_err_mean']) error += 'N_%.3f-' % (acc['n_err_mean']) data['error_map'] = error_map['angular_map'] return recorder, iter_res, error def prepareSave(args, data, pred): results = [data['img'].data, data['m'].data, (data['n'].data+1)/2] if args.s1_est_n: pred_n = (pred['n'].data + 1) / 2 masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data) res_n = [pred_n, masked_pred, data['error_map']] results += res_n nrow = data['img'].shape[0] return results, nrow ================================================ FILE: test_stage2.py ================================================ import os import torch from models import model_utils from utils import eval_utils, time_utils import numpy as np def get_itervals(args, split): if split not in ['train', 'val', 'test']: split = 'test' args_var = vars(args) disp_intv = args_var[split+'_disp'] save_intv = args_var[split+'_save'] stop_iters = args_var['max_'+split+'_iter'] return disp_intv, save_intv, stop_iters def test(args, split, loader, models, log, epoch, recorder): models[0].eval() models[1].eval() log.printWrite('---- Start %s Epoch %d: %d batches ----' % (split, epoch, len(loader))) timer = time_utils.Timer(args.time_sync); disp_intv, save_intv, stop_iters = get_itervals(args, split) res = [] with torch.no_grad(): for i, sample in enumerate(loader): data = model_utils.parseData(args, sample, timer, split) input = model_utils.getInput(args, data) pred_c = models[0](input); timer.updateTime('Forward') input.append(pred_c) pred = models[1](input); timer.updateTime('Forward') recoder, iter_res, error = prepareRes(args, data, pred_c, pred, recorder, log, split) res.append(iter_res) iters = i + 1 if iters % disp_intv == 0: opt = {'split':split, 'epoch':epoch, 'iters':iters, 'batch':len(loader), 'timer':timer, 'recorder': recorder} log.printItersSummary(opt) if iters % save_intv == 0: results, nrow = prepareSave(args, data, pred_c, pred) log.saveImgResults(results, split, epoch, iters, nrow=nrow, error='') log.plotCurves(recorder, split, epoch=epoch, intv=disp_intv) if stop_iters > 0 and iters >= stop_iters: break res = np.vstack([np.array(res), np.array(res).mean(0)]) save_name = '%s_res.txt' % (args.suffix) np.savetxt(os.path.join(args.log_dir, split, save_name), res, fmt='%.2f') if res.ndim > 1: for i in range(res.shape[1]): save_name = '%s_%d_res.txt' % (args.suffix, i) np.savetxt(os.path.join(args.log_dir, split, save_name), res[:,i], fmt='%.3f') opt = {'split': split, 'epoch': epoch, 'recorder': recorder} log.printEpochSummary(opt) def prepareRes(args, data, pred_c, pred, recorder, log, split): data_batch = args.val_batch if split == 'val' else args.test_batch iter_res = [] error = '' if args.s1_est_d: l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred_c['dirs'].data, data_batch) recorder.updateIter(split, l_acc.keys(), l_acc.values()) iter_res.append(l_acc['l_err_mean']) error += 'D_%.3f-' % (l_acc['l_err_mean']) if args.s1_est_i: int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred_c['intens'].data, data_batch) recorder.updateIter(split, int_acc.keys(), int_acc.values()) iter_res.append(int_acc['ints_ratio']) error += 'I_%.3f-' % (int_acc['ints_ratio']) if args.s2_est_n: acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data) recorder.updateIter(split, acc.keys(), acc.values()) iter_res.append(acc['n_err_mean']) error += 'N_%.3f-' % (acc['n_err_mean']) data['error_map'] = error_map['angular_map'] return recorder, iter_res, error def prepareSave(args, data, pred_c, pred): results = [data['img'].data, data['m'].data, (data['n'].data+1) / 2] if args.s2_est_n: pred_n = (pred['n'].data + 1) / 2 masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data) res_n = [masked_pred, data['error_map']] results += res_n nrow = data['img'].shape[0] return results, nrow ================================================ FILE: train_stage1.py ================================================ from models import model_utils from utils import eval_utils, time_utils def train(args, loader, model, criterion, optimizer, log, epoch, recorder): model.train() log.printWrite('---- Start Training Epoch %d: %d batches ----' % (epoch, len(loader))) timer = time_utils.Timer(args.time_sync); for i, sample in enumerate(loader): data = model_utils.parseData(args, sample, timer, 'train') input = model_utils.getInput(args, data) pred = model(input); timer.updateTime('Forward') optimizer.zero_grad() loss = criterion.forward(pred, data); timer.updateTime('Crit'); criterion.backward(); timer.updateTime('Backward') recorder.updateIter('train', loss.keys(), loss.values()) optimizer.step(); timer.updateTime('Solver') iters = i + 1 if iters % args.train_disp == 0: opt = {'split':'train', 'epoch':epoch, 'iters':iters, 'batch':len(loader), 'timer':timer, 'recorder': recorder} log.printItersSummary(opt) if iters % args.train_save == 0: results, recorder, nrow = prepareSave(args, data, pred, recorder, log) log.saveImgResults(results, 'train', epoch, iters, nrow=nrow) log.plotCurves(recorder, 'train', epoch=epoch, intv=args.train_disp) if args.max_train_iter > 0 and iters >= args.max_train_iter: break opt = {'split': 'train', 'epoch': epoch, 'recorder': recorder} log.printEpochSummary(opt) def prepareSave(args, data, pred, recorder, log): results = [data['img'].data, data['m'].data, (data['n'].data+1)/2] if args.s1_est_d: l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred['dirs'].data, args.batch) recorder.updateIter('train', l_acc.keys(), l_acc.values()) if args.s1_est_i: int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred['intens'].data, args.batch) recorder.updateIter('train', int_acc.keys(), int_acc.values()) if args.s1_est_n: acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, data['m'].data) pred_n = (pred['n'].data + 1) / 2 masked_pred = pred_n * data['m'].data.expand_as(pred['n'].data) res_n = [pred_n, masked_pred, error_map['angular_map']] results += res_n recorder.updateIter('train', acc.keys(), acc.values()) nrow = data['img'].shape[0] if data['img'].shape[0] <= 32 else 32 return results, recorder, nrow ================================================ FILE: train_stage2.py ================================================ import torch from models import model_utils from utils import eval_utils, time_utils def train(args, loader, models, criterion, optimizers, log, epoch, recorder): models[1].train() models[0].eval() optimizer, optimizer_c = optimizers log.printWrite('---- Start Training Epoch %d: %d batches ----' % (epoch, len(loader))) timer = time_utils.Timer(args.time_sync); for i, sample in enumerate(loader): data = model_utils.parseData(args, sample, timer, 'train') input = model_utils.getInput(args, data) with torch.no_grad(): pred_c = models[0](input); input.append(pred_c) pred = models[1](input); timer.updateTime('Forward') optimizer.zero_grad() loss = criterion.forward(pred, data); timer.updateTime('Crit'); criterion.backward(); timer.updateTime('Backward') recorder.updateIter('train', loss.keys(), loss.values()) optimizer.step(); timer.updateTime('Solver') iters = i + 1 if iters % args.train_disp == 0: opt = {'split':'train', 'epoch':epoch, 'iters':iters, 'batch':len(loader), 'timer':timer, 'recorder': recorder} log.printItersSummary(opt) if iters % args.train_save == 0: results, recorder, nrow = prepareSave(args, data, pred_c, pred, recorder, log) log.saveImgResults(results, 'train', epoch, iters, nrow=nrow) log.plotCurves(recorder, 'train', epoch=epoch, intv=args.train_disp) if args.max_train_iter > 0 and iters >= args.max_train_iter: break opt = {'split': 'train', 'epoch': epoch, 'recorder': recorder} log.printEpochSummary(opt) def prepareSave(args, data, pred_c, pred, recorder, log): input_var, mask_var = data['img'], data['m'] results = [input_var.data, mask_var.data, (data['n'].data+1)/2] if args.s1_est_d: l_acc, data['dir_err'] = eval_utils.calDirsAcc(data['dirs'].data, pred_c['dirs'].data, args.batch) recorder.updateIter('train', l_acc.keys(), l_acc.values()) if args.s1_est_i: int_acc, data['int_err'] = eval_utils.calIntsAcc(data['ints'].data, pred_c['intens'].data, args.batch) recorder.updateIter('train', int_acc.keys(), int_acc.values()) if args.s2_est_n: acc, error_map = eval_utils.calNormalAcc(data['n'].data, pred['n'].data, mask_var.data) pred_n = (pred['n'].data + 1) / 2 masked_pred = pred_n * mask_var.data.expand_as(pred['n'].data) res_n = [masked_pred, error_map['angular_map']] results += res_n recorder.updateIter('train', acc.keys(), acc.values()) nrow = input_var.shape[0] if input_var.shape[0] <= 32 else 32 return results, recorder, nrow ================================================ FILE: utils/__init__.py ================================================ ================================================ FILE: utils/eval_utils.py ================================================ import torch import math import numpy as np from matplotlib import cm def colorMap(diff): thres = 90 diff_norm = np.clip(diff, 0, thres) / thres diff_cm = torch.from_numpy(cm.jet(diff_norm.numpy()))[:,:,:, :3] return diff_cm.permute(0,3,1,2).clone().float() def calDirsAcc(gt_l, pred_l, data_batch=1): n, c = gt_l.shape pred_l = pred_l.view(n, c) dot_product = (gt_l * pred_l).sum(1).clamp(-1, 1) angular_err = torch.acos(dot_product) * 180.0 / math.pi l_err_mean = angular_err.mean() return {'l_err_mean': l_err_mean.item()}, angular_err.squeeze() def calIntsAcc(gt_i, pred_i, data_batch=1): n, c, h, w = gt_i.shape pred_i = pred_i.view(n, c, h, w) ref_int = gt_i[:, :3].repeat(1, gt_i.shape[1] // 3, 1, 1) gt_i = gt_i / ref_int scale = torch.gels(gt_i.view(-1, 1), pred_i.view(-1, 1)) ints_ratio = (gt_i - scale[0][0] * pred_i).abs() / (gt_i + 1e-8) ints_error = torch.stack(ints_ratio.split(3, 1), 1).mean(2) return {'ints_ratio': ints_ratio.mean().item()}, ints_error.squeeze() def calNormalAcc(gt_n, pred_n, mask=None): """Tensor Dim: NxCxHxW""" dot_product = (gt_n * pred_n).sum(1).clamp(-1,1) error_map = torch.acos(dot_product) # [-pi, pi] angular_map = error_map * 180.0 / math.pi angular_map = angular_map * mask.narrow(1, 0, 1).squeeze(1) valid = mask.narrow(1, 0, 1).sum() ang_valid = angular_map[mask.narrow(1, 0, 1).squeeze(1).byte()] n_err_mean = ang_valid.sum() / valid n_err_med = ang_valid.median() n_acc_11 = (ang_valid < 11.25).sum().float() / valid n_acc_30 = (ang_valid < 30).sum().float() / valid n_acc_45 = (ang_valid < 45).sum().float() / valid angular_map = colorMap(angular_map.cpu().squeeze(1)) value = {'n_err_mean': n_err_mean.item(), 'n_acc_11': n_acc_11.item(), 'n_acc_30': n_acc_30.item(), 'n_acc_45': n_acc_45.item()} angular_error_map = {'angular_map': angular_map} return value, angular_error_map def SphericalDirsToClass(dirs, cls_num): theta = torch.atan(dirs[:,0] / (dirs[:,2] + 1e-8)) denom = torch.sqrt(dirs[:,0] * dirs[:,0] + dirs[:,2] * dirs[:,2]) phi = torch.atan(dirs[:,1] / (denom + 1e-8)) theta = theta / np.pi * 180 phi = phi / np.pi * 180 azimuth = ((theta + 90.0) / 180 * cls_num).clamp(0, cls_num-1).long() elevate = ((phi + 90.0) / 180 * cls_num).clamp(0, cls_num-1).long() return azimuth, elevate def SphericalClassToDirs(x_cls, y_cls, cls_num): theta = (x_cls.float() + 0.5) / cls_num * 180 - 90 phi = (y_cls.float() + 0.5) / cls_num * 180 - 90 neg_x = theta < 0 neg_y = phi < 0 theta = theta.clamp(-90, 90) / 180.0 * np.pi phi = phi.clamp(-90, 90) / 180.0 * np.pi tan2_phi = pow(torch.tan(phi), 2) tan2_theta = pow(torch.tan(theta), 2) y = torch.sqrt(tan2_phi / (1 + tan2_phi)) y[neg_y] = y[neg_y] * -1 #y = torch.sin(phi) z = torch.sqrt((1 - y * y) / (1 + tan2_theta)) x = z * torch.tan(theta) dirs = torch.stack([x,y,z], 1) dirs = dirs / dirs.norm(p=2, dim=1, keepdim=True) return dirs def LightIntsToClass(ints, cls_num): ints = (ints - 0.2) / 1.8 ints = (ints * cls_num).clamp(0, cls_num-1).long() return ints.view(-1) def ClassToLightInts(cls, cls_num): ints = (cls.float() + 0.5) / cls_num * 1.8 + 0.2 ints = ints.clamp(0.2 , 2.0) return ints ================================================ FILE: utils/logger.py ================================================ import datetime, time, os import numpy as np import torch import torchvision.utils as vutils import scipy.io as sio from . import utils import matplotlib; matplotlib.use('agg') import matplotlib.pyplot as plt from matplotlib import cm from matplotlib.font_manager import FontProperties fontP = FontProperties() fontP.set_size('small') plt.rcParams["figure.figsize"] = (5,8) class Logger(object): def __init__(self, args): self.times = {'init': time.time()} if args.make_dir: self._setupDirs(args) self.args = args args.log = self self.printArgs() def printArgs(self): strs = '------------ Options -------------\n' strs += '{}'.format(utils.dictToString(vars(self.args))) strs += '-------------- End ----------------\n' self.printWrite(strs) def _addArguments(self, args): info = '' if hasattr(args, 'run_model') and args.run_model: info += '_run_model,%s' % os.path.basename(args.retrain).split('.')[0] arg_var = vars(args) for k in args.str_keys: info = '{0},{1}'.format(info, arg_var[k]) for k in args.val_keys: var_key = k[:2] + '_' + k[-1] info = '{0},{1}-{2}'.format(info, var_key, arg_var[k]) for k in args.bool_keys: info = '{0},{1}'.format(info, k) if arg_var[k] else info return info def _setupDirs(self, args): date_now = datetime.datetime.now() dir_name = '%d-%d' % (date_now.month, date_now.day) dir_name += (',%s' % args.suffix) if args.suffix else '' dir_name += self._addArguments(args) dir_name += ',DEBUG' if args.debug else '' self._checkPath(args, dir_name) file_dir = os.path.join(args.log_dir, '%s,%s' % (dir_name, date_now.strftime('%H:%M:%S'))) self.log_fie = open(file_dir, 'w') return def _checkPath(self, args, dir_name): if hasattr(args, 'run_model') and args.run_model: log_root = os.path.join(os.path.dirname(args.retrain), dir_name) args.log_dir = log_root sub_dirs = ['test'] else: if args.resume and os.path.isfile(args.resume): log_root = os.path.join(os.path.dirname(os.path.dirname(args.resume)), dir_name) else: if args.debug: dir_name = 'DEBUG/' + dir_name log_root = os.path.join(args.save_root, args.dataset, args.item, dir_name) args.log_dir = os.path.join(log_root, 'logdir') args.cp_dir = os.path.join(log_root, 'checkpointdir') utils.makeFiles([args.log_dir, args.cp_dir]) sub_dirs = ['train', 'val'] for sub_dir in sub_dirs: utils.makeFiles([os.path.join(args.log_dir, sub_dir, 'Images')]) def printWrite(self, strs): print('%s' % strs) if self.args.make_dir: self.log_fie.write('%s\n' % strs) self.log_fie.flush() def getTimeInfo(self, epoch, iters, batch): time_elapsed = (time.time() - self.times['init']) / 3600.0 total_iters = (self.args.epochs - self.args.start_epoch + 1) * batch cur_iters = (epoch - self.args.start_epoch) * batch + iters time_total = time_elapsed * (float(total_iters) / cur_iters) return time_elapsed, time_total def printItersSummary(self, opt): epoch, iters, batch = opt['epoch'], opt['iters'], opt['batch'] strs = ' | {}'.format(str.upper(opt['split'])) strs += ' Iter [{}/{}] Epoch [{}/{}]'.format(iters, batch, epoch, self.args.epochs) if opt['split'] == 'train': time_elapsed, time_total = self.getTimeInfo(epoch, iters, batch) # Buggy for test strs += ' Clock [{:.2f}h/{:.2f}h]'.format(time_elapsed, time_total) strs += ' LR [{}]'.format(opt['recorder'].records[opt['split']]['lr'][epoch][0]) self.printWrite(strs) if 'timer' in opt.keys(): self.printWrite(opt['timer'].timeToString()) if 'recorder' in opt.keys(): self.printWrite(opt['recorder'].iterRecToString(opt['split'], epoch)) def printEpochSummary(self, opt): split = opt['split'] epoch = opt['epoch'] self.printWrite('---------- {} Epoch {} Summary -----------'.format(str.upper(split), epoch)) self.printWrite(opt['recorder'].epochRecToString(split, epoch)) def convertToSameSize(self, t_list): shape = (t_list[0].shape[0], 3, t_list[0].shape[2], t_list[0].shape[3]) for i, tensor in enumerate(t_list): n, c, h, w = tensor.shape if tensor.shape[1] != shape[1]: # check channel t_list[i] = tensor.expand((n, shape[1], h, w)) if h == shape[2] and w == shape[3]: continue t_list[i] = torch.nn.functional.upsample(tensor, [shape[2], shape[3]], mode='bilinear') return t_list def getSaveDir(self, split, epoch): save_dir = os.path.join(self.args.log_dir, split, 'Images') run_model = hasattr(self.args, 'run_model') and self.args.run_model if not run_model and epoch > 0: save_dir = os.path.join(save_dir, str(epoch)) utils.makeFile(save_dir) return save_dir def splitMulitChannel(self, t_list, max_save_n = 8): new_list = [] for tensor in t_list: if tensor.shape[1] > 3: num = 3 new_list += torch.split(tensor, num, 1)[:max_save_n] else: new_list.append(tensor) return new_list def saveSplit(self, res, save_prefix): n, c, h, w = res.shape for i in range(n): vutils.save_image(res[i], save_prefix + '_%d.png' % (i)) def saveImgResults(self, results, split, epoch, iters, nrow, error=''): max_save_n = self.args.test_save_n if split == 'test' else self.args.train_save_n res = [img.cpu() for img in results] res = self.splitMulitChannel(res, max_save_n) res = torch.cat(self.convertToSameSize(res)) save_dir = self.getSaveDir(split, epoch) save_prefix = os.path.join(save_dir, '%d_%d' % (epoch, iters)) save_prefix += ('_%s' % error) if error != '' else '' if self.args.save_split: self.saveSplit(res, save_prefix) else: vutils.save_image(res, save_prefix + '_out.png', nrow=nrow) def plotCurves(self, recorder, split='train', epoch=-1, intv=1): dict_of_array = recorder.recordToDictOfArray(split, epoch, intv) save_dir = os.path.join(self.args.log_dir, split) if epoch < 0: save_dir = self.args.log_dir save_name = '%s_Summary.png' % (split) else: save_name = '%s_epoch_%d.png' % (split, epoch) classes = ['loss', 'acc', 'err', 'lr', 'ratio'] classes = utils.checkIfInList(classes, dict_of_array.keys()) if len(classes) == 0: return for idx, c in enumerate(classes): plt.subplot(len(classes), 1, idx+1) plt.grid() legends = [] for k in dict_of_array.keys(): if (c in k.lower()) and not k.endswith('_x'): plt.plot(dict_of_array[k+'_x'], dict_of_array[k]) legends.append(k) if len(legends) != 0: plt.legend(legends, bbox_to_anchor=(0.5, 1.05), loc='upper center', ncol=len(legends), prop=fontP) plt.title(c) if epoch < 0: plt.xlabel('Epoch') else: plt.xlabel('Iters') plt.tight_layout() plt.savefig(os.path.join(save_dir, save_name)) plt.clf() ================================================ FILE: utils/recorders.py ================================================ from collections import OrderedDict import numpy as np class Records(object): """ Records->Train,Val->Loss,Accuracy->Epoch1,2,3->[v1,v2] IterRecords->Train,Val->Loss, Accuracy,->[v1,v2] """ def __init__(self, log_dir, records=None): if records == None: self.records = OrderedDict() else: self.records = records self.iter_rec = OrderedDict() self.log_dir = log_dir self.classes = ['loss', 'acc', 'err', 'ratio'] def resetIter(self): self.iter_rec.clear() def checkDict(self, a_dict, key, sub_type='dict'): if key not in a_dict.keys(): if sub_type == 'dict': a_dict[key] = OrderedDict() if sub_type == 'list': a_dict[key] = [] def updateIter(self, split, keys, values): self.checkDict(self.iter_rec, split, 'dict') for k, v in zip(keys, values): self.checkDict(self.iter_rec[split], k, 'list') self.iter_rec[split][k].append(v) def saveIterRecord(self, epoch, reset=True): for s in self.iter_rec.keys(): # s stands for split self.checkDict(self.records, s, 'dict') for k in self.iter_rec[s].keys(): self.checkDict(self.records[s], k, 'dict') self.checkDict(self.records[s][k], epoch, 'list') self.records[s][k][epoch].append(np.mean(self.iter_rec[s][k])) if reset: self.resetIter() def insertRecord(self, split, key, epoch, value): self.checkDict(self.records, split, 'dict') self.checkDict(self.records[split], key, 'dict') self.checkDict(self.records[split][key], epoch, 'list') self.records[split][key][epoch].append(value) def iterRecToString(self, split, epoch): rec_strs = '' for c in self.classes: strs = '' for k in self.iter_rec[split].keys(): if (c in k.lower()): strs += '{}: {:.3f}| '.format(k, np.mean(self.iter_rec[split][k])) if strs != '': rec_strs += '\t [{}] {}\n'.format(c.upper(), strs) self.saveIterRecord(epoch) return rec_strs def epochRecToString(self, split, epoch): rec_strs = '' for c in self.classes: strs = '' for k in self.records[split].keys(): if (c in k.lower()) and (epoch in self.records[split][k].keys()): strs += '{}: {:.3f}| '.format(k, np.mean(self.records[split][k][epoch])) if strs != '': rec_strs += '\t [{}] {}\n'.format(c.upper(), strs) return rec_strs def recordToDictOfArray(self, splits, epoch=-1, intv=1): if len(self.records) == 0: return {} if type(splits) == str: splits = [splits] dict_of_array = OrderedDict() for split in splits: for k in self.records[split].keys(): y_array, x_array = [], [] if epoch < 0: for ep in self.records[split][k].keys(): y_array.append(np.mean(self.records[split][k][ep])) x_array.append(ep) else: if epoch in self.records[split][k].keys(): y_array = np.array(self.records[split][k][epoch]) x_array = np.linspace(intv, intv*len(y_array), len(y_array)) dict_of_array[split[0] + split[-1] + '_' + k] = y_array dict_of_array[split[0] + split[-1] + '_' + k+'_x'] = x_array return dict_of_array ================================================ FILE: utils/time_utils.py ================================================ import time import torch from collections import OrderedDict class Timer(object): def __init__(self, cuda_sync=False): self.timer = OrderedDict() self.cuda_sync = cuda_sync self.startTimer() def startTimer(self): self.iter_start = time.time() self.disp_start = time.time() def resetTimer(self): self.iter_start = time.time() self.disp_start = time.time() for key in self.timer.keys(): self.timer[key].reset() def updateTime(self, key): if key not in self.timer.keys(): self.timer[key] = AverageMeter() if self.cuda_sync: torch.cuda.synchronize() self.timer[key].update(time.time() - self.iter_start) self.iter_start = time.time() def timeToString(self, reset=True): strs = '\t [Time %.3fs] ' % (time.time() - self.disp_start) for key in self.timer.keys(): if self.timer[key].sum < 1e-4: continue strs += '%s: %.3fs| ' % (key, self.timer[key].sum) self.resetTimer() return strs class AverageMeter(object): def __init__(self): self.reset() def reset(self): self.sum = 0 self.count = 0 self.avg = 0 def update(self, val, n=1): self.sum += val * n self.count += n self.avg = self.sum / self.count def __repr__(self): return '%.3f' % (self.avg) ================================================ FILE: utils/utils.py ================================================ import os import numpy as np from imageio import imread, imsave def makeFile(f): if not os.path.exists(f): os.makedirs(f) #else: raise Exception('Rendered image directory %s is already existed!!!' % directory) def makeFiles(f_list): for f in f_list: makeFile(f) def emptyFile(name): with open(name, 'w') as f: f.write(' ') def dictToString(dicts, start='\t', end='\n'): strs = '' for k, v in sorted(dicts.items()): strs += '%s%s: %s%s' % (start, str(k), str(v), end) return strs def checkIfInList(list1, list2): contains = [] for l1 in list1: for l2 in list2: if l1 in l2.lower(): contains.append(l1) break return contains def atoi(text): return int(text) if text.isdigit() else text def natural_keys(text): ''' alist.sort(key=natural_keys) sorts in human order http://nedbatchelder.com/blog/200712/human_sorting.html (See Toothy's implementation in the comments) ''' return [ atoi(c) for c in re.split('(\d+)', text) ] def readList(list_path,ignore_head=False, sort=False): lists = [] with open(list_path) as f: lists = f.read().splitlines() if ignore_head: lists = lists[1:] if sort: lists.sort(key=natural_keys) return lists