master b67591912a3e cached
35 files
183.1 KB
47.7k tokens
224 symbols
1 requests
Download .txt
Repository: yiranran/Unpaired-Portrait-Drawing
Branch: master
Commit: b67591912a3e
Files: 35
Total size: 183.1 KB

Directory structure:
gitextract_uyxvzmrx/

├── .gitignore
├── data/
│   ├── __init__.py
│   ├── base_dataset.py
│   ├── image_folder.py
│   ├── single_dataset.py
│   └── unaligned_mask_stylecls_dataset.py
├── models/
│   ├── __init__.py
│   ├── asymmetric_cycle_gan_cls_model.py
│   ├── base_model.py
│   ├── dist_model.py
│   ├── networks.py
│   ├── networks_basic.py
│   ├── pretrained_networks.py
│   ├── test_3styles_model.py
│   └── test_model.py
├── options/
│   ├── __init__.py
│   ├── base_options.py
│   ├── test_options.py
│   └── train_options.py
├── portrait_drawing_resources.md
├── preprocess/
│   ├── example/
│   │   └── ia_selfie_10515_facial5point.mat
│   ├── face_align_512.m
│   └── readme.md
├── readme.md
├── requirements.txt
├── scripts/
│   └── train.sh
├── test.py
├── test_seq_style.py
├── train.py
└── util/
    ├── __init__.py
    ├── get_data.py
    ├── html.py
    ├── image_pool.py
    ├── util.py
    └── visualizer.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
.DS_Store
debug*
datasets/
checkpoints/
style_features/
results/
build/
dist/
torch.egg-info/
*/**/__pycache__
torch/version.py
torch/csrc/generic/TensorMethods.cpp
torch/lib/*.so*
torch/lib/*.dylib*
torch/lib/*.h
torch/lib/build
torch/lib/tmp_install
torch/lib/include
torch/lib/torch_shm_manager
torch/csrc/cudnn/cuDNN.cpp
torch/csrc/nn/THNN.cwrap
torch/csrc/nn/THNN.cpp
torch/csrc/nn/THCUNN.cwrap
torch/csrc/nn/THCUNN.cpp
torch/csrc/nn/THNN_generic.cwrap
torch/csrc/nn/THNN_generic.cpp
torch/csrc/nn/THNN_generic.h
docs/src/**/*
test/data/legacy_modules.t7
test/data/gpu_tensors.pt
test/htmlcov
test/.coverage
*/*.pyc
*/**/*.pyc
*/**/**/*.pyc
*/**/**/**/*.pyc
*/**/**/**/**/*.pyc
*/*.so*
*/**/*.so*
*/**/*.dylib*
test/data/legacy_serialized.pt
*~
.idea


================================================
FILE: data/__init__.py
================================================
"""This package includes all the modules related to data loading and preprocessing

 To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
 You need to implement four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point from data loader.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.

Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import importlib
import torch.utils.data
from data.base_dataset import BaseDataset


def find_dataset_using_name(dataset_name):
    """Import the module "data/[dataset_name]_dataset.py".

    In the file, the class called DatasetNameDataset() will
    be instantiated. It has to be a subclass of BaseDataset,
    and it is case-insensitive.
    """
    dataset_filename = "data." + dataset_name + "_dataset"
    datasetlib = importlib.import_module(dataset_filename)

    dataset = None
    target_dataset_name = dataset_name.replace('_', '') + 'dataset'
    for name, cls in datasetlib.__dict__.items():
        if name.lower() == target_dataset_name.lower() \
           and issubclass(cls, BaseDataset):
            dataset = cls

    if dataset is None:
        raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))

    return dataset


def get_option_setter(dataset_name):
    """Return the static method <modify_commandline_options> of the dataset class."""
    dataset_class = find_dataset_using_name(dataset_name)
    return dataset_class.modify_commandline_options


def create_dataset(opt):
    """Create a dataset given the option.

    This function wraps the class CustomDatasetDataLoader.
        This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from data import create_dataset
        >>> dataset = create_dataset(opt)
    """
    data_loader = CustomDatasetDataLoader(opt)
    dataset = data_loader.load_data()
    return dataset


class CustomDatasetDataLoader():
    """Wrapper class of Dataset class that performs multi-threaded data loading"""

    def __init__(self, opt):
        """Initialize this class

        Step 1: create a dataset instance given the name [dataset_mode]
        Step 2: create a multi-threaded data loader.
        """
        self.opt = opt
        dataset_class = find_dataset_using_name(opt.dataset_mode)
        self.dataset = dataset_class(opt)
        print("dataset [%s] was created" % type(self.dataset).__name__)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batch_size,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.num_threads))

    def load_data(self):
        return self

    def __len__(self):
        """Return the number of data in the dataset"""
        return min(len(self.dataset), self.opt.max_dataset_size)

    def __iter__(self):
        """Return a batch of data"""
        for i, data in enumerate(self.dataloader):
            if i * self.opt.batch_size >= self.opt.max_dataset_size:
                break
            yield data


================================================
FILE: data/base_dataset.py
================================================
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.

It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABCMeta, abstractmethod


class BaseDataset(data.Dataset):
    __metaclass__ = ABCMeta
    """This class is an abstract base class (ABC) for datasets.

    To create a subclass, you need to implement the following four functions:
    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
    -- <__len__>:                       return the size of dataset.
    -- <__getitem__>:                   get a data point.
    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
    """

    def __init__(self, opt):
        """Initialize the class; save the options in the class

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        self.opt = opt
        self.root = opt.dataroot

    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        return parser

    @abstractmethod
    def __len__(self):
        """Return the total number of images in the dataset."""
        return 0

    @abstractmethod
    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns:
            a dictionary of data with their names. It ususally contains the data itself and its metadata information.
        """
        pass


def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt.preprocess == 'resize_and_crop':
        new_h = new_w = opt.load_size
    elif opt.preprocess == 'scale_width_and_crop':
        new_w = opt.load_size
        new_h = opt.load_size * h // w

    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))

    flip = random.random() > 0.5

    return {'crop_pos': (x, y), 'flip': flip}


def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    if convert:
        transform_list += [transforms.ToTensor()]
        if grayscale:
            transform_list += [transforms.Normalize((0.5,), (0.5,))]
        else:
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)

def get_transform_mask(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
    transform_list = []
    if grayscale:
        transform_list.append(transforms.Grayscale(1))
    if 'resize' in opt.preprocess:
        osize = [opt.load_size, opt.load_size]
        transform_list.append(transforms.Resize(osize, method))
    elif 'scale_width' in opt.preprocess:
        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))

    if 'crop' in opt.preprocess:
        if params is None:
            transform_list.append(transforms.RandomCrop(opt.crop_size))
        else:
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))

    if opt.preprocess == 'none':
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))

    if not opt.no_flip:
        if params is None:
            transform_list.append(transforms.RandomHorizontalFlip())
        elif params['flip']:
            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    if convert:
        transform_list += [transforms.ToTensor()]
    return transforms.Compose(transform_list)

def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if (h == oh) and (w == ow):
        return img

    __print_size_warning(ow, oh, w, h)
    return img.resize((w, h), method)


def __scale_width(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    if (ow == target_width):
        return img
    w = target_width
    h = int(target_width * oh / ow)
    return img.resize((w, h), method)


def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size
    if (ow > tw or oh > th):
        return img.crop((x1, y1, x1 + tw, y1 + th))
    return img


def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img


def __print_size_warning(ow, oh, w, h):
    """Print warning information about image size(only print once)"""
    if not hasattr(__print_size_warning, 'has_printed'):
        print("The image size needs to be a multiple of 4. "
              "The loaded image size was (%d, %d), so it was adjusted to "
              "(%d, %d). This adjustment will be done to all images "
              "whose sizes are not multiples of 4" % (ow, oh, w, h))
        __print_size_warning.has_printed = True


================================================
FILE: data/image_folder.py
================================================
"""A modified image folder class

We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""

import torch.utils.data as data

from PIL import Image
import os
import os.path

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir, max_dataset_size=float("inf")):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    return images[:min(max_dataset_size, len(images))]


def default_loader(path):
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " +
                               ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.return_paths:
            return img, path
        else:
            return img

    def __len__(self):
        return len(self.imgs)


================================================
FILE: data/single_dataset.py
================================================
from data.base_dataset import BaseDataset, get_transform, get_params, get_transform_mask
from data.image_folder import make_dataset
from PIL import Image
import torch
import os


class SingleDataset(BaseDataset):
    """This dataset class can load a set of images specified by the path --dataroot /path/to/data.

    It can be used for generating CycleGAN results only for one side with the model option '-model test'.
    """

    def __init__(self, opt):
        """Initialize this dataset class.

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseDataset.__init__(self, opt)
        if os.path.exists(opt.dataroot):
            self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
        else:
            imglistA = 'datasets/list/%s/%s.txt' % (opt.phase+'A', opt.dataroot)
            self.A_paths = sorted(open(imglistA, 'r').read().splitlines())
        self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index - - a random integer for data indexing

        Returns a dictionary that contains A and A_paths
            A(tensor) - - an image in one domain
            A_paths(str) - - the path of the image
        """
        A_path = self.A_paths[index]
        A_img = Image.open(A_path).convert('RGB')
        self.opt.W, self.opt.H = A_img.size
        transform_params_A = get_params(self.opt, A_img.size)
        A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img)
        item = {'A': A, 'A_paths': A_path}

        if self.opt.style_control:
            if self.opt.sinput == 'sind':
                B_style = torch.Tensor([0.,0.,0.])
                B_style[self.opt.sind] = 1.
            elif self.opt.sinput == 'svec':
                ss = self.opt.svec.split(',')
                B_style = torch.Tensor([float(ss[0]),float(ss[1]),float(ss[2])])
            elif self.opt.sinput == 'simg':
                self.featureloc = os.path.join('style_features/styles2/', self.opt.sfeature_mode)
                B_style = np.load(self.featureloc, self.opt.simg[:-4]+'.npy')
            B_style = B_style.view(3, 1, 1)
            B_style = B_style.repeat(1, 128, 128)
            item['B_style'] = B_style

        return item

    def __len__(self):
        """Return the total number of images in the dataset."""
        return len(self.A_paths)


================================================
FILE: data/unaligned_mask_stylecls_dataset.py
================================================
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, get_transform_mask
from data.image_folder import make_dataset
from PIL import Image
import random
import torch
import torchvision.transforms as transforms
import numpy as np
import pdb


class UnalignedMaskStyleClsDataset(BaseDataset):
    def __init__(self, opt):
        BaseDataset.__init__(self, opt)

        self.dir_A = os.path.join(opt.dataroot, opt.phase + '/A')  # create a path '/path/to/data/trainA'
        self.dir_B = os.path.join(opt.dataroot, opt.phase + '/B')  # create a path '/path/to/data/trainB'

        self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size))   # load images from '/path/to/data/trainA'
        self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size))    # load images from '/path/to/data/trainB'

        self.A_size = len(self.A_paths)  # get the size of dataset A
        self.B_size = len(self.B_paths)  # get the size of dataset B
        print("A size:", self.A_size)
        print("B size:", self.B_size)
        btoA = self.opt.direction == 'BtoA'
        self.input_nc = self.opt.output_nc if btoA else self.opt.input_nc       # get the number of channels of input image
        self.output_nc = self.opt.input_nc if btoA else self.opt.output_nc      # get the number of channels of output image

        self.auxdir_A = os.path.join(opt.dataroot, "%s/A" % opt.phase)
        self.auxdir_B = os.path.join(opt.dataroot, "%s/B" % opt.phase)


    def __getitem__(self, index):
        A_path = self.A_paths[index % self.A_size]  # make sure index is within then range
        if self.opt.serial_batches:   # make sure index is within then range
            index_B = index % self.B_size
        else:   # randomize the index for domain B to avoid fixed pairs.
            index_B = random.randint(0, self.B_size - 1)
        B_path = self.B_paths[index_B]
        A_img = Image.open(A_path).convert('RGB')
        B_img = Image.open(B_path).convert('RGB')

        basenA = os.path.basename(A_path)
        A_mask_img = Image.open(os.path.join(self.auxdir_A+'_nose',basenA))
        basenB = os.path.basename(B_path)
        B_mask_img = Image.open(os.path.join(self.auxdir_B+'_nose',basenB))
        if self.opt.use_eye_mask:
            A_maske_img = Image.open(os.path.join(self.auxdir_A+'_eyes',basenA))
            B_maske_img = Image.open(os.path.join(self.auxdir_B+'_eyes',basenB))
        if self.opt.use_lip_mask:
            A_maskl_img = Image.open(os.path.join(self.auxdir_A+'_lips',basenA))
            B_maskl_img = Image.open(os.path.join(self.auxdir_B+'_lips',basenB))

        # apply image transformation
        transform_params_A = get_params(self.opt, A_img.size)
        transform_params_B = get_params(self.opt, B_img.size)
        A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img)
        B = get_transform(self.opt, transform_params_B, grayscale=(self.output_nc == 1))(B_img)
        A_mask = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_mask_img)
        B_mask = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_mask_img)
        if self.opt.use_eye_mask:
            A_maske = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maske_img)
            B_maske = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maske_img)
        if self.opt.use_lip_mask:
            A_maskl = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maskl_img)
            B_maskl = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maskl_img)

        item = {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_mask': A_mask, 'B_mask': B_mask}
        if self.opt.use_eye_mask:
            item['A_maske'] = A_maske
            item['B_maske'] = B_maske
        if self.opt.use_lip_mask:
            item['A_maskl'] = A_maskl
            item['B_maskl'] = B_maskl

        softmax = np.load(os.path.join(self.auxdir_B+'_feat',basenB[:-4]+'.npy'))
        softmax = torch.Tensor(softmax)
        [maxv,index] = torch.max(softmax,0)
        B_label = index
        if len(self.opt.sfeature_mode) >= 8 and self.opt.sfeature_mode[-8:] == '_softmax':
            if self.opt.one_hot:
                B_style = torch.Tensor([0.,0.,0.])
                B_style[index] = 1.
            else:
                B_style = softmax
            B_style = B_style.view(3, 1, 1)
            B_style = B_style.repeat(1, 128, 128)
        elif self.opt.sfeature_mode == 'domain':
            B_style = B_label
        item['B_style'] = B_style
        item['B_label'] = B_label
        if self.opt.isTrain and self.opt.style_loss_with_weight:
            item['B_style0'] = softmax

        return item

    def __len__(self):
        return max(self.A_size, self.B_size)


================================================
FILE: models/__init__.py
================================================
"""This package contains modules related to objective functions, optimizations, and network architectures.

To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
    -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
    -- <set_input>:                     unpack data from dataset and apply preprocessing.
    -- <forward>:                       produce intermediate results.
    -- <optimize_parameters>:           calculate loss, gradients, and update network weights.
    -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.

In the function <__init__>, you need to define four lists:
    -- self.loss_names (str list):          specify the training losses that you want to plot and save.
    -- self.model_names (str list):         define networks used in our training.
    -- self.visual_names (str list):        specify the images that you want to display and save.
    -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.

Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""

import importlib
from models.base_model import BaseModel


def find_model_using_name(model_name):
    """Import the module "models/[model_name]_model.py".

    In the file, the class called DatasetNameModel() will
    be instantiated. It has to be a subclass of BaseModel,
    and it is case-insensitive.
    """
    model_filename = "models." + model_name + "_model"
    modellib = importlib.import_module(model_filename)
    model = None
    target_model_name = model_name.replace('_', '') + 'model'
    for name, cls in modellib.__dict__.items():
        if name.lower() == target_model_name.lower() \
           and issubclass(cls, BaseModel):
            model = cls

    if model is None:
        print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
        exit(0)

    return model


def get_option_setter(model_name):
    """Return the static method <modify_commandline_options> of the model class."""
    model_class = find_model_using_name(model_name)
    return model_class.modify_commandline_options


def create_model(opt):
    """Create a model given the option.

    This function warps the class CustomDatasetDataLoader.
    This is the main interface between this package and 'train.py'/'test.py'

    Example:
        >>> from models import create_model
        >>> model = create_model(opt)
    """
    model = find_model_using_name(opt.model)
    instance = model(opt)
    print("model [%s] was created" % type(instance).__name__)
    return instance


================================================
FILE: models/asymmetric_cycle_gan_cls_model.py
================================================
import torch
import itertools
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
import models.dist_model as dm # numpy==1.14.3
import torchvision.transforms as transforms
import os

def truncate(fake_B,a=127.5):#[-1,1]
    return ((fake_B+1)*a).int().float()/a-1

class AsymmetricCycleGANClsModel(BaseModel):
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        parser.set_defaults(no_dropout=True)  # default CycleGAN did not use dropout
        parser.set_defaults(dataset_mode='unaligned_mask_stylecls')
        parser.add_argument('--netda', type=str, default='basic_cls')
        parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A')
        parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0 (before insert style)')
        parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')
        if is_train:
            parser.add_argument('--lambda_A', type=float, default=5.0, help='weight for cycle loss (A -> B -> A)')
            parser.add_argument('--lambda_B', type=float, default=5.0, help='weight for cycle loss (B -> A -> B)')
            parser.add_argument('--lambda_identity', type=float, default=0, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
            parser.add_argument('--ntrunc_trunc', type=int, default=1, help='whether use both non-trunc version and trunc version')
            parser.add_argument('--trunc_a', type=float, default=31.875, help='multiply which value to round when trunc')
            parser.add_argument('--lambda_A_trunc', type=float, default=5.0, help='weight for cycle loss for trunc')
            parser.add_argument('--hed_pretrained_mode', type=str, default='./checkpoints/network-bsds500.pytorch', help='path to the pretrained hed model')
            parser.add_argument('--lambda_G_A_l', type=float, default=0.5, help='weight for local GAN loss in G')
            parser.add_argument('--style_loss_with_weight', type=int, default=1, help='whether multiply prob in style loss')
        # for masks
        parser.add_argument('--use_mask', type=int, default=1, help='whether use mask for special face region')
        parser.add_argument('--use_eye_mask', type=int, default=1, help='whether use mask for special face region')
        parser.add_argument('--use_lip_mask', type=int, default=1, help='whether use mask for special face region')
        parser.add_argument('--mask_type', type=int, default=3, help='use mask type, 0 outside black, 1 outside white')
        # for style control
        parser.add_argument('--style_control', type=int, default=1, help='use style_control')
        parser.add_argument('--sfeature_mode', type=str, default='1vgg19_softmax', help='vgg19 softmax as feature')
        parser.add_argument('--one_hot', type=int, default=0, help='use one-hot for style code')

        return parser

    def __init__(self, opt):
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')
        if self.isTrain:
            visual_names_A.append('real_A_hed')
            visual_names_A.append('rec_A_hed')
        if self.isTrain and self.opt.ntrunc_trunc:
            visual_names_A.append('rec_At')
            visual_names_A.append('rec_At_hed')
            self.loss_names = ['D_A', 'G_A', 'cycle_A', 'cycle_A2', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'G']
        if self.isTrain and self.opt.use_mask:
            visual_names_A.append('fake_B_l')
            visual_names_A.append('real_B_l')
            self.loss_names += ['D_A_l', 'G_A_l']
        if self.isTrain and self.opt.use_eye_mask:
            visual_names_A.append('fake_B_le')
            visual_names_A.append('real_B_le')
            self.loss_names += ['D_A_le', 'G_A_le']
        if self.isTrain and self.opt.use_lip_mask:
            visual_names_A.append('fake_B_ll')
            visual_names_A.append('real_B_ll')
            self.loss_names += ['D_A_ll', 'G_A_ll']
        if not self.isTrain and self.opt.use_mask:
            visual_names_A.append('fake_B_l')
            visual_names_A.append('real_B_l')
        if not self.isTrain and self.opt.use_eye_mask:
            visual_names_A.append('fake_B_le')
            visual_names_A.append('real_B_le')
        if not self.isTrain and self.opt.use_lip_mask:
            visual_names_A.append('fake_B_ll')
            visual_names_A.append('real_B_ll')
        self.loss_names += ['D_A_cls','G_A_cls']

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B
        print(self.visual_names)
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
        if self.isTrain:
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
            if self.opt.use_mask:
                self.model_names += ['D_A_l']
            if self.opt.use_eye_mask:
                self.model_names += ['D_A_le']
            if self.opt.use_lip_mask:
                self.model_names += ['D_A_ll']
        else:  # during test time, only load Gs
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        if not self.opt.style_control:
            self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        else:
            print(opt.netga)
            print('model0_res', opt.model0_res)
            print('model1_res', opt.model1_res)
            self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:  # define discriminators
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netda,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, n_class=3)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
            if self.opt.use_mask:
                if self.opt.mask_type in [2, 3]:
                    output_nc = opt.output_nc + 1
                else:
                    output_nc = opt.output_nc
                self.netD_A_l = networks.define_D(output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
            if self.opt.use_eye_mask:
                if self.opt.mask_type in [2, 3]:
                    output_nc = opt.output_nc + 1
                else:
                    output_nc = opt.output_nc
                self.netD_A_le = networks.define_D(output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
            if self.opt.use_lip_mask:
                if self.opt.mask_type in [2, 3]:
                    output_nc = opt.output_nc + 1
                else:
                    output_nc = opt.output_nc
                self.netD_A_ll = networks.define_D(output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
        
        if not self.isTrain:
            self.criterionGAN = networks.GANLoss('lsgan').to(self.device)

        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert(opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.criterionCls = torch.nn.CrossEntropyLoss()
            self.criterionCls2 = torch.nn.CrossEntropyLoss(reduction='none')
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            if not self.opt.use_mask:
                self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            elif not self.opt.use_eye_mask:
                D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters())
                self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
            elif not self.opt.use_lip_mask:
                D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters())
                self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
            else:
                D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters()) + list(self.netD_A_ll.parameters())
                self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

            self.lpips = dm.DistModel(opt,model='net-lin',net='alex',use_gpu=True)

            self.hed = networks.define_HED(init_weights_=opt.hed_pretrained_mode, gpu_ids_=self.opt.gpu_ids_p)
            self.set_requires_grad(self.hed, False)


    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']
        if self.opt.use_mask:
            self.A_mask = input['A_mask'].to(self.device)
            self.B_mask = input['B_mask'].to(self.device)
        if self.opt.use_eye_mask:
            self.A_maske = input['A_maske'].to(self.device)
            self.B_maske = input['B_maske'].to(self.device)
        if self.opt.use_lip_mask:
            self.A_maskl = input['A_maskl'].to(self.device)
            self.B_maskl = input['B_maskl'].to(self.device)
        if self.opt.style_control:
            self.real_B_style = input['B_style'].to(self.device)
            self.real_B_label = input['B_label'].to(self.device)
        if self.opt.isTrain and self.opt.style_loss_with_weight:
            self.real_B_style0 = input['B_style0'].to(self.device)
            self.zero = torch.zeros(self.real_B_label.size(),dtype=torch.int64).to(self.device)
            self.one = torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device)
            self.two = 2*torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device)

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        if not self.opt.style_control:
            self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        else:
            #print(torch.mean(self.real_B_style,(2,3)),'style_control')
            self.fake_B = self.netG_A(self.real_A, self.real_B_style)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        if not self.opt.style_control:
            self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))
        else:
            #print(torch.mean(self.real_B_style,(2,3)),'style_control')
            self.rec_B = self.netG_A(self.fake_A, self.real_B_style) # -- cycle_B loss

        if self.opt.use_mask:
            self.fake_B_l = self.masked(self.fake_B,self.A_mask)
            self.real_B_l = self.masked(self.real_B,self.B_mask)
        if self.opt.use_eye_mask:
            self.fake_B_le = self.masked(self.fake_B,self.A_maske)
            self.real_B_le = self.masked(self.real_B,self.B_maske)
        if self.opt.use_lip_mask:
            self.fake_B_ll = self.masked(self.fake_B,self.A_maskl)
            self.real_B_ll = self.masked(self.real_B,self.B_maskl)

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D
    
    def backward_D_basic_cls(self, netD, real, fake):
        # Real
        pred_real, pred_real_cls = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        if not self.opt.style_loss_with_weight:
            loss_D_real_cls = self.criterionCls(pred_real_cls, self.real_B_label)
        else:
            loss_D_real_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_real_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_real_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_real_cls, self.two))
        # Fake
        pred_fake, pred_fake_cls = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        if not self.opt.style_loss_with_weight:
            loss_D_fake_cls = self.criterionCls(pred_fake_cls, self.real_B_label)
        else:
            loss_D_fake_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two))
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D_cls = (loss_D_real_cls + loss_D_fake_cls) * 0.5
        loss_D_total = loss_D + loss_D_cls
        loss_D_total.backward()
        return loss_D, loss_D_cls

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A, self.loss_D_A_cls = self.backward_D_basic_cls(self.netD_A, self.real_B, fake_B)
    
    def backward_D_A_l(self):
        """Calculate GAN loss for discriminator D_A_l"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A_l = self.backward_D_basic(self.netD_A_l, self.masked(self.real_B,self.B_mask), self.masked(fake_B,self.A_mask))

    def backward_D_A_le(self):
        """Calculate GAN loss for discriminator D_A_le"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A_le = self.backward_D_basic(self.netD_A_le, self.masked(self.real_B,self.B_maske), self.masked(fake_B,self.A_maske))
    
    def backward_D_A_ll(self):
        """Calculate GAN loss for discriminator D_A_ll"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A_ll = self.backward_D_basic(self.netD_A_ll, self.masked(self.real_B,self.B_maskl), self.masked(fake_B,self.A_maskl))

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
    
    def update_process(self, epoch):
        self.process = (epoch - 1) / float(self.opt.niter_decay + self.opt.niter)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_G_A_l = self.opt.lambda_G_A_l
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        lambda_A_trunc = self.opt.lambda_A_trunc
        if self.opt.ntrunc_trunc:
            lambda_A = lambda_A * (1 - self.process * 0.9)
            lambda_A_trunc = lambda_A_trunc * self.process * 0.9
        self.lambda_As = [lambda_A, lambda_A_trunc]
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        pred_fake, pred_fake_cls = self.netD_A(self.fake_B)
        self.loss_G_A = self.criterionGAN(pred_fake, True)
        if not self.opt.style_loss_with_weight:
            self.loss_G_A_cls = self.criterionCls(pred_fake_cls, self.real_B_label)
        else:
            self.loss_G_A_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two))
        if self.opt.use_mask:
            self.loss_G_A_l = self.criterionGAN(self.netD_A_l(self.fake_B_l), True) * lambda_G_A_l
        if self.opt.use_eye_mask:
            self.loss_G_A_le = self.criterionGAN(self.netD_A_le(self.fake_B_le), True) * lambda_G_A_l
        if self.opt.use_lip_mask:
            self.loss_G_A_ll = self.criterionGAN(self.netD_A_ll(self.fake_B_ll), True) * lambda_G_A_l
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)

        # Forward cycle loss  LPIPS( HED(G_B(G_A(A))), HED(A))
        ts = self.real_A.shape
        gpu_p = self.opt.gpu_ids_p[0]
        gpu = self.opt.gpu_ids[0]
        rec_A_hed = (self.hed(self.rec_A.cuda(gpu_p)/2+0.5)-0.5)*2
        real_A_hed = (self.hed(self.real_A.cuda(gpu_p)/2+0.5)-0.5)*2
        self.loss_cycle_A = (self.lpips.forward_pair(rec_A_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A
        self.rec_A_hed = rec_A_hed
        self.real_A_hed = real_A_hed
        if self.opt.ntrunc_trunc:
            self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))
            rec_At_hed = (self.hed(self.rec_At.cuda(gpu_p)/2+0.5)-0.5)*2
            self.loss_cycle_A2 = (self.lpips.forward_pair(rec_At_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A_trunc
            self.rec_At_hed = rec_At_hed

        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        if getattr(self,'loss_cycle_A2',-1) != -1:
            self.loss_G = self.loss_G + self.loss_cycle_A2
        if getattr(self,'loss_G_A_l',-1) != -1:
            self.loss_G = self.loss_G + self.loss_G_A_l
        if getattr(self,'loss_G_A_le',-1) != -1:
            self.loss_G = self.loss_G + self.loss_G_A_le
        if getattr(self,'loss_G_A_ll',-1) != -1:
            self.loss_G = self.loss_G + self.loss_G_A_ll
        if getattr(self,'loss_G_A_cls',-1) != -1:
            self.loss_G = self.loss_G + self.loss_G_A_cls
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs
        if self.opt.use_mask:
            self.set_requires_grad([self.netD_A_l], False)
        if self.opt.use_eye_mask:
            self.set_requires_grad([self.netD_A_le], False)
        if self.opt.use_lip_mask:
            self.set_requires_grad([self.netD_A_ll], False)
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()             # calculate gradients for G_A and G_B
        self.optimizer_G.step()       # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        if self.opt.use_mask:
            self.set_requires_grad([self.netD_A_l], True)
        if self.opt.use_eye_mask:
            self.set_requires_grad([self.netD_A_le], True)
        if self.opt.use_lip_mask:
            self.set_requires_grad([self.netD_A_ll], True)
        self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()      # calculate gradients for D_A
        if self.opt.use_mask:
            self.backward_D_A_l()# calculate gradients for D_A_l
        if self.opt.use_eye_mask:
            self.backward_D_A_le()# calculate gradients for D_A_le
        if self.opt.use_lip_mask:
            self.backward_D_A_ll()# calculate gradients for D_A_ll
        self.backward_D_B()      # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights


================================================
FILE: models/base_model.py
================================================
import os
import torch
from collections import OrderedDict
from abc import ABCMeta, abstractmethod
from . import networks


class BaseModel():
    __metaclass__ = ABCMeta
    """This class is an abstract base class (ABC) for models.
    To create a subclass, you need to implement the following five functions:
        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
        -- <set_input>:                     unpack data from dataset and apply preprocessing.
        -- <forward>:                       produce intermediate results.
        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.
        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
    """

    def __init__(self, opt):
        """Initialize the BaseModel class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions

        When creating your custom class, you need to implement your own initialization.
        In this fucntion, you should first call <BaseModel.__init__(self, opt)>
        Then, you need to define four lists:
            -- self.loss_names (str list):          specify the training losses that you want to plot and save.
            -- self.model_names (str list):         specify the images that you want to display and save.
            -- self.visual_names (str list):        define networks used in our training.
            -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
        """
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir
        if opt.preprocess != 'scale_width':  # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
            torch.backends.cudnn.benchmark = True
        self.loss_names = []
        self.model_names = []
        self.visual_names = []
        self.optimizers = []
        self.image_paths = []
        self.metric = 0  # used for learning rate policy 'plateau'

    @staticmethod
    def modify_commandline_options(parser, is_train):
        """Add new model-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.
        """
        return parser

    @abstractmethod
    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): includes the data itself and its metadata information.
        """
        pass

    @abstractmethod
    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        pass

    @abstractmethod
    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        pass

    def setup(self, opt):
        """Load and print networks; create schedulers

        Parameters:
            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        if self.isTrain:
            self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
        if not self.isTrain or opt.continue_train:
            load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch
            self.load_networks(load_suffix)
        self.print_networks(opt.verbose)

    def eval(self):
        """Make models eval mode during test time"""
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                net.eval()

    def test(self):
        """Forward function used in test time.

        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
        It also calls <compute_visuals> to produce additional visualization results
        """
        with torch.no_grad():
            self.forward()
            self.compute_visuals()

    def compute_visuals(self):
        """Calculate additional output images for visdom and HTML visualization"""
        pass

    def get_image_paths(self):
        """ Return image paths that are used to load current data"""
        return self.image_paths

    def update_learning_rate(self):
        """Update learning rates for all the networks; called at the end of every epoch"""
        for scheduler in self.schedulers:
            if self.opt.lr_policy == 'plateau':
                scheduler.step(self.metric)
            else:
                scheduler.step()

        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    def get_current_visuals(self):
        """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
        visual_ret = OrderedDict()
        for name in self.visual_names:
            if isinstance(name, str):
                visual_ret[name] = getattr(self, name)
        return visual_ret

    def get_current_losses(self):
        """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
        errors_ret = OrderedDict()
        for name in self.loss_names:
            if isinstance(name, str):
                errors_ret[name] = float(getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number
        return errors_ret

    def save_networks(self, epoch):
        """Save all the networks to the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        for name in self.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = os.path.join(self.save_dir, save_filename)
                net = getattr(self, 'net' + name)

                if len(self.gpu_ids) > 0 and torch.cuda.is_available():
                    torch.save(net.module.cpu().state_dict(), save_path)
                    net.cuda(self.gpu_ids[0])
                else:
                    torch.save(net.cpu().state_dict(), save_path)

    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
        """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
        key = keys[i]
        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
            if module.__class__.__name__.startswith('InstanceNorm') and \
                    (key == 'running_mean' or key == 'running_var'):
                if getattr(module, key) is None:
                    state_dict.pop('.'.join(keys))
            if module.__class__.__name__.startswith('InstanceNorm') and \
               (key == 'num_batches_tracked'):
                state_dict.pop('.'.join(keys))
        else:
            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)

    def load_networks(self, epoch):
        """Load all the networks from the disk.

        Parameters:
            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
        """
        for name in self.model_names:
            if isinstance(name, str):
                load_filename = '%s_net_%s.pth' % (epoch, name)
                load_path = os.path.join(self.save_dir, load_filename)
                net = getattr(self, 'net' + name)
                if isinstance(net, torch.nn.DataParallel):
                    net = net.module
                print('loading the model from %s' % load_path)
                # if you are using PyTorch newer than 0.4 (e.g., built from
                # GitHub source), you can remove str() on self.device
                state_dict = torch.load(load_path, map_location=str(self.device))
                if hasattr(state_dict, '_metadata'):
                    del state_dict._metadata

                # patch InstanceNorm checkpoints prior to 0.4
                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop
                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
                net.load_state_dict(state_dict)

    def print_networks(self, verbose):
        """Print the total number of parameters in the network and (if verbose) network architecture

        Parameters:
            verbose (bool) -- if verbose: print the network architecture
        """
        print('---------- Networks initialized -------------')
        for name in self.model_names:
            if isinstance(name, str):
                net = getattr(self, 'net' + name)
                num_params = 0
                for param in net.parameters():
                    num_params += param.numel()
                if verbose:
                    print(net)
                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
        print('-----------------------------------------------')

    def set_requires_grad(self, nets, requires_grad=False):
        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
        Parameters:
            nets (network list)   -- a list of networks
            requires_grad (bool)  -- whether the networks require gradients or not
        """
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad
    
    # ===========================================================================================================
    def masked(self, A,mask):
        if self.opt.mask_type == 0:
            return (A/2+0.5)*mask*2-1
        elif self.opt.mask_type == 1:
            return ((A/2+0.5)*mask+1-mask)*2-1
        elif self.opt.mask_type == 2:
            return torch.cat((A, mask), 1)
        elif self.opt.mask_type == 3:
            masked = ((A/2+0.5)*mask+1-mask)*2-1
            return torch.cat((masked, mask), 1)

================================================
FILE: models/dist_model.py
================================================

from __future__ import absolute_import

import sys
sys.path.append('..')
sys.path.append('.')
import numpy as np
import torch
from torch import nn
from collections import OrderedDict
from torch.autograd import Variable
from .base_model import BaseModel
from scipy.ndimage import zoom
import skimage.transform

from . import networks_basic as networks
# from PerceptualSimilarity.util import util
from util import util

class DistModel(BaseModel):
    def name(self):
        return self.model_name

    def __init__(self, opt, model='net-lin', net='alex', pnet_rand=False, pnet_tune=False, model_path=None, colorspace='Lab', use_gpu=True, printNet=False, spatial=False, spatial_shape=None, spatial_order=1, spatial_factor=None, is_train=False, lr=.0001, beta1=0.5, version='0.1'):
        '''
        INPUTS
            model - ['net-lin'] for linearly calibrated network
                    ['net'] for off-the-shelf network
                    ['L2'] for L2 distance in Lab colorspace
                    ['SSIM'] for ssim in RGB colorspace
            net - ['squeeze','alex','vgg']
            model_path - if None, will look in weights/[NET_NAME].pth
            colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
            use_gpu - bool - whether or not to use a GPU
            printNet - bool - whether or not to print network architecture out
            spatial - bool - whether to output an array containing varying distances across spatial dimensions
            spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).
            spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.
            spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).
            is_train - bool - [True] for training mode
            lr - float - initial learning rate
            beta1 - float - initial momentum term for adam
            version - 0.1 for latest, 0.0 was original
        '''
        BaseModel.__init__(self, opt)

        self.model = model
        self.net = net
        self.use_gpu = use_gpu
        self.is_train = is_train
        self.spatial = spatial
        self.spatial_shape = spatial_shape
        self.spatial_order = spatial_order
        self.spatial_factor = spatial_factor

        self.model_name = '%s [%s]'%(model,net)
        if(self.model == 'net-lin'): # pretrained net + linear layer
            #self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')
            self.device = torch.device('cuda:{}'.format(opt.gpu_ids_p[0])) if opt.gpu_ids_p else torch.device('cpu')
            self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,use_dropout=True,spatial=spatial,version=version,lpips=True).to(self.device)
            kw = {}
            
            if not use_gpu:
                kw['map_location'] = 'cpu'
            if(model_path is None):
                import inspect
                #model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', '..', 'weights/v%s/%s.pth'%(version,net)))
                model_path = './checkpoints/weights/v%s/%s.pth'%(version,net)

            if(not is_train):
                print('Loading model from: %s'%model_path)
                #self.net.load_state_dict(torch.load(model_path, **kw))
                state_dict = torch.load(model_path, map_location=str(self.device))
                self.net.load_state_dict(state_dict, strict=False)

        elif(self.model=='net'): # pretrained network
            assert not self.spatial, 'spatial argument not supported yet for uncalibrated networks'
            self.net = networks.PNet(use_gpu=use_gpu,pnet_type=net,device=self.device)
            self.is_fake_net = True
        elif(self.model in ['L2','l2']):
            self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace,device=self.device) # not really a network, only for testing
            self.model_name = 'L2'
        elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
            self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace,device=self.device)
            self.model_name = 'SSIM'
        else:
            raise ValueError("Model [%s] not recognized." % self.model)

        self.parameters = list(self.net.parameters())

        if self.is_train: # training mode
            # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
            self.rankLoss = networks.BCERankingLoss(use_gpu=use_gpu,device=self.device)
            self.parameters+=self.rankLoss.parameters
            self.lr = lr
            self.old_lr = lr
            self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
        else: # test mode
            self.net.eval()

        if(printNet):
            print('---------- Networks initialized -------------')
            networks.print_network(self.net)
            print('-----------------------------------------------')

    def forward_pair(self,in1,in2,retPerLayer=False):
        if(retPerLayer):
            return self.net.forward(in1,in2, retPerLayer=True)
        else:
            return self.net.forward(in1,in2)

    def forward(self, in0, in1, retNumpy=False):
        ''' Function computes the distance between image patches in0 and in1
        INPUTS
            in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
            retNumpy - [False] to return as torch.Tensor, [True] to return as numpy array
        OUTPUT
            computed distances between in0 and in1
        '''

        self.input_ref = in0
        self.input_p0 = in1

        self.var_ref = Variable(self.input_ref,requires_grad=True)
        self.var_p0 = Variable(self.input_p0,requires_grad=True)

        self.d0 = self.forward_pair(self.var_ref, self.var_p0)
        self.loss_total = self.d0

        def convert_output(d0):
            if(retNumpy):
                ans = d0.cpu().data.numpy()
                if not self.spatial:
                    ans = ans.flatten()
                else:
                    assert(ans.shape[0] == 1 and len(ans.shape) == 4)
                    return ans[0,...].transpose([1, 2, 0])                  # Reshape to usual numpy image format: (height, width, channels)
                return ans
            else:
                return d0

        if self.spatial:
            L = [convert_output(x) for x in self.d0]
            spatial_shape = self.spatial_shape
            if spatial_shape is None:
                if(self.spatial_factor is None):
                    spatial_shape = (in0.size()[2],in0.size()[3])
                else:
                    spatial_shape = (max([x.shape[0] for x in L])*self.spatial_factor, max([x.shape[1] for x in L])*self.spatial_factor)
            
            L = [skimage.transform.resize(x, spatial_shape, order=self.spatial_order, mode='edge') for x in L]
            
            L = np.mean(np.concatenate(L, 2) * len(L), 2)
            return L
        else:
            return convert_output(self.d0)

    # ***** TRAINING FUNCTIONS *****
    def optimize_parameters(self):
        self.forward_train()
        self.optimizer_net.zero_grad()
        self.backward_train()
        self.optimizer_net.step()
        self.clamp_weights()

    def clamp_weights(self):
        for module in self.net.modules():
            if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
                module.weight.data = torch.clamp(module.weight.data,min=0)

    def set_input(self, data):
        self.input_ref = data['ref']
        self.input_p0 = data['p0']
        self.input_p1 = data['p1']
        self.input_judge = data['judge']

        if(self.use_gpu):
            self.input_ref = self.input_ref.cuda(self.device)
            self.input_p0 = self.input_p0.cuda(self.device)
            self.input_p1 = self.input_p1.cuda(self.device)
            self.input_judge = self.input_judge.cuda(self.device)

        self.var_ref = Variable(self.input_ref,requires_grad=True)
        self.var_p0 = Variable(self.input_p0,requires_grad=True)
        self.var_p1 = Variable(self.input_p1,requires_grad=True)

    def forward_train(self): # run forward pass
        self.d0 = self.forward_pair(self.var_ref, self.var_p0)
        self.d1 = self.forward_pair(self.var_ref, self.var_p1)
        self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)

        # var_judge
        self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())

        self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
        return self.loss_total

    def backward_train(self):
        torch.mean(self.loss_total).backward()

    def compute_accuracy(self,d0,d1,judge):
        ''' d0, d1 are Variables, judge is a Tensor '''
        d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()
        judge_per = judge.cpu().numpy().flatten()
        return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)

    def get_current_errors(self):
        retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),
                            ('acc_r', self.acc_r)])

        for key in retDict.keys():
            retDict[key] = np.mean(retDict[key])

        return retDict

    def get_current_visuals(self):
        zoom_factor = 256/self.var_ref.data.size()[2]

        ref_img = util.tensor2im(self.var_ref.data)
        p0_img = util.tensor2im(self.var_p0.data)
        p1_img = util.tensor2im(self.var_p1.data)

        ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)
        p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)
        p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)

        return OrderedDict([('ref', ref_img_vis),
                            ('p0', p0_img_vis),
                            ('p1', p1_img_vis)])

    def save(self, path, label):
        self.save_network(self.net, path, '', label)
        self.save_network(self.rankLoss.net, path, 'rank', label)

    def update_learning_rate(self,nepoch_decay):
        lrd = self.lr / nepoch_decay
        lr = self.old_lr - lrd

        for param_group in self.optimizer_net.param_groups:
            param_group['lr'] = lr

        print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))
        self.old_lr = lr



def score_2afc_dataset(data_loader,func):
    ''' Function computes Two Alternative Forced Choice (2AFC) score using
        distance function 'func' in dataset 'data_loader'
    INPUTS
        data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
        func - callable distance function - calling d=func(in0,in1) should take 2
            pytorch tensors with shape Nx3xXxY, and return numpy array of length N
    OUTPUTS
        [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
        [1] - dictionary with following elements
            d0s,d1s - N arrays containing distances between reference patch to perturbed patches 
            gts - N array in [0,1], preferred patch selected by human evaluators
                (closer to "0" for left patch p0, "1" for right patch p1,
                "0.6" means 60pct people preferred right patch, 40pct preferred left)
            scores - N array in [0,1], corresponding to what percentage function agreed with humans
    CONSTS
        N - number of test triplets in data_loader
    '''

    d0s = []
    d1s = []
    gts = []

    # bar = pb.ProgressBar(max_value=data_loader.load_data().__len__())
    for (i,data) in enumerate(data_loader.load_data()):
        d0s+=func(data['ref'],data['p0']).tolist()
        d1s+=func(data['ref'],data['p1']).tolist()
        gts+=data['judge'].cpu().numpy().flatten().tolist()
        # bar.update(i)

    d0s = np.array(d0s)
    d1s = np.array(d1s)
    gts = np.array(gts)
    scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5

    return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))

def score_jnd_dataset(data_loader,func):
    ''' Function computes JND score using distance function 'func' in dataset 'data_loader'
    INPUTS
        data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside
        func - callable distance function - calling d=func(in0,in1) should take 2
            pytorch tensors with shape Nx3xXxY, and return numpy array of length N
    OUTPUTS
        [0] - JND score in [0,1], mAP score (area under precision-recall curve)
        [1] - dictionary with following elements
            ds - N array containing distances between two patches shown to human evaluator
            sames - N array containing fraction of people who thought the two patches were identical
    CONSTS
        N - number of test triplets in data_loader
    '''

    ds = []
    gts = []

    # bar = pb.ProgressBar(max_value=data_loader.load_data().__len__())
    for (i,data) in enumerate(data_loader.load_data()):
        ds+=func(data['p0'],data['p1']).tolist()
        gts+=data['same'].cpu().numpy().flatten().tolist()
        # bar.update(i)

    sames = np.array(gts)
    ds = np.array(ds)

    sorted_inds = np.argsort(ds)
    ds_sorted = ds[sorted_inds]
    sames_sorted = sames[sorted_inds]

    TPs = np.cumsum(sames_sorted)
    FPs = np.cumsum(1-sames_sorted)
    FNs = np.sum(sames_sorted)-TPs

    precs = TPs/(TPs+FPs)
    recs = TPs/(TPs+FNs)
    score = util.voc_ap(recs,precs)

    return(score, dict(ds=ds,sames=sames))


================================================
FILE: models/networks.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.optim import lr_scheduler


###############################################################################
# Helper Functions
###############################################################################


class Identity(nn.Module):
	def forward(self, x):
		return x


def get_norm_layer(norm_type='instance'):
	"""Return a normalization layer

	Parameters:
		norm_type (str) -- the name of the normalization layer: batch | instance | none

	For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
	For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
	"""
	if norm_type == 'batch':
		norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
	elif norm_type == 'instance':
		norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
	elif norm_type == 'none':
		norm_layer = lambda x: Identity()
	else:
		raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
	return norm_layer


def get_scheduler(optimizer, opt):
	"""Return a learning rate scheduler

	Parameters:
		optimizer          -- the optimizer of the network
		opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
							  opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine

	For 'linear', we keep the same learning rate for the first <opt.niter> epochs
	and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
	For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
	See https://pytorch.org/docs/stable/optim.html for more details.
	"""
	if opt.lr_policy == 'linear':
		def lambda_rule(epoch):
			lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
			return lr_l
		scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
	elif opt.lr_policy == 'step':
		scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
	elif opt.lr_policy == 'plateau':
		scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
	elif opt.lr_policy == 'cosine':
		scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)
	else:
		return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
	return scheduler


def init_weights(net, init_type='normal', init_gain=0.02):
	"""Initialize network weights.

	Parameters:
		net (network)   -- network to be initialized
		init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
		init_gain (float)    -- scaling factor for normal, xavier and orthogonal.

	We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might
	work better for some applications. Feel free to try yourself.
	"""
	def init_func(m):  # define the initialization function
		classname = m.__class__.__name__
		if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
			if init_type == 'normal':
				init.normal_(m.weight.data, 0.0, init_gain)
			elif init_type == 'xavier':
				init.xavier_normal_(m.weight.data, gain=init_gain)
			elif init_type == 'kaiming':
				init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
			elif init_type == 'orthogonal':
				init.orthogonal_(m.weight.data, gain=init_gain)
			else:
				raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
			if hasattr(m, 'bias') and m.bias is not None:
				init.constant_(m.bias.data, 0.0)
		elif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
			init.normal_(m.weight.data, 1.0, init_gain)
			init.constant_(m.bias.data, 0.0)

	print('initialize network with %s' % init_type)
	net.apply(init_func)  # apply the initialization function <init_func>


def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
	"""Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
	Parameters:
		net (network)      -- the network to be initialized
		init_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal
		gain (float)       -- scaling factor for normal, xavier and orthogonal.
		gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

	Return an initialized network.
	"""
	if len(gpu_ids) > 0:
		assert(torch.cuda.is_available())
		net.to(gpu_ids[0])
		net = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs
	init_weights(net, init_type, init_gain=init_gain)
	return net


def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], model0_res=0, model1_res=0, extra_channel=3):
	"""Create a generator

	Parameters:
		input_nc (int) -- the number of channels in input images
		output_nc (int) -- the number of channels in output images
		ngf (int) -- the number of filters in the last conv layer
		netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
		norm (str) -- the name of normalization layers used in the network: batch | instance | none
		use_dropout (bool) -- if use dropout layers.
		init_type (str)    -- the name of our initialization method.
		init_gain (float)  -- scaling factor for normal, xavier and orthogonal.
		gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

	Returns a generator

	Our current implementation provides two types of generators:
		U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
		The original U-Net paper: https://arxiv.org/abs/1505.04597

		Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
		Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
		We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).


	The generator has been initialized by <init_net>. It uses RELU for non-linearity.
	"""
	net = None
	norm_layer = get_norm_layer(norm_type=norm)

	if netG == 'resnet_9blocks':
		net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
	elif netG == 'resnet_style2_9blocks':
		net = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, extra_channel=extra_channel)
	elif netG == 'resnet_6blocks':
		net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
	elif netG == 'unet_128':
		net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
	elif netG == 'unet_256':
		net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
	else:
		raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
	return init_net(net, init_type, init_gain, gpu_ids)


def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], n_class=3):
	"""Create a discriminator

	Parameters:
		input_nc (int)     -- the number of channels in input images
		ndf (int)          -- the number of filters in the first conv layer
		netD (str)         -- the architecture's name: basic | n_layers | pixel
		n_layers_D (int)   -- the number of conv layers in the discriminator; effective when netD=='n_layers'
		norm (str)         -- the type of normalization layers used in the network.
		init_type (str)    -- the name of the initialization method.
		init_gain (float)  -- scaling factor for normal, xavier and orthogonal.
		gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2

	Returns a discriminator

	Our current implementation provides three types of discriminators:
		[basic]: 'PatchGAN' classifier described in the original pix2pix paper.
		It can classify whether 70×70 overlapping patches are real or fake.
		Such a patch-level discriminator architecture has fewer parameters
		than a full-image discriminator and can work on arbitrarily-sized images
		in a fully convolutional fashion.

		[n_layers]: With this mode, you cna specify the number of conv layers in the discriminator
		with the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)

		[pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.
		It encourages greater color diversity but has no effect on spatial statistics.

	The discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.
	"""
	net = None
	norm_layer = get_norm_layer(norm_type=norm)

	if netD == 'basic':  # default PatchGAN classifier
		net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
	elif netD == 'basic_cls':
		net = NLayerDiscriminatorCls(input_nc, ndf, n_layers=3, n_class=3, norm_layer=norm_layer)
	elif netD == 'n_layers':  # more options
		net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
	elif netD == 'pixel':     # classify if each pixel is real or fake
		net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
	else:
		raise NotImplementedError('Discriminator model name [%s] is not recognized' % net)
	return init_net(net, init_type, init_gain, gpu_ids)


def define_HED(init_weights_, gpu_ids_=[]):
	net = HED()

	if len(gpu_ids_) > 0:
		assert(torch.cuda.is_available())
		net.to(gpu_ids_[0])
		net = torch.nn.DataParallel(net, gpu_ids_)  # multi-GPUs
	
	if not init_weights_ == None:
		device = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')
		print('Loading model from: %s'%init_weights_)
		state_dict = torch.load(init_weights_, map_location=str(device))
		if isinstance(net, torch.nn.DataParallel):
			net.module.load_state_dict(state_dict)
		else:
			net.load_state_dict(state_dict)
		print('load the weights successfully')

	return net

##############################################################################
# Classes
##############################################################################
class GANLoss(nn.Module):
	"""Define different GAN objectives.

	The GANLoss class abstracts away the need to create the target label tensor
	that has the same size as the input.
	"""

	def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
		""" Initialize the GANLoss class.

		Parameters:
			gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
			target_real_label (bool) - - label for a real image
			target_fake_label (bool) - - label of a fake image

		Note: Do not use sigmoid as the last layer of Discriminator.
		LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
		"""
		super(GANLoss, self).__init__()
		self.register_buffer('real_label', torch.tensor(target_real_label))
		self.register_buffer('fake_label', torch.tensor(target_fake_label))
		self.gan_mode = gan_mode
		if gan_mode == 'lsgan':#cyclegan
			self.loss = nn.MSELoss()
		elif gan_mode == 'vanilla':
			self.loss = nn.BCEWithLogitsLoss()
		elif gan_mode in ['wgangp']:
			self.loss = None
		else:
			raise NotImplementedError('gan mode %s not implemented' % gan_mode)

	def get_target_tensor(self, prediction, target_is_real):
		"""Create label tensors with the same size as the input.

		Parameters:
			prediction (tensor) - - tpyically the prediction from a discriminator
			target_is_real (bool) - - if the ground truth label is for real images or fake images

		Returns:
			A label tensor filled with ground truth label, and with the size of the input
		"""

		if target_is_real:
			target_tensor = self.real_label
		else:
			target_tensor = self.fake_label
		return target_tensor.expand_as(prediction)

	def __call__(self, prediction, target_is_real):
		"""Calculate loss given Discriminator's output and grount truth labels.

		Parameters:
			prediction (tensor) - - tpyically the prediction output from a discriminator
			target_is_real (bool) - - if the ground truth label is for real images or fake images

		Returns:
			the calculated loss.
		"""
		if self.gan_mode in ['lsgan', 'vanilla']:
			target_tensor = self.get_target_tensor(prediction, target_is_real)
			loss = self.loss(prediction, target_tensor)
		elif self.gan_mode == 'wgangp':
			if target_is_real:
				loss = -prediction.mean()
			else:
				loss = prediction.mean()
		return loss


def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):
	"""Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028

	Arguments:
		netD (network)              -- discriminator network
		real_data (tensor array)    -- real images
		fake_data (tensor array)    -- generated images from the generator
		device (str)                -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
		type (str)                  -- if we mix real and fake data or not [real | fake | mixed].
		constant (float)            -- the constant used in formula ( | |gradient||_2 - constant)^2
		lambda_gp (float)           -- weight for this loss

	Returns the gradient penalty loss
	"""
	if lambda_gp > 0.0:
		if type == 'real':   # either use real images, fake images, or a linear interpolation of two.
			interpolatesv = real_data
		elif type == 'fake':
			interpolatesv = fake_data
		elif type == 'mixed':
			alpha = torch.rand(real_data.shape[0], 1, device=device)
			alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
			interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
		else:
			raise NotImplementedError('{} not implemented'.format(type))
		interpolatesv.requires_grad_(True)
		disc_interpolates = netD(interpolatesv)
		gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
										grad_outputs=torch.ones(disc_interpolates.size()).to(device),
										create_graph=True, retain_graph=True, only_inputs=True)
		gradients = gradients[0].view(real_data.size(0), -1)  # flat the data
		gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp        # added eps
		return gradient_penalty, gradients
	else:
		return 0.0, None


class ResnetGenerator(nn.Module):
	"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.

	We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
	"""

	def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
		"""Construct a Resnet-based generator

		Parameters:
			input_nc (int)      -- the number of channels in input images
			output_nc (int)     -- the number of channels in output images
			ngf (int)           -- the number of filters in the last conv layer
			norm_layer          -- normalization layer
			use_dropout (bool)  -- if use dropout layers
			n_blocks (int)      -- the number of ResNet blocks
			padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
		"""
		assert(n_blocks >= 0)
		super(ResnetGenerator, self).__init__()
		if type(norm_layer) == functools.partial:
			use_bias = norm_layer.func == nn.InstanceNorm2d
		else:
			use_bias = norm_layer == nn.InstanceNorm2d

		model = [nn.ReflectionPad2d(3),
				 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
				 norm_layer(ngf),
				 nn.ReLU(True)]

		n_downsampling = 2
		for i in range(n_downsampling):  # add downsampling layers
			mult = 2 ** i
			model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
					  norm_layer(ngf * mult * 2),
					  nn.ReLU(True)]

		mult = 2 ** n_downsampling
		for i in range(n_blocks):       # add ResNet blocks

			model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

		for i in range(n_downsampling):  # add upsampling layers
			mult = 2 ** (n_downsampling - i)
			model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
										 kernel_size=3, stride=2,
										 padding=1, output_padding=1,
										 bias=use_bias),
					  norm_layer(int(ngf * mult / 2)),
					  nn.ReLU(True)]
		model += [nn.ReflectionPad2d(3)]
		model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
		model += [nn.Tanh()]

		self.model = nn.Sequential(*model)

	def forward(self, input):
		"""Standard forward"""
		return self.model(input)

class ResnetStyle2Generator(nn.Module):
	def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0):
		"""Construct a Resnet-based generator

		Parameters:
			input_nc (int)      -- the number of channels in input images
			output_nc (int)     -- the number of channels in output images
			ngf (int)           -- the number of filters in the last conv layer
			norm_layer          -- normalization layer
			use_dropout (bool)  -- if use dropout layers
			n_blocks (int)      -- the number of ResNet blocks
			padding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero
		"""
		assert(n_blocks >= 0)
		super(ResnetStyle2Generator, self).__init__()
		if type(norm_layer) == functools.partial:
			use_bias = norm_layer.func == nn.InstanceNorm2d
		else:
			use_bias = norm_layer == nn.InstanceNorm2d

		model0 = [nn.ReflectionPad2d(3),
				 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
				 norm_layer(ngf),
				 nn.ReLU(True)]

		n_downsampling = 2
		for i in range(n_downsampling):  # add downsampling layers
			mult = 2 ** i
			model0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
					  norm_layer(ngf * mult * 2),
					  nn.ReLU(True)]
		
		mult = 2 ** n_downsampling
		for i in range(model0_res):       # add ResNet blocks
			model0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

		model = []
		model += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),
					  norm_layer(ngf * mult),
					  nn.ReLU(True)]

		for i in range(n_blocks-model0_res):       # add ResNet blocks
			model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]

		for i in range(n_downsampling):  # add upsampling layers
			mult = 2 ** (n_downsampling - i)
			model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
										 kernel_size=3, stride=2,
										 padding=1, output_padding=1,
										 bias=use_bias),
					  norm_layer(int(ngf * mult / 2)),
					  nn.ReLU(True)]
		model += [nn.ReflectionPad2d(3)]
		model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
		model += [nn.Tanh()]

		self.model0 = nn.Sequential(*model0)
		self.model = nn.Sequential(*model)
		#print(list(self.modules()))

	def forward(self, input1, input2):
		"""Standard forward"""
		f1 = self.model0(input1)
		y1 = torch.cat([f1, input2], 1)
		return self.model(y1)

class ResnetBlock(nn.Module):
	"""Define a Resnet block"""

	def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3):
		"""Initialize the Resnet block

		A resnet block is a conv block with skip connections
		We construct a conv block with build_conv_block function,
		and implement skip connections in <forward> function.
		Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
		"""
		super(ResnetBlock, self).__init__()
		self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, kernel)

	def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3):
		"""Construct a convolutional block.

		Parameters:
			dim (int)           -- the number of channels in the conv layer.
			padding_type (str)  -- the name of padding layer: reflect | replicate | zero
			norm_layer          -- normalization layer
			use_dropout (bool)  -- if use dropout layers.
			use_bias (bool)     -- if the conv layer uses bias or not

		Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
		"""
		conv_block = []
		p = 0
		pad = int((kernel-1)/2)
		if padding_type == 'reflect':#by default
			conv_block += [nn.ReflectionPad2d(pad)]
		elif padding_type == 'replicate':
			conv_block += [nn.ReplicationPad2d(pad)]
		elif padding_type == 'zero':
			p = pad
		else:
			raise NotImplementedError('padding [%s] is not implemented' % padding_type)

		conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
		if use_dropout:
			conv_block += [nn.Dropout(0.5)]

		p = 0
		if padding_type == 'reflect':
			conv_block += [nn.ReflectionPad2d(pad)]
		elif padding_type == 'replicate':
			conv_block += [nn.ReplicationPad2d(pad)]
		elif padding_type == 'zero':
			p = pad
		else:
			raise NotImplementedError('padding [%s] is not implemented' % padding_type)
		conv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim)]

		return nn.Sequential(*conv_block)

	def forward(self, x):
		"""Forward function (with skip connections)"""
		out = x + self.conv_block(x)  # add skip connections
		return out


class UnetGenerator(nn.Module):
	"""Create a Unet-based generator"""

	def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
		"""Construct a Unet generator
		Parameters:
			input_nc (int)  -- the number of channels in input images
			output_nc (int) -- the number of channels in output images
			num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
								image of size 128x128 will become of size 1x1 # at the bottleneck
			ngf (int)       -- the number of filters in the last conv layer
			norm_layer      -- normalization layer

		We construct the U-Net from the innermost layer to the outermost layer.
		It is a recursive process.
		"""
		super(UnetGenerator, self).__init__()
		# construct unet structure
		unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer
		for i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters
			unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
		# gradually reduce the number of filters from ngf * 8 to ngf
		unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
		unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
		unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
		self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer

	def forward(self, input):
		"""Standard forward"""
		return self.model(input)


class UnetSkipConnectionBlock(nn.Module):
	"""Defines the Unet submodule with skip connection.
		X -------------------identity----------------------
		|-- downsampling -- |submodule| -- upsampling --|
	"""

	def __init__(self, outer_nc, inner_nc, input_nc=None,
				 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
		"""Construct a Unet submodule with skip connections.

		Parameters:
			outer_nc (int) -- the number of filters in the outer conv layer
			inner_nc (int) -- the number of filters in the inner conv layer
			input_nc (int) -- the number of channels in input images/features
			submodule (UnetSkipConnectionBlock) -- previously defined submodules
			outermost (bool)    -- if this module is the outermost module
			innermost (bool)    -- if this module is the innermost module
			norm_layer          -- normalization layer
			user_dropout (bool) -- if use dropout layers.
		"""
		super(UnetSkipConnectionBlock, self).__init__()
		self.outermost = outermost
		if type(norm_layer) == functools.partial:
			use_bias = norm_layer.func == nn.InstanceNorm2d
		else:
			use_bias = norm_layer == nn.InstanceNorm2d
		if input_nc is None:
			input_nc = outer_nc
		downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
							 stride=2, padding=1, bias=use_bias)
		downrelu = nn.LeakyReLU(0.2, True)
		downnorm = norm_layer(inner_nc)
		uprelu = nn.ReLU(True)
		upnorm = norm_layer(outer_nc)

		if outermost:
			upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
										kernel_size=4, stride=2,
										padding=1)
			down = [downconv]
			up = [uprelu, upconv, nn.Tanh()]
			model = down + [submodule] + up
		elif innermost:
			upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
										kernel_size=4, stride=2,
										padding=1, bias=use_bias)
			down = [downrelu, downconv]
			up = [uprelu, upconv, upnorm]
			model = down + up
		else:
			upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
										kernel_size=4, stride=2,
										padding=1, bias=use_bias)
			down = [downrelu, downconv, downnorm]
			up = [uprelu, upconv, upnorm]

			if use_dropout:
				model = down + [submodule] + up + [nn.Dropout(0.5)]
			else:
				model = down + [submodule] + up

		self.model = nn.Sequential(*model)

	def forward(self, x):
		if self.outermost:
			return self.model(x)
		else:   # add skip connections
			return torch.cat([x, self.model(x)], 1)


class NLayerDiscriminator(nn.Module):
	"""Defines a PatchGAN discriminator"""

	def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
		"""Construct a PatchGAN discriminator

		Parameters:
			input_nc (int)  -- the number of channels in input images
			ndf (int)       -- the number of filters in the last conv layer
			n_layers (int)  -- the number of conv layers in the discriminator
			norm_layer      -- normalization layer
		"""
		super(NLayerDiscriminator, self).__init__()
		if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
			use_bias = norm_layer.func != nn.BatchNorm2d
		else:
			use_bias = norm_layer != nn.BatchNorm2d

		kw = 4
		padw = 1
		sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
		nf_mult = 1
		nf_mult_prev = 1
		for n in range(1, n_layers):  # gradually increase the number of filters
			nf_mult_prev = nf_mult
			nf_mult = min(2 ** n, 8)
			sequence += [
				nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
				norm_layer(ndf * nf_mult),
				nn.LeakyReLU(0.2, True)
			]

		nf_mult_prev = nf_mult
		nf_mult = min(2 ** n_layers, 8)
		sequence += [
			nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
			norm_layer(ndf * nf_mult),
			nn.LeakyReLU(0.2, True)
		]

		sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map
		self.model = nn.Sequential(*sequence)

	def forward(self, input):
		"""Standard forward."""
		return self.model(input)


class NLayerDiscriminatorCls(nn.Module):
	"""Defines a PatchGAN discriminator"""

	def __init__(self, input_nc, ndf=64, n_layers=3, n_class=3, norm_layer=nn.BatchNorm2d):
		"""Construct a PatchGAN discriminator

		Parameters:
			input_nc (int)  -- the number of channels in input images
			ndf (int)       -- the number of filters in the last conv layer
			n_layers (int)  -- the number of conv layers in the discriminator
			norm_layer      -- normalization layer
		"""
		super(NLayerDiscriminatorCls, self).__init__()
		if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
			use_bias = norm_layer.func != nn.BatchNorm2d
		else:
			use_bias = norm_layer != nn.BatchNorm2d

		kw = 4
		padw = 1
		sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
		nf_mult = 1
		nf_mult_prev = 1
		for n in range(1, n_layers):  # gradually increase the number of filters
			nf_mult_prev = nf_mult
			nf_mult = min(2 ** n, 8)
			sequence += [
				nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
				norm_layer(ndf * nf_mult),
				nn.LeakyReLU(0.2, True)
			]

		nf_mult_prev = nf_mult
		nf_mult = min(2 ** n_layers, 8)
		sequence1 = [
			nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
			norm_layer(ndf * nf_mult),
			nn.LeakyReLU(0.2, True)
		]
		sequence1 += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map

		sequence2 = [
			nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
			norm_layer(ndf * nf_mult),
			nn.LeakyReLU(0.2, True)
		]
		sequence2 += [
			nn.Conv2d(ndf * nf_mult, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
			norm_layer(ndf * nf_mult),
			nn.LeakyReLU(0.2, True)
		]
		sequence2 += [
			nn.Conv2d(ndf * nf_mult, n_class, kernel_size=16, stride=1, padding=0, bias=use_bias)]


		self.model0 = nn.Sequential(*sequence)
		self.model1 = nn.Sequential(*sequence1)
		self.model2 = nn.Sequential(*sequence2)
		print(list(self.modules()))

	def forward(self, input):
		"""Standard forward."""
		feat = self.model0(input)
		# patchGAN output (1 * 62 * 62)
		patch = self.model1(feat)
		# class output (3 * 1 * 1)
		classl = self.model2(feat)
		return patch, classl.view(classl.size(0), -1)


class PixelDiscriminator(nn.Module):
	"""Defines a 1x1 PatchGAN discriminator (pixelGAN)"""

	def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
		"""Construct a 1x1 PatchGAN discriminator

		Parameters:
			input_nc (int)  -- the number of channels in input images
			ndf (int)       -- the number of filters in the last conv layer
			norm_layer      -- normalization layer
		"""
		super(PixelDiscriminator, self).__init__()
		if type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters
			use_bias = norm_layer.func != nn.InstanceNorm2d
		else:
			use_bias = norm_layer != nn.InstanceNorm2d

		self.net = [
			nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
			nn.LeakyReLU(0.2, True),
			nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
			norm_layer(ndf * 2),
			nn.LeakyReLU(0.2, True),
			nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]

		self.net = nn.Sequential(*self.net)

	def forward(self, input):
		"""Standard forward."""
		return self.net(input)


class HED(nn.Module):
	def __init__(self):
		super(HED, self).__init__()

		self.moduleVggOne = nn.Sequential(
			nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False),
			nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False)
		)

		self.moduleVggTwo = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False),
			nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False)
		)

		self.moduleVggThr = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False),
			nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False),
			nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False)
		)

		self.moduleVggFou = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False),
			nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False),
			nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False)
		)

		self.moduleVggFiv = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False),
			nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False),
			nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
			nn.ReLU(inplace=False)
		)

		self.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)
		self.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)
		self.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)
		self.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)
		self.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)

		self.moduleCombine = nn.Sequential(
			nn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),
			nn.Sigmoid()
		)
		
	def forward(self, tensorInput):
		tensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793
		tensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762
		tensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434
		
		tensorInput = torch.cat([ tensorBlue, tensorGreen, tensorRed ], 1)
		
		tensorVggOne = self.moduleVggOne(tensorInput)
		tensorVggTwo = self.moduleVggTwo(tensorVggOne)
		tensorVggThr = self.moduleVggThr(tensorVggTwo)
		tensorVggFou = self.moduleVggFou(tensorVggThr)
		tensorVggFiv = self.moduleVggFiv(tensorVggFou)
		
		tensorScoreOne = self.moduleScoreOne(tensorVggOne)
		tensorScoreTwo = self.moduleScoreTwo(tensorVggTwo)
		tensorScoreThr = self.moduleScoreThr(tensorVggThr)
		tensorScoreFou = self.moduleScoreFou(tensorVggFou)
		tensorScoreFiv = self.moduleScoreFiv(tensorVggFiv)
		
		tensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
		tensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
		tensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
		tensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
		tensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)
		
		return self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1))



================================================
FILE: models/networks_basic.py
================================================

from __future__ import absolute_import

import sys
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.autograd import Variable
import numpy as np
from pdb import set_trace as st
from skimage import color
from IPython import embed
from . import pretrained_networks as pn

from util import util

def spatial_average(in_tens, keepdim=True):
    return in_tens.mean([2,3],keepdim=keepdim)

def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
    in_H = in_tens.shape[2]
    scale_factor = 1.*out_H/in_H

    return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)

# Learned perceptual metric
class PNetLin(nn.Module):
    def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
        super(PNetLin, self).__init__()

        self.pnet_type = pnet_type
        self.pnet_tune = pnet_tune
        self.pnet_rand = pnet_rand
        self.spatial = spatial
        self.lpips = lpips
        self.version = version
        self.scaling_layer = ScalingLayer()

        if(self.pnet_type in ['vgg','vgg16']):
            net_type = pn.vgg16
            self.chns = [64,128,256,512,512]
        elif(self.pnet_type=='alex'):
            net_type = pn.alexnet
            self.chns = [64,192,384,256,256]
        elif(self.pnet_type=='squeeze'):
            net_type = pn.squeezenet
            self.chns = [64,128,256,384,384,512,512]
        self.L = len(self.chns)

        self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)

        if(lpips):
            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
            self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
            if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
                self.lins+=[self.lin5,self.lin6]

    def forward(self, in0, in1, retPerLayer=False):
        # v0.0 - original release had a bug, where input was not scaled
        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk]-feats1[kk])**2

        if(self.lpips):
            if(self.spatial):
                res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
            else:
                res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
        else:
            if(self.spatial):
                res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
            else:
                res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]

        val = res[0]
        for l in range(1,self.L):
            val += res[l]
        
        if(retPerLayer):
            return (val, res)
        else:
            return val

class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
        self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])

    def forward(self, inp):
        return (inp - self.shift.to(inp.device)) / self.scale.to(inp.device)


class NetLinLayer(nn.Module):
    ''' A single linear layer which does a 1x1 conv '''
    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(NetLinLayer, self).__init__()

        layers = [nn.Dropout(),] if(use_dropout) else []
        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
        self.model = nn.Sequential(*layers)


class Dist2LogitLayer(nn.Module):
    ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
    def __init__(self, chn_mid=32, use_sigmoid=True):
        super(Dist2LogitLayer, self).__init__()

        layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
        layers += [nn.LeakyReLU(0.2,True),]
        layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
        layers += [nn.LeakyReLU(0.2,True),]
        layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
        if(use_sigmoid):
            layers += [nn.Sigmoid(),]
        self.model = nn.Sequential(*layers)

    def forward(self,d0,d1,eps=0.1):
        return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))

class BCERankingLoss(nn.Module):
    def __init__(self, chn_mid=32):
        super(BCERankingLoss, self).__init__()
        self.net = Dist2LogitLayer(chn_mid=chn_mid)
        # self.parameters = list(self.net.parameters())
        self.loss = torch.nn.BCELoss()

    def forward(self, d0, d1, judge):
        per = (judge+1.)/2.
        self.logit = self.net.forward(d0,d1)
        return self.loss(self.logit, per)

# L2, DSSIM metrics
class FakeNet(nn.Module):
    def __init__(self, use_gpu=True, colorspace='Lab'):
        super(FakeNet, self).__init__()
        self.use_gpu = use_gpu
        self.colorspace=colorspace

class L2(FakeNet):

    def forward(self, in0, in1, retPerLayer=None):
        assert(in0.size()[0]==1) # currently only supports batchSize 1

        if(self.colorspace=='RGB'):
            (N,C,X,Y) = in0.size()
            value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
            return value
        elif(self.colorspace=='Lab'):
            value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 
                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
            ret_var = Variable( torch.Tensor((value,) ) )
            if(self.use_gpu):
                ret_var = ret_var.cuda()
            return ret_var

class DSSIM(FakeNet):

    def forward(self, in0, in1, retPerLayer=None):
        assert(in0.size()[0]==1) # currently only supports batchSize 1

        if(self.colorspace=='RGB'):
            value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
        elif(self.colorspace=='Lab'):
            value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 
                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
        ret_var = Variable( torch.Tensor((value,) ) )
        if(self.use_gpu):
            ret_var = ret_var.cuda()
        return ret_var

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print('Network',net)
    print('Total number of parameters: %d' % num_params)


================================================
FILE: models/pretrained_networks.py
================================================
from collections import namedtuple
import torch
from torchvision import models
from IPython import embed

class squeezenet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(squeezenet, self).__init__()
        pretrained_features = models.squeezenet1_1(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.slice6 = torch.nn.Sequential()
        self.slice7 = torch.nn.Sequential()
        self.N_slices = 7
        for x in range(2):
            self.slice1.add_module(str(x), pretrained_features[x])
        for x in range(2,5):
            self.slice2.add_module(str(x), pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), pretrained_features[x])
        for x in range(10, 11):
            self.slice5.add_module(str(x), pretrained_features[x])
        for x in range(11, 12):
            self.slice6.add_module(str(x), pretrained_features[x])
        for x in range(12, 13):
            self.slice7.add_module(str(x), pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        h = self.slice6(h)
        h_relu6 = h
        h = self.slice7(h)
        h_relu7 = h
        vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
        out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)

        return out


class alexnet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(alexnet, self).__init__()
        alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(2):
            self.slice1.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(2, 5):
            self.slice2.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(5, 8):
            self.slice3.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(8, 10):
            self.slice4.add_module(str(x), alexnet_pretrained_features[x])
        for x in range(10, 12):
            self.slice5.add_module(str(x), alexnet_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1 = h
        h = self.slice2(h)
        h_relu2 = h
        h = self.slice3(h)
        h_relu3 = h
        h = self.slice4(h)
        h_relu4 = h
        h = self.slice5(h)
        h_relu5 = h
        alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)

        return out

class vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)

        return out



class resnet(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True, num=18):
        super(resnet, self).__init__()
        if(num==18):
            self.net = models.resnet18(pretrained=pretrained)
        elif(num==34):
            self.net = models.resnet34(pretrained=pretrained)
        elif(num==50):
            self.net = models.resnet50(pretrained=pretrained)
        elif(num==101):
            self.net = models.resnet101(pretrained=pretrained)
        elif(num==152):
            self.net = models.resnet152(pretrained=pretrained)
        self.N_slices = 5

        self.conv1 = self.net.conv1
        self.bn1 = self.net.bn1
        self.relu = self.net.relu
        self.maxpool = self.net.maxpool
        self.layer1 = self.net.layer1
        self.layer2 = self.net.layer2
        self.layer3 = self.net.layer3
        self.layer4 = self.net.layer4

    def forward(self, X):
        h = self.conv1(X)
        h = self.bn1(h)
        h = self.relu(h)
        h_relu1 = h
        h = self.maxpool(h)
        h = self.layer1(h)
        h_conv2 = h
        h = self.layer2(h)
        h_conv3 = h
        h = self.layer3(h)
        h_conv4 = h
        h = self.layer4(h)
        h_conv5 = h

        outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)

        return out


================================================
FILE: models/test_3styles_model.py
================================================
from .base_model import BaseModel
from . import networks
import torch

class Test3StylesModel(BaseModel):
    """ This TesteModel can be used to generate CycleGAN results for only one direction.
    This model will automatically set '--dataset_mode single', which only loads the images from one collection.

    See the test instruction for more details.
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        The model can only be used during test time. It requires '--dataset_mode single'.
        You need to specify the network using the option '--model_suffix'.
        """
        assert not is_train, 'TestModel cannot be used during training time'
        parser.set_defaults(dataset_mode='single')
        parser.add_argument('--style_control', type=int, default=0, help='not set style_vec in dataset')
        parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A')
        parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0')
        parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')

        return parser

    def __init__(self, opt):
        assert(not opt.isTrain)
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts  will call <BaseModel.get_current_losses>
        self.loss_names = []
        # specify the images you want to save/display. The training/test scripts  will call <BaseModel.get_current_visuals>
        self.visual_names = ['real', 'fake1', 'fake2', 'fake3']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        self.model_names = ['G_A']  # only generator is needed.
        print(opt.netga)
        print('model0_res', opt.model0_res)
        print('model1_res', opt.model1_res)
        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,
                                    not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)
        
        setattr(self, 'netG_A', self.netG)  # store netG in self.

    def set_input(self, input):
        self.real = input['A'].to(self.device)
        self.image_paths = input['A_paths']
        self.style1 = torch.Tensor([1, 0, 0]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device)
        self.style2 = torch.Tensor([0, 1, 0]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device)
        self.style3 = torch.Tensor([0, 0, 1]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device)

    def forward(self):
        """Run forward pass."""
        self.fake1 = self.netG(self.real, self.style1)
        self.fake2 = self.netG(self.real, self.style2)
        self.fake3 = self.netG(self.real, self.style3)

    def optimize_parameters(self):
        """No optimization for test model."""
        pass


================================================
FILE: models/test_model.py
================================================
from .base_model import BaseModel
from . import networks
import torch

class TestModel(BaseModel):
    """ This TesteModel can be used to generate CycleGAN results for only one direction.
    This model will automatically set '--dataset_mode single', which only loads the images from one collection.

    See the test instruction for more details.
    """
    @staticmethod
    def modify_commandline_options(parser, is_train=True):
        """Add new dataset-specific options, and rewrite default values for existing options.

        Parameters:
            parser          -- original option parser
            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.

        Returns:
            the modified parser.

        The model can only be used during test time. It requires '--dataset_mode single'.
        You need to specify the network using the option '--model_suffix'.
        """
        assert not is_train, 'TestModel cannot be used during training time'
        parser.set_defaults(dataset_mode='single')
        parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')
        parser.add_argument('--style_control', type=int, default=1, help='use style_control')
        parser.add_argument('--sfeature_mode', type=str, default='vgg19_softmax', help='vgg19 softmax as feature')
        parser.add_argument('--sinput', type=str, default='sind', help='use which one for style input')
        parser.add_argument('--sind', type=int, default=0, help='one hot for sfeature')
        parser.add_argument('--svec', type=str, default='1,0,0', help='3-dim vec')
        parser.add_argument('--simg', type=str, default='Yann_Legendre-053', help='drawing example for style')
        parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A')
        parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0')
        parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')

        return parser

    def __init__(self, opt):
        """Initialize the pix2pix class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        assert(not opt.isTrain)
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts  will call <BaseModel.get_current_losses>
        self.loss_names = []
        # specify the images you want to save/display. The training/test scripts  will call <BaseModel.get_current_visuals>
        self.visual_names = ['real', 'fake', 'rec']
        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
        self.model_names = ['G' + opt.model_suffix, 'G_B']  # only generator is needed.
        if not self.opt.style_control:
            self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,
                                      opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        else:
            print(opt.netga)
            print('model0_res', opt.model0_res)
            print('model1_res', opt.model1_res)
            self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)
        
        self.netGB = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG,
                                      opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        # assigns the model to self.netG_[suffix] so that it can be loaded
        # please see <BaseModel.load_networks>
        setattr(self, 'netG' + opt.model_suffix, self.netG)  # store netG in self.
        setattr(self, 'netG_B', self.netGB)  # store netGB in self.

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input: a dictionary that contains the data itself and its metadata information.

        We need to use 'single_dataset' dataset mode. It only load images from one domain.
        """
        self.real = input['A'].to(self.device)
        self.image_paths = input['A_paths']
        if self.opt.style_control:
            self.style = input['B_style']

    def forward(self):
        """Run forward pass."""
        if not self.opt.style_control:
            self.fake = self.netG(self.real)  # G(real)
        else:
            print(torch.mean(self.style,(2,3)),'style_control')
            self.fake = self.netG(self.real, self.style)
        self.rec = self.netG_B(self.fake)

    def optimize_parameters(self):
        """No optimization for test model."""
        pass


================================================
FILE: options/__init__.py
================================================
"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""


================================================
FILE: options/base_options.py
================================================
import argparse
import os
from util import util
import torch
import models
import data


class BaseOptions():
    """This class defines options used during both training and test time.

    It also implements several helper functions such as parsing, printing, and saving the options.
    It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
    """

    def __init__(self):
        """Reset the class; indicates the class hasn't been initailized"""
        self.initialized = False

    def initialize(self, parser):
        """Define the common options that are used in both training and test."""
        # basic parameters
        parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
        parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
        parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
        parser.add_argument('--gpu_ids_p', type=str, default='0', help='gpu ids for pretrained auxiliary models: e.g. 0  0,1,2, 0,2. use -1 for CPU')
        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
        # model parameters
        parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')
        parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
        parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
        parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')
        parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')
        parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
        parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')
        parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
        parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')
        parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')
        parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
        parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
        # dataset parameters
        parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
        parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')
        parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
        parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
        parser.add_argument('--batch_size', type=int, default=1, help='input batch size')
        parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')
        parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')
        parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
        parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')
        parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
        parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')
        # additional parameters
        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
        parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')
        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
        parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
        self.initialized = True
        return parser

    def gather_options(self):
        """Initialize our parser with basic options(only once).
        Add additional model-specific and dataset-specific options.
        These options are defined in the <modify_commandline_options> function
        in model and dataset classes.
        """
        if not self.initialized:  # check if it has been initialized
            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
            parser = self.initialize(parser)

        # get the basic options
        opt, _ = parser.parse_known_args()

        # modify model-related parser options
        model_name = opt.model
        model_option_setter = models.get_option_setter(model_name)
        parser = model_option_setter(parser, self.isTrain)
        opt, _ = parser.parse_known_args()  # parse again with new defaults

        # modify dataset-related parser options
        dataset_name = opt.dataset_mode
        dataset_option_setter = data.get_option_setter(dataset_name)
        parser = dataset_option_setter(parser, self.isTrain)

        # save and return the parser
        self.parser = parser
        return parser.parse_args()

    def print_options(self, opt):
        """Print and save options

        It will print both current options and default values(if different).
        It will save options into a text file / [checkpoints_dir] / opt.txt
        """
        message = ''
        message += '----------------- Options ---------------\n'
        for k, v in sorted(vars(opt).items()):
            comment = ''
            default = self.parser.get_default(k)
            if v != default:
                comment = '\t[default: %s]' % str(default)
            message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
        message += '----------------- End -------------------'
        print(message)

        # save to the disk
        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
        util.mkdirs(expr_dir)
        file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
        with open(file_name, 'wt') as opt_file:
            opt_file.write(message)
            opt_file.write('\n')

    def parse(self):
        """Parse our options, create checkpoints directory suffix, and set up gpu device."""
        opt = self.gather_options()
        opt.isTrain = self.isTrain   # train or test

        # process opt.suffix
        if opt.suffix:
            suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
            opt.name = opt.name + suffix

        self.print_options(opt)

        # set gpu ids
        str_ids = opt.gpu_ids.split(',')
        opt.gpu_ids = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0:
                opt.gpu_ids.append(id)
        if len(opt.gpu_ids) > 0:
            torch.cuda.set_device(opt.gpu_ids[0])
        
        # set gpu ids
        str_ids = opt.gpu_ids_p.split(',')
        opt.gpu_ids_p = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0:
                opt.gpu_ids_p.append(id)

        self.opt = opt
        return self.opt


================================================
FILE: options/test_options.py
================================================
from .base_options import BaseOptions


class TestOptions(BaseOptions):
    """This class includes test options.

    It also includes shared options defined in BaseOptions.
    """

    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)  # define shared options
        parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
        parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
        parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
        # Dropout and Batchnorm has different behavioir during training and test.
        parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
        parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
        parser.add_argument('--imagefolder', type=str, default='images', help='subfolder to save images')
        # rewrite devalue values
        parser.set_defaults(model='test')
        # To avoid cropping, the load_size should be the same as crop_size
        parser.set_defaults(load_size=parser.get_default('crop_size'))
        self.isTrain = False
        return parser


================================================
FILE: options/train_options.py
================================================
from .base_options import BaseOptions


class TrainOptions(BaseOptions):
    """This class includes training options.

    It also includes shared options defined in BaseOptions.
    """

    def initialize(self, parser):
        parser = BaseOptions.initialize(self, parser)
        # visdom and HTML visualization parameters
        parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
        parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
        parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
        parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
        parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
        parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
        parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
        # network saving and loading parameters
        parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
        parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
        parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        # training parameters
        parser.add_argument('--n_epochs', type=int, default=200, help='the end epoch count')
        parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
        parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
        parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
        parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
        parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
        parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
        parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
        parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')

        self.isTrain = True
        return parser


================================================
FILE: portrait_drawing_resources.md
================================================

- Charles Burns (style1): https://www.pinterest.co.uk/johns59/charles-burns-fan-club/
- Yann Legendre (style1): http://www.yannlegendre.com/project/portraits/
- Kathryn Rathke (style2):
https://www.kathrynrathke.com/
- Vectorportal (style3): https://www.pinterest.co.uk/vectorportal/celebrity-vector-illustrations/

================================================
FILE: preprocess/face_align_512.m
================================================
function [trans_img]=face_align_512(impath,facial5point,savedir)
% align the faces by similarity transformation.
% using 5 facial landmarks: 2 eyes, nose, 2 mouth corners.
%   impath: path to image
%   facial5point: 5x2 size, 5 facial landmark positions, detected by MTCNN
%   savedir: savedir for cropped image and transformed facial landmarks

%% alignment settings
imgSize = [512,512];
coord5point = [180,230;
    300,230;
    240,301;
    186,365.6;
    294,365.6];%480x480
coord5point = (coord5point-240)/560 * 512 + 256;

%% face alignment

% load and align, resize image to imgSize
img      = imread(impath);
facial5point = double(facial5point);
transf   = cp2tform(facial5point, coord5point, 'similarity');
trans_img  = imtransform(img, transf, 'XData', [1 imgSize(2)],...
                                    'YData', [1 imgSize(1)],...
                                    'Size', imgSize,...
                                    'FillValues', [255;255;255]);
trans_facial5point = round(tformfwd(transf,facial5point));


%% save results
if ~exist(savedir,'dir')
    mkdir(savedir)
end
[~,name,~] = fileparts(impath);
% save trans_img
imwrite(trans_img, fullfile(savedir,[name,'_resized.png']));
fprintf('write aligned image to %s\n',fullfile(savedir,[name,'_resized.png']));

%% show results
imshow(trans_img); hold on;
plot(trans_facial5point(:,1),trans_facial5point(:,2),'b');
plot(trans_facial5point(:,1),trans_facial5point(:,2),'r+');

end

================================================
FILE: preprocess/readme.md
================================================
## Preprocessing steps

During training, face photos and drawings are aligned and have nose,eyes,lips mask detected. 

During test, the alignment step is optional and the masks are not needed.

### 1. Align, resize, crop images to 512x512

All training and testing images in our model are aligned using facial landmarks. And landmarks after alignment are needed in our code.

- First, 5 facial landmark for a face photo need to be detected (we detect using [MTCNN](https://github.com/kpzhang93/MTCNN_face_detection_alignment)(MTCNNv1)).

- Then, we provide a matlab function in `face_align_512.m` to align, resize and crop face photos (and corresponding drawings) to 512x512. Call this function in MATLAB to align the image to 512x512.
For example, for `ia_selfie_10515.jpg` in `example` dir, 5 detected facial landmark is saved in `example/ia_selfie_10515_facial5point.mat`. Call following in MATLAB:
```bash
load('example/ia_selfie_10515_facial5point.mat');
[trans_img]=face_align_512('example/ia_selfie_10515.jpg',facial5point,'example');
```

This will align the image and output aligned image  in `example` folder.
See `face_align_512.m` for more instructions.


### 2. Prepare nose,eyes,lips masks

In our work, we use the face parsing network in https://github.com/cientgu/Mask_Guided_Portrait_Editing to get nose,eyes,lips regions and then dilate the regions to make them cover these facial features (some examples are shown in `example` folder).

- The background masks need to be copied to `datasets/portrait_drawing/train/A(B)(_eyes)(_lips)`, and has the **same filename** with aligned face photos.  


================================================
FILE: readme.md
================================================

# Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping

We provide PyTorch implementations for our CVPR 2020 paper "Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping". [paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Yi_Unpaired_Portrait_Drawing_Generation_via_Asymmetric_Cycle_Mapping_CVPR_2020_paper.pdf), [suppl](https://openaccess.thecvf.com/content_CVPR_2020/supplemental/Yi_Unpaired_Portrait_Drawing_CVPR_2020_supplemental.pdf).

This project generates multi-style artistic portrait drawings from face photos using a GAN-based model.

[[Jittor implementation]](https://github.com/yiranran/Unpaired-Portrait-Drawing-Jittor)


## Our Proposed Framework
 
<img src = 'imgs/architecture.jpg'>

## Sample Results
From left to right: input, output(style1), output(style2), output(style3)
<img src = 'imgs/results.jpg'>

## Prerequisites
- Linux or macOS
- Python 3
- CPU or NVIDIA GPU + CUDA CuDNN


## Installation
- To install the dependencies, run
```bash
pip install -r requirements.txt
```

## Colab
A colab demo is [here](https://colab.research.google.com/drive/1U1fPXD1JukuKPOrhGMX1iaJC-d8_RUYr).

## Test steps (apply a pretrained model)

- 1. Download pre-trained models from [BaiduYun](https://pan.baidu.com/s/1_9Fy8mRpTQp6AvqhHsfQAQ)(extract code:c9h7) or [GoogleDrive](https://drive.google.com/drive/folders/1FzOcdlMYhvK_nyLCe8wnwotMphhIoiYt?usp=sharing) and rename the folder to `checkpoints`.

- 2. Test for example photos: generate artistic portrait drawings for example photos in the folder `./examples` using
``` bash
# with GPU
python test_seq_style.py
# without GPU
python test_seq_style.py --gpu -1
```
The test results will be saved to a html file here: `./results/pretrained/test_200/index3styles.html`.
The result images are saved in `./results/pretrained/test_200/images3styles`,
where `real`, `fake1`, `fake2`, `fake3` correspond to input face photo, style1 drawing, style2 drawing, style3 drawing respectively.

<img src = 'imgs/how_to_crop.jpg'>

- 3. To test on your own photos: First use an image editor to crop the face region of your photo (or use an optional preprocess [here](preprocess/readme.md)). Then specify the folder that contains test photos using option `--dataroot`, specify save folder name using option `--savefolder` and run the above command again:

``` bash
# with GPU
python test_seq_style.py --dataroot [input_folder] --savefolder [save_folder_name]
# without GPU
python test_seq_style.py --gpu -1 --dataroot [input_folder] --savefolder [save_folder_name]
# E.g.
python test_seq_style.py --gpu -1 --dataroot ./imgs/test1 --savefolder 3styles_test1
```
The test results will be saved to a html file here: `./results/pretrained/test_200/index[save_folder_name].html`.
The result images are saved in `./results/pretrained/test_200/images[save_folder_name]`.
An example html screenshot is shown below:
<img src = 'imgs/result_html.jpg'>

You can contact email yr16@mails.tsinghua.edu.cn for any questions.

## Train steps

- 1. Prepare for the dataset: 1) download face photos and portrait drawings from internet (e.g. [resources](portrait_drawing_resources.md)). 2) align, crop photos and drawings & 3) prepare nose, eyes, lips masks according to [preprocess instructions](preprocess/readme.md). 3) put aligned photos under `./datasets/portrait_drawing/train/A`, aligned drawings under `./datasets/portrait_drawing/train/B`, masks under `A_nose`,`A_eyes`,`A_lips`,`B_nose`,`B_eyes`,`B_lips` respectively.

- 2. Train a 3-class style classifier and extract the 3-dim style feature (according to paper). And save the style feature of each drawing in the training set in .npy format, in folder `./datasets/portrait_drawing/train/B_feat`

A subset of our training set is [here](https://drive.google.com/file/d/1OSMOR3-uhGkoPwPFRNychJSNrpSak_23/view?usp=sharing).

- 3. Train our model
``` bash
sh ./scripts/train.sh
```
Models are saved in folder checkpoints/portrait_drawing


## Citation
If you use this code for your research, please cite our paper.

```
@inproceedings{YiLLR20,
  title     = {Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping},
  author    = {Yi, Ran and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L},
  booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition (CVPR '20)},
  pages     = {8214--8222},
  year      = {2020}
}
```

## Acknowledgments
Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).


================================================
FILE: requirements.txt
================================================
torch==1.2.0
torchvision==0.4.0
dominate==2.4.0
visdom==0.1.8.9
scipy==1.1.0
numpy==1.16.4
Pillow==6.2.1
opencv-python==4.1.0.25

================================================
FILE: scripts/train.sh
================================================
set -ex
python train.py --dataroot ./datasets/portrait_drawing --name formal --model asymmetric_cycle_gan_cls --output_nc 1 --load_size 572 --crop_size 512 --lr 0.000015 --dataset_mode unaligned_mask_stylecls --display_env asymmetric_trainset --gpu_ids 0 --gpu_ids_p 0 --niter 100 --niter_decay 200 --n_epochs 200

================================================
FILE: test.py
================================================
"""General-purpose test script for image-to-image translation.

Once you have trained your model with train.py, you can use this script to test the model.
It will load a saved model from --checkpoints_dir and save the results to --results_dir.

It first creates model and dataset given the option. It will hard-code some parameters.
It then runs inference for --num_test images and save results to an HTML file.

Example (You need to train models first or download pre-trained models from our website):
    Test a CycleGAN model (both sides):
        python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan

    Test a CycleGAN model (one side only):
        python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout

    The option '--model test' is used for generating CycleGAN results only for one side.
    This option will automatically set '--dataset_mode single', which only loads the images from one set.
    On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
    which is sometimes unnecessary. The results will be saved at ./results/.
    Use '--results_dir <directory_path_to_save_result>' to specify the results directory.

    Test a pix2pix model:
        python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA

See options/base_options.py and options/test_options.py for more test options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import os
from options.test_options import TestOptions
from data import create_dataset
from models import create_model
from util.visualizer import save_images
from util import html


if __name__ == '__main__':
    opt = TestOptions().parse()  # get test options
    # hard-code some parameters for test
    opt.num_threads = 0   # test code only supports num_threads = 1
    opt.batch_size = 1    # test code only supports batch_size = 1
    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    model = create_model(opt)      # create a model given opt.model and other options
    model.setup(opt)               # regular setup: load and print networks; create schedulers
    # create a website
    web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch))  # define the website directory
    #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch), refresh=0, folder=opt.imagefolder)
    # test with eval mode. This only affects layers like batchnorm and dropout.
    # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
    if opt.eval:
        model.eval()
    for i, data in enumerate(dataset):
        if i >= opt.num_test:  # only apply our model to opt.num_test images.
            break
        model.set_input(data)  # unpack data from data loader
        model.test()           # run inference
        visuals = model.get_current_visuals()  # get image results
        img_path = model.get_image_paths()     # get image paths
        if i % 5 == 0:  # save images to an HTML file
            print('processing (%04d)-th image... %s' % (i, img_path))
        save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, W=opt.W, H=opt.H)
    webpage.save()  # save the HTML


================================================
FILE: test_seq_style.py
================================================
import os
import argparse

def opts():
    parser = argparse.ArgumentParser()
    parser.add_argument('-g','--gpu', default = '0', type = str, help = 'gpu ids, -1 for cpu, default is 0.')
    parser.add_argument('-d','--dataroot', default = './examples', type = str, help = 'the input folder that contains test face photos, default is ./examples')
    parser.add_argument('-s','--savefolder', default = '3styles', type = str, help = 'the name of save folder that contains result images, default is 3styles')
    return parser.parse_args()

if __name__ == '__main__':
    opt = opts()
    exp = 'pretrained'
    imgsize = 512
    epoch = '200'
    dataroot = opt.dataroot
    gpu_id = opt.gpu

    # test 3 styles in one pass
    savefolder = 'images'+opt.savefolder
    os.system('python3 test.py --dataroot %s --name %s --model test_3styles --output_nc 1 --no_dropout --num_test 1000 --epoch %s --imagefolder %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,epoch,savefolder,imgsize,imgsize,gpu_id))
    print('check ./results/%s/test_%s/index%s.html'%(exp,epoch,savefolder[6:]))
    print('saved to ./results/%s/test_%s/%s'%(exp,epoch,savefolder))

    # test 3 styles separately
    '''
    for vec in [[1,0,0],[0,1,0],[0,0,1]]:
        #1,0,0 for style1; 0,1,0 for style2; 0,0,1 for style3
        svec = '%d,%d,%d' % (vec[0],vec[1],vec[2])
        savefolder = 'imagesstyle%d-%d-%d'%(vec[0],vec[1],vec[2])
        print('results/%s/test_%s/index%s.html'%(exp,epoch,savefolder[6:]))
        os.system('python3 test.py --dataroot %s --name %s --model test --output_nc 1 --no_dropout --model_suffix _A --num_test 1000 --epoch %s --imagefolder %s --sinput svec --svec %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,epoch,savefolder,svec,imgsize,imgsize,gpu_id))
    '''
        

================================================
FILE: train.py
================================================
"""General-purpose training script for image-to-image translation.

This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and
different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization).
You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').

It first creates model, dataset, and visualizer given the option.
It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models.
The script supports continue/resume training. Use '--continue_train' to resume your previous training.

Example:
    Train a CycleGAN model:
        python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
    Train a pix2pix model:
        python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA

See options/base_options.py and options/train_options.py for more training options.
See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
"""
import time
from options.train_options import TrainOptions
from data import create_dataset
from models import create_model
from util.visualizer import Visualizer
import pdb

if __name__ == '__main__':
    start = time.time()
    opt = TrainOptions().parse()   # get training options
    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options
    dataset_size = len(dataset)    # get the number of images in the dataset.
    print('The number of training images = %d' % dataset_size)

    model = create_model(opt)      # create a model given opt.model and other options
    model.setup(opt)               # regular setup: load and print networks; create schedulers
    visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots
    total_iters = 0                # the total number of training iterations

    for epoch in range(opt.epoch_count, opt.n_epochs + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
        epoch_start_time = time.time()  # timer for entire epoch
        iter_data_time = time.time()    # timer for data loading per iteration
        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch
        model.update_process(epoch)

        for i, data in enumerate(dataset):  # inner loop within one epoch
            iter_start_time = time.time()  # timer for computation per iteration
            if total_iters % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_iters += opt.batch_size
            epoch_iter += opt.batch_size
            model.set_input(data)         # unpack data from dataset and apply preprocessing
            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights
            
            if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file
                save_result = total_iters % opt.update_html_freq == 0
                model.compute_visuals()
                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk
                losses = model.get_current_losses()
                t_comp = (time.time() - iter_start_time) / opt.batch_size
                if opt.model == 'cycle_gan':
                    processes = [model.process] + model.lambda_As
                    visualizer.print_current_losses_process(epoch, epoch_iter, losses, t_comp, t_data, processes)
                else:
                    visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)

            if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations
                print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'
                model.save_networks(save_suffix)

            iter_data_time = time.time()
        if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs
            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
            model.save_networks('latest')
            model.save_networks(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
        model.update_learning_rate()                     # update learning rates at the end of every epoch.
    
    print('Total Time Taken: %d sec' % (time.time() - start))

================================================
FILE: util/__init__.py
================================================
"""This package includes a miscellaneous collection of useful helper functions."""


================================================
FILE: util/get_data.py
================================================
from __future__ import print_function
import os
import tarfile
import requests
from warnings import warn
from zipfile import ZipFile
from bs4 import BeautifulSoup
from os.path import abspath, isdir, join, basename


class GetData(object):
    """A Python script for downloading CycleGAN or pix2pix datasets.

    Parameters:
        technique (str) -- One of: 'cyclegan' or 'pix2pix'.
        verbose (bool)  -- If True, print additional information.

    Examples:
        >>> from util.get_data import GetData
        >>> gd = GetData(technique='cyclegan')
        >>> new_data_path = gd.get(save_path='./datasets')  # options will be displayed.

    Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
    and 'scripts/download_cyclegan_model.sh'.
    """

    def __init__(self, technique='cyclegan', verbose=True):
        url_dict = {
            'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',
            'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
        }
        self.url = url_dict.get(technique.lower())
        self._verbose = verbose

    def _print(self, text):
        if self._verbose:
            print(text)

    @staticmethod
    def _get_options(r):
        soup = BeautifulSoup(r.text, 'lxml')
        options = [h.text for h in soup.find_all('a', href=True)
                   if h.text.endswith(('.zip', 'tar.gz'))]
        return options

    def _present_options(self):
        r = requests.get(self.url)
        options = self._get_options(r)
        print('Options:\n')
        for i, o in enumerate(options):
            print("{0}: {1}".format(i, o))
        choice = input("\nPlease enter the number of the "
                       "dataset above you wish to download:")
        return options[int(choice)]

    def _download_data(self, dataset_url, save_path):
        if not isdir(save_path):
            os.makedirs(save_path)

        base = basename(dataset_url)
        temp_save_path = join(save_path, base)

        with open(temp_save_path, "wb") as f:
            r = requests.get(dataset_url)
            f.write(r.content)

        if base.endswith('.tar.gz'):
            obj = tarfile.open(temp_save_path)
        elif base.endswith('.zip'):
            obj = ZipFile(temp_save_path, 'r')
        else:
            raise ValueError("Unknown File Type: {0}.".format(base))

        self._print("Unpacking Data...")
        obj.extractall(save_path)
        obj.close()
        os.remove(temp_save_path)

    def get(self, save_path, dataset=None):
        """

        Download a dataset.

        Parameters:
            save_path (str) -- A directory to save the data to.
            dataset (str)   -- (optional). A specific dataset to download.
                            Note: this must include the file extension.
                            If None, options will be presented for you
                            to choose from.

        Returns:
            save_path_full (str) -- the absolute path to the downloaded data.

        """
        if dataset is None:
            selected_dataset = self._present_options()
        else:
            selected_dataset = dataset

        save_path_full = join(save_path, selected_dataset.split('.')[0])

        if isdir(save_path_full):
            warn("\n'{0}' already exists. Voiding Download.".format(
                save_path_full))
        else:
            self._print('Downloading Data...')
            url = "{0}/{1}".format(self.url, selected_dataset)
            self._download_data(url, save_path=save_path)

        return abspath(save_path_full)


================================================
FILE: util/html.py
================================================
import dominate
from dominate.tags import meta, h3, table, tr, td, p, a, img, br
import os


class HTML:
    """This HTML class allows us to save images and write texts into a single HTML file.

     It consists of functions such as <add_header> (add a text header to the HTML file),
     <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
     It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
    """

    def __init__(self, web_dir, title, refresh=0, folder='images'):
        """Initialize the HTML classes

        Parameters:
            web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
            title (str)   -- the webpage name
            refresh (int) -- how often the website refresh itself; if 0; no refreshing
        """
        self.title = title
        self.web_dir = web_dir
        #self.img_dir = os.path.join(self.web_dir, 'images')
        self.img_dir = os.path.join(self.web_dir, folder)
        self.folder = folder
        if not os.path.exists(self.web_dir):
            os.makedirs(self.web_dir)
        if not os.path.exists(self.img_dir):
            os.makedirs(self.img_dir)

        self.doc = dominate.document(title=title)
        if refresh > 0:
            with self.doc.head:
                meta(http_equiv="refresh", content=str(refresh))

    def get_image_dir(self):
        """Return the directory that stores images"""
        return self.img_dir

    def add_header(self, text):
        """Insert a header to the HTML file

        Parameters:
            text (str) -- the header text
        """
        with self.doc:
            h3(text)

    def add_images(self, ims, txts, links, width=400):
        """add images to the HTML file

        Parameters:
            ims (str list)   -- a list of image paths
            txts (str list)  -- a list of image names shown on the website
            links (str list) --  a list of hyperref links; when you click an image, it will redirect you to a new page
        """
        self.t = table(border=1, style="table-layout: fixed;")  # Insert a table
        self.doc.add(self.t)
        with self.t:
            with tr():
                for im, txt, link in zip(ims, txts, links):
                    with td(style="word-wrap: break-word;", halign="center", valign="top"):
                        with p():
                            with a(href=os.path.join('images', link)):
                                #img(style="width:%dpx" % width, src=os.path.join('images', im))
                                img(style="width:%dpx" % width, src=os.path.join(self.folder, im))
                            br()
                            p(txt)

    def save(self):
        """save the current content to the HMTL file"""
        #html_file = '%s/index.html' % self.web_dir
        name = self.folder[6:] if self.folder[:6] == 'images' else self.folder
        html_file = '%s/index%s.html' % (self.web_dir, name)
        f = open(html_file, 'wt')
        f.write(self.doc.render())
        f.close()


if __name__ == '__main__':  # we show an example usage here.
    html = HTML('web/', 'test_html')
    html.add_header('hello world')

    ims, txts, links = [], [], []
    for n in range(4):
        ims.append('image_%d.png' % n)
        txts.append('text_%d' % n)
        links.append('image_%d.png' % n)
    html.add_images(ims, txts, links)
    html.save()


================================================
FILE: util/image_pool.py
================================================
import random
import torch


class ImagePool():
    """This class implements an image buffer that stores previously generated images.

    This buffer enables us to update discriminators using a history of generated images
    rather than the ones produced by the latest generators.
    """

    def __init__(self, pool_size):
        """Initialize the ImagePool class

        Parameters:
            pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
        """
        self.pool_size = pool_size
        if self.pool_size > 0:  # create an empty pool
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        """Return an image from the pool.

        Parameters:
            images: the latest generated images from the generator

        Returns images from the buffer.

        By 50/100, the buffer will return input images.
        By 50/100, the buffer will return images previously stored in the buffer,
        and insert the current images to the buffer.
        """
        if self.pool_size == 0:  # if the buffer size is 0, do nothing
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:  # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
                    random_id = random.randint(0, self.pool_size - 1)  # randint is inclusive
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:       # by another 50% chance, the buffer will return the current image
                    return_images.append(image)
        return_images = torch.cat(return_images, 0)   # collect all the images and return
        return return_images


================================================
FILE: util/util.py
================================================
"""This module contains simple helper functions """
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os
import pdb
from scipy.io import savemat


def tensor2im(input_image, imtype=np.uint8):
    """"Converts a Tensor array into a numpy image array.

    Parameters:
        input_image (tensor) --  the input image tensor array
        imtype (type)        --  the desired type of the converted numpy array
    """
    if not isinstance(input_image, np.ndarray):
        if isinstance(input_image, torch.Tensor):  # get the data from a variable
            image_tensor = input_image.data
        else:
            return input_image
        #pdb.set_trace()
        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array
        if image_numpy.shape[0] == 1:  # grayscale to RGB
            image_numpy = np.tile(image_numpy, (3, 1, 1))
        elif image_numpy.shape[0] == 2:
            image_numpy = np.concatenate([image_numpy, image_numpy[1:2,:,:]], 0)
        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling
    else:  # if it is a numpy array, do nothing
        image_numpy = input_image
    return image_numpy.astype(imtype)
    #return np.round(image_numpy).astype(imtype),image_numpy


def diagnose_network(net, name='network'):
    """Calculate and print the mean of average absolute(gradients)

    Parameters:
        net (torch network) -- Torch network
        name (str) -- the name of the network
    """
    mean = 0.0
    count = 0
    for param in net.parameters():
        if param.grad is not None:
            mean += torch.mean(torch.abs(param.grad.data))
            count += 1
    if count > 0:
        mean = mean / count
    print(name)
    print(mean)


def save_image(image_numpy, image_path):
    """Save a numpy image to the disk

    Parameters:
        image_numpy (numpy array) -- input numpy array
        image_path (str)          -- the path of the image
    """
    image_pil = Image.fromarray(image_numpy)
    #pdb.set_trace()
    image_pil.save(image_path)


def print_numpy(x, val=True, shp=False):
    """Print the mean, min, max, median, std, and size of a numpy array

    Parameters:
        val (bool) -- if print the values of the numpy array
        shp (bool) -- if print the shape of the numpy array
    """
    x = x.astype(np.float64)
    if shp:
        print('shape,', x.shape)
    if val:
        x = x.flatten()
        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))


def mkdirs(paths):
    """create empty directories if they don't exist

    Parameters:
        paths (str list) -- a list of directory paths
    """
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)


def mkdir(path):
    """create a single empty directory if it didn't exist

    Parameters:
        path (str) -- a single directory path
    """
    if not os.path.exists(path):
        os.makedirs(path)

def normalize_tensor(in_feat,eps=1e-10):
    norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
    return in_feat/(norm_factor+eps)

================================================
FILE: util/visualizer.py
================================================
import numpy as np
import os
import sys
import ntpath
import time
from . import util, html
from subprocess import Popen, PIPE
from PIL import Image

if sys.version_info[0] == 2:
    VisdomExceptionBase = Exception
else:
    VisdomExceptionBase = ConnectionError


def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, W=None, H=None):
    """Save images to the disk.

    Parameters:
        webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
        visuals (OrderedDict)    -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
        image_path (str)         -- the string is used to create image paths
        aspect_ratio (float)     -- the aspect ratio of saved images
        width (int)              -- the images will be resized to width x width

    This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
    """
    image_dir = webpage.get_image_dir()
    short_path = ntpath.basename(image_path[0])
    name = os.path.splitext(short_path)[0]

    webpage.add_header(name)
    ims, txts, links = [], [], []

    for label, im_data in visuals.items():
        ## tensor to im
        im = util.tensor2im(im_data)
        image_name = '%s_%s.png' % (name, label)
        save_path = os.path.join(image_dir, image_name)
        h, w, _ = im.shape
        if W is not None and H is not None and (W != w or H != h):
            im = np.array(Image.fromarray(im).resize((W, H), Image.BICUBIC))
        else:
            if aspect_ratio > 1.0:
                im = np.array(Image.fromarray(im).resize((int(w * aspect_ratio), h), Image.BICUBIC))
            if aspect_ratio < 1.0:
                im = np.array(Image.fromarray(im).resize((w, int(h / aspect_ratio)), Image.BICUBIC))
        util.save_image(im, save_path)

        ims.append(image_name)
        txts.append(label)
        links.append(image_name)
    webpage.add_images(ims, txts, links, width=width)


class Visualizer():
    """This class includes several functions that can display/save images and print/save logging information.

    It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
    """

    def __init__(self, opt):
        """Initialize the Visualizer class

        Parameters:
            opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
        Step 1: Cache the training/test options
        Step 2: connect to a visdom server
        Step 3: create an HTML object for saveing HTML filters
        Step 4: create a logging file to store training losses
        """
        self.opt = opt  # cache the option
        self.display_id = opt.display_id
        self.use_html = opt.isTrain and not opt.no_html
        self.win_size = opt.display_winsize
        self.name = opt.name
        self.port = opt.display_port
        self.saved = False
        if self.display_id > 0:  # connect to a visdom server given <display_port> and <display_server>
            import visdom
            self.ncols = opt.display_ncols
            self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)
            if not self.vis.check_connection():
                self.create_visdom_connections()

        if self.use_html:  # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
            self.img_dir = os.path.join(self.web_dir, 'images')
            print('create web directory %s...' % self.web_dir)
            util.mkdirs([self.web_dir, self.img_dir])
        # create a logging file to store training losses
        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
        with open(self.log_name, "a") as log_file:
            now = time.strftime("%c")
            log_file.write('================ Training Loss (%s) ================\n' % now)

    def reset(self):
        """Reset the self.saved status"""
        self.saved = False

    def create_visdom_connections(self):
        """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """
        cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port
        print('\n\nCould not connect to Visdom server. \n Trying to start a server....')
        print('Command: %s' % cmd)
        Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)

    def display_current_results(self, visuals, epoch, save_result):
        """Display current results on visdom; save current results to an HTML file.

        Parameters:
            visuals (OrderedDict) - - dictionary of images to display or save
            epoch (int) - - the current epoch
            save_result (bool) - - if save the current results to an HTML file
        """
        if self.display_id > 0:  # show images in the browser using visdom
            ncols = self.ncols
            if ncols > 0:        # show all the images in one visdom panel
                ncols = min(ncols, len(visuals))
                h, w = next(iter(visuals.values())).shape[:2]
                table_css = """<style>
                        table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
                        table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
                        </style>""" % (w, h)  # create a table css
                # create a table of images.
                title = self.name
                label_html = ''
                label_html_row = ''
                images = []
                idx = 0
                for label, image in visuals.items():
                    image_numpy = util.tensor2im(image)
                    label_html_row += '<td>%s</td>' % label
                    images.append(image_numpy.transpose([2, 0, 1]))
                    idx += 1
                    if idx % ncols == 0:
                        label_html += '<tr>%s</tr>' % label_html_row
                        label_html_row = ''
                white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
                while idx % ncols != 0:
                    images.append(white_image)
                    label_html_row += '<td></td>'
                    idx += 1
                if label_html_row != '':
                    label_html += '<tr>%s</tr>' % label_html_row
                try:
                    self.vis.images(images, nrow=ncols, win=self.display_id + 1,
                                    padding=2, opts=dict(title=title + ' images'))
                    label_html = '<table>%s</table>' % label_html
                    self.vis.text(table_css + label_html, win=self.display_id + 2,
                                  opts=dict(title=title + ' labels'))
                except VisdomExceptionBase:
                    self.create_visdom_connections()

            else:     # show each image in a separate visdom panel;
                idx = 1
                try:
                    for label, image in visuals.items():
                        image_numpy = util.tensor2im(image)
                        self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
                                       win=self.display_id + idx)
                        idx += 1
                except VisdomExceptionBase:
                    self.create_visdom_connections()

        if self.use_html and (save_result or not self.saved):  # save images to an HTML file if they haven't been saved.
            self.saved = True
            # save images to the disk
            for label, image in visuals.items():
                image_numpy = util.tensor2im(image)
                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
                util.save_image(image_numpy, img_path)

            # update website
            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
            for n in range(epoch, 0, -1):
                webpage.add_header('epoch [%d]' % n)
                ims, txts, links = [], [], []

                for label, image_numpy in visuals.items():
                    image_numpy = util.tensor2im(image)
                    img_path = 'epoch%.3d_%s.png' % (n, label)
                    ims.append(img_path)
                    txts.append(label)
                    links.append(img_path)
                webpage.add_images(ims, txts, links, width=self.win_size)
            webpage.save()

    def plot_current_losses(self, epoch, counter_ratio, losses):
        """display the current losses on visdom display: dictionary of error labels and values

        Parameters:
            epoch (int)           -- current epoch
            counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1
            losses (OrderedDict)  -- training losses stored in the format of (name, float) pairs
        """
        if not hasattr(self, 'plot_data'):
            self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
        self.plot_data['X'].append(epoch + counter_ratio)
        self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
        #X = np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1)
        #Y = np.array(self.plot_data['Y'])
        try:
            self.vis.line(
                X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
                Y=np.array(self.plot_data['Y']),
                opts={
                    'title': self.name + ' loss over time',
                    'legend': self.plot_data['legend'],
                    'xlabel': 'epoch',
                    'ylabel': 'loss'},
                win=self.display_id)
        except VisdomExceptionBase:
            self.create_visdom_connections()

    # losses: same format as |losses| of plot_current_losses
    def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
        """print current losses on console; also save the losses to the disk

        Parameters:
            epoch (int) -- current epoch
            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
            t_comp (float) -- computational time per data point (normalized by batch_size)
            t_data (float) -- data loading time per data point (normalized by batch_size)
        """
        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)

        print(message)  # print the message
        with open(self.log_name, "a") as log_file:
            log_file.write('%s\n' % message)  # save the message
    
    # losses: same format as |losses| of plot_current_losses
    def print_current_losses_process(self, epoch, iters, losses, t_comp, t_data, processes):
        """print current losses on console; also save the losses to the disk

        Parameters:
            epoch (int) -- current epoch
            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
            t_comp (float) -- computational time per data point (normalized by batch_size)
            t_data (float) -- data loading time per data point (normalized by batch_size)
        """
        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
        message += '[process: %.3f, non_trunc: %.3f, trunc: %.3f] ' % (processes[0], processes[1], processes[2])
        for k, v in losses.items():
            message += '%s: %.3f ' % (k, v)

        print(message)  # print the message
        with open(self.log_name, "a") as log_file:
            log_file.write('%s\n' % message)  # save the message
Download .txt
gitextract_uyxvzmrx/

├── .gitignore
├── data/
│   ├── __init__.py
│   ├── base_dataset.py
│   ├── image_folder.py
│   ├── single_dataset.py
│   └── unaligned_mask_stylecls_dataset.py
├── models/
│   ├── __init__.py
│   ├── asymmetric_cycle_gan_cls_model.py
│   ├── base_model.py
│   ├── dist_model.py
│   ├── networks.py
│   ├── networks_basic.py
│   ├── pretrained_networks.py
│   ├── test_3styles_model.py
│   └── test_model.py
├── options/
│   ├── __init__.py
│   ├── base_options.py
│   ├── test_options.py
│   └── train_options.py
├── portrait_drawing_resources.md
├── preprocess/
│   ├── example/
│   │   └── ia_selfie_10515_facial5point.mat
│   ├── face_align_512.m
│   └── readme.md
├── readme.md
├── requirements.txt
├── scripts/
│   └── train.sh
├── test.py
├── test_seq_style.py
├── train.py
└── util/
    ├── __init__.py
    ├── get_data.py
    ├── html.py
    ├── image_pool.py
    ├── util.py
    └── visualizer.py
Download .txt
SYMBOL INDEX (224 symbols across 23 files)

FILE: data/__init__.py
  function find_dataset_using_name (line 18) | def find_dataset_using_name(dataset_name):
  function get_option_setter (line 41) | def get_option_setter(dataset_name):
  function create_dataset (line 47) | def create_dataset(opt):
  class CustomDatasetDataLoader (line 62) | class CustomDatasetDataLoader():
    method __init__ (line 65) | def __init__(self, opt):
    method load_data (line 81) | def load_data(self):
    method __len__ (line 84) | def __len__(self):
    method __iter__ (line 88) | def __iter__(self):

FILE: data/base_dataset.py
  class BaseDataset (line 13) | class BaseDataset(data.Dataset):
    method __init__ (line 24) | def __init__(self, opt):
    method modify_commandline_options (line 34) | def modify_commandline_options(parser, is_train):
    method __len__ (line 47) | def __len__(self):
    method __getitem__ (line 52) | def __getitem__(self, index):
  function get_params (line 64) | def get_params(opt, size):
  function get_transform (line 82) | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBI...
  function get_transform_mask (line 115) | def get_transform_mask(opt, params=None, grayscale=False, method=Image.B...
  function __make_power_2 (line 144) | def __make_power_2(img, base, method=Image.BICUBIC):
  function __scale_width (line 155) | def __scale_width(img, target_width, method=Image.BICUBIC):
  function __crop (line 164) | def __crop(img, pos, size):
  function __flip (line 173) | def __flip(img, flip):
  function __print_size_warning (line 179) | def __print_size_warning(ow, oh, w, h):

FILE: data/image_folder.py
  function is_image_file (line 19) | def is_image_file(filename):
  function make_dataset (line 23) | def make_dataset(dir, max_dataset_size=float("inf")):
  function default_loader (line 35) | def default_loader(path):
  class ImageFolder (line 39) | class ImageFolder(data.Dataset):
    method __init__ (line 41) | def __init__(self, root, transform=None, return_paths=False,
    method __getitem__ (line 55) | def __getitem__(self, index):
    method __len__ (line 65) | def __len__(self):

FILE: data/single_dataset.py
  class SingleDataset (line 8) | class SingleDataset(BaseDataset):
    method __init__ (line 14) | def __init__(self, opt):
    method __getitem__ (line 28) | def __getitem__(self, index):
    method __len__ (line 61) | def __len__(self):

FILE: data/unaligned_mask_stylecls_dataset.py
  class UnalignedMaskStyleClsDataset (line 12) | class UnalignedMaskStyleClsDataset(BaseDataset):
    method __init__ (line 13) | def __init__(self, opt):
    method __getitem__ (line 34) | def __getitem__(self, index):
    method __len__ (line 98) | def __len__(self):

FILE: models/__init__.py
  function find_model_using_name (line 25) | def find_model_using_name(model_name):
  function get_option_setter (line 48) | def get_option_setter(model_name):
  function create_model (line 54) | def create_model(opt):

FILE: models/asymmetric_cycle_gan_cls_model.py
  function truncate (line 10) | def truncate(fake_B,a=127.5):#[-1,1]
  class AsymmetricCycleGANClsModel (line 13) | class AsymmetricCycleGANClsModel(BaseModel):
    method modify_commandline_options (line 15) | def modify_commandline_options(parser, is_train=True):
    method __init__ (line 44) | def __init__(self, opt):
    method set_input (line 176) | def set_input(self, input):
    method forward (line 206) | def forward(self):
    method backward_D_basic (line 231) | def backward_D_basic(self, netD, real, fake):
    method backward_D_basic_cls (line 253) | def backward_D_basic_cls(self, netD, real, fake):
    method backward_D_A (line 275) | def backward_D_A(self):
    method backward_D_A_l (line 280) | def backward_D_A_l(self):
    method backward_D_A_le (line 285) | def backward_D_A_le(self):
    method backward_D_A_ll (line 290) | def backward_D_A_ll(self):
    method backward_D_B (line 295) | def backward_D_B(self):
    method update_process (line 300) | def update_process(self, epoch):
    method backward_G (line 303) | def backward_G(self):
    method optimize_parameters (line 374) | def optimize_parameters(self):

FILE: models/base_model.py
  class BaseModel (line 8) | class BaseModel():
    method __init__ (line 19) | def __init__(self, opt):
    method modify_commandline_options (line 48) | def modify_commandline_options(parser, is_train):
    method set_input (line 61) | def set_input(self, input):
    method forward (line 70) | def forward(self):
    method optimize_parameters (line 75) | def optimize_parameters(self):
    method setup (line 79) | def setup(self, opt):
    method eval (line 92) | def eval(self):
    method test (line 99) | def test(self):
    method compute_visuals (line 109) | def compute_visuals(self):
    method get_image_paths (line 113) | def get_image_paths(self):
    method update_learning_rate (line 117) | def update_learning_rate(self):
    method get_current_visuals (line 128) | def get_current_visuals(self):
    method get_current_losses (line 136) | def get_current_losses(self):
    method save_networks (line 144) | def save_networks(self, epoch):
    method __patch_instance_norm_state_dict (line 162) | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i...
    method load_networks (line 176) | def load_networks(self, epoch):
    method print_networks (line 201) | def print_networks(self, verbose):
    method set_requires_grad (line 219) | def set_requires_grad(self, nets, requires_grad=False):
    method masked (line 233) | def masked(self, A,mask):

FILE: models/dist_model.py
  class DistModel (line 20) | class DistModel(BaseModel):
    method name (line 21) | def name(self):
    method __init__ (line 24) | def __init__(self, opt, model='net-lin', net='alex', pnet_rand=False, ...
    method forward_pair (line 106) | def forward_pair(self,in1,in2,retPerLayer=False):
    method forward (line 112) | def forward(self, in0, in1, retNumpy=False):
    method optimize_parameters (line 159) | def optimize_parameters(self):
    method clamp_weights (line 166) | def clamp_weights(self):
    method set_input (line 171) | def set_input(self, data):
    method forward_train (line 187) | def forward_train(self): # run forward pass
    method backward_train (line 198) | def backward_train(self):
    method compute_accuracy (line 201) | def compute_accuracy(self,d0,d1,judge):
    method get_current_errors (line 207) | def get_current_errors(self):
    method get_current_visuals (line 216) | def get_current_visuals(self):
    method save (line 231) | def save(self, path, label):
    method update_learning_rate (line 235) | def update_learning_rate(self,nepoch_decay):
  function score_2afc_dataset (line 247) | def score_2afc_dataset(data_loader,func):
  function score_jnd_dataset (line 284) | def score_jnd_dataset(data_loader,func):

FILE: models/networks.py
  class Identity (line 13) | class Identity(nn.Module):
    method forward (line 14) | def forward(self, x):
  function get_norm_layer (line 18) | def get_norm_layer(norm_type='instance'):
  function get_scheduler (line 38) | def get_scheduler(optimizer, opt):
  function init_weights (line 67) | def init_weights(net, init_type='normal', init_gain=0.02):
  function init_net (line 101) | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
  function define_G (line 119) | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=F...
  function define_D (line 164) | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type=...
  function define_HED (line 210) | def define_HED(init_weights_, gpu_ids_=[]):
  class GANLoss (line 233) | class GANLoss(nn.Module):
    method __init__ (line 240) | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=...
    method get_target_tensor (line 264) | def get_target_tensor(self, prediction, target_is_real):
    method __call__ (line 281) | def __call__(self, prediction, target_is_real):
  function cal_gradient_penalty (line 302) | def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed...
  class ResnetGenerator (line 339) | class ResnetGenerator(nn.Module):
    method __init__ (line 345) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
    method forward (line 395) | def forward(self, input):
  class ResnetStyle2Generator (line 399) | class ResnetStyle2Generator(nn.Module):
    method __init__ (line 400) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
    method forward (line 459) | def forward(self, input1, input2):
  class ResnetBlock (line 465) | class ResnetBlock(nn.Module):
    method __init__ (line 468) | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bia...
    method build_conv_block (line 479) | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout,...
    method forward (line 520) | def forward(self, x):
  class UnetGenerator (line 526) | class UnetGenerator(nn.Module):
    method __init__ (line 529) | def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=...
    method forward (line 553) | def forward(self, input):
  class UnetSkipConnectionBlock (line 558) | class UnetSkipConnectionBlock(nn.Module):
    method __init__ (line 564) | def __init__(self, outer_nc, inner_nc, input_nc=None,
    method forward (line 621) | def forward(self, x):
  class NLayerDiscriminator (line 628) | class NLayerDiscriminator(nn.Module):
    method __init__ (line 631) | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNo...
    method forward (line 671) | def forward(self, input):
  class NLayerDiscriminatorCls (line 676) | class NLayerDiscriminatorCls(nn.Module):
    method __init__ (line 679) | def __init__(self, input_nc, ndf=64, n_layers=3, n_class=3, norm_layer...
    method forward (line 736) | def forward(self, input):
  class PixelDiscriminator (line 746) | class PixelDiscriminator(nn.Module):
    method __init__ (line 749) | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
    method forward (line 773) | def forward(self, input):
  class HED (line 778) | class HED(nn.Module):
    method __init__ (line 779) | def __init__(self):
    method forward (line 838) | def forward(self, tensorInput):

FILE: models/networks_basic.py
  function spatial_average (line 17) | def spatial_average(in_tens, keepdim=True):
  function upsample (line 20) | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
  class PNetLin (line 27) | class PNetLin(nn.Module):
    method __init__ (line 28) | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, ...
    method forward (line 64) | def forward(self, in0, in1, retPerLayer=False):
  class ScalingLayer (line 94) | class ScalingLayer(nn.Module):
    method __init__ (line 95) | def __init__(self):
    method forward (line 100) | def forward(self, inp):
  class NetLinLayer (line 104) | class NetLinLayer(nn.Module):
    method __init__ (line 106) | def __init__(self, chn_in, chn_out=1, use_dropout=False):
  class Dist2LogitLayer (line 114) | class Dist2LogitLayer(nn.Module):
    method __init__ (line 116) | def __init__(self, chn_mid=32, use_sigmoid=True):
    method forward (line 128) | def forward(self,d0,d1,eps=0.1):
  class BCERankingLoss (line 131) | class BCERankingLoss(nn.Module):
    method __init__ (line 132) | def __init__(self, chn_mid=32):
    method forward (line 138) | def forward(self, d0, d1, judge):
  class FakeNet (line 144) | class FakeNet(nn.Module):
    method __init__ (line 145) | def __init__(self, use_gpu=True, colorspace='Lab'):
  class L2 (line 150) | class L2(FakeNet):
    method forward (line 152) | def forward(self, in0, in1, retPerLayer=None):
  class DSSIM (line 167) | class DSSIM(FakeNet):
    method forward (line 169) | def forward(self, in0, in1, retPerLayer=None):
  function print_network (line 182) | def print_network(net):

FILE: models/pretrained_networks.py
  class squeezenet (line 6) | class squeezenet(torch.nn.Module):
    method __init__ (line 7) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 36) | def forward(self, X):
  class alexnet (line 57) | class alexnet(torch.nn.Module):
    method __init__ (line 58) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 81) | def forward(self, X):
  class vgg16 (line 97) | class vgg16(torch.nn.Module):
    method __init__ (line 98) | def __init__(self, requires_grad=False, pretrained=True):
    method forward (line 121) | def forward(self, X):
  class resnet (line 139) | class resnet(torch.nn.Module):
    method __init__ (line 140) | def __init__(self, requires_grad=False, pretrained=True, num=18):
    method forward (line 163) | def forward(self, X):

FILE: models/test_3styles_model.py
  class Test3StylesModel (line 5) | class Test3StylesModel(BaseModel):
    method modify_commandline_options (line 12) | def modify_commandline_options(parser, is_train=True):
    method __init__ (line 34) | def __init__(self, opt):
    method set_input (line 51) | def set_input(self, input):
    method forward (line 58) | def forward(self):
    method optimize_parameters (line 64) | def optimize_parameters(self):

FILE: models/test_model.py
  class TestModel (line 5) | class TestModel(BaseModel):
    method modify_commandline_options (line 12) | def modify_commandline_options(parser, is_train=True):
    method __init__ (line 40) | def __init__(self, opt):
    method set_input (line 71) | def set_input(self, input):
    method forward (line 84) | def forward(self):
    method optimize_parameters (line 93) | def optimize_parameters(self):

FILE: options/base_options.py
  class BaseOptions (line 9) | class BaseOptions():
    method __init__ (line 16) | def __init__(self):
    method initialize (line 20) | def initialize(self, parser):
    method gather_options (line 61) | def gather_options(self):
    method print_options (line 89) | def print_options(self, opt):
    method parse (line 114) | def parse(self):

FILE: options/test_options.py
  class TestOptions (line 4) | class TestOptions(BaseOptions):
    method initialize (line 10) | def initialize(self, parser):

FILE: options/train_options.py
  class TrainOptions (line 4) | class TrainOptions(BaseOptions):
    method initialize (line 10) | def initialize(self, parser):

FILE: test_seq_style.py
  function opts (line 4) | def opts():

FILE: util/get_data.py
  class GetData (line 11) | class GetData(object):
    method __init__ (line 27) | def __init__(self, technique='cyclegan', verbose=True):
    method _print (line 35) | def _print(self, text):
    method _get_options (line 40) | def _get_options(r):
    method _present_options (line 46) | def _present_options(self):
    method _download_data (line 56) | def _download_data(self, dataset_url, save_path):
    method get (line 79) | def get(self, save_path, dataset=None):

FILE: util/html.py
  class HTML (line 6) | class HTML:
    method __init__ (line 14) | def __init__(self, web_dir, title, refresh=0, folder='images'):
    method get_image_dir (line 37) | def get_image_dir(self):
    method add_header (line 41) | def add_header(self, text):
    method add_images (line 50) | def add_images(self, ims, txts, links, width=400):
    method save (line 71) | def save(self):

FILE: util/image_pool.py
  class ImagePool (line 5) | class ImagePool():
    method __init__ (line 12) | def __init__(self, pool_size):
    method query (line 23) | def query(self, images):

FILE: util/util.py
  function tensor2im (line 11) | def tensor2im(input_image, imtype=np.uint8):
  function diagnose_network (line 36) | def diagnose_network(net, name='network'):
  function save_image (line 55) | def save_image(image_numpy, image_path):
  function print_numpy (line 67) | def print_numpy(x, val=True, shp=False):
  function mkdirs (line 83) | def mkdirs(paths):
  function mkdir (line 96) | def mkdir(path):
  function normalize_tensor (line 105) | def normalize_tensor(in_feat,eps=1e-10):

FILE: util/visualizer.py
  function save_images (line 16) | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=25...
  class Visualizer (line 56) | class Visualizer():
    method __init__ (line 62) | def __init__(self, opt):
    method reset (line 97) | def reset(self):
    method create_visdom_connections (line 101) | def create_visdom_connections(self):
    method display_current_results (line 108) | def display_current_results(self, visuals, epoch, save_result):
    method plot_current_losses (line 189) | def plot_current_losses(self, epoch, counter_ratio, losses):
    method print_current_losses (line 217) | def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
    method print_current_losses_process (line 236) | def print_current_losses_process(self, epoch, iters, losses, t_comp, t...
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (196K chars).
[
  {
    "path": ".gitignore",
    "chars": 756,
    "preview": ".DS_Store\ndebug*\ndatasets/\ncheckpoints/\nstyle_features/\nresults/\nbuild/\ndist/\ntorch.egg-info/\n*/**/__pycache__\ntorch/ver"
  },
  {
    "path": "data/__init__.py",
    "chars": 3554,
    "preview": "\"\"\"This package includes all the modules related to data loading and preprocessing\n\n To add a custom dataset class calle"
  },
  {
    "path": "data/base_dataset.py",
    "chars": 6615,
    "preview": "\"\"\"This module implements an abstract base class (ABC) 'BaseDataset' for datasets.\n\nIt also includes common transformati"
  },
  {
    "path": "data/image_folder.py",
    "chars": 1893,
    "preview": "\"\"\"A modified image folder class\n\nWe modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/ma"
  },
  {
    "path": "data/single_dataset.py",
    "chars": 2583,
    "preview": "from data.base_dataset import BaseDataset, get_transform, get_params, get_transform_mask\nfrom data.image_folder import m"
  },
  {
    "path": "data/unaligned_mask_stylecls_dataset.py",
    "chars": 4851,
    "preview": "import os.path\nfrom data.base_dataset import BaseDataset, get_params, get_transform, get_transform_mask\nfrom data.image_"
  },
  {
    "path": "models/__init__.py",
    "chars": 3072,
    "preview": "\"\"\"This package contains modules related to objective functions, optimizations, and network architectures.\n\nTo add a cus"
  },
  {
    "path": "models/asymmetric_cycle_gan_cls_model.py",
    "chars": 23569,
    "preview": "import torch\nimport itertools\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import netw"
  },
  {
    "path": "models/base_model.py",
    "chars": 10885,
    "preview": "import os\nimport torch\nfrom collections import OrderedDict\nfrom abc import ABCMeta, abstractmethod\nfrom . import network"
  },
  {
    "path": "models/dist_model.py",
    "chars": 13695,
    "preview": "\nfrom __future__ import absolute_import\n\nimport sys\nsys.path.append('..')\nsys.path.append('.')\nimport numpy as np\nimport"
  },
  {
    "path": "models/networks.py",
    "chars": 35426,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools\nfrom torch.optim import lr_scheduler\n\n\n###"
  },
  {
    "path": "models/networks_basic.py",
    "chars": 7514,
    "preview": "\nfrom __future__ import absolute_import\n\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom"
  },
  {
    "path": "models/pretrained_networks.py",
    "chars": 6559,
    "preview": "from collections import namedtuple\nimport torch\nfrom torchvision import models\nfrom IPython import embed\n\nclass squeezen"
  },
  {
    "path": "models/test_3styles_model.py",
    "chars": 3411,
    "preview": "from .base_model import BaseModel\nfrom . import networks\nimport torch\n\nclass Test3StylesModel(BaseModel):\n    \"\"\" This T"
  },
  {
    "path": "models/test_model.py",
    "chars": 5123,
    "preview": "from .base_model import BaseModel\nfrom . import networks\nimport torch\n\nclass TestModel(BaseModel):\n    \"\"\" This TesteMod"
  },
  {
    "path": "options/__init__.py",
    "chars": 136,
    "preview": "\"\"\"This package options includes option modules: training options, test options, and basic options (used in both trainin"
  },
  {
    "path": "options/base_options.py",
    "chars": 8437,
    "preview": "import argparse\nimport os\nfrom util import util\nimport torch\nimport models\nimport data\n\n\nclass BaseOptions():\n    \"\"\"Thi"
  },
  {
    "path": "options/test_options.py",
    "chars": 1363,
    "preview": "from .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n    \"\"\"This class includes test options.\n\n    It"
  },
  {
    "path": "options/train_options.py",
    "chars": 3516,
    "preview": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n    \"\"\"This class includes training options.\n\n "
  },
  {
    "path": "portrait_drawing_resources.md",
    "chars": 315,
    "preview": "\n- Charles Burns (style1): https://www.pinterest.co.uk/johns59/charles-burns-fan-club/\n- Yann Legendre (style1): http://"
  },
  {
    "path": "preprocess/face_align_512.m",
    "chars": 1450,
    "preview": "function [trans_img]=face_align_512(impath,facial5point,savedir)\n% align the faces by similarity transformation.\n% using"
  },
  {
    "path": "preprocess/readme.md",
    "chars": 1612,
    "preview": "## Preprocessing steps\n\nDuring training, face photos and drawings are aligned and have nose,eyes,lips mask detected. \n\nD"
  },
  {
    "path": "readme.md",
    "chars": 4506,
    "preview": "\n# Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping\n\nWe provide PyTorch implementations for our CVPR 20"
  },
  {
    "path": "requirements.txt",
    "chars": 128,
    "preview": "torch==1.2.0\ntorchvision==0.4.0\ndominate==2.4.0\nvisdom==0.1.8.9\nscipy==1.1.0\nnumpy==1.16.4\nPillow==6.2.1\nopencv-python=="
  },
  {
    "path": "scripts/train.sh",
    "chars": 313,
    "preview": "set -ex\npython train.py --dataroot ./datasets/portrait_drawing --name formal --model asymmetric_cycle_gan_cls --output_n"
  },
  {
    "path": "test.py",
    "chars": 4143,
    "preview": "\"\"\"General-purpose test script for image-to-image translation.\n\nOnce you have trained your model with train.py, you can "
  },
  {
    "path": "test_seq_style.py",
    "chars": 1809,
    "preview": "import os\nimport argparse\n\ndef opts():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-g','--gpu', defa"
  },
  {
    "path": "train.py",
    "chars": 5219,
    "preview": "\"\"\"General-purpose training script for image-to-image translation.\n\nThis script works for various models (with option '-"
  },
  {
    "path": "util/__init__.py",
    "chars": 83,
    "preview": "\"\"\"This package includes a miscellaneous collection of useful helper functions.\"\"\"\n"
  },
  {
    "path": "util/get_data.py",
    "chars": 3639,
    "preview": "from __future__ import print_function\nimport os\nimport tarfile\nimport requests\nfrom warnings import warn\nfrom zipfile im"
  },
  {
    "path": "util/html.py",
    "chars": 3569,
    "preview": "import dominate\nfrom dominate.tags import meta, h3, table, tr, td, p, a, img, br\nimport os\n\n\nclass HTML:\n    \"\"\"This HTM"
  },
  {
    "path": "util/image_pool.py",
    "chars": 2226,
    "preview": "import random\nimport torch\n\n\nclass ImagePool():\n    \"\"\"This class implements an image buffer that stores previously gene"
  },
  {
    "path": "util/util.py",
    "chars": 3326,
    "preview": "\"\"\"This module contains simple helper functions \"\"\"\nfrom __future__ import print_function\nimport torch\nimport numpy as n"
  },
  {
    "path": "util/visualizer.py",
    "chars": 12238,
    "preview": "import numpy as np\nimport os\nimport sys\nimport ntpath\nimport time\nfrom . import util, html\nfrom subprocess import Popen,"
  }
]

// ... and 1 more files (download for full content)

About this extraction

This page contains the full source code of the yiranran/Unpaired-Portrait-Drawing GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (183.1 KB), approximately 47.7k tokens, and a symbol index with 224 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!