Repository: NVIDIA/pix2pixHD
Branch: master
Commit: 14b3b3c7fff4
Files: 44
Total size: 115.1 KB
Directory structure:
gitextract_tm06yt2o/
├── .gitignore
├── LICENSE.txt
├── README.md
├── _config.yml
├── data/
│ ├── __init__.py
│ ├── aligned_dataset.py
│ ├── base_data_loader.py
│ ├── base_dataset.py
│ ├── custom_dataset_data_loader.py
│ ├── data_loader.py
│ └── image_folder.py
├── encode_features.py
├── models/
│ ├── __init__.py
│ ├── base_model.py
│ ├── models.py
│ ├── networks.py
│ ├── pix2pixHD_model.py
│ └── ui_model.py
├── options/
│ ├── __init__.py
│ ├── base_options.py
│ ├── test_options.py
│ └── train_options.py
├── precompute_feature_maps.py
├── run_engine.py
├── scripts/
│ ├── test_1024p.sh
│ ├── test_1024p_feat.sh
│ ├── test_512p.sh
│ ├── test_512p_feat.sh
│ ├── train_1024p_12G.sh
│ ├── train_1024p_24G.sh
│ ├── train_1024p_feat_12G.sh
│ ├── train_1024p_feat_24G.sh
│ ├── train_512p.sh
│ ├── train_512p_feat.sh
│ ├── train_512p_fp16.sh
│ ├── train_512p_fp16_multigpu.sh
│ └── train_512p_multigpu.sh
├── test.py
├── train.py
└── util/
├── __init__.py
├── html.py
├── image_pool.py
├── util.py
└── visualizer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
debug*
checkpoints/
results/
build/
dist/
torch.egg-info/
*/**/__pycache__
torch/version.py
torch/csrc/generic/TensorMethods.cpp
torch/lib/*.so*
torch/lib/*.dylib*
torch/lib/*.h
torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
torch/csrc/nn/THCUNN.cpp
torch/csrc/nn/THNN_generic.cwrap
torch/csrc/nn/THNN_generic.cpp
torch/csrc/nn/THNN_generic.h
docs/src/**/*
test/data/legacy_modules.t7
test/data/gpu_tensors.pt
test/htmlcov
test/.coverage
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
test/data/legacy_serialized.pt
*.DS_Store
*~
================================================
FILE: LICENSE.txt
================================================
Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
BSD License. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
--------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ----------------
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# pix2pixHD
### [Project](https://tcwang0509.github.io/pix2pixHD/) | [Youtube](https://youtu.be/3AIpPlzM_qs) | [Paper](https://arxiv.org/pdf/1711.11585.pdf)
Pytorch implementation of our method for high-resolution (e.g. 2048x1024) photorealistic image-to-image translation. It can be used for turning semantic label maps into photo-realistic images or synthesizing portraits from face label maps.
[High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs](https://tcwang0509.github.io/pix2pixHD/)
[Ting-Chun Wang](https://tcwang0509.github.io/)1, [Ming-Yu Liu](http://mingyuliu.net/)1, [Jun-Yan Zhu](http://people.eecs.berkeley.edu/~junyanz/)2, Andrew Tao1, [Jan Kautz](http://jankautz.com/)1, [Bryan Catanzaro](http://catanzaro.name/)1
1NVIDIA Corporation, 2UC Berkeley
In CVPR 2018.
## Image-to-image translation at 2k/1k resolution
- Our label-to-streetview results
- Interactive editing results
- Additional streetview results
- Label-to-face and interactive editing results
- Our editing interface
## Prerequisites
- Linux or macOS
- Python 2 or 3
- NVIDIA GPU (11G memory or larger) + CUDA cuDNN
## Getting Started
### Installation
- Install PyTorch and dependencies from http://pytorch.org
- Install python libraries [dominate](https://github.com/Knio/dominate).
```bash
pip install dominate
```
- Clone this repo:
```bash
git clone https://github.com/NVIDIA/pix2pixHD
cd pix2pixHD
```
### Testing
- A few example Cityscapes test images are included in the `datasets` folder.
- Please download the pre-trained Cityscapes model from [here](https://drive.google.com/file/d/1OR-2aEPHOxZKuoOV34DvQxreqGCSLcW9/view?usp=drive_link) (google drive link), and put it under `./checkpoints/label2city_1024p/`
- Test the model (`bash ./scripts/test_1024p.sh`):
```bash
#!./scripts/test_1024p.sh
python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none
```
The test results will be saved to a html file here: `./results/label2city_1024p/test_latest/index.html`.
More example scripts can be found in the `scripts` directory.
### Dataset
- We use the Cityscapes dataset. To train a model on the full dataset, please download it from the [official website](https://www.cityscapes-dataset.com/) (registration required).
After downloading, please put it under the `datasets` folder in the same way the example images are provided.
### Training
- Train a model at 1024 x 512 resolution (`bash ./scripts/train_512p.sh`):
```bash
#!./scripts/train_512p.sh
python train.py --name label2city_512p
```
- To view training results, please checkout intermediate results in `./checkpoints/label2city_512p/web/index.html`.
If you have tensorflow installed, you can see tensorboard logs in `./checkpoints/label2city_512p/logs` by adding `--tf_log` to the training scripts.
### Multi-GPU training
- Train a model using multiple GPUs (`bash ./scripts/train_512p_multigpu.sh`):
```bash
#!./scripts/train_512p_multigpu.sh
python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7
```
Note: this is not tested and we trained our model using single GPU only. Please use at your own discretion.
### Training with Automatic Mixed Precision (AMP) for faster speed
- To train with mixed precision support, please first install apex from: https://github.com/NVIDIA/apex
- You can then train the model by adding `--fp16`. For example,
```bash
#!./scripts/train_512p_fp16.sh
python -m torch.distributed.launch train.py --name label2city_512p --fp16
```
In our test case, it trains about 80% faster with AMP on a Volta machine.
### Training at full resolution
- To train the images at full resolution (2048 x 1024) requires a GPU with 24G memory (`bash ./scripts/train_1024p_24G.sh`), or 16G memory if using mixed precision (AMP).
- If only GPUs with 12G memory are available, please use the 12G script (`bash ./scripts/train_1024p_12G.sh`), which will crop the images during training. Performance is not guaranteed using this script.
### Training with your own dataset
- If you want to train with your own dataset, please generate label maps which are one-channel whose pixel values correspond to the object labels (i.e. 0,1,...,N-1, where N is the number of labels). This is because we need to generate one-hot vectors from the label maps. Please also specity `--label_nc N` during both training and testing.
- If your input is not a label map, please just specify `--label_nc 0` which will directly use the RGB colors as input. The folders should then be named `train_A`, `train_B` instead of `train_label`, `train_img`, where the goal is to translate images from A to B.
- If you don't have instance maps or don't want to use them, please specify `--no_instance`.
- The default setting for preprocessing is `scale_width`, which will scale the width of all training images to `opt.loadSize` (1024) while keeping the aspect ratio. If you want a different setting, please change it by using the `--resize_or_crop` option. For example, `scale_width_and_crop` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `crop` skips the resizing step and only performs random cropping. If you don't want any preprocessing, please specify `none`, which will do nothing other than making sure the image is divisible by 32.
## More Training/Test Details
- Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags.
- Instance map: we take in both label maps and instance maps as input. If you don't want to use instance maps, please specify the flag `--no_instance`.
## Citation
If you find this useful for your research, please use the following.
```
@inproceedings{wang2018pix2pixHD,
title={High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs},
author={Ting-Chun Wang and Ming-Yu Liu and Jun-Yan Zhu and Andrew Tao and Jan Kautz and Bryan Catanzaro},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2018}
}
```
## Acknowledgments
This code borrows heavily from [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
================================================
FILE: _config.yml
================================================
theme: jekyll-theme-minimal
================================================
FILE: data/__init__.py
================================================
================================================
FILE: data/aligned_dataset.py
================================================
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image
class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
### input A (label maps)
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
self.A_paths = sorted(make_dataset(self.dir_A))
### input B (real images)
if opt.isTrain or opt.use_encoded_image:
dir_B = '_B' if self.opt.label_nc == 0 else '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))
### instance maps
if not opt.no_instance:
self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst')
self.inst_paths = sorted(make_dataset(self.dir_inst))
### load precomputed instance-wise encoded features
if opt.load_features:
self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat')
print('----------- loading features from %s ----------' % self.dir_feat)
self.feat_paths = sorted(make_dataset(self.dir_feat))
self.dataset_size = len(self.A_paths)
def __getitem__(self, index):
### input A (label maps)
A_path = self.A_paths[index]
A = Image.open(A_path)
params = get_params(self.opt, A.size)
if self.opt.label_nc == 0:
transform_A = get_transform(self.opt, params)
A_tensor = transform_A(A.convert('RGB'))
else:
transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
A_tensor = transform_A(A) * 255.0
B_tensor = inst_tensor = feat_tensor = 0
### input B (real images)
if self.opt.isTrain or self.opt.use_encoded_image:
B_path = self.B_paths[index]
B = Image.open(B_path).convert('RGB')
transform_B = get_transform(self.opt, params)
B_tensor = transform_B(B)
### if using instance maps
if not self.opt.no_instance:
inst_path = self.inst_paths[index]
inst = Image.open(inst_path)
inst_tensor = transform_A(inst)
if self.opt.load_features:
feat_path = self.feat_paths[index]
feat = Image.open(feat_path).convert('RGB')
norm = normalize()
feat_tensor = norm(transform_A(feat))
input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
'feat': feat_tensor, 'path': A_path}
return input_dict
def __len__(self):
return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize
def name(self):
return 'AlignedDataset'
================================================
FILE: data/base_data_loader.py
================================================
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass
def load_data():
return None
================================================
FILE: data/base_dataset.py
================================================
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize
elif opt.resize_or_crop == 'scale_width_and_crop':
new_w = opt.loadSize
new_h = opt.loadSize * h // w
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
================================================
FILE: data/custom_dataset_data_loader.py
================================================
import torch.utils.data
from data.base_data_loader import BaseDataLoader
def CreateDataset(opt):
dataset = None
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
================================================
FILE: data/data_loader.py
================================================
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
================================================
FILE: data/image_folder.py
================================================
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
from PIL import Image
import os
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: encode_features.py
================================================
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import numpy as np
import os
opt = TrainOptions().parse()
opt.nThreads = 1
opt.batchSize = 1
opt.serial_batches = True
opt.no_flip = True
opt.instance_feat = True
opt.continue_train = True
name = 'features'
save_path = os.path.join(opt.checkpoints_dir, opt.name)
############ Initialize #########
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
model = create_model(opt)
########### Encode features ###########
reencode = True
if reencode:
features = {}
for label in range(opt.label_nc):
features[label] = np.zeros((0, opt.feat_num+1))
for i, data in enumerate(dataset):
feat = model.module.encode_features(data['image'], data['inst'])
for label in range(opt.label_nc):
features[label] = np.append(features[label], feat[label], axis=0)
print('%d / %d images' % (i+1, dataset_size))
save_name = os.path.join(save_path, name + '.npy')
np.save(save_name, features)
############## Clustering ###########
n_clusters = opt.n_clusters
load_name = os.path.join(save_path, name + '.npy')
features = np.load(load_name).item()
from sklearn.cluster import KMeans
centers = {}
for label in range(opt.label_nc):
feat = features[label]
feat = feat[feat[:,-1] > 0.5, :-1]
if feat.shape[0]:
n_clusters = min(feat.shape[0], opt.n_clusters)
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat)
centers[label] = kmeans.cluster_centers_
save_name = os.path.join(save_path, name + '_clustered_%03d.npy' % opt.n_clusters)
np.save(save_name, centers)
print('saving to %s' % save_name)
================================================
FILE: models/__init__.py
================================================
================================================
FILE: models/base_model.py
================================================
import os
import torch
import sys
class BaseModel(torch.nn.Module):
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
network.cuda()
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label, save_dir=''):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
if not save_dir:
save_dir = self.save_dir
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
print('%s not exists yet!' % save_path)
if network_label == 'G':
raise('Generator must exist!')
else:
#network.load_state_dict(torch.load(save_path))
try:
network.load_state_dict(torch.load(save_path))
except:
pretrained_dict = torch.load(save_path)
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
if self.opt.verbose:
print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
except:
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
if sys.version_info >= (3,0):
not_initialized = set()
else:
from sets import Set
not_initialized = Set()
for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split('.')[0])
print(sorted(not_initialized))
network.load_state_dict(model_dict)
def update_learning_rate():
pass
================================================
FILE: models/models.py
================================================
import torch
def create_model(opt):
if opt.model == 'pix2pixHD':
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
if opt.isTrain:
model = Pix2PixHDModel()
else:
model = InferenceModel()
else:
from .ui_model import UIModel
model = UIModel()
model.initialize(opt)
if opt.verbose:
print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids) and not opt.fp16:
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
return model
================================================
FILE: models/networks.py
================================================
import torch
import torch.nn as nn
import functools
from torch.autograd import Variable
import numpy as np
###############################################################################
# Functions
###############################################################################
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1,
n_blocks_local=3, norm='instance', gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'global':
netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer)
elif netG == 'local':
netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global,
n_local_enhancers, n_blocks_local, norm_layer)
elif netG == 'encoder':
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer)
else:
raise('generator not implemented!')
print(netG)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]):
norm_layer = get_norm_layer(norm_type=norm)
netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat)
print(netD)
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
netD.cuda(gpu_ids[0])
netD.apply(weights_init)
return netD
def print_network(net):
if isinstance(net, list):
net = net[0]
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
##############################################################################
# Losses
##############################################################################
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
if isinstance(input[0], list):
loss = 0
for input_i in input:
pred = input_i[-1]
target_tensor = self.get_target_tensor(pred, target_is_real)
loss += self.loss(pred, target_tensor)
return loss
else:
target_tensor = self.get_target_tensor(input[-1], target_is_real)
return self.loss(input[-1], target_tensor)
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = Vgg19().cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
##############################################################################
# Generator
##############################################################################
class LocalEnhancer(nn.Module):
def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9,
n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'):
super(LocalEnhancer, self).__init__()
self.n_local_enhancers = n_local_enhancers
###### global generator model #####
ngf_global = ngf * (2**n_local_enhancers)
model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model
model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers
self.model = nn.Sequential(*model_global)
###### local enhancer layers #####
for n in range(1, n_local_enhancers+1):
### downsample
ngf_global = ngf * (2**(n_local_enhancers-n))
model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
norm_layer(ngf_global), nn.ReLU(True),
nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf_global * 2), nn.ReLU(True)]
### residual blocks
model_upsample = []
for i in range(n_blocks_local):
model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)]
### upsample
model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(ngf_global), nn.ReLU(True)]
### final convolution
if n == n_local_enhancers:
model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample))
setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample))
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, input):
### create input pyramid
input_downsampled = [input]
for i in range(self.n_local_enhancers):
input_downsampled.append(self.downsample(input_downsampled[-1]))
### output at coarest level
output_prev = self.model(input_downsampled[-1])
### build up one layer at a time
for n_local_enhancers in range(1, self.n_local_enhancers+1):
model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1')
model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2')
input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers]
output_prev = model_upsample(model_downsample(input_i) + output_prev)
return output_prev
class GlobalGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect'):
assert(n_blocks >= 0)
super(GlobalGenerator, self).__init__()
activation = nn.ReLU(True)
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), activation]
### resnet blocks
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)]
### upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), activation]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim),
activation]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
class Encoder(nn.Module):
def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d):
super(Encoder, self).__init__()
self.output_nc = output_nc
model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
norm_layer(ngf), nn.ReLU(True)]
### downsample
for i in range(n_downsampling):
mult = 2**i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
norm_layer(ngf * mult * 2), nn.ReLU(True)]
### upsample
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), nn.ReLU(True)]
model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input, inst):
outputs = self.model(input)
# instance-wise average pooling
outputs_mean = outputs.clone()
inst_list = np.unique(inst.cpu().numpy().astype(int))
for i in inst_list:
for b in range(input.size()[0]):
indices = (inst[b:b+1] == int(i)).nonzero() # n x 4
for j in range(self.output_nc):
output_ins = outputs[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]]
mean_feat = torch.mean(output_ins).expand_as(output_ins)
outputs_mean[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat
return outputs_mean
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
use_sigmoid=False, num_D=3, getIntermFeat=False):
super(MultiscaleDiscriminator, self).__init__()
self.num_D = num_D
self.n_layers = n_layers
self.getIntermFeat = getIntermFeat
for i in range(num_D):
netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat)
if getIntermFeat:
for j in range(n_layers+2):
setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j)))
else:
setattr(self, 'layer'+str(i), netD.model)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def singleD_forward(self, model, input):
if self.getIntermFeat:
result = [input]
for i in range(len(model)):
result.append(model[i](result[-1]))
return result[1:]
else:
return [model(input)]
def forward(self, input):
num_D = self.num_D
result = []
input_downsampled = input
for i in range(num_D):
if self.getIntermFeat:
model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)]
else:
model = getattr(self, 'layer'+str(num_D-1-i))
result.append(self.singleD_forward(model, input_downsampled))
if i != (num_D-1):
input_downsampled = self.downsample(input_downsampled)
return result
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False):
super(NLayerDiscriminator, self).__init__()
self.getIntermFeat = getIntermFeat
self.n_layers = n_layers
kw = 4
padw = int(np.ceil((kw-1.0)/2))
sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]]
nf = ndf
for n in range(1, n_layers):
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw),
norm_layer(nf), nn.LeakyReLU(0.2, True)
]]
nf_prev = nf
nf = min(nf * 2, 512)
sequence += [[
nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
norm_layer(nf),
nn.LeakyReLU(0.2, True)
]]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
if use_sigmoid:
sequence += [[nn.Sigmoid()]]
if getIntermFeat:
for n in range(len(sequence)):
setattr(self, 'model'+str(n), nn.Sequential(*sequence[n]))
else:
sequence_stream = []
for n in range(len(sequence)):
sequence_stream += sequence[n]
self.model = nn.Sequential(*sequence_stream)
def forward(self, input):
if self.getIntermFeat:
res = [input]
for n in range(self.n_layers+2):
model = getattr(self, 'model'+str(n))
res.append(model(res[-1]))
return res[1:]
else:
return self.model(input)
from torchvision import models
class Vgg19(torch.nn.Module):
def __init__(self, requires_grad=False):
super(Vgg19, self).__init__()
vgg_pretrained_features = models.vgg19(pretrained=True).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
for x in range(2):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(2, 7):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(7, 12):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(12, 21):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(21, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h_relu1 = self.slice1(X)
h_relu2 = self.slice2(h_relu1)
h_relu3 = self.slice3(h_relu2)
h_relu4 = self.slice4(h_relu3)
h_relu5 = self.slice5(h_relu4)
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
return out
================================================
FILE: models/pix2pixHD_model.py
================================================
import numpy as np
import torch
import os
from torch.autograd import Variable
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class Pix2PixHDModel(BaseModel):
def name(self):
return 'Pix2PixHDModel'
def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss):
flags = (True, use_gan_feat_loss, use_vgg_loss, True, True)
def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake):
return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f]
return loss_filter
def initialize(self, opt):
BaseModel.initialize(self, opt)
if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM
torch.backends.cudnn.benchmark = True
self.isTrain = opt.isTrain
self.use_features = opt.instance_feat or opt.label_feat
self.gen_features = self.use_features and not self.opt.load_features
input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc
##### define networks
# Generator network
netG_input_nc = input_nc
if not opt.no_instance:
netG_input_nc += 1
if self.use_features:
netG_input_nc += opt.feat_num
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
# Discriminator network
if self.isTrain:
use_sigmoid = opt.no_lsgan
netD_input_nc = input_nc + opt.output_nc
if not opt.no_instance:
netD_input_nc += 1
self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid,
opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids)
### Encoder network
if self.gen_features:
self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder',
opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids)
if self.opt.verbose:
print('---------- Networks initialized -------------')
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path)
if self.isTrain:
self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path)
if self.gen_features:
self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path)
# set loss functions and optimizers
if self.isTrain:
if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
self.fake_pool = ImagePool(opt.pool_size)
self.old_lr = opt.lr
# define loss functions
self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss)
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionFeat = torch.nn.L1Loss()
if not opt.no_vgg_loss:
self.criterionVGG = networks.VGGLoss(self.gpu_ids)
# Names so we can breakout loss
self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake')
# initialize optimizers
# optimizer G
if opt.niter_fix_global > 0:
import sys
if sys.version_info >= (3,0):
finetune_list = set()
else:
from sets import Set
finetune_list = Set()
params_dict = dict(self.netG.named_parameters())
params = []
for key, value in params_dict.items():
if key.startswith('model' + str(opt.n_local_enhancers)):
params += [value]
finetune_list.add(key.split('.')[0])
print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global)
print('The layers that are finetuned are ', sorted(finetune_list))
else:
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
# optimizer D
params = list(self.netD.parameters())
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False):
if self.opt.label_nc == 0:
input_label = label_map.data.cuda()
else:
# create one-hot vector for label map
size = label_map.size()
oneHot_size = (size[0], self.opt.label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
if self.opt.data_type == 16:
input_label = input_label.half()
# get edges from instance map
if not self.opt.no_instance:
inst_map = inst_map.data.cuda()
edge_map = self.get_edges(inst_map)
input_label = torch.cat((input_label, edge_map), dim=1)
input_label = Variable(input_label, volatile=infer)
# real images for training
if real_image is not None:
real_image = Variable(real_image.data.cuda())
# instance map for feature encoding
if self.use_features:
# get precomputed feature maps
if self.opt.load_features:
feat_map = Variable(feat_map.data.cuda())
if self.opt.label_feat:
inst_map = label_map.cuda()
return input_label, inst_map, real_image, feat_map
def discriminate(self, input_label, test_image, use_pool=False):
input_concat = torch.cat((input_label, test_image.detach()), dim=1)
if use_pool:
fake_query = self.fake_pool.query(input_concat)
return self.netD.forward(fake_query)
else:
return self.netD.forward(input_concat)
def forward(self, label, inst, image, feat, infer=False):
# Encode Inputs
input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat)
# Fake Generation
if self.use_features:
if not self.opt.load_features:
feat_map = self.netE.forward(real_image, inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
fake_image = self.netG.forward(input_concat)
# Fake Detection and Loss
pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True)
loss_D_fake = self.criterionGAN(pred_fake_pool, False)
# Real Detection and Loss
pred_real = self.discriminate(input_label, real_image)
loss_D_real = self.criterionGAN(pred_real, True)
# GAN loss (Fake Passability Loss)
pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1))
loss_G_GAN = self.criterionGAN(pred_fake, True)
# GAN feature matching loss
loss_G_GAN_Feat = 0
if not self.opt.no_ganFeat_loss:
feat_weights = 4.0 / (self.opt.n_layers_D + 1)
D_weights = 1.0 / self.opt.num_D
for i in range(self.opt.num_D):
for j in range(len(pred_fake[i])-1):
loss_G_GAN_Feat += D_weights * feat_weights * \
self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat
# VGG feature matching loss
loss_G_VGG = 0
if not self.opt.no_vgg_loss:
loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat
# Only return the fake_B image if necessary to save BW
return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ]
def inference(self, label, inst, image=None):
# Encode Inputs
image = Variable(image) if image is not None else None
input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True)
# Fake Generation
if self.use_features:
if self.opt.use_encoded_image:
# encode the real image to get feature map
feat_map = self.netE.forward(real_image, inst_map)
else:
# sample clusters from precomputed features
feat_map = self.sample_features(inst_map)
input_concat = torch.cat((input_label, feat_map), dim=1)
else:
input_concat = input_label
if torch.__version__.startswith('0.4'):
with torch.no_grad():
fake_image = self.netG.forward(input_concat)
else:
fake_image = self.netG.forward(input_concat)
return fake_image
def sample_features(self, inst):
# read precomputed feature clusters
cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path)
features_clustered = np.load(cluster_path, encoding='latin1').item()
# randomly sample from the feature clusters
inst_np = inst.cpu().numpy().astype(int)
feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3])
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
if label in features_clustered:
feat = features_clustered[label]
cluster_idx = np.random.randint(0, feat.shape[0])
idx = (inst == int(i)).nonzero()
for k in range(self.opt.feat_num):
feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
if self.opt.data_type==16:
feat_map = feat_map.half()
return feat_map
def encode_features(self, image, inst):
image = Variable(image.cuda(), volatile=True)
feat_num = self.opt.feat_num
h, w = inst.size()[2], inst.size()[3]
block_num = 32
feat_map = self.netE.forward(image, inst.cuda())
inst_np = inst.cpu().numpy().astype(int)
feature = {}
for i in range(self.opt.label_nc):
feature[i] = np.zeros((0, feat_num+1))
for i in np.unique(inst_np):
label = i if i < 1000 else i//1000
idx = (inst == int(i)).nonzero()
num = idx.size()[0]
idx = idx[num//2,:]
val = np.zeros((1, feat_num+1))
for k in range(feat_num):
val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0]
val[0, feat_num] = float(num) / (h * w // block_num)
feature[label] = np.append(feature[label], val, axis=0)
return feature
def get_edges(self, t):
edge = torch.cuda.ByteTensor(t.size()).zero_()
edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
if self.opt.data_type==16:
return edge.half()
else:
return edge.float()
def save(self, which_epoch):
self.save_network(self.netG, 'G', which_epoch, self.gpu_ids)
self.save_network(self.netD, 'D', which_epoch, self.gpu_ids)
if self.gen_features:
self.save_network(self.netE, 'E', which_epoch, self.gpu_ids)
def update_fixed_params(self):
# after fixing the global generator for a number of iterations, also start finetuning it
params = list(self.netG.parameters())
if self.gen_features:
params += list(self.netE.parameters())
self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999))
if self.opt.verbose:
print('------------ Now also finetuning global generator -----------')
def update_learning_rate(self):
lrd = self.opt.lr / self.opt.niter_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = lr
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr
if self.opt.verbose:
print('update learning rate: %f -> %f' % (self.old_lr, lr))
self.old_lr = lr
class InferenceModel(Pix2PixHDModel):
def forward(self, inp):
label, inst = inp
return self.inference(label, inst)
================================================
FILE: models/ui_model.py
================================================
import torch
from torch.autograd import Variable
from collections import OrderedDict
import numpy as np
import os
from PIL import Image
import util.util as util
from .base_model import BaseModel
from . import networks
class UIModel(BaseModel):
def name(self):
return 'UIModel'
def initialize(self, opt):
assert(not opt.isTrain)
BaseModel.initialize(self, opt)
self.use_features = opt.instance_feat or opt.label_feat
netG_input_nc = opt.label_nc
if not opt.no_instance:
netG_input_nc += 1
if self.use_features:
netG_input_nc += opt.feat_num
self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
self.load_network(self.netG, 'G', opt.which_epoch)
print('---------- Networks initialized -------------')
def toTensor(self, img, normalize=False):
tensor = torch.from_numpy(np.array(img, np.int32, copy=False))
tensor = tensor.view(1, img.size[1], img.size[0], len(img.mode))
tensor = tensor.transpose(1, 2).transpose(1, 3).contiguous()
if normalize:
return (tensor.float()/255.0 - 0.5) / 0.5
return tensor.float()
def load_image(self, label_path, inst_path, feat_path):
opt = self.opt
# read label map
label_img = Image.open(label_path)
if label_path.find('face') != -1:
label_img = label_img.convert('L')
ow, oh = label_img.size
w = opt.loadSize
h = int(w * oh / ow)
label_img = label_img.resize((w, h), Image.NEAREST)
label_map = self.toTensor(label_img)
# onehot vector input for label map
self.label_map = label_map.cuda()
oneHot_size = (1, opt.label_nc, h, w)
input_label = self.Tensor(torch.Size(oneHot_size)).zero_()
self.input_label = input_label.scatter_(1, label_map.long().cuda(), 1.0)
# read instance map
if not opt.no_instance:
inst_img = Image.open(inst_path)
inst_img = inst_img.resize((w, h), Image.NEAREST)
self.inst_map = self.toTensor(inst_img).cuda()
self.edge_map = self.get_edges(self.inst_map)
self.net_input = Variable(torch.cat((self.input_label, self.edge_map), dim=1), volatile=True)
else:
self.net_input = Variable(self.input_label, volatile=True)
self.features_clustered = np.load(feat_path).item()
self.object_map = self.inst_map if opt.instance_feat else self.label_map
object_np = self.object_map.cpu().numpy().astype(int)
self.feat_map = self.Tensor(1, opt.feat_num, h, w).zero_()
self.cluster_indices = np.zeros(self.opt.label_nc, np.uint8)
for i in np.unique(object_np):
label = i if i < 1000 else i//1000
if label in self.features_clustered:
feat = self.features_clustered[label]
np.random.seed(i+1)
cluster_idx = np.random.randint(0, feat.shape[0])
self.cluster_indices[label] = cluster_idx
idx = (self.object_map == i).nonzero()
self.set_features(idx, feat, cluster_idx)
self.net_input_original = self.net_input.clone()
self.label_map_original = self.label_map.clone()
self.feat_map_original = self.feat_map.clone()
if not opt.no_instance:
self.inst_map_original = self.inst_map.clone()
def reset(self):
self.net_input = self.net_input_prev = self.net_input_original.clone()
self.label_map = self.label_map_prev = self.label_map_original.clone()
self.feat_map = self.feat_map_prev = self.feat_map_original.clone()
if not self.opt.no_instance:
self.inst_map = self.inst_map_prev = self.inst_map_original.clone()
self.object_map = self.inst_map if self.opt.instance_feat else self.label_map
def undo(self):
self.net_input = self.net_input_prev
self.label_map = self.label_map_prev
self.feat_map = self.feat_map_prev
if not self.opt.no_instance:
self.inst_map = self.inst_map_prev
self.object_map = self.inst_map if self.opt.instance_feat else self.label_map
# get boundary map from instance map
def get_edges(self, t):
edge = torch.cuda.ByteTensor(t.size()).zero_()
edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
return edge.float()
# change the label at the source position to the label at the target position
def change_labels(self, click_src, click_tgt):
y_src, x_src = click_src[0], click_src[1]
y_tgt, x_tgt = click_tgt[0], click_tgt[1]
label_src = int(self.label_map[0, 0, y_src, x_src])
inst_src = self.inst_map[0, 0, y_src, x_src]
label_tgt = int(self.label_map[0, 0, y_tgt, x_tgt])
inst_tgt = self.inst_map[0, 0, y_tgt, x_tgt]
idx_src = (self.inst_map == inst_src).nonzero()
# need to change 3 things: label map, instance map, and feature map
if idx_src.shape:
# backup current maps
self.backup_current_state()
# change both the label map and the network input
self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
self.net_input[idx_src[:,0], idx_src[:,1] + label_src, idx_src[:,2], idx_src[:,3]] = 0
self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
# update the instance map (and the network input)
if inst_tgt > 1000:
# if different instances have different ids, give the new object a new id
tgt_indices = (self.inst_map > label_tgt * 1000) & (self.inst_map < (label_tgt+1) * 1000)
inst_tgt = self.inst_map[tgt_indices].max() + 1
self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = inst_tgt
self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
# also copy the source features to the target position
idx_tgt = (self.inst_map == inst_tgt).nonzero()
if idx_tgt.shape:
self.copy_features(idx_src, idx_tgt[0,:])
self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
# add strokes of target label in the image
def add_strokes(self, click_src, label_tgt, bw, save):
# get the region of the new strokes (bw is the brush width)
size = self.net_input.size()
h, w = size[2], size[3]
idx_src = torch.LongTensor(bw**2, 4).fill_(0)
for i in range(bw):
idx_src[i*bw:(i+1)*bw, 2] = min(h-1, max(0, click_src[0]-bw//2 + i))
for j in range(bw):
idx_src[i*bw+j, 3] = min(w-1, max(0, click_src[1]-bw//2 + j))
idx_src = idx_src.cuda()
# again, need to update 3 things
if idx_src.shape:
# backup current maps
if save:
self.backup_current_state()
# update the label map (and the network input) in the stroke region
self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
for k in range(self.opt.label_nc):
self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0
self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
# update the instance map (and the network input)
self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
# also update the features if available
if self.opt.instance_feat:
feat = self.features_clustered[label_tgt]
#np.random.seed(label_tgt+1)
#cluster_idx = np.random.randint(0, feat.shape[0])
cluster_idx = self.cluster_indices[label_tgt]
self.set_features(idx_src, feat, cluster_idx)
self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
# add an object to the clicked position with selected style
def add_objects(self, click_src, label_tgt, mask, style_id=0):
y, x = click_src[0], click_src[1]
mask = np.transpose(mask, (2, 0, 1))[np.newaxis,...]
idx_src = torch.from_numpy(mask).cuda().nonzero()
idx_src[:,2] += y
idx_src[:,3] += x
# backup current maps
self.backup_current_state()
# update label map
self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
for k in range(self.opt.label_nc):
self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0
self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
# update instance map
self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
# update feature map
self.set_features(idx_src, self.feat, style_id)
self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
def single_forward(self, net_input, feat_map):
net_input = torch.cat((net_input, feat_map), dim=1)
fake_image = self.netG.forward(net_input)
if fake_image.size()[0] == 1:
return fake_image.data[0]
return fake_image.data
# generate all outputs for different styles
def style_forward(self, click_pt, style_id=-1):
if click_pt is None:
self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
self.crop = None
self.mask = None
else:
instToChange = int(self.object_map[0, 0, click_pt[0], click_pt[1]])
self.instToChange = instToChange
label = instToChange if instToChange < 1000 else instToChange//1000
self.feat = self.features_clustered[label]
self.fake_image = []
self.mask = self.object_map == instToChange
idx = self.mask.nonzero()
self.get_crop_region(idx)
if idx.size():
if style_id == -1:
(min_y, min_x, max_y, max_x) = self.crop
### original
for cluster_idx in range(self.opt.multiple_output):
self.set_features(idx, self.feat, cluster_idx)
fake_image = self.single_forward(self.net_input, self.feat_map)
fake_image = util.tensor2im(fake_image[:,min_y:max_y,min_x:max_x])
self.fake_image.append(fake_image)
"""### To speed up previewing different style results, either crop or downsample the label maps
if instToChange > 1000:
(min_y, min_x, max_y, max_x) = self.crop
### crop
_, _, h, w = self.net_input.size()
offset = 512
y_start, x_start = max(0, min_y-offset), max(0, min_x-offset)
y_end, x_end = min(h, (max_y + offset)), min(w, (max_x + offset))
y_region = slice(y_start, y_start+(y_end-y_start)//16*16)
x_region = slice(x_start, x_start+(x_end-x_start)//16*16)
net_input = self.net_input[:,:,y_region,x_region]
for cluster_idx in range(self.opt.multiple_output):
self.set_features(idx, self.feat, cluster_idx)
fake_image = self.single_forward(net_input, self.feat_map[:,:,y_region,x_region])
fake_image = util.tensor2im(fake_image[:,min_y-y_start:max_y-y_start,min_x-x_start:max_x-x_start])
self.fake_image.append(fake_image)
else:
### downsample
(min_y, min_x, max_y, max_x) = [crop//2 for crop in self.crop]
net_input = self.net_input[:,:,::2,::2]
size = net_input.size()
net_input_batch = net_input.expand(self.opt.multiple_output, size[1], size[2], size[3])
for cluster_idx in range(self.opt.multiple_output):
self.set_features(idx, self.feat, cluster_idx)
feat_map = self.feat_map[:,:,::2,::2]
if cluster_idx == 0:
feat_map_batch = feat_map
else:
feat_map_batch = torch.cat((feat_map_batch, feat_map), dim=0)
fake_image_batch = self.single_forward(net_input_batch, feat_map_batch)
for i in range(self.opt.multiple_output):
self.fake_image.append(util.tensor2im(fake_image_batch[i,:,min_y:max_y,min_x:max_x]))"""
else:
self.set_features(idx, self.feat, style_id)
self.cluster_indices[label] = style_id
self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
def backup_current_state(self):
self.net_input_prev = self.net_input.clone()
self.label_map_prev = self.label_map.clone()
self.inst_map_prev = self.inst_map.clone()
self.feat_map_prev = self.feat_map.clone()
# crop the ROI and get the mask of the object
def get_crop_region(self, idx):
size = self.net_input.size()
h, w = size[2], size[3]
min_y, min_x = idx[:,2].min(), idx[:,3].min()
max_y, max_x = idx[:,2].max(), idx[:,3].max()
crop_min = 128
if max_y - min_y < crop_min:
min_y = max(0, (max_y + min_y) // 2 - crop_min // 2)
max_y = min(h-1, min_y + crop_min)
if max_x - min_x < crop_min:
min_x = max(0, (max_x + min_x) // 2 - crop_min // 2)
max_x = min(w-1, min_x + crop_min)
self.crop = (min_y, min_x, max_y, max_x)
self.mask = self.mask[:,:, min_y:max_y, min_x:max_x]
# update the feature map once a new object is added or the label is changed
def update_features(self, cluster_idx, mask=None, click_pt=None):
self.feat_map_prev = self.feat_map.clone()
# adding a new object
if mask is not None:
y, x = click_pt[0], click_pt[1]
mask = np.transpose(mask, (2,0,1))[np.newaxis,...]
idx = torch.from_numpy(mask).cuda().nonzero()
idx[:,2] += y
idx[:,3] += x
# changing the label of an existing object
else:
idx = (self.object_map == self.instToChange).nonzero()
# update feature map
self.set_features(idx, self.feat, cluster_idx)
# set the class features to the target feature
def set_features(self, idx, feat, cluster_idx):
for k in range(self.opt.feat_num):
self.feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
# copy the features at the target position to the source position
def copy_features(self, idx_src, idx_tgt):
for k in range(self.opt.feat_num):
val = self.feat_map[idx_tgt[0], idx_tgt[1] + k, idx_tgt[2], idx_tgt[3]]
self.feat_map[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = val
def get_current_visuals(self, getLabel=False):
mask = self.mask
if self.mask is not None:
mask = np.transpose(self.mask[0].cpu().float().numpy(), (1,2,0)).astype(np.uint8)
dict_list = [('fake_image', self.fake_image), ('mask', mask)]
if getLabel: # only output label map if needed to save bandwidth
label = util.tensor2label(self.net_input.data[0], self.opt.label_nc)
dict_list += [('label', label)]
return OrderedDict(dict_list)
================================================
FILE: options/__init__.py
================================================
================================================
FILE: options/base_options.py
================================================
import argparse
import os
from util import util
import torch
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self):
# experiment specifics
self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use')
self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
self.parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP')
self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
# input/output sizes
self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size')
self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
self.parser.add_argument('--label_nc', type=int, default=35, help='# of input label channels')
self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
# for setting inputs
self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/')
self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
# for displays
self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
# for generator
self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG')
self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network')
self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')
# for instance-wise features
self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')
self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input')
self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input')
self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features')
self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps')
self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder')
self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features')
self.initialized = True
def parse(self, save=True):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
str_ids = self.opt.gpu_ids.split(',')
self.opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
self.opt.gpu_ids.append(id)
# set gpu ids
if len(self.opt.gpu_ids) > 0:
torch.cuda.set_device(self.opt.gpu_ids[0])
args = vars(self.opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
util.mkdirs(expr_dir)
if save and not self.opt.continue_train:
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt
================================================
FILE: options/test_options.py
================================================
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map')
self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
self.isTrain = False
================================================
FILE: options/train_options.py
================================================
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
# for displays
self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
# for training
self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
# for discriminators
self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
self.isTrain = True
================================================
FILE: precompute_feature_maps.py
================================================
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import os
import util.util as util
from torch.autograd import Variable
import torch.nn as nn
opt = TrainOptions().parse()
opt.nThreads = 1
opt.batchSize = 1
opt.serial_batches = True
opt.no_flip = True
opt.instance_feat = True
name = 'features'
save_path = os.path.join(opt.checkpoints_dir, opt.name)
############ Initialize #########
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
model = create_model(opt)
util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat'))
######## Save precomputed feature maps for 1024p training #######
for i, data in enumerate(dataset):
print('%d / %d images' % (i+1, dataset_size))
feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda())
feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map)
image_numpy = util.tensor2im(feat_map.data[0])
save_path = data['path'][0].replace('/train_label/', '/train_feat/')
util.save_image(image_numpy, save_path)
================================================
FILE: run_engine.py
================================================
import os
import sys
from random import randint
import numpy as np
import tensorrt
try:
from PIL import Image
import pycuda.driver as cuda
import pycuda.gpuarray as gpuarray
import pycuda.autoinit
import argparse
except ImportError as err:
sys.stderr.write("""ERROR: failed to import module ({})
Please make sure you have pycuda and the example dependencies installed.
https://wiki.tiker.net/PyCuda/Installation/Linux
pip(3) install tensorrt[examples]
""".format(err))
exit(1)
try:
import tensorrt as trt
from tensorrt.parsers import caffeparser
from tensorrt.parsers import onnxparser
except ImportError as err:
sys.stderr.write("""ERROR: failed to import module ({})
Please make sure you have the TensorRT Library installed
and accessible in your LD_LIBRARY_PATH
""".format(err))
exit(1)
G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO)
class Profiler(trt.infer.Profiler):
"""
Example Implimentation of a Profiler
Is identical to the Profiler class in trt.infer so it is possible
to just use that instead of implementing this if further
functionality is not needed
"""
def __init__(self, timing_iter):
trt.infer.Profiler.__init__(self)
self.timing_iterations = timing_iter
self.profile = []
def report_layer_time(self, layerName, ms):
record = next((r for r in self.profile if r[0] == layerName), (None, None))
if record == (None, None):
self.profile.append((layerName, ms))
else:
self.profile[self.profile.index(record)] = (record[0], record[1] + ms)
def print_layer_times(self):
totalTime = 0
for i in range(len(self.profile)):
print("{:40.40} {:4.3f}ms".format(self.profile[i][0], self.profile[i][1] / self.timing_iterations))
totalTime += self.profile[i][1]
print("Time over all layers: {:4.2f} ms per iteration".format(totalTime / self.timing_iterations))
def get_input_output_names(trt_engine):
nbindings = trt_engine.get_nb_bindings();
maps = []
for b in range(0, nbindings):
dims = trt_engine.get_binding_dimensions(b).to_DimsCHW()
name = trt_engine.get_binding_name(b)
type = trt_engine.get_binding_data_type(b)
if (trt_engine.binding_is_input(b)):
maps.append(name)
print("Found input: ", name)
else:
maps.append(name)
print("Found output: ", name)
print("shape=" + str(dims.C()) + " , " + str(dims.H()) + " , " + str(dims.W()))
print("dtype=" + str(type))
return maps
def create_memory(engine, name, buf, mem, batchsize, inp, inp_idx):
binding_idx = engine.get_binding_index(name)
if binding_idx == -1:
raise AttributeError("Not a valid binding")
print("Binding: name={}, bindingIndex={}".format(name, str(binding_idx)))
dims = engine.get_binding_dimensions(binding_idx).to_DimsCHW()
eltCount = dims.C() * dims.H() * dims.W() * batchsize
if engine.binding_is_input(binding_idx):
h_mem = inp[inp_idx]
inp_idx = inp_idx + 1
else:
h_mem = np.random.uniform(0.0, 255.0, eltCount).astype(np.dtype('f4'))
d_mem = cuda.mem_alloc(eltCount * 4)
cuda.memcpy_htod(d_mem, h_mem)
buf.insert(binding_idx, int(d_mem))
mem.append(d_mem)
return inp_idx
#Run inference on device
def time_inference(engine, batch_size, inp):
bindings = []
mem = []
inp_idx = 0
for io in get_input_output_names(engine):
inp_idx = create_memory(engine, io, bindings, mem,
batch_size, inp, inp_idx)
context = engine.create_execution_context()
g_prof = Profiler(500)
context.set_profiler(g_prof)
for i in range(iter):
context.execute(batch_size, bindings)
g_prof.print_layer_times()
context.destroy()
return
def convert_to_datatype(v):
if v==8:
return trt.infer.DataType.INT8
elif v==16:
return trt.infer.DataType.HALF
elif v==32:
return trt.infer.DataType.FLOAT
else:
print("ERROR: Invalid model data type bit depth: " + str(v))
return trt.infer.DataType.INT8
def run_trt_engine(engine_file, bs, it):
engine = trt.utils.load_engine(G_LOGGER, engine_file)
time_inference(engine, bs, it)
def run_onnx(onnx_file, data_type, bs, inp):
# Create onnx_config
apex = onnxparser.create_onnxconfig()
apex.set_model_file_name(onnx_file)
apex.set_model_dtype(convert_to_datatype(data_type))
# create parser
trt_parser = onnxparser.create_onnxparser(apex)
assert(trt_parser)
data_type = apex.get_model_dtype()
onnx_filename = apex.get_model_file_name()
trt_parser.parse(onnx_filename, data_type)
trt_parser.report_parsing_info()
trt_parser.convert_to_trtnetwork()
trt_network = trt_parser.get_trtnetwork()
assert(trt_network)
# create infer builder
trt_builder = trt.infer.create_infer_builder(G_LOGGER)
trt_builder.set_max_batch_size(max_batch_size)
trt_builder.set_max_workspace_size(max_workspace_size)
if (apex.get_model_dtype() == trt.infer.DataType_kHALF):
print("------------------- Running FP16 -----------------------------")
trt_builder.set_half2_mode(True)
elif (apex.get_model_dtype() == trt.infer.DataType_kINT8):
print("------------------- Running INT8 -----------------------------")
trt_builder.set_int8_mode(True)
else:
print("------------------- Running FP32 -----------------------------")
print("----- Builder is Done -----")
print("----- Creating Engine -----")
trt_engine = trt_builder.build_cuda_engine(trt_network)
print("----- Engine is built -----")
time_inference(engine, bs, inp)
================================================
FILE: scripts/test_1024p.sh
================================================
#!/bin/bash
################################ Testing ################################
# labels only
python test.py --name label2city_1024p --netG local --ngf 32 --resize_or_crop none $@
================================================
FILE: scripts/test_1024p_feat.sh
================================================
################################ Testing ################################
# first precompute and cluster all features
python encode_features.py --name label2city_1024p_feat --netG local --ngf 32 --resize_or_crop none;
# use instance-wise features
python test.py --name label2city_1024p_feat ---netG local --ngf 32 --resize_or_crop none --instance_feat
================================================
FILE: scripts/test_512p.sh
================================================
################################ Testing ################################
# labels only
python test.py --name label2city_512p
================================================
FILE: scripts/test_512p_feat.sh
================================================
################################ Testing ################################
# first precompute and cluster all features
python encode_features.py --name label2city_512p_feat;
# use instance-wise features
python test.py --name label2city_512p_feat --instance_feat
================================================
FILE: scripts/train_1024p_12G.sh
================================================
############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
##### Using GPUs with 12G memory (not tested)
# Using labels only
python train.py --name label2city_1024p --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p/ --niter_fix_global 20 --resize_or_crop crop --fineSize 1024
================================================
FILE: scripts/train_1024p_24G.sh
================================================
############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
######## Using GPUs with 24G memory
# Using labels only
python train.py --name label2city_1024p --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p/ --niter 50 --niter_decay 50 --niter_fix_global 10 --resize_or_crop none
================================================
FILE: scripts/train_1024p_feat_12G.sh
================================================
############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
##### Using GPUs with 12G memory (not tested)
# First precompute feature maps and save them
python precompute_feature_maps.py --name label2city_512p_feat;
# Adding instances and encoded features
python train.py --name label2city_1024p_feat --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p_feat/ --niter_fix_global 20 --resize_or_crop crop --fineSize 896 --instance_feat --load_features
================================================
FILE: scripts/train_1024p_feat_24G.sh
================================================
############## To train images at 2048 x 1024 resolution after training 1024 x 512 resolution models #############
######## Using GPUs with 24G memory
# First precompute feature maps and save them
python precompute_feature_maps.py --name label2city_512p_feat;
# Adding instances and encoded features
python train.py --name label2city_1024p_feat --netG local --ngf 32 --num_D 3 --load_pretrain checkpoints/label2city_512p_feat/ --niter 50 --niter_decay 50 --niter_fix_global 10 --resize_or_crop none --instance_feat --load_features
================================================
FILE: scripts/train_512p.sh
================================================
### Using labels only
python train.py --name label2city_512p
================================================
FILE: scripts/train_512p_feat.sh
================================================
### Adding instances and encoded features
python train.py --name label2city_512p_feat --instance_feat
================================================
FILE: scripts/train_512p_fp16.sh
================================================
### Using labels only
python -m torch.distributed.launch train.py --name label2city_512p --fp16
================================================
FILE: scripts/train_512p_fp16_multigpu.sh
================================================
######## Multi-GPU training example #######
python -m torch.distributed.launch train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7 --fp16
================================================
FILE: scripts/train_512p_multigpu.sh
================================================
######## Multi-GPU training example #######
python train.py --name label2city_512p --batchSize 8 --gpu_ids 0,1,2,3,4,5,6,7
================================================
FILE: test.py
================================================
import os
from collections import OrderedDict
from torch.autograd import Variable
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
from util import html
import torch
opt = TestOptions().parse(save=False)
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
if not opt.engine and not opt.onnx:
model = create_model(opt)
if opt.data_type == 16:
model.half()
elif opt.data_type == 8:
model.type(torch.uint8)
if opt.verbose:
print(model)
else:
from run_engine import run_trt_engine, run_onnx
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
if opt.data_type == 16:
data['label'] = data['label'].half()
data['inst'] = data['inst'].half()
elif opt.data_type == 8:
data['label'] = data['label'].uint8()
data['inst'] = data['inst'].uint8()
if opt.export_onnx:
print ("Exporting to ONNX: ", opt.export_onnx)
assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx"
torch.onnx.export(model, [data['label'], data['inst']],
opt.export_onnx, verbose=True)
exit(0)
minibatch = 1
if opt.engine:
generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']])
elif opt.onnx:
generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']])
else:
generated = model.inference(data['label'], data['inst'], data['image'])
visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
('synthesized_image', util.tensor2im(generated.data[0]))])
img_path = data['path']
print('process image... %s' % img_path)
visualizer.save_images(webpage, visuals, img_path)
webpage.save()
================================================
FILE: train.py
================================================
import time
import os
import numpy as np
import torch
from torch.autograd import Variable
from collections import OrderedDict
from subprocess import call
import fractions
def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
opt = TrainOptions().parse()
iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
if opt.continue_train:
try:
start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
except:
start_epoch, epoch_iter = 1, 0
print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
else:
start_epoch, epoch_iter = 1, 0
opt.print_freq = lcm(opt.print_freq, opt.batchSize)
if opt.debug:
opt.display_freq = 1
opt.print_freq = 1
opt.niter = 1
opt.niter_decay = 0
opt.max_dataset_size = 10
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
model = create_model(opt)
visualizer = Visualizer(opt)
if opt.fp16:
from apex import amp
model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D], opt_level='O1')
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
else:
optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D
total_steps = (start_epoch-1) * dataset_size + epoch_iter
display_delta = total_steps % opt.display_freq
print_delta = total_steps % opt.print_freq
save_delta = total_steps % opt.save_latest_freq
for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
if epoch != start_epoch:
epoch_iter = epoch_iter % dataset_size
for i, data in enumerate(dataset, start=epoch_iter):
if total_steps % opt.print_freq == print_delta:
iter_start_time = time.time()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
# whether to collect output images
save_fake = total_steps % opt.display_freq == display_delta
############## Forward Pass ######################
losses, generated = model(Variable(data['label']), Variable(data['inst']),
Variable(data['image']), Variable(data['feat']), infer=save_fake)
# sum per device losses
losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ]
loss_dict = dict(zip(model.module.loss_names, losses))
# calculate final loss scalar
loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5
loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0)
############### Backward Pass ####################
# update generator weights
optimizer_G.zero_grad()
if opt.fp16:
with amp.scale_loss(loss_G, optimizer_G) as scaled_loss: scaled_loss.backward()
else:
loss_G.backward()
optimizer_G.step()
# update discriminator weights
optimizer_D.zero_grad()
if opt.fp16:
with amp.scale_loss(loss_D, optimizer_D) as scaled_loss: scaled_loss.backward()
else:
loss_D.backward()
optimizer_D.step()
############## Display results and errors ##########
### print out errors
if total_steps % opt.print_freq == print_delta:
errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()}
t = (time.time() - iter_start_time) / opt.print_freq
visualizer.print_current_errors(epoch, epoch_iter, errors, t)
visualizer.plot_current_errors(errors, total_steps)
#call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
### display output images
if save_fake:
visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)),
('synthesized_image', util.tensor2im(generated.data[0])),
('real_image', util.tensor2im(data['image'][0]))])
visualizer.display_current_results(visuals, epoch, total_steps)
### save latest model
if total_steps % opt.save_latest_freq == save_delta:
print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
model.module.save('latest')
np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
if epoch_iter >= dataset_size:
break
# end of epoch
iter_end_time = time.time()
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
### save model for this epoch
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
model.module.save('latest')
model.module.save(epoch)
np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
### instead of only training the local enhancer, train the entire network after certain iterations
if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global):
model.module.update_fixed_params()
### linearly decay learning rate after certain iterations
if epoch > opt.niter:
model.module.update_learning_rate()
================================================
FILE: util/__init__.py
================================================
================================================
FILE: util/html.py
================================================
import dominate
from dominate.tags import *
import os
class HTML:
def __init__(self, web_dir, title, refresh=0):
self.title = title
self.web_dir = web_dir
self.img_dir = os.path.join(self.web_dir, 'images')
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
self.doc = dominate.document(title=title)
if refresh > 0:
with self.doc.head:
meta(http_equiv="refresh", content=str(refresh))
def get_image_dir(self):
return self.img_dir
def add_header(self, str):
with self.doc:
h3(str)
def add_table(self, border=1):
self.t = table(border=border, style="table-layout: fixed;")
self.doc.add(self.t)
def add_images(self, ims, txts, links, width=512):
self.add_table()
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % (width), src=os.path.join('images', im))
br()
p(txt)
def save(self):
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__':
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims = []
txts = []
links = []
for n in range(4):
ims.append('image_%d.jpg' % n)
txts.append('text_%d' % n)
links.append('image_%d.jpg' % n)
html.add_images(ims, txts, links)
html.save()
================================================
FILE: util/image_pool.py
================================================
import random
import torch
from torch.autograd import Variable
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
if self.pool_size == 0:
return images
return_images = []
for image in images.data:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size-1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = Variable(torch.cat(return_images, 0))
return return_images
================================================
FILE: util/util.py
================================================
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import numpy as np
import os
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
if isinstance(image_tensor, list):
image_numpy = []
for i in range(len(image_tensor)):
image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
return image_numpy
image_numpy = image_tensor.cpu().float().numpy()
if normalize:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
else:
image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
image_numpy = image_numpy[:,:,0]
return image_numpy.astype(imtype)
# Converts a one-hot tensor into a colorful label map
def tensor2label(label_tensor, n_label, imtype=np.uint8):
if n_label == 0:
return tensor2im(label_tensor, imtype)
label_tensor = label_tensor.cpu().float()
if label_tensor.size()[0] > 1:
label_tensor = label_tensor.max(0, keepdim=True)[1]
label_tensor = Colorize(n_label)(label_tensor)
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
return label_numpy.astype(imtype)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
###############################################################################
# Code from
# https://github.com/ycszen/pytorch-seg/blob/master/transform.py
# Modified so it complies with the Citscape label map colors
###############################################################################
def uint82bin(n, count=8):
"""returns the binary of integer n, count refers to amount of bits"""
return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
def labelcolormap(N):
if N == 35: # cityscape
cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
(128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
(180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
(107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
dtype=np.uint8)
else:
cmap = np.zeros((N, 3), dtype=np.uint8)
for i in range(N):
r, g, b = 0, 0, 0
id = i
for j in range(7):
str_id = uint82bin(id)
r = r ^ (np.uint8(str_id[-1]) << (7-j))
g = g ^ (np.uint8(str_id[-2]) << (7-j))
b = b ^ (np.uint8(str_id[-3]) << (7-j))
id = id >> 3
cmap[i, 0] = r
cmap[i, 1] = g
cmap[i, 2] = b
return cmap
class Colorize(object):
def __init__(self, n=35):
self.cmap = labelcolormap(n)
self.cmap = torch.from_numpy(self.cmap[:n])
def __call__(self, gray_image):
size = gray_image.size()
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
for label in range(0, len(self.cmap)):
mask = (label == gray_image[0]).cpu()
color_image[0][mask] = self.cmap[label][0]
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]
return color_image
================================================
FILE: util/visualizer.py
================================================
import numpy as np
import os
import ntpath
import time
from . import util
from . import html
import scipy.misc
try:
from StringIO import StringIO # Python 2.7
except ImportError:
from io import BytesIO # Python 3.x
class Visualizer():
def __init__(self, opt):
# self.opt = opt
self.tf_log = opt.tf_log
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
if self.tf_log:
import tensorflow as tf
self.tf = tf
self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs')
self.writer = tf.summary.FileWriter(self.log_dir)
if self.use_html:
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, step):
if self.tf_log: # show images in tensorboard output
img_summaries = []
for label, image_numpy in visuals.items():
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(image_numpy).save(s, format="jpeg")
# Create an Image object
img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1])
# Create a Summary value
img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum))
# Create and write Summary
summary = self.tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)
if self.use_html: # save images to a html file
for label, image_numpy in visuals.items():
if isinstance(image_numpy, list):
for i in range(len(image_numpy)):
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i))
util.save_image(image_numpy[i], img_path)
else:
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
if isinstance(image_numpy, list):
for i in range(len(image_numpy)):
img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i)
ims.append(img_path)
txts.append(label+str(i))
links.append(img_path)
else:
img_path = 'epoch%.3d_%s.jpg' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
if len(ims) < 10:
webpage.add_images(ims, txts, links, width=self.win_size)
else:
num = int(round(len(ims)/2.0))
webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size)
webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size)
webpage.save()
# errors: dictionary of error labels and values
def plot_current_errors(self, errors, step):
if self.tf_log:
for tag, value in errors.items():
summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, t):
message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
for k, v in errors.items():
if v != 0:
message += '%s: %.3f ' % (k, v)
print(message)
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message)
# save image to the disk
def save_images(self, webpage, visuals, image_path):
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
image_name = '%s_%s.jpg' % (name, label)
save_path = os.path.join(image_dir, image_name)
util.save_image(image_numpy, save_path)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=self.win_size)