Repository: yiranran/Unpaired-Portrait-Drawing Branch: master Commit: b67591912a3e Files: 35 Total size: 183.1 KB Directory structure: gitextract_uyxvzmrx/ ├── .gitignore ├── data/ │ ├── __init__.py │ ├── base_dataset.py │ ├── image_folder.py │ ├── single_dataset.py │ └── unaligned_mask_stylecls_dataset.py ├── models/ │ ├── __init__.py │ ├── asymmetric_cycle_gan_cls_model.py │ ├── base_model.py │ ├── dist_model.py │ ├── networks.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ ├── test_3styles_model.py │ └── test_model.py ├── options/ │ ├── __init__.py │ ├── base_options.py │ ├── test_options.py │ └── train_options.py ├── portrait_drawing_resources.md ├── preprocess/ │ ├── example/ │ │ └── ia_selfie_10515_facial5point.mat │ ├── face_align_512.m │ └── readme.md ├── readme.md ├── requirements.txt ├── scripts/ │ └── train.sh ├── test.py ├── test_seq_style.py ├── train.py └── util/ ├── __init__.py ├── get_data.py ├── html.py ├── image_pool.py ├── util.py └── visualizer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .DS_Store debug* datasets/ checkpoints/ style_features/ results/ build/ dist/ torch.egg-info/ */**/__pycache__ torch/version.py torch/csrc/generic/TensorMethods.cpp torch/lib/*.so* torch/lib/*.dylib* torch/lib/*.h torch/lib/build torch/lib/tmp_install torch/lib/include torch/lib/torch_shm_manager torch/csrc/cudnn/cuDNN.cpp torch/csrc/nn/THNN.cwrap torch/csrc/nn/THNN.cpp torch/csrc/nn/THCUNN.cwrap torch/csrc/nn/THCUNN.cpp torch/csrc/nn/THNN_generic.cwrap torch/csrc/nn/THNN_generic.cpp torch/csrc/nn/THNN_generic.h docs/src/**/* test/data/legacy_modules.t7 test/data/gpu_tensors.pt test/htmlcov test/.coverage */*.pyc */**/*.pyc */**/**/*.pyc */**/**/**/*.pyc */**/**/**/**/*.pyc */*.so* */**/*.so* */**/*.dylib* test/data/legacy_serialized.pt *~ .idea ================================================ FILE: data/__init__.py ================================================ """This package includes all the modules related to data loading and preprocessing To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. You need to implement four functions: -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). -- <__len__>: return the size of dataset. -- <__getitem__>: get a data point from data loader. -- : (optionally) add dataset-specific options and set default options. Now you can use the dataset class by specifying flag '--dataset_mode dummy'. See our template dataset class 'template_dataset.py' for more details. """ import importlib import torch.utils.data from data.base_dataset import BaseDataset def find_dataset_using_name(dataset_name): """Import the module "data/[dataset_name]_dataset.py". In the file, the class called DatasetNameDataset() will be instantiated. It has to be a subclass of BaseDataset, and it is case-insensitive. """ dataset_filename = "data." + dataset_name + "_dataset" datasetlib = importlib.import_module(dataset_filename) dataset = None target_dataset_name = dataset_name.replace('_', '') + 'dataset' for name, cls in datasetlib.__dict__.items(): if name.lower() == target_dataset_name.lower() \ and issubclass(cls, BaseDataset): dataset = cls if dataset is None: raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) return dataset def get_option_setter(dataset_name): """Return the static method of the dataset class.""" dataset_class = find_dataset_using_name(dataset_name) return dataset_class.modify_commandline_options def create_dataset(opt): """Create a dataset given the option. This function wraps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' Example: >>> from data import create_dataset >>> dataset = create_dataset(opt) """ data_loader = CustomDatasetDataLoader(opt) dataset = data_loader.load_data() return dataset class CustomDatasetDataLoader(): """Wrapper class of Dataset class that performs multi-threaded data loading""" def __init__(self, opt): """Initialize this class Step 1: create a dataset instance given the name [dataset_mode] Step 2: create a multi-threaded data loader. """ self.opt = opt dataset_class = find_dataset_using_name(opt.dataset_mode) self.dataset = dataset_class(opt) print("dataset [%s] was created" % type(self.dataset).__name__) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batch_size, shuffle=not opt.serial_batches, num_workers=int(opt.num_threads)) def load_data(self): return self def __len__(self): """Return the number of data in the dataset""" return min(len(self.dataset), self.opt.max_dataset_size) def __iter__(self): """Return a batch of data""" for i, data in enumerate(self.dataloader): if i * self.opt.batch_size >= self.opt.max_dataset_size: break yield data ================================================ FILE: data/base_dataset.py ================================================ """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. """ import random import numpy as np import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms from abc import ABCMeta, abstractmethod class BaseDataset(data.Dataset): __metaclass__ = ABCMeta """This class is an abstract base class (ABC) for datasets. To create a subclass, you need to implement the following four functions: -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). -- <__len__>: return the size of dataset. -- <__getitem__>: get a data point. -- : (optionally) add dataset-specific options and set default options. """ def __init__(self, opt): """Initialize the class; save the options in the class Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ self.opt = opt self.root = opt.dataroot @staticmethod def modify_commandline_options(parser, is_train): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ return parser @abstractmethod def __len__(self): """Return the total number of images in the dataset.""" return 0 @abstractmethod def __getitem__(self, index): """Return a data point and its metadata information. Parameters: index - - a random integer for data indexing Returns: a dictionary of data with their names. It ususally contains the data itself and its metadata information. """ pass def get_params(opt, size): w, h = size new_h = h new_w = w if opt.preprocess == 'resize_and_crop': new_h = new_w = opt.load_size elif opt.preprocess == 'scale_width_and_crop': new_w = opt.load_size new_h = opt.load_size * h // w x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) flip = random.random() > 0.5 return {'crop_pos': (x, y), 'flip': flip} def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): transform_list = [] if grayscale: transform_list.append(transforms.Grayscale(1)) if 'resize' in opt.preprocess: osize = [opt.load_size, opt.load_size] transform_list.append(transforms.Resize(osize, method)) elif 'scale_width' in opt.preprocess: transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) if 'crop' in opt.preprocess: if params is None: transform_list.append(transforms.RandomCrop(opt.crop_size)) else: transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) if opt.preprocess == 'none': transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) if not opt.no_flip: if params is None: transform_list.append(transforms.RandomHorizontalFlip()) elif params['flip']: transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) if convert: transform_list += [transforms.ToTensor()] if grayscale: transform_list += [transforms.Normalize((0.5,), (0.5,))] else: transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def get_transform_mask(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): transform_list = [] if grayscale: transform_list.append(transforms.Grayscale(1)) if 'resize' in opt.preprocess: osize = [opt.load_size, opt.load_size] transform_list.append(transforms.Resize(osize, method)) elif 'scale_width' in opt.preprocess: transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) if 'crop' in opt.preprocess: if params is None: transform_list.append(transforms.RandomCrop(opt.crop_size)) else: transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) if opt.preprocess == 'none': transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) if not opt.no_flip: if params is None: transform_list.append(transforms.RandomHorizontalFlip()) elif params['flip']: transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) if convert: transform_list += [transforms.ToTensor()] return transforms.Compose(transform_list) def __make_power_2(img, base, method=Image.BICUBIC): ow, oh = img.size h = int(round(oh / base) * base) w = int(round(ow / base) * base) if (h == oh) and (w == ow): return img __print_size_warning(ow, oh, w, h) return img.resize((w, h), method) def __scale_width(img, target_width, method=Image.BICUBIC): ow, oh = img.size if (ow == target_width): return img w = target_width h = int(target_width * oh / ow) return img.resize((w, h), method) def __crop(img, pos, size): ow, oh = img.size x1, y1 = pos tw = th = size if (ow > tw or oh > th): return img.crop((x1, y1, x1 + tw, y1 + th)) return img def __flip(img, flip): if flip: return img.transpose(Image.FLIP_LEFT_RIGHT) return img def __print_size_warning(ow, oh, w, h): """Print warning information about image size(only print once)""" if not hasattr(__print_size_warning, 'has_printed'): print("The image size needs to be a multiple of 4. " "The loaded image size was (%d, %d), so it was adjusted to " "(%d, %d). This adjustment will be done to all images " "whose sizes are not multiples of 4" % (ow, oh, w, h)) __print_size_warning.has_printed = True ================================================ FILE: data/image_folder.py ================================================ """A modified image folder class We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) so that this class can load images from both current directory and its subdirectories. """ import torch.utils.data as data from PIL import Image import os import os.path IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def make_dataset(dir, max_dataset_size=float("inf")): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) return images[:min(max_dataset_size, len(images))] def default_loader(path): return Image.open(path).convert('RGB') class ImageFolder(data.Dataset): def __init__(self, root, transform=None, return_paths=False, loader=default_loader): imgs = make_dataset(root) if len(imgs) == 0: raise(RuntimeError("Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.transform = transform self.return_paths = return_paths self.loader = loader def __getitem__(self, index): path = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.return_paths: return img, path else: return img def __len__(self): return len(self.imgs) ================================================ FILE: data/single_dataset.py ================================================ from data.base_dataset import BaseDataset, get_transform, get_params, get_transform_mask from data.image_folder import make_dataset from PIL import Image import torch import os class SingleDataset(BaseDataset): """This dataset class can load a set of images specified by the path --dataroot /path/to/data. It can be used for generating CycleGAN results only for one side with the model option '-model test'. """ def __init__(self, opt): """Initialize this dataset class. Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ BaseDataset.__init__(self, opt) if os.path.exists(opt.dataroot): self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size)) else: imglistA = 'datasets/list/%s/%s.txt' % (opt.phase+'A', opt.dataroot) self.A_paths = sorted(open(imglistA, 'r').read().splitlines()) self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc def __getitem__(self, index): """Return a data point and its metadata information. Parameters: index - - a random integer for data indexing Returns a dictionary that contains A and A_paths A(tensor) - - an image in one domain A_paths(str) - - the path of the image """ A_path = self.A_paths[index] A_img = Image.open(A_path).convert('RGB') self.opt.W, self.opt.H = A_img.size transform_params_A = get_params(self.opt, A_img.size) A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img) item = {'A': A, 'A_paths': A_path} if self.opt.style_control: if self.opt.sinput == 'sind': B_style = torch.Tensor([0.,0.,0.]) B_style[self.opt.sind] = 1. elif self.opt.sinput == 'svec': ss = self.opt.svec.split(',') B_style = torch.Tensor([float(ss[0]),float(ss[1]),float(ss[2])]) elif self.opt.sinput == 'simg': self.featureloc = os.path.join('style_features/styles2/', self.opt.sfeature_mode) B_style = np.load(self.featureloc, self.opt.simg[:-4]+'.npy') B_style = B_style.view(3, 1, 1) B_style = B_style.repeat(1, 128, 128) item['B_style'] = B_style return item def __len__(self): """Return the total number of images in the dataset.""" return len(self.A_paths) ================================================ FILE: data/unaligned_mask_stylecls_dataset.py ================================================ import os.path from data.base_dataset import BaseDataset, get_params, get_transform, get_transform_mask from data.image_folder import make_dataset from PIL import Image import random import torch import torchvision.transforms as transforms import numpy as np import pdb class UnalignedMaskStyleClsDataset(BaseDataset): def __init__(self, opt): BaseDataset.__init__(self, opt) self.dir_A = os.path.join(opt.dataroot, opt.phase + '/A') # create a path '/path/to/data/trainA' self.dir_B = os.path.join(opt.dataroot, opt.phase + '/B') # create a path '/path/to/data/trainB' self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA' self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' self.A_size = len(self.A_paths) # get the size of dataset A self.B_size = len(self.B_paths) # get the size of dataset B print("A size:", self.A_size) print("B size:", self.B_size) btoA = self.opt.direction == 'BtoA' self.input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image self.output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image self.auxdir_A = os.path.join(opt.dataroot, "%s/A" % opt.phase) self.auxdir_B = os.path.join(opt.dataroot, "%s/B" % opt.phase) def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] # make sure index is within then range if self.opt.serial_batches: # make sure index is within then range index_B = index % self.B_size else: # randomize the index for domain B to avoid fixed pairs. index_B = random.randint(0, self.B_size - 1) B_path = self.B_paths[index_B] A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') basenA = os.path.basename(A_path) A_mask_img = Image.open(os.path.join(self.auxdir_A+'_nose',basenA)) basenB = os.path.basename(B_path) B_mask_img = Image.open(os.path.join(self.auxdir_B+'_nose',basenB)) if self.opt.use_eye_mask: A_maske_img = Image.open(os.path.join(self.auxdir_A+'_eyes',basenA)) B_maske_img = Image.open(os.path.join(self.auxdir_B+'_eyes',basenB)) if self.opt.use_lip_mask: A_maskl_img = Image.open(os.path.join(self.auxdir_A+'_lips',basenA)) B_maskl_img = Image.open(os.path.join(self.auxdir_B+'_lips',basenB)) # apply image transformation transform_params_A = get_params(self.opt, A_img.size) transform_params_B = get_params(self.opt, B_img.size) A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img) B = get_transform(self.opt, transform_params_B, grayscale=(self.output_nc == 1))(B_img) A_mask = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_mask_img) B_mask = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_mask_img) if self.opt.use_eye_mask: A_maske = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maske_img) B_maske = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maske_img) if self.opt.use_lip_mask: A_maskl = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maskl_img) B_maskl = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maskl_img) item = {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_mask': A_mask, 'B_mask': B_mask} if self.opt.use_eye_mask: item['A_maske'] = A_maske item['B_maske'] = B_maske if self.opt.use_lip_mask: item['A_maskl'] = A_maskl item['B_maskl'] = B_maskl softmax = np.load(os.path.join(self.auxdir_B+'_feat',basenB[:-4]+'.npy')) softmax = torch.Tensor(softmax) [maxv,index] = torch.max(softmax,0) B_label = index if len(self.opt.sfeature_mode) >= 8 and self.opt.sfeature_mode[-8:] == '_softmax': if self.opt.one_hot: B_style = torch.Tensor([0.,0.,0.]) B_style[index] = 1. else: B_style = softmax B_style = B_style.view(3, 1, 1) B_style = B_style.repeat(1, 128, 128) elif self.opt.sfeature_mode == 'domain': B_style = B_label item['B_style'] = B_style item['B_label'] = B_label if self.opt.isTrain and self.opt.style_loss_with_weight: item['B_style0'] = softmax return item def __len__(self): return max(self.A_size, self.B_size) ================================================ FILE: models/__init__.py ================================================ """This package contains modules related to objective functions, optimizations, and network architectures. To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. You need to implement the following five functions: -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- : unpack data from dataset and apply preprocessing. -- : produce intermediate results. -- : calculate loss, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. In the function <__init__>, you need to define four lists: -- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.model_names (str list): define networks used in our training. -- self.visual_names (str list): specify the images that you want to display and save. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. Now you can use the model class by specifying flag '--model dummy'. See our template model class 'template_model.py' for more details. """ import importlib from models.base_model import BaseModel def find_model_using_name(model_name): """Import the module "models/[model_name]_model.py". In the file, the class called DatasetNameModel() will be instantiated. It has to be a subclass of BaseModel, and it is case-insensitive. """ model_filename = "models." + model_name + "_model" modellib = importlib.import_module(model_filename) model = None target_model_name = model_name.replace('_', '') + 'model' for name, cls in modellib.__dict__.items(): if name.lower() == target_model_name.lower() \ and issubclass(cls, BaseModel): model = cls if model is None: print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) exit(0) return model def get_option_setter(model_name): """Return the static method of the model class.""" model_class = find_model_using_name(model_name) return model_class.modify_commandline_options def create_model(opt): """Create a model given the option. This function warps the class CustomDatasetDataLoader. This is the main interface between this package and 'train.py'/'test.py' Example: >>> from models import create_model >>> model = create_model(opt) """ model = find_model_using_name(opt.model) instance = model(opt) print("model [%s] was created" % type(instance).__name__) return instance ================================================ FILE: models/asymmetric_cycle_gan_cls_model.py ================================================ import torch import itertools from util.image_pool import ImagePool from .base_model import BaseModel from . import networks import models.dist_model as dm # numpy==1.14.3 import torchvision.transforms as transforms import os def truncate(fake_B,a=127.5):#[-1,1] return ((fake_B+1)*a).int().float()/a-1 class AsymmetricCycleGANClsModel(BaseModel): @staticmethod def modify_commandline_options(parser, is_train=True): parser.set_defaults(no_dropout=True) # default CycleGAN did not use dropout parser.set_defaults(dataset_mode='unaligned_mask_stylecls') parser.add_argument('--netda', type=str, default='basic_cls') parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A') parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0 (before insert style)') parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)') if is_train: parser.add_argument('--lambda_A', type=float, default=5.0, help='weight for cycle loss (A -> B -> A)') parser.add_argument('--lambda_B', type=float, default=5.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_identity', type=float, default=0, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') parser.add_argument('--ntrunc_trunc', type=int, default=1, help='whether use both non-trunc version and trunc version') parser.add_argument('--trunc_a', type=float, default=31.875, help='multiply which value to round when trunc') parser.add_argument('--lambda_A_trunc', type=float, default=5.0, help='weight for cycle loss for trunc') parser.add_argument('--hed_pretrained_mode', type=str, default='./checkpoints/network-bsds500.pytorch', help='path to the pretrained hed model') parser.add_argument('--lambda_G_A_l', type=float, default=0.5, help='weight for local GAN loss in G') parser.add_argument('--style_loss_with_weight', type=int, default=1, help='whether multiply prob in style loss') # for masks parser.add_argument('--use_mask', type=int, default=1, help='whether use mask for special face region') parser.add_argument('--use_eye_mask', type=int, default=1, help='whether use mask for special face region') parser.add_argument('--use_lip_mask', type=int, default=1, help='whether use mask for special face region') parser.add_argument('--mask_type', type=int, default=3, help='use mask type, 0 outside black, 1 outside white') # for style control parser.add_argument('--style_control', type=int, default=1, help='use style_control') parser.add_argument('--sfeature_mode', type=str, default='1vgg19_softmax', help='vgg19 softmax as feature') parser.add_argument('--one_hot', type=int, default=0, help='use one-hot for style code') return parser def __init__(self, opt): BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B'] # specify the images you want to save/display. The training/test scripts will call visual_names_A = ['real_A', 'fake_B', 'rec_A'] visual_names_B = ['real_B', 'fake_A', 'rec_B'] if self.isTrain and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B) visual_names_A.append('idt_B') visual_names_B.append('idt_A') if self.isTrain: visual_names_A.append('real_A_hed') visual_names_A.append('rec_A_hed') if self.isTrain and self.opt.ntrunc_trunc: visual_names_A.append('rec_At') visual_names_A.append('rec_At_hed') self.loss_names = ['D_A', 'G_A', 'cycle_A', 'cycle_A2', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'G'] if self.isTrain and self.opt.use_mask: visual_names_A.append('fake_B_l') visual_names_A.append('real_B_l') self.loss_names += ['D_A_l', 'G_A_l'] if self.isTrain and self.opt.use_eye_mask: visual_names_A.append('fake_B_le') visual_names_A.append('real_B_le') self.loss_names += ['D_A_le', 'G_A_le'] if self.isTrain and self.opt.use_lip_mask: visual_names_A.append('fake_B_ll') visual_names_A.append('real_B_ll') self.loss_names += ['D_A_ll', 'G_A_ll'] if not self.isTrain and self.opt.use_mask: visual_names_A.append('fake_B_l') visual_names_A.append('real_B_l') if not self.isTrain and self.opt.use_eye_mask: visual_names_A.append('fake_B_le') visual_names_A.append('real_B_le') if not self.isTrain and self.opt.use_lip_mask: visual_names_A.append('fake_B_ll') visual_names_A.append('real_B_ll') self.loss_names += ['D_A_cls','G_A_cls'] self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B print(self.visual_names) # specify the models you want to save to the disk. The training/test scripts will call and . if self.isTrain: self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] if self.opt.use_mask: self.model_names += ['D_A_l'] if self.opt.use_eye_mask: self.model_names += ['D_A_le'] if self.opt.use_lip_mask: self.model_names += ['D_A_ll'] else: # during test time, only load Gs self.model_names = ['G_A', 'G_B'] # define networks (both Generators and discriminators) # The naming is different from those used in the paper. # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X) if not self.opt.style_control: self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) else: print(opt.netga) print('model0_res', opt.model0_res) print('model1_res', opt.model1_res) self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res) self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) if self.isTrain: # define discriminators self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netda, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, n_class=3) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.opt.use_mask: if self.opt.mask_type in [2, 3]: output_nc = opt.output_nc + 1 else: output_nc = opt.output_nc self.netD_A_l = networks.define_D(output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.opt.use_eye_mask: if self.opt.mask_type in [2, 3]: output_nc = opt.output_nc + 1 else: output_nc = opt.output_nc self.netD_A_le = networks.define_D(output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if self.opt.use_lip_mask: if self.opt.mask_type in [2, 3]: output_nc = opt.output_nc + 1 else: output_nc = opt.output_nc self.netD_A_ll = networks.define_D(output_nc, opt.ndf, opt.netD, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) if not self.isTrain: self.criterionGAN = networks.GANLoss('lsgan').to(self.device) if self.isTrain: if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels assert(opt.input_nc == opt.output_nc) self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images # define loss functions self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss. self.criterionCycle = torch.nn.L1Loss() self.criterionIdt = torch.nn.L1Loss() self.criterionCls = torch.nn.CrossEntropyLoss() self.criterionCls2 = torch.nn.CrossEntropyLoss(reduction='none') # initialize optimizers; schedulers will be automatically created by function . self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) if not self.opt.use_mask: self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) elif not self.opt.use_eye_mask: D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999)) elif not self.opt.use_lip_mask: D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters()) self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999)) else: D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters()) + list(self.netD_A_ll.parameters()) self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers.append(self.optimizer_G) self.optimizers.append(self.optimizer_D) self.lpips = dm.DistModel(opt,model='net-lin',net='alex',use_gpu=True) self.hed = networks.define_HED(init_weights_=opt.hed_pretrained_mode, gpu_ids_=self.opt.gpu_ids_p) self.set_requires_grad(self.hed, False) def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): include the data itself and its metadata information. The option 'direction' can be used to swap domain A and domain B. """ AtoB = self.opt.direction == 'AtoB' self.real_A = input['A' if AtoB else 'B'].to(self.device) self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] if self.opt.use_mask: self.A_mask = input['A_mask'].to(self.device) self.B_mask = input['B_mask'].to(self.device) if self.opt.use_eye_mask: self.A_maske = input['A_maske'].to(self.device) self.B_maske = input['B_maske'].to(self.device) if self.opt.use_lip_mask: self.A_maskl = input['A_maskl'].to(self.device) self.B_maskl = input['B_maskl'].to(self.device) if self.opt.style_control: self.real_B_style = input['B_style'].to(self.device) self.real_B_label = input['B_label'].to(self.device) if self.opt.isTrain and self.opt.style_loss_with_weight: self.real_B_style0 = input['B_style0'].to(self.device) self.zero = torch.zeros(self.real_B_label.size(),dtype=torch.int64).to(self.device) self.one = torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device) self.two = 2*torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device) def forward(self): """Run forward pass; called by both functions and .""" if not self.opt.style_control: self.fake_B = self.netG_A(self.real_A) # G_A(A) else: #print(torch.mean(self.real_B_style,(2,3)),'style_control') self.fake_B = self.netG_A(self.real_A, self.real_B_style) self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) self.fake_A = self.netG_B(self.real_B) # G_B(B) if not self.opt.style_control: self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) else: #print(torch.mean(self.real_B_style,(2,3)),'style_control') self.rec_B = self.netG_A(self.fake_A, self.real_B_style) # -- cycle_B loss if self.opt.use_mask: self.fake_B_l = self.masked(self.fake_B,self.A_mask) self.real_B_l = self.masked(self.real_B,self.B_mask) if self.opt.use_eye_mask: self.fake_B_le = self.masked(self.fake_B,self.A_maske) self.real_B_le = self.masked(self.real_B,self.B_maske) if self.opt.use_lip_mask: self.fake_B_ll = self.masked(self.fake_B,self.A_maskl) self.real_B_ll = self.masked(self.real_B,self.B_maskl) def backward_D_basic(self, netD, real, fake): """Calculate GAN loss for the discriminator Parameters: netD (network) -- the discriminator D real (tensor array) -- real images fake (tensor array) -- images generated by a generator Return the discriminator loss. We also call loss_D.backward() to calculate the gradients. """ # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D.backward() return loss_D def backward_D_basic_cls(self, netD, real, fake): # Real pred_real, pred_real_cls = netD(real) loss_D_real = self.criterionGAN(pred_real, True) if not self.opt.style_loss_with_weight: loss_D_real_cls = self.criterionCls(pred_real_cls, self.real_B_label) else: loss_D_real_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_real_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_real_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_real_cls, self.two)) # Fake pred_fake, pred_fake_cls = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) if not self.opt.style_loss_with_weight: loss_D_fake_cls = self.criterionCls(pred_fake_cls, self.real_B_label) else: loss_D_fake_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two)) # Combined loss and calculate gradients loss_D = (loss_D_real + loss_D_fake) * 0.5 loss_D_cls = (loss_D_real_cls + loss_D_fake_cls) * 0.5 loss_D_total = loss_D + loss_D_cls loss_D_total.backward() return loss_D, loss_D_cls def backward_D_A(self): """Calculate GAN loss for discriminator D_A""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A, self.loss_D_A_cls = self.backward_D_basic_cls(self.netD_A, self.real_B, fake_B) def backward_D_A_l(self): """Calculate GAN loss for discriminator D_A_l""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A_l = self.backward_D_basic(self.netD_A_l, self.masked(self.real_B,self.B_mask), self.masked(fake_B,self.A_mask)) def backward_D_A_le(self): """Calculate GAN loss for discriminator D_A_le""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A_le = self.backward_D_basic(self.netD_A_le, self.masked(self.real_B,self.B_maske), self.masked(fake_B,self.A_maske)) def backward_D_A_ll(self): """Calculate GAN loss for discriminator D_A_ll""" fake_B = self.fake_B_pool.query(self.fake_B) self.loss_D_A_ll = self.backward_D_basic(self.netD_A_ll, self.masked(self.real_B,self.B_maskl), self.masked(fake_B,self.A_maskl)) def backward_D_B(self): """Calculate GAN loss for discriminator D_B""" fake_A = self.fake_A_pool.query(self.fake_A) self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) def update_process(self, epoch): self.process = (epoch - 1) / float(self.opt.niter_decay + self.opt.niter) def backward_G(self): """Calculate the loss for generators G_A and G_B""" lambda_idt = self.opt.lambda_identity lambda_G_A_l = self.opt.lambda_G_A_l lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B lambda_A_trunc = self.opt.lambda_A_trunc if self.opt.ntrunc_trunc: lambda_A = lambda_A * (1 - self.process * 0.9) lambda_A_trunc = lambda_A_trunc * self.process * 0.9 self.lambda_As = [lambda_A, lambda_A_trunc] # Identity loss if lambda_idt > 0: # G_A should be identity if real_B is fed: ||G_A(B) - B|| self.idt_A = self.netG_A(self.real_B) self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt # G_B should be identity if real_A is fed: ||G_B(A) - A|| self.idt_B = self.netG_B(self.real_A) self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt else: self.loss_idt_A = 0 self.loss_idt_B = 0 # GAN loss D_A(G_A(A)) pred_fake, pred_fake_cls = self.netD_A(self.fake_B) self.loss_G_A = self.criterionGAN(pred_fake, True) if not self.opt.style_loss_with_weight: self.loss_G_A_cls = self.criterionCls(pred_fake_cls, self.real_B_label) else: self.loss_G_A_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two)) if self.opt.use_mask: self.loss_G_A_l = self.criterionGAN(self.netD_A_l(self.fake_B_l), True) * lambda_G_A_l if self.opt.use_eye_mask: self.loss_G_A_le = self.criterionGAN(self.netD_A_le(self.fake_B_le), True) * lambda_G_A_l if self.opt.use_lip_mask: self.loss_G_A_ll = self.criterionGAN(self.netD_A_ll(self.fake_B_ll), True) * lambda_G_A_l # GAN loss D_B(G_B(B)) self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True) # Forward cycle loss LPIPS( HED(G_B(G_A(A))), HED(A)) ts = self.real_A.shape gpu_p = self.opt.gpu_ids_p[0] gpu = self.opt.gpu_ids[0] rec_A_hed = (self.hed(self.rec_A.cuda(gpu_p)/2+0.5)-0.5)*2 real_A_hed = (self.hed(self.real_A.cuda(gpu_p)/2+0.5)-0.5)*2 self.loss_cycle_A = (self.lpips.forward_pair(rec_A_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A self.rec_A_hed = rec_A_hed self.real_A_hed = real_A_hed if self.opt.ntrunc_trunc: self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a)) rec_At_hed = (self.hed(self.rec_At.cuda(gpu_p)/2+0.5)-0.5)*2 self.loss_cycle_A2 = (self.lpips.forward_pair(rec_At_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A_trunc self.rec_At_hed = rec_At_hed # Backward cycle loss || G_A(G_B(B)) - B|| self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B # combined loss and calculate gradients self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B if getattr(self,'loss_cycle_A2',-1) != -1: self.loss_G = self.loss_G + self.loss_cycle_A2 if getattr(self,'loss_G_A_l',-1) != -1: self.loss_G = self.loss_G + self.loss_G_A_l if getattr(self,'loss_G_A_le',-1) != -1: self.loss_G = self.loss_G + self.loss_G_A_le if getattr(self,'loss_G_A_ll',-1) != -1: self.loss_G = self.loss_G + self.loss_G_A_ll if getattr(self,'loss_G_A_cls',-1) != -1: self.loss_G = self.loss_G + self.loss_G_A_cls self.loss_G.backward() def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" # forward self.forward() # compute fake images and reconstruction images. # G_A and G_B self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs if self.opt.use_mask: self.set_requires_grad([self.netD_A_l], False) if self.opt.use_eye_mask: self.set_requires_grad([self.netD_A_le], False) if self.opt.use_lip_mask: self.set_requires_grad([self.netD_A_ll], False) self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero self.backward_G() # calculate gradients for G_A and G_B self.optimizer_G.step() # update G_A and G_B's weights # D_A and D_B self.set_requires_grad([self.netD_A, self.netD_B], True) if self.opt.use_mask: self.set_requires_grad([self.netD_A_l], True) if self.opt.use_eye_mask: self.set_requires_grad([self.netD_A_le], True) if self.opt.use_lip_mask: self.set_requires_grad([self.netD_A_ll], True) self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero self.backward_D_A() # calculate gradients for D_A if self.opt.use_mask: self.backward_D_A_l()# calculate gradients for D_A_l if self.opt.use_eye_mask: self.backward_D_A_le()# calculate gradients for D_A_le if self.opt.use_lip_mask: self.backward_D_A_ll()# calculate gradients for D_A_ll self.backward_D_B() # calculate graidents for D_B self.optimizer_D.step() # update D_A and D_B's weights ================================================ FILE: models/base_model.py ================================================ import os import torch from collections import OrderedDict from abc import ABCMeta, abstractmethod from . import networks class BaseModel(): __metaclass__ = ABCMeta """This class is an abstract base class (ABC) for models. To create a subclass, you need to implement the following five functions: -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). -- : unpack data from dataset and apply preprocessing. -- : produce intermediate results. -- : calculate losses, gradients, and update network weights. -- : (optionally) add model-specific options and set default options. """ def __init__(self, opt): """Initialize the BaseModel class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions When creating your custom class, you need to implement your own initialization. In this fucntion, you should first call Then, you need to define four lists: -- self.loss_names (str list): specify the training losses that you want to plot and save. -- self.model_names (str list): specify the images that you want to display and save. -- self.visual_names (str list): define networks used in our training. -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. """ self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. torch.backends.cudnn.benchmark = True self.loss_names = [] self.model_names = [] self.visual_names = [] self.optimizers = [] self.image_paths = [] self.metric = 0 # used for learning rate policy 'plateau' @staticmethod def modify_commandline_options(parser, is_train): """Add new model-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. """ return parser @abstractmethod def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input (dict): includes the data itself and its metadata information. """ pass @abstractmethod def forward(self): """Run forward pass; called by both functions and .""" pass @abstractmethod def optimize_parameters(self): """Calculate losses, gradients, and update network weights; called in every training iteration""" pass def setup(self, opt): """Load and print networks; create schedulers Parameters: opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions """ if self.isTrain: self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] if not self.isTrain or opt.continue_train: load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch self.load_networks(load_suffix) self.print_networks(opt.verbose) def eval(self): """Make models eval mode during test time""" for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) net.eval() def test(self): """Forward function used in test time. This function wraps function in no_grad() so we don't save intermediate steps for backprop It also calls to produce additional visualization results """ with torch.no_grad(): self.forward() self.compute_visuals() def compute_visuals(self): """Calculate additional output images for visdom and HTML visualization""" pass def get_image_paths(self): """ Return image paths that are used to load current data""" return self.image_paths def update_learning_rate(self): """Update learning rates for all the networks; called at the end of every epoch""" for scheduler in self.schedulers: if self.opt.lr_policy == 'plateau': scheduler.step(self.metric) else: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate = %.7f' % lr) def get_current_visuals(self): """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" visual_ret = OrderedDict() for name in self.visual_names: if isinstance(name, str): visual_ret[name] = getattr(self, name) return visual_ret def get_current_losses(self): """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" errors_ret = OrderedDict() for name in self.loss_names: if isinstance(name, str): errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number return errors_ret def save_networks(self, epoch): """Save all the networks to the disk. Parameters: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ for name in self.model_names: if isinstance(name, str): save_filename = '%s_net_%s.pth' % (epoch, name) save_path = os.path.join(self.save_dir, save_filename) net = getattr(self, 'net' + name) if len(self.gpu_ids) > 0 and torch.cuda.is_available(): torch.save(net.module.cpu().state_dict(), save_path) net.cuda(self.gpu_ids[0]) else: torch.save(net.cpu().state_dict(), save_path) def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" key = keys[i] if i + 1 == len(keys): # at the end, pointing to a parameter/buffer if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'running_mean' or key == 'running_var'): if getattr(module, key) is None: state_dict.pop('.'.join(keys)) if module.__class__.__name__.startswith('InstanceNorm') and \ (key == 'num_batches_tracked'): state_dict.pop('.'.join(keys)) else: self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) def load_networks(self, epoch): """Load all the networks from the disk. Parameters: epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) """ for name in self.model_names: if isinstance(name, str): load_filename = '%s_net_%s.pth' % (epoch, name) load_path = os.path.join(self.save_dir, load_filename) net = getattr(self, 'net' + name) if isinstance(net, torch.nn.DataParallel): net = net.module print('loading the model from %s' % load_path) # if you are using PyTorch newer than 0.4 (e.g., built from # GitHub source), you can remove str() on self.device state_dict = torch.load(load_path, map_location=str(self.device)) if hasattr(state_dict, '_metadata'): del state_dict._metadata # patch InstanceNorm checkpoints prior to 0.4 for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) net.load_state_dict(state_dict) def print_networks(self, verbose): """Print the total number of parameters in the network and (if verbose) network architecture Parameters: verbose (bool) -- if verbose: print the network architecture """ print('---------- Networks initialized -------------') for name in self.model_names: if isinstance(name, str): net = getattr(self, 'net' + name) num_params = 0 for param in net.parameters(): num_params += param.numel() if verbose: print(net) print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) print('-----------------------------------------------') def set_requires_grad(self, nets, requires_grad=False): """Set requies_grad=Fasle for all the networks to avoid unnecessary computations Parameters: nets (network list) -- a list of networks requires_grad (bool) -- whether the networks require gradients or not """ if not isinstance(nets, list): nets = [nets] for net in nets: if net is not None: for param in net.parameters(): param.requires_grad = requires_grad # =========================================================================================================== def masked(self, A,mask): if self.opt.mask_type == 0: return (A/2+0.5)*mask*2-1 elif self.opt.mask_type == 1: return ((A/2+0.5)*mask+1-mask)*2-1 elif self.opt.mask_type == 2: return torch.cat((A, mask), 1) elif self.opt.mask_type == 3: masked = ((A/2+0.5)*mask+1-mask)*2-1 return torch.cat((masked, mask), 1) ================================================ FILE: models/dist_model.py ================================================ from __future__ import absolute_import import sys sys.path.append('..') sys.path.append('.') import numpy as np import torch from torch import nn from collections import OrderedDict from torch.autograd import Variable from .base_model import BaseModel from scipy.ndimage import zoom import skimage.transform from . import networks_basic as networks # from PerceptualSimilarity.util import util from util import util class DistModel(BaseModel): def name(self): return self.model_name def __init__(self, opt, model='net-lin', net='alex', pnet_rand=False, pnet_tune=False, model_path=None, colorspace='Lab', use_gpu=True, printNet=False, spatial=False, spatial_shape=None, spatial_order=1, spatial_factor=None, is_train=False, lr=.0001, beta1=0.5, version='0.1'): ''' INPUTS model - ['net-lin'] for linearly calibrated network ['net'] for off-the-shelf network ['L2'] for L2 distance in Lab colorspace ['SSIM'] for ssim in RGB colorspace net - ['squeeze','alex','vgg'] model_path - if None, will look in weights/[NET_NAME].pth colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM use_gpu - bool - whether or not to use a GPU printNet - bool - whether or not to print network architecture out spatial - bool - whether to output an array containing varying distances across spatial dimensions spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). is_train - bool - [True] for training mode lr - float - initial learning rate beta1 - float - initial momentum term for adam version - 0.1 for latest, 0.0 was original ''' BaseModel.__init__(self, opt) self.model = model self.net = net self.use_gpu = use_gpu self.is_train = is_train self.spatial = spatial self.spatial_shape = spatial_shape self.spatial_order = spatial_order self.spatial_factor = spatial_factor self.model_name = '%s [%s]'%(model,net) if(self.model == 'net-lin'): # pretrained net + linear layer #self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu') self.device = torch.device('cuda:{}'.format(opt.gpu_ids_p[0])) if opt.gpu_ids_p else torch.device('cpu') self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,use_dropout=True,spatial=spatial,version=version,lpips=True).to(self.device) kw = {} if not use_gpu: kw['map_location'] = 'cpu' if(model_path is None): import inspect #model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', '..', 'weights/v%s/%s.pth'%(version,net))) model_path = './checkpoints/weights/v%s/%s.pth'%(version,net) if(not is_train): print('Loading model from: %s'%model_path) #self.net.load_state_dict(torch.load(model_path, **kw)) state_dict = torch.load(model_path, map_location=str(self.device)) self.net.load_state_dict(state_dict, strict=False) elif(self.model=='net'): # pretrained network assert not self.spatial, 'spatial argument not supported yet for uncalibrated networks' self.net = networks.PNet(use_gpu=use_gpu,pnet_type=net,device=self.device) self.is_fake_net = True elif(self.model in ['L2','l2']): self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace,device=self.device) # not really a network, only for testing self.model_name = 'L2' elif(self.model in ['DSSIM','dssim','SSIM','ssim']): self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace,device=self.device) self.model_name = 'SSIM' else: raise ValueError("Model [%s] not recognized." % self.model) self.parameters = list(self.net.parameters()) if self.is_train: # training mode # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) self.rankLoss = networks.BCERankingLoss(use_gpu=use_gpu,device=self.device) self.parameters+=self.rankLoss.parameters self.lr = lr self.old_lr = lr self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) else: # test mode self.net.eval() if(printNet): print('---------- Networks initialized -------------') networks.print_network(self.net) print('-----------------------------------------------') def forward_pair(self,in1,in2,retPerLayer=False): if(retPerLayer): return self.net.forward(in1,in2, retPerLayer=True) else: return self.net.forward(in1,in2) def forward(self, in0, in1, retNumpy=False): ''' Function computes the distance between image patches in0 and in1 INPUTS in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] retNumpy - [False] to return as torch.Tensor, [True] to return as numpy array OUTPUT computed distances between in0 and in1 ''' self.input_ref = in0 self.input_p0 = in1 self.var_ref = Variable(self.input_ref,requires_grad=True) self.var_p0 = Variable(self.input_p0,requires_grad=True) self.d0 = self.forward_pair(self.var_ref, self.var_p0) self.loss_total = self.d0 def convert_output(d0): if(retNumpy): ans = d0.cpu().data.numpy() if not self.spatial: ans = ans.flatten() else: assert(ans.shape[0] == 1 and len(ans.shape) == 4) return ans[0,...].transpose([1, 2, 0]) # Reshape to usual numpy image format: (height, width, channels) return ans else: return d0 if self.spatial: L = [convert_output(x) for x in self.d0] spatial_shape = self.spatial_shape if spatial_shape is None: if(self.spatial_factor is None): spatial_shape = (in0.size()[2],in0.size()[3]) else: spatial_shape = (max([x.shape[0] for x in L])*self.spatial_factor, max([x.shape[1] for x in L])*self.spatial_factor) L = [skimage.transform.resize(x, spatial_shape, order=self.spatial_order, mode='edge') for x in L] L = np.mean(np.concatenate(L, 2) * len(L), 2) return L else: return convert_output(self.d0) # ***** TRAINING FUNCTIONS ***** def optimize_parameters(self): self.forward_train() self.optimizer_net.zero_grad() self.backward_train() self.optimizer_net.step() self.clamp_weights() def clamp_weights(self): for module in self.net.modules(): if(hasattr(module, 'weight') and module.kernel_size==(1,1)): module.weight.data = torch.clamp(module.weight.data,min=0) def set_input(self, data): self.input_ref = data['ref'] self.input_p0 = data['p0'] self.input_p1 = data['p1'] self.input_judge = data['judge'] if(self.use_gpu): self.input_ref = self.input_ref.cuda(self.device) self.input_p0 = self.input_p0.cuda(self.device) self.input_p1 = self.input_p1.cuda(self.device) self.input_judge = self.input_judge.cuda(self.device) self.var_ref = Variable(self.input_ref,requires_grad=True) self.var_p0 = Variable(self.input_p0,requires_grad=True) self.var_p1 = Variable(self.input_p1,requires_grad=True) def forward_train(self): # run forward pass self.d0 = self.forward_pair(self.var_ref, self.var_p0) self.d1 = self.forward_pair(self.var_ref, self.var_p1) self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) # var_judge self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) return self.loss_total def backward_train(self): torch.mean(self.loss_total).backward() def compute_accuracy(self,d0,d1,judge): ''' d0, d1 are Variables, judge is a Tensor ''' d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) self.old_lr = lr def score_2afc_dataset(data_loader,func): ''' Function computes Two Alternative Forced Choice (2AFC) score using distance function 'func' in dataset 'data_loader' INPUTS data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside func - callable distance function - calling d=func(in0,in1) should take 2 pytorch tensors with shape Nx3xXxY, and return numpy array of length N OUTPUTS [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators [1] - dictionary with following elements d0s,d1s - N arrays containing distances between reference patch to perturbed patches gts - N array in [0,1], preferred patch selected by human evaluators (closer to "0" for left patch p0, "1" for right patch p1, "0.6" means 60pct people preferred right patch, 40pct preferred left) scores - N array in [0,1], corresponding to what percentage function agreed with humans CONSTS N - number of test triplets in data_loader ''' d0s = [] d1s = [] gts = [] # bar = pb.ProgressBar(max_value=data_loader.load_data().__len__()) for (i,data) in enumerate(data_loader.load_data()): d0s+=func(data['ref'],data['p0']).tolist() d1s+=func(data['ref'],data['p1']).tolist() gts+=data['judge'].cpu().numpy().flatten().tolist() # bar.update(i) d0s = np.array(d0s) d1s = np.array(d1s) gts = np.array(gts) scores = (d0s epochs and linearly decay the rate to zero over the next epochs. For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. See https://pytorch.org/docs/stable/optim.html for more details. """ if opt.lr_policy == 'linear': def lambda_rule(epoch): lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'step': scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) elif opt.lr_policy == 'plateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) elif opt.lr_policy == 'cosine': scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) else: return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) return scheduler def init_weights(net, init_type='normal', init_gain=0.02): """Initialize network weights. Parameters: net (network) -- network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal init_gain (float) -- scaling factor for normal, xavier and orthogonal. We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. """ def init_func(m): # define the initialization function classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': init.normal_(m.weight.data, 0.0, init_gain) elif init_type == 'xavier': init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == 'kaiming': init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. init.normal_(m.weight.data, 1.0, init_gain) init.constant_(m.bias.data, 0.0) print('initialize network with %s' % init_type) net.apply(init_func) # apply the initialization function def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights Parameters: net (network) -- the network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Return an initialized network. """ if len(gpu_ids) > 0: assert(torch.cuda.is_available()) net.to(gpu_ids[0]) net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs init_weights(net, init_type, init_gain=init_gain) return net def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], model0_res=0, model1_res=0, extra_channel=3): """Create a generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images ngf (int) -- the number of filters in the last conv layer netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 norm (str) -- the name of normalization layers used in the network: batch | instance | none use_dropout (bool) -- if use dropout layers. init_type (str) -- the name of our initialization method. init_gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Returns a generator Our current implementation provides two types of generators: U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) The original U-Net paper: https://arxiv.org/abs/1505.04597 Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). The generator has been initialized by . It uses RELU for non-linearity. """ net = None norm_layer = get_norm_layer(norm_type=norm) if netG == 'resnet_9blocks': net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) elif netG == 'resnet_style2_9blocks': net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, extra_channel=extra_channel) elif netG == 'resnet_6blocks': net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) elif netG == 'unet_128': net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) elif netG == 'unet_256': net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) else: raise NotImplementedError('Generator model name [%s] is not recognized' % netG) return init_net(net, init_type, init_gain, gpu_ids) def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], n_class=3): """Create a discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the first conv layer netD (str) -- the architecture's name: basic | n_layers | pixel n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' norm (str) -- the type of normalization layers used in the network. init_type (str) -- the name of the initialization method. init_gain (float) -- scaling factor for normal, xavier and orthogonal. gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 Returns a discriminator Our current implementation provides three types of discriminators: [basic]: 'PatchGAN' classifier described in the original pix2pix paper. It can classify whether 70×70 overlapping patches are real or fake. Such a patch-level discriminator architecture has fewer parameters than a full-image discriminator and can work on arbitrarily-sized images in a fully convolutional fashion. [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator with the parameter (default=3 as used in [basic] (PatchGAN).) [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. It encourages greater color diversity but has no effect on spatial statistics. The discriminator has been initialized by . It uses Leakly RELU for non-linearity. """ net = None norm_layer = get_norm_layer(norm_type=norm) if netD == 'basic': # default PatchGAN classifier net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) elif netD == 'basic_cls': net = NLayerDiscriminatorCls(input_nc, ndf, n_layers=3, n_class=3, norm_layer=norm_layer) elif netD == 'n_layers': # more options net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) elif netD == 'pixel': # classify if each pixel is real or fake net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) return init_net(net, init_type, init_gain, gpu_ids) def define_HED(init_weights_, gpu_ids_=[]): net = HED() if len(gpu_ids_) > 0: assert(torch.cuda.is_available()) net.to(gpu_ids_[0]) net = torch.nn.DataParallel(net, gpu_ids_) # multi-GPUs if not init_weights_ == None: device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu') print('Loading model from: %s'%init_weights_) state_dict = torch.load(init_weights_, map_location=str(device)) if isinstance(net, torch.nn.DataParallel): net.module.load_state_dict(state_dict) else: net.load_state_dict(state_dict) print('load the weights successfully') return net ############################################################################## # Classes ############################################################################## class GANLoss(nn.Module): """Define different GAN objectives. The GANLoss class abstracts away the need to create the target label tensor that has the same size as the input. """ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): """ Initialize the GANLoss class. Parameters: gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. target_real_label (bool) - - label for a real image target_fake_label (bool) - - label of a fake image Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. """ super(GANLoss, self).__init__() self.register_buffer('real_label', torch.tensor(target_real_label)) self.register_buffer('fake_label', torch.tensor(target_fake_label)) self.gan_mode = gan_mode if gan_mode == 'lsgan':#cyclegan self.loss = nn.MSELoss() elif gan_mode == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif gan_mode in ['wgangp']: self.loss = None else: raise NotImplementedError('gan mode %s not implemented' % gan_mode) def get_target_tensor(self, prediction, target_is_real): """Create label tensors with the same size as the input. Parameters: prediction (tensor) - - tpyically the prediction from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images Returns: A label tensor filled with ground truth label, and with the size of the input """ if target_is_real: target_tensor = self.real_label else: target_tensor = self.fake_label return target_tensor.expand_as(prediction) def __call__(self, prediction, target_is_real): """Calculate loss given Discriminator's output and grount truth labels. Parameters: prediction (tensor) - - tpyically the prediction output from a discriminator target_is_real (bool) - - if the ground truth label is for real images or fake images Returns: the calculated loss. """ if self.gan_mode in ['lsgan', 'vanilla']: target_tensor = self.get_target_tensor(prediction, target_is_real) loss = self.loss(prediction, target_tensor) elif self.gan_mode == 'wgangp': if target_is_real: loss = -prediction.mean() else: loss = prediction.mean() return loss def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 Arguments: netD (network) -- discriminator network real_data (tensor array) -- real images fake_data (tensor array) -- generated images from the generator device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') type (str) -- if we mix real and fake data or not [real | fake | mixed]. constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2 lambda_gp (float) -- weight for this loss Returns the gradient penalty loss """ if lambda_gp > 0.0: if type == 'real': # either use real images, fake images, or a linear interpolation of two. interpolatesv = real_data elif type == 'fake': interpolatesv = fake_data elif type == 'mixed': alpha = torch.rand(real_data.shape[0], 1, device=device) alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) else: raise NotImplementedError('{} not implemented'.format(type)) interpolatesv.requires_grad_(True) disc_interpolates = netD(interpolatesv) gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, grad_outputs=torch.ones(disc_interpolates.size()).to(device), create_graph=True, retain_graph=True, only_inputs=True) gradients = gradients[0].view(real_data.size(0), -1) # flat the data gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps return gradient_penalty, gradients else: return 0.0, None class ResnetGenerator(nn.Module): """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) """ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): """Construct a Resnet-based generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers n_blocks (int) -- the number of ResNet blocks padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero """ assert(n_blocks >= 0) super(ResnetGenerator, self).__init__() if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): # add downsampling layers mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2 ** n_downsampling for i in range(n_blocks): # add ResNet blocks model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] for i in range(n_downsampling): # add upsampling layers mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): """Standard forward""" return self.model(input) class ResnetStyle2Generator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0): """Construct a Resnet-based generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers n_blocks (int) -- the number of ResNet blocks padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero """ assert(n_blocks >= 0) super(ResnetStyle2Generator, self).__init__() if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d model0 = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): # add downsampling layers mult = 2 ** i model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2 ** n_downsampling for i in range(model0_res): # add ResNet blocks model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] model = [] model += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias), norm_layer(ngf * mult), nn.ReLU(True)] for i in range(n_blocks-model0_res): # add ResNet blocks model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] for i in range(n_downsampling): # add upsampling layers mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model0 = nn.Sequential(*model0) self.model = nn.Sequential(*model) #print(list(self.modules())) def forward(self, input1, input2): """Standard forward""" f1 = self.model0(input1) y1 = torch.cat([f1, input2], 1) return self.model(y1) class ResnetBlock(nn.Module): """Define a Resnet block""" def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3): """Initialize the Resnet block A resnet block is a conv block with skip connections We construct a conv block with build_conv_block function, and implement skip connections in function. Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf """ super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, kernel) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3): """Construct a convolutional block. Parameters: dim (int) -- the number of channels in the conv layer. padding_type (str) -- the name of padding layer: reflect | replicate | zero norm_layer -- normalization layer use_dropout (bool) -- if use dropout layers. use_bias (bool) -- if the conv layer uses bias or not Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) """ conv_block = [] p = 0 pad = int((kernel-1)/2) if padding_type == 'reflect':#by default conv_block += [nn.ReflectionPad2d(pad)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(pad)] elif padding_type == 'zero': p = pad else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(pad)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(pad)] elif padding_type == 'zero': p = pad else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): """Forward function (with skip connections)""" out = x + self.conv_block(x) # add skip connections return out class UnetGenerator(nn.Module): """Create a Unet-based generator""" def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): """Construct a Unet generator Parameters: input_nc (int) -- the number of channels in input images output_nc (int) -- the number of channels in output images num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck ngf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer We construct the U-Net from the innermost layer to the outermost layer. It is a recursive process. """ super(UnetGenerator, self).__init__() # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) # gradually reduce the number of filters from ngf * 8 to ngf unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer def forward(self, input): """Standard forward""" return self.model(input) class UnetSkipConnectionBlock(nn.Module): """Defines the Unet submodule with skip connection. X -------------------identity---------------------- |-- downsampling -- |submodule| -- upsampling --| """ def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): """Construct a Unet submodule with skip connections. Parameters: outer_nc (int) -- the number of filters in the outer conv layer inner_nc (int) -- the number of filters in the inner conv layer input_nc (int) -- the number of channels in input images/features submodule (UnetSkipConnectionBlock) -- previously defined submodules outermost (bool) -- if this module is the outermost module innermost (bool) -- if this module is the innermost module norm_layer -- normalization layer user_dropout (bool) -- if use dropout layers. """ super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: # add skip connections return torch.cat([x, self.model(x)], 1) class NLayerDiscriminator(nn.Module): """Defines a PatchGAN discriminator""" def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminator, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: use_bias = norm_layer != nn.BatchNorm2d kw = 4 padw = 1 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map self.model = nn.Sequential(*sequence) def forward(self, input): """Standard forward.""" return self.model(input) class NLayerDiscriminatorCls(nn.Module): """Defines a PatchGAN discriminator""" def __init__(self, input_nc, ndf=64, n_layers=3, n_class=3, norm_layer=nn.BatchNorm2d): """Construct a PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer n_layers (int) -- the number of conv layers in the discriminator norm_layer -- normalization layer """ super(NLayerDiscriminatorCls, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.BatchNorm2d else: use_bias = norm_layer != nn.BatchNorm2d kw = 4 padw = 1 sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): # gradually increase the number of filters nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence1 = [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence1 += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map sequence2 = [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence2 += [ nn.Conv2d(ndf * nf_mult, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence2 += [ nn.Conv2d(ndf * nf_mult, n_class, kernel_size=16, stride=1, padding=0, bias=use_bias)] self.model0 = nn.Sequential(*sequence) self.model1 = nn.Sequential(*sequence1) self.model2 = nn.Sequential(*sequence2) print(list(self.modules())) def forward(self, input): """Standard forward.""" feat = self.model0(input) # patchGAN output (1 * 62 * 62) patch = self.model1(feat) # class output (3 * 1 * 1) classl = self.model2(feat) return patch, classl.view(classl.size(0), -1) class PixelDiscriminator(nn.Module): """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): """Construct a 1x1 PatchGAN discriminator Parameters: input_nc (int) -- the number of channels in input images ndf (int) -- the number of filters in the last conv layer norm_layer -- normalization layer """ super(PixelDiscriminator, self).__init__() if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters use_bias = norm_layer.func != nn.InstanceNorm2d else: use_bias = norm_layer != nn.InstanceNorm2d self.net = [ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), norm_layer(ndf * 2), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] self.net = nn.Sequential(*self.net) def forward(self, input): """Standard forward.""" return self.net(input) class HED(nn.Module): def __init__(self): super(HED, self).__init__() self.moduleVggOne = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False) ) self.moduleVggTwo = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False) ) self.moduleVggThr = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False) ) self.moduleVggFou = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False) ) self.moduleVggFiv = nn.Sequential( nn.MaxPool2d(kernel_size=2, stride=2), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False), nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False) ) self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0) self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0) self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0) self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0) self.moduleCombine = nn.Sequential( nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0), nn.Sigmoid() ) def forward(self, tensorInput): tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793 tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762 tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434 tensorInput = torch.cat([ tensorBlue, tensorGreen, tensorRed ], 1) tensorVggOne = self.moduleVggOne(tensorInput) tensorVggTwo = self.moduleVggTwo(tensorVggOne) tensorVggThr = self.moduleVggThr(tensorVggTwo) tensorVggFou = self.moduleVggFou(tensorVggThr) tensorVggFiv = self.moduleVggFiv(tensorVggFou) tensorScoreOne = self.moduleScoreOne(tensorVggOne) tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo) tensorScoreThr = self.moduleScoreThr(tensorVggThr) tensorScoreFou = self.moduleScoreFou(tensorVggFou) tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv) tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False) return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1)) ================================================ FILE: models/networks_basic.py ================================================ from __future__ import absolute_import import sys import torch import torch.nn as nn import torch.nn.init as init from torch.autograd import Variable import numpy as np from pdb import set_trace as st from skimage import color from IPython import embed from . import pretrained_networks as pn from util import util def spatial_average(in_tens, keepdim=True): return in_tens.mean([2,3],keepdim=keepdim) def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W in_H = in_tens.shape[2] scale_factor = 1.*out_H/in_H return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) # Learned perceptual metric class PNetLin(nn.Module): def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): super(PNetLin, self).__init__() self.pnet_type = pnet_type self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.lpips = lpips self.version = version self.scaling_layer = ScalingLayer() if(self.pnet_type in ['vgg','vgg16']): net_type = pn.vgg16 self.chns = [64,128,256,512,512] elif(self.pnet_type=='alex'): net_type = pn.alexnet self.chns = [64,192,384,256,256] elif(self.pnet_type=='squeeze'): net_type = pn.squeezenet self.chns = [64,128,256,384,384,512,512] self.L = len(self.chns) self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) if(lpips): self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] if(self.pnet_type=='squeeze'): # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins+=[self.lin5,self.lin6] def forward(self, in0, in1, retPerLayer=False): # v0.0 - original release had a bug, where input was not scaled in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) feats0, feats1, diffs = {}, {}, {} for kk in range(self.L): feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) diffs[kk] = (feats0[kk]-feats1[kk])**2 if(self.lpips): if(self.spatial): res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] else: res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] else: if(self.spatial): res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] else: res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] val = res[0] for l in range(1,self.L): val += res[l] if(retPerLayer): return (val, res) else: return val class ScalingLayer(nn.Module): def __init__(self): super(ScalingLayer, self).__init__() self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) def forward(self, inp): return (inp - self.shift.to(inp.device)) / self.scale.to(inp.device) class NetLinLayer(nn.Module): ''' A single linear layer which does a 1x1 conv ''' def __init__(self, chn_in, chn_out=1, use_dropout=False): super(NetLinLayer, self).__init__() layers = [nn.Dropout(),] if(use_dropout) else [] layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] self.model = nn.Sequential(*layers) class Dist2LogitLayer(nn.Module): ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' def __init__(self, chn_mid=32, use_sigmoid=True): super(Dist2LogitLayer, self).__init__() layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] layers += [nn.LeakyReLU(0.2,True),] layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] layers += [nn.LeakyReLU(0.2,True),] layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] if(use_sigmoid): layers += [nn.Sigmoid(),] self.model = nn.Sequential(*layers) def forward(self,d0,d1,eps=0.1): return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) class BCERankingLoss(nn.Module): def __init__(self, chn_mid=32): super(BCERankingLoss, self).__init__() self.net = Dist2LogitLayer(chn_mid=chn_mid) # self.parameters = list(self.net.parameters()) self.loss = torch.nn.BCELoss() def forward(self, d0, d1, judge): per = (judge+1.)/2. self.logit = self.net.forward(d0,d1) return self.loss(self.logit, per) # L2, DSSIM metrics class FakeNet(nn.Module): def __init__(self, use_gpu=True, colorspace='Lab'): super(FakeNet, self).__init__() self.use_gpu = use_gpu self.colorspace=colorspace class L2(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert(in0.size()[0]==1) # currently only supports batchSize 1 if(self.colorspace=='RGB'): (N,C,X,Y) = in0.size() value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) return value elif(self.colorspace=='Lab'): value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') ret_var = Variable( torch.Tensor((value,) ) ) if(self.use_gpu): ret_var = ret_var.cuda() return ret_var class DSSIM(FakeNet): def forward(self, in0, in1, retPerLayer=None): assert(in0.size()[0]==1) # currently only supports batchSize 1 if(self.colorspace=='RGB'): value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') elif(self.colorspace=='Lab'): value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') ret_var = Variable( torch.Tensor((value,) ) ) if(self.use_gpu): ret_var = ret_var.cuda() return ret_var def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print('Network',net) print('Total number of parameters: %d' % num_params) ================================================ FILE: models/pretrained_networks.py ================================================ from collections import namedtuple import torch from torchvision import models from IPython import embed class squeezenet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(squeezenet, self).__init__() pretrained_features = models.squeezenet1_1(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.slice6 = torch.nn.Sequential() self.slice7 = torch.nn.Sequential() self.N_slices = 7 for x in range(2): self.slice1.add_module(str(x), pretrained_features[x]) for x in range(2,5): self.slice2.add_module(str(x), pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), pretrained_features[x]) for x in range(10, 11): self.slice5.add_module(str(x), pretrained_features[x]) for x in range(11, 12): self.slice6.add_module(str(x), pretrained_features[x]) for x in range(12, 13): self.slice7.add_module(str(x), pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h h = self.slice6(h) h_relu6 = h h = self.slice7(h) h_relu7 = h vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) return out class alexnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(alexnet, self).__init__() alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(2): self.slice1.add_module(str(x), alexnet_pretrained_features[x]) for x in range(2, 5): self.slice2.add_module(str(x), alexnet_pretrained_features[x]) for x in range(5, 8): self.slice3.add_module(str(x), alexnet_pretrained_features[x]) for x in range(8, 10): self.slice4.add_module(str(x), alexnet_pretrained_features[x]) for x in range(10, 12): self.slice5.add_module(str(x), alexnet_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1 = h h = self.slice2(h) h_relu2 = h h = self.slice3(h) h_relu3 = h h = self.slice4(h) h_relu4 = h h = self.slice5(h) h_relu5 = h alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) return out class vgg16(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True): super(vgg16, self).__init__() vgg_pretrained_features = models.vgg16(pretrained=pretrained).features self.slice1 = torch.nn.Sequential() self.slice2 = torch.nn.Sequential() self.slice3 = torch.nn.Sequential() self.slice4 = torch.nn.Sequential() self.slice5 = torch.nn.Sequential() self.N_slices = 5 for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) for x in range(23, 30): self.slice5.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def forward(self, X): h = self.slice1(X) h_relu1_2 = h h = self.slice2(h) h_relu2_2 = h h = self.slice3(h) h_relu3_3 = h h = self.slice4(h) h_relu4_3 = h h = self.slice5(h) h_relu5_3 = h vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) return out class resnet(torch.nn.Module): def __init__(self, requires_grad=False, pretrained=True, num=18): super(resnet, self).__init__() if(num==18): self.net = models.resnet18(pretrained=pretrained) elif(num==34): self.net = models.resnet34(pretrained=pretrained) elif(num==50): self.net = models.resnet50(pretrained=pretrained) elif(num==101): self.net = models.resnet101(pretrained=pretrained) elif(num==152): self.net = models.resnet152(pretrained=pretrained) self.N_slices = 5 self.conv1 = self.net.conv1 self.bn1 = self.net.bn1 self.relu = self.net.relu self.maxpool = self.net.maxpool self.layer1 = self.net.layer1 self.layer2 = self.net.layer2 self.layer3 = self.net.layer3 self.layer4 = self.net.layer4 def forward(self, X): h = self.conv1(X) h = self.bn1(h) h = self.relu(h) h_relu1 = h h = self.maxpool(h) h = self.layer1(h) h_conv2 = h h = self.layer2(h) h_conv3 = h h = self.layer3(h) h_conv4 = h h = self.layer4(h) h_conv5 = h outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) return out ================================================ FILE: models/test_3styles_model.py ================================================ from .base_model import BaseModel from . import networks import torch class Test3StylesModel(BaseModel): """ This TesteModel can be used to generate CycleGAN results for only one direction. This model will automatically set '--dataset_mode single', which only loads the images from one collection. See the test instruction for more details. """ @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. The model can only be used during test time. It requires '--dataset_mode single'. You need to specify the network using the option '--model_suffix'. """ assert not is_train, 'TestModel cannot be used during training time' parser.set_defaults(dataset_mode='single') parser.add_argument('--style_control', type=int, default=0, help='not set style_vec in dataset') parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A') parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0') parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)') return parser def __init__(self, opt): assert(not opt.isTrain) BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call self.loss_names = [] # specify the images you want to save/display. The training/test scripts will call self.visual_names = ['real', 'fake1', 'fake2', 'fake3'] # specify the models you want to save to the disk. The training/test scripts will call and self.model_names = ['G_A'] # only generator is needed. print(opt.netga) print('model0_res', opt.model0_res) print('model1_res', opt.model1_res) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res) setattr(self, 'netG_A', self.netG) # store netG in self. def set_input(self, input): self.real = input['A'].to(self.device) self.image_paths = input['A_paths'] self.style1 = torch.Tensor([1, 0, 0]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device) self.style2 = torch.Tensor([0, 1, 0]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device) self.style3 = torch.Tensor([0, 0, 1]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device) def forward(self): """Run forward pass.""" self.fake1 = self.netG(self.real, self.style1) self.fake2 = self.netG(self.real, self.style2) self.fake3 = self.netG(self.real, self.style3) def optimize_parameters(self): """No optimization for test model.""" pass ================================================ FILE: models/test_model.py ================================================ from .base_model import BaseModel from . import networks import torch class TestModel(BaseModel): """ This TesteModel can be used to generate CycleGAN results for only one direction. This model will automatically set '--dataset_mode single', which only loads the images from one collection. See the test instruction for more details. """ @staticmethod def modify_commandline_options(parser, is_train=True): """Add new dataset-specific options, and rewrite default values for existing options. Parameters: parser -- original option parser is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. Returns: the modified parser. The model can only be used during test time. It requires '--dataset_mode single'. You need to specify the network using the option '--model_suffix'. """ assert not is_train, 'TestModel cannot be used during training time' parser.set_defaults(dataset_mode='single') parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.') parser.add_argument('--style_control', type=int, default=1, help='use style_control') parser.add_argument('--sfeature_mode', type=str, default='vgg19_softmax', help='vgg19 softmax as feature') parser.add_argument('--sinput', type=str, default='sind', help='use which one for style input') parser.add_argument('--sind', type=int, default=0, help='one hot for sfeature') parser.add_argument('--svec', type=str, default='1,0,0', help='3-dim vec') parser.add_argument('--simg', type=str, default='Yann_Legendre-053', help='drawing example for style') parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A') parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0') parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)') return parser def __init__(self, opt): """Initialize the pix2pix class. Parameters: opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions """ assert(not opt.isTrain) BaseModel.__init__(self, opt) # specify the training losses you want to print out. The training/test scripts will call self.loss_names = [] # specify the images you want to save/display. The training/test scripts will call self.visual_names = ['real', 'fake', 'rec'] # specify the models you want to save to the disk. The training/test scripts will call and self.model_names = ['G' + opt.model_suffix, 'G_B'] # only generator is needed. if not self.opt.style_control: self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) else: print(opt.netga) print('model0_res', opt.model0_res) print('model1_res', opt.model1_res) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res) self.netGB = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) # assigns the model to self.netG_[suffix] so that it can be loaded # please see setattr(self, 'netG' + opt.model_suffix, self.netG) # store netG in self. setattr(self, 'netG_B', self.netGB) # store netGB in self. def set_input(self, input): """Unpack input data from the dataloader and perform necessary pre-processing steps. Parameters: input: a dictionary that contains the data itself and its metadata information. We need to use 'single_dataset' dataset mode. It only load images from one domain. """ self.real = input['A'].to(self.device) self.image_paths = input['A_paths'] if self.opt.style_control: self.style = input['B_style'] def forward(self): """Run forward pass.""" if not self.opt.style_control: self.fake = self.netG(self.real) # G(real) else: print(torch.mean(self.style,(2,3)),'style_control') self.fake = self.netG(self.real, self.style) self.rec = self.netG_B(self.fake) def optimize_parameters(self): """No optimization for test model.""" pass ================================================ FILE: options/__init__.py ================================================ """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" ================================================ FILE: options/base_options.py ================================================ import argparse import os from util import util import torch import models import data class BaseOptions(): """This class defines options used during both training and test time. It also implements several helper functions such as parsing, printing, and saving the options. It also gathers additional options defined in functions in both dataset class and model class. """ def __init__(self): """Reset the class; indicates the class hasn't been initailized""" self.initialized = False def initialize(self, parser): """Define the common options that are used in both training and test.""" # basic parameters parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') parser.add_argument('--gpu_ids_p', type=str, default='0', help='gpu ids for pretrained auxiliary models: e.g. 0 0,1,2, 0,2. use -1 for CPU') parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') # model parameters parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') # dataset parameters parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') parser.add_argument('--batch_size', type=int, default=1, help='input batch size') parser.add_argument('--load_size', type=int, default=286, help='scale images to this size') parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') # additional parameters parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') self.initialized = True return parser def gather_options(self): """Initialize our parser with basic options(only once). Add additional model-specific and dataset-specific options. These options are defined in the function in model and dataset classes. """ if not self.initialized: # check if it has been initialized parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = self.initialize(parser) # get the basic options opt, _ = parser.parse_known_args() # modify model-related parser options model_name = opt.model model_option_setter = models.get_option_setter(model_name) parser = model_option_setter(parser, self.isTrain) opt, _ = parser.parse_known_args() # parse again with new defaults # modify dataset-related parser options dataset_name = opt.dataset_mode dataset_option_setter = data.get_option_setter(dataset_name) parser = dataset_option_setter(parser, self.isTrain) # save and return the parser self.parser = parser return parser.parse_args() def print_options(self, opt): """Print and save options It will print both current options and default values(if different). It will save options into a text file / [checkpoints_dir] / opt.txt """ message = '' message += '----------------- Options ---------------\n' for k, v in sorted(vars(opt).items()): comment = '' default = self.parser.get_default(k) if v != default: comment = '\t[default: %s]' % str(default) message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) message += '----------------- End -------------------' print(message) # save to the disk expr_dir = os.path.join(opt.checkpoints_dir, opt.name) util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) with open(file_name, 'wt') as opt_file: opt_file.write(message) opt_file.write('\n') def parse(self): """Parse our options, create checkpoints directory suffix, and set up gpu device.""" opt = self.gather_options() opt.isTrain = self.isTrain # train or test # process opt.suffix if opt.suffix: suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' opt.name = opt.name + suffix self.print_options(opt) # set gpu ids str_ids = opt.gpu_ids.split(',') opt.gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: opt.gpu_ids.append(id) if len(opt.gpu_ids) > 0: torch.cuda.set_device(opt.gpu_ids[0]) # set gpu ids str_ids = opt.gpu_ids_p.split(',') opt.gpu_ids_p = [] for str_id in str_ids: id = int(str_id) if id >= 0: opt.gpu_ids_p.append(id) self.opt = opt return self.opt ================================================ FILE: options/test_options.py ================================================ from .base_options import BaseOptions class TestOptions(BaseOptions): """This class includes test options. It also includes shared options defined in BaseOptions. """ def initialize(self, parser): parser = BaseOptions.initialize(self, parser) # define shared options parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') # Dropout and Batchnorm has different behavioir during training and test. parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') parser.add_argument('--imagefolder', type=str, default='images', help='subfolder to save images') # rewrite devalue values parser.set_defaults(model='test') # To avoid cropping, the load_size should be the same as crop_size parser.set_defaults(load_size=parser.get_default('crop_size')) self.isTrain = False return parser ================================================ FILE: options/train_options.py ================================================ from .base_options import BaseOptions class TrainOptions(BaseOptions): """This class includes training options. It also includes shared options defined in BaseOptions. """ def initialize(self, parser): parser = BaseOptions.initialize(self, parser) # visdom and HTML visualization parameters parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') # network saving and loading parameters parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') # training parameters parser.add_argument('--n_epochs', type=int, default=200, help='the end epoch count') parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') self.isTrain = True return parser ================================================ FILE: portrait_drawing_resources.md ================================================ - Charles Burns (style1): https://www.pinterest.co.uk/johns59/charles-burns-fan-club/ - Yann Legendre (style1): http://www.yannlegendre.com/project/portraits/ - Kathryn Rathke (style2): https://www.kathrynrathke.com/ - Vectorportal (style3): https://www.pinterest.co.uk/vectorportal/celebrity-vector-illustrations/ ================================================ FILE: preprocess/face_align_512.m ================================================ function [trans_img]=face_align_512(impath,facial5point,savedir) % align the faces by similarity transformation. % using 5 facial landmarks: 2 eyes, nose, 2 mouth corners. % impath: path to image % facial5point: 5x2 size, 5 facial landmark positions, detected by MTCNN % savedir: savedir for cropped image and transformed facial landmarks %% alignment settings imgSize = [512,512]; coord5point = [180,230; 300,230; 240,301; 186,365.6; 294,365.6];%480x480 coord5point = (coord5point-240)/560 * 512 + 256; %% face alignment % load and align, resize image to imgSize img = imread(impath); facial5point = double(facial5point); transf = cp2tform(facial5point, coord5point, 'similarity'); trans_img = imtransform(img, transf, 'XData', [1 imgSize(2)],... 'YData', [1 imgSize(1)],... 'Size', imgSize,... 'FillValues', [255;255;255]); trans_facial5point = round(tformfwd(transf,facial5point)); %% save results if ~exist(savedir,'dir') mkdir(savedir) end [~,name,~] = fileparts(impath); % save trans_img imwrite(trans_img, fullfile(savedir,[name,'_resized.png'])); fprintf('write aligned image to %s\n',fullfile(savedir,[name,'_resized.png'])); %% show results imshow(trans_img); hold on; plot(trans_facial5point(:,1),trans_facial5point(:,2),'b'); plot(trans_facial5point(:,1),trans_facial5point(:,2),'r+'); end ================================================ FILE: preprocess/readme.md ================================================ ## Preprocessing steps During training, face photos and drawings are aligned and have nose,eyes,lips mask detected. During test, the alignment step is optional and the masks are not needed. ### 1. Align, resize, crop images to 512x512 All training and testing images in our model are aligned using facial landmarks. And landmarks after alignment are needed in our code. - First, 5 facial landmark for a face photo need to be detected (we detect using [MTCNN](https://github.com/kpzhang93/MTCNN_face_detection_alignment)(MTCNNv1)). - Then, we provide a matlab function in `face_align_512.m` to align, resize and crop face photos (and corresponding drawings) to 512x512. Call this function in MATLAB to align the image to 512x512. For example, for `ia_selfie_10515.jpg` in `example` dir, 5 detected facial landmark is saved in `example/ia_selfie_10515_facial5point.mat`. Call following in MATLAB: ```bash load('example/ia_selfie_10515_facial5point.mat'); [trans_img]=face_align_512('example/ia_selfie_10515.jpg',facial5point,'example'); ``` This will align the image and output aligned image in `example` folder. See `face_align_512.m` for more instructions. ### 2. Prepare nose,eyes,lips masks In our work, we use the face parsing network in https://github.com/cientgu/Mask_Guided_Portrait_Editing to get nose,eyes,lips regions and then dilate the regions to make them cover these facial features (some examples are shown in `example` folder). - The background masks need to be copied to `datasets/portrait_drawing/train/A(B)(_eyes)(_lips)`, and has the **same filename** with aligned face photos. ================================================ FILE: readme.md ================================================ # Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping We provide PyTorch implementations for our CVPR 2020 paper "Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping". [paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Yi_Unpaired_Portrait_Drawing_Generation_via_Asymmetric_Cycle_Mapping_CVPR_2020_paper.pdf), [suppl](https://openaccess.thecvf.com/content_CVPR_2020/supplemental/Yi_Unpaired_Portrait_Drawing_CVPR_2020_supplemental.pdf). This project generates multi-style artistic portrait drawings from face photos using a GAN-based model. [[Jittor implementation]](https://github.com/yiranran/Unpaired-Portrait-Drawing-Jittor) ## Our Proposed Framework ## Sample Results From left to right: input, output(style1), output(style2), output(style3) ## Prerequisites - Linux or macOS - Python 3 - CPU or NVIDIA GPU + CUDA CuDNN ## Installation - To install the dependencies, run ```bash pip install -r requirements.txt ``` ## Colab A colab demo is [here](https://colab.research.google.com/drive/1U1fPXD1JukuKPOrhGMX1iaJC-d8_RUYr). ## Test steps (apply a pretrained model) - 1. Download pre-trained models from [BaiduYun](https://pan.baidu.com/s/1_9Fy8mRpTQp6AvqhHsfQAQ)(extract code:c9h7) or [GoogleDrive](https://drive.google.com/drive/folders/1FzOcdlMYhvK_nyLCe8wnwotMphhIoiYt?usp=sharing) and rename the folder to `checkpoints`. - 2. Test for example photos: generate artistic portrait drawings for example photos in the folder `./examples` using ``` bash # with GPU python test_seq_style.py # without GPU python test_seq_style.py --gpu -1 ``` The test results will be saved to a html file here: `./results/pretrained/test_200/index3styles.html`. The result images are saved in `./results/pretrained/test_200/images3styles`, where `real`, `fake1`, `fake2`, `fake3` correspond to input face photo, style1 drawing, style2 drawing, style3 drawing respectively. - 3. To test on your own photos: First use an image editor to crop the face region of your photo (or use an optional preprocess [here](preprocess/readme.md)). Then specify the folder that contains test photos using option `--dataroot`, specify save folder name using option `--savefolder` and run the above command again: ``` bash # with GPU python test_seq_style.py --dataroot [input_folder] --savefolder [save_folder_name] # without GPU python test_seq_style.py --gpu -1 --dataroot [input_folder] --savefolder [save_folder_name] # E.g. python test_seq_style.py --gpu -1 --dataroot ./imgs/test1 --savefolder 3styles_test1 ``` The test results will be saved to a html file here: `./results/pretrained/test_200/index[save_folder_name].html`. The result images are saved in `./results/pretrained/test_200/images[save_folder_name]`. An example html screenshot is shown below: You can contact email yr16@mails.tsinghua.edu.cn for any questions. ## Train steps - 1. Prepare for the dataset: 1) download face photos and portrait drawings from internet (e.g. [resources](portrait_drawing_resources.md)). 2) align, crop photos and drawings & 3) prepare nose, eyes, lips masks according to [preprocess instructions](preprocess/readme.md). 3) put aligned photos under `./datasets/portrait_drawing/train/A`, aligned drawings under `./datasets/portrait_drawing/train/B`, masks under `A_nose`,`A_eyes`,`A_lips`,`B_nose`,`B_eyes`,`B_lips` respectively. - 2. Train a 3-class style classifier and extract the 3-dim style feature (according to paper). And save the style feature of each drawing in the training set in .npy format, in folder `./datasets/portrait_drawing/train/B_feat` A subset of our training set is [here](https://drive.google.com/file/d/1OSMOR3-uhGkoPwPFRNychJSNrpSak_23/view?usp=sharing). - 3. Train our model ``` bash sh ./scripts/train.sh ``` Models are saved in folder checkpoints/portrait_drawing ## Citation If you use this code for your research, please cite our paper. ``` @inproceedings{YiLLR20, title = {Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping}, author = {Yi, Ran and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L}, booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition (CVPR '20)}, pages = {8214--8222}, year = {2020} } ``` ## Acknowledgments Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). ================================================ FILE: requirements.txt ================================================ torch==1.2.0 torchvision==0.4.0 dominate==2.4.0 visdom==0.1.8.9 scipy==1.1.0 numpy==1.16.4 Pillow==6.2.1 opencv-python==4.1.0.25 ================================================ FILE: scripts/train.sh ================================================ set -ex python train.py --dataroot ./datasets/portrait_drawing --name formal --model asymmetric_cycle_gan_cls --output_nc 1 --load_size 572 --crop_size 512 --lr 0.000015 --dataset_mode unaligned_mask_stylecls --display_env asymmetric_trainset --gpu_ids 0 --gpu_ids_p 0 --niter 100 --niter_decay 200 --n_epochs 200 ================================================ FILE: test.py ================================================ """General-purpose test script for image-to-image translation. Once you have trained your model with train.py, you can use this script to test the model. It will load a saved model from --checkpoints_dir and save the results to --results_dir. It first creates model and dataset given the option. It will hard-code some parameters. It then runs inference for --num_test images and save results to an HTML file. Example (You need to train models first or download pre-trained models from our website): Test a CycleGAN model (both sides): python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan Test a CycleGAN model (one side only): python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout The option '--model test' is used for generating CycleGAN results only for one side. This option will automatically set '--dataset_mode single', which only loads the images from one set. On the contrary, using '--model cycle_gan' requires loading and generating results in both directions, which is sometimes unnecessary. The results will be saved at ./results/. Use '--results_dir ' to specify the results directory. Test a pix2pix model: python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA See options/base_options.py and options/test_options.py for more test options. See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md """ import os from options.test_options import TestOptions from data import create_dataset from models import create_model from util.visualizer import save_images from util import html if __name__ == '__main__': opt = TestOptions().parse() # get test options # hard-code some parameters for test opt.num_threads = 0 # test code only supports num_threads = 1 opt.batch_size = 1 # test code only supports batch_size = 1 opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. opt.no_flip = True # no flip; comment this line if results on flipped images are needed. opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options model = create_model(opt) # create a model given opt.model and other options model.setup(opt) # regular setup: load and print networks; create schedulers # create a website web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch)) # define the website directory #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch), refresh=0, folder=opt.imagefolder) # test with eval mode. This only affects layers like batchnorm and dropout. # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. if opt.eval: model.eval() for i, data in enumerate(dataset): if i >= opt.num_test: # only apply our model to opt.num_test images. break model.set_input(data) # unpack data from data loader model.test() # run inference visuals = model.get_current_visuals() # get image results img_path = model.get_image_paths() # get image paths if i % 5 == 0: # save images to an HTML file print('processing (%04d)-th image... %s' % (i, img_path)) save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, W=opt.W, H=opt.H) webpage.save() # save the HTML ================================================ FILE: test_seq_style.py ================================================ import os import argparse def opts(): parser = argparse.ArgumentParser() parser.add_argument('-g','--gpu', default = '0', type = str, help = 'gpu ids, -1 for cpu, default is 0.') parser.add_argument('-d','--dataroot', default = './examples', type = str, help = 'the input folder that contains test face photos, default is ./examples') parser.add_argument('-s','--savefolder', default = '3styles', type = str, help = 'the name of save folder that contains result images, default is 3styles') return parser.parse_args() if __name__ == '__main__': opt = opts() exp = 'pretrained' imgsize = 512 epoch = '200' dataroot = opt.dataroot gpu_id = opt.gpu # test 3 styles in one pass savefolder = 'images'+opt.savefolder os.system('python3 test.py --dataroot %s --name %s --model test_3styles --output_nc 1 --no_dropout --num_test 1000 --epoch %s --imagefolder %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,epoch,savefolder,imgsize,imgsize,gpu_id)) print('check ./results/%s/test_%s/index%s.html'%(exp,epoch,savefolder[6:])) print('saved to ./results/%s/test_%s/%s'%(exp,epoch,savefolder)) # test 3 styles separately ''' for vec in [[1,0,0],[0,1,0],[0,0,1]]: #1,0,0 for style1; 0,1,0 for style2; 0,0,1 for style3 svec = '%d,%d,%d' % (vec[0],vec[1],vec[2]) savefolder = 'imagesstyle%d-%d-%d'%(vec[0],vec[1],vec[2]) print('results/%s/test_%s/index%s.html'%(exp,epoch,savefolder[6:])) os.system('python3 test.py --dataroot %s --name %s --model test --output_nc 1 --no_dropout --model_suffix _A --num_test 1000 --epoch %s --imagefolder %s --sinput svec --svec %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,epoch,savefolder,svec,imgsize,imgsize,gpu_id)) ''' ================================================ FILE: train.py ================================================ """General-purpose training script for image-to-image translation. This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization). You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). It first creates model, dataset, and visualizer given the option. It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. The script supports continue/resume training. Use '--continue_train' to resume your previous training. Example: Train a CycleGAN model: python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan Train a pix2pix model: python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA See options/base_options.py and options/train_options.py for more training options. See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md """ import time from options.train_options import TrainOptions from data import create_dataset from models import create_model from util.visualizer import Visualizer import pdb if __name__ == '__main__': start = time.time() opt = TrainOptions().parse() # get training options dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options dataset_size = len(dataset) # get the number of images in the dataset. print('The number of training images = %d' % dataset_size) model = create_model(opt) # create a model given opt.model and other options model.setup(opt) # regular setup: load and print networks; create schedulers visualizer = Visualizer(opt) # create a visualizer that display/save images and plots total_iters = 0 # the total number of training iterations for epoch in range(opt.epoch_count, opt.n_epochs + 1): # outer loop for different epochs; we save the model by , + epoch_start_time = time.time() # timer for entire epoch iter_data_time = time.time() # timer for data loading per iteration epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch model.update_process(epoch) for i, data in enumerate(dataset): # inner loop within one epoch iter_start_time = time.time() # timer for computation per iteration if total_iters % opt.print_freq == 0: t_data = iter_start_time - iter_data_time visualizer.reset() total_iters += opt.batch_size epoch_iter += opt.batch_size model.set_input(data) # unpack data from dataset and apply preprocessing model.optimize_parameters() # calculate loss functions, get gradients, update network weights if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file save_result = total_iters % opt.update_html_freq == 0 model.compute_visuals() visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk losses = model.get_current_losses() t_comp = (time.time() - iter_start_time) / opt.batch_size if opt.model == 'cycle_gan': processes = [model.process] + model.lambda_As visualizer.print_current_losses_process(epoch, epoch_iter, losses, t_comp, t_data, processes) else: visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) if opt.display_id > 0: visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' model.save_networks(save_suffix) iter_data_time = time.time() if epoch % opt.save_epoch_freq == 0: # cache our model every epochs print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) model.save_networks('latest') model.save_networks(epoch) print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) model.update_learning_rate() # update learning rates at the end of every epoch. print('Total Time Taken: %d sec' % (time.time() - start)) ================================================ FILE: util/__init__.py ================================================ """This package includes a miscellaneous collection of useful helper functions.""" ================================================ FILE: util/get_data.py ================================================ from __future__ import print_function import os import tarfile import requests from warnings import warn from zipfile import ZipFile from bs4 import BeautifulSoup from os.path import abspath, isdir, join, basename class GetData(object): """A Python script for downloading CycleGAN or pix2pix datasets. Parameters: technique (str) -- One of: 'cyclegan' or 'pix2pix'. verbose (bool) -- If True, print additional information. Examples: >>> from util.get_data import GetData >>> gd = GetData(technique='cyclegan') >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' and 'scripts/download_cyclegan_model.sh'. """ def __init__(self, technique='cyclegan', verbose=True): url_dict = { 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' } self.url = url_dict.get(technique.lower()) self._verbose = verbose def _print(self, text): if self._verbose: print(text) @staticmethod def _get_options(r): soup = BeautifulSoup(r.text, 'lxml') options = [h.text for h in soup.find_all('a', href=True) if h.text.endswith(('.zip', 'tar.gz'))] return options def _present_options(self): r = requests.get(self.url) options = self._get_options(r) print('Options:\n') for i, o in enumerate(options): print("{0}: {1}".format(i, o)) choice = input("\nPlease enter the number of the " "dataset above you wish to download:") return options[int(choice)] def _download_data(self, dataset_url, save_path): if not isdir(save_path): os.makedirs(save_path) base = basename(dataset_url) temp_save_path = join(save_path, base) with open(temp_save_path, "wb") as f: r = requests.get(dataset_url) f.write(r.content) if base.endswith('.tar.gz'): obj = tarfile.open(temp_save_path) elif base.endswith('.zip'): obj = ZipFile(temp_save_path, 'r') else: raise ValueError("Unknown File Type: {0}.".format(base)) self._print("Unpacking Data...") obj.extractall(save_path) obj.close() os.remove(temp_save_path) def get(self, save_path, dataset=None): """ Download a dataset. Parameters: save_path (str) -- A directory to save the data to. dataset (str) -- (optional). A specific dataset to download. Note: this must include the file extension. If None, options will be presented for you to choose from. Returns: save_path_full (str) -- the absolute path to the downloaded data. """ if dataset is None: selected_dataset = self._present_options() else: selected_dataset = dataset save_path_full = join(save_path, selected_dataset.split('.')[0]) if isdir(save_path_full): warn("\n'{0}' already exists. Voiding Download.".format( save_path_full)) else: self._print('Downloading Data...') url = "{0}/{1}".format(self.url, selected_dataset) self._download_data(url, save_path=save_path) return abspath(save_path_full) ================================================ FILE: util/html.py ================================================ import dominate from dominate.tags import meta, h3, table, tr, td, p, a, img, br import os class HTML: """This HTML class allows us to save images and write texts into a single HTML file. It consists of functions such as (add a text header to the HTML file), (add a row of images to the HTML file), and (save the HTML to the disk). It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. """ def __init__(self, web_dir, title, refresh=0, folder='images'): """Initialize the HTML classes Parameters: web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: with self.doc.head: meta(http_equiv="refresh", content=str(refresh)) def get_image_dir(self): """Return the directory that stores images""" return self.img_dir def add_header(self, text): """Insert a header to the HTML file Parameters: text (str) -- the header text """ with self.doc: h3(text) def add_images(self, ims, txts, links, width=400): """add images to the HTML file Parameters: ims (str list) -- a list of image paths txts (str list) -- a list of image names shown on the website links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page """ self.t = table(border=1, style="table-layout: fixed;") # Insert a table self.doc.add(self.t) with self.t: with tr(): for im, txt, link in zip(ims, txts, links): with td(style="word-wrap: break-word;", halign="center", valign="top"): with p(): with a(href=os.path.join('images', link)): #img(style="width:%dpx" % width, src=os.path.join('images', im)) img(style="width:%dpx" % width, src=os.path.join(self.folder, im)) br() p(txt) def save(self): """save the current content to the HMTL file""" #html_file = '%s/index.html' % self.web_dir name = self.folder[6:] if self.folder[:6] == 'images' else self.folder html_file = '%s/index%s.html' % (self.web_dir, name) f = open(html_file, 'wt') f.write(self.doc.render()) f.close() if __name__ == '__main__': # we show an example usage here. html = HTML('web/', 'test_html') html.add_header('hello world') ims, txts, links = [], [], [] for n in range(4): ims.append('image_%d.png' % n) txts.append('text_%d' % n) links.append('image_%d.png' % n) html.add_images(ims, txts, links) html.save() ================================================ FILE: util/image_pool.py ================================================ import random import torch class ImagePool(): """This class implements an image buffer that stores previously generated images. This buffer enables us to update discriminators using a history of generated images rather than the ones produced by the latest generators. """ def __init__(self, pool_size): """Initialize the ImagePool class Parameters: pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created """ self.pool_size = pool_size if self.pool_size > 0: # create an empty pool self.num_imgs = 0 self.images = [] def query(self, images): """Return an image from the pool. Parameters: images: the latest generated images from the generator Returns images from the buffer. By 50/100, the buffer will return input images. By 50/100, the buffer will return images previously stored in the buffer, and insert the current images to the buffer. """ if self.pool_size == 0: # if the buffer size is 0, do nothing return images return_images = [] for image in images: image = torch.unsqueeze(image.data, 0) if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer self.num_imgs = self.num_imgs + 1 self.images.append(image) return_images.append(image) else: p = random.uniform(0, 1) if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer random_id = random.randint(0, self.pool_size - 1) # randint is inclusive tmp = self.images[random_id].clone() self.images[random_id] = image return_images.append(tmp) else: # by another 50% chance, the buffer will return the current image return_images.append(image) return_images = torch.cat(return_images, 0) # collect all the images and return return return_images ================================================ FILE: util/util.py ================================================ """This module contains simple helper functions """ from __future__ import print_function import torch import numpy as np from PIL import Image import os import pdb from scipy.io import savemat def tensor2im(input_image, imtype=np.uint8): """"Converts a Tensor array into a numpy image array. Parameters: input_image (tensor) -- the input image tensor array imtype (type) -- the desired type of the converted numpy array """ if not isinstance(input_image, np.ndarray): if isinstance(input_image, torch.Tensor): # get the data from a variable image_tensor = input_image.data else: return input_image #pdb.set_trace() image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array if image_numpy.shape[0] == 1: # grayscale to RGB image_numpy = np.tile(image_numpy, (3, 1, 1)) elif image_numpy.shape[0] == 2: image_numpy = np.concatenate([image_numpy, image_numpy[1:2,:,:]], 0) image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling else: # if it is a numpy array, do nothing image_numpy = input_image return image_numpy.astype(imtype) #return np.round(image_numpy).astype(imtype),image_numpy def diagnose_network(net, name='network'): """Calculate and print the mean of average absolute(gradients) Parameters: net (torch network) -- Torch network name (str) -- the name of the network """ mean = 0.0 count = 0 for param in net.parameters(): if param.grad is not None: mean += torch.mean(torch.abs(param.grad.data)) count += 1 if count > 0: mean = mean / count print(name) print(mean) def save_image(image_numpy, image_path): """Save a numpy image to the disk Parameters: image_numpy (numpy array) -- input numpy array image_path (str) -- the path of the image """ image_pil = Image.fromarray(image_numpy) #pdb.set_trace() image_pil.save(image_path) def print_numpy(x, val=True, shp=False): """Print the mean, min, max, median, std, and size of a numpy array Parameters: val (bool) -- if print the values of the numpy array shp (bool) -- if print the shape of the numpy array """ x = x.astype(np.float64) if shp: print('shape,', x.shape) if val: x = x.flatten() print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) def mkdirs(paths): """create empty directories if they don't exist Parameters: paths (str list) -- a list of directory paths """ if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): """create a single empty directory if it didn't exist Parameters: path (str) -- a single directory path """ if not os.path.exists(path): os.makedirs(path) def normalize_tensor(in_feat,eps=1e-10): norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) return in_feat/(norm_factor+eps) ================================================ FILE: util/visualizer.py ================================================ import numpy as np import os import sys import ntpath import time from . import util, html from subprocess import Popen, PIPE from PIL import Image if sys.version_info[0] == 2: VisdomExceptionBase = Exception else: VisdomExceptionBase = ConnectionError def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, W=None, H=None): """Save images to the disk. Parameters: webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs image_path (str) -- the string is used to create image paths aspect_ratio (float) -- the aspect ratio of saved images width (int) -- the images will be resized to width x width This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. """ image_dir = webpage.get_image_dir() short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] webpage.add_header(name) ims, txts, links = [], [], [] for label, im_data in visuals.items(): ## tensor to im im = util.tensor2im(im_data) image_name = '%s_%s.png' % (name, label) save_path = os.path.join(image_dir, image_name) h, w, _ = im.shape if W is not None and H is not None and (W != w or H != h): im = np.array(Image.fromarray(im).resize((W, H), Image.BICUBIC)) else: if aspect_ratio > 1.0: im = np.array(Image.fromarray(im).resize((int(w * aspect_ratio), h), Image.BICUBIC)) if aspect_ratio < 1.0: im = np.array(Image.fromarray(im).resize((w, int(h / aspect_ratio)), Image.BICUBIC)) util.save_image(im, save_path) ims.append(image_name) txts.append(label) links.append(image_name) webpage.add_images(ims, txts, links, width=width) class Visualizer(): """This class includes several functions that can display/save images and print/save logging information. It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. """ def __init__(self, opt): """Initialize the Visualizer class Parameters: opt -- stores all the experiment flags; needs to be a subclass of BaseOptions Step 1: Cache the training/test options Step 2: connect to a visdom server Step 3: create an HTML object for saveing HTML filters Step 4: create a logging file to store training losses """ self.opt = opt # cache the option self.display_id = opt.display_id self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name self.port = opt.display_port self.saved = False if self.display_id > 0: # connect to a visdom server given and import visdom self.ncols = opt.display_ncols self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) if not self.vis.check_connection(): self.create_visdom_connections() if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) # create a logging file to store training losses self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) def reset(self): """Reset the self.saved status""" self.saved = False def create_visdom_connections(self): """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port print('\n\nCould not connect to Visdom server. \n Trying to start a server....') print('Command: %s' % cmd) Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) def display_current_results(self, visuals, epoch, save_result): """Display current results on visdom; save current results to an HTML file. Parameters: visuals (OrderedDict) - - dictionary of images to display or save epoch (int) - - the current epoch save_result (bool) - - if save the current results to an HTML file """ if self.display_id > 0: # show images in the browser using visdom ncols = self.ncols if ncols > 0: # show all the images in one visdom panel ncols = min(ncols, len(visuals)) h, w = next(iter(visuals.values())).shape[:2] table_css = """""" % (w, h) # create a table css # create a table of images. title = self.name label_html = '' label_html_row = '' images = [] idx = 0 for label, image in visuals.items(): image_numpy = util.tensor2im(image) label_html_row += '%s' % label images.append(image_numpy.transpose([2, 0, 1])) idx += 1 if idx % ncols == 0: label_html += '%s' % label_html_row label_html_row = '' white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 while idx % ncols != 0: images.append(white_image) label_html_row += '' idx += 1 if label_html_row != '': label_html += '%s' % label_html_row try: self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '%s
' % label_html self.vis.text(table_css + label_html, win=self.display_id + 2, opts=dict(title=title + ' labels')) except VisdomExceptionBase: self.create_visdom_connections() else: # show each image in a separate visdom panel; idx = 1 try: for label, image in visuals.items(): image_numpy = util.tensor2im(image) self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 except VisdomExceptionBase: self.create_visdom_connections() if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. self.saved = True # save images to the disk for label, image in visuals.items(): image_numpy = util.tensor2im(image) img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) # update website webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) for n in range(epoch, 0, -1): webpage.add_header('epoch [%d]' % n) ims, txts, links = [], [], [] for label, image_numpy in visuals.items(): image_numpy = util.tensor2im(image) img_path = 'epoch%.3d_%s.png' % (n, label) ims.append(img_path) txts.append(label) links.append(img_path) webpage.add_images(ims, txts, links, width=self.win_size) webpage.save() def plot_current_losses(self, epoch, counter_ratio, losses): """display the current losses on visdom display: dictionary of error labels and values Parameters: epoch (int) -- current epoch counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 losses (OrderedDict) -- training losses stored in the format of (name, float) pairs """ if not hasattr(self, 'plot_data'): self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} self.plot_data['X'].append(epoch + counter_ratio) self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) #X = np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1) #Y = np.array(self.plot_data['Y']) try: self.vis.line( X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), Y=np.array(self.plot_data['Y']), opts={ 'title': self.name + ' loss over time', 'legend': self.plot_data['legend'], 'xlabel': 'epoch', 'ylabel': 'loss'}, win=self.display_id) except VisdomExceptionBase: self.create_visdom_connections() # losses: same format as |losses| of plot_current_losses def print_current_losses(self, epoch, iters, losses, t_comp, t_data): """print current losses on console; also save the losses to the disk Parameters: epoch (int) -- current epoch iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) losses (OrderedDict) -- training losses stored in the format of (name, float) pairs t_comp (float) -- computational time per data point (normalized by batch_size) t_data (float) -- data loading time per data point (normalized by batch_size) """ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) for k, v in losses.items(): message += '%s: %.3f ' % (k, v) print(message) # print the message with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) # save the message # losses: same format as |losses| of plot_current_losses def print_current_losses_process(self, epoch, iters, losses, t_comp, t_data, processes): """print current losses on console; also save the losses to the disk Parameters: epoch (int) -- current epoch iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) losses (OrderedDict) -- training losses stored in the format of (name, float) pairs t_comp (float) -- computational time per data point (normalized by batch_size) t_data (float) -- data loading time per data point (normalized by batch_size) """ message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) message += '[process: %.3f, non_trunc: %.3f, trunc: %.3f] ' % (processes[0], processes[1], processes[2]) for k, v in losses.items(): message += '%s: %.3f ' % (k, v) print(message) # print the message with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) # save the message