Repository: yiranran/Unpaired-Portrait-Drawing
Branch: master
Commit: b67591912a3e
Files: 35
Total size: 183.1 KB
Directory structure:
gitextract_uyxvzmrx/
├── .gitignore
├── data/
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── image_folder.py
│ ├── single_dataset.py
│ └── unaligned_mask_stylecls_dataset.py
├── models/
│ ├── __init__.py
│ ├── asymmetric_cycle_gan_cls_model.py
│ ├── base_model.py
│ ├── dist_model.py
│ ├── networks.py
│ ├── networks_basic.py
│ ├── pretrained_networks.py
│ ├── test_3styles_model.py
│ └── test_model.py
├── options/
│ ├── __init__.py
│ ├── base_options.py
│ ├── test_options.py
│ └── train_options.py
├── portrait_drawing_resources.md
├── preprocess/
│ ├── example/
│ │ └── ia_selfie_10515_facial5point.mat
│ ├── face_align_512.m
│ └── readme.md
├── readme.md
├── requirements.txt
├── scripts/
│ └── train.sh
├── test.py
├── test_seq_style.py
├── train.py
└── util/
├── __init__.py
├── get_data.py
├── html.py
├── image_pool.py
├── util.py
└── visualizer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.DS_Store
debug*
datasets/
checkpoints/
style_features/
results/
build/
dist/
torch.egg-info/
*/**/__pycache__
torch/version.py
torch/csrc/generic/TensorMethods.cpp
torch/lib/*.so*
torch/lib/*.dylib*
torch/lib/*.h
torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
torch/csrc/nn/THCUNN.cpp
torch/csrc/nn/THNN_generic.cwrap
torch/csrc/nn/THNN_generic.cpp
torch/csrc/nn/THNN_generic.h
docs/src/**/*
test/data/legacy_modules.t7
test/data/gpu_tensors.pt
test/htmlcov
test/.coverage
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
test/data/legacy_serialized.pt
*~
.idea
================================================
FILE: data/__init__.py
================================================
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt)
dataset = data_loader.load_data()
return dataset
class CustomDatasetDataLoader():
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
def __init__(self, opt):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.
"""
self.opt = opt
dataset_class = find_dataset_using_name(opt.dataset_mode)
self.dataset = dataset_class(opt)
print("dataset [%s] was created" % type(self.dataset).__name__)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.num_threads))
def load_data(self):
return self
def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data
================================================
FILE: data/base_dataset.py
================================================
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABCMeta, abstractmethod
class BaseDataset(data.Dataset):
__metaclass__ = ABCMeta
"""This class is an abstract base class (ABC) for datasets.
To create a subclass, you need to implement the following four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the class; save the options in the class
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.opt = opt
self.root = opt.dataroot
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.preprocess == 'resize_and_crop':
new_h = new_w = opt.load_size
elif opt.preprocess == 'scale_width_and_crop':
new_w = opt.load_size
new_h = opt.load_size * h // w
x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
flip = random.random() > 0.5
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
if 'crop' in opt.preprocess:
if params is None:
transform_list.append(transforms.RandomCrop(opt.crop_size))
else:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
if opt.preprocess == 'none':
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
if not opt.no_flip:
if params is None:
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
if convert:
transform_list += [transforms.ToTensor()]
if grayscale:
transform_list += [transforms.Normalize((0.5,), (0.5,))]
else:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def get_transform_mask(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
if 'resize' in opt.preprocess:
osize = [opt.load_size, opt.load_size]
transform_list.append(transforms.Resize(osize, method))
elif 'scale_width' in opt.preprocess:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
if 'crop' in opt.preprocess:
if params is None:
transform_list.append(transforms.RandomCrop(opt.crop_size))
else:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
if opt.preprocess == 'none':
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
if not opt.no_flip:
if params is None:
transform_list.append(transforms.RandomHorizontalFlip())
elif params['flip']:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
if convert:
transform_list += [transforms.ToTensor()]
return transforms.Compose(transform_list)
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
__print_size_warning(ow, oh, w, h)
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
def __print_size_warning(ow, oh, w, h):
"""Print warning information about image size(only print once)"""
if not hasattr(__print_size_warning, 'has_printed'):
print("The image size needs to be a multiple of 4. "
"The loaded image size was (%d, %d), so it was adjusted to "
"(%d, %d). This adjustment will be done to all images "
"whose sizes are not multiples of 4" % (ow, oh, w, h))
__print_size_warning.has_printed = True
================================================
FILE: data/image_folder.py
================================================
"""A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: data/single_dataset.py
================================================
from data.base_dataset import BaseDataset, get_transform, get_params, get_transform_mask
from data.image_folder import make_dataset
from PIL import Image
import torch
import os
class SingleDataset(BaseDataset):
"""This dataset class can load a set of images specified by the path --dataroot /path/to/data.
It can be used for generating CycleGAN results only for one side with the model option '-model test'.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
if os.path.exists(opt.dataroot):
self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
else:
imglistA = 'datasets/list/%s/%s.txt' % (opt.phase+'A', opt.dataroot)
self.A_paths = sorted(open(imglistA, 'r').read().splitlines())
self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns a dictionary that contains A and A_paths
A(tensor) - - an image in one domain
A_paths(str) - - the path of the image
"""
A_path = self.A_paths[index]
A_img = Image.open(A_path).convert('RGB')
self.opt.W, self.opt.H = A_img.size
transform_params_A = get_params(self.opt, A_img.size)
A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img)
item = {'A': A, 'A_paths': A_path}
if self.opt.style_control:
if self.opt.sinput == 'sind':
B_style = torch.Tensor([0.,0.,0.])
B_style[self.opt.sind] = 1.
elif self.opt.sinput == 'svec':
ss = self.opt.svec.split(',')
B_style = torch.Tensor([float(ss[0]),float(ss[1]),float(ss[2])])
elif self.opt.sinput == 'simg':
self.featureloc = os.path.join('style_features/styles2/', self.opt.sfeature_mode)
B_style = np.load(self.featureloc, self.opt.simg[:-4]+'.npy')
B_style = B_style.view(3, 1, 1)
B_style = B_style.repeat(1, 128, 128)
item['B_style'] = B_style
return item
def __len__(self):
"""Return the total number of images in the dataset."""
return len(self.A_paths)
================================================
FILE: data/unaligned_mask_stylecls_dataset.py
================================================
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, get_transform_mask
from data.image_folder import make_dataset
from PIL import Image
import random
import torch
import torchvision.transforms as transforms
import numpy as np
import pdb
class UnalignedMaskStyleClsDataset(BaseDataset):
def __init__(self, opt):
BaseDataset.__init__(self, opt)
self.dir_A = os.path.join(opt.dataroot, opt.phase + '/A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(opt.dataroot, opt.phase + '/B') # create a path '/path/to/data/trainB'
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
print("A size:", self.A_size)
print("B size:", self.B_size)
btoA = self.opt.direction == 'BtoA'
self.input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
self.output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
self.auxdir_A = os.path.join(opt.dataroot, "%s/A" % opt.phase)
self.auxdir_B = os.path.join(opt.dataroot, "%s/B" % opt.phase)
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
if self.opt.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
basenA = os.path.basename(A_path)
A_mask_img = Image.open(os.path.join(self.auxdir_A+'_nose',basenA))
basenB = os.path.basename(B_path)
B_mask_img = Image.open(os.path.join(self.auxdir_B+'_nose',basenB))
if self.opt.use_eye_mask:
A_maske_img = Image.open(os.path.join(self.auxdir_A+'_eyes',basenA))
B_maske_img = Image.open(os.path.join(self.auxdir_B+'_eyes',basenB))
if self.opt.use_lip_mask:
A_maskl_img = Image.open(os.path.join(self.auxdir_A+'_lips',basenA))
B_maskl_img = Image.open(os.path.join(self.auxdir_B+'_lips',basenB))
# apply image transformation
transform_params_A = get_params(self.opt, A_img.size)
transform_params_B = get_params(self.opt, B_img.size)
A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img)
B = get_transform(self.opt, transform_params_B, grayscale=(self.output_nc == 1))(B_img)
A_mask = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_mask_img)
B_mask = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_mask_img)
if self.opt.use_eye_mask:
A_maske = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maske_img)
B_maske = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maske_img)
if self.opt.use_lip_mask:
A_maskl = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maskl_img)
B_maskl = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maskl_img)
item = {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_mask': A_mask, 'B_mask': B_mask}
if self.opt.use_eye_mask:
item['A_maske'] = A_maske
item['B_maske'] = B_maske
if self.opt.use_lip_mask:
item['A_maskl'] = A_maskl
item['B_maskl'] = B_maskl
softmax = np.load(os.path.join(self.auxdir_B+'_feat',basenB[:-4]+'.npy'))
softmax = torch.Tensor(softmax)
[maxv,index] = torch.max(softmax,0)
B_label = index
if len(self.opt.sfeature_mode) >= 8 and self.opt.sfeature_mode[-8:] == '_softmax':
if self.opt.one_hot:
B_style = torch.Tensor([0.,0.,0.])
B_style[index] = 1.
else:
B_style = softmax
B_style = B_style.view(3, 1, 1)
B_style = B_style.repeat(1, 128, 128)
elif self.opt.sfeature_mode == 'domain':
B_style = B_label
item['B_style'] = B_style
item['B_label'] = B_label
if self.opt.isTrain and self.opt.style_loss_with_weight:
item['B_style0'] = softmax
return item
def __len__(self):
return max(self.A_size, self.B_size)
================================================
FILE: models/__init__.py
================================================
"""This package contains modules related to objective functions, optimizations, and network architectures.
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
In the function <__init__>, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""
import importlib
from models.base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
return instance
================================================
FILE: models/asymmetric_cycle_gan_cls_model.py
================================================
import torch
import itertools
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
import models.dist_model as dm # numpy==1.14.3
import torchvision.transforms as transforms
import os
def truncate(fake_B,a=127.5):#[-1,1]
return ((fake_B+1)*a).int().float()/a-1
class AsymmetricCycleGANClsModel(BaseModel):
@staticmethod
def modify_commandline_options(parser, is_train=True):
parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout
parser.set_defaults(dataset_mode='unaligned_mask_stylecls')
parser.add_argument('--netda', type=str, default='basic_cls')
parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A')
parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0 (before insert style)')
parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')
if is_train:
parser.add_argument('--lambda_A', type=float, default=5.0, help='weight for cycle loss (A -> B -> A)')
parser.add_argument('--lambda_B', type=float, default=5.0, help='weight for cycle loss (B -> A -> B)')
parser.add_argument('--lambda_identity', type=float, default=0, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
parser.add_argument('--ntrunc_trunc', type=int, default=1, help='whether use both non-trunc version and trunc version')
parser.add_argument('--trunc_a', type=float, default=31.875, help='multiply which value to round when trunc')
parser.add_argument('--lambda_A_trunc', type=float, default=5.0, help='weight for cycle loss for trunc')
parser.add_argument('--hed_pretrained_mode', type=str, default='./checkpoints/network-bsds500.pytorch', help='path to the pretrained hed model')
parser.add_argument('--lambda_G_A_l', type=float, default=0.5, help='weight for local GAN loss in G')
parser.add_argument('--style_loss_with_weight', type=int, default=1, help='whether multiply prob in style loss')
# for masks
parser.add_argument('--use_mask', type=int, default=1, help='whether use mask for special face region')
parser.add_argument('--use_eye_mask', type=int, default=1, help='whether use mask for special face region')
parser.add_argument('--use_lip_mask', type=int, default=1, help='whether use mask for special face region')
parser.add_argument('--mask_type', type=int, default=3, help='use mask type, 0 outside black, 1 outside white')
# for style control
parser.add_argument('--style_control', type=int, default=1, help='use style_control')
parser.add_argument('--sfeature_mode', type=str, default='1vgg19_softmax', help='vgg19 softmax as feature')
parser.add_argument('--one_hot', type=int, default=0, help='use one-hot for style code')
return parser
def __init__(self, opt):
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
visual_names_A = ['real_A', 'fake_B', 'rec_A']
visual_names_B = ['real_B', 'fake_A', 'rec_B']
if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
visual_names_A.append('idt_B')
visual_names_B.append('idt_A')
if self.isTrain:
visual_names_A.append('real_A_hed')
visual_names_A.append('rec_A_hed')
if self.isTrain and self.opt.ntrunc_trunc:
visual_names_A.append('rec_At')
visual_names_A.append('rec_At_hed')
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'cycle_A2', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'G']
if self.isTrain and self.opt.use_mask:
visual_names_A.append('fake_B_l')
visual_names_A.append('real_B_l')
self.loss_names += ['D_A_l', 'G_A_l']
if self.isTrain and self.opt.use_eye_mask:
visual_names_A.append('fake_B_le')
visual_names_A.append('real_B_le')
self.loss_names += ['D_A_le', 'G_A_le']
if self.isTrain and self.opt.use_lip_mask:
visual_names_A.append('fake_B_ll')
visual_names_A.append('real_B_ll')
self.loss_names += ['D_A_ll', 'G_A_ll']
if not self.isTrain and self.opt.use_mask:
visual_names_A.append('fake_B_l')
visual_names_A.append('real_B_l')
if not self.isTrain and self.opt.use_eye_mask:
visual_names_A.append('fake_B_le')
visual_names_A.append('real_B_le')
if not self.isTrain and self.opt.use_lip_mask:
visual_names_A.append('fake_B_ll')
visual_names_A.append('real_B_ll')
self.loss_names += ['D_A_cls','G_A_cls']
self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
print(self.visual_names)
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
if self.isTrain:
self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
if self.opt.use_mask:
self.model_names += ['D_A_l']
if self.opt.use_eye_mask:
self.model_names += ['D_A_le']
if self.opt.use_lip_mask:
self.model_names += ['D_A_ll']
else: # during test time, only load Gs
self.model_names = ['G_A', 'G_B']
# define networks (both Generators and discriminators)
# The naming is different from those used in the paper.
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
if not self.opt.style_control:
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
else:
print(opt.netga)
print('model0_res', opt.model0_res)
print('model1_res', opt.model1_res)
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)
self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
if self.isTrain: # define discriminators
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netda,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, n_class=3)
self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
if self.opt.use_mask:
if self.opt.mask_type in [2, 3]:
output_nc = opt.output_nc + 1
else:
output_nc = opt.output_nc
self.netD_A_l = networks.define_D(output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
if self.opt.use_eye_mask:
if self.opt.mask_type in [2, 3]:
output_nc = opt.output_nc + 1
else:
output_nc = opt.output_nc
self.netD_A_le = networks.define_D(output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
if self.opt.use_lip_mask:
if self.opt.mask_type in [2, 3]:
output_nc = opt.output_nc + 1
else:
output_nc = opt.output_nc
self.netD_A_ll = networks.define_D(output_nc, opt.ndf, opt.netD,
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
if not self.isTrain:
self.criterionGAN = networks.GANLoss('lsgan').to(self.device)
if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert(opt.input_nc == opt.output_nc)
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
# define loss functions
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
self.criterionCls = torch.nn.CrossEntropyLoss()
self.criterionCls2 = torch.nn.CrossEntropyLoss(reduction='none')
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
if not self.opt.use_mask:
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
elif not self.opt.use_eye_mask:
D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters())
self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
elif not self.opt.use_lip_mask:
D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters())
self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
else:
D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters()) + list(self.netD_A_ll.parameters())
self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
self.lpips = dm.DistModel(opt,model='net-lin',net='alex',use_gpu=True)
self.hed = networks.define_HED(init_weights_=opt.hed_pretrained_mode, gpu_ids_=self.opt.gpu_ids_p)
self.set_requires_grad(self.hed, False)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
if self.opt.use_mask:
self.A_mask = input['A_mask'].to(self.device)
self.B_mask = input['B_mask'].to(self.device)
if self.opt.use_eye_mask:
self.A_maske = input['A_maske'].to(self.device)
self.B_maske = input['B_maske'].to(self.device)
if self.opt.use_lip_mask:
self.A_maskl = input['A_maskl'].to(self.device)
self.B_maskl = input['B_maskl'].to(self.device)
if self.opt.style_control:
self.real_B_style = input['B_style'].to(self.device)
self.real_B_label = input['B_label'].to(self.device)
if self.opt.isTrain and self.opt.style_loss_with_weight:
self.real_B_style0 = input['B_style0'].to(self.device)
self.zero = torch.zeros(self.real_B_label.size(),dtype=torch.int64).to(self.device)
self.one = torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device)
self.two = 2*torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device)
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
if not self.opt.style_control:
self.fake_B = self.netG_A(self.real_A) # G_A(A)
else:
#print(torch.mean(self.real_B_style,(2,3)),'style_control')
self.fake_B = self.netG_A(self.real_A, self.real_B_style)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
self.fake_A = self.netG_B(self.real_B) # G_B(B)
if not self.opt.style_control:
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
else:
#print(torch.mean(self.real_B_style,(2,3)),'style_control')
self.rec_B = self.netG_A(self.fake_A, self.real_B_style) # -- cycle_B loss
if self.opt.use_mask:
self.fake_B_l = self.masked(self.fake_B,self.A_mask)
self.real_B_l = self.masked(self.real_B,self.B_mask)
if self.opt.use_eye_mask:
self.fake_B_le = self.masked(self.fake_B,self.A_maske)
self.real_B_le = self.masked(self.real_B,self.B_maske)
if self.opt.use_lip_mask:
self.fake_B_ll = self.masked(self.fake_B,self.A_maskl)
self.real_B_ll = self.masked(self.real_B,self.B_maskl)
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward()
return loss_D
def backward_D_basic_cls(self, netD, real, fake):
# Real
pred_real, pred_real_cls = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
if not self.opt.style_loss_with_weight:
loss_D_real_cls = self.criterionCls(pred_real_cls, self.real_B_label)
else:
loss_D_real_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_real_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_real_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_real_cls, self.two))
# Fake
pred_fake, pred_fake_cls = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
if not self.opt.style_loss_with_weight:
loss_D_fake_cls = self.criterionCls(pred_fake_cls, self.real_B_label)
else:
loss_D_fake_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two))
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D_cls = (loss_D_real_cls + loss_D_fake_cls) * 0.5
loss_D_total = loss_D + loss_D_cls
loss_D_total.backward()
return loss_D, loss_D_cls
def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A, self.loss_D_A_cls = self.backward_D_basic_cls(self.netD_A, self.real_B, fake_B)
def backward_D_A_l(self):
"""Calculate GAN loss for discriminator D_A_l"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A_l = self.backward_D_basic(self.netD_A_l, self.masked(self.real_B,self.B_mask), self.masked(fake_B,self.A_mask))
def backward_D_A_le(self):
"""Calculate GAN loss for discriminator D_A_le"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A_le = self.backward_D_basic(self.netD_A_le, self.masked(self.real_B,self.B_maske), self.masked(fake_B,self.A_maske))
def backward_D_A_ll(self):
"""Calculate GAN loss for discriminator D_A_ll"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A_ll = self.backward_D_basic(self.netD_A_ll, self.masked(self.real_B,self.B_maskl), self.masked(fake_B,self.A_maskl))
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def update_process(self, epoch):
self.process = (epoch - 1) / float(self.opt.niter_decay + self.opt.niter)
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
lambda_idt = self.opt.lambda_identity
lambda_G_A_l = self.opt.lambda_G_A_l
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
lambda_A_trunc = self.opt.lambda_A_trunc
if self.opt.ntrunc_trunc:
lambda_A = lambda_A * (1 - self.process * 0.9)
lambda_A_trunc = lambda_A_trunc * self.process * 0.9
self.lambda_As = [lambda_A, lambda_A_trunc]
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
pred_fake, pred_fake_cls = self.netD_A(self.fake_B)
self.loss_G_A = self.criterionGAN(pred_fake, True)
if not self.opt.style_loss_with_weight:
self.loss_G_A_cls = self.criterionCls(pred_fake_cls, self.real_B_label)
else:
self.loss_G_A_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two))
if self.opt.use_mask:
self.loss_G_A_l = self.criterionGAN(self.netD_A_l(self.fake_B_l), True) * lambda_G_A_l
if self.opt.use_eye_mask:
self.loss_G_A_le = self.criterionGAN(self.netD_A_le(self.fake_B_le), True) * lambda_G_A_l
if self.opt.use_lip_mask:
self.loss_G_A_ll = self.criterionGAN(self.netD_A_ll(self.fake_B_ll), True) * lambda_G_A_l
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss LPIPS( HED(G_B(G_A(A))), HED(A))
ts = self.real_A.shape
gpu_p = self.opt.gpu_ids_p[0]
gpu = self.opt.gpu_ids[0]
rec_A_hed = (self.hed(self.rec_A.cuda(gpu_p)/2+0.5)-0.5)*2
real_A_hed = (self.hed(self.real_A.cuda(gpu_p)/2+0.5)-0.5)*2
self.loss_cycle_A = (self.lpips.forward_pair(rec_A_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A
self.rec_A_hed = rec_A_hed
self.real_A_hed = real_A_hed
if self.opt.ntrunc_trunc:
self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))
rec_At_hed = (self.hed(self.rec_At.cuda(gpu_p)/2+0.5)-0.5)*2
self.loss_cycle_A2 = (self.lpips.forward_pair(rec_At_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A_trunc
self.rec_At_hed = rec_At_hed
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
if getattr(self,'loss_cycle_A2',-1) != -1:
self.loss_G = self.loss_G + self.loss_cycle_A2
if getattr(self,'loss_G_A_l',-1) != -1:
self.loss_G = self.loss_G + self.loss_G_A_l
if getattr(self,'loss_G_A_le',-1) != -1:
self.loss_G = self.loss_G + self.loss_G_A_le
if getattr(self,'loss_G_A_ll',-1) != -1:
self.loss_G = self.loss_G + self.loss_G_A_ll
if getattr(self,'loss_G_A_cls',-1) != -1:
self.loss_G = self.loss_G + self.loss_G_A_cls
self.loss_G.backward()
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute fake images and reconstruction images.
# G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
if self.opt.use_mask:
self.set_requires_grad([self.netD_A_l], False)
if self.opt.use_eye_mask:
self.set_requires_grad([self.netD_A_le], False)
if self.opt.use_lip_mask:
self.set_requires_grad([self.netD_A_ll], False)
self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
self.backward_G() # calculate gradients for G_A and G_B
self.optimizer_G.step() # update G_A and G_B's weights
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
if self.opt.use_mask:
self.set_requires_grad([self.netD_A_l], True)
if self.opt.use_eye_mask:
self.set_requires_grad([self.netD_A_le], True)
if self.opt.use_lip_mask:
self.set_requires_grad([self.netD_A_ll], True)
self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_A() # calculate gradients for D_A
if self.opt.use_mask:
self.backward_D_A_l()# calculate gradients for D_A_l
if self.opt.use_eye_mask:
self.backward_D_A_le()# calculate gradients for D_A_le
if self.opt.use_lip_mask:
self.backward_D_A_ll()# calculate gradients for D_A_ll
self.backward_D_B() # calculate graidents for D_B
self.optimizer_D.step() # update D_A and D_B's weights
================================================
FILE: models/base_model.py
================================================
import os
import torch
from collections import OrderedDict
from abc import ABCMeta, abstractmethod
from . import networks
class BaseModel():
__metaclass__ = ABCMeta
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): specify the images that you want to display and save.
-- self.visual_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
torch.backends.cudnn.benchmark = True
self.loss_names = []
self.model_names = []
self.visual_names = []
self.optimizers = []
self.image_paths = []
self.metric = 0 # used for learning rate policy 'plateau'
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
@abstractmethod
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
pass
def setup(self, opt):
"""Load and print networks; create schedulers
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
if self.isTrain:
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
if not self.isTrain or opt.continue_train:
load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
self.load_networks(load_suffix)
self.print_networks(opt.verbose)
def eval(self):
"""Make models eval mode during test time"""
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
net.eval()
def test(self):
"""Forward function used in test time.
This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
It also calls <compute_visuals> to produce additional visualization results
"""
with torch.no_grad():
self.forward()
self.compute_visuals()
def compute_visuals(self):
"""Calculate additional output images for visdom and HTML visualization"""
pass
def get_image_paths(self):
""" Return image paths that are used to load current data"""
return self.image_paths
def update_learning_rate(self):
"""Update learning rates for all the networks; called at the end of every epoch"""
for scheduler in self.schedulers:
if self.opt.lr_policy == 'plateau':
scheduler.step(self.metric)
else:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
def get_current_visuals(self):
"""Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
visual_ret = OrderedDict()
for name in self.visual_names:
if isinstance(name, str):
visual_ret[name] = getattr(self, name)
return visual_ret
def get_current_losses(self):
"""Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
errors_ret = OrderedDict()
for name in self.loss_names:
if isinstance(name, str):
errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
return errors_ret
def save_networks(self, epoch):
"""Save all the networks to the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
save_filename = '%s_net_%s.pth' % (epoch, name)
save_path = os.path.join(self.save_dir, save_filename)
net = getattr(self, 'net' + name)
if len(self.gpu_ids) > 0 and torch.cuda.is_available():
torch.save(net.module.cpu().state_dict(), save_path)
net.cuda(self.gpu_ids[0])
else:
torch.save(net.cpu().state_dict(), save_path)
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def load_networks(self, epoch):
"""Load all the networks from the disk.
Parameters:
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
"""
for name in self.model_names:
if isinstance(name, str):
load_filename = '%s_net_%s.pth' % (epoch, name)
load_path = os.path.join(self.save_dir, load_filename)
net = getattr(self, 'net' + name)
if isinstance(net, torch.nn.DataParallel):
net = net.module
print('loading the model from %s' % load_path)
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=str(self.device))
if hasattr(state_dict, '_metadata'):
del state_dict._metadata
# patch InstanceNorm checkpoints prior to 0.4
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
net.load_state_dict(state_dict)
def print_networks(self, verbose):
"""Print the total number of parameters in the network and (if verbose) network architecture
Parameters:
verbose (bool) -- if verbose: print the network architecture
"""
print('---------- Networks initialized -------------')
for name in self.model_names:
if isinstance(name, str):
net = getattr(self, 'net' + name)
num_params = 0
for param in net.parameters():
num_params += param.numel()
if verbose:
print(net)
print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
print('-----------------------------------------------')
def set_requires_grad(self, nets, requires_grad=False):
"""Set requies_grad=Fasle for all the networks to avoid unnecessary computations
Parameters:
nets (network list) -- a list of networks
requires_grad (bool) -- whether the networks require gradients or not
"""
if not isinstance(nets, list):
nets = [nets]
for net in nets:
if net is not None:
for param in net.parameters():
param.requires_grad = requires_grad
# ===========================================================================================================
def masked(self, A,mask):
if self.opt.mask_type == 0:
return (A/2+0.5)*mask*2-1
elif self.opt.mask_type == 1:
return ((A/2+0.5)*mask+1-mask)*2-1
elif self.opt.mask_type == 2:
return torch.cat((A, mask), 1)
elif self.opt.mask_type == 3:
masked = ((A/2+0.5)*mask+1-mask)*2-1
return torch.cat((masked, mask), 1)
================================================
FILE: models/dist_model.py
================================================
from __future__ import absolute_import
import sys
sys.path.append('..')
sys.path.append('.')
import numpy as np
import torch
from torch import nn
from collections import OrderedDict
from torch.autograd import Variable
from .base_model import BaseModel
from scipy.ndimage import zoom
import skimage.transform
from . import networks_basic as networks
# from PerceptualSimilarity.util import util
from util import util
class DistModel(BaseModel):
def name(self):
return self.model_name
def __init__(self, opt, model='net-lin', net='alex', pnet_rand=False, pnet_tune=False, model_path=None, colorspace='Lab', use_gpu=True, printNet=False, spatial=False, spatial_shape=None, spatial_order=1, spatial_factor=None, is_train=False, lr=.0001, beta1=0.5, version='0.1'):
'''
INPUTS
model - ['net-lin'] for linearly calibrated network
['net'] for off-the-shelf network
['L2'] for L2 distance in Lab colorspace
['SSIM'] for ssim in RGB colorspace
net - ['squeeze','alex','vgg']
model_path - if None, will look in weights/[NET_NAME].pth
colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
use_gpu - bool - whether or not to use a GPU
printNet - bool - whether or not to print network architecture out
spatial - bool - whether to output an array containing varying distances across spatial dimensions
spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
is_train - bool - [True] for training mode
lr - float - initial learning rate
beta1 - float - initial momentum term for adam
version - 0.1 for latest, 0.0 was original
'''
BaseModel.__init__(self, opt)
self.model = model
self.net = net
self.use_gpu = use_gpu
self.is_train = is_train
self.spatial = spatial
self.spatial_shape = spatial_shape
self.spatial_order = spatial_order
self.spatial_factor = spatial_factor
self.model_name = '%s [%s]'%(model,net)
if(self.model == 'net-lin'): # pretrained net + linear layer
#self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
self.device = torch.device('cuda:{}'.format(opt.gpu_ids_p[0])) if opt.gpu_ids_p else torch.device('cpu')
self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,use_dropout=True,spatial=spatial,version=version,lpips=True).to(self.device)
kw = {}
if not use_gpu:
kw['map_location'] = 'cpu'
if(model_path is None):
import inspect
#model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', '..', 'weights/v%s/%s.pth'%(version,net)))
model_path = './checkpoints/weights/v%s/%s.pth'%(version,net)
if(not is_train):
print('Loading model from: %s'%model_path)
#self.net.load_state_dict(torch.load(model_path, **kw))
state_dict = torch.load(model_path, map_location=str(self.device))
self.net.load_state_dict(state_dict, strict=False)
elif(self.model=='net'): # pretrained network
assert not self.spatial, 'spatial argument not supported yet for uncalibrated networks'
self.net = networks.PNet(use_gpu=use_gpu,pnet_type=net,device=self.device)
self.is_fake_net = True
elif(self.model in ['L2','l2']):
self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace,device=self.device) # not really a network, only for testing
self.model_name = 'L2'
elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace,device=self.device)
self.model_name = 'SSIM'
else:
raise ValueError("Model [%s] not recognized." % self.model)
self.parameters = list(self.net.parameters())
if self.is_train: # training mode
# extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
self.rankLoss = networks.BCERankingLoss(use_gpu=use_gpu,device=self.device)
self.parameters+=self.rankLoss.parameters
self.lr = lr
self.old_lr = lr
self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
else: # test mode
self.net.eval()
if(printNet):
print('---------- Networks initialized -------------')
networks.print_network(self.net)
print('-----------------------------------------------')
def forward_pair(self,in1,in2,retPerLayer=False):
if(retPerLayer):
return self.net.forward(in1,in2, retPerLayer=True)
else:
return self.net.forward(in1,in2)
def forward(self, in0, in1, retNumpy=False):
''' Function computes the distance between image patches in0 and in1
INPUTS
in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
retNumpy - [False] to return as torch.Tensor, [True] to return as numpy array
OUTPUT
computed distances between in0 and in1
'''
self.input_ref = in0
self.input_p0 = in1
self.var_ref = Variable(self.input_ref,requires_grad=True)
self.var_p0 = Variable(self.input_p0,requires_grad=True)
self.d0 = self.forward_pair(self.var_ref, self.var_p0)
self.loss_total = self.d0
def convert_output(d0):
if(retNumpy):
ans = d0.cpu().data.numpy()
if not self.spatial:
ans = ans.flatten()
else:
assert(ans.shape[0] == 1 and len(ans.shape) == 4)
return ans[0,...].transpose([1, 2, 0]) # Reshape to usual numpy image format: (height, width, channels)
return ans
else:
return d0
if self.spatial:
L = [convert_output(x) for x in self.d0]
spatial_shape = self.spatial_shape
if spatial_shape is None:
if(self.spatial_factor is None):
spatial_shape = (in0.size()[2],in0.size()[3])
else:
spatial_shape = (max([x.shape[0] for x in L])*self.spatial_factor, max([x.shape[1] for x in L])*self.spatial_factor)
L = [skimage.transform.resize(x, spatial_shape, order=self.spatial_order, mode='edge') for x in L]
L = np.mean(np.concatenate(L, 2) * len(L), 2)
return L
else:
return convert_output(self.d0)
# ***** TRAINING FUNCTIONS *****
def optimize_parameters(self):
self.forward_train()
self.optimizer_net.zero_grad()
self.backward_train()
self.optimizer_net.step()
self.clamp_weights()
def clamp_weights(self):
for module in self.net.modules():
if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
module.weight.data = torch.clamp(module.weight.data,min=0)
def set_input(self, data):
self.input_ref = data['ref']
self.input_p0 = data['p0']
self.input_p1 = data['p1']
self.input_judge = data['judge']
if(self.use_gpu):
self.input_ref = self.input_ref.cuda(self.device)
self.input_p0 = self.input_p0.cuda(self.device)
self.input_p1 = self.input_p1.cuda(self.device)
self.input_judge = self.input_judge.cuda(self.device)
self.var_ref = Variable(self.input_ref,requires_grad=True)
self.var_p0 = Variable(self.input_p0,requires_grad=True)
self.var_p1 = Variable(self.input_p1,requires_grad=True)
def forward_train(self): # run forward pass
self.d0 = self.forward_pair(self.var_ref, self.var_p0)
self.d1 = self.forward_pair(self.var_ref, self.var_p1)
self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
# var_judge
self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
return self.loss_total
def backward_train(self):
torch.mean(self.loss_total).backward()
def compute_accuracy(self,d0,d1,judge):
''' d0, d1 are Variables, judge is a Tensor '''
d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
judge_per = judge.cpu().numpy().flatten()
return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)
def get_current_errors(self):
retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
('acc_r', self.acc_r)])
for key in retDict.keys():
retDict[key] = np.mean(retDict[key])
return retDict
def get_current_visuals(self):
zoom_factor = 256/self.var_ref.data.size()[2]
ref_img = util.tensor2im(self.var_ref.data)
p0_img = util.tensor2im(self.var_p0.data)
p1_img = util.tensor2im(self.var_p1.data)
ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)
return OrderedDict([('ref', ref_img_vis),
('p0', p0_img_vis),
('p1', p1_img_vis)])
def save(self, path, label):
self.save_network(self.net, path, '', label)
self.save_network(self.rankLoss.net, path, 'rank', label)
def update_learning_rate(self,nepoch_decay):
lrd = self.lr / nepoch_decay
lr = self.old_lr - lrd
for param_group in self.optimizer_net.param_groups:
param_group['lr'] = lr
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
self.old_lr = lr
def score_2afc_dataset(data_loader,func):
''' Function computes Two Alternative Forced Choice (2AFC) score using
distance function 'func' in dataset 'data_loader'
INPUTS
data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
func - callable distance function - calling d=func(in0,in1) should take 2
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
OUTPUTS
[0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
[1] - dictionary with following elements
d0s,d1s - N arrays containing distances between reference patch to perturbed patches
gts - N array in [0,1], preferred patch selected by human evaluators
(closer to "0" for left patch p0, "1" for right patch p1,
"0.6" means 60pct people preferred right patch, 40pct preferred left)
scores - N array in [0,1], corresponding to what percentage function agreed with humans
CONSTS
N - number of test triplets in data_loader
'''
d0s = []
d1s = []
gts = []
# bar = pb.ProgressBar(max_value=data_loader.load_data().__len__())
for (i,data) in enumerate(data_loader.load_data()):
d0s+=func(data['ref'],data['p0']).tolist()
d1s+=func(data['ref'],data['p1']).tolist()
gts+=data['judge'].cpu().numpy().flatten().tolist()
# bar.update(i)
d0s = np.array(d0s)
d1s = np.array(d1s)
gts = np.array(gts)
scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5
return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))
def score_jnd_dataset(data_loader,func):
''' Function computes JND score using distance function 'func' in dataset 'data_loader'
INPUTS
data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
func - callable distance function - calling d=func(in0,in1) should take 2
pytorch tensors with shape Nx3xXxY, and return numpy array of length N
OUTPUTS
[0] - JND score in [0,1], mAP score (area under precision-recall curve)
[1] - dictionary with following elements
ds - N array containing distances between two patches shown to human evaluator
sames - N array containing fraction of people who thought the two patches were identical
CONSTS
N - number of test triplets in data_loader
'''
ds = []
gts = []
# bar = pb.ProgressBar(max_value=data_loader.load_data().__len__())
for (i,data) in enumerate(data_loader.load_data()):
ds+=func(data['p0'],data['p1']).tolist()
gts+=data['same'].cpu().numpy().flatten().tolist()
# bar.update(i)
sames = np.array(gts)
ds = np.array(ds)
sorted_inds = np.argsort(ds)
ds_sorted = ds[sorted_inds]
sames_sorted = sames[sorted_inds]
TPs = np.cumsum(sames_sorted)
FPs = np.cumsum(1-sames_sorted)
FNs = np.sum(sames_sorted)-TPs
precs = TPs/(TPs+FPs)
recs = TPs/(TPs+FNs)
score = util.voc_ap(recs,precs)
return(score, dict(ds=ds,sames=sames))
================================================
FILE: models/networks.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler
###############################################################################
# Helper Functions
###############################################################################
class Identity(nn.Module):
def forward(self, x):
return x
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
norm_layer = lambda x: Identity()
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_scheduler(optimizer, opt):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
For 'linear', we keep the same learning rate for the first <opt.niter> epochs
and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if opt.lr_policy == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif opt.lr_policy == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def init_weights(net, init_type='normal', init_gain=0.02):
"""Initialize network weights.
Parameters:
net (network) -- network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
work better for some applications. Feel free to try yourself.
"""
def init_func(m): # define the initialization function
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, init_gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=init_gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=init_gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
init.normal_(m.weight.data, 1.0, init_gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func) # apply the initialization function <init_func>
def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
Parameters:
net (network) -- the network to be initialized
init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Return an initialized network.
"""
if len(gpu_ids) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids[0])
net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
init_weights(net, init_type, init_gain=init_gain)
return net
def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], model0_res=0, model1_res=0, extra_channel=3):
"""Create a generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
norm (str) -- the name of normalization layers used in the network: batch | instance | none
use_dropout (bool) -- if use dropout layers.
init_type (str) -- the name of our initialization method.
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Returns a generator
Our current implementation provides two types of generators:
U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
The original U-Net paper: https://arxiv.org/abs/1505.04597
Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
The generator has been initialized by <init_net>. It uses RELU for non-linearity.
"""
net = None
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'resnet_9blocks':
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
elif netG == 'resnet_style2_9blocks':
net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, extra_channel=extra_channel)
elif netG == 'resnet_6blocks':
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
elif netG == 'unet_128':
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
elif netG == 'unet_256':
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
return init_net(net, init_type, init_gain, gpu_ids)
def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], n_class=3):
"""Create a discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the first conv layer
netD (str) -- the architecture's name: basic | n_layers | pixel
n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers'
norm (str) -- the type of normalization layers used in the network.
init_type (str) -- the name of the initialization method.
init_gain (float) -- scaling factor for normal, xavier and orthogonal.
gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
Returns a discriminator
Our current implementation provides three types of discriminators:
[basic]: 'PatchGAN' classifier described in the original pix2pix paper.
It can classify whether 70×70 overlapping patches are real or fake.
Such a patch-level discriminator architecture has fewer parameters
than a full-image discriminator and can work on arbitrarily-sized images
in a fully convolutional fashion.
[n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)
[pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
It encourages greater color diversity but has no effect on spatial statistics.
The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
"""
net = None
norm_layer = get_norm_layer(norm_type=norm)
if netD == 'basic': # default PatchGAN classifier
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
elif netD == 'basic_cls':
net = NLayerDiscriminatorCls(input_nc, ndf, n_layers=3, n_class=3, norm_layer=norm_layer)
elif netD == 'n_layers': # more options
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
elif netD == 'pixel': # classify if each pixel is real or fake
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
return init_net(net, init_type, init_gain, gpu_ids)
def define_HED(init_weights_, gpu_ids_=[]):
net = HED()
if len(gpu_ids_) > 0:
assert(torch.cuda.is_available())
net.to(gpu_ids_[0])
net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs
if not init_weights_ == None:
device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
print('Loading model from: %s'%init_weights_)
state_dict = torch.load(init_weights_, map_location=str(device))
if isinstance(net, torch.nn.DataParallel):
net.module.load_state_dict(state_dict)
else:
net.load_state_dict(state_dict)
print('load the weights successfully')
return net
##############################################################################
# Classes
##############################################################################
class GANLoss(nn.Module):
"""Define different GAN objectives.
The GANLoss class abstracts away the need to create the target label tensor
that has the same size as the input.
"""
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
""" Initialize the GANLoss class.
Parameters:
gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
target_real_label (bool) - - label for a real image
target_fake_label (bool) - - label of a fake image
Note: Do not use sigmoid as the last layer of Discriminator.
LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
"""
super(GANLoss, self).__init__()
self.register_buffer('real_label', torch.tensor(target_real_label))
self.register_buffer('fake_label', torch.tensor(target_fake_label))
self.gan_mode = gan_mode
if gan_mode == 'lsgan':#cyclegan
self.loss = nn.MSELoss()
elif gan_mode == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ['wgangp']:
self.loss = None
else:
raise NotImplementedError('gan mode %s not implemented' % gan_mode)
def get_target_tensor(self, prediction, target_is_real):
"""Create label tensors with the same size as the input.
Parameters:
prediction (tensor) - - tpyically the prediction from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
A label tensor filled with ground truth label, and with the size of the input
"""
if target_is_real:
target_tensor = self.real_label
else:
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)
def __call__(self, prediction, target_is_real):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
prediction (tensor) - - tpyically the prediction output from a discriminator
target_is_real (bool) - - if the ground truth label is for real images or fake images
Returns:
the calculated loss.
"""
if self.gan_mode in ['lsgan', 'vanilla']:
target_tensor = self.get_target_tensor(prediction, target_is_real)
loss = self.loss(prediction, target_tensor)
elif self.gan_mode == 'wgangp':
if target_is_real:
loss = -prediction.mean()
else:
loss = prediction.mean()
return loss
def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
Arguments:
netD (network) -- discriminator network
real_data (tensor array) -- real images
fake_data (tensor array) -- generated images from the generator
device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
type (str) -- if we mix real and fake data or not [real | fake | mixed].
constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
lambda_gp (float) -- weight for this loss
Returns the gradient penalty loss
"""
if lambda_gp > 0.0:
if type == 'real': # either use real images, fake images, or a linear interpolation of two.
interpolatesv = real_data
elif type == 'fake':
interpolatesv = fake_data
elif type == 'mixed':
alpha = torch.rand(real_data.shape[0], 1, device=device)
alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
else:
raise NotImplementedError('{} not implemented'.format(type))
interpolatesv.requires_grad_(True)
disc_interpolates = netD(interpolatesv)
gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)
gradients = gradients[0].view(real_data.size(0), -1) # flat the data
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
return gradient_penalty, gradients
else:
return 0.0, None
class ResnetGenerator(nn.Module):
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
"""
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
assert(n_blocks >= 0)
super(ResnetGenerator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks): # add ResNet blocks
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
"""Standard forward"""
return self.model(input)
class ResnetStyle2Generator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0):
"""Construct a Resnet-based generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers
n_blocks (int) -- the number of ResNet blocks
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
"""
assert(n_blocks >= 0)
super(ResnetStyle2Generator, self).__init__()
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model0 = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling): # add downsampling layers
mult = 2 ** i
model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(model0_res): # add ResNet blocks
model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
model = []
model += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
norm_layer(ngf * mult),
nn.ReLU(True)]
for i in range(n_blocks-model0_res): # add ResNet blocks
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
for i in range(n_downsampling): # add upsampling layers
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model0 = nn.Sequential(*model0)
self.model = nn.Sequential(*model)
#print(list(self.modules()))
def forward(self, input1, input2):
"""Standard forward"""
f1 = self.model0(input1)
y1 = torch.cat([f1, input2], 1)
return self.model(y1)
class ResnetBlock(nn.Module):
"""Define a Resnet block"""
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3):
"""Initialize the Resnet block
A resnet block is a conv block with skip connections
We construct a conv block with build_conv_block function,
and implement skip connections in <forward> function.
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
"""
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, kernel)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3):
"""Construct a convolutional block.
Parameters:
dim (int) -- the number of channels in the conv layer.
padding_type (str) -- the name of padding layer: reflect | replicate | zero
norm_layer -- normalization layer
use_dropout (bool) -- if use dropout layers.
use_bias (bool) -- if the conv layer uses bias or not
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
"""
conv_block = []
p = 0
pad = int((kernel-1)/2)
if padding_type == 'reflect':#by default
conv_block += [nn.ReflectionPad2d(pad)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(pad)]
elif padding_type == 'zero':
p = pad
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(pad)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(pad)]
elif padding_type == 'zero':
p = pad
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
"""Forward function (with skip connections)"""
out = x + self.conv_block(x) # add skip connections
return out
class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
image of size 128x128 will become of size 1x1 # at the bottleneck
ngf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
We construct the U-Net from the innermost layer to the outermost layer.
It is a recursive process.
"""
super(UnetGenerator, self).__init__()
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
# gradually reduce the number of filters from ngf * 8 to ngf
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
def forward(self, input):
"""Standard forward"""
return self.model(input)
class UnetSkipConnectionBlock(nn.Module):
"""Defines the Unet submodule with skip connection.
X -------------------identity----------------------
|-- downsampling -- |submodule| -- upsampling --|
"""
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
"""Construct a Unet submodule with skip connections.
Parameters:
outer_nc (int) -- the number of filters in the outer conv layer
inner_nc (int) -- the number of filters in the inner conv layer
input_nc (int) -- the number of channels in input images/features
submodule (UnetSkipConnectionBlock) -- previously defined submodules
outermost (bool) -- if this module is the outermost module
innermost (bool) -- if this module is the innermost module
norm_layer -- normalization layer
user_dropout (bool) -- if use dropout layers.
"""
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else: # add skip connections
return torch.cat([x, self.model(x)], 1)
class NLayerDiscriminator(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
self.model = nn.Sequential(*sequence)
def forward(self, input):
"""Standard forward."""
return self.model(input)
class NLayerDiscriminatorCls(nn.Module):
"""Defines a PatchGAN discriminator"""
def __init__(self, input_nc, ndf=64, n_layers=3, n_class=3, norm_layer=nn.BatchNorm2d):
"""Construct a PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
n_layers (int) -- the number of conv layers in the discriminator
norm_layer -- normalization layer
"""
super(NLayerDiscriminatorCls, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.BatchNorm2d
else:
use_bias = norm_layer != nn.BatchNorm2d
kw = 4
padw = 1
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers): # gradually increase the number of filters
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence1 = [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence1 += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
sequence2 = [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence2 += [
nn.Conv2d(ndf * nf_mult, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence2 += [
nn.Conv2d(ndf * nf_mult, n_class, kernel_size=16, stride=1, padding=0, bias=use_bias)]
self.model0 = nn.Sequential(*sequence)
self.model1 = nn.Sequential(*sequence1)
self.model2 = nn.Sequential(*sequence2)
print(list(self.modules()))
def forward(self, input):
"""Standard forward."""
feat = self.model0(input)
# patchGAN output (1 * 62 * 62)
patch = self.model1(feat)
# class output (3 * 1 * 1)
classl = self.model2(feat)
return patch, classl.view(classl.size(0), -1)
class PixelDiscriminator(nn.Module):
"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
"""Construct a 1x1 PatchGAN discriminator
Parameters:
input_nc (int) -- the number of channels in input images
ndf (int) -- the number of filters in the last conv layer
norm_layer -- normalization layer
"""
super(PixelDiscriminator, self).__init__()
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
use_bias = norm_layer.func != nn.InstanceNorm2d
else:
use_bias = norm_layer != nn.InstanceNorm2d
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
norm_layer(ndf * 2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
self.net = nn.Sequential(*self.net)
def forward(self, input):
"""Standard forward."""
return self.net(input)
class HED(nn.Module):
def __init__(self):
super(HED, self).__init__()
self.moduleVggOne = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False)
)
self.moduleVggTwo = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False)
)
self.moduleVggThr = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False)
)
self.moduleVggFou = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False)
)
self.moduleVggFiv = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False),
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=False)
)
self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
self.moduleCombine = nn.Sequential(
nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
nn.Sigmoid()
)
def forward(self, tensorInput):
tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793
tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762
tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434
tensorInput = torch.cat([ tensorBlue, tensorGreen, tensorRed ], 1)
tensorVggOne = self.moduleVggOne(tensorInput)
tensorVggTwo = self.moduleVggTwo(tensorVggOne)
tensorVggThr = self.moduleVggThr(tensorVggTwo)
tensorVggFou = self.moduleVggFou(tensorVggThr)
tensorVggFiv = self.moduleVggFiv(tensorVggFou)
tensorScoreOne = self.moduleScoreOne(tensorVggOne)
tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo)
tensorScoreThr = self.moduleScoreThr(tensorVggThr)
tensorScoreFou = self.moduleScoreFou(tensorVggFou)
tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv)
tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1))
================================================
FILE: models/networks_basic.py
================================================
from __future__ import absolute_import
import sys
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
from pdb import set_trace as st
from skimage import color
from IPython import embed
from . import pretrained_networks as pn
from util import util
def spatial_average(in_tens, keepdim=True):
return in_tens.mean([2,3],keepdim=keepdim)
def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
in_H = in_tens.shape[2]
scale_factor = 1.*out_H/in_H
return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
# Learned perceptual metric
class PNetLin(nn.Module):
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
super(PNetLin, self).__init__()
self.pnet_type = pnet_type
self.pnet_tune = pnet_tune
self.pnet_rand = pnet_rand
self.spatial = spatial
self.lpips = lpips
self.version = version
self.scaling_layer = ScalingLayer()
if(self.pnet_type in ['vgg','vgg16']):
net_type = pn.vgg16
self.chns = [64,128,256,512,512]
elif(self.pnet_type=='alex'):
net_type = pn.alexnet
self.chns = [64,192,384,256,256]
elif(self.pnet_type=='squeeze'):
net_type = pn.squeezenet
self.chns = [64,128,256,384,384,512,512]
self.L = len(self.chns)
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
if(lpips):
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
self.lins+=[self.lin5,self.lin6]
def forward(self, in0, in1, retPerLayer=False):
# v0.0 - original release had a bug, where input was not scaled
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk]-feats1[kk])**2
if(self.lpips):
if(self.spatial):
res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
else:
res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
else:
if(self.spatial):
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
else:
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
val = res[0]
for l in range(1,self.L):
val += res[l]
if(retPerLayer):
return (val, res)
else:
return val
class ScalingLayer(nn.Module):
def __init__(self):
super(ScalingLayer, self).__init__()
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
def forward(self, inp):
return (inp - self.shift.to(inp.device)) / self.scale.to(inp.device)
class NetLinLayer(nn.Module):
''' A single linear layer which does a 1x1 conv '''
def __init__(self, chn_in, chn_out=1, use_dropout=False):
super(NetLinLayer, self).__init__()
layers = [nn.Dropout(),] if(use_dropout) else []
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
self.model = nn.Sequential(*layers)
class Dist2LogitLayer(nn.Module):
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
def __init__(self, chn_mid=32, use_sigmoid=True):
super(Dist2LogitLayer, self).__init__()
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
layers += [nn.LeakyReLU(0.2,True),]
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
layers += [nn.LeakyReLU(0.2,True),]
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
if(use_sigmoid):
layers += [nn.Sigmoid(),]
self.model = nn.Sequential(*layers)
def forward(self,d0,d1,eps=0.1):
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
class BCERankingLoss(nn.Module):
def __init__(self, chn_mid=32):
super(BCERankingLoss, self).__init__()
self.net = Dist2LogitLayer(chn_mid=chn_mid)
# self.parameters = list(self.net.parameters())
self.loss = torch.nn.BCELoss()
def forward(self, d0, d1, judge):
per = (judge+1.)/2.
self.logit = self.net.forward(d0,d1)
return self.loss(self.logit, per)
# L2, DSSIM metrics
class FakeNet(nn.Module):
def __init__(self, use_gpu=True, colorspace='Lab'):
super(FakeNet, self).__init__()
self.use_gpu = use_gpu
self.colorspace=colorspace
class L2(FakeNet):
def forward(self, in0, in1, retPerLayer=None):
assert(in0.size()[0]==1) # currently only supports batchSize 1
if(self.colorspace=='RGB'):
(N,C,X,Y) = in0.size()
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
return value
elif(self.colorspace=='Lab'):
value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
ret_var = Variable( torch.Tensor((value,) ) )
if(self.use_gpu):
ret_var = ret_var.cuda()
return ret_var
class DSSIM(FakeNet):
def forward(self, in0, in1, retPerLayer=None):
assert(in0.size()[0]==1) # currently only supports batchSize 1
if(self.colorspace=='RGB'):
value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
elif(self.colorspace=='Lab'):
value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
ret_var = Variable( torch.Tensor((value,) ) )
if(self.use_gpu):
ret_var = ret_var.cuda()
return ret_var
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print('Network',net)
print('Total number of parameters: %d' % num_params)
================================================
FILE: models/pretrained_networks.py
================================================
from collections import namedtuple
import torch
from torchvision import models
from IPython import embed
class squeezenet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(squeezenet, self).__init__()
pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.slice6 = torch.nn.Sequential()
self.slice7 = torch.nn.Sequential()
self.N_slices = 7
for x in range(2):
self.slice1.add_module(str(x), pretrained_features[x])
for x in range(2,5):
self.slice2.add_module(str(x), pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), pretrained_features[x])
for x in range(10, 11):
self.slice5.add_module(str(x), pretrained_features[x])
for x in range(11, 12):
self.slice6.add_module(str(x), pretrained_features[x])
for x in range(12, 13):
self.slice7.add_module(str(x), pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
h = self.slice6(h)
h_relu6 = h
h = self.slice7(h)
h_relu7 = h
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
return out
class alexnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(alexnet, self).__init__()
alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(2):
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
for x in range(2, 5):
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
for x in range(5, 8):
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
for x in range(8, 10):
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
for x in range(10, 12):
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1 = h
h = self.slice2(h)
h_relu2 = h
h = self.slice3(h)
h_relu3 = h
h = self.slice4(h)
h_relu4 = h
h = self.slice5(h)
h_relu5 = h
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
return out
class vgg16(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True):
super(vgg16, self).__init__()
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
class resnet(torch.nn.Module):
def __init__(self, requires_grad=False, pretrained=True, num=18):
super(resnet, self).__init__()
if(num==18):
self.net = models.resnet18(pretrained=pretrained)
elif(num==34):
self.net = models.resnet34(pretrained=pretrained)
elif(num==50):
self.net = models.resnet50(pretrained=pretrained)
elif(num==101):
self.net = models.resnet101(pretrained=pretrained)
elif(num==152):
self.net = models.resnet152(pretrained=pretrained)
self.N_slices = 5
self.conv1 = self.net.conv1
self.bn1 = self.net.bn1
self.relu = self.net.relu
self.maxpool = self.net.maxpool
self.layer1 = self.net.layer1
self.layer2 = self.net.layer2
self.layer3 = self.net.layer3
self.layer4 = self.net.layer4
def forward(self, X):
h = self.conv1(X)
h = self.bn1(h)
h = self.relu(h)
h_relu1 = h
h = self.maxpool(h)
h = self.layer1(h)
h_conv2 = h
h = self.layer2(h)
h_conv3 = h
h = self.layer3(h)
h_conv4 = h
h = self.layer4(h)
h_conv5 = h
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
return out
================================================
FILE: models/test_3styles_model.py
================================================
from .base_model import BaseModel
from . import networks
import torch
class Test3StylesModel(BaseModel):
""" This TesteModel can be used to generate CycleGAN results for only one direction.
This model will automatically set '--dataset_mode single', which only loads the images from one collection.
See the test instruction for more details.
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
The model can only be used during test time. It requires '--dataset_mode single'.
You need to specify the network using the option '--model_suffix'.
"""
assert not is_train, 'TestModel cannot be used during training time'
parser.set_defaults(dataset_mode='single')
parser.add_argument('--style_control', type=int, default=0, help='not set style_vec in dataset')
parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A')
parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0')
parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')
return parser
def __init__(self, opt):
assert(not opt.isTrain)
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = []
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['real', 'fake1', 'fake2', 'fake3']
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
self.model_names = ['G_A'] # only generator is needed.
print(opt.netga)
print('model0_res', opt.model0_res)
print('model1_res', opt.model1_res)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)
setattr(self, 'netG_A', self.netG) # store netG in self.
def set_input(self, input):
self.real = input['A'].to(self.device)
self.image_paths = input['A_paths']
self.style1 = torch.Tensor([1, 0, 0]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device)
self.style2 = torch.Tensor([0, 1, 0]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device)
self.style3 = torch.Tensor([0, 0, 1]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device)
def forward(self):
"""Run forward pass."""
self.fake1 = self.netG(self.real, self.style1)
self.fake2 = self.netG(self.real, self.style2)
self.fake3 = self.netG(self.real, self.style3)
def optimize_parameters(self):
"""No optimization for test model."""
pass
================================================
FILE: models/test_model.py
================================================
from .base_model import BaseModel
from . import networks
import torch
class TestModel(BaseModel):
""" This TesteModel can be used to generate CycleGAN results for only one direction.
This model will automatically set '--dataset_mode single', which only loads the images from one collection.
See the test instruction for more details.
"""
@staticmethod
def modify_commandline_options(parser, is_train=True):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
The model can only be used during test time. It requires '--dataset_mode single'.
You need to specify the network using the option '--model_suffix'.
"""
assert not is_train, 'TestModel cannot be used during training time'
parser.set_defaults(dataset_mode='single')
parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
parser.add_argument('--style_control', type=int, default=1, help='use style_control')
parser.add_argument('--sfeature_mode', type=str, default='vgg19_softmax', help='vgg19 softmax as feature')
parser.add_argument('--sinput', type=str, default='sind', help='use which one for style input')
parser.add_argument('--sind', type=int, default=0, help='one hot for sfeature')
parser.add_argument('--svec', type=str, default='1,0,0', help='3-dim vec')
parser.add_argument('--simg', type=str, default='Yann_Legendre-053', help='drawing example for style')
parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A')
parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0')
parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')
return parser
def __init__(self, opt):
"""Initialize the pix2pix class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
assert(not opt.isTrain)
BaseModel.__init__(self, opt)
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
self.loss_names = []
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
self.visual_names = ['real', 'fake', 'rec']
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
self.model_names = ['G' + opt.model_suffix, 'G_B'] # only generator is needed.
if not self.opt.style_control:
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
else:
print(opt.netga)
print('model0_res', opt.model0_res)
print('model1_res', opt.model1_res)
self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)
self.netGB = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG,
opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
# assigns the model to self.netG_[suffix] so that it can be loaded
# please see <BaseModel.load_networks>
setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self.
setattr(self, 'netG_B', self.netGB) # store netGB in self.
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input: a dictionary that contains the data itself and its metadata information.
We need to use 'single_dataset' dataset mode. It only load images from one domain.
"""
self.real = input['A'].to(self.device)
self.image_paths = input['A_paths']
if self.opt.style_control:
self.style = input['B_style']
def forward(self):
"""Run forward pass."""
if not self.opt.style_control:
self.fake = self.netG(self.real) # G(real)
else:
print(torch.mean(self.style,(2,3)),'style_control')
self.fake = self.netG(self.real, self.style)
self.rec = self.netG_B(self.fake)
def optimize_parameters(self):
"""No optimization for test model."""
pass
================================================
FILE: options/__init__.py
================================================
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
================================================
FILE: options/base_options.py
================================================
import argparse
import os
from util import util
import torch
import models
import data
class BaseOptions():
"""This class defines options used during both training and test time.
It also implements several helper functions such as parsing, printing, and saving the options.
It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
"""
def __init__(self):
"""Reset the class; indicates the class hasn't been initailized"""
self.initialized = False
def initialize(self, parser):
"""Define the common options that are used in both training and test."""
# basic parameters
parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--gpu_ids_p', type=str, default='0', help='gpu ids for pretrained auxiliary models: e.g. 0 0,1,2, 0,2. use -1 for CPU')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
# model parameters
parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
# dataset parameters
parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
# additional parameters
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
self.initialized = True
return parser
def gather_options(self):
"""Initialize our parser with basic options(only once).
Add additional model-specific and dataset-specific options.
These options are defined in the <modify_commandline_options> function
in model and dataset classes.
"""
if not self.initialized: # check if it has been initialized
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
# get the basic options
opt, _ = parser.parse_known_args()
# modify model-related parser options
model_name = opt.model
model_option_setter = models.get_option_setter(model_name)
parser = model_option_setter(parser, self.isTrain)
opt, _ = parser.parse_known_args() # parse again with new defaults
# modify dataset-related parser options
dataset_name = opt.dataset_mode
dataset_option_setter = data.get_option_setter(dataset_name)
parser = dataset_option_setter(parser, self.isTrain)
# save and return the parser
self.parser = parser
return parser.parse_args()
def print_options(self, opt):
"""Print and save options
It will print both current options and default values(if different).
It will save options into a text file / [checkpoints_dir] / opt.txt
"""
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
# save to the disk
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
with open(file_name, 'wt') as opt_file:
opt_file.write(message)
opt_file.write('\n')
def parse(self):
"""Parse our options, create checkpoints directory suffix, and set up gpu device."""
opt = self.gather_options()
opt.isTrain = self.isTrain # train or test
# process opt.suffix
if opt.suffix:
suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
opt.name = opt.name + suffix
self.print_options(opt)
# set gpu ids
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
if len(opt.gpu_ids) > 0:
torch.cuda.set_device(opt.gpu_ids[0])
# set gpu ids
str_ids = opt.gpu_ids_p.split(',')
opt.gpu_ids_p = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids_p.append(id)
self.opt = opt
return self.opt
================================================
FILE: options/test_options.py
================================================
from .base_options import BaseOptions
class TestOptions(BaseOptions):
"""This class includes test options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser) # define shared options
parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
# Dropout and Batchnorm has different behavioir during training and test.
parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
parser.add_argument('--imagefolder', type=str, default='images', help='subfolder to save images')
# rewrite devalue values
parser.set_defaults(model='test')
# To avoid cropping, the load_size should be the same as crop_size
parser.set_defaults(load_size=parser.get_default('crop_size'))
self.isTrain = False
return parser
================================================
FILE: options/train_options.py
================================================
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
"""This class includes training options.
It also includes shared options defined in BaseOptions.
"""
def initialize(self, parser):
parser = BaseOptions.initialize(self, parser)
# visdom and HTML visualization parameters
parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
# network saving and loading parameters
parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
# training parameters
parser.add_argument('--n_epochs', type=int, default=200, help='the end epoch count')
parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
self.isTrain = True
return parser
================================================
FILE: portrait_drawing_resources.md
================================================
- Charles Burns (style1): https://www.pinterest.co.uk/johns59/charles-burns-fan-club/
- Yann Legendre (style1): http://www.yannlegendre.com/project/portraits/
- Kathryn Rathke (style2):
https://www.kathrynrathke.com/
- Vectorportal (style3): https://www.pinterest.co.uk/vectorportal/celebrity-vector-illustrations/
================================================
FILE: preprocess/face_align_512.m
================================================
function [trans_img]=face_align_512(impath,facial5point,savedir)
% align the faces by similarity transformation.
% using 5 facial landmarks: 2 eyes, nose, 2 mouth corners.
% impath: path to image
% facial5point: 5x2 size, 5 facial landmark positions, detected by MTCNN
% savedir: savedir for cropped image and transformed facial landmarks
%% alignment settings
imgSize = [512,512];
coord5point = [180,230;
300,230;
240,301;
186,365.6;
294,365.6];%480x480
coord5point = (coord5point-240)/560 * 512 + 256;
%% face alignment
% load and align, resize image to imgSize
img = imread(impath);
facial5point = double(facial5point);
transf = cp2tform(facial5point, coord5point, 'similarity');
trans_img = imtransform(img, transf, 'XData', [1 imgSize(2)],...
'YData', [1 imgSize(1)],...
'Size', imgSize,...
'FillValues', [255;255;255]);
trans_facial5point = round(tformfwd(transf,facial5point));
%% save results
if ~exist(savedir,'dir')
mkdir(savedir)
end
[~,name,~] = fileparts(impath);
% save trans_img
imwrite(trans_img, fullfile(savedir,[name,'_resized.png']));
fprintf('write aligned image to %s\n',fullfile(savedir,[name,'_resized.png']));
%% show results
imshow(trans_img); hold on;
plot(trans_facial5point(:,1),trans_facial5point(:,2),'b');
plot(trans_facial5point(:,1),trans_facial5point(:,2),'r+');
end
================================================
FILE: preprocess/readme.md
================================================
## Preprocessing steps
During training, face photos and drawings are aligned and have nose,eyes,lips mask detected.
During test, the alignment step is optional and the masks are not needed.
### 1. Align, resize, crop images to 512x512
All training and testing images in our model are aligned using facial landmarks. And landmarks after alignment are needed in our code.
- First, 5 facial landmark for a face photo need to be detected (we detect using [MTCNN](https://github.com/kpzhang93/MTCNN_face_detection_alignment)(MTCNNv1)).
- Then, we provide a matlab function in `face_align_512.m` to align, resize and crop face photos (and corresponding drawings) to 512x512. Call this function in MATLAB to align the image to 512x512.
For example, for `ia_selfie_10515.jpg` in `example` dir, 5 detected facial landmark is saved in `example/ia_selfie_10515_facial5point.mat`. Call following in MATLAB:
```bash
load('example/ia_selfie_10515_facial5point.mat');
[trans_img]=face_align_512('example/ia_selfie_10515.jpg',facial5point,'example');
```
This will align the image and output aligned image in `example` folder.
See `face_align_512.m` for more instructions.
### 2. Prepare nose,eyes,lips masks
In our work, we use the face parsing network in https://github.com/cientgu/Mask_Guided_Portrait_Editing to get nose,eyes,lips regions and then dilate the regions to make them cover these facial features (some examples are shown in `example` folder).
- The background masks need to be copied to `datasets/portrait_drawing/train/A(B)(_eyes)(_lips)`, and has the **same filename** with aligned face photos.
================================================
FILE: readme.md
================================================
# Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping
We provide PyTorch implementations for our CVPR 2020 paper "Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping". [paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Yi_Unpaired_Portrait_Drawing_Generation_via_Asymmetric_Cycle_Mapping_CVPR_2020_paper.pdf), [suppl](https://openaccess.thecvf.com/content_CVPR_2020/supplemental/Yi_Unpaired_Portrait_Drawing_CVPR_2020_supplemental.pdf).
This project generates multi-style artistic portrait drawings from face photos using a GAN-based model.
[[Jittor implementation]](https://github.com/yiranran/Unpaired-Portrait-Drawing-Jittor)
## Our Proposed Framework
<img src = 'imgs/architecture.jpg'>
## Sample Results
From left to right: input, output(style1), output(style2), output(style3)
<img src = 'imgs/results.jpg'>
## Prerequisites
- Linux or macOS
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN
## Installation
- To install the dependencies, run
```bash
pip install -r requirements.txt
```
## Colab
A colab demo is [here](https://colab.research.google.com/drive/1U1fPXD1JukuKPOrhGMX1iaJC-d8_RUYr).
## Test steps (apply a pretrained model)
- 1. Download pre-trained models from [BaiduYun](https://pan.baidu.com/s/1_9Fy8mRpTQp6AvqhHsfQAQ)(extract code:c9h7) or [GoogleDrive](https://drive.google.com/drive/folders/1FzOcdlMYhvK_nyLCe8wnwotMphhIoiYt?usp=sharing) and rename the folder to `checkpoints`.
- 2. Test for example photos: generate artistic portrait drawings for example photos in the folder `./examples` using
``` bash
# with GPU
python test_seq_style.py
# without GPU
python test_seq_style.py --gpu -1
```
The test results will be saved to a html file here: `./results/pretrained/test_200/index3styles.html`.
The result images are saved in `./results/pretrained/test_200/images3styles`,
where `real`, `fake1`, `fake2`, `fake3` correspond to input face photo, style1 drawing, style2 drawing, style3 drawing respectively.
<img src = 'imgs/how_to_crop.jpg'>
- 3. To test on your own photos: First use an image editor to crop the face region of your photo (or use an optional preprocess [here](preprocess/readme.md)). Then specify the folder that contains test photos using option `--dataroot`, specify save folder name using option `--savefolder` and run the above command again:
``` bash
# with GPU
python test_seq_style.py --dataroot [input_folder] --savefolder [save_folder_name]
# without GPU
python test_seq_style.py --gpu -1 --dataroot [input_folder] --savefolder [save_folder_name]
# E.g.
python test_seq_style.py --gpu -1 --dataroot ./imgs/test1 --savefolder 3styles_test1
```
The test results will be saved to a html file here: `./results/pretrained/test_200/index[save_folder_name].html`.
The result images are saved in `./results/pretrained/test_200/images[save_folder_name]`.
An example html screenshot is shown below:
<img src = 'imgs/result_html.jpg'>
You can contact email yr16@mails.tsinghua.edu.cn for any questions.
## Train steps
- 1. Prepare for the dataset: 1) download face photos and portrait drawings from internet (e.g. [resources](portrait_drawing_resources.md)). 2) align, crop photos and drawings & 3) prepare nose, eyes, lips masks according to [preprocess instructions](preprocess/readme.md). 3) put aligned photos under `./datasets/portrait_drawing/train/A`, aligned drawings under `./datasets/portrait_drawing/train/B`, masks under `A_nose`,`A_eyes`,`A_lips`,`B_nose`,`B_eyes`,`B_lips` respectively.
- 2. Train a 3-class style classifier and extract the 3-dim style feature (according to paper). And save the style feature of each drawing in the training set in .npy format, in folder `./datasets/portrait_drawing/train/B_feat`
A subset of our training set is [here](https://drive.google.com/file/d/1OSMOR3-uhGkoPwPFRNychJSNrpSak_23/view?usp=sharing).
- 3. Train our model
``` bash
sh ./scripts/train.sh
```
Models are saved in folder checkpoints/portrait_drawing
## Citation
If you use this code for your research, please cite our paper.
```
@inproceedings{YiLLR20,
title = {Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping},
author = {Yi, Ran and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L},
booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition (CVPR '20)},
pages = {8214--8222},
year = {2020}
}
```
## Acknowledgments
Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).
================================================
FILE: requirements.txt
================================================
torch==1.2.0
torchvision==0.4.0
dominate==2.4.0
visdom==0.1.8.9
scipy==1.1.0
numpy==1.16.4
Pillow==6.2.1
opencv-python==4.1.0.25
================================================
FILE: scripts/train.sh
================================================
set -ex
python train.py --dataroot ./datasets/portrait_drawing --name formal --model asymmetric_cycle_gan_cls --output_nc 1 --load_size 572 --crop_size 512 --lr 0.000015 --dataset_mode unaligned_mask_stylecls --display_env asymmetric_trainset --gpu_ids 0 --gpu_ids_p 0 --niter 100 --niter_decay 200 --n_epochs 200
================================================
FILE: test.py
================================================
"""General-purpose test script for image-to-image translation.
Once you have trained your model with train.py, you can use this script to test the model.
It will load a saved model from --checkpoints_dir and save the results to --results_dir.
It first creates model and dataset given the option. It will hard-code some parameters.
It then runs inference for --num_test images and save results to an HTML file.
Example (You need to train models first or download pre-trained models from our website):
Test a CycleGAN model (both sides):
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
Test a CycleGAN model (one side only):
python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
The option '--model test' is used for generating CycleGAN results only for one side.
This option will automatically set '--dataset_mode single', which only loads the images from one set.
On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
which is sometimes unnecessary. The results will be saved at ./results/.
Use '--results_dir <directory_path_to_save_result>' to specify the results directory.
Test a pix2pix model:
python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
See options/base_options.py and options/test_options.py for more test options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import os
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from util.visualizer import save_images
from util import html
if __name__ == '__main__':
opt = TestOptions().parse() # get test options
# hard-code some parameters for test
opt.num_threads = 0 # test code only supports num_threads = 1
opt.batch_size = 1 # test code only supports batch_size = 1
opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
opt.no_flip = True # no flip; comment this line if results on flipped images are needed.
opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
# create a website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory
#webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch), refresh=0, folder=opt.imagefolder)
# test with eval mode. This only affects layers like batchnorm and dropout.
# For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
# For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
if opt.eval:
model.eval()
for i, data in enumerate(dataset):
if i >= opt.num_test: # only apply our model to opt.num_test images.
break
model.set_input(data) # unpack data from data loader
model.test() # run inference
visuals = model.get_current_visuals() # get image results
img_path = model.get_image_paths() # get image paths
if i % 5 == 0: # save images to an HTML file
print('processing (%04d)-th image... %s' % (i, img_path))
save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, W=opt.W, H=opt.H)
webpage.save() # save the HTML
================================================
FILE: test_seq_style.py
================================================
import os
import argparse
def opts():
parser = argparse.ArgumentParser()
parser.add_argument('-g','--gpu', default = '0', type = str, help = 'gpu ids, -1 for cpu, default is 0.')
parser.add_argument('-d','--dataroot', default = './examples', type = str, help = 'the input folder that contains test face photos, default is ./examples')
parser.add_argument('-s','--savefolder', default = '3styles', type = str, help = 'the name of save folder that contains result images, default is 3styles')
return parser.parse_args()
if __name__ == '__main__':
opt = opts()
exp = 'pretrained'
imgsize = 512
epoch = '200'
dataroot = opt.dataroot
gpu_id = opt.gpu
# test 3 styles in one pass
savefolder = 'images'+opt.savefolder
os.system('python3 test.py --dataroot %s --name %s --model test_3styles --output_nc 1 --no_dropout --num_test 1000 --epoch %s --imagefolder %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,epoch,savefolder,imgsize,imgsize,gpu_id))
print('check ./results/%s/test_%s/index%s.html'%(exp,epoch,savefolder[6:]))
print('saved to ./results/%s/test_%s/%s'%(exp,epoch,savefolder))
# test 3 styles separately
'''
for vec in [[1,0,0],[0,1,0],[0,0,1]]:
#1,0,0 for style1; 0,1,0 for style2; 0,0,1 for style3
svec = '%d,%d,%d' % (vec[0],vec[1],vec[2])
savefolder = 'imagesstyle%d-%d-%d'%(vec[0],vec[1],vec[2])
print('results/%s/test_%s/index%s.html'%(exp,epoch,savefolder[6:]))
os.system('python3 test.py --dataroot %s --name %s --model test --output_nc 1 --no_dropout --model_suffix _A --num_test 1000 --epoch %s --imagefolder %s --sinput svec --svec %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,epoch,savefolder,svec,imgsize,imgsize,gpu_id))
'''
================================================
FILE: train.py
================================================
"""General-purpose training script for image-to-image translation.
This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and
different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization).
You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').
It first creates model, dataset, and visualizer given the option.
It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models.
The script supports continue/resume training. Use '--continue_train' to resume your previous training.
Example:
Train a CycleGAN model:
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
Train a pix2pix model:
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
See options/base_options.py and options/train_options.py for more training options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
import pdb
if __name__ == '__main__':
start = time.time()
opt = TrainOptions().parse() # get training options
dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options
dataset_size = len(dataset) # get the number of images in the dataset.
print('The number of training images = %d' % dataset_size)
model = create_model(opt) # create a model given opt.model and other options
model.setup(opt) # regular setup: load and print networks; create schedulers
visualizer = Visualizer(opt) # create a visualizer that display/save images and plots
total_iters = 0 # the total number of training iterations
for epoch in range(opt.epoch_count, opt.n_epochs + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
epoch_start_time = time.time() # timer for entire epoch
iter_data_time = time.time() # timer for data loading per iteration
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
model.update_process(epoch)
for i, data in enumerate(dataset): # inner loop within one epoch
iter_start_time = time.time() # timer for computation per iteration
if total_iters % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
visualizer.reset()
total_iters += opt.batch_size
epoch_iter += opt.batch_size
model.set_input(data) # unpack data from dataset and apply preprocessing
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file
save_result = total_iters % opt.update_html_freq == 0
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk
losses = model.get_current_losses()
t_comp = (time.time() - iter_start_time) / opt.batch_size
if opt.model == 'cycle_gan':
processes = [model.process] + model.lambda_As
visualizer.print_current_losses_process(epoch, epoch_iter, losses, t_comp, t_data, processes)
else:
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
if opt.display_id > 0:
visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)
if total_iters % opt.save_latest_freq == 0: # cache our latest model every <save_latest_freq> iterations
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
model.save_networks(save_suffix)
iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
model.save_networks(epoch)
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_learning_rate() # update learning rates at the end of every epoch.
print('Total Time Taken: %d sec' % (time.time() - start))
================================================
FILE: util/__init__.py
================================================
"""This package includes a miscellaneous collection of useful helper functions."""
================================================
FILE: util/get_data.py
================================================
from __future__ import print_function
import os
import tarfile
import requests
from warnings import warn
from zipfile import ZipFile
from bs4 import BeautifulSoup
from os.path import abspath, isdir, join, basename
class GetData(object):
"""A Python script for downloading CycleGAN or pix2pix datasets.
Parameters:
technique (str) -- One of: 'cyclegan' or 'pix2pix'.
verbose (bool) -- If True, print additional information.
Examples:
>>> from util.get_data import GetData
>>> gd = GetData(technique='cyclegan')
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
and 'scripts/download_cyclegan_model.sh'.
"""
def __init__(self, technique='cyclegan', verbose=True):
url_dict = {
'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
}
self.url = url_dict.get(technique.lower())
self._verbose = verbose
def _print(self, text):
if self._verbose:
print(text)
@staticmethod
def _get_options(r):
soup = BeautifulSoup(r.text, 'lxml')
options = [h.text for h in soup.find_all('a', href=True)
if h.text.endswith(('.zip', 'tar.gz'))]
return options
def _present_options(self):
r = requests.get(self.url)
options = self._get_options(r)
print('Options:\n')
for i, o in enumerate(options):
print("{0}: {1}".format(i, o))
choice = input("\nPlease enter the number of the "
"dataset above you wish to download:")
return options[int(choice)]
def _download_data(self, dataset_url, save_path):
if not isdir(save_path):
os.makedirs(save_path)
base = basename(dataset_url)
temp_save_path = join(save_path, base)
with open(temp_save_path, "wb") as f:
r = requests.get(dataset_url)
f.write(r.content)
if base.endswith('.tar.gz'):
obj = tarfile.open(temp_save_path)
elif base.endswith('.zip'):
obj = ZipFile(temp_save_path, 'r')
else:
raise ValueError("Unknown File Type: {0}.".format(base))
self._print("Unpacking Data...")
obj.extractall(save_path)
obj.close()
os.remove(temp_save_path)
def get(self, save_path, dataset=None):
"""
Download a dataset.
Parameters:
save_path (str) -- A directory to save the data to.
dataset (str) -- (optional). A specific dataset to download.
Note: this must include the file extension.
If None, options will be presented for you
to choose from.
Returns:
save_path_full (str) -- the absolute path to the downloaded data.
"""
if dataset is None:
selected_dataset = self._present_options()
else:
selected_dataset = dataset
save_path_full = join(save_path, selected_dataset.split('.')[0])
if isdir(save_path_full):
warn("\n'{0}' already exists. Voiding Download.".format(
save_path_full))
else:
self._print('Downloading Data...')
url = "{0}/{1}".format(self.url, selected_dataset)
self._download_data(url, save_path=save_path)
return abspath(save_path_full)
================================================
FILE: util/html.py
================================================
import dominate
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
import os
class HTML:
"""This HTML class allows us to save images and write texts into a single HTML file.
It consists of functions such as <add_header> (add a text header to the HTML file),
<add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
"""
def __init__(self, web_dir, title, refresh=0, folder='images'):
"""Initialize the HTML classes
Parameters:
web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
title (str) -- the webpage name
refresh (int) -- how often the website refresh itself; if 0; no refreshing
"""
self.title = title
self.web_dir = web_dir
#self.img_dir = os.path.join(self.web_dir, 'images')
self.img_dir = os.path.join(self.web_dir, folder)
self.folder = folder
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
self.doc = dominate.document(title=title)
if refresh > 0:
with self.doc.head:
meta(http_equiv="refresh", content=str(refresh))
def get_image_dir(self):
"""Return the directory that stores images"""
return self.img_dir
def add_header(self, text):
"""Insert a header to the HTML file
Parameters:
text (str) -- the header text
"""
with self.doc:
h3(text)
def add_images(self, ims, txts, links, width=400):
"""add images to the HTML file
Parameters:
ims (str list) -- a list of image paths
txts (str list) -- a list of image names shown on the website
links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
"""
self.t = table(border=1, style="table-layout: fixed;") # Insert a table
self.doc.add(self.t)
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
#img(style="width:%dpx" % width, src=os.path.join('images', im))
img(style="width:%dpx" % width, src=os.path.join(self.folder, im))
br()
p(txt)
def save(self):
"""save the current content to the HMTL file"""
#html_file = '%s/index.html' % self.web_dir
name = self.folder[6:] if self.folder[:6] == 'images' else self.folder
html_file = '%s/index%s.html' % (self.web_dir, name)
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__': # we show an example usage here.
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims, txts, links = [], [], []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()
================================================
FILE: util/image_pool.py
================================================
import random
import torch
class ImagePool():
"""This class implements an image buffer that stores previously generated images.
This buffer enables us to update discriminators using a history of generated images
rather than the ones produced by the latest generators.
"""
def __init__(self, pool_size):
"""Initialize the ImagePool class
Parameters:
pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
"""
self.pool_size = pool_size
if self.pool_size > 0: # create an empty pool
self.num_imgs = 0
self.images = []
def query(self, images):
"""Return an image from the pool.
Parameters:
images: the latest generated images from the generator
Returns images from the buffer.
By 50/100, the buffer will return input images.
By 50/100, the buffer will return images previously stored in the buffer,
and insert the current images to the buffer.
"""
if self.pool_size == 0: # if the buffer size is 0, do nothing
return images
return_images = []
for image in images:
image = torch.unsqueeze(image.data, 0)
if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else: # by another 50% chance, the buffer will return the current image
return_images.append(image)
return_images = torch.cat(return_images, 0) # collect all the images and return
return return_images
================================================
FILE: util/util.py
================================================
"""This module contains simple helper functions """
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os
import pdb
from scipy.io import savemat
def tensor2im(input_image, imtype=np.uint8):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
#pdb.set_trace()
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
if image_numpy.shape[0] == 1: # grayscale to RGB
image_numpy = np.tile(image_numpy, (3, 1, 1))
elif image_numpy.shape[0] == 2:
image_numpy = np.concatenate([image_numpy, image_numpy[1:2,:,:]], 0)
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
#return np.round(image_numpy).astype(imtype),image_numpy
def diagnose_network(net, name='network'):
"""Calculate and print the mean of average absolute(gradients)
Parameters:
net (torch network) -- Torch network
name (str) -- the name of the network
"""
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path):
"""Save a numpy image to the disk
Parameters:
image_numpy (numpy array) -- input numpy array
image_path (str) -- the path of the image
"""
image_pil = Image.fromarray(image_numpy)
#pdb.set_trace()
image_pil.save(image_path)
def print_numpy(x, val=True, shp=False):
"""Print the mean, min, max, median, std, and size of a numpy array
Parameters:
val (bool) -- if print the values of the numpy array
shp (bool) -- if print the shape of the numpy array
"""
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
"""create empty directories if they don't exist
Parameters:
paths (str list) -- a list of directory paths
"""
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
"""create a single empty directory if it didn't exist
Parameters:
path (str) -- a single directory path
"""
if not os.path.exists(path):
os.makedirs(path)
def normalize_tensor(in_feat,eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
return in_feat/(norm_factor+eps)
================================================
FILE: util/visualizer.py
================================================
import numpy as np
import os
import sys
import ntpath
import time
from . import util, html
from subprocess import Popen, PIPE
from PIL import Image
if sys.version_info[0] == 2:
VisdomExceptionBase = Exception
else:
VisdomExceptionBase = ConnectionError
def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, W=None, H=None):
"""Save images to the disk.
Parameters:
webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
image_path (str) -- the string is used to create image paths
aspect_ratio (float) -- the aspect ratio of saved images
width (int) -- the images will be resized to width x width
This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
"""
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
webpage.add_header(name)
ims, txts, links = [], [], []
for label, im_data in visuals.items():
## tensor to im
im = util.tensor2im(im_data)
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
h, w, _ = im.shape
if W is not None and H is not None and (W != w or H != h):
im = np.array(Image.fromarray(im).resize((W, H), Image.BICUBIC))
else:
if aspect_ratio > 1.0:
im = np.array(Image.fromarray(im).resize((int(w * aspect_ratio), h), Image.BICUBIC))
if aspect_ratio < 1.0:
im = np.array(Image.fromarray(im).resize((w, int(h / aspect_ratio)), Image.BICUBIC))
util.save_image(im, save_path)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=width)
class Visualizer():
"""This class includes several functions that can display/save images and print/save logging information.
It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
"""
def __init__(self, opt):
"""Initialize the Visualizer class
Parameters:
opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
Step 1: Cache the training/test options
Step 2: connect to a visdom server
Step 3: create an HTML object for saveing HTML filters
Step 4: create a logging file to store training losses
"""
self.opt = opt # cache the option
self.display_id = opt.display_id
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
self.port = opt.display_port
self.saved = False
if self.display_id > 0: # connect to a visdom server given <display_port> and <display_server>
import visdom
self.ncols = opt.display_ncols
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
if not self.vis.check_connection():
self.create_visdom_connections()
if self.use_html: # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
# create a logging file to store training losses
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
def reset(self):
"""Reset the self.saved status"""
self.saved = False
def create_visdom_connections(self):
"""If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
print('Command: %s' % cmd)
Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
def display_current_results(self, visuals, epoch, save_result):
"""Display current results on visdom; save current results to an HTML file.
Parameters:
visuals (OrderedDict) - - dictionary of images to display or save
epoch (int) - - the current epoch
save_result (bool) - - if save the current results to an HTML file
"""
if self.display_id > 0: # show images in the browser using visdom
ncols = self.ncols
if ncols > 0: # show all the images in one visdom panel
ncols = min(ncols, len(visuals))
h, w = next(iter(visuals.values())).shape[:2]
table_css = """<style>
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
</style>""" % (w, h) # create a table css
# create a table of images.
title = self.name
label_html = ''
label_html_row = ''
images = []
idx = 0
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
label_html_row += '<td>%s</td>' % label
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
label_html += '<tr>%s</tr>' % label_html_row
label_html_row = ''
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
while idx % ncols != 0:
images.append(white_image)
label_html_row += '<td></td>'
idx += 1
if label_html_row != '':
label_html += '<tr>%s</tr>' % label_html_row
try:
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '<table>%s</table>' % label_html
self.vis.text(table_css + label_html, win=self.display_id + 2,
opts=dict(title=title + ' labels'))
except VisdomExceptionBase:
self.create_visdom_connections()
else: # show each image in a separate visdom panel;
idx = 1
try:
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
except VisdomExceptionBase:
self.create_visdom_connections()
if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved.
self.saved = True
# save images to the disk
for label, image in visuals.items():
image_numpy = util.tensor2im(image)
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
ims, txts, links = [], [], []
for label, image_numpy in visuals.items():
image_numpy = util.tensor2im(image)
img_path = 'epoch%.3d_%s.png' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
def plot_current_losses(self, epoch, counter_ratio, losses):
"""display the current losses on visdom display: dictionary of error labels and values
Parameters:
epoch (int) -- current epoch
counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
"""
if not hasattr(self, 'plot_data'):
self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
#X = np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1)
#Y = np.array(self.plot_data['Y'])
try:
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
except VisdomExceptionBase:
self.create_visdom_connections()
# losses: same format as |losses| of plot_current_losses
def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
"""print current losses on console; also save the losses to the disk
Parameters:
epoch (int) -- current epoch
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
t_comp (float) -- computational time per data point (normalized by batch_size)
t_data (float) -- data loading time per data point (normalized by batch_size)
"""
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
print(message) # print the message
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message
# losses: same format as |losses| of plot_current_losses
def print_current_losses_process(self, epoch, iters, losses, t_comp, t_data, processes):
"""print current losses on console; also save the losses to the disk
Parameters:
epoch (int) -- current epoch
iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
t_comp (float) -- computational time per data point (normalized by batch_size)
t_data (float) -- data loading time per data point (normalized by batch_size)
"""
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
message += '[process: %.3f, non_trunc: %.3f, trunc: %.3f] ' % (processes[0], processes[1], processes[2])
for k, v in losses.items():
message += '%s: %.3f ' % (k, v)
print(message) # print the message
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message) # save the message
gitextract_uyxvzmrx/
├── .gitignore
├── data/
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── image_folder.py
│ ├── single_dataset.py
│ └── unaligned_mask_stylecls_dataset.py
├── models/
│ ├── __init__.py
│ ├── asymmetric_cycle_gan_cls_model.py
│ ├── base_model.py
│ ├── dist_model.py
│ ├── networks.py
│ ├── networks_basic.py
│ ├── pretrained_networks.py
│ ├── test_3styles_model.py
│ └── test_model.py
├── options/
│ ├── __init__.py
│ ├── base_options.py
│ ├── test_options.py
│ └── train_options.py
├── portrait_drawing_resources.md
├── preprocess/
│ ├── example/
│ │ └── ia_selfie_10515_facial5point.mat
│ ├── face_align_512.m
│ └── readme.md
├── readme.md
├── requirements.txt
├── scripts/
│ └── train.sh
├── test.py
├── test_seq_style.py
├── train.py
└── util/
├── __init__.py
├── get_data.py
├── html.py
├── image_pool.py
├── util.py
└── visualizer.py
SYMBOL INDEX (224 symbols across 23 files)
FILE: data/__init__.py
function find_dataset_using_name (line 18) | def find_dataset_using_name(dataset_name):
function get_option_setter (line 41) | def get_option_setter(dataset_name):
function create_dataset (line 47) | def create_dataset(opt):
class CustomDatasetDataLoader (line 62) | class CustomDatasetDataLoader():
method __init__ (line 65) | def __init__(self, opt):
method load_data (line 81) | def load_data(self):
method __len__ (line 84) | def __len__(self):
method __iter__ (line 88) | def __iter__(self):
FILE: data/base_dataset.py
class BaseDataset (line 13) | class BaseDataset(data.Dataset):
method __init__ (line 24) | def __init__(self, opt):
method modify_commandline_options (line 34) | def modify_commandline_options(parser, is_train):
method __len__ (line 47) | def __len__(self):
method __getitem__ (line 52) | def __getitem__(self, index):
function get_params (line 64) | def get_params(opt, size):
function get_transform (line 82) | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBI...
function get_transform_mask (line 115) | def get_transform_mask(opt, params=None, grayscale=False, method=Image.B...
function __make_power_2 (line 144) | def __make_power_2(img, base, method=Image.BICUBIC):
function __scale_width (line 155) | def __scale_width(img, target_width, method=Image.BICUBIC):
function __crop (line 164) | def __crop(img, pos, size):
function __flip (line 173) | def __flip(img, flip):
function __print_size_warning (line 179) | def __print_size_warning(ow, oh, w, h):
FILE: data/image_folder.py
function is_image_file (line 19) | def is_image_file(filename):
function make_dataset (line 23) | def make_dataset(dir, max_dataset_size=float("inf")):
function default_loader (line 35) | def default_loader(path):
class ImageFolder (line 39) | class ImageFolder(data.Dataset):
method __init__ (line 41) | def __init__(self, root, transform=None, return_paths=False,
method __getitem__ (line 55) | def __getitem__(self, index):
method __len__ (line 65) | def __len__(self):
FILE: data/single_dataset.py
class SingleDataset (line 8) | class SingleDataset(BaseDataset):
method __init__ (line 14) | def __init__(self, opt):
method __getitem__ (line 28) | def __getitem__(self, index):
method __len__ (line 61) | def __len__(self):
FILE: data/unaligned_mask_stylecls_dataset.py
class UnalignedMaskStyleClsDataset (line 12) | class UnalignedMaskStyleClsDataset(BaseDataset):
method __init__ (line 13) | def __init__(self, opt):
method __getitem__ (line 34) | def __getitem__(self, index):
method __len__ (line 98) | def __len__(self):
FILE: models/__init__.py
function find_model_using_name (line 25) | def find_model_using_name(model_name):
function get_option_setter (line 48) | def get_option_setter(model_name):
function create_model (line 54) | def create_model(opt):
FILE: models/asymmetric_cycle_gan_cls_model.py
function truncate (line 10) | def truncate(fake_B,a=127.5):#[-1,1]
class AsymmetricCycleGANClsModel (line 13) | class AsymmetricCycleGANClsModel(BaseModel):
method modify_commandline_options (line 15) | def modify_commandline_options(parser, is_train=True):
method __init__ (line 44) | def __init__(self, opt):
method set_input (line 176) | def set_input(self, input):
method forward (line 206) | def forward(self):
method backward_D_basic (line 231) | def backward_D_basic(self, netD, real, fake):
method backward_D_basic_cls (line 253) | def backward_D_basic_cls(self, netD, real, fake):
method backward_D_A (line 275) | def backward_D_A(self):
method backward_D_A_l (line 280) | def backward_D_A_l(self):
method backward_D_A_le (line 285) | def backward_D_A_le(self):
method backward_D_A_ll (line 290) | def backward_D_A_ll(self):
method backward_D_B (line 295) | def backward_D_B(self):
method update_process (line 300) | def update_process(self, epoch):
method backward_G (line 303) | def backward_G(self):
method optimize_parameters (line 374) | def optimize_parameters(self):
FILE: models/base_model.py
class BaseModel (line 8) | class BaseModel():
method __init__ (line 19) | def __init__(self, opt):
method modify_commandline_options (line 48) | def modify_commandline_options(parser, is_train):
method set_input (line 61) | def set_input(self, input):
method forward (line 70) | def forward(self):
method optimize_parameters (line 75) | def optimize_parameters(self):
method setup (line 79) | def setup(self, opt):
method eval (line 92) | def eval(self):
method test (line 99) | def test(self):
method compute_visuals (line 109) | def compute_visuals(self):
method get_image_paths (line 113) | def get_image_paths(self):
method update_learning_rate (line 117) | def update_learning_rate(self):
method get_current_visuals (line 128) | def get_current_visuals(self):
method get_current_losses (line 136) | def get_current_losses(self):
method save_networks (line 144) | def save_networks(self, epoch):
method __patch_instance_norm_state_dict (line 162) | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i...
method load_networks (line 176) | def load_networks(self, epoch):
method print_networks (line 201) | def print_networks(self, verbose):
method set_requires_grad (line 219) | def set_requires_grad(self, nets, requires_grad=False):
method masked (line 233) | def masked(self, A,mask):
FILE: models/dist_model.py
class DistModel (line 20) | class DistModel(BaseModel):
method name (line 21) | def name(self):
method __init__ (line 24) | def __init__(self, opt, model='net-lin', net='alex', pnet_rand=False, ...
method forward_pair (line 106) | def forward_pair(self,in1,in2,retPerLayer=False):
method forward (line 112) | def forward(self, in0, in1, retNumpy=False):
method optimize_parameters (line 159) | def optimize_parameters(self):
method clamp_weights (line 166) | def clamp_weights(self):
method set_input (line 171) | def set_input(self, data):
method forward_train (line 187) | def forward_train(self): # run forward pass
method backward_train (line 198) | def backward_train(self):
method compute_accuracy (line 201) | def compute_accuracy(self,d0,d1,judge):
method get_current_errors (line 207) | def get_current_errors(self):
method get_current_visuals (line 216) | def get_current_visuals(self):
method save (line 231) | def save(self, path, label):
method update_learning_rate (line 235) | def update_learning_rate(self,nepoch_decay):
function score_2afc_dataset (line 247) | def score_2afc_dataset(data_loader,func):
function score_jnd_dataset (line 284) | def score_jnd_dataset(data_loader,func):
FILE: models/networks.py
class Identity (line 13) | class Identity(nn.Module):
method forward (line 14) | def forward(self, x):
function get_norm_layer (line 18) | def get_norm_layer(norm_type='instance'):
function get_scheduler (line 38) | def get_scheduler(optimizer, opt):
function init_weights (line 67) | def init_weights(net, init_type='normal', init_gain=0.02):
function init_net (line 101) | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
function define_G (line 119) | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=F...
function define_D (line 164) | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type=...
function define_HED (line 210) | def define_HED(init_weights_, gpu_ids_=[]):
class GANLoss (line 233) | class GANLoss(nn.Module):
method __init__ (line 240) | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=...
method get_target_tensor (line 264) | def get_target_tensor(self, prediction, target_is_real):
method __call__ (line 281) | def __call__(self, prediction, target_is_real):
function cal_gradient_penalty (line 302) | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed...
class ResnetGenerator (line 339) | class ResnetGenerator(nn.Module):
method __init__ (line 345) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
method forward (line 395) | def forward(self, input):
class ResnetStyle2Generator (line 399) | class ResnetStyle2Generator(nn.Module):
method __init__ (line 400) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
method forward (line 459) | def forward(self, input1, input2):
class ResnetBlock (line 465) | class ResnetBlock(nn.Module):
method __init__ (line 468) | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bia...
method build_conv_block (line 479) | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout,...
method forward (line 520) | def forward(self, x):
class UnetGenerator (line 526) | class UnetGenerator(nn.Module):
method __init__ (line 529) | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=...
method forward (line 553) | def forward(self, input):
class UnetSkipConnectionBlock (line 558) | class UnetSkipConnectionBlock(nn.Module):
method __init__ (line 564) | def __init__(self, outer_nc, inner_nc, input_nc=None,
method forward (line 621) | def forward(self, x):
class NLayerDiscriminator (line 628) | class NLayerDiscriminator(nn.Module):
method __init__ (line 631) | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNo...
method forward (line 671) | def forward(self, input):
class NLayerDiscriminatorCls (line 676) | class NLayerDiscriminatorCls(nn.Module):
method __init__ (line 679) | def __init__(self, input_nc, ndf=64, n_layers=3, n_class=3, norm_layer...
method forward (line 736) | def forward(self, input):
class PixelDiscriminator (line 746) | class PixelDiscriminator(nn.Module):
method __init__ (line 749) | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
method forward (line 773) | def forward(self, input):
class HED (line 778) | class HED(nn.Module):
method __init__ (line 779) | def __init__(self):
method forward (line 838) | def forward(self, tensorInput):
FILE: models/networks_basic.py
function spatial_average (line 17) | def spatial_average(in_tens, keepdim=True):
function upsample (line 20) | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
class PNetLin (line 27) | class PNetLin(nn.Module):
method __init__ (line 28) | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, ...
method forward (line 64) | def forward(self, in0, in1, retPerLayer=False):
class ScalingLayer (line 94) | class ScalingLayer(nn.Module):
method __init__ (line 95) | def __init__(self):
method forward (line 100) | def forward(self, inp):
class NetLinLayer (line 104) | class NetLinLayer(nn.Module):
method __init__ (line 106) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
class Dist2LogitLayer (line 114) | class Dist2LogitLayer(nn.Module):
method __init__ (line 116) | def __init__(self, chn_mid=32, use_sigmoid=True):
method forward (line 128) | def forward(self,d0,d1,eps=0.1):
class BCERankingLoss (line 131) | class BCERankingLoss(nn.Module):
method __init__ (line 132) | def __init__(self, chn_mid=32):
method forward (line 138) | def forward(self, d0, d1, judge):
class FakeNet (line 144) | class FakeNet(nn.Module):
method __init__ (line 145) | def __init__(self, use_gpu=True, colorspace='Lab'):
class L2 (line 150) | class L2(FakeNet):
method forward (line 152) | def forward(self, in0, in1, retPerLayer=None):
class DSSIM (line 167) | class DSSIM(FakeNet):
method forward (line 169) | def forward(self, in0, in1, retPerLayer=None):
function print_network (line 182) | def print_network(net):
FILE: models/pretrained_networks.py
class squeezenet (line 6) | class squeezenet(torch.nn.Module):
method __init__ (line 7) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 36) | def forward(self, X):
class alexnet (line 57) | class alexnet(torch.nn.Module):
method __init__ (line 58) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 81) | def forward(self, X):
class vgg16 (line 97) | class vgg16(torch.nn.Module):
method __init__ (line 98) | def __init__(self, requires_grad=False, pretrained=True):
method forward (line 121) | def forward(self, X):
class resnet (line 139) | class resnet(torch.nn.Module):
method __init__ (line 140) | def __init__(self, requires_grad=False, pretrained=True, num=18):
method forward (line 163) | def forward(self, X):
FILE: models/test_3styles_model.py
class Test3StylesModel (line 5) | class Test3StylesModel(BaseModel):
method modify_commandline_options (line 12) | def modify_commandline_options(parser, is_train=True):
method __init__ (line 34) | def __init__(self, opt):
method set_input (line 51) | def set_input(self, input):
method forward (line 58) | def forward(self):
method optimize_parameters (line 64) | def optimize_parameters(self):
FILE: models/test_model.py
class TestModel (line 5) | class TestModel(BaseModel):
method modify_commandline_options (line 12) | def modify_commandline_options(parser, is_train=True):
method __init__ (line 40) | def __init__(self, opt):
method set_input (line 71) | def set_input(self, input):
method forward (line 84) | def forward(self):
method optimize_parameters (line 93) | def optimize_parameters(self):
FILE: options/base_options.py
class BaseOptions (line 9) | class BaseOptions():
method __init__ (line 16) | def __init__(self):
method initialize (line 20) | def initialize(self, parser):
method gather_options (line 61) | def gather_options(self):
method print_options (line 89) | def print_options(self, opt):
method parse (line 114) | def parse(self):
FILE: options/test_options.py
class TestOptions (line 4) | class TestOptions(BaseOptions):
method initialize (line 10) | def initialize(self, parser):
FILE: options/train_options.py
class TrainOptions (line 4) | class TrainOptions(BaseOptions):
method initialize (line 10) | def initialize(self, parser):
FILE: test_seq_style.py
function opts (line 4) | def opts():
FILE: util/get_data.py
class GetData (line 11) | class GetData(object):
method __init__ (line 27) | def __init__(self, technique='cyclegan', verbose=True):
method _print (line 35) | def _print(self, text):
method _get_options (line 40) | def _get_options(r):
method _present_options (line 46) | def _present_options(self):
method _download_data (line 56) | def _download_data(self, dataset_url, save_path):
method get (line 79) | def get(self, save_path, dataset=None):
FILE: util/html.py
class HTML (line 6) | class HTML:
method __init__ (line 14) | def __init__(self, web_dir, title, refresh=0, folder='images'):
method get_image_dir (line 37) | def get_image_dir(self):
method add_header (line 41) | def add_header(self, text):
method add_images (line 50) | def add_images(self, ims, txts, links, width=400):
method save (line 71) | def save(self):
FILE: util/image_pool.py
class ImagePool (line 5) | class ImagePool():
method __init__ (line 12) | def __init__(self, pool_size):
method query (line 23) | def query(self, images):
FILE: util/util.py
function tensor2im (line 11) | def tensor2im(input_image, imtype=np.uint8):
function diagnose_network (line 36) | def diagnose_network(net, name='network'):
function save_image (line 55) | def save_image(image_numpy, image_path):
function print_numpy (line 67) | def print_numpy(x, val=True, shp=False):
function mkdirs (line 83) | def mkdirs(paths):
function mkdir (line 96) | def mkdir(path):
function normalize_tensor (line 105) | def normalize_tensor(in_feat,eps=1e-10):
FILE: util/visualizer.py
function save_images (line 16) | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=25...
class Visualizer (line 56) | class Visualizer():
method __init__ (line 62) | def __init__(self, opt):
method reset (line 97) | def reset(self):
method create_visdom_connections (line 101) | def create_visdom_connections(self):
method display_current_results (line 108) | def display_current_results(self, visuals, epoch, save_result):
method plot_current_losses (line 189) | def plot_current_losses(self, epoch, counter_ratio, losses):
method print_current_losses (line 217) | def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
method print_current_losses_process (line 236) | def print_current_losses_process(self, epoch, iters, losses, t_comp, t...
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (196K chars).
[
{
"path": ".gitignore",
"chars": 756,
"preview": ".DS_Store\ndebug*\ndatasets/\ncheckpoints/\nstyle_features/\nresults/\nbuild/\ndist/\ntorch.egg-info/\n*/**/__pycache__\ntorch/ver"
},
{
"path": "data/__init__.py",
"chars": 3554,
"preview": "\"\"\"This package includes all the modules related to data loading and preprocessing\n\n To add a custom dataset class calle"
},
{
"path": "data/base_dataset.py",
"chars": 6615,
"preview": "\"\"\"This module implements an abstract base class (ABC) 'BaseDataset' for datasets.\n\nIt also includes common transformati"
},
{
"path": "data/image_folder.py",
"chars": 1893,
"preview": "\"\"\"A modified image folder class\n\nWe modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/ma"
},
{
"path": "data/single_dataset.py",
"chars": 2583,
"preview": "from data.base_dataset import BaseDataset, get_transform, get_params, get_transform_mask\nfrom data.image_folder import m"
},
{
"path": "data/unaligned_mask_stylecls_dataset.py",
"chars": 4851,
"preview": "import os.path\nfrom data.base_dataset import BaseDataset, get_params, get_transform, get_transform_mask\nfrom data.image_"
},
{
"path": "models/__init__.py",
"chars": 3072,
"preview": "\"\"\"This package contains modules related to objective functions, optimizations, and network architectures.\n\nTo add a cus"
},
{
"path": "models/asymmetric_cycle_gan_cls_model.py",
"chars": 23569,
"preview": "import torch\nimport itertools\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import netw"
},
{
"path": "models/base_model.py",
"chars": 10885,
"preview": "import os\nimport torch\nfrom collections import OrderedDict\nfrom abc import ABCMeta, abstractmethod\nfrom . import network"
},
{
"path": "models/dist_model.py",
"chars": 13695,
"preview": "\nfrom __future__ import absolute_import\n\nimport sys\nsys.path.append('..')\nsys.path.append('.')\nimport numpy as np\nimport"
},
{
"path": "models/networks.py",
"chars": 35426,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools\nfrom torch.optim import lr_scheduler\n\n\n###"
},
{
"path": "models/networks_basic.py",
"chars": 7514,
"preview": "\nfrom __future__ import absolute_import\n\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom"
},
{
"path": "models/pretrained_networks.py",
"chars": 6559,
"preview": "from collections import namedtuple\nimport torch\nfrom torchvision import models\nfrom IPython import embed\n\nclass squeezen"
},
{
"path": "models/test_3styles_model.py",
"chars": 3411,
"preview": "from .base_model import BaseModel\nfrom . import networks\nimport torch\n\nclass Test3StylesModel(BaseModel):\n \"\"\" This T"
},
{
"path": "models/test_model.py",
"chars": 5123,
"preview": "from .base_model import BaseModel\nfrom . import networks\nimport torch\n\nclass TestModel(BaseModel):\n \"\"\" This TesteMod"
},
{
"path": "options/__init__.py",
"chars": 136,
"preview": "\"\"\"This package options includes option modules: training options, test options, and basic options (used in both trainin"
},
{
"path": "options/base_options.py",
"chars": 8437,
"preview": "import argparse\nimport os\nfrom util import util\nimport torch\nimport models\nimport data\n\n\nclass BaseOptions():\n \"\"\"Thi"
},
{
"path": "options/test_options.py",
"chars": 1363,
"preview": "from .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n \"\"\"This class includes test options.\n\n It"
},
{
"path": "options/train_options.py",
"chars": 3516,
"preview": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n \"\"\"This class includes training options.\n\n "
},
{
"path": "portrait_drawing_resources.md",
"chars": 315,
"preview": "\n- Charles Burns (style1): https://www.pinterest.co.uk/johns59/charles-burns-fan-club/\n- Yann Legendre (style1): http://"
},
{
"path": "preprocess/face_align_512.m",
"chars": 1450,
"preview": "function [trans_img]=face_align_512(impath,facial5point,savedir)\n% align the faces by similarity transformation.\n% using"
},
{
"path": "preprocess/readme.md",
"chars": 1612,
"preview": "## Preprocessing steps\n\nDuring training, face photos and drawings are aligned and have nose,eyes,lips mask detected. \n\nD"
},
{
"path": "readme.md",
"chars": 4506,
"preview": "\n# Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping\n\nWe provide PyTorch implementations for our CVPR 20"
},
{
"path": "requirements.txt",
"chars": 128,
"preview": "torch==1.2.0\ntorchvision==0.4.0\ndominate==2.4.0\nvisdom==0.1.8.9\nscipy==1.1.0\nnumpy==1.16.4\nPillow==6.2.1\nopencv-python=="
},
{
"path": "scripts/train.sh",
"chars": 313,
"preview": "set -ex\npython train.py --dataroot ./datasets/portrait_drawing --name formal --model asymmetric_cycle_gan_cls --output_n"
},
{
"path": "test.py",
"chars": 4143,
"preview": "\"\"\"General-purpose test script for image-to-image translation.\n\nOnce you have trained your model with train.py, you can "
},
{
"path": "test_seq_style.py",
"chars": 1809,
"preview": "import os\nimport argparse\n\ndef opts():\n parser = argparse.ArgumentParser()\n parser.add_argument('-g','--gpu', defa"
},
{
"path": "train.py",
"chars": 5219,
"preview": "\"\"\"General-purpose training script for image-to-image translation.\n\nThis script works for various models (with option '-"
},
{
"path": "util/__init__.py",
"chars": 83,
"preview": "\"\"\"This package includes a miscellaneous collection of useful helper functions.\"\"\"\n"
},
{
"path": "util/get_data.py",
"chars": 3639,
"preview": "from __future__ import print_function\nimport os\nimport tarfile\nimport requests\nfrom warnings import warn\nfrom zipfile im"
},
{
"path": "util/html.py",
"chars": 3569,
"preview": "import dominate\nfrom dominate.tags import meta, h3, table, tr, td, p, a, img, br\nimport os\n\n\nclass HTML:\n \"\"\"This HTM"
},
{
"path": "util/image_pool.py",
"chars": 2226,
"preview": "import random\nimport torch\n\n\nclass ImagePool():\n \"\"\"This class implements an image buffer that stores previously gene"
},
{
"path": "util/util.py",
"chars": 3326,
"preview": "\"\"\"This module contains simple helper functions \"\"\"\nfrom __future__ import print_function\nimport torch\nimport numpy as n"
},
{
"path": "util/visualizer.py",
"chars": 12238,
"preview": "import numpy as np\nimport os\nimport sys\nimport ntpath\nimport time\nfrom . import util, html\nfrom subprocess import Popen,"
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the yiranran/Unpaired-Portrait-Drawing GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (183.1 KB), approximately 47.7k tokens, and a symbol index with 224 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.