[
  {
    "path": ".gitignore",
    "content": ".DS_Store\ndebug*\ndatasets/\ncheckpoints/\nstyle_features/\nresults/\nbuild/\ndist/\ntorch.egg-info/\n*/**/__pycache__\ntorch/version.py\ntorch/csrc/generic/TensorMethods.cpp\ntorch/lib/*.so*\ntorch/lib/*.dylib*\ntorch/lib/*.h\ntorch/lib/build\ntorch/lib/tmp_install\ntorch/lib/include\ntorch/lib/torch_shm_manager\ntorch/csrc/cudnn/cuDNN.cpp\ntorch/csrc/nn/THNN.cwrap\ntorch/csrc/nn/THNN.cpp\ntorch/csrc/nn/THCUNN.cwrap\ntorch/csrc/nn/THCUNN.cpp\ntorch/csrc/nn/THNN_generic.cwrap\ntorch/csrc/nn/THNN_generic.cpp\ntorch/csrc/nn/THNN_generic.h\ndocs/src/**/*\ntest/data/legacy_modules.t7\ntest/data/gpu_tensors.pt\ntest/htmlcov\ntest/.coverage\n*/*.pyc\n*/**/*.pyc\n*/**/**/*.pyc\n*/**/**/**/*.pyc\n*/**/**/**/**/*.pyc\n*/*.so*\n*/**/*.so*\n*/**/*.dylib*\ntest/data/legacy_serialized.pt\n*~\n.idea\n"
  },
  {
    "path": "data/__init__.py",
    "content": "\"\"\"This package includes all the modules related to data loading and preprocessing\n\n To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.\n You need to implement four functions:\n    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).\n    -- <__len__>:                       return the size of dataset.\n    -- <__getitem__>:                   get a data point from data loader.\n    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.\n\nNow you can use the dataset class by specifying flag '--dataset_mode dummy'.\nSee our template dataset class 'template_dataset.py' for more details.\n\"\"\"\nimport importlib\nimport torch.utils.data\nfrom data.base_dataset import BaseDataset\n\n\ndef find_dataset_using_name(dataset_name):\n    \"\"\"Import the module \"data/[dataset_name]_dataset.py\".\n\n    In the file, the class called DatasetNameDataset() will\n    be instantiated. It has to be a subclass of BaseDataset,\n    and it is case-insensitive.\n    \"\"\"\n    dataset_filename = \"data.\" + dataset_name + \"_dataset\"\n    datasetlib = importlib.import_module(dataset_filename)\n\n    dataset = None\n    target_dataset_name = dataset_name.replace('_', '') + 'dataset'\n    for name, cls in datasetlib.__dict__.items():\n        if name.lower() == target_dataset_name.lower() \\\n           and issubclass(cls, BaseDataset):\n            dataset = cls\n\n    if dataset is None:\n        raise NotImplementedError(\"In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase.\" % (dataset_filename, target_dataset_name))\n\n    return dataset\n\n\ndef get_option_setter(dataset_name):\n    \"\"\"Return the static method <modify_commandline_options> of the dataset class.\"\"\"\n    dataset_class = find_dataset_using_name(dataset_name)\n    return dataset_class.modify_commandline_options\n\n\ndef create_dataset(opt):\n    \"\"\"Create a dataset given the option.\n\n    This function wraps the class CustomDatasetDataLoader.\n        This is the main interface between this package and 'train.py'/'test.py'\n\n    Example:\n        >>> from data import create_dataset\n        >>> dataset = create_dataset(opt)\n    \"\"\"\n    data_loader = CustomDatasetDataLoader(opt)\n    dataset = data_loader.load_data()\n    return dataset\n\n\nclass CustomDatasetDataLoader():\n    \"\"\"Wrapper class of Dataset class that performs multi-threaded data loading\"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize this class\n\n        Step 1: create a dataset instance given the name [dataset_mode]\n        Step 2: create a multi-threaded data loader.\n        \"\"\"\n        self.opt = opt\n        dataset_class = find_dataset_using_name(opt.dataset_mode)\n        self.dataset = dataset_class(opt)\n        print(\"dataset [%s] was created\" % type(self.dataset).__name__)\n        self.dataloader = torch.utils.data.DataLoader(\n            self.dataset,\n            batch_size=opt.batch_size,\n            shuffle=not opt.serial_batches,\n            num_workers=int(opt.num_threads))\n\n    def load_data(self):\n        return self\n\n    def __len__(self):\n        \"\"\"Return the number of data in the dataset\"\"\"\n        return min(len(self.dataset), self.opt.max_dataset_size)\n\n    def __iter__(self):\n        \"\"\"Return a batch of data\"\"\"\n        for i, data in enumerate(self.dataloader):\n            if i * self.opt.batch_size >= self.opt.max_dataset_size:\n                break\n            yield data\n"
  },
  {
    "path": "data/base_dataset.py",
    "content": "\"\"\"This module implements an abstract base class (ABC) 'BaseDataset' for datasets.\n\nIt also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.\n\"\"\"\nimport random\nimport numpy as np\nimport torch.utils.data as data\nfrom PIL import Image\nimport torchvision.transforms as transforms\nfrom abc import ABCMeta, abstractmethod\n\n\nclass BaseDataset(data.Dataset):\n    __metaclass__ = ABCMeta\n    \"\"\"This class is an abstract base class (ABC) for datasets.\n\n    To create a subclass, you need to implement the following four functions:\n    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).\n    -- <__len__>:                       return the size of dataset.\n    -- <__getitem__>:                   get a data point.\n    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize the class; save the options in the class\n\n        Parameters:\n            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        self.opt = opt\n        self.root = opt.dataroot\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        \"\"\"Add new dataset-specific options, and rewrite default values for existing options.\n\n        Parameters:\n            parser          -- original option parser\n            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.\n\n        Returns:\n            the modified parser.\n        \"\"\"\n        return parser\n\n    @abstractmethod\n    def __len__(self):\n        \"\"\"Return the total number of images in the dataset.\"\"\"\n        return 0\n\n    @abstractmethod\n    def __getitem__(self, index):\n        \"\"\"Return a data point and its metadata information.\n\n        Parameters:\n            index - - a random integer for data indexing\n\n        Returns:\n            a dictionary of data with their names. It ususally contains the data itself and its metadata information.\n        \"\"\"\n        pass\n\n\ndef get_params(opt, size):\n    w, h = size\n    new_h = h\n    new_w = w\n    if opt.preprocess == 'resize_and_crop':\n        new_h = new_w = opt.load_size\n    elif opt.preprocess == 'scale_width_and_crop':\n        new_w = opt.load_size\n        new_h = opt.load_size * h // w\n\n    x = random.randint(0, np.maximum(0, new_w - opt.crop_size))\n    y = random.randint(0, np.maximum(0, new_h - opt.crop_size))\n\n    flip = random.random() > 0.5\n\n    return {'crop_pos': (x, y), 'flip': flip}\n\n\ndef get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):\n    transform_list = []\n    if grayscale:\n        transform_list.append(transforms.Grayscale(1))\n    if 'resize' in opt.preprocess:\n        osize = [opt.load_size, opt.load_size]\n        transform_list.append(transforms.Resize(osize, method))\n    elif 'scale_width' in opt.preprocess:\n        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))\n\n    if 'crop' in opt.preprocess:\n        if params is None:\n            transform_list.append(transforms.RandomCrop(opt.crop_size))\n        else:\n            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))\n\n    if opt.preprocess == 'none':\n        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))\n\n    if not opt.no_flip:\n        if params is None:\n            transform_list.append(transforms.RandomHorizontalFlip())\n        elif params['flip']:\n            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))\n\n    if convert:\n        transform_list += [transforms.ToTensor()]\n        if grayscale:\n            transform_list += [transforms.Normalize((0.5,), (0.5,))]\n        else:\n            transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n    return transforms.Compose(transform_list)\n\ndef get_transform_mask(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):\n    transform_list = []\n    if grayscale:\n        transform_list.append(transforms.Grayscale(1))\n    if 'resize' in opt.preprocess:\n        osize = [opt.load_size, opt.load_size]\n        transform_list.append(transforms.Resize(osize, method))\n    elif 'scale_width' in opt.preprocess:\n        transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))\n\n    if 'crop' in opt.preprocess:\n        if params is None:\n            transform_list.append(transforms.RandomCrop(opt.crop_size))\n        else:\n            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))\n\n    if opt.preprocess == 'none':\n        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))\n\n    if not opt.no_flip:\n        if params is None:\n            transform_list.append(transforms.RandomHorizontalFlip())\n        elif params['flip']:\n            transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))\n\n    if convert:\n        transform_list += [transforms.ToTensor()]\n    return transforms.Compose(transform_list)\n\ndef __make_power_2(img, base, method=Image.BICUBIC):\n    ow, oh = img.size\n    h = int(round(oh / base) * base)\n    w = int(round(ow / base) * base)\n    if (h == oh) and (w == ow):\n        return img\n\n    __print_size_warning(ow, oh, w, h)\n    return img.resize((w, h), method)\n\n\ndef __scale_width(img, target_width, method=Image.BICUBIC):\n    ow, oh = img.size\n    if (ow == target_width):\n        return img\n    w = target_width\n    h = int(target_width * oh / ow)\n    return img.resize((w, h), method)\n\n\ndef __crop(img, pos, size):\n    ow, oh = img.size\n    x1, y1 = pos\n    tw = th = size\n    if (ow > tw or oh > th):\n        return img.crop((x1, y1, x1 + tw, y1 + th))\n    return img\n\n\ndef __flip(img, flip):\n    if flip:\n        return img.transpose(Image.FLIP_LEFT_RIGHT)\n    return img\n\n\ndef __print_size_warning(ow, oh, w, h):\n    \"\"\"Print warning information about image size(only print once)\"\"\"\n    if not hasattr(__print_size_warning, 'has_printed'):\n        print(\"The image size needs to be a multiple of 4. \"\n              \"The loaded image size was (%d, %d), so it was adjusted to \"\n              \"(%d, %d). This adjustment will be done to all images \"\n              \"whose sizes are not multiples of 4\" % (ow, oh, w, h))\n        __print_size_warning.has_printed = True\n"
  },
  {
    "path": "data/image_folder.py",
    "content": "\"\"\"A modified image folder class\n\nWe modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)\nso that this class can load images from both current directory and its subdirectories.\n\"\"\"\n\nimport torch.utils.data as data\n\nfrom PIL import Image\nimport os\nimport os.path\n\nIMG_EXTENSIONS = [\n    '.jpg', '.JPG', '.jpeg', '.JPEG',\n    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',\n]\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef make_dataset(dir, max_dataset_size=float(\"inf\")):\n    images = []\n    assert os.path.isdir(dir), '%s is not a valid directory' % dir\n\n    for root, _, fnames in sorted(os.walk(dir)):\n        for fname in fnames:\n            if is_image_file(fname):\n                path = os.path.join(root, fname)\n                images.append(path)\n    return images[:min(max_dataset_size, len(images))]\n\n\ndef default_loader(path):\n    return Image.open(path).convert('RGB')\n\n\nclass ImageFolder(data.Dataset):\n\n    def __init__(self, root, transform=None, return_paths=False,\n                 loader=default_loader):\n        imgs = make_dataset(root)\n        if len(imgs) == 0:\n            raise(RuntimeError(\"Found 0 images in: \" + root + \"\\n\"\n                               \"Supported image extensions are: \" +\n                               \",\".join(IMG_EXTENSIONS)))\n\n        self.root = root\n        self.imgs = imgs\n        self.transform = transform\n        self.return_paths = return_paths\n        self.loader = loader\n\n    def __getitem__(self, index):\n        path = self.imgs[index]\n        img = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.return_paths:\n            return img, path\n        else:\n            return img\n\n    def __len__(self):\n        return len(self.imgs)\n"
  },
  {
    "path": "data/single_dataset.py",
    "content": "from data.base_dataset import BaseDataset, get_transform, get_params, get_transform_mask\nfrom data.image_folder import make_dataset\nfrom PIL import Image\nimport torch\nimport os\n\n\nclass SingleDataset(BaseDataset):\n    \"\"\"This dataset class can load a set of images specified by the path --dataroot /path/to/data.\n\n    It can be used for generating CycleGAN results only for one side with the model option '-model test'.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize this dataset class.\n\n        Parameters:\n            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        BaseDataset.__init__(self, opt)\n        if os.path.exists(opt.dataroot):\n            self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))\n        else:\n            imglistA = 'datasets/list/%s/%s.txt' % (opt.phase+'A', opt.dataroot)\n            self.A_paths = sorted(open(imglistA, 'r').read().splitlines())\n        self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc\n\n    def __getitem__(self, index):\n        \"\"\"Return a data point and its metadata information.\n\n        Parameters:\n            index - - a random integer for data indexing\n\n        Returns a dictionary that contains A and A_paths\n            A(tensor) - - an image in one domain\n            A_paths(str) - - the path of the image\n        \"\"\"\n        A_path = self.A_paths[index]\n        A_img = Image.open(A_path).convert('RGB')\n        self.opt.W, self.opt.H = A_img.size\n        transform_params_A = get_params(self.opt, A_img.size)\n        A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img)\n        item = {'A': A, 'A_paths': A_path}\n\n        if self.opt.style_control:\n            if self.opt.sinput == 'sind':\n                B_style = torch.Tensor([0.,0.,0.])\n                B_style[self.opt.sind] = 1.\n            elif self.opt.sinput == 'svec':\n                ss = self.opt.svec.split(',')\n                B_style = torch.Tensor([float(ss[0]),float(ss[1]),float(ss[2])])\n            elif self.opt.sinput == 'simg':\n                self.featureloc = os.path.join('style_features/styles2/', self.opt.sfeature_mode)\n                B_style = np.load(self.featureloc, self.opt.simg[:-4]+'.npy')\n            B_style = B_style.view(3, 1, 1)\n            B_style = B_style.repeat(1, 128, 128)\n            item['B_style'] = B_style\n\n        return item\n\n    def __len__(self):\n        \"\"\"Return the total number of images in the dataset.\"\"\"\n        return len(self.A_paths)\n"
  },
  {
    "path": "data/unaligned_mask_stylecls_dataset.py",
    "content": "import os.path\nfrom data.base_dataset import BaseDataset, get_params, get_transform, get_transform_mask\nfrom data.image_folder import make_dataset\nfrom PIL import Image\nimport random\nimport torch\nimport torchvision.transforms as transforms\nimport numpy as np\nimport pdb\n\n\nclass UnalignedMaskStyleClsDataset(BaseDataset):\n    def __init__(self, opt):\n        BaseDataset.__init__(self, opt)\n\n        self.dir_A = os.path.join(opt.dataroot, opt.phase + '/A')  # create a path '/path/to/data/trainA'\n        self.dir_B = os.path.join(opt.dataroot, opt.phase + '/B')  # create a path '/path/to/data/trainB'\n\n        self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size))   # load images from '/path/to/data/trainA'\n        self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size))    # load images from '/path/to/data/trainB'\n\n        self.A_size = len(self.A_paths)  # get the size of dataset A\n        self.B_size = len(self.B_paths)  # get the size of dataset B\n        print(\"A size:\", self.A_size)\n        print(\"B size:\", self.B_size)\n        btoA = self.opt.direction == 'BtoA'\n        self.input_nc = self.opt.output_nc if btoA else self.opt.input_nc       # get the number of channels of input image\n        self.output_nc = self.opt.input_nc if btoA else self.opt.output_nc      # get the number of channels of output image\n\n        self.auxdir_A = os.path.join(opt.dataroot, \"%s/A\" % opt.phase)\n        self.auxdir_B = os.path.join(opt.dataroot, \"%s/B\" % opt.phase)\n\n\n    def __getitem__(self, index):\n        A_path = self.A_paths[index % self.A_size]  # make sure index is within then range\n        if self.opt.serial_batches:   # make sure index is within then range\n            index_B = index % self.B_size\n        else:   # randomize the index for domain B to avoid fixed pairs.\n            index_B = random.randint(0, self.B_size - 1)\n        B_path = self.B_paths[index_B]\n        A_img = Image.open(A_path).convert('RGB')\n        B_img = Image.open(B_path).convert('RGB')\n\n        basenA = os.path.basename(A_path)\n        A_mask_img = Image.open(os.path.join(self.auxdir_A+'_nose',basenA))\n        basenB = os.path.basename(B_path)\n        B_mask_img = Image.open(os.path.join(self.auxdir_B+'_nose',basenB))\n        if self.opt.use_eye_mask:\n            A_maske_img = Image.open(os.path.join(self.auxdir_A+'_eyes',basenA))\n            B_maske_img = Image.open(os.path.join(self.auxdir_B+'_eyes',basenB))\n        if self.opt.use_lip_mask:\n            A_maskl_img = Image.open(os.path.join(self.auxdir_A+'_lips',basenA))\n            B_maskl_img = Image.open(os.path.join(self.auxdir_B+'_lips',basenB))\n\n        # apply image transformation\n        transform_params_A = get_params(self.opt, A_img.size)\n        transform_params_B = get_params(self.opt, B_img.size)\n        A = get_transform(self.opt, transform_params_A, grayscale=(self.input_nc == 1))(A_img)\n        B = get_transform(self.opt, transform_params_B, grayscale=(self.output_nc == 1))(B_img)\n        A_mask = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_mask_img)\n        B_mask = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_mask_img)\n        if self.opt.use_eye_mask:\n            A_maske = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maske_img)\n            B_maske = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maske_img)\n        if self.opt.use_lip_mask:\n            A_maskl = get_transform_mask(self.opt, transform_params_A, grayscale=1)(A_maskl_img)\n            B_maskl = get_transform_mask(self.opt, transform_params_B, grayscale=1)(B_maskl_img)\n\n        item = {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_mask': A_mask, 'B_mask': B_mask}\n        if self.opt.use_eye_mask:\n            item['A_maske'] = A_maske\n            item['B_maske'] = B_maske\n        if self.opt.use_lip_mask:\n            item['A_maskl'] = A_maskl\n            item['B_maskl'] = B_maskl\n\n        softmax = np.load(os.path.join(self.auxdir_B+'_feat',basenB[:-4]+'.npy'))\n        softmax = torch.Tensor(softmax)\n        [maxv,index] = torch.max(softmax,0)\n        B_label = index\n        if len(self.opt.sfeature_mode) >= 8 and self.opt.sfeature_mode[-8:] == '_softmax':\n            if self.opt.one_hot:\n                B_style = torch.Tensor([0.,0.,0.])\n                B_style[index] = 1.\n            else:\n                B_style = softmax\n            B_style = B_style.view(3, 1, 1)\n            B_style = B_style.repeat(1, 128, 128)\n        elif self.opt.sfeature_mode == 'domain':\n            B_style = B_label\n        item['B_style'] = B_style\n        item['B_label'] = B_label\n        if self.opt.isTrain and self.opt.style_loss_with_weight:\n            item['B_style0'] = softmax\n\n        return item\n\n    def __len__(self):\n        return max(self.A_size, self.B_size)\n"
  },
  {
    "path": "models/__init__.py",
    "content": "\"\"\"This package contains modules related to objective functions, optimizations, and network architectures.\n\nTo add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.\nYou need to implement the following five functions:\n    -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).\n    -- <set_input>:                     unpack data from dataset and apply preprocessing.\n    -- <forward>:                       produce intermediate results.\n    -- <optimize_parameters>:           calculate loss, gradients, and update network weights.\n    -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.\n\nIn the function <__init__>, you need to define four lists:\n    -- self.loss_names (str list):          specify the training losses that you want to plot and save.\n    -- self.model_names (str list):         define networks used in our training.\n    -- self.visual_names (str list):        specify the images that you want to display and save.\n    -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.\n\nNow you can use the model class by specifying flag '--model dummy'.\nSee our template model class 'template_model.py' for more details.\n\"\"\"\n\nimport importlib\nfrom models.base_model import BaseModel\n\n\ndef find_model_using_name(model_name):\n    \"\"\"Import the module \"models/[model_name]_model.py\".\n\n    In the file, the class called DatasetNameModel() will\n    be instantiated. It has to be a subclass of BaseModel,\n    and it is case-insensitive.\n    \"\"\"\n    model_filename = \"models.\" + model_name + \"_model\"\n    modellib = importlib.import_module(model_filename)\n    model = None\n    target_model_name = model_name.replace('_', '') + 'model'\n    for name, cls in modellib.__dict__.items():\n        if name.lower() == target_model_name.lower() \\\n           and issubclass(cls, BaseModel):\n            model = cls\n\n    if model is None:\n        print(\"In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase.\" % (model_filename, target_model_name))\n        exit(0)\n\n    return model\n\n\ndef get_option_setter(model_name):\n    \"\"\"Return the static method <modify_commandline_options> of the model class.\"\"\"\n    model_class = find_model_using_name(model_name)\n    return model_class.modify_commandline_options\n\n\ndef create_model(opt):\n    \"\"\"Create a model given the option.\n\n    This function warps the class CustomDatasetDataLoader.\n    This is the main interface between this package and 'train.py'/'test.py'\n\n    Example:\n        >>> from models import create_model\n        >>> model = create_model(opt)\n    \"\"\"\n    model = find_model_using_name(opt.model)\n    instance = model(opt)\n    print(\"model [%s] was created\" % type(instance).__name__)\n    return instance\n"
  },
  {
    "path": "models/asymmetric_cycle_gan_cls_model.py",
    "content": "import torch\nimport itertools\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import networks\nimport models.dist_model as dm # numpy==1.14.3\nimport torchvision.transforms as transforms\nimport os\n\ndef truncate(fake_B,a=127.5):#[-1,1]\n    return ((fake_B+1)*a).int().float()/a-1\n\nclass AsymmetricCycleGANClsModel(BaseModel):\n    @staticmethod\n    def modify_commandline_options(parser, is_train=True):\n        parser.set_defaults(no_dropout=True)  # default CycleGAN did not use dropout\n        parser.set_defaults(dataset_mode='unaligned_mask_stylecls')\n        parser.add_argument('--netda', type=str, default='basic_cls')\n        parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A')\n        parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0 (before insert style)')\n        parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')\n        if is_train:\n            parser.add_argument('--lambda_A', type=float, default=5.0, help='weight for cycle loss (A -> B -> A)')\n            parser.add_argument('--lambda_B', type=float, default=5.0, help='weight for cycle loss (B -> A -> B)')\n            parser.add_argument('--lambda_identity', type=float, default=0, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')\n            parser.add_argument('--ntrunc_trunc', type=int, default=1, help='whether use both non-trunc version and trunc version')\n            parser.add_argument('--trunc_a', type=float, default=31.875, help='multiply which value to round when trunc')\n            parser.add_argument('--lambda_A_trunc', type=float, default=5.0, help='weight for cycle loss for trunc')\n            parser.add_argument('--hed_pretrained_mode', type=str, default='./checkpoints/network-bsds500.pytorch', help='path to the pretrained hed model')\n            parser.add_argument('--lambda_G_A_l', type=float, default=0.5, help='weight for local GAN loss in G')\n            parser.add_argument('--style_loss_with_weight', type=int, default=1, help='whether multiply prob in style loss')\n        # for masks\n        parser.add_argument('--use_mask', type=int, default=1, help='whether use mask for special face region')\n        parser.add_argument('--use_eye_mask', type=int, default=1, help='whether use mask for special face region')\n        parser.add_argument('--use_lip_mask', type=int, default=1, help='whether use mask for special face region')\n        parser.add_argument('--mask_type', type=int, default=3, help='use mask type, 0 outside black, 1 outside white')\n        # for style control\n        parser.add_argument('--style_control', type=int, default=1, help='use style_control')\n        parser.add_argument('--sfeature_mode', type=str, default='1vgg19_softmax', help='vgg19 softmax as feature')\n        parser.add_argument('--one_hot', type=int, default=0, help='use one-hot for style code')\n\n        return parser\n\n    def __init__(self, opt):\n        BaseModel.__init__(self, opt)\n        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>\n        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']\n        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>\n        visual_names_A = ['real_A', 'fake_B', 'rec_A']\n        visual_names_B = ['real_B', 'fake_A', 'rec_B']\n        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)\n            visual_names_A.append('idt_B')\n            visual_names_B.append('idt_A')\n        if self.isTrain:\n            visual_names_A.append('real_A_hed')\n            visual_names_A.append('rec_A_hed')\n        if self.isTrain and self.opt.ntrunc_trunc:\n            visual_names_A.append('rec_At')\n            visual_names_A.append('rec_At_hed')\n            self.loss_names = ['D_A', 'G_A', 'cycle_A', 'cycle_A2', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B', 'G']\n        if self.isTrain and self.opt.use_mask:\n            visual_names_A.append('fake_B_l')\n            visual_names_A.append('real_B_l')\n            self.loss_names += ['D_A_l', 'G_A_l']\n        if self.isTrain and self.opt.use_eye_mask:\n            visual_names_A.append('fake_B_le')\n            visual_names_A.append('real_B_le')\n            self.loss_names += ['D_A_le', 'G_A_le']\n        if self.isTrain and self.opt.use_lip_mask:\n            visual_names_A.append('fake_B_ll')\n            visual_names_A.append('real_B_ll')\n            self.loss_names += ['D_A_ll', 'G_A_ll']\n        if not self.isTrain and self.opt.use_mask:\n            visual_names_A.append('fake_B_l')\n            visual_names_A.append('real_B_l')\n        if not self.isTrain and self.opt.use_eye_mask:\n            visual_names_A.append('fake_B_le')\n            visual_names_A.append('real_B_le')\n        if not self.isTrain and self.opt.use_lip_mask:\n            visual_names_A.append('fake_B_ll')\n            visual_names_A.append('real_B_ll')\n        self.loss_names += ['D_A_cls','G_A_cls']\n\n        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B\n        print(self.visual_names)\n        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.\n        if self.isTrain:\n            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']\n            if self.opt.use_mask:\n                self.model_names += ['D_A_l']\n            if self.opt.use_eye_mask:\n                self.model_names += ['D_A_le']\n            if self.opt.use_lip_mask:\n                self.model_names += ['D_A_ll']\n        else:  # during test time, only load Gs\n            self.model_names = ['G_A', 'G_B']\n\n        # define networks (both Generators and discriminators)\n        # The naming is different from those used in the paper.\n        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)\n        if not self.opt.style_control:\n            self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,\n                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)\n        else:\n            print(opt.netga)\n            print('model0_res', opt.model0_res)\n            print('model1_res', opt.model1_res)\n            self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,\n                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)\n        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,\n                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)\n\n        if self.isTrain:  # define discriminators\n            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netda,\n                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids, n_class=3)\n            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,\n                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)\n            if self.opt.use_mask:\n                if self.opt.mask_type in [2, 3]:\n                    output_nc = opt.output_nc + 1\n                else:\n                    output_nc = opt.output_nc\n                self.netD_A_l = networks.define_D(output_nc, opt.ndf, opt.netD,\n                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)\n            if self.opt.use_eye_mask:\n                if self.opt.mask_type in [2, 3]:\n                    output_nc = opt.output_nc + 1\n                else:\n                    output_nc = opt.output_nc\n                self.netD_A_le = networks.define_D(output_nc, opt.ndf, opt.netD,\n                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)\n            if self.opt.use_lip_mask:\n                if self.opt.mask_type in [2, 3]:\n                    output_nc = opt.output_nc + 1\n                else:\n                    output_nc = opt.output_nc\n                self.netD_A_ll = networks.define_D(output_nc, opt.ndf, opt.netD,\n                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)\n        \n        if not self.isTrain:\n            self.criterionGAN = networks.GANLoss('lsgan').to(self.device)\n\n        if self.isTrain:\n            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels\n                assert(opt.input_nc == opt.output_nc)\n            self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images\n            self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images\n            # define loss functions\n            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.\n            self.criterionCycle = torch.nn.L1Loss()\n            self.criterionIdt = torch.nn.L1Loss()\n            self.criterionCls = torch.nn.CrossEntropyLoss()\n            self.criterionCls2 = torch.nn.CrossEntropyLoss(reduction='none')\n            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.\n            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))\n            if not self.opt.use_mask:\n                self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))\n            elif not self.opt.use_eye_mask:\n                D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters())\n                self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))\n            elif not self.opt.use_lip_mask:\n                D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters())\n                self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))\n            else:\n                D_params = list(self.netD_A.parameters()) + list(self.netD_B.parameters()) + list(self.netD_A_l.parameters()) + list(self.netD_A_le.parameters()) + list(self.netD_A_ll.parameters())\n                self.optimizer_D = torch.optim.Adam(D_params, lr=opt.lr, betas=(opt.beta1, 0.999))\n            self.optimizers.append(self.optimizer_G)\n            self.optimizers.append(self.optimizer_D)\n\n            self.lpips = dm.DistModel(opt,model='net-lin',net='alex',use_gpu=True)\n\n            self.hed = networks.define_HED(init_weights_=opt.hed_pretrained_mode, gpu_ids_=self.opt.gpu_ids_p)\n            self.set_requires_grad(self.hed, False)\n\n\n    def set_input(self, input):\n        \"\"\"Unpack input data from the dataloader and perform necessary pre-processing steps.\n\n        Parameters:\n            input (dict): include the data itself and its metadata information.\n\n        The option 'direction' can be used to swap domain A and domain B.\n        \"\"\"\n        AtoB = self.opt.direction == 'AtoB'\n        self.real_A = input['A' if AtoB else 'B'].to(self.device)\n        self.real_B = input['B' if AtoB else 'A'].to(self.device)\n        self.image_paths = input['A_paths' if AtoB else 'B_paths']\n        if self.opt.use_mask:\n            self.A_mask = input['A_mask'].to(self.device)\n            self.B_mask = input['B_mask'].to(self.device)\n        if self.opt.use_eye_mask:\n            self.A_maske = input['A_maske'].to(self.device)\n            self.B_maske = input['B_maske'].to(self.device)\n        if self.opt.use_lip_mask:\n            self.A_maskl = input['A_maskl'].to(self.device)\n            self.B_maskl = input['B_maskl'].to(self.device)\n        if self.opt.style_control:\n            self.real_B_style = input['B_style'].to(self.device)\n            self.real_B_label = input['B_label'].to(self.device)\n        if self.opt.isTrain and self.opt.style_loss_with_weight:\n            self.real_B_style0 = input['B_style0'].to(self.device)\n            self.zero = torch.zeros(self.real_B_label.size(),dtype=torch.int64).to(self.device)\n            self.one = torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device)\n            self.two = 2*torch.ones(self.real_B_label.size(),dtype=torch.int64).to(self.device)\n\n    def forward(self):\n        \"\"\"Run forward pass; called by both functions <optimize_parameters> and <test>.\"\"\"\n        if not self.opt.style_control:\n            self.fake_B = self.netG_A(self.real_A)  # G_A(A)\n        else:\n            #print(torch.mean(self.real_B_style,(2,3)),'style_control')\n            self.fake_B = self.netG_A(self.real_A, self.real_B_style)\n        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))\n        self.fake_A = self.netG_B(self.real_B)  # G_B(B)\n        if not self.opt.style_control:\n            self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))\n        else:\n            #print(torch.mean(self.real_B_style,(2,3)),'style_control')\n            self.rec_B = self.netG_A(self.fake_A, self.real_B_style) # -- cycle_B loss\n\n        if self.opt.use_mask:\n            self.fake_B_l = self.masked(self.fake_B,self.A_mask)\n            self.real_B_l = self.masked(self.real_B,self.B_mask)\n        if self.opt.use_eye_mask:\n            self.fake_B_le = self.masked(self.fake_B,self.A_maske)\n            self.real_B_le = self.masked(self.real_B,self.B_maske)\n        if self.opt.use_lip_mask:\n            self.fake_B_ll = self.masked(self.fake_B,self.A_maskl)\n            self.real_B_ll = self.masked(self.real_B,self.B_maskl)\n\n    def backward_D_basic(self, netD, real, fake):\n        \"\"\"Calculate GAN loss for the discriminator\n\n        Parameters:\n            netD (network)      -- the discriminator D\n            real (tensor array) -- real images\n            fake (tensor array) -- images generated by a generator\n\n        Return the discriminator loss.\n        We also call loss_D.backward() to calculate the gradients.\n        \"\"\"\n        # Real\n        pred_real = netD(real)\n        loss_D_real = self.criterionGAN(pred_real, True)\n        # Fake\n        pred_fake = netD(fake.detach())\n        loss_D_fake = self.criterionGAN(pred_fake, False)\n        # Combined loss and calculate gradients\n        loss_D = (loss_D_real + loss_D_fake) * 0.5\n        loss_D.backward()\n        return loss_D\n    \n    def backward_D_basic_cls(self, netD, real, fake):\n        # Real\n        pred_real, pred_real_cls = netD(real)\n        loss_D_real = self.criterionGAN(pred_real, True)\n        if not self.opt.style_loss_with_weight:\n            loss_D_real_cls = self.criterionCls(pred_real_cls, self.real_B_label)\n        else:\n            loss_D_real_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_real_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_real_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_real_cls, self.two))\n        # Fake\n        pred_fake, pred_fake_cls = netD(fake.detach())\n        loss_D_fake = self.criterionGAN(pred_fake, False)\n        if not self.opt.style_loss_with_weight:\n            loss_D_fake_cls = self.criterionCls(pred_fake_cls, self.real_B_label)\n        else:\n            loss_D_fake_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two))\n        # Combined loss and calculate gradients\n        loss_D = (loss_D_real + loss_D_fake) * 0.5\n        loss_D_cls = (loss_D_real_cls + loss_D_fake_cls) * 0.5\n        loss_D_total = loss_D + loss_D_cls\n        loss_D_total.backward()\n        return loss_D, loss_D_cls\n\n    def backward_D_A(self):\n        \"\"\"Calculate GAN loss for discriminator D_A\"\"\"\n        fake_B = self.fake_B_pool.query(self.fake_B)\n        self.loss_D_A, self.loss_D_A_cls = self.backward_D_basic_cls(self.netD_A, self.real_B, fake_B)\n    \n    def backward_D_A_l(self):\n        \"\"\"Calculate GAN loss for discriminator D_A_l\"\"\"\n        fake_B = self.fake_B_pool.query(self.fake_B)\n        self.loss_D_A_l = self.backward_D_basic(self.netD_A_l, self.masked(self.real_B,self.B_mask), self.masked(fake_B,self.A_mask))\n\n    def backward_D_A_le(self):\n        \"\"\"Calculate GAN loss for discriminator D_A_le\"\"\"\n        fake_B = self.fake_B_pool.query(self.fake_B)\n        self.loss_D_A_le = self.backward_D_basic(self.netD_A_le, self.masked(self.real_B,self.B_maske), self.masked(fake_B,self.A_maske))\n    \n    def backward_D_A_ll(self):\n        \"\"\"Calculate GAN loss for discriminator D_A_ll\"\"\"\n        fake_B = self.fake_B_pool.query(self.fake_B)\n        self.loss_D_A_ll = self.backward_D_basic(self.netD_A_ll, self.masked(self.real_B,self.B_maskl), self.masked(fake_B,self.A_maskl))\n\n    def backward_D_B(self):\n        \"\"\"Calculate GAN loss for discriminator D_B\"\"\"\n        fake_A = self.fake_A_pool.query(self.fake_A)\n        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)\n    \n    def update_process(self, epoch):\n        self.process = (epoch - 1) / float(self.opt.niter_decay + self.opt.niter)\n\n    def backward_G(self):\n        \"\"\"Calculate the loss for generators G_A and G_B\"\"\"\n        lambda_idt = self.opt.lambda_identity\n        lambda_G_A_l = self.opt.lambda_G_A_l\n        lambda_A = self.opt.lambda_A\n        lambda_B = self.opt.lambda_B\n        lambda_A_trunc = self.opt.lambda_A_trunc\n        if self.opt.ntrunc_trunc:\n            lambda_A = lambda_A * (1 - self.process * 0.9)\n            lambda_A_trunc = lambda_A_trunc * self.process * 0.9\n        self.lambda_As = [lambda_A, lambda_A_trunc]\n        # Identity loss\n        if lambda_idt > 0:\n            # G_A should be identity if real_B is fed: ||G_A(B) - B||\n            self.idt_A = self.netG_A(self.real_B)\n            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt\n            # G_B should be identity if real_A is fed: ||G_B(A) - A||\n            self.idt_B = self.netG_B(self.real_A)\n            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt\n        else:\n            self.loss_idt_A = 0\n            self.loss_idt_B = 0\n\n        # GAN loss D_A(G_A(A))\n        pred_fake, pred_fake_cls = self.netD_A(self.fake_B)\n        self.loss_G_A = self.criterionGAN(pred_fake, True)\n        if not self.opt.style_loss_with_weight:\n            self.loss_G_A_cls = self.criterionCls(pred_fake_cls, self.real_B_label)\n        else:\n            self.loss_G_A_cls = torch.mean(self.real_B_style0[:,0] * self.criterionCls2(pred_fake_cls, self.zero) + self.real_B_style0[:,1] * self.criterionCls2(pred_fake_cls, self.one) + self.real_B_style0[:,2] * self.criterionCls2(pred_fake_cls, self.two))\n        if self.opt.use_mask:\n            self.loss_G_A_l = self.criterionGAN(self.netD_A_l(self.fake_B_l), True) * lambda_G_A_l\n        if self.opt.use_eye_mask:\n            self.loss_G_A_le = self.criterionGAN(self.netD_A_le(self.fake_B_le), True) * lambda_G_A_l\n        if self.opt.use_lip_mask:\n            self.loss_G_A_ll = self.criterionGAN(self.netD_A_ll(self.fake_B_ll), True) * lambda_G_A_l\n        # GAN loss D_B(G_B(B))\n        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)\n\n        # Forward cycle loss  LPIPS( HED(G_B(G_A(A))), HED(A))\n        ts = self.real_A.shape\n        gpu_p = self.opt.gpu_ids_p[0]\n        gpu = self.opt.gpu_ids[0]\n        rec_A_hed = (self.hed(self.rec_A.cuda(gpu_p)/2+0.5)-0.5)*2\n        real_A_hed = (self.hed(self.real_A.cuda(gpu_p)/2+0.5)-0.5)*2\n        self.loss_cycle_A = (self.lpips.forward_pair(rec_A_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A\n        self.rec_A_hed = rec_A_hed\n        self.real_A_hed = real_A_hed\n        if self.opt.ntrunc_trunc:\n            self.rec_At = self.netG_B(truncate(self.fake_B,self.opt.trunc_a))\n            rec_At_hed = (self.hed(self.rec_At.cuda(gpu_p)/2+0.5)-0.5)*2\n            self.loss_cycle_A2 = (self.lpips.forward_pair(rec_At_hed.expand(ts), real_A_hed.expand(ts)).mean()).cuda(gpu) * lambda_A_trunc\n            self.rec_At_hed = rec_At_hed\n\n        # Backward cycle loss || G_A(G_B(B)) - B||\n        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B\n        \n        # combined loss and calculate gradients\n        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B\n        if getattr(self,'loss_cycle_A2',-1) != -1:\n            self.loss_G = self.loss_G + self.loss_cycle_A2\n        if getattr(self,'loss_G_A_l',-1) != -1:\n            self.loss_G = self.loss_G + self.loss_G_A_l\n        if getattr(self,'loss_G_A_le',-1) != -1:\n            self.loss_G = self.loss_G + self.loss_G_A_le\n        if getattr(self,'loss_G_A_ll',-1) != -1:\n            self.loss_G = self.loss_G + self.loss_G_A_ll\n        if getattr(self,'loss_G_A_cls',-1) != -1:\n            self.loss_G = self.loss_G + self.loss_G_A_cls\n        self.loss_G.backward()\n\n    def optimize_parameters(self):\n        \"\"\"Calculate losses, gradients, and update network weights; called in every training iteration\"\"\"\n        # forward\n        self.forward()      # compute fake images and reconstruction images.\n        # G_A and G_B\n        self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs\n        if self.opt.use_mask:\n            self.set_requires_grad([self.netD_A_l], False)\n        if self.opt.use_eye_mask:\n            self.set_requires_grad([self.netD_A_le], False)\n        if self.opt.use_lip_mask:\n            self.set_requires_grad([self.netD_A_ll], False)\n        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero\n        self.backward_G()             # calculate gradients for G_A and G_B\n        self.optimizer_G.step()       # update G_A and G_B's weights\n        # D_A and D_B\n        self.set_requires_grad([self.netD_A, self.netD_B], True)\n        if self.opt.use_mask:\n            self.set_requires_grad([self.netD_A_l], True)\n        if self.opt.use_eye_mask:\n            self.set_requires_grad([self.netD_A_le], True)\n        if self.opt.use_lip_mask:\n            self.set_requires_grad([self.netD_A_ll], True)\n        self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero\n        self.backward_D_A()      # calculate gradients for D_A\n        if self.opt.use_mask:\n            self.backward_D_A_l()# calculate gradients for D_A_l\n        if self.opt.use_eye_mask:\n            self.backward_D_A_le()# calculate gradients for D_A_le\n        if self.opt.use_lip_mask:\n            self.backward_D_A_ll()# calculate gradients for D_A_ll\n        self.backward_D_B()      # calculate graidents for D_B\n        self.optimizer_D.step()  # update D_A and D_B's weights\n"
  },
  {
    "path": "models/base_model.py",
    "content": "import os\nimport torch\nfrom collections import OrderedDict\nfrom abc import ABCMeta, abstractmethod\nfrom . import networks\n\n\nclass BaseModel():\n    __metaclass__ = ABCMeta\n    \"\"\"This class is an abstract base class (ABC) for models.\n    To create a subclass, you need to implement the following five functions:\n        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).\n        -- <set_input>:                     unpack data from dataset and apply preprocessing.\n        -- <forward>:                       produce intermediate results.\n        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.\n        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize the BaseModel class.\n\n        Parameters:\n            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions\n\n        When creating your custom class, you need to implement your own initialization.\n        In this fucntion, you should first call <BaseModel.__init__(self, opt)>\n        Then, you need to define four lists:\n            -- self.loss_names (str list):          specify the training losses that you want to plot and save.\n            -- self.model_names (str list):         specify the images that you want to display and save.\n            -- self.visual_names (str list):        define networks used in our training.\n            -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.\n        \"\"\"\n        self.opt = opt\n        self.gpu_ids = opt.gpu_ids\n        self.isTrain = opt.isTrain\n        self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')  # get device name: CPU or GPU\n        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir\n        if opt.preprocess != 'scale_width':  # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.\n            torch.backends.cudnn.benchmark = True\n        self.loss_names = []\n        self.model_names = []\n        self.visual_names = []\n        self.optimizers = []\n        self.image_paths = []\n        self.metric = 0  # used for learning rate policy 'plateau'\n\n    @staticmethod\n    def modify_commandline_options(parser, is_train):\n        \"\"\"Add new model-specific options, and rewrite default values for existing options.\n\n        Parameters:\n            parser          -- original option parser\n            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.\n\n        Returns:\n            the modified parser.\n        \"\"\"\n        return parser\n\n    @abstractmethod\n    def set_input(self, input):\n        \"\"\"Unpack input data from the dataloader and perform necessary pre-processing steps.\n\n        Parameters:\n            input (dict): includes the data itself and its metadata information.\n        \"\"\"\n        pass\n\n    @abstractmethod\n    def forward(self):\n        \"\"\"Run forward pass; called by both functions <optimize_parameters> and <test>.\"\"\"\n        pass\n\n    @abstractmethod\n    def optimize_parameters(self):\n        \"\"\"Calculate losses, gradients, and update network weights; called in every training iteration\"\"\"\n        pass\n\n    def setup(self, opt):\n        \"\"\"Load and print networks; create schedulers\n\n        Parameters:\n            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        if self.isTrain:\n            self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]\n        if not self.isTrain or opt.continue_train:\n            load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch\n            self.load_networks(load_suffix)\n        self.print_networks(opt.verbose)\n\n    def eval(self):\n        \"\"\"Make models eval mode during test time\"\"\"\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, 'net' + name)\n                net.eval()\n\n    def test(self):\n        \"\"\"Forward function used in test time.\n\n        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop\n        It also calls <compute_visuals> to produce additional visualization results\n        \"\"\"\n        with torch.no_grad():\n            self.forward()\n            self.compute_visuals()\n\n    def compute_visuals(self):\n        \"\"\"Calculate additional output images for visdom and HTML visualization\"\"\"\n        pass\n\n    def get_image_paths(self):\n        \"\"\" Return image paths that are used to load current data\"\"\"\n        return self.image_paths\n\n    def update_learning_rate(self):\n        \"\"\"Update learning rates for all the networks; called at the end of every epoch\"\"\"\n        for scheduler in self.schedulers:\n            if self.opt.lr_policy == 'plateau':\n                scheduler.step(self.metric)\n            else:\n                scheduler.step()\n\n        lr = self.optimizers[0].param_groups[0]['lr']\n        print('learning rate = %.7f' % lr)\n\n    def get_current_visuals(self):\n        \"\"\"Return visualization images. train.py will display these images with visdom, and save the images to a HTML\"\"\"\n        visual_ret = OrderedDict()\n        for name in self.visual_names:\n            if isinstance(name, str):\n                visual_ret[name] = getattr(self, name)\n        return visual_ret\n\n    def get_current_losses(self):\n        \"\"\"Return traning losses / errors. train.py will print out these errors on console, and save them to a file\"\"\"\n        errors_ret = OrderedDict()\n        for name in self.loss_names:\n            if isinstance(name, str):\n                errors_ret[name] = float(getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number\n        return errors_ret\n\n    def save_networks(self, epoch):\n        \"\"\"Save all the networks to the disk.\n\n        Parameters:\n            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)\n        \"\"\"\n        for name in self.model_names:\n            if isinstance(name, str):\n                save_filename = '%s_net_%s.pth' % (epoch, name)\n                save_path = os.path.join(self.save_dir, save_filename)\n                net = getattr(self, 'net' + name)\n\n                if len(self.gpu_ids) > 0 and torch.cuda.is_available():\n                    torch.save(net.module.cpu().state_dict(), save_path)\n                    net.cuda(self.gpu_ids[0])\n                else:\n                    torch.save(net.cpu().state_dict(), save_path)\n\n    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):\n        \"\"\"Fix InstanceNorm checkpoints incompatibility (prior to 0.4)\"\"\"\n        key = keys[i]\n        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer\n            if module.__class__.__name__.startswith('InstanceNorm') and \\\n                    (key == 'running_mean' or key == 'running_var'):\n                if getattr(module, key) is None:\n                    state_dict.pop('.'.join(keys))\n            if module.__class__.__name__.startswith('InstanceNorm') and \\\n               (key == 'num_batches_tracked'):\n                state_dict.pop('.'.join(keys))\n        else:\n            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)\n\n    def load_networks(self, epoch):\n        \"\"\"Load all the networks from the disk.\n\n        Parameters:\n            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)\n        \"\"\"\n        for name in self.model_names:\n            if isinstance(name, str):\n                load_filename = '%s_net_%s.pth' % (epoch, name)\n                load_path = os.path.join(self.save_dir, load_filename)\n                net = getattr(self, 'net' + name)\n                if isinstance(net, torch.nn.DataParallel):\n                    net = net.module\n                print('loading the model from %s' % load_path)\n                # if you are using PyTorch newer than 0.4 (e.g., built from\n                # GitHub source), you can remove str() on self.device\n                state_dict = torch.load(load_path, map_location=str(self.device))\n                if hasattr(state_dict, '_metadata'):\n                    del state_dict._metadata\n\n                # patch InstanceNorm checkpoints prior to 0.4\n                for key in list(state_dict.keys()):  # need to copy keys here because we mutate in loop\n                    self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))\n                net.load_state_dict(state_dict)\n\n    def print_networks(self, verbose):\n        \"\"\"Print the total number of parameters in the network and (if verbose) network architecture\n\n        Parameters:\n            verbose (bool) -- if verbose: print the network architecture\n        \"\"\"\n        print('---------- Networks initialized -------------')\n        for name in self.model_names:\n            if isinstance(name, str):\n                net = getattr(self, 'net' + name)\n                num_params = 0\n                for param in net.parameters():\n                    num_params += param.numel()\n                if verbose:\n                    print(net)\n                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))\n        print('-----------------------------------------------')\n\n    def set_requires_grad(self, nets, requires_grad=False):\n        \"\"\"Set requies_grad=Fasle for all the networks to avoid unnecessary computations\n        Parameters:\n            nets (network list)   -- a list of networks\n            requires_grad (bool)  -- whether the networks require gradients or not\n        \"\"\"\n        if not isinstance(nets, list):\n            nets = [nets]\n        for net in nets:\n            if net is not None:\n                for param in net.parameters():\n                    param.requires_grad = requires_grad\n    \n    # ===========================================================================================================\n    def masked(self, A,mask):\n        if self.opt.mask_type == 0:\n            return (A/2+0.5)*mask*2-1\n        elif self.opt.mask_type == 1:\n            return ((A/2+0.5)*mask+1-mask)*2-1\n        elif self.opt.mask_type == 2:\n            return torch.cat((A, mask), 1)\n        elif self.opt.mask_type == 3:\n            masked = ((A/2+0.5)*mask+1-mask)*2-1\n            return torch.cat((masked, mask), 1)"
  },
  {
    "path": "models/dist_model.py",
    "content": "\nfrom __future__ import absolute_import\n\nimport sys\nsys.path.append('..')\nsys.path.append('.')\nimport numpy as np\nimport torch\nfrom torch import nn\nfrom collections import OrderedDict\nfrom torch.autograd import Variable\nfrom .base_model import BaseModel\nfrom scipy.ndimage import zoom\nimport skimage.transform\n\nfrom . import networks_basic as networks\n# from PerceptualSimilarity.util import util\nfrom util import util\n\nclass DistModel(BaseModel):\n    def name(self):\n        return self.model_name\n\n    def __init__(self, opt, model='net-lin', net='alex', pnet_rand=False, pnet_tune=False, model_path=None, colorspace='Lab', use_gpu=True, printNet=False, spatial=False, spatial_shape=None, spatial_order=1, spatial_factor=None, is_train=False, lr=.0001, beta1=0.5, version='0.1'):\n        '''\n        INPUTS\n            model - ['net-lin'] for linearly calibrated network\n                    ['net'] for off-the-shelf network\n                    ['L2'] for L2 distance in Lab colorspace\n                    ['SSIM'] for ssim in RGB colorspace\n            net - ['squeeze','alex','vgg']\n            model_path - if None, will look in weights/[NET_NAME].pth\n            colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM\n            use_gpu - bool - whether or not to use a GPU\n            printNet - bool - whether or not to print network architecture out\n            spatial - bool - whether to output an array containing varying distances across spatial dimensions\n            spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below).\n            spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images.\n            spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear).\n            is_train - bool - [True] for training mode\n            lr - float - initial learning rate\n            beta1 - float - initial momentum term for adam\n            version - 0.1 for latest, 0.0 was original\n        '''\n        BaseModel.__init__(self, opt)\n\n        self.model = model\n        self.net = net\n        self.use_gpu = use_gpu\n        self.is_train = is_train\n        self.spatial = spatial\n        self.spatial_shape = spatial_shape\n        self.spatial_order = spatial_order\n        self.spatial_factor = spatial_factor\n\n        self.model_name = '%s [%s]'%(model,net)\n        if(self.model == 'net-lin'): # pretrained net + linear layer\n            #self.device = torch.device('cuda:{}'.format(opt.gpu_ids[0])) if opt.gpu_ids else torch.device('cpu')\n            self.device = torch.device('cuda:{}'.format(opt.gpu_ids_p[0])) if opt.gpu_ids_p else torch.device('cpu')\n            self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,use_dropout=True,spatial=spatial,version=version,lpips=True).to(self.device)\n            kw = {}\n            \n            if not use_gpu:\n                kw['map_location'] = 'cpu'\n            if(model_path is None):\n                import inspect\n                #model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', '..', 'weights/v%s/%s.pth'%(version,net)))\n                model_path = './checkpoints/weights/v%s/%s.pth'%(version,net)\n\n            if(not is_train):\n                print('Loading model from: %s'%model_path)\n                #self.net.load_state_dict(torch.load(model_path, **kw))\n                state_dict = torch.load(model_path, map_location=str(self.device))\n                self.net.load_state_dict(state_dict, strict=False)\n\n        elif(self.model=='net'): # pretrained network\n            assert not self.spatial, 'spatial argument not supported yet for uncalibrated networks'\n            self.net = networks.PNet(use_gpu=use_gpu,pnet_type=net,device=self.device)\n            self.is_fake_net = True\n        elif(self.model in ['L2','l2']):\n            self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace,device=self.device) # not really a network, only for testing\n            self.model_name = 'L2'\n        elif(self.model in ['DSSIM','dssim','SSIM','ssim']):\n            self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace,device=self.device)\n            self.model_name = 'SSIM'\n        else:\n            raise ValueError(\"Model [%s] not recognized.\" % self.model)\n\n        self.parameters = list(self.net.parameters())\n\n        if self.is_train: # training mode\n            # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)\n            self.rankLoss = networks.BCERankingLoss(use_gpu=use_gpu,device=self.device)\n            self.parameters+=self.rankLoss.parameters\n            self.lr = lr\n            self.old_lr = lr\n            self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))\n        else: # test mode\n            self.net.eval()\n\n        if(printNet):\n            print('---------- Networks initialized -------------')\n            networks.print_network(self.net)\n            print('-----------------------------------------------')\n\n    def forward_pair(self,in1,in2,retPerLayer=False):\n        if(retPerLayer):\n            return self.net.forward(in1,in2, retPerLayer=True)\n        else:\n            return self.net.forward(in1,in2)\n\n    def forward(self, in0, in1, retNumpy=False):\n        ''' Function computes the distance between image patches in0 and in1\n        INPUTS\n            in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]\n            retNumpy - [False] to return as torch.Tensor, [True] to return as numpy array\n        OUTPUT\n            computed distances between in0 and in1\n        '''\n\n        self.input_ref = in0\n        self.input_p0 = in1\n\n        self.var_ref = Variable(self.input_ref,requires_grad=True)\n        self.var_p0 = Variable(self.input_p0,requires_grad=True)\n\n        self.d0 = self.forward_pair(self.var_ref, self.var_p0)\n        self.loss_total = self.d0\n\n        def convert_output(d0):\n            if(retNumpy):\n                ans = d0.cpu().data.numpy()\n                if not self.spatial:\n                    ans = ans.flatten()\n                else:\n                    assert(ans.shape[0] == 1 and len(ans.shape) == 4)\n                    return ans[0,...].transpose([1, 2, 0])                  # Reshape to usual numpy image format: (height, width, channels)\n                return ans\n            else:\n                return d0\n\n        if self.spatial:\n            L = [convert_output(x) for x in self.d0]\n            spatial_shape = self.spatial_shape\n            if spatial_shape is None:\n                if(self.spatial_factor is None):\n                    spatial_shape = (in0.size()[2],in0.size()[3])\n                else:\n                    spatial_shape = (max([x.shape[0] for x in L])*self.spatial_factor, max([x.shape[1] for x in L])*self.spatial_factor)\n            \n            L = [skimage.transform.resize(x, spatial_shape, order=self.spatial_order, mode='edge') for x in L]\n            \n            L = np.mean(np.concatenate(L, 2) * len(L), 2)\n            return L\n        else:\n            return convert_output(self.d0)\n\n    # ***** TRAINING FUNCTIONS *****\n    def optimize_parameters(self):\n        self.forward_train()\n        self.optimizer_net.zero_grad()\n        self.backward_train()\n        self.optimizer_net.step()\n        self.clamp_weights()\n\n    def clamp_weights(self):\n        for module in self.net.modules():\n            if(hasattr(module, 'weight') and module.kernel_size==(1,1)):\n                module.weight.data = torch.clamp(module.weight.data,min=0)\n\n    def set_input(self, data):\n        self.input_ref = data['ref']\n        self.input_p0 = data['p0']\n        self.input_p1 = data['p1']\n        self.input_judge = data['judge']\n\n        if(self.use_gpu):\n            self.input_ref = self.input_ref.cuda(self.device)\n            self.input_p0 = self.input_p0.cuda(self.device)\n            self.input_p1 = self.input_p1.cuda(self.device)\n            self.input_judge = self.input_judge.cuda(self.device)\n\n        self.var_ref = Variable(self.input_ref,requires_grad=True)\n        self.var_p0 = Variable(self.input_p0,requires_grad=True)\n        self.var_p1 = Variable(self.input_p1,requires_grad=True)\n\n    def forward_train(self): # run forward pass\n        self.d0 = self.forward_pair(self.var_ref, self.var_p0)\n        self.d1 = self.forward_pair(self.var_ref, self.var_p1)\n        self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)\n\n        # var_judge\n        self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())\n\n        self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)\n        return self.loss_total\n\n    def backward_train(self):\n        torch.mean(self.loss_total).backward()\n\n    def compute_accuracy(self,d0,d1,judge):\n        ''' d0, d1 are Variables, judge is a Tensor '''\n        d1_lt_d0 = (d1<d0).cpu().data.numpy().flatten()\n        judge_per = judge.cpu().numpy().flatten()\n        return d1_lt_d0*judge_per + (1-d1_lt_d0)*(1-judge_per)\n\n    def get_current_errors(self):\n        retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()),\n                            ('acc_r', self.acc_r)])\n\n        for key in retDict.keys():\n            retDict[key] = np.mean(retDict[key])\n\n        return retDict\n\n    def get_current_visuals(self):\n        zoom_factor = 256/self.var_ref.data.size()[2]\n\n        ref_img = util.tensor2im(self.var_ref.data)\n        p0_img = util.tensor2im(self.var_p0.data)\n        p1_img = util.tensor2im(self.var_p1.data)\n\n        ref_img_vis = zoom(ref_img,[zoom_factor, zoom_factor, 1],order=0)\n        p0_img_vis = zoom(p0_img,[zoom_factor, zoom_factor, 1],order=0)\n        p1_img_vis = zoom(p1_img,[zoom_factor, zoom_factor, 1],order=0)\n\n        return OrderedDict([('ref', ref_img_vis),\n                            ('p0', p0_img_vis),\n                            ('p1', p1_img_vis)])\n\n    def save(self, path, label):\n        self.save_network(self.net, path, '', label)\n        self.save_network(self.rankLoss.net, path, 'rank', label)\n\n    def update_learning_rate(self,nepoch_decay):\n        lrd = self.lr / nepoch_decay\n        lr = self.old_lr - lrd\n\n        for param_group in self.optimizer_net.param_groups:\n            param_group['lr'] = lr\n\n        print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr))\n        self.old_lr = lr\n\n\n\ndef score_2afc_dataset(data_loader,func):\n    ''' Function computes Two Alternative Forced Choice (2AFC) score using\n        distance function 'func' in dataset 'data_loader'\n    INPUTS\n        data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside\n        func - callable distance function - calling d=func(in0,in1) should take 2\n            pytorch tensors with shape Nx3xXxY, and return numpy array of length N\n    OUTPUTS\n        [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators\n        [1] - dictionary with following elements\n            d0s,d1s - N arrays containing distances between reference patch to perturbed patches \n            gts - N array in [0,1], preferred patch selected by human evaluators\n                (closer to \"0\" for left patch p0, \"1\" for right patch p1,\n                \"0.6\" means 60pct people preferred right patch, 40pct preferred left)\n            scores - N array in [0,1], corresponding to what percentage function agreed with humans\n    CONSTS\n        N - number of test triplets in data_loader\n    '''\n\n    d0s = []\n    d1s = []\n    gts = []\n\n    # bar = pb.ProgressBar(max_value=data_loader.load_data().__len__())\n    for (i,data) in enumerate(data_loader.load_data()):\n        d0s+=func(data['ref'],data['p0']).tolist()\n        d1s+=func(data['ref'],data['p1']).tolist()\n        gts+=data['judge'].cpu().numpy().flatten().tolist()\n        # bar.update(i)\n\n    d0s = np.array(d0s)\n    d1s = np.array(d1s)\n    gts = np.array(gts)\n    scores = (d0s<d1s)*(1.-gts) + (d1s<d0s)*gts + (d1s==d0s)*.5\n\n    return(np.mean(scores), dict(d0s=d0s,d1s=d1s,gts=gts,scores=scores))\n\ndef score_jnd_dataset(data_loader,func):\n    ''' Function computes JND score using distance function 'func' in dataset 'data_loader'\n    INPUTS\n        data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside\n        func - callable distance function - calling d=func(in0,in1) should take 2\n            pytorch tensors with shape Nx3xXxY, and return numpy array of length N\n    OUTPUTS\n        [0] - JND score in [0,1], mAP score (area under precision-recall curve)\n        [1] - dictionary with following elements\n            ds - N array containing distances between two patches shown to human evaluator\n            sames - N array containing fraction of people who thought the two patches were identical\n    CONSTS\n        N - number of test triplets in data_loader\n    '''\n\n    ds = []\n    gts = []\n\n    # bar = pb.ProgressBar(max_value=data_loader.load_data().__len__())\n    for (i,data) in enumerate(data_loader.load_data()):\n        ds+=func(data['p0'],data['p1']).tolist()\n        gts+=data['same'].cpu().numpy().flatten().tolist()\n        # bar.update(i)\n\n    sames = np.array(gts)\n    ds = np.array(ds)\n\n    sorted_inds = np.argsort(ds)\n    ds_sorted = ds[sorted_inds]\n    sames_sorted = sames[sorted_inds]\n\n    TPs = np.cumsum(sames_sorted)\n    FPs = np.cumsum(1-sames_sorted)\n    FNs = np.sum(sames_sorted)-TPs\n\n    precs = TPs/(TPs+FPs)\n    recs = TPs/(TPs+FNs)\n    score = util.voc_ap(recs,precs)\n\n    return(score, dict(ds=ds,sames=sames))\n"
  },
  {
    "path": "models/networks.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools\nfrom torch.optim import lr_scheduler\n\n\n###############################################################################\n# Helper Functions\n###############################################################################\n\n\nclass Identity(nn.Module):\n\tdef forward(self, x):\n\t\treturn x\n\n\ndef get_norm_layer(norm_type='instance'):\n\t\"\"\"Return a normalization layer\n\n\tParameters:\n\t\tnorm_type (str) -- the name of the normalization layer: batch | instance | none\n\n\tFor BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).\n\tFor InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.\n\t\"\"\"\n\tif norm_type == 'batch':\n\t\tnorm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)\n\telif norm_type == 'instance':\n\t\tnorm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)\n\telif norm_type == 'none':\n\t\tnorm_layer = lambda x: Identity()\n\telse:\n\t\traise NotImplementedError('normalization layer [%s] is not found' % norm_type)\n\treturn norm_layer\n\n\ndef get_scheduler(optimizer, opt):\n\t\"\"\"Return a learning rate scheduler\n\n\tParameters:\n\t\toptimizer          -- the optimizer of the network\n\t\topt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions．　\n\t\t\t\t\t\t\t  opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine\n\n\tFor 'linear', we keep the same learning rate for the first <opt.niter> epochs\n\tand linearly decay the rate to zero over the next <opt.niter_decay> epochs.\n\tFor other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.\n\tSee https://pytorch.org/docs/stable/optim.html for more details.\n\t\"\"\"\n\tif opt.lr_policy == 'linear':\n\t\tdef lambda_rule(epoch):\n\t\t\tlr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)\n\t\t\treturn lr_l\n\t\tscheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)\n\telif opt.lr_policy == 'step':\n\t\tscheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)\n\telif opt.lr_policy == 'plateau':\n\t\tscheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)\n\telif opt.lr_policy == 'cosine':\n\t\tscheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0)\n\telse:\n\t\treturn NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)\n\treturn scheduler\n\n\ndef init_weights(net, init_type='normal', init_gain=0.02):\n\t\"\"\"Initialize network weights.\n\n\tParameters:\n\t\tnet (network)   -- network to be initialized\n\t\tinit_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal\n\t\tinit_gain (float)    -- scaling factor for normal, xavier and orthogonal.\n\n\tWe use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might\n\twork better for some applications. Feel free to try yourself.\n\t\"\"\"\n\tdef init_func(m):  # define the initialization function\n\t\tclassname = m.__class__.__name__\n\t\tif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):\n\t\t\tif init_type == 'normal':\n\t\t\t\tinit.normal_(m.weight.data, 0.0, init_gain)\n\t\t\telif init_type == 'xavier':\n\t\t\t\tinit.xavier_normal_(m.weight.data, gain=init_gain)\n\t\t\telif init_type == 'kaiming':\n\t\t\t\tinit.kaiming_normal_(m.weight.data, a=0, mode='fan_in')\n\t\t\telif init_type == 'orthogonal':\n\t\t\t\tinit.orthogonal_(m.weight.data, gain=init_gain)\n\t\t\telse:\n\t\t\t\traise NotImplementedError('initialization method [%s] is not implemented' % init_type)\n\t\t\tif hasattr(m, 'bias') and m.bias is not None:\n\t\t\t\tinit.constant_(m.bias.data, 0.0)\n\t\telif classname.find('BatchNorm2d') != -1:  # BatchNorm Layer's weight is not a matrix; only normal distribution applies.\n\t\t\tinit.normal_(m.weight.data, 1.0, init_gain)\n\t\t\tinit.constant_(m.bias.data, 0.0)\n\n\tprint('initialize network with %s' % init_type)\n\tnet.apply(init_func)  # apply the initialization function <init_func>\n\n\ndef init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):\n\t\"\"\"Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights\n\tParameters:\n\t\tnet (network)      -- the network to be initialized\n\t\tinit_type (str)    -- the name of an initialization method: normal | xavier | kaiming | orthogonal\n\t\tgain (float)       -- scaling factor for normal, xavier and orthogonal.\n\t\tgpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2\n\n\tReturn an initialized network.\n\t\"\"\"\n\tif len(gpu_ids) > 0:\n\t\tassert(torch.cuda.is_available())\n\t\tnet.to(gpu_ids[0])\n\t\tnet = torch.nn.DataParallel(net, gpu_ids)  # multi-GPUs\n\tinit_weights(net, init_type, init_gain=init_gain)\n\treturn net\n\n\ndef define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], model0_res=0, model1_res=0, extra_channel=3):\n\t\"\"\"Create a generator\n\n\tParameters:\n\t\tinput_nc (int) -- the number of channels in input images\n\t\toutput_nc (int) -- the number of channels in output images\n\t\tngf (int) -- the number of filters in the last conv layer\n\t\tnetG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128\n\t\tnorm (str) -- the name of normalization layers used in the network: batch | instance | none\n\t\tuse_dropout (bool) -- if use dropout layers.\n\t\tinit_type (str)    -- the name of our initialization method.\n\t\tinit_gain (float)  -- scaling factor for normal, xavier and orthogonal.\n\t\tgpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2\n\n\tReturns a generator\n\n\tOur current implementation provides two types of generators:\n\t\tU-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)\n\t\tThe original U-Net paper: https://arxiv.org/abs/1505.04597\n\n\t\tResnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)\n\t\tResnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.\n\t\tWe adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).\n\n\n\tThe generator has been initialized by <init_net>. It uses RELU for non-linearity.\n\t\"\"\"\n\tnet = None\n\tnorm_layer = get_norm_layer(norm_type=norm)\n\n\tif netG == 'resnet_9blocks':\n\t\tnet = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)\n\telif netG == 'resnet_style2_9blocks':\n\t\tnet = ResnetStyle2Generator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, model0_res=model0_res, extra_channel=extra_channel)\n\telif netG == 'resnet_6blocks':\n\t\tnet = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)\n\telif netG == 'unet_128':\n\t\tnet = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)\n\telif netG == 'unet_256':\n\t\tnet = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)\n\telse:\n\t\traise NotImplementedError('Generator model name [%s] is not recognized' % netG)\n\treturn init_net(net, init_type, init_gain, gpu_ids)\n\n\ndef define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[], n_class=3):\n\t\"\"\"Create a discriminator\n\n\tParameters:\n\t\tinput_nc (int)     -- the number of channels in input images\n\t\tndf (int)          -- the number of filters in the first conv layer\n\t\tnetD (str)         -- the architecture's name: basic | n_layers | pixel\n\t\tn_layers_D (int)   -- the number of conv layers in the discriminator; effective when netD=='n_layers'\n\t\tnorm (str)         -- the type of normalization layers used in the network.\n\t\tinit_type (str)    -- the name of the initialization method.\n\t\tinit_gain (float)  -- scaling factor for normal, xavier and orthogonal.\n\t\tgpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2\n\n\tReturns a discriminator\n\n\tOur current implementation provides three types of discriminators:\n\t\t[basic]: 'PatchGAN' classifier described in the original pix2pix paper.\n\t\tIt can classify whether 70×70 overlapping patches are real or fake.\n\t\tSuch a patch-level discriminator architecture has fewer parameters\n\t\tthan a full-image discriminator and can work on arbitrarily-sized images\n\t\tin a fully convolutional fashion.\n\n\t\t[n_layers]: With this mode, you cna specify the number of conv layers in the discriminator\n\t\twith the parameter <n_layers_D> (default=3 as used in [basic] (PatchGAN).)\n\n\t\t[pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not.\n\t\tIt encourages greater color diversity but has no effect on spatial statistics.\n\n\tThe discriminator has been initialized by <init_net>. It uses Leakly RELU for non-linearity.\n\t\"\"\"\n\tnet = None\n\tnorm_layer = get_norm_layer(norm_type=norm)\n\n\tif netD == 'basic':  # default PatchGAN classifier\n\t\tnet = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)\n\telif netD == 'basic_cls':\n\t\tnet = NLayerDiscriminatorCls(input_nc, ndf, n_layers=3, n_class=3, norm_layer=norm_layer)\n\telif netD == 'n_layers':  # more options\n\t\tnet = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)\n\telif netD == 'pixel':     # classify if each pixel is real or fake\n\t\tnet = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)\n\telse:\n\t\traise NotImplementedError('Discriminator model name [%s] is not recognized' % net)\n\treturn init_net(net, init_type, init_gain, gpu_ids)\n\n\ndef define_HED(init_weights_, gpu_ids_=[]):\n\tnet = HED()\n\n\tif len(gpu_ids_) > 0:\n\t\tassert(torch.cuda.is_available())\n\t\tnet.to(gpu_ids_[0])\n\t\tnet = torch.nn.DataParallel(net, gpu_ids_)  # multi-GPUs\n\t\n\tif not init_weights_ == None:\n\t\tdevice = torch.device('cuda:{}'.format(gpu_ids_[0])) if gpu_ids_ else torch.device('cpu')\n\t\tprint('Loading model from: %s'%init_weights_)\n\t\tstate_dict = torch.load(init_weights_, map_location=str(device))\n\t\tif isinstance(net, torch.nn.DataParallel):\n\t\t\tnet.module.load_state_dict(state_dict)\n\t\telse:\n\t\t\tnet.load_state_dict(state_dict)\n\t\tprint('load the weights successfully')\n\n\treturn net\n\n##############################################################################\n# Classes\n##############################################################################\nclass GANLoss(nn.Module):\n\t\"\"\"Define different GAN objectives.\n\n\tThe GANLoss class abstracts away the need to create the target label tensor\n\tthat has the same size as the input.\n\t\"\"\"\n\n\tdef __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):\n\t\t\"\"\" Initialize the GANLoss class.\n\n\t\tParameters:\n\t\t\tgan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.\n\t\t\ttarget_real_label (bool) - - label for a real image\n\t\t\ttarget_fake_label (bool) - - label of a fake image\n\n\t\tNote: Do not use sigmoid as the last layer of Discriminator.\n\t\tLSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.\n\t\t\"\"\"\n\t\tsuper(GANLoss, self).__init__()\n\t\tself.register_buffer('real_label', torch.tensor(target_real_label))\n\t\tself.register_buffer('fake_label', torch.tensor(target_fake_label))\n\t\tself.gan_mode = gan_mode\n\t\tif gan_mode == 'lsgan':#cyclegan\n\t\t\tself.loss = nn.MSELoss()\n\t\telif gan_mode == 'vanilla':\n\t\t\tself.loss = nn.BCEWithLogitsLoss()\n\t\telif gan_mode in ['wgangp']:\n\t\t\tself.loss = None\n\t\telse:\n\t\t\traise NotImplementedError('gan mode %s not implemented' % gan_mode)\n\n\tdef get_target_tensor(self, prediction, target_is_real):\n\t\t\"\"\"Create label tensors with the same size as the input.\n\n\t\tParameters:\n\t\t\tprediction (tensor) - - tpyically the prediction from a discriminator\n\t\t\ttarget_is_real (bool) - - if the ground truth label is for real images or fake images\n\n\t\tReturns:\n\t\t\tA label tensor filled with ground truth label, and with the size of the input\n\t\t\"\"\"\n\n\t\tif target_is_real:\n\t\t\ttarget_tensor = self.real_label\n\t\telse:\n\t\t\ttarget_tensor = self.fake_label\n\t\treturn target_tensor.expand_as(prediction)\n\n\tdef __call__(self, prediction, target_is_real):\n\t\t\"\"\"Calculate loss given Discriminator's output and grount truth labels.\n\n\t\tParameters:\n\t\t\tprediction (tensor) - - tpyically the prediction output from a discriminator\n\t\t\ttarget_is_real (bool) - - if the ground truth label is for real images or fake images\n\n\t\tReturns:\n\t\t\tthe calculated loss.\n\t\t\"\"\"\n\t\tif self.gan_mode in ['lsgan', 'vanilla']:\n\t\t\ttarget_tensor = self.get_target_tensor(prediction, target_is_real)\n\t\t\tloss = self.loss(prediction, target_tensor)\n\t\telif self.gan_mode == 'wgangp':\n\t\t\tif target_is_real:\n\t\t\t\tloss = -prediction.mean()\n\t\t\telse:\n\t\t\t\tloss = prediction.mean()\n\t\treturn loss\n\n\ndef cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0):\n\t\"\"\"Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028\n\n\tArguments:\n\t\tnetD (network)              -- discriminator network\n\t\treal_data (tensor array)    -- real images\n\t\tfake_data (tensor array)    -- generated images from the generator\n\t\tdevice (str)                -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')\n\t\ttype (str)                  -- if we mix real and fake data or not [real | fake | mixed].\n\t\tconstant (float)            -- the constant used in formula ( | |gradient||_2 - constant)^2\n\t\tlambda_gp (float)           -- weight for this loss\n\n\tReturns the gradient penalty loss\n\t\"\"\"\n\tif lambda_gp > 0.0:\n\t\tif type == 'real':   # either use real images, fake images, or a linear interpolation of two.\n\t\t\tinterpolatesv = real_data\n\t\telif type == 'fake':\n\t\t\tinterpolatesv = fake_data\n\t\telif type == 'mixed':\n\t\t\talpha = torch.rand(real_data.shape[0], 1, device=device)\n\t\t\talpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)\n\t\t\tinterpolatesv = alpha * real_data + ((1 - alpha) * fake_data)\n\t\telse:\n\t\t\traise NotImplementedError('{} not implemented'.format(type))\n\t\tinterpolatesv.requires_grad_(True)\n\t\tdisc_interpolates = netD(interpolatesv)\n\t\tgradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,\n\t\t\t\t\t\t\t\t\t\tgrad_outputs=torch.ones(disc_interpolates.size()).to(device),\n\t\t\t\t\t\t\t\t\t\tcreate_graph=True, retain_graph=True, only_inputs=True)\n\t\tgradients = gradients[0].view(real_data.size(0), -1)  # flat the data\n\t\tgradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp        # added eps\n\t\treturn gradient_penalty, gradients\n\telse:\n\t\treturn 0.0, None\n\n\nclass ResnetGenerator(nn.Module):\n\t\"\"\"Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.\n\n\tWe adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)\n\t\"\"\"\n\n\tdef __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):\n\t\t\"\"\"Construct a Resnet-based generator\n\n\t\tParameters:\n\t\t\tinput_nc (int)      -- the number of channels in input images\n\t\t\toutput_nc (int)     -- the number of channels in output images\n\t\t\tngf (int)           -- the number of filters in the last conv layer\n\t\t\tnorm_layer          -- normalization layer\n\t\t\tuse_dropout (bool)  -- if use dropout layers\n\t\t\tn_blocks (int)      -- the number of ResNet blocks\n\t\t\tpadding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero\n\t\t\"\"\"\n\t\tassert(n_blocks >= 0)\n\t\tsuper(ResnetGenerator, self).__init__()\n\t\tif type(norm_layer) == functools.partial:\n\t\t\tuse_bias = norm_layer.func == nn.InstanceNorm2d\n\t\telse:\n\t\t\tuse_bias = norm_layer == nn.InstanceNorm2d\n\n\t\tmodel = [nn.ReflectionPad2d(3),\n\t\t\t\t nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),\n\t\t\t\t norm_layer(ngf),\n\t\t\t\t nn.ReLU(True)]\n\n\t\tn_downsampling = 2\n\t\tfor i in range(n_downsampling):  # add downsampling layers\n\t\t\tmult = 2 ** i\n\t\t\tmodel += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),\n\t\t\t\t\t  norm_layer(ngf * mult * 2),\n\t\t\t\t\t  nn.ReLU(True)]\n\n\t\tmult = 2 ** n_downsampling\n\t\tfor i in range(n_blocks):       # add ResNet blocks\n\n\t\t\tmodel += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]\n\n\t\tfor i in range(n_downsampling):  # add upsampling layers\n\t\t\tmult = 2 ** (n_downsampling - i)\n\t\t\tmodel += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),\n\t\t\t\t\t\t\t\t\t\t kernel_size=3, stride=2,\n\t\t\t\t\t\t\t\t\t\t padding=1, output_padding=1,\n\t\t\t\t\t\t\t\t\t\t bias=use_bias),\n\t\t\t\t\t  norm_layer(int(ngf * mult / 2)),\n\t\t\t\t\t  nn.ReLU(True)]\n\t\tmodel += [nn.ReflectionPad2d(3)]\n\t\tmodel += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]\n\t\tmodel += [nn.Tanh()]\n\n\t\tself.model = nn.Sequential(*model)\n\n\tdef forward(self, input):\n\t\t\"\"\"Standard forward\"\"\"\n\t\treturn self.model(input)\n\nclass ResnetStyle2Generator(nn.Module):\n\tdef __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect', extra_channel=3, model0_res=0):\n\t\t\"\"\"Construct a Resnet-based generator\n\n\t\tParameters:\n\t\t\tinput_nc (int)      -- the number of channels in input images\n\t\t\toutput_nc (int)     -- the number of channels in output images\n\t\t\tngf (int)           -- the number of filters in the last conv layer\n\t\t\tnorm_layer          -- normalization layer\n\t\t\tuse_dropout (bool)  -- if use dropout layers\n\t\t\tn_blocks (int)      -- the number of ResNet blocks\n\t\t\tpadding_type (str)  -- the name of padding layer in conv layers: reflect | replicate | zero\n\t\t\"\"\"\n\t\tassert(n_blocks >= 0)\n\t\tsuper(ResnetStyle2Generator, self).__init__()\n\t\tif type(norm_layer) == functools.partial:\n\t\t\tuse_bias = norm_layer.func == nn.InstanceNorm2d\n\t\telse:\n\t\t\tuse_bias = norm_layer == nn.InstanceNorm2d\n\n\t\tmodel0 = [nn.ReflectionPad2d(3),\n\t\t\t\t nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),\n\t\t\t\t norm_layer(ngf),\n\t\t\t\t nn.ReLU(True)]\n\n\t\tn_downsampling = 2\n\t\tfor i in range(n_downsampling):  # add downsampling layers\n\t\t\tmult = 2 ** i\n\t\t\tmodel0 += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),\n\t\t\t\t\t  norm_layer(ngf * mult * 2),\n\t\t\t\t\t  nn.ReLU(True)]\n\t\t\n\t\tmult = 2 ** n_downsampling\n\t\tfor i in range(model0_res):       # add ResNet blocks\n\t\t\tmodel0 += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]\n\n\t\tmodel = []\n\t\tmodel += [nn.Conv2d(ngf * mult + extra_channel, ngf * mult, kernel_size=3, stride=1, padding=1, bias=use_bias),\n\t\t\t\t\t  norm_layer(ngf * mult),\n\t\t\t\t\t  nn.ReLU(True)]\n\n\t\tfor i in range(n_blocks-model0_res):       # add ResNet blocks\n\t\t\tmodel += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]\n\n\t\tfor i in range(n_downsampling):  # add upsampling layers\n\t\t\tmult = 2 ** (n_downsampling - i)\n\t\t\tmodel += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),\n\t\t\t\t\t\t\t\t\t\t kernel_size=3, stride=2,\n\t\t\t\t\t\t\t\t\t\t padding=1, output_padding=1,\n\t\t\t\t\t\t\t\t\t\t bias=use_bias),\n\t\t\t\t\t  norm_layer(int(ngf * mult / 2)),\n\t\t\t\t\t  nn.ReLU(True)]\n\t\tmodel += [nn.ReflectionPad2d(3)]\n\t\tmodel += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]\n\t\tmodel += [nn.Tanh()]\n\n\t\tself.model0 = nn.Sequential(*model0)\n\t\tself.model = nn.Sequential(*model)\n\t\t#print(list(self.modules()))\n\n\tdef forward(self, input1, input2):\n\t\t\"\"\"Standard forward\"\"\"\n\t\tf1 = self.model0(input1)\n\t\ty1 = torch.cat([f1, input2], 1)\n\t\treturn self.model(y1)\n\nclass ResnetBlock(nn.Module):\n\t\"\"\"Define a Resnet block\"\"\"\n\n\tdef __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3):\n\t\t\"\"\"Initialize the Resnet block\n\n\t\tA resnet block is a conv block with skip connections\n\t\tWe construct a conv block with build_conv_block function,\n\t\tand implement skip connections in <forward> function.\n\t\tOriginal Resnet paper: https://arxiv.org/pdf/1512.03385.pdf\n\t\t\"\"\"\n\t\tsuper(ResnetBlock, self).__init__()\n\t\tself.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, kernel)\n\n\tdef build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, kernel=3):\n\t\t\"\"\"Construct a convolutional block.\n\n\t\tParameters:\n\t\t\tdim (int)           -- the number of channels in the conv layer.\n\t\t\tpadding_type (str)  -- the name of padding layer: reflect | replicate | zero\n\t\t\tnorm_layer          -- normalization layer\n\t\t\tuse_dropout (bool)  -- if use dropout layers.\n\t\t\tuse_bias (bool)     -- if the conv layer uses bias or not\n\n\t\tReturns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))\n\t\t\"\"\"\n\t\tconv_block = []\n\t\tp = 0\n\t\tpad = int((kernel-1)/2)\n\t\tif padding_type == 'reflect':#by default\n\t\t\tconv_block += [nn.ReflectionPad2d(pad)]\n\t\telif padding_type == 'replicate':\n\t\t\tconv_block += [nn.ReplicationPad2d(pad)]\n\t\telif padding_type == 'zero':\n\t\t\tp = pad\n\t\telse:\n\t\t\traise NotImplementedError('padding [%s] is not implemented' % padding_type)\n\n\t\tconv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]\n\t\tif use_dropout:\n\t\t\tconv_block += [nn.Dropout(0.5)]\n\n\t\tp = 0\n\t\tif padding_type == 'reflect':\n\t\t\tconv_block += [nn.ReflectionPad2d(pad)]\n\t\telif padding_type == 'replicate':\n\t\t\tconv_block += [nn.ReplicationPad2d(pad)]\n\t\telif padding_type == 'zero':\n\t\t\tp = pad\n\t\telse:\n\t\t\traise NotImplementedError('padding [%s] is not implemented' % padding_type)\n\t\tconv_block += [nn.Conv2d(dim, dim, kernel_size=kernel, padding=p, bias=use_bias), norm_layer(dim)]\n\n\t\treturn nn.Sequential(*conv_block)\n\n\tdef forward(self, x):\n\t\t\"\"\"Forward function (with skip connections)\"\"\"\n\t\tout = x + self.conv_block(x)  # add skip connections\n\t\treturn out\n\n\nclass UnetGenerator(nn.Module):\n\t\"\"\"Create a Unet-based generator\"\"\"\n\n\tdef __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):\n\t\t\"\"\"Construct a Unet generator\n\t\tParameters:\n\t\t\tinput_nc (int)  -- the number of channels in input images\n\t\t\toutput_nc (int) -- the number of channels in output images\n\t\t\tnum_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,\n\t\t\t\t\t\t\t\timage of size 128x128 will become of size 1x1 # at the bottleneck\n\t\t\tngf (int)       -- the number of filters in the last conv layer\n\t\t\tnorm_layer      -- normalization layer\n\n\t\tWe construct the U-Net from the innermost layer to the outermost layer.\n\t\tIt is a recursive process.\n\t\t\"\"\"\n\t\tsuper(UnetGenerator, self).__init__()\n\t\t# construct unet structure\n\t\tunet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)  # add the innermost layer\n\t\tfor i in range(num_downs - 5):          # add intermediate layers with ngf * 8 filters\n\t\t\tunet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)\n\t\t# gradually reduce the number of filters from ngf * 8 to ngf\n\t\tunet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n\t\tunet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n\t\tunet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)\n\t\tself.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)  # add the outermost layer\n\n\tdef forward(self, input):\n\t\t\"\"\"Standard forward\"\"\"\n\t\treturn self.model(input)\n\n\nclass UnetSkipConnectionBlock(nn.Module):\n\t\"\"\"Defines the Unet submodule with skip connection.\n\t\tX -------------------identity----------------------\n\t\t|-- downsampling -- |submodule| -- upsampling --|\n\t\"\"\"\n\n\tdef __init__(self, outer_nc, inner_nc, input_nc=None,\n\t\t\t\t submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):\n\t\t\"\"\"Construct a Unet submodule with skip connections.\n\n\t\tParameters:\n\t\t\touter_nc (int) -- the number of filters in the outer conv layer\n\t\t\tinner_nc (int) -- the number of filters in the inner conv layer\n\t\t\tinput_nc (int) -- the number of channels in input images/features\n\t\t\tsubmodule (UnetSkipConnectionBlock) -- previously defined submodules\n\t\t\toutermost (bool)    -- if this module is the outermost module\n\t\t\tinnermost (bool)    -- if this module is the innermost module\n\t\t\tnorm_layer          -- normalization layer\n\t\t\tuser_dropout (bool) -- if use dropout layers.\n\t\t\"\"\"\n\t\tsuper(UnetSkipConnectionBlock, self).__init__()\n\t\tself.outermost = outermost\n\t\tif type(norm_layer) == functools.partial:\n\t\t\tuse_bias = norm_layer.func == nn.InstanceNorm2d\n\t\telse:\n\t\t\tuse_bias = norm_layer == nn.InstanceNorm2d\n\t\tif input_nc is None:\n\t\t\tinput_nc = outer_nc\n\t\tdownconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,\n\t\t\t\t\t\t\t stride=2, padding=1, bias=use_bias)\n\t\tdownrelu = nn.LeakyReLU(0.2, True)\n\t\tdownnorm = norm_layer(inner_nc)\n\t\tuprelu = nn.ReLU(True)\n\t\tupnorm = norm_layer(outer_nc)\n\n\t\tif outermost:\n\t\t\tupconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,\n\t\t\t\t\t\t\t\t\t\tkernel_size=4, stride=2,\n\t\t\t\t\t\t\t\t\t\tpadding=1)\n\t\t\tdown = [downconv]\n\t\t\tup = [uprelu, upconv, nn.Tanh()]\n\t\t\tmodel = down + [submodule] + up\n\t\telif innermost:\n\t\t\tupconv = nn.ConvTranspose2d(inner_nc, outer_nc,\n\t\t\t\t\t\t\t\t\t\tkernel_size=4, stride=2,\n\t\t\t\t\t\t\t\t\t\tpadding=1, bias=use_bias)\n\t\t\tdown = [downrelu, downconv]\n\t\t\tup = [uprelu, upconv, upnorm]\n\t\t\tmodel = down + up\n\t\telse:\n\t\t\tupconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,\n\t\t\t\t\t\t\t\t\t\tkernel_size=4, stride=2,\n\t\t\t\t\t\t\t\t\t\tpadding=1, bias=use_bias)\n\t\t\tdown = [downrelu, downconv, downnorm]\n\t\t\tup = [uprelu, upconv, upnorm]\n\n\t\t\tif use_dropout:\n\t\t\t\tmodel = down + [submodule] + up + [nn.Dropout(0.5)]\n\t\t\telse:\n\t\t\t\tmodel = down + [submodule] + up\n\n\t\tself.model = nn.Sequential(*model)\n\n\tdef forward(self, x):\n\t\tif self.outermost:\n\t\t\treturn self.model(x)\n\t\telse:   # add skip connections\n\t\t\treturn torch.cat([x, self.model(x)], 1)\n\n\nclass NLayerDiscriminator(nn.Module):\n\t\"\"\"Defines a PatchGAN discriminator\"\"\"\n\n\tdef __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):\n\t\t\"\"\"Construct a PatchGAN discriminator\n\n\t\tParameters:\n\t\t\tinput_nc (int)  -- the number of channels in input images\n\t\t\tndf (int)       -- the number of filters in the last conv layer\n\t\t\tn_layers (int)  -- the number of conv layers in the discriminator\n\t\t\tnorm_layer      -- normalization layer\n\t\t\"\"\"\n\t\tsuper(NLayerDiscriminator, self).__init__()\n\t\tif type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters\n\t\t\tuse_bias = norm_layer.func != nn.BatchNorm2d\n\t\telse:\n\t\t\tuse_bias = norm_layer != nn.BatchNorm2d\n\n\t\tkw = 4\n\t\tpadw = 1\n\t\tsequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]\n\t\tnf_mult = 1\n\t\tnf_mult_prev = 1\n\t\tfor n in range(1, n_layers):  # gradually increase the number of filters\n\t\t\tnf_mult_prev = nf_mult\n\t\t\tnf_mult = min(2 ** n, 8)\n\t\t\tsequence += [\n\t\t\t\tnn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n\t\t\t\tnorm_layer(ndf * nf_mult),\n\t\t\t\tnn.LeakyReLU(0.2, True)\n\t\t\t]\n\n\t\tnf_mult_prev = nf_mult\n\t\tnf_mult = min(2 ** n_layers, 8)\n\t\tsequence += [\n\t\t\tnn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n\t\t\tnorm_layer(ndf * nf_mult),\n\t\t\tnn.LeakyReLU(0.2, True)\n\t\t]\n\n\t\tsequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map\n\t\tself.model = nn.Sequential(*sequence)\n\n\tdef forward(self, input):\n\t\t\"\"\"Standard forward.\"\"\"\n\t\treturn self.model(input)\n\n\nclass NLayerDiscriminatorCls(nn.Module):\n\t\"\"\"Defines a PatchGAN discriminator\"\"\"\n\n\tdef __init__(self, input_nc, ndf=64, n_layers=3, n_class=3, norm_layer=nn.BatchNorm2d):\n\t\t\"\"\"Construct a PatchGAN discriminator\n\n\t\tParameters:\n\t\t\tinput_nc (int)  -- the number of channels in input images\n\t\t\tndf (int)       -- the number of filters in the last conv layer\n\t\t\tn_layers (int)  -- the number of conv layers in the discriminator\n\t\t\tnorm_layer      -- normalization layer\n\t\t\"\"\"\n\t\tsuper(NLayerDiscriminatorCls, self).__init__()\n\t\tif type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters\n\t\t\tuse_bias = norm_layer.func != nn.BatchNorm2d\n\t\telse:\n\t\t\tuse_bias = norm_layer != nn.BatchNorm2d\n\n\t\tkw = 4\n\t\tpadw = 1\n\t\tsequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]\n\t\tnf_mult = 1\n\t\tnf_mult_prev = 1\n\t\tfor n in range(1, n_layers):  # gradually increase the number of filters\n\t\t\tnf_mult_prev = nf_mult\n\t\t\tnf_mult = min(2 ** n, 8)\n\t\t\tsequence += [\n\t\t\t\tnn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n\t\t\t\tnorm_layer(ndf * nf_mult),\n\t\t\t\tnn.LeakyReLU(0.2, True)\n\t\t\t]\n\n\t\tnf_mult_prev = nf_mult\n\t\tnf_mult = min(2 ** n_layers, 8)\n\t\tsequence1 = [\n\t\t\tnn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n\t\t\tnorm_layer(ndf * nf_mult),\n\t\t\tnn.LeakyReLU(0.2, True)\n\t\t]\n\t\tsequence1 += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]  # output 1 channel prediction map\n\n\t\tsequence2 = [\n\t\t\tnn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n\t\t\tnorm_layer(ndf * nf_mult),\n\t\t\tnn.LeakyReLU(0.2, True)\n\t\t]\n\t\tsequence2 += [\n\t\t\tnn.Conv2d(ndf * nf_mult, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n\t\t\tnorm_layer(ndf * nf_mult),\n\t\t\tnn.LeakyReLU(0.2, True)\n\t\t]\n\t\tsequence2 += [\n\t\t\tnn.Conv2d(ndf * nf_mult, n_class, kernel_size=16, stride=1, padding=0, bias=use_bias)]\n\n\n\t\tself.model0 = nn.Sequential(*sequence)\n\t\tself.model1 = nn.Sequential(*sequence1)\n\t\tself.model2 = nn.Sequential(*sequence2)\n\t\tprint(list(self.modules()))\n\n\tdef forward(self, input):\n\t\t\"\"\"Standard forward.\"\"\"\n\t\tfeat = self.model0(input)\n\t\t# patchGAN output (1 * 62 * 62)\n\t\tpatch = self.model1(feat)\n\t\t# class output (3 * 1 * 1)\n\t\tclassl = self.model2(feat)\n\t\treturn patch, classl.view(classl.size(0), -1)\n\n\nclass PixelDiscriminator(nn.Module):\n\t\"\"\"Defines a 1x1 PatchGAN discriminator (pixelGAN)\"\"\"\n\n\tdef __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):\n\t\t\"\"\"Construct a 1x1 PatchGAN discriminator\n\n\t\tParameters:\n\t\t\tinput_nc (int)  -- the number of channels in input images\n\t\t\tndf (int)       -- the number of filters in the last conv layer\n\t\t\tnorm_layer      -- normalization layer\n\t\t\"\"\"\n\t\tsuper(PixelDiscriminator, self).__init__()\n\t\tif type(norm_layer) == functools.partial:  # no need to use bias as BatchNorm2d has affine parameters\n\t\t\tuse_bias = norm_layer.func != nn.InstanceNorm2d\n\t\telse:\n\t\t\tuse_bias = norm_layer != nn.InstanceNorm2d\n\n\t\tself.net = [\n\t\t\tnn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),\n\t\t\tnn.LeakyReLU(0.2, True),\n\t\t\tnn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),\n\t\t\tnorm_layer(ndf * 2),\n\t\t\tnn.LeakyReLU(0.2, True),\n\t\t\tnn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]\n\n\t\tself.net = nn.Sequential(*self.net)\n\n\tdef forward(self, input):\n\t\t\"\"\"Standard forward.\"\"\"\n\t\treturn self.net(input)\n\n\nclass HED(nn.Module):\n\tdef __init__(self):\n\t\tsuper(HED, self).__init__()\n\n\t\tself.moduleVggOne = nn.Sequential(\n\t\t\tnn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False),\n\t\t\tnn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False)\n\t\t)\n\n\t\tself.moduleVggTwo = nn.Sequential(\n\t\t\tnn.MaxPool2d(kernel_size=2, stride=2),\n\t\t\tnn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False),\n\t\t\tnn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False)\n\t\t)\n\n\t\tself.moduleVggThr = nn.Sequential(\n\t\t\tnn.MaxPool2d(kernel_size=2, stride=2),\n\t\t\tnn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False),\n\t\t\tnn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False),\n\t\t\tnn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False)\n\t\t)\n\n\t\tself.moduleVggFou = nn.Sequential(\n\t\t\tnn.MaxPool2d(kernel_size=2, stride=2),\n\t\t\tnn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False),\n\t\t\tnn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False),\n\t\t\tnn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False)\n\t\t)\n\n\t\tself.moduleVggFiv = nn.Sequential(\n\t\t\tnn.MaxPool2d(kernel_size=2, stride=2),\n\t\t\tnn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False),\n\t\t\tnn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False),\n\t\t\tnn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),\n\t\t\tnn.ReLU(inplace=False)\n\t\t)\n\n\t\tself.moduleScoreOne = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0)\n\t\tself.moduleScoreTwo = nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1, padding=0)\n\t\tself.moduleScoreThr = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1, stride=1, padding=0)\n\t\tself.moduleScoreFou = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)\n\t\tself.moduleScoreFiv = nn.Conv2d(in_channels=512, out_channels=1, kernel_size=1, stride=1, padding=0)\n\n\t\tself.moduleCombine = nn.Sequential(\n\t\t\tnn.Conv2d(in_channels=5, out_channels=1, kernel_size=1, stride=1, padding=0),\n\t\t\tnn.Sigmoid()\n\t\t)\n\t\t\n\tdef forward(self, tensorInput):\n\t\ttensorBlue = (tensorInput[:, 2:3, :, :] * 255.0) - 104.00698793\n\t\ttensorGreen = (tensorInput[:, 1:2, :, :] * 255.0) - 116.66876762\n\t\ttensorRed = (tensorInput[:, 0:1, :, :] * 255.0) - 122.67891434\n\t\t\n\t\ttensorInput = torch.cat([ tensorBlue, tensorGreen, tensorRed ], 1)\n\t\t\n\t\ttensorVggOne = self.moduleVggOne(tensorInput)\n\t\ttensorVggTwo = self.moduleVggTwo(tensorVggOne)\n\t\ttensorVggThr = self.moduleVggThr(tensorVggTwo)\n\t\ttensorVggFou = self.moduleVggFou(tensorVggThr)\n\t\ttensorVggFiv = self.moduleVggFiv(tensorVggFou)\n\t\t\n\t\ttensorScoreOne = self.moduleScoreOne(tensorVggOne)\n\t\ttensorScoreTwo = self.moduleScoreTwo(tensorVggTwo)\n\t\ttensorScoreThr = self.moduleScoreThr(tensorVggThr)\n\t\ttensorScoreFou = self.moduleScoreFou(tensorVggFou)\n\t\ttensorScoreFiv = self.moduleScoreFiv(tensorVggFiv)\n\t\t\n\t\ttensorScoreOne = nn.functional.interpolate(input=tensorScoreOne, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)\n\t\ttensorScoreTwo = nn.functional.interpolate(input=tensorScoreTwo, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)\n\t\ttensorScoreThr = nn.functional.interpolate(input=tensorScoreThr, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)\n\t\ttensorScoreFou = nn.functional.interpolate(input=tensorScoreFou, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)\n\t\ttensorScoreFiv = nn.functional.interpolate(input=tensorScoreFiv, size=(tensorInput.size(2), tensorInput.size(3)), mode='bilinear', align_corners=False)\n\t\t\n\t\treturn self.moduleCombine(torch.cat([ tensorScoreOne, tensorScoreTwo, tensorScoreThr, tensorScoreFou, tensorScoreFiv ], 1))\n\n"
  },
  {
    "path": "models/networks_basic.py",
    "content": "\nfrom __future__ import absolute_import\n\nimport sys\nimport torch\nimport torch.nn as nn\nimport torch.nn.init as init\nfrom torch.autograd import Variable\nimport numpy as np\nfrom pdb import set_trace as st\nfrom skimage import color\nfrom IPython import embed\nfrom . import pretrained_networks as pn\n\nfrom util import util\n\ndef spatial_average(in_tens, keepdim=True):\n    return in_tens.mean([2,3],keepdim=keepdim)\n\ndef upsample(in_tens, out_H=64): # assumes scale factor is same for H and W\n    in_H = in_tens.shape[2]\n    scale_factor = 1.*out_H/in_H\n\n    return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)\n\n# Learned perceptual metric\nclass PNetLin(nn.Module):\n    def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):\n        super(PNetLin, self).__init__()\n\n        self.pnet_type = pnet_type\n        self.pnet_tune = pnet_tune\n        self.pnet_rand = pnet_rand\n        self.spatial = spatial\n        self.lpips = lpips\n        self.version = version\n        self.scaling_layer = ScalingLayer()\n\n        if(self.pnet_type in ['vgg','vgg16']):\n            net_type = pn.vgg16\n            self.chns = [64,128,256,512,512]\n        elif(self.pnet_type=='alex'):\n            net_type = pn.alexnet\n            self.chns = [64,192,384,256,256]\n        elif(self.pnet_type=='squeeze'):\n            net_type = pn.squeezenet\n            self.chns = [64,128,256,384,384,512,512]\n        self.L = len(self.chns)\n\n        self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)\n\n        if(lpips):\n            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)\n            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)\n            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)\n            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)\n            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)\n            self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]\n            if(self.pnet_type=='squeeze'): # 7 layers for squeezenet\n                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)\n                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)\n                self.lins+=[self.lin5,self.lin6]\n\n    def forward(self, in0, in1, retPerLayer=False):\n        # v0.0 - original release had a bug, where input was not scaled\n        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)\n        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)\n        feats0, feats1, diffs = {}, {}, {}\n\n        for kk in range(self.L):\n            feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])\n            diffs[kk] = (feats0[kk]-feats1[kk])**2\n\n        if(self.lpips):\n            if(self.spatial):\n                res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]\n            else:\n                res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]\n        else:\n            if(self.spatial):\n                res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]\n            else:\n                res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]\n\n        val = res[0]\n        for l in range(1,self.L):\n            val += res[l]\n        \n        if(retPerLayer):\n            return (val, res)\n        else:\n            return val\n\nclass ScalingLayer(nn.Module):\n    def __init__(self):\n        super(ScalingLayer, self).__init__()\n        self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])\n        self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])\n\n    def forward(self, inp):\n        return (inp - self.shift.to(inp.device)) / self.scale.to(inp.device)\n\n\nclass NetLinLayer(nn.Module):\n    ''' A single linear layer which does a 1x1 conv '''\n    def __init__(self, chn_in, chn_out=1, use_dropout=False):\n        super(NetLinLayer, self).__init__()\n\n        layers = [nn.Dropout(),] if(use_dropout) else []\n        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]\n        self.model = nn.Sequential(*layers)\n\n\nclass Dist2LogitLayer(nn.Module):\n    ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''\n    def __init__(self, chn_mid=32, use_sigmoid=True):\n        super(Dist2LogitLayer, self).__init__()\n\n        layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]\n        layers += [nn.LeakyReLU(0.2,True),]\n        layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]\n        layers += [nn.LeakyReLU(0.2,True),]\n        layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]\n        if(use_sigmoid):\n            layers += [nn.Sigmoid(),]\n        self.model = nn.Sequential(*layers)\n\n    def forward(self,d0,d1,eps=0.1):\n        return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))\n\nclass BCERankingLoss(nn.Module):\n    def __init__(self, chn_mid=32):\n        super(BCERankingLoss, self).__init__()\n        self.net = Dist2LogitLayer(chn_mid=chn_mid)\n        # self.parameters = list(self.net.parameters())\n        self.loss = torch.nn.BCELoss()\n\n    def forward(self, d0, d1, judge):\n        per = (judge+1.)/2.\n        self.logit = self.net.forward(d0,d1)\n        return self.loss(self.logit, per)\n\n# L2, DSSIM metrics\nclass FakeNet(nn.Module):\n    def __init__(self, use_gpu=True, colorspace='Lab'):\n        super(FakeNet, self).__init__()\n        self.use_gpu = use_gpu\n        self.colorspace=colorspace\n\nclass L2(FakeNet):\n\n    def forward(self, in0, in1, retPerLayer=None):\n        assert(in0.size()[0]==1) # currently only supports batchSize 1\n\n        if(self.colorspace=='RGB'):\n            (N,C,X,Y) = in0.size()\n            value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)\n            return value\n        elif(self.colorspace=='Lab'):\n            value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), \n                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')\n            ret_var = Variable( torch.Tensor((value,) ) )\n            if(self.use_gpu):\n                ret_var = ret_var.cuda()\n            return ret_var\n\nclass DSSIM(FakeNet):\n\n    def forward(self, in0, in1, retPerLayer=None):\n        assert(in0.size()[0]==1) # currently only supports batchSize 1\n\n        if(self.colorspace=='RGB'):\n            value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')\n        elif(self.colorspace=='Lab'):\n            value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), \n                util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')\n        ret_var = Variable( torch.Tensor((value,) ) )\n        if(self.use_gpu):\n            ret_var = ret_var.cuda()\n        return ret_var\n\ndef print_network(net):\n    num_params = 0\n    for param in net.parameters():\n        num_params += param.numel()\n    print('Network',net)\n    print('Total number of parameters: %d' % num_params)\n"
  },
  {
    "path": "models/pretrained_networks.py",
    "content": "from collections import namedtuple\nimport torch\nfrom torchvision import models\nfrom IPython import embed\n\nclass squeezenet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(squeezenet, self).__init__()\n        pretrained_features = models.squeezenet1_1(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.slice6 = torch.nn.Sequential()\n        self.slice7 = torch.nn.Sequential()\n        self.N_slices = 7\n        for x in range(2):\n            self.slice1.add_module(str(x), pretrained_features[x])\n        for x in range(2,5):\n            self.slice2.add_module(str(x), pretrained_features[x])\n        for x in range(5, 8):\n            self.slice3.add_module(str(x), pretrained_features[x])\n        for x in range(8, 10):\n            self.slice4.add_module(str(x), pretrained_features[x])\n        for x in range(10, 11):\n            self.slice5.add_module(str(x), pretrained_features[x])\n        for x in range(11, 12):\n            self.slice6.add_module(str(x), pretrained_features[x])\n        for x in range(12, 13):\n            self.slice7.add_module(str(x), pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1 = h\n        h = self.slice2(h)\n        h_relu2 = h\n        h = self.slice3(h)\n        h_relu3 = h\n        h = self.slice4(h)\n        h_relu4 = h\n        h = self.slice5(h)\n        h_relu5 = h\n        h = self.slice6(h)\n        h_relu6 = h\n        h = self.slice7(h)\n        h_relu7 = h\n        vgg_outputs = namedtuple(\"SqueezeOutputs\", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])\n        out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)\n\n        return out\n\n\nclass alexnet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(alexnet, self).__init__()\n        alexnet_pretrained_features = models.alexnet(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(2):\n            self.slice1.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(2, 5):\n            self.slice2.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(5, 8):\n            self.slice3.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(8, 10):\n            self.slice4.add_module(str(x), alexnet_pretrained_features[x])\n        for x in range(10, 12):\n            self.slice5.add_module(str(x), alexnet_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1 = h\n        h = self.slice2(h)\n        h_relu2 = h\n        h = self.slice3(h)\n        h_relu3 = h\n        h = self.slice4(h)\n        h_relu4 = h\n        h = self.slice5(h)\n        h_relu5 = h\n        alexnet_outputs = namedtuple(\"AlexnetOutputs\", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])\n        out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)\n\n        return out\n\nclass vgg16(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True):\n        super(vgg16, self).__init__()\n        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features\n        self.slice1 = torch.nn.Sequential()\n        self.slice2 = torch.nn.Sequential()\n        self.slice3 = torch.nn.Sequential()\n        self.slice4 = torch.nn.Sequential()\n        self.slice5 = torch.nn.Sequential()\n        self.N_slices = 5\n        for x in range(4):\n            self.slice1.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(4, 9):\n            self.slice2.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(9, 16):\n            self.slice3.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(16, 23):\n            self.slice4.add_module(str(x), vgg_pretrained_features[x])\n        for x in range(23, 30):\n            self.slice5.add_module(str(x), vgg_pretrained_features[x])\n        if not requires_grad:\n            for param in self.parameters():\n                param.requires_grad = False\n\n    def forward(self, X):\n        h = self.slice1(X)\n        h_relu1_2 = h\n        h = self.slice2(h)\n        h_relu2_2 = h\n        h = self.slice3(h)\n        h_relu3_3 = h\n        h = self.slice4(h)\n        h_relu4_3 = h\n        h = self.slice5(h)\n        h_relu5_3 = h\n        vgg_outputs = namedtuple(\"VggOutputs\", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])\n        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)\n\n        return out\n\n\n\nclass resnet(torch.nn.Module):\n    def __init__(self, requires_grad=False, pretrained=True, num=18):\n        super(resnet, self).__init__()\n        if(num==18):\n            self.net = models.resnet18(pretrained=pretrained)\n        elif(num==34):\n            self.net = models.resnet34(pretrained=pretrained)\n        elif(num==50):\n            self.net = models.resnet50(pretrained=pretrained)\n        elif(num==101):\n            self.net = models.resnet101(pretrained=pretrained)\n        elif(num==152):\n            self.net = models.resnet152(pretrained=pretrained)\n        self.N_slices = 5\n\n        self.conv1 = self.net.conv1\n        self.bn1 = self.net.bn1\n        self.relu = self.net.relu\n        self.maxpool = self.net.maxpool\n        self.layer1 = self.net.layer1\n        self.layer2 = self.net.layer2\n        self.layer3 = self.net.layer3\n        self.layer4 = self.net.layer4\n\n    def forward(self, X):\n        h = self.conv1(X)\n        h = self.bn1(h)\n        h = self.relu(h)\n        h_relu1 = h\n        h = self.maxpool(h)\n        h = self.layer1(h)\n        h_conv2 = h\n        h = self.layer2(h)\n        h_conv3 = h\n        h = self.layer3(h)\n        h_conv4 = h\n        h = self.layer4(h)\n        h_conv5 = h\n\n        outputs = namedtuple(\"Outputs\", ['relu1','conv2','conv3','conv4','conv5'])\n        out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)\n\n        return out\n"
  },
  {
    "path": "models/test_3styles_model.py",
    "content": "from .base_model import BaseModel\nfrom . import networks\nimport torch\n\nclass Test3StylesModel(BaseModel):\n    \"\"\" This TesteModel can be used to generate CycleGAN results for only one direction.\n    This model will automatically set '--dataset_mode single', which only loads the images from one collection.\n\n    See the test instruction for more details.\n    \"\"\"\n    @staticmethod\n    def modify_commandline_options(parser, is_train=True):\n        \"\"\"Add new dataset-specific options, and rewrite default values for existing options.\n\n        Parameters:\n            parser          -- original option parser\n            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.\n\n        Returns:\n            the modified parser.\n\n        The model can only be used during test time. It requires '--dataset_mode single'.\n        You need to specify the network using the option '--model_suffix'.\n        \"\"\"\n        assert not is_train, 'TestModel cannot be used during training time'\n        parser.set_defaults(dataset_mode='single')\n        parser.add_argument('--style_control', type=int, default=0, help='not set style_vec in dataset')\n        parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A')\n        parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0')\n        parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')\n\n        return parser\n\n    def __init__(self, opt):\n        assert(not opt.isTrain)\n        BaseModel.__init__(self, opt)\n        # specify the training losses you want to print out. The training/test scripts  will call <BaseModel.get_current_losses>\n        self.loss_names = []\n        # specify the images you want to save/display. The training/test scripts  will call <BaseModel.get_current_visuals>\n        self.visual_names = ['real', 'fake1', 'fake2', 'fake3']\n        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>\n        self.model_names = ['G_A']  # only generator is needed.\n        print(opt.netga)\n        print('model0_res', opt.model0_res)\n        print('model1_res', opt.model1_res)\n        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,\n                                    not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)\n        \n        setattr(self, 'netG_A', self.netG)  # store netG in self.\n\n    def set_input(self, input):\n        self.real = input['A'].to(self.device)\n        self.image_paths = input['A_paths']\n        self.style1 = torch.Tensor([1, 0, 0]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device)\n        self.style2 = torch.Tensor([0, 1, 0]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device)\n        self.style3 = torch.Tensor([0, 0, 1]).view(3, 1, 1).repeat(1, 1, 128, 128).to(self.device)\n\n    def forward(self):\n        \"\"\"Run forward pass.\"\"\"\n        self.fake1 = self.netG(self.real, self.style1)\n        self.fake2 = self.netG(self.real, self.style2)\n        self.fake3 = self.netG(self.real, self.style3)\n\n    def optimize_parameters(self):\n        \"\"\"No optimization for test model.\"\"\"\n        pass\n"
  },
  {
    "path": "models/test_model.py",
    "content": "from .base_model import BaseModel\nfrom . import networks\nimport torch\n\nclass TestModel(BaseModel):\n    \"\"\" This TesteModel can be used to generate CycleGAN results for only one direction.\n    This model will automatically set '--dataset_mode single', which only loads the images from one collection.\n\n    See the test instruction for more details.\n    \"\"\"\n    @staticmethod\n    def modify_commandline_options(parser, is_train=True):\n        \"\"\"Add new dataset-specific options, and rewrite default values for existing options.\n\n        Parameters:\n            parser          -- original option parser\n            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.\n\n        Returns:\n            the modified parser.\n\n        The model can only be used during test time. It requires '--dataset_mode single'.\n        You need to specify the network using the option '--model_suffix'.\n        \"\"\"\n        assert not is_train, 'TestModel cannot be used during training time'\n        parser.set_defaults(dataset_mode='single')\n        parser.add_argument('--model_suffix', type=str, default='', help='In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.')\n        parser.add_argument('--style_control', type=int, default=1, help='use style_control')\n        parser.add_argument('--sfeature_mode', type=str, default='vgg19_softmax', help='vgg19 softmax as feature')\n        parser.add_argument('--sinput', type=str, default='sind', help='use which one for style input')\n        parser.add_argument('--sind', type=int, default=0, help='one hot for sfeature')\n        parser.add_argument('--svec', type=str, default='1,0,0', help='3-dim vec')\n        parser.add_argument('--simg', type=str, default='Yann_Legendre-053', help='drawing example for style')\n        parser.add_argument('--netga', type=str, default='resnet_style2_9blocks', help='net arch for netG_A')\n        parser.add_argument('--model0_res', type=int, default=0, help='number of resblocks in model0')\n        parser.add_argument('--model1_res', type=int, default=0, help='number of resblocks in model1 (after insert style, before 2 column merge)')\n\n        return parser\n\n    def __init__(self, opt):\n        \"\"\"Initialize the pix2pix class.\n\n        Parameters:\n            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions\n        \"\"\"\n        assert(not opt.isTrain)\n        BaseModel.__init__(self, opt)\n        # specify the training losses you want to print out. The training/test scripts  will call <BaseModel.get_current_losses>\n        self.loss_names = []\n        # specify the images you want to save/display. The training/test scripts  will call <BaseModel.get_current_visuals>\n        self.visual_names = ['real', 'fake', 'rec']\n        # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>\n        self.model_names = ['G' + opt.model_suffix, 'G_B']  # only generator is needed.\n        if not self.opt.style_control:\n            self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG,\n                                      opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)\n        else:\n            print(opt.netga)\n            print('model0_res', opt.model0_res)\n            print('model1_res', opt.model1_res)\n            self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netga, opt.norm,\n                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.model0_res, opt.model1_res)\n        \n        self.netGB = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG,\n                                      opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)\n        # assigns the model to self.netG_[suffix] so that it can be loaded\n        # please see <BaseModel.load_networks>\n        setattr(self, 'netG' + opt.model_suffix, self.netG)  # store netG in self.\n        setattr(self, 'netG_B', self.netGB)  # store netGB in self.\n\n    def set_input(self, input):\n        \"\"\"Unpack input data from the dataloader and perform necessary pre-processing steps.\n\n        Parameters:\n            input: a dictionary that contains the data itself and its metadata information.\n\n        We need to use 'single_dataset' dataset mode. It only load images from one domain.\n        \"\"\"\n        self.real = input['A'].to(self.device)\n        self.image_paths = input['A_paths']\n        if self.opt.style_control:\n            self.style = input['B_style']\n\n    def forward(self):\n        \"\"\"Run forward pass.\"\"\"\n        if not self.opt.style_control:\n            self.fake = self.netG(self.real)  # G(real)\n        else:\n            print(torch.mean(self.style,(2,3)),'style_control')\n            self.fake = self.netG(self.real, self.style)\n        self.rec = self.netG_B(self.fake)\n\n    def optimize_parameters(self):\n        \"\"\"No optimization for test model.\"\"\"\n        pass\n"
  },
  {
    "path": "options/__init__.py",
    "content": "\"\"\"This package options includes option modules: training options, test options, and basic options (used in both training and test).\"\"\"\n"
  },
  {
    "path": "options/base_options.py",
    "content": "import argparse\nimport os\nfrom util import util\nimport torch\nimport models\nimport data\n\n\nclass BaseOptions():\n    \"\"\"This class defines options used during both training and test time.\n\n    It also implements several helper functions such as parsing, printing, and saving the options.\n    It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.\n    \"\"\"\n\n    def __init__(self):\n        \"\"\"Reset the class; indicates the class hasn't been initailized\"\"\"\n        self.initialized = False\n\n    def initialize(self, parser):\n        \"\"\"Define the common options that are used in both training and test.\"\"\"\n        # basic parameters\n        parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n        parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')\n        parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')\n        parser.add_argument('--gpu_ids_p', type=str, default='0', help='gpu ids for pretrained auxiliary models: e.g. 0  0,1,2, 0,2. use -1 for CPU')\n        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n        # model parameters\n        parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]')\n        parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')\n        parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')\n        parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer')\n        parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer')\n        parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')\n        parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]')\n        parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')\n        parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]')\n        parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]')\n        parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')\n        parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')\n        # dataset parameters\n        parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')\n        parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA')\n        parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')\n        parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')\n        parser.add_argument('--batch_size', type=int, default=1, help='input batch size')\n        parser.add_argument('--load_size', type=int, default=286, help='scale images to this size')\n        parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size')\n        parser.add_argument('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n        parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]')\n        parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n        parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML')\n        # additional parameters\n        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')\n        parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]')\n        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')\n        parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')\n        self.initialized = True\n        return parser\n\n    def gather_options(self):\n        \"\"\"Initialize our parser with basic options(only once).\n        Add additional model-specific and dataset-specific options.\n        These options are defined in the <modify_commandline_options> function\n        in model and dataset classes.\n        \"\"\"\n        if not self.initialized:  # check if it has been initialized\n            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n            parser = self.initialize(parser)\n\n        # get the basic options\n        opt, _ = parser.parse_known_args()\n\n        # modify model-related parser options\n        model_name = opt.model\n        model_option_setter = models.get_option_setter(model_name)\n        parser = model_option_setter(parser, self.isTrain)\n        opt, _ = parser.parse_known_args()  # parse again with new defaults\n\n        # modify dataset-related parser options\n        dataset_name = opt.dataset_mode\n        dataset_option_setter = data.get_option_setter(dataset_name)\n        parser = dataset_option_setter(parser, self.isTrain)\n\n        # save and return the parser\n        self.parser = parser\n        return parser.parse_args()\n\n    def print_options(self, opt):\n        \"\"\"Print and save options\n\n        It will print both current options and default values(if different).\n        It will save options into a text file / [checkpoints_dir] / opt.txt\n        \"\"\"\n        message = ''\n        message += '----------------- Options ---------------\\n'\n        for k, v in sorted(vars(opt).items()):\n            comment = ''\n            default = self.parser.get_default(k)\n            if v != default:\n                comment = '\\t[default: %s]' % str(default)\n            message += '{:>25}: {:<30}{}\\n'.format(str(k), str(v), comment)\n        message += '----------------- End -------------------'\n        print(message)\n\n        # save to the disk\n        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)\n        util.mkdirs(expr_dir)\n        file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))\n        with open(file_name, 'wt') as opt_file:\n            opt_file.write(message)\n            opt_file.write('\\n')\n\n    def parse(self):\n        \"\"\"Parse our options, create checkpoints directory suffix, and set up gpu device.\"\"\"\n        opt = self.gather_options()\n        opt.isTrain = self.isTrain   # train or test\n\n        # process opt.suffix\n        if opt.suffix:\n            suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''\n            opt.name = opt.name + suffix\n\n        self.print_options(opt)\n\n        # set gpu ids\n        str_ids = opt.gpu_ids.split(',')\n        opt.gpu_ids = []\n        for str_id in str_ids:\n            id = int(str_id)\n            if id >= 0:\n                opt.gpu_ids.append(id)\n        if len(opt.gpu_ids) > 0:\n            torch.cuda.set_device(opt.gpu_ids[0])\n        \n        # set gpu ids\n        str_ids = opt.gpu_ids_p.split(',')\n        opt.gpu_ids_p = []\n        for str_id in str_ids:\n            id = int(str_id)\n            if id >= 0:\n                opt.gpu_ids_p.append(id)\n\n        self.opt = opt\n        return self.opt\n"
  },
  {
    "path": "options/test_options.py",
    "content": "from .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n    \"\"\"This class includes test options.\n\n    It also includes shared options defined in BaseOptions.\n    \"\"\"\n\n    def initialize(self, parser):\n        parser = BaseOptions.initialize(self, parser)  # define shared options\n        parser.add_argument('--ntest', type=int, default=float(\"inf\"), help='# of test examples.')\n        parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')\n        parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')\n        # Dropout and Batchnorm has different behavioir during training and test.\n        parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')\n        parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')\n        parser.add_argument('--imagefolder', type=str, default='images', help='subfolder to save images')\n        # rewrite devalue values\n        parser.set_defaults(model='test')\n        # To avoid cropping, the load_size should be the same as crop_size\n        parser.set_defaults(load_size=parser.get_default('crop_size'))\n        self.isTrain = False\n        return parser\n"
  },
  {
    "path": "options/train_options.py",
    "content": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n    \"\"\"This class includes training options.\n\n    It also includes shared options defined in BaseOptions.\n    \"\"\"\n\n    def initialize(self, parser):\n        parser = BaseOptions.initialize(self, parser)\n        # visdom and HTML visualization parameters\n        parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')\n        parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n        parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')\n        parser.add_argument('--display_server', type=str, default=\"http://localhost\", help='visdom server of the web display')\n        parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is \"main\")')\n        parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')\n        parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')\n        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n        parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n        # network saving and loading parameters\n        parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')\n        parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')\n        parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')\n        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')\n        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')\n        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')\n        # training parameters\n        parser.add_argument('--n_epochs', type=int, default=200, help='the end epoch count')\n        parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')\n        parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')\n        parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')\n        parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')\n        parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')\n        parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n        parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')\n        parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')\n\n        self.isTrain = True\n        return parser\n"
  },
  {
    "path": "portrait_drawing_resources.md",
    "content": "\n- Charles Burns (style1): https://www.pinterest.co.uk/johns59/charles-burns-fan-club/\n- Yann Legendre (style1): http://www.yannlegendre.com/project/portraits/\n- Kathryn Rathke (style2):\nhttps://www.kathrynrathke.com/\n- Vectorportal (style3): https://www.pinterest.co.uk/vectorportal/celebrity-vector-illustrations/"
  },
  {
    "path": "preprocess/face_align_512.m",
    "content": "function [trans_img]=face_align_512(impath,facial5point,savedir)\n% align the faces by similarity transformation.\n% using 5 facial landmarks: 2 eyes, nose, 2 mouth corners.\n%   impath: path to image\n%   facial5point: 5x2 size, 5 facial landmark positions, detected by MTCNN\n%   savedir: savedir for cropped image and transformed facial landmarks\n\n%% alignment settings\nimgSize = [512,512];\ncoord5point = [180,230;\n    300,230;\n    240,301;\n    186,365.6;\n    294,365.6];%480x480\ncoord5point = (coord5point-240)/560 * 512 + 256;\n\n%% face alignment\n\n% load and align, resize image to imgSize\nimg      = imread(impath);\nfacial5point = double(facial5point);\ntransf   = cp2tform(facial5point, coord5point, 'similarity');\ntrans_img  = imtransform(img, transf, 'XData', [1 imgSize(2)],...\n                                    'YData', [1 imgSize(1)],...\n                                    'Size', imgSize,...\n                                    'FillValues', [255;255;255]);\ntrans_facial5point = round(tformfwd(transf,facial5point));\n\n\n%% save results\nif ~exist(savedir,'dir')\n    mkdir(savedir)\nend\n[~,name,~] = fileparts(impath);\n% save trans_img\nimwrite(trans_img, fullfile(savedir,[name,'_resized.png']));\nfprintf('write aligned image to %s\\n',fullfile(savedir,[name,'_resized.png']));\n\n%% show results\nimshow(trans_img); hold on;\nplot(trans_facial5point(:,1),trans_facial5point(:,2),'b');\nplot(trans_facial5point(:,1),trans_facial5point(:,2),'r+');\n\nend"
  },
  {
    "path": "preprocess/readme.md",
    "content": "## Preprocessing steps\n\nDuring training, face photos and drawings are aligned and have nose,eyes,lips mask detected. \n\nDuring test, the alignment step is optional and the masks are not needed.\n\n### 1. Align, resize, crop images to 512x512\n\nAll training and testing images in our model are aligned using facial landmarks. And landmarks after alignment are needed in our code.\n\n- First, 5 facial landmark for a face photo need to be detected (we detect using [MTCNN](https://github.com/kpzhang93/MTCNN_face_detection_alignment)(MTCNNv1)).\n\n- Then, we provide a matlab function in `face_align_512.m` to align, resize and crop face photos (and corresponding drawings) to 512x512. Call this function in MATLAB to align the image to 512x512.\nFor example, for `ia_selfie_10515.jpg` in `example` dir, 5 detected facial landmark is saved in `example/ia_selfie_10515_facial5point.mat`. Call following in MATLAB:\n```bash\nload('example/ia_selfie_10515_facial5point.mat');\n[trans_img]=face_align_512('example/ia_selfie_10515.jpg',facial5point,'example');\n```\n\nThis will align the image and output aligned image  in `example` folder.\nSee `face_align_512.m` for more instructions.\n\n\n### 2. Prepare nose,eyes,lips masks\n\nIn our work, we use the face parsing network in https://github.com/cientgu/Mask_Guided_Portrait_Editing to get nose,eyes,lips regions and then dilate the regions to make them cover these facial features (some examples are shown in `example` folder).\n\n- The background masks need to be copied to `datasets/portrait_drawing/train/A(B)(_eyes)(_lips)`, and has the **same filename** with aligned face photos.  \n"
  },
  {
    "path": "readme.md",
    "content": "\n# Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping\n\nWe provide PyTorch implementations for our CVPR 2020 paper \"Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping\". [paper](https://openaccess.thecvf.com/content_CVPR_2020/papers/Yi_Unpaired_Portrait_Drawing_Generation_via_Asymmetric_Cycle_Mapping_CVPR_2020_paper.pdf), [suppl](https://openaccess.thecvf.com/content_CVPR_2020/supplemental/Yi_Unpaired_Portrait_Drawing_CVPR_2020_supplemental.pdf).\n\nThis project generates multi-style artistic portrait drawings from face photos using a GAN-based model.\n\n[[Jittor implementation]](https://github.com/yiranran/Unpaired-Portrait-Drawing-Jittor)\n\n\n## Our Proposed Framework\n \n<img src = 'imgs/architecture.jpg'>\n\n## Sample Results\nFrom left to right: input, output(style1), output(style2), output(style3)\n<img src = 'imgs/results.jpg'>\n\n## Prerequisites\n- Linux or macOS\n- Python 3\n- CPU or NVIDIA GPU + CUDA CuDNN\n\n\n## Installation\n- To install the dependencies, run\n```bash\npip install -r requirements.txt\n```\n\n## Colab\nA colab demo is [here](https://colab.research.google.com/drive/1U1fPXD1JukuKPOrhGMX1iaJC-d8_RUYr).\n\n## Test steps (apply a pretrained model)\n\n- 1. Download pre-trained models from [BaiduYun](https://pan.baidu.com/s/1_9Fy8mRpTQp6AvqhHsfQAQ)(extract code:c9h7) or [GoogleDrive](https://drive.google.com/drive/folders/1FzOcdlMYhvK_nyLCe8wnwotMphhIoiYt?usp=sharing) and rename the folder to `checkpoints`.\n\n- 2. Test for example photos: generate artistic portrait drawings for example photos in the folder `./examples` using\n``` bash\n# with GPU\npython test_seq_style.py\n# without GPU\npython test_seq_style.py --gpu -1\n```\nThe test results will be saved to a html file here: `./results/pretrained/test_200/index3styles.html`.\nThe result images are saved in `./results/pretrained/test_200/images3styles`,\nwhere `real`, `fake1`, `fake2`, `fake3` correspond to input face photo, style1 drawing, style2 drawing, style3 drawing respectively.\n\n<img src = 'imgs/how_to_crop.jpg'>\n\n- 3. To test on your own photos: First use an image editor to crop the face region of your photo (or use an optional preprocess [here](preprocess/readme.md)). Then specify the folder that contains test photos using option `--dataroot`, specify save folder name using option `--savefolder` and run the above command again:\n\n``` bash\n# with GPU\npython test_seq_style.py --dataroot [input_folder] --savefolder [save_folder_name]\n# without GPU\npython test_seq_style.py --gpu -1 --dataroot [input_folder] --savefolder [save_folder_name]\n# E.g.\npython test_seq_style.py --gpu -1 --dataroot ./imgs/test1 --savefolder 3styles_test1\n```\nThe test results will be saved to a html file here: `./results/pretrained/test_200/index[save_folder_name].html`.\nThe result images are saved in `./results/pretrained/test_200/images[save_folder_name]`.\nAn example html screenshot is shown below:\n<img src = 'imgs/result_html.jpg'>\n\nYou can contact email yr16@mails.tsinghua.edu.cn for any questions.\n\n## Train steps\n\n- 1. Prepare for the dataset: 1) download face photos and portrait drawings from internet (e.g. [resources](portrait_drawing_resources.md)). 2) align, crop photos and drawings & 3) prepare nose, eyes, lips masks according to [preprocess instructions](preprocess/readme.md). 3) put aligned photos under `./datasets/portrait_drawing/train/A`, aligned drawings under `./datasets/portrait_drawing/train/B`, masks under `A_nose`,`A_eyes`,`A_lips`,`B_nose`,`B_eyes`,`B_lips` respectively.\n\n- 2. Train a 3-class style classifier and extract the 3-dim style feature (according to paper). And save the style feature of each drawing in the training set in .npy format, in folder `./datasets/portrait_drawing/train/B_feat`\n\nA subset of our training set is [here](https://drive.google.com/file/d/1OSMOR3-uhGkoPwPFRNychJSNrpSak_23/view?usp=sharing).\n\n- 3. Train our model\n``` bash\nsh ./scripts/train.sh\n```\nModels are saved in folder checkpoints/portrait_drawing\n\n\n## Citation\nIf you use this code for your research, please cite our paper.\n\n```\n@inproceedings{YiLLR20,\n  title     = {Unpaired Portrait Drawing Generation via Asymmetric Cycle Mapping},\n  author    = {Yi, Ran and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L},\n  booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition (CVPR '20)},\n  pages     = {8214--8222},\n  year      = {2020}\n}\n```\n\n## Acknowledgments\nOur code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix).\n"
  },
  {
    "path": "requirements.txt",
    "content": "torch==1.2.0\ntorchvision==0.4.0\ndominate==2.4.0\nvisdom==0.1.8.9\nscipy==1.1.0\nnumpy==1.16.4\nPillow==6.2.1\nopencv-python==4.1.0.25"
  },
  {
    "path": "scripts/train.sh",
    "content": "set -ex\npython train.py --dataroot ./datasets/portrait_drawing --name formal --model asymmetric_cycle_gan_cls --output_nc 1 --load_size 572 --crop_size 512 --lr 0.000015 --dataset_mode unaligned_mask_stylecls --display_env asymmetric_trainset --gpu_ids 0 --gpu_ids_p 0 --niter 100 --niter_decay 200 --n_epochs 200"
  },
  {
    "path": "test.py",
    "content": "\"\"\"General-purpose test script for image-to-image translation.\n\nOnce you have trained your model with train.py, you can use this script to test the model.\nIt will load a saved model from --checkpoints_dir and save the results to --results_dir.\n\nIt first creates model and dataset given the option. It will hard-code some parameters.\nIt then runs inference for --num_test images and save results to an HTML file.\n\nExample (You need to train models first or download pre-trained models from our website):\n    Test a CycleGAN model (both sides):\n        python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan\n\n    Test a CycleGAN model (one side only):\n        python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout\n\n    The option '--model test' is used for generating CycleGAN results only for one side.\n    This option will automatically set '--dataset_mode single', which only loads the images from one set.\n    On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,\n    which is sometimes unnecessary. The results will be saved at ./results/.\n    Use '--results_dir <directory_path_to_save_result>' to specify the results directory.\n\n    Test a pix2pix model:\n        python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA\n\nSee options/base_options.py and options/test_options.py for more test options.\nSee training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md\nSee frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md\n\"\"\"\nimport os\nfrom options.test_options import TestOptions\nfrom data import create_dataset\nfrom models import create_model\nfrom util.visualizer import save_images\nfrom util import html\n\n\nif __name__ == '__main__':\n    opt = TestOptions().parse()  # get test options\n    # hard-code some parameters for test\n    opt.num_threads = 0   # test code only supports num_threads = 1\n    opt.batch_size = 1    # test code only supports batch_size = 1\n    opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.\n    opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.\n    opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.\n    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options\n    model = create_model(opt)      # create a model given opt.model and other options\n    model.setup(opt)               # regular setup: load and print networks; create schedulers\n    # create a website\n    web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.epoch))  # define the website directory\n    #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch))\n    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch), refresh=0, folder=opt.imagefolder)\n    # test with eval mode. This only affects layers like batchnorm and dropout.\n    # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.\n    # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.\n    if opt.eval:\n        model.eval()\n    for i, data in enumerate(dataset):\n        if i >= opt.num_test:  # only apply our model to opt.num_test images.\n            break\n        model.set_input(data)  # unpack data from data loader\n        model.test()           # run inference\n        visuals = model.get_current_visuals()  # get image results\n        img_path = model.get_image_paths()     # get image paths\n        if i % 5 == 0:  # save images to an HTML file\n            print('processing (%04d)-th image... %s' % (i, img_path))\n        save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, W=opt.W, H=opt.H)\n    webpage.save()  # save the HTML\n"
  },
  {
    "path": "test_seq_style.py",
    "content": "import os\nimport argparse\n\ndef opts():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('-g','--gpu', default = '0', type = str, help = 'gpu ids, -1 for cpu, default is 0.')\n    parser.add_argument('-d','--dataroot', default = './examples', type = str, help = 'the input folder that contains test face photos, default is ./examples')\n    parser.add_argument('-s','--savefolder', default = '3styles', type = str, help = 'the name of save folder that contains result images, default is 3styles')\n    return parser.parse_args()\n\nif __name__ == '__main__':\n    opt = opts()\n    exp = 'pretrained'\n    imgsize = 512\n    epoch = '200'\n    dataroot = opt.dataroot\n    gpu_id = opt.gpu\n\n    # test 3 styles in one pass\n    savefolder = 'images'+opt.savefolder\n    os.system('python3 test.py --dataroot %s --name %s --model test_3styles --output_nc 1 --no_dropout --num_test 1000 --epoch %s --imagefolder %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,epoch,savefolder,imgsize,imgsize,gpu_id))\n    print('check ./results/%s/test_%s/index%s.html'%(exp,epoch,savefolder[6:]))\n    print('saved to ./results/%s/test_%s/%s'%(exp,epoch,savefolder))\n\n    # test 3 styles separately\n    '''\n    for vec in [[1,0,0],[0,1,0],[0,0,1]]:\n        #1,0,0 for style1; 0,1,0 for style2; 0,0,1 for style3\n        svec = '%d,%d,%d' % (vec[0],vec[1],vec[2])\n        savefolder = 'imagesstyle%d-%d-%d'%(vec[0],vec[1],vec[2])\n        print('results/%s/test_%s/index%s.html'%(exp,epoch,savefolder[6:]))\n        os.system('python3 test.py --dataroot %s --name %s --model test --output_nc 1 --no_dropout --model_suffix _A --num_test 1000 --epoch %s --imagefolder %s --sinput svec --svec %s --crop_size %d --load_size %d --gpu_ids %s' % (dataroot,exp,epoch,savefolder,svec,imgsize,imgsize,gpu_id))\n    '''\n        "
  },
  {
    "path": "train.py",
    "content": "\"\"\"General-purpose training script for image-to-image translation.\n\nThis script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and\ndifferent datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization).\nYou need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').\n\nIt first creates model, dataset, and visualizer given the option.\nIt then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models.\nThe script supports continue/resume training. Use '--continue_train' to resume your previous training.\n\nExample:\n    Train a CycleGAN model:\n        python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan\n    Train a pix2pix model:\n        python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA\n\nSee options/base_options.py and options/train_options.py for more training options.\nSee training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md\nSee frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md\n\"\"\"\nimport time\nfrom options.train_options import TrainOptions\nfrom data import create_dataset\nfrom models import create_model\nfrom util.visualizer import Visualizer\nimport pdb\n\nif __name__ == '__main__':\n    start = time.time()\n    opt = TrainOptions().parse()   # get training options\n    dataset = create_dataset(opt)  # create a dataset given opt.dataset_mode and other options\n    dataset_size = len(dataset)    # get the number of images in the dataset.\n    print('The number of training images = %d' % dataset_size)\n\n    model = create_model(opt)      # create a model given opt.model and other options\n    model.setup(opt)               # regular setup: load and print networks; create schedulers\n    visualizer = Visualizer(opt)   # create a visualizer that display/save images and plots\n    total_iters = 0                # the total number of training iterations\n\n    for epoch in range(opt.epoch_count, opt.n_epochs + 1):    # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>\n        epoch_start_time = time.time()  # timer for entire epoch\n        iter_data_time = time.time()    # timer for data loading per iteration\n        epoch_iter = 0                  # the number of training iterations in current epoch, reset to 0 every epoch\n        model.update_process(epoch)\n\n        for i, data in enumerate(dataset):  # inner loop within one epoch\n            iter_start_time = time.time()  # timer for computation per iteration\n            if total_iters % opt.print_freq == 0:\n                t_data = iter_start_time - iter_data_time\n            visualizer.reset()\n            total_iters += opt.batch_size\n            epoch_iter += opt.batch_size\n            model.set_input(data)         # unpack data from dataset and apply preprocessing\n            model.optimize_parameters()   # calculate loss functions, get gradients, update network weights\n            \n            if total_iters % opt.display_freq == 0:   # display images on visdom and save images to a HTML file\n                save_result = total_iters % opt.update_html_freq == 0\n                model.compute_visuals()\n                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)\n\n            if total_iters % opt.print_freq == 0:    # print training losses and save logging information to the disk\n                losses = model.get_current_losses()\n                t_comp = (time.time() - iter_start_time) / opt.batch_size\n                if opt.model == 'cycle_gan':\n                    processes = [model.process] + model.lambda_As\n                    visualizer.print_current_losses_process(epoch, epoch_iter, losses, t_comp, t_data, processes)\n                else:\n                    visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)\n                if opt.display_id > 0:\n                    visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses)\n\n            if total_iters % opt.save_latest_freq == 0:   # cache our latest model every <save_latest_freq> iterations\n                print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))\n                save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest'\n                model.save_networks(save_suffix)\n\n            iter_data_time = time.time()\n        if epoch % opt.save_epoch_freq == 0:              # cache our model every <save_epoch_freq> epochs\n            print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))\n            model.save_networks('latest')\n            model.save_networks(epoch)\n\n        print('End of epoch %d / %d \\t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))\n        model.update_learning_rate()                     # update learning rates at the end of every epoch.\n    \n    print('Total Time Taken: %d sec' % (time.time() - start))"
  },
  {
    "path": "util/__init__.py",
    "content": "\"\"\"This package includes a miscellaneous collection of useful helper functions.\"\"\"\n"
  },
  {
    "path": "util/get_data.py",
    "content": "from __future__ import print_function\nimport os\nimport tarfile\nimport requests\nfrom warnings import warn\nfrom zipfile import ZipFile\nfrom bs4 import BeautifulSoup\nfrom os.path import abspath, isdir, join, basename\n\n\nclass GetData(object):\n    \"\"\"A Python script for downloading CycleGAN or pix2pix datasets.\n\n    Parameters:\n        technique (str) -- One of: 'cyclegan' or 'pix2pix'.\n        verbose (bool)  -- If True, print additional information.\n\n    Examples:\n        >>> from util.get_data import GetData\n        >>> gd = GetData(technique='cyclegan')\n        >>> new_data_path = gd.get(save_path='./datasets')  # options will be displayed.\n\n    Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'\n    and 'scripts/download_cyclegan_model.sh'.\n    \"\"\"\n\n    def __init__(self, technique='cyclegan', verbose=True):\n        url_dict = {\n            'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/',\n            'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'\n        }\n        self.url = url_dict.get(technique.lower())\n        self._verbose = verbose\n\n    def _print(self, text):\n        if self._verbose:\n            print(text)\n\n    @staticmethod\n    def _get_options(r):\n        soup = BeautifulSoup(r.text, 'lxml')\n        options = [h.text for h in soup.find_all('a', href=True)\n                   if h.text.endswith(('.zip', 'tar.gz'))]\n        return options\n\n    def _present_options(self):\n        r = requests.get(self.url)\n        options = self._get_options(r)\n        print('Options:\\n')\n        for i, o in enumerate(options):\n            print(\"{0}: {1}\".format(i, o))\n        choice = input(\"\\nPlease enter the number of the \"\n                       \"dataset above you wish to download:\")\n        return options[int(choice)]\n\n    def _download_data(self, dataset_url, save_path):\n        if not isdir(save_path):\n            os.makedirs(save_path)\n\n        base = basename(dataset_url)\n        temp_save_path = join(save_path, base)\n\n        with open(temp_save_path, \"wb\") as f:\n            r = requests.get(dataset_url)\n            f.write(r.content)\n\n        if base.endswith('.tar.gz'):\n            obj = tarfile.open(temp_save_path)\n        elif base.endswith('.zip'):\n            obj = ZipFile(temp_save_path, 'r')\n        else:\n            raise ValueError(\"Unknown File Type: {0}.\".format(base))\n\n        self._print(\"Unpacking Data...\")\n        obj.extractall(save_path)\n        obj.close()\n        os.remove(temp_save_path)\n\n    def get(self, save_path, dataset=None):\n        \"\"\"\n\n        Download a dataset.\n\n        Parameters:\n            save_path (str) -- A directory to save the data to.\n            dataset (str)   -- (optional). A specific dataset to download.\n                            Note: this must include the file extension.\n                            If None, options will be presented for you\n                            to choose from.\n\n        Returns:\n            save_path_full (str) -- the absolute path to the downloaded data.\n\n        \"\"\"\n        if dataset is None:\n            selected_dataset = self._present_options()\n        else:\n            selected_dataset = dataset\n\n        save_path_full = join(save_path, selected_dataset.split('.')[0])\n\n        if isdir(save_path_full):\n            warn(\"\\n'{0}' already exists. Voiding Download.\".format(\n                save_path_full))\n        else:\n            self._print('Downloading Data...')\n            url = \"{0}/{1}\".format(self.url, selected_dataset)\n            self._download_data(url, save_path=save_path)\n\n        return abspath(save_path_full)\n"
  },
  {
    "path": "util/html.py",
    "content": "import dominate\nfrom dominate.tags import meta, h3, table, tr, td, p, a, img, br\nimport os\n\n\nclass HTML:\n    \"\"\"This HTML class allows us to save images and write texts into a single HTML file.\n\n     It consists of functions such as <add_header> (add a text header to the HTML file),\n     <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).\n     It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.\n    \"\"\"\n\n    def __init__(self, web_dir, title, refresh=0, folder='images'):\n        \"\"\"Initialize the HTML classes\n\n        Parameters:\n            web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/\n            title (str)   -- the webpage name\n            refresh (int) -- how often the website refresh itself; if 0; no refreshing\n        \"\"\"\n        self.title = title\n        self.web_dir = web_dir\n        #self.img_dir = os.path.join(self.web_dir, 'images')\n        self.img_dir = os.path.join(self.web_dir, folder)\n        self.folder = folder\n        if not os.path.exists(self.web_dir):\n            os.makedirs(self.web_dir)\n        if not os.path.exists(self.img_dir):\n            os.makedirs(self.img_dir)\n\n        self.doc = dominate.document(title=title)\n        if refresh > 0:\n            with self.doc.head:\n                meta(http_equiv=\"refresh\", content=str(refresh))\n\n    def get_image_dir(self):\n        \"\"\"Return the directory that stores images\"\"\"\n        return self.img_dir\n\n    def add_header(self, text):\n        \"\"\"Insert a header to the HTML file\n\n        Parameters:\n            text (str) -- the header text\n        \"\"\"\n        with self.doc:\n            h3(text)\n\n    def add_images(self, ims, txts, links, width=400):\n        \"\"\"add images to the HTML file\n\n        Parameters:\n            ims (str list)   -- a list of image paths\n            txts (str list)  -- a list of image names shown on the website\n            links (str list) --  a list of hyperref links; when you click an image, it will redirect you to a new page\n        \"\"\"\n        self.t = table(border=1, style=\"table-layout: fixed;\")  # Insert a table\n        self.doc.add(self.t)\n        with self.t:\n            with tr():\n                for im, txt, link in zip(ims, txts, links):\n                    with td(style=\"word-wrap: break-word;\", halign=\"center\", valign=\"top\"):\n                        with p():\n                            with a(href=os.path.join('images', link)):\n                                #img(style=\"width:%dpx\" % width, src=os.path.join('images', im))\n                                img(style=\"width:%dpx\" % width, src=os.path.join(self.folder, im))\n                            br()\n                            p(txt)\n\n    def save(self):\n        \"\"\"save the current content to the HMTL file\"\"\"\n        #html_file = '%s/index.html' % self.web_dir\n        name = self.folder[6:] if self.folder[:6] == 'images' else self.folder\n        html_file = '%s/index%s.html' % (self.web_dir, name)\n        f = open(html_file, 'wt')\n        f.write(self.doc.render())\n        f.close()\n\n\nif __name__ == '__main__':  # we show an example usage here.\n    html = HTML('web/', 'test_html')\n    html.add_header('hello world')\n\n    ims, txts, links = [], [], []\n    for n in range(4):\n        ims.append('image_%d.png' % n)\n        txts.append('text_%d' % n)\n        links.append('image_%d.png' % n)\n    html.add_images(ims, txts, links)\n    html.save()\n"
  },
  {
    "path": "util/image_pool.py",
    "content": "import random\nimport torch\n\n\nclass ImagePool():\n    \"\"\"This class implements an image buffer that stores previously generated images.\n\n    This buffer enables us to update discriminators using a history of generated images\n    rather than the ones produced by the latest generators.\n    \"\"\"\n\n    def __init__(self, pool_size):\n        \"\"\"Initialize the ImagePool class\n\n        Parameters:\n            pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created\n        \"\"\"\n        self.pool_size = pool_size\n        if self.pool_size > 0:  # create an empty pool\n            self.num_imgs = 0\n            self.images = []\n\n    def query(self, images):\n        \"\"\"Return an image from the pool.\n\n        Parameters:\n            images: the latest generated images from the generator\n\n        Returns images from the buffer.\n\n        By 50/100, the buffer will return input images.\n        By 50/100, the buffer will return images previously stored in the buffer,\n        and insert the current images to the buffer.\n        \"\"\"\n        if self.pool_size == 0:  # if the buffer size is 0, do nothing\n            return images\n        return_images = []\n        for image in images:\n            image = torch.unsqueeze(image.data, 0)\n            if self.num_imgs < self.pool_size:   # if the buffer is not full; keep inserting current images to the buffer\n                self.num_imgs = self.num_imgs + 1\n                self.images.append(image)\n                return_images.append(image)\n            else:\n                p = random.uniform(0, 1)\n                if p > 0.5:  # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer\n                    random_id = random.randint(0, self.pool_size - 1)  # randint is inclusive\n                    tmp = self.images[random_id].clone()\n                    self.images[random_id] = image\n                    return_images.append(tmp)\n                else:       # by another 50% chance, the buffer will return the current image\n                    return_images.append(image)\n        return_images = torch.cat(return_images, 0)   # collect all the images and return\n        return return_images\n"
  },
  {
    "path": "util/util.py",
    "content": "\"\"\"This module contains simple helper functions \"\"\"\nfrom __future__ import print_function\nimport torch\nimport numpy as np\nfrom PIL import Image\nimport os\nimport pdb\nfrom scipy.io import savemat\n\n\ndef tensor2im(input_image, imtype=np.uint8):\n    \"\"\"\"Converts a Tensor array into a numpy image array.\n\n    Parameters:\n        input_image (tensor) --  the input image tensor array\n        imtype (type)        --  the desired type of the converted numpy array\n    \"\"\"\n    if not isinstance(input_image, np.ndarray):\n        if isinstance(input_image, torch.Tensor):  # get the data from a variable\n            image_tensor = input_image.data\n        else:\n            return input_image\n        #pdb.set_trace()\n        image_numpy = image_tensor[0].cpu().float().numpy()  # convert it into a numpy array\n        if image_numpy.shape[0] == 1:  # grayscale to RGB\n            image_numpy = np.tile(image_numpy, (3, 1, 1))\n        elif image_numpy.shape[0] == 2:\n            image_numpy = np.concatenate([image_numpy, image_numpy[1:2,:,:]], 0)\n        image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0  # post-processing: tranpose and scaling\n    else:  # if it is a numpy array, do nothing\n        image_numpy = input_image\n    return image_numpy.astype(imtype)\n    #return np.round(image_numpy).astype(imtype),image_numpy\n\n\ndef diagnose_network(net, name='network'):\n    \"\"\"Calculate and print the mean of average absolute(gradients)\n\n    Parameters:\n        net (torch network) -- Torch network\n        name (str) -- the name of the network\n    \"\"\"\n    mean = 0.0\n    count = 0\n    for param in net.parameters():\n        if param.grad is not None:\n            mean += torch.mean(torch.abs(param.grad.data))\n            count += 1\n    if count > 0:\n        mean = mean / count\n    print(name)\n    print(mean)\n\n\ndef save_image(image_numpy, image_path):\n    \"\"\"Save a numpy image to the disk\n\n    Parameters:\n        image_numpy (numpy array) -- input numpy array\n        image_path (str)          -- the path of the image\n    \"\"\"\n    image_pil = Image.fromarray(image_numpy)\n    #pdb.set_trace()\n    image_pil.save(image_path)\n\n\ndef print_numpy(x, val=True, shp=False):\n    \"\"\"Print the mean, min, max, median, std, and size of a numpy array\n\n    Parameters:\n        val (bool) -- if print the values of the numpy array\n        shp (bool) -- if print the shape of the numpy array\n    \"\"\"\n    x = x.astype(np.float64)\n    if shp:\n        print('shape,', x.shape)\n    if val:\n        x = x.flatten()\n        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (\n            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))\n\n\ndef mkdirs(paths):\n    \"\"\"create empty directories if they don't exist\n\n    Parameters:\n        paths (str list) -- a list of directory paths\n    \"\"\"\n    if isinstance(paths, list) and not isinstance(paths, str):\n        for path in paths:\n            mkdir(path)\n    else:\n        mkdir(paths)\n\n\ndef mkdir(path):\n    \"\"\"create a single empty directory if it didn't exist\n\n    Parameters:\n        path (str) -- a single directory path\n    \"\"\"\n    if not os.path.exists(path):\n        os.makedirs(path)\n\ndef normalize_tensor(in_feat,eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))\n    return in_feat/(norm_factor+eps)"
  },
  {
    "path": "util/visualizer.py",
    "content": "import numpy as np\nimport os\nimport sys\nimport ntpath\nimport time\nfrom . import util, html\nfrom subprocess import Popen, PIPE\nfrom PIL import Image\n\nif sys.version_info[0] == 2:\n    VisdomExceptionBase = Exception\nelse:\n    VisdomExceptionBase = ConnectionError\n\n\ndef save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256, W=None, H=None):\n    \"\"\"Save images to the disk.\n\n    Parameters:\n        webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)\n        visuals (OrderedDict)    -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs\n        image_path (str)         -- the string is used to create image paths\n        aspect_ratio (float)     -- the aspect ratio of saved images\n        width (int)              -- the images will be resized to width x width\n\n    This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.\n    \"\"\"\n    image_dir = webpage.get_image_dir()\n    short_path = ntpath.basename(image_path[0])\n    name = os.path.splitext(short_path)[0]\n\n    webpage.add_header(name)\n    ims, txts, links = [], [], []\n\n    for label, im_data in visuals.items():\n        ## tensor to im\n        im = util.tensor2im(im_data)\n        image_name = '%s_%s.png' % (name, label)\n        save_path = os.path.join(image_dir, image_name)\n        h, w, _ = im.shape\n        if W is not None and H is not None and (W != w or H != h):\n            im = np.array(Image.fromarray(im).resize((W, H), Image.BICUBIC))\n        else:\n            if aspect_ratio > 1.0:\n                im = np.array(Image.fromarray(im).resize((int(w * aspect_ratio), h), Image.BICUBIC))\n            if aspect_ratio < 1.0:\n                im = np.array(Image.fromarray(im).resize((w, int(h / aspect_ratio)), Image.BICUBIC))\n        util.save_image(im, save_path)\n\n        ims.append(image_name)\n        txts.append(label)\n        links.append(image_name)\n    webpage.add_images(ims, txts, links, width=width)\n\n\nclass Visualizer():\n    \"\"\"This class includes several functions that can display/save images and print/save logging information.\n\n    It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.\n    \"\"\"\n\n    def __init__(self, opt):\n        \"\"\"Initialize the Visualizer class\n\n        Parameters:\n            opt -- stores all the experiment flags; needs to be a subclass of BaseOptions\n        Step 1: Cache the training/test options\n        Step 2: connect to a visdom server\n        Step 3: create an HTML object for saveing HTML filters\n        Step 4: create a logging file to store training losses\n        \"\"\"\n        self.opt = opt  # cache the option\n        self.display_id = opt.display_id\n        self.use_html = opt.isTrain and not opt.no_html\n        self.win_size = opt.display_winsize\n        self.name = opt.name\n        self.port = opt.display_port\n        self.saved = False\n        if self.display_id > 0:  # connect to a visdom server given <display_port> and <display_server>\n            import visdom\n            self.ncols = opt.display_ncols\n            self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env)\n            if not self.vis.check_connection():\n                self.create_visdom_connections()\n\n        if self.use_html:  # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/\n            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')\n            self.img_dir = os.path.join(self.web_dir, 'images')\n            print('create web directory %s...' % self.web_dir)\n            util.mkdirs([self.web_dir, self.img_dir])\n        # create a logging file to store training losses\n        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')\n        with open(self.log_name, \"a\") as log_file:\n            now = time.strftime(\"%c\")\n            log_file.write('================ Training Loss (%s) ================\\n' % now)\n\n    def reset(self):\n        \"\"\"Reset the self.saved status\"\"\"\n        self.saved = False\n\n    def create_visdom_connections(self):\n        \"\"\"If the program could not connect to Visdom server, this function will start a new server at port < self.port > \"\"\"\n        cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port\n        print('\\n\\nCould not connect to Visdom server. \\n Trying to start a server....')\n        print('Command: %s' % cmd)\n        Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)\n\n    def display_current_results(self, visuals, epoch, save_result):\n        \"\"\"Display current results on visdom; save current results to an HTML file.\n\n        Parameters:\n            visuals (OrderedDict) - - dictionary of images to display or save\n            epoch (int) - - the current epoch\n            save_result (bool) - - if save the current results to an HTML file\n        \"\"\"\n        if self.display_id > 0:  # show images in the browser using visdom\n            ncols = self.ncols\n            if ncols > 0:        # show all the images in one visdom panel\n                ncols = min(ncols, len(visuals))\n                h, w = next(iter(visuals.values())).shape[:2]\n                table_css = \"\"\"<style>\n                        table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}\n                        table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}\n                        </style>\"\"\" % (w, h)  # create a table css\n                # create a table of images.\n                title = self.name\n                label_html = ''\n                label_html_row = ''\n                images = []\n                idx = 0\n                for label, image in visuals.items():\n                    image_numpy = util.tensor2im(image)\n                    label_html_row += '<td>%s</td>' % label\n                    images.append(image_numpy.transpose([2, 0, 1]))\n                    idx += 1\n                    if idx % ncols == 0:\n                        label_html += '<tr>%s</tr>' % label_html_row\n                        label_html_row = ''\n                white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255\n                while idx % ncols != 0:\n                    images.append(white_image)\n                    label_html_row += '<td></td>'\n                    idx += 1\n                if label_html_row != '':\n                    label_html += '<tr>%s</tr>' % label_html_row\n                try:\n                    self.vis.images(images, nrow=ncols, win=self.display_id + 1,\n                                    padding=2, opts=dict(title=title + ' images'))\n                    label_html = '<table>%s</table>' % label_html\n                    self.vis.text(table_css + label_html, win=self.display_id + 2,\n                                  opts=dict(title=title + ' labels'))\n                except VisdomExceptionBase:\n                    self.create_visdom_connections()\n\n            else:     # show each image in a separate visdom panel;\n                idx = 1\n                try:\n                    for label, image in visuals.items():\n                        image_numpy = util.tensor2im(image)\n                        self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),\n                                       win=self.display_id + idx)\n                        idx += 1\n                except VisdomExceptionBase:\n                    self.create_visdom_connections()\n\n        if self.use_html and (save_result or not self.saved):  # save images to an HTML file if they haven't been saved.\n            self.saved = True\n            # save images to the disk\n            for label, image in visuals.items():\n                image_numpy = util.tensor2im(image)\n                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))\n                util.save_image(image_numpy, img_path)\n\n            # update website\n            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)\n            for n in range(epoch, 0, -1):\n                webpage.add_header('epoch [%d]' % n)\n                ims, txts, links = [], [], []\n\n                for label, image_numpy in visuals.items():\n                    image_numpy = util.tensor2im(image)\n                    img_path = 'epoch%.3d_%s.png' % (n, label)\n                    ims.append(img_path)\n                    txts.append(label)\n                    links.append(img_path)\n                webpage.add_images(ims, txts, links, width=self.win_size)\n            webpage.save()\n\n    def plot_current_losses(self, epoch, counter_ratio, losses):\n        \"\"\"display the current losses on visdom display: dictionary of error labels and values\n\n        Parameters:\n            epoch (int)           -- current epoch\n            counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1\n            losses (OrderedDict)  -- training losses stored in the format of (name, float) pairs\n        \"\"\"\n        if not hasattr(self, 'plot_data'):\n            self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}\n        self.plot_data['X'].append(epoch + counter_ratio)\n        self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])\n        #X = np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1)\n        #Y = np.array(self.plot_data['Y'])\n        try:\n            self.vis.line(\n                X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),\n                Y=np.array(self.plot_data['Y']),\n                opts={\n                    'title': self.name + ' loss over time',\n                    'legend': self.plot_data['legend'],\n                    'xlabel': 'epoch',\n                    'ylabel': 'loss'},\n                win=self.display_id)\n        except VisdomExceptionBase:\n            self.create_visdom_connections()\n\n    # losses: same format as |losses| of plot_current_losses\n    def print_current_losses(self, epoch, iters, losses, t_comp, t_data):\n        \"\"\"print current losses on console; also save the losses to the disk\n\n        Parameters:\n            epoch (int) -- current epoch\n            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)\n            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs\n            t_comp (float) -- computational time per data point (normalized by batch_size)\n            t_data (float) -- data loading time per data point (normalized by batch_size)\n        \"\"\"\n        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)\n        for k, v in losses.items():\n            message += '%s: %.3f ' % (k, v)\n\n        print(message)  # print the message\n        with open(self.log_name, \"a\") as log_file:\n            log_file.write('%s\\n' % message)  # save the message\n    \n    # losses: same format as |losses| of plot_current_losses\n    def print_current_losses_process(self, epoch, iters, losses, t_comp, t_data, processes):\n        \"\"\"print current losses on console; also save the losses to the disk\n\n        Parameters:\n            epoch (int) -- current epoch\n            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)\n            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs\n            t_comp (float) -- computational time per data point (normalized by batch_size)\n            t_data (float) -- data loading time per data point (normalized by batch_size)\n        \"\"\"\n        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)\n        message += '[process: %.3f, non_trunc: %.3f, trunc: %.3f] ' % (processes[0], processes[1], processes[2])\n        for k, v in losses.items():\n            message += '%s: %.3f ' % (k, v)\n\n        print(message)  # print the message\n        with open(self.log_name, \"a\") as log_file:\n            log_file.write('%s\\n' % message)  # save the message\n"
  }
]