[
  {
    "path": ".gitignore",
    "content": "datasets/\ncheckpoints/\nresults/\n*.png\n*/**/__pycache__\n*/*.pyc\n*/**/*.pyc\n*/**/**/*.pyc\n*/**/**/**/*.pyc\n*/**/**/**/**/*.pyc\n*/*.so*\n*/**/*.so*\n*/**/*.dylib*\n*~\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright (c) 2017, Asha Anoosheh\nAll rights reserved.\n\nRedistribution and use in source and binary forms, with or without\nmodification, are permitted provided that the following conditions are met:\n\n* Redistributions of source code must retain the above copyright notice, this\n  list of conditions and the following disclaimer.\n\n* Redistributions in binary form must reproduce the above copyright notice,\n  this list of conditions and the following disclaimer in the documentation\n  and/or other materials provided with the distribution.\n\nTHIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\nAND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\nIMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE\nDISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE\nFOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL\nDAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR\nSERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER\nCAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,\nOR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE\nOF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.\n"
  },
  {
    "path": "README.md",
    "content": "\n# ComboGAN\n\nThis is our ongoing PyTorch implementation for ComboGAN.\nCode was written by [Asha Anoosheh](https://github.com/aanoosheh) (built upon [CycleGAN](https://github.com/junyanz/CycleGAN))\n\n\n#### [[ComboGAN Paper]](https://arxiv.org/pdf/1712.06909.pdf)\n<img src=\"img/Inference.png\" width=420/>\n\n\nIf you use this code for your research, please cite:\n\nComboGAN: Unrestrained Scalability for Image Domain Translation\n[Asha Anoosheh](http://ashaanoosheh.com),  [Eirikur Augustsson](https://relational.github.io/), [Radu Timofte](http://www.vision.ee.ethz.ch/~timofter/), [Luc van Gool](https://www.vision.ee.ethz.ch/en/members/get_member.cgi?id=1)\nIn Arxiv, 2017.\n\n\n<br><br>\n<img src='img/Paintings.png' align=\"center\" width=900>\n<br><br>\n\n\n## Prerequisites\n- Linux or macOS\n- Python 3\n- CPU or NVIDIA GPU + CUDA CuDNN\n\n## Getting Started\n### Installation\n- Install PyTorch and dependencies from http://pytorch.org\n- Install Torch vision from the source.\n```bash\ngit clone https://github.com/pytorch/vision\ncd vision\npython setup.py install\n```\n- Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate).\n```bash\npip install visdom\npip install dominate\n```\n- Clone this repo:\n```bash\ngit clone https://github.com/AAnoosheh/ComboGAN.git\ncd ComboGAN\n```\n\n### ComboGAN training\nOur ready datasets can be downloaded using `./datasets/download_dataset.sh <dataset_name>`.\n\nA pretrained model for the 14-painters dataset can be found [HERE](https://www.dropbox.com/s/t8s6x0bu52d73s0/paint14_pretrained.zip?dl=0). Place under `./checkpoints/` and test using the instructions below, with args `--name paint14_pretrained --dataroot ./datasets/painters_14 --n_domains 14 --which_epoch 1150`.\n\nExample running scripts can be found in the `scripts` directory.\n\n- Train a model:\n```\npython train.py --name <experiment_name> --dataroot ./datasets/<your_dataset> --n_domains <N> --niter <num_epochs_constant_LR> --niter_decay <num_epochs_decaying_LR>\n```\nCheckpoints will be saved by default to `./checkpoints/<experiment_name>/`\n- Fine-tuning/Resume training:\n```\npython train.py --continue_train --which_epoch <checkpoint_number_to_load> --name <experiment_name> --dataroot ./datasets/<your_dataset> --n_domains <N> --niter <num_epochs_constant_LR> --niter_decay <num_epochs_decaying_LR>\n```\n- Test the model:\n```\npython test.py --phase test --name <experiment_name> --dataroot ./datasets/<your_dataset> --n_domains <N> --which_epoch <checkpoint_number_to_load> --serial_test\n```\nThe test results will be saved to a html file here: `./results/<experiment_name>/<epoch_number>/index.html`.\n\n\n\n## Training/Testing Details\n- Flags: see `options/train_options.py` for training-specific flags; see `options/test_options.py` for test-specific flags; and see `options/base_options.py` for all common flags.\n- Dataset format: The desired data directory (provided by `--dataroot`) should contain subfolders of the form `train*/` and `test*/`, and they are loaded in alphabetical order. (Note that a folder named train10 would be loaded before train2, and thus all checkpoints and results would be ordered accordingly.)\n- CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batchSize 32`) to benefit from multiple GPUs.\n- Visualization: during training, the current results and loss plots can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id 0`. Secondly, the intermediate results are also saved to `./checkpoints/<experiment_name>/web/index.html`. To avoid this, set the `--no_html` flag.\n- Preprocessing: images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image to be of size `(opt.loadSize, opt.loadSize)` and does a random crop of size `(opt.fineSize, opt.fineSize)`. `'crop'` skips the resizing step and only performs random cropping. `'scale_width'` resizes the image to have width `opt.fineSize` while keeping the aspect ratio. `'scale_width_and_crop'` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`.\n\n\nNOTE: one should **not** expect ComboGAN to work on just any combination of input and output datasets (e.g. `dogs<->houses`). We find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`.\n"
  },
  {
    "path": "data/__init__.py",
    "content": ""
  },
  {
    "path": "data/base_dataset.py",
    "content": "import torch.utils.data as data\nfrom PIL import Image\nimport torchvision.transforms as transforms\n\nclass BaseDataset(data.Dataset):\n    def __init__(self):\n        super(BaseDataset, self).__init__()\n\n    def name(self):\n        return 'BaseDataset'\n\n    def initialize(self, opt):\n        pass\n\ndef get_transform(opt):\n    transform_list = []\n    if 'resize' in opt.resize_or_crop:\n        transform_list.append(transforms.Resize(opt.loadSize, Image.BICUBIC))\n\n    if opt.isTrain:\n        if 'crop' in opt.resize_or_crop:\n            transform_list.append(transforms.RandomCrop(opt.fineSize))\n        if not opt.no_flip:\n            transform_list.append(transforms.RandomHorizontalFlip())\n\n    transform_list += [transforms.ToTensor(),\n                       transforms.Normalize((0.5, 0.5, 0.5),\n                                            (0.5, 0.5, 0.5))]\n    return transforms.Compose(transform_list)\n"
  },
  {
    "path": "data/data_loader.py",
    "content": "import torch.utils.data\nfrom data.unaligned_dataset import UnalignedDataset\n\n\nclass DataLoader():\n    def name(self):\n        return 'DataLoader'\n\n    def __init__(self, opt):\n        self.opt = opt\n        self.dataset = UnalignedDataset(opt)\n        self.dataloader = torch.utils.data.DataLoader(\n            self.dataset,\n            batch_size=opt.batchSize,\n            num_workers=int(opt.nThreads))\n\n    def __len__(self):\n        return min(len(self.dataset), self.opt.max_dataset_size)\n\n    def __iter__(self):\n        for i, data in enumerate(self.dataloader):\n            if i >= self.opt.max_dataset_size:\n                break\n            yield data\n\n"
  },
  {
    "path": "data/image_folder.py",
    "content": "###############################################################################\n# Code from\n# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py\n# Modified the original code so that it also loads images from the current\n# directory as well as the subdirectories\n###############################################################################\n\nimport torch.utils.data as data\n\nfrom PIL import Image\nimport os\nimport os.path\n\nIMG_EXTENSIONS = [\n    '.jpg', '.JPG', '.jpeg', '.JPEG',\n    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',\n]\n\n\ndef is_image_file(filename):\n    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)\n\n\ndef make_dataset(dir):\n    images = []\n    assert os.path.isdir(dir), '%s is not a valid directory' % dir\n\n    for root, _, fnames in sorted(os.walk(dir)):\n        for fname in fnames:\n            if is_image_file(fname):\n                path = os.path.join(root, fname)\n                images.append(path)\n\n    return images\n\n\ndef default_loader(path):\n    return Image.open(path).convert('RGB')\n\n\nclass ImageFolder(data.Dataset):\n\n    def __init__(self, root, transform=None, return_paths=False,\n                 loader=default_loader):\n        imgs = make_dataset(root)\n        if len(imgs) == 0:\n            raise(RuntimeError(\"Found 0 images in: \" + root + \"\\n\"\n                               \"Supported image extensions are: \" +\n                               \",\".join(IMG_EXTENSIONS)))\n\n        self.root = root\n        self.imgs = imgs\n        self.transform = transform\n        self.return_paths = return_paths\n        self.loader = loader\n\n    def __getitem__(self, index):\n        path = self.imgs[index]\n        img = self.loader(path)\n        if self.transform is not None:\n            img = self.transform(img)\n        if self.return_paths:\n            return img, path\n        else:\n            return img\n\n    def __len__(self):\n        return len(self.imgs)\n"
  },
  {
    "path": "data/unaligned_dataset.py",
    "content": "import os.path, glob\nimport torchvision.transforms as transforms\nfrom data.base_dataset import BaseDataset, get_transform\nfrom data.image_folder import make_dataset\nfrom PIL import Image\nimport random\n\nclass UnalignedDataset(BaseDataset):\n    def __init__(self, opt):\n        super(UnalignedDataset, self).__init__()\n        self.opt = opt\n        self.transform = get_transform(opt)\n\n        datapath = os.path.join(opt.dataroot, opt.phase + '*')\n        self.dirs = sorted(glob.glob(datapath))\n\n        self.paths = [sorted(make_dataset(d)) for d in self.dirs]\n        self.sizes = [len(p) for p in self.paths]\n\n    def load_image(self, dom, idx):\n        path = self.paths[dom][idx]\n        img = Image.open(path).convert('RGB')\n        img = self.transform(img)\n        return img, path\n\n    def __getitem__(self, index):\n        if not self.opt.isTrain:\n            if self.opt.serial_test:\n                for d,s in enumerate(self.sizes):\n                    if index < s:\n                        DA = d; break\n                    index -= s\n                index_A = index\n            else:\n                DA = index % len(self.dirs)\n                index_A = random.randint(0, self.sizes[DA] - 1)\n        else:\n            # Choose two of our domains to perform a pass on\n            DA, DB = random.sample(range(len(self.dirs)), 2)\n            index_A = random.randint(0, self.sizes[DA] - 1)\n\n        A_img, A_path = self.load_image(DA, index_A)\n        bundle = {'A': A_img, 'DA': DA, 'path': A_path}\n\n        if self.opt.isTrain:\n            index_B = random.randint(0, self.sizes[DB] - 1)\n            B_img, _ = self.load_image(DB, index_B)\n            bundle.update( {'B': B_img, 'DB': DB} )\n\n        return bundle\n\n    def __len__(self):\n        if self.opt.isTrain:\n            return max(self.sizes)\n        return sum(self.sizes)\n\n    def name(self):\n        return 'UnalignedDataset'\n"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/base_model.py",
    "content": "import os\nimport torch\n\n\nclass BaseModel():\n    def name(self):\n        return 'BaseModel'\n\n    def __init__(self, opt):\n        self.opt = opt\n        self.gpu_ids = opt.gpu_ids\n        self.isTrain = opt.isTrain\n        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor\n        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)\n\n    def set_input(self, input):\n        self.input = input\n\n    def forward(self):\n        pass\n\n    # used in test time, no backprop\n    def test(self):\n        pass\n\n    def get_image_paths(self):\n        pass\n\n    def optimize_parameters(self):\n        pass\n\n    def get_current_visuals(self):\n        return self.input\n\n    def get_current_errors(self):\n        return {}\n\n    def save(self, label):\n        pass\n\n    # helper saving function that can be used by subclasses\n    def save_network(self, network, network_label, epoch, gpu_ids):\n        save_filename = '%d_net_%s' % (epoch, network_label)\n        save_path = os.path.join(self.save_dir, save_filename)\n        network.save(save_path)\n        if gpu_ids and torch.cuda.is_available():\n            network.cuda(gpu_ids[0])\n\n    # helper loading function that can be used by subclasses\n    def load_network(self, network, network_label, epoch):\n        save_filename = '%d_net_%s' % (epoch, network_label)\n        save_path = os.path.join(self.save_dir, save_filename)\n        network.load(save_path)\n\n    def update_learning_rate():\n        pass\n"
  },
  {
    "path": "models/combogan_model.py",
    "content": "import numpy as np\nimport torch\nfrom collections import OrderedDict\nimport util.util as util\nfrom util.image_pool import ImagePool\nfrom .base_model import BaseModel\nfrom . import networks\n\n\nclass ComboGANModel(BaseModel):\n    def name(self):\n        return 'ComboGANModel'\n\n    def __init__(self, opt):\n        super(ComboGANModel, self).__init__(opt)\n\n        self.n_domains = opt.n_domains\n        self.DA, self.DB = None, None\n\n        self.real_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)\n        self.real_B = self.Tensor(opt.batchSize, opt.output_nc, opt.fineSize, opt.fineSize)\n\n        # load/define networks\n        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,\n                                      opt.netG_n_blocks, opt.netG_n_shared,\n                                      self.n_domains, opt.norm, opt.use_dropout, self.gpu_ids)\n        if self.isTrain:\n            blur_fn = lambda x : torch.nn.functional.conv2d(x, self.Tensor(util.gkern_2d()), groups=3, padding=2)\n            self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.netD_n_layers,\n                                          self.n_domains, blur_fn, opt.norm, self.gpu_ids)\n\n        if not self.isTrain or opt.continue_train:\n            which_epoch = opt.which_epoch\n            self.load_network(self.netG, 'G', which_epoch)\n            if self.isTrain:\n                self.load_network(self.netD, 'D', which_epoch)\n\n        if self.isTrain:\n            self.fake_pools = [ImagePool(opt.pool_size) for _ in range(self.n_domains)]\n            # define loss functions\n            self.L1 = torch.nn.SmoothL1Loss()\n            self.downsample = torch.nn.AvgPool2d(3, stride=2)\n            self.criterionCycle = self.L1\n            self.criterionIdt = lambda y,t : self.L1(self.downsample(y), self.downsample(t))\n            self.criterionLatent = lambda y,t : self.L1(y, t.detach())\n            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)\n            # initialize optimizers\n            self.netG.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))\n            self.netD.init_optimizers(torch.optim.Adam, opt.lr, (opt.beta1, 0.999))\n            # initialize loss storage\n            self.loss_D, self.loss_G = [0]*self.n_domains, [0]*self.n_domains\n            self.loss_cycle = [0]*self.n_domains\n            # initialize loss multipliers\n            self.lambda_cyc, self.lambda_enc = opt.lambda_cycle, (0 * opt.lambda_latent)\n            self.lambda_idt, self.lambda_fwd = opt.lambda_identity, opt.lambda_forward\n\n        print('---------- Networks initialized -------------')\n        print(self.netG)\n        if self.isTrain:\n            print(self.netD)\n        print('-----------------------------------------------')\n\n    def set_input(self, input):\n        input_A = input['A']\n        self.real_A.resize_(input_A.size()).copy_(input_A)\n        self.DA = input['DA'][0]\n        if self.isTrain:\n            input_B = input['B']\n            self.real_B.resize_(input_B.size()).copy_(input_B)\n            self.DB = input['DB'][0]\n        self.image_paths = input['path']\n\n    def test(self):\n        with torch.no_grad():\n            self.visuals = [self.real_A]\n            self.labels = ['real_%d' % self.DA]\n\n            # cache encoding to not repeat it everytime\n            encoded = self.netG.encode(self.real_A, self.DA)\n            for d in range(self.n_domains):\n                if d == self.DA and not self.opt.autoencode:\n                    continue\n                fake = self.netG.decode(encoded, d)\n                self.visuals.append( fake )\n                self.labels.append( 'fake_%d' % d )\n                if self.opt.reconstruct:\n                    rec = self.netG.forward(fake, d, self.DA)\n                    self.visuals.append( rec )\n                    self.labels.append( 'rec_%d' % d )\n\n    def get_image_paths(self):\n        return self.image_paths\n\n    def backward_D_basic(self, real, fake, domain):\n        # Real\n        pred_real = self.netD.forward(real, domain)\n        loss_D_real = self.criterionGAN(pred_real, True)\n        # Fake\n        pred_fake = self.netD.forward(fake.detach(), domain)\n        loss_D_fake = self.criterionGAN(pred_fake, False)\n        # Combined loss\n        loss_D = (loss_D_real + loss_D_fake) * 0.5\n        # backward\n        loss_D.backward()\n        return loss_D\n\n    def backward_D(self):\n        #D_A\n        fake_B = self.fake_pools[self.DB].query(self.fake_B)\n        self.loss_D[self.DA] = self.backward_D_basic(self.real_B, fake_B, self.DB)\n        #D_B\n        fake_A = self.fake_pools[self.DA].query(self.fake_A)\n        self.loss_D[self.DB] = self.backward_D_basic(self.real_A, fake_A, self.DA)\n\n    def backward_G(self):\n        encoded_A = self.netG.encode(self.real_A, self.DA)\n        encoded_B = self.netG.encode(self.real_B, self.DB)\n\n        # Optional identity \"autoencode\" loss\n        if self.lambda_idt > 0:\n            # Same encoder and decoder should recreate image\n            idt_A = self.netG.decode(encoded_A, self.DA)\n            loss_idt_A = self.criterionIdt(idt_A, self.real_A)\n            idt_B = self.netG.decode(encoded_B, self.DB)\n            loss_idt_B = self.criterionIdt(idt_B, self.real_B)\n        else:\n            loss_idt_A, loss_idt_B = 0, 0\n\n        # GAN loss\n        # D_A(G_A(A))\n        self.fake_B = self.netG.decode(encoded_A, self.DB)\n        pred_fake = self.netD.forward(self.fake_B, self.DB)\n        self.loss_G[self.DA] = self.criterionGAN(pred_fake, True)\n        # D_B(G_B(B))\n        self.fake_A = self.netG.decode(encoded_B, self.DA)\n        pred_fake = self.netD.forward(self.fake_A, self.DA)\n        self.loss_G[self.DB] = self.criterionGAN(pred_fake, True)\n        # Forward cycle loss\n        rec_encoded_A = self.netG.encode(self.fake_B, self.DB)\n        self.rec_A = self.netG.decode(rec_encoded_A, self.DA)\n        self.loss_cycle[self.DA] = self.criterionCycle(self.rec_A, self.real_A)\n        # Backward cycle loss\n        rec_encoded_B = self.netG.encode(self.fake_A, self.DA)\n        self.rec_B = self.netG.decode(rec_encoded_B, self.DB)\n        self.loss_cycle[self.DB] = self.criterionCycle(self.rec_B, self.real_B)\n\n        # Optional cycle loss on encoding space\n        if self.lambda_enc > 0:\n            loss_enc_A = self.criterionLatent(rec_encoded_A, encoded_A)\n            loss_enc_B = self.criterionLatent(rec_encoded_B, encoded_B)\n        else:\n            loss_enc_A, loss_enc_B = 0, 0\n\n        # Optional loss on downsampled image before and after\n        if self.lambda_fwd > 0:\n            loss_fwd_A = self.criterionIdt(self.fake_B, self.real_A)\n            loss_fwd_B = self.criterionIdt(self.fake_A, self.real_B)\n        else:\n            loss_fwd_A, loss_fwd_B = 0, 0\n\n        # combined loss\n        loss_G = self.loss_G[self.DA] + self.loss_G[self.DB] + \\\n                 (self.loss_cycle[self.DA] + self.loss_cycle[self.DB]) * self.lambda_cyc + \\\n                 (loss_idt_A + loss_idt_B) * self.lambda_idt + \\\n                 (loss_enc_A + loss_enc_B) * self.lambda_enc + \\\n                 (loss_fwd_A + loss_fwd_B) * self.lambda_fwd\n        loss_G.backward()\n\n    def optimize_parameters(self):\n        # G_A and G_B\n        self.netG.zero_grads(self.DA, self.DB)\n        self.backward_G()\n        self.netG.step_grads(self.DA, self.DB)\n        # D_A and D_B\n        self.netD.zero_grads(self.DA, self.DB)\n        self.backward_D()\n        self.netD.step_grads(self.DA, self.DB)\n\n    def get_current_errors(self):\n        extract = lambda l: [(i if type(i) is int or type(i) is float else i.item()) for i in l]\n        D_losses, G_losses, cyc_losses = extract(self.loss_D), extract(self.loss_G), extract(self.loss_cycle)\n        return OrderedDict([('D', D_losses), ('G', G_losses), ('Cyc', cyc_losses)])\n\n    def get_current_visuals(self, testing=False):\n        if not testing:\n            self.visuals = [self.real_A, self.fake_B, self.rec_A, self.real_B, self.fake_A, self.rec_B]\n            self.labels = ['real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B']\n        images = [util.tensor2im(v.data) for v in self.visuals]\n        return OrderedDict(zip(self.labels, images))\n\n    def save(self, label):\n        self.save_network(self.netG, 'G', label, self.gpu_ids)\n        self.save_network(self.netD, 'D', label, self.gpu_ids)\n\n    def update_hyperparams(self, curr_iter):\n        if curr_iter > self.opt.niter:\n            decay_frac = (curr_iter - self.opt.niter) / self.opt.niter_decay\n            new_lr = self.opt.lr * (1 - decay_frac)\n            self.netG.update_lr(new_lr)\n            self.netD.update_lr(new_lr)\n            print('updated learning rate: %f' % new_lr)\n\n        if self.opt.lambda_latent > 0:\n            decay_frac = curr_iter / (self.opt.niter + self.opt.niter_decay)\n            self.lambda_enc = self.opt.lambda_latent * decay_frac\n"
  },
  {
    "path": "models/networks.py",
    "content": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools, itertools\nimport numpy as np\n\n\n\n\ndef weights_init(m):\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        m.weight.data.normal_(0.0, 0.02)\n        if hasattr(m.bias, 'data'):\n            m.bias.data.fill_(0)\n    elif classname.find('BatchNorm2d') != -1:\n        m.weight.data.normal_(1.0, 0.02)\n        m.bias.data.fill_(0)\n\n\ndef get_norm_layer(norm_type='instance'):\n    if norm_type == 'batch':\n        return functools.partial(nn.BatchNorm2d, affine=True)\n    elif norm_type == 'instance':\n        return functools.partial(nn.InstanceNorm2d, affine=False)\n    else:\n        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)\n\n\ndef define_G(input_nc, output_nc, ngf, n_blocks, n_blocks_shared, n_domains, norm='batch', use_dropout=False, gpu_ids=[]):\n    norm_layer = get_norm_layer(norm_type=norm)\n    if type(norm_layer) == functools.partial:\n        use_bias = norm_layer.func == nn.InstanceNorm2d\n    else:\n        use_bias = norm_layer == nn.InstanceNorm2d\n\n    n_blocks -= n_blocks_shared\n    n_blocks_enc = n_blocks // 2\n    n_blocks_dec = n_blocks - n_blocks_enc\n\n    dup_args = (ngf, norm_layer, use_dropout, gpu_ids, use_bias)\n    enc_args = (input_nc, n_blocks_enc) + dup_args\n    dec_args = (output_nc, n_blocks_dec) + dup_args\n\n    if n_blocks_shared > 0:\n        n_blocks_shdec = n_blocks_shared // 2\n        n_blocks_shenc = n_blocks_shared - n_blocks_shdec\n        shenc_args = (n_domains, n_blocks_shenc) + dup_args\n        shdec_args = (n_domains, n_blocks_shdec) + dup_args\n        plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args, ResnetGenShared, shenc_args, shdec_args)\n    else:\n        plex_netG = G_Plexer(n_domains, ResnetGenEncoder, enc_args, ResnetGenDecoder, dec_args)\n\n    if len(gpu_ids) > 0:\n        assert(torch.cuda.is_available())\n        plex_netG.cuda(gpu_ids[0])\n\n    plex_netG.apply(weights_init)\n    return plex_netG\n\n\ndef define_D(input_nc, ndf, netD_n_layers, n_domains, blur_fn, norm='batch', gpu_ids=[]):\n    norm_layer = get_norm_layer(norm_type=norm)\n\n    model_args = (input_nc, ndf, netD_n_layers, blur_fn, norm_layer, gpu_ids)\n    plex_netD = D_Plexer(n_domains, NLayerDiscriminator, model_args)\n\n    if len(gpu_ids) > 0:\n        assert(torch.cuda.is_available())\n        plex_netD.cuda(gpu_ids[0])\n\n    plex_netD.apply(weights_init)\n    return plex_netD\n\n\n##############################################################################\n# Classes\n##############################################################################\n\n\n# Defines the GAN loss which uses either LSGAN or the regular GAN.\n# When LSGAN is used, it is basically same as MSELoss,\n# but it abstracts away the need to create the target label tensor\n# that has the same size as the input\nclass GANLoss(nn.Module):\n    def __init__(self, use_lsgan=True, tensor=torch.FloatTensor):\n        super(GANLoss, self).__init__()\n        self.Tensor = tensor\n        self.labels_real, self.labels_fake = None, None\n        self.preloss = nn.Sigmoid() if not use_lsgan else None\n        self.loss = nn.MSELoss() if use_lsgan else nn.BCELoss()\n\n    def get_target_tensor(self, inputs, is_real):\n        if self.labels_real is None or self.labels_real[0].numel() != inputs[0].numel():\n            self.labels_real = [ self.Tensor(input.size()).fill_(1.0) for input in inputs ]\n            self.labels_fake = [ self.Tensor(input.size()).fill_(0.0) for input in inputs ]\n        if is_real:\n            return self.labels_real\n        return self.labels_fake\n\n    def __call__(self, inputs, is_real):\n        labels = self.get_target_tensor(inputs, is_real)\n        if self.preloss is not None:\n            inputs = [self.preloss(input) for input in inputs]\n        losses = [self.loss(input, label) for input, label in zip(inputs, labels)]\n        multipliers = list(range(1, len(inputs)+1));  multipliers[-1] += 1\n        losses = [m*l for m,l in zip(multipliers, losses)]\n        return sum(losses) / (sum(multipliers) * len(losses))\n\n\n# Defines the generator that consists of Resnet blocks between a few\n# downsampling/upsampling operations.\n# Code and idea originally from Justin Johnson's architecture.\n# https://github.com/jcjohnson/fast-neural-style/\nclass ResnetGenEncoder(nn.Module):\n    def __init__(self, input_nc, n_blocks=4, ngf=64, norm_layer=nn.BatchNorm2d,\n                 use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):\n        assert(n_blocks >= 0)\n        super(ResnetGenEncoder, self).__init__()\n        self.gpu_ids = gpu_ids\n\n        model = [nn.ReflectionPad2d(3),\n                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,\n                           bias=use_bias),\n                 norm_layer(ngf),\n                 nn.PReLU()]\n\n        n_downsampling = 2\n        for i in range(n_downsampling):\n            mult = 2**i\n            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,\n                                stride=2, padding=1, bias=use_bias),\n                      norm_layer(ngf * mult * 2),\n                      nn.PReLU()]\n\n        mult = 2**n_downsampling\n        for _ in range(n_blocks):\n            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer,\n                                  use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]\n\n        self.model = nn.Sequential(*model)\n\n    def forward(self, input):\n        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):\n            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)\n        return self.model(input)\n\nclass ResnetGenShared(nn.Module):\n    def __init__(self, n_domains, n_blocks=2, ngf=64, norm_layer=nn.BatchNorm2d,\n                 use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):\n        assert(n_blocks >= 0)\n        super(ResnetGenShared, self).__init__()\n        self.gpu_ids = gpu_ids\n\n        model = []\n        n_downsampling = 2\n        mult = 2**n_downsampling\n\n        for _ in range(n_blocks):\n            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer, n_domains=n_domains,\n                                  use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]\n\n        self.model = SequentialContext(n_domains, *model)\n\n    def forward(self, input, domain):\n        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):\n            return nn.parallel.data_parallel(self.model, (input, domain), self.gpu_ids)\n        return self.model(input, domain)\n\nclass ResnetGenDecoder(nn.Module):\n    def __init__(self, output_nc, n_blocks=5, ngf=64, norm_layer=nn.BatchNorm2d,\n                 use_dropout=False, gpu_ids=[], use_bias=False, padding_type='reflect'):\n        assert(n_blocks >= 0)\n        super(ResnetGenDecoder, self).__init__()\n        self.gpu_ids = gpu_ids\n\n        model = []\n        n_downsampling = 2\n        mult = 2**n_downsampling\n\n        for _ in range(n_blocks):\n            model += [ResnetBlock(ngf * mult, norm_layer=norm_layer,\n                                  use_dropout=use_dropout, use_bias=use_bias, padding_type=padding_type)]\n\n        for i in range(n_downsampling):\n            mult = 2**(n_downsampling - i)\n            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),\n                                         kernel_size=4, stride=2,\n                                         padding=1, output_padding=0,\n                                         bias=use_bias),\n                      norm_layer(int(ngf * mult / 2)),\n                      nn.PReLU()]\n\n        model += [nn.ReflectionPad2d(3),\n                  nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),\n                  nn.Tanh()]\n\n        self.model = nn.Sequential(*model)\n\n    def forward(self, input):\n        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):\n            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)\n        return self.model(input)\n\n\n# Define a resnet block\nclass ResnetBlock(nn.Module):\n    def __init__(self, dim, norm_layer, use_dropout, use_bias, padding_type='reflect', n_domains=0):\n        super(ResnetBlock, self).__init__()\n\n        conv_block = []\n        p = 0\n        if padding_type == 'reflect':\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == 'replicate':\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == 'zero':\n            p = 1\n        else:\n            raise NotImplementedError('padding [%s] is not implemented' % padding_type)\n\n        conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias),\n                       norm_layer(dim),\n                       nn.PReLU()]\n        if use_dropout:\n            conv_block += [nn.Dropout(0.5)]\n\n        p = 0\n        if padding_type == 'reflect':\n            conv_block += [nn.ReflectionPad2d(1)]\n        elif padding_type == 'replicate':\n            conv_block += [nn.ReplicationPad2d(1)]\n        elif padding_type == 'zero':\n            p = 1\n        else:\n            raise NotImplementedError('padding [%s] is not implemented' % padding_type)\n        conv_block += [nn.Conv2d(dim + n_domains, dim, kernel_size=3, padding=p, bias=use_bias),\n                       norm_layer(dim)]\n\n        self.conv_block = SequentialContext(n_domains, *conv_block)\n\n    def forward(self, input):\n        if isinstance(input, tuple):\n            return input[0] + self.conv_block(*input)\n        return input + self.conv_block(input)\n\n\n# Defines the PatchGAN discriminator with the specified arguments.\nclass NLayerDiscriminator(nn.Module):\n    def __init__(self, input_nc, ndf=64, n_layers=3, blur_fn=None, norm_layer=nn.BatchNorm2d, gpu_ids=[]):\n        super(NLayerDiscriminator, self).__init__()\n        self.gpu_ids = gpu_ids\n        self.blur_fn = blur_fn\n        self.gray_fn = lambda x: (.299*x[:,0,:,:] + .587*x[:,1,:,:] + .114*x[:,2,:,:]).unsqueeze_(1)\n\n        self.model_gray = self.model(1, ndf, n_layers, norm_layer)\n        self.model_rgb = self.model(input_nc, ndf, n_layers, norm_layer)\n\n    def model(self, input_nc, ndf, n_layers, norm_layer):\n        if type(norm_layer) == functools.partial:\n            use_bias = norm_layer.func == nn.InstanceNorm2d\n        else:\n            use_bias = norm_layer == nn.InstanceNorm2d\n\n        kw = 4\n        padw = int(np.ceil((kw-1)/2))\n        sequences = [[\n            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),\n            nn.PReLU()\n        ]]\n\n        nf_mult = 1\n        nf_mult_prev = 1\n        for n in range(1, n_layers):\n            nf_mult_prev = nf_mult\n            nf_mult = min(2**n, 8)\n            sequences += [[\n                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult + 1,\n                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),\n                norm_layer(ndf * nf_mult + 1),\n                nn.PReLU()\n            ]]\n\n        nf_mult_prev = nf_mult\n        nf_mult = min(2**n_layers, 8)\n        sequences += [[\n            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,\n                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),\n            norm_layer(ndf * nf_mult),\n            nn.PReLU(),\n            \\\n            nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)\n        ]]\n\n        return SequentialOutput(*sequences)\n\n    def forward(self, input):\n        luminance, blurred_rgb = self.gray_fn(input), self.blur_fn(input)\n\n        if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):\n            outs1 = nn.parallel.data_parallel(self.model_gray, luminance, self.gpu_ids)\n            outs2 = nn.parallel.data_parallel(self.model_rgb, blurred_rgb, self.gpu_ids)\n        else:\n            outs1 = self.model_gray(luminance)\n            outs2 = self.model_rgb(blurred_rgb)\n        return [torch.cat([o1,o2], 1) for o1,o2 in zip(outs1, outs2)]\n\n\nclass Plexer(nn.Module):\n    def __init__(self):\n        super(Plexer, self).__init__()\n\n    def apply(self, func):\n        for net in self.networks:\n            net.apply(func)\n\n    def cuda(self, device_id):\n        for net in self.networks:\n            net.cuda(device_id)\n\n    def init_optimizers(self, opt, lr, betas):\n        self.optimizers = [opt(net.parameters(), lr=lr, betas=betas) \\\n                           for net in self.networks]\n\n    def zero_grads(self, dom_a, dom_b):\n        self.optimizers[dom_a].zero_grad()\n        self.optimizers[dom_b].zero_grad()\n\n    def step_grads(self, dom_a, dom_b):\n        self.optimizers[dom_a].step()\n        self.optimizers[dom_b].step()\n\n    def update_lr(self, new_lr):\n        for opt in self.optimizers:\n            for param_group in opt.param_groups:\n                param_group['lr'] = new_lr\n\n    def save(self, save_path):\n        for i, net in enumerate(self.networks):\n            filename = save_path + ('%d.pth' % i)\n            torch.save(net.cpu().state_dict(), filename)\n\n    def load(self, save_path):\n        for i, net in enumerate(self.networks):\n            filename = save_path + ('%d.pth' % i)\n            net.load_state_dict(torch.load(filename))\n\nclass G_Plexer(Plexer):\n    def __init__(self, n_domains, encoder, enc_args, decoder, dec_args,\n                 block=None, shenc_args=None, shdec_args=None):\n        super(G_Plexer, self).__init__()\n        self.encoders = [encoder(*enc_args) for _ in range(n_domains)]\n        self.decoders = [decoder(*dec_args) for _ in range(n_domains)]\n\n        self.sharing = block is not None\n        if self.sharing:\n            self.shared_encoder = block(*shenc_args)\n            self.shared_decoder = block(*shdec_args)\n            self.encoders.append( self.shared_encoder )\n            self.decoders.append( self.shared_decoder )\n        self.networks = self.encoders + self.decoders\n\n    def init_optimizers(self, opt, lr, betas):\n        self.optimizers = []\n        for enc, dec in zip(self.encoders, self.decoders):\n            params = itertools.chain(enc.parameters(), dec.parameters())\n            self.optimizers.append( opt(params, lr=lr, betas=betas) )\n\n    def forward(self, input, in_domain, out_domain):\n        encoded = self.encode(input, in_domain)\n        return self.decode(encoded, out_domain)\n\n    def encode(self, input, domain):\n        output = self.encoders[domain].forward(input)\n        if self.sharing:\n            return self.shared_encoder.forward(output, domain)\n        return output\n\n    def decode(self, input, domain):\n        if self.sharing:\n            input = self.shared_decoder.forward(input, domain)\n        return self.decoders[domain].forward(input)\n\n    def zero_grads(self, dom_a, dom_b):\n        self.optimizers[dom_a].zero_grad()\n        if self.sharing:\n            self.optimizers[-1].zero_grad()\n        self.optimizers[dom_b].zero_grad()\n\n    def step_grads(self, dom_a, dom_b):\n        self.optimizers[dom_a].step()\n        if self.sharing:\n            self.optimizers[-1].step()\n        self.optimizers[dom_b].step()\n\n    def __repr__(self):\n        e, d = self.encoders[0], self.decoders[0]\n        e_params = sum([p.numel() for p in e.parameters()])\n        d_params = sum([p.numel() for p in d.parameters()])\n        return repr(e) +'\\n'+ repr(d) +'\\n'+ \\\n            'Created %d Encoder-Decoder pairs' % len(self.encoders) +'\\n'+ \\\n            'Number of parameters per Encoder: %d' % e_params +'\\n'+ \\\n            'Number of parameters per Deocder: %d' % d_params\n\nclass D_Plexer(Plexer):\n    def __init__(self, n_domains, model, model_args):\n        super(D_Plexer, self).__init__()\n        self.networks = [model(*model_args) for _ in range(n_domains)]\n\n    def forward(self, input, domain):\n        discriminator = self.networks[domain]\n        return discriminator.forward(input)\n\n    def __repr__(self):\n        t = self.networks[0]\n        t_params = sum([p.numel() for p in t.parameters()])\n        return repr(t) +'\\n'+ \\\n            'Created %d Discriminators' % len(self.networks) +'\\n'+ \\\n            'Number of parameters per Discriminator: %d' % t_params\n\n\nclass SequentialContext(nn.Sequential):\n    def __init__(self, n_classes, *args):\n        super(SequentialContext, self).__init__(*args)\n        self.n_classes = n_classes\n        self.context_var = None\n\n    def prepare_context(self, input, domain):\n        if self.context_var is None or self.context_var.size()[-2:] != input.size()[-2:]:\n            tensor = torch.cuda.FloatTensor if isinstance(input.data, torch.cuda.FloatTensor) \\\n                     else torch.FloatTensor\n            self.context_var = tensor(*((1, self.n_classes) + input.size()[-2:]))\n\n        self.context_var.data.fill_(-1.0)\n        self.context_var.data[:,domain,:,:] = 1.0\n        return self.context_var\n\n    def forward(self, *input):\n        if self.n_classes < 2 or len(input) < 2:\n            return super(SequentialContext, self).forward(input[0])\n        x, domain = input\n\n        for module in self._modules.values():\n            if 'Conv' in module.__class__.__name__:\n                context_var = self.prepare_context(x, domain)\n                x = torch.cat([x, context_var], dim=1)\n            elif 'Block' in module.__class__.__name__:\n                x = (x,) + input[1:]\n            x = module(x)\n        return x\n\nclass SequentialOutput(nn.Sequential):\n    def __init__(self, *args):\n        args = [nn.Sequential(*arg) for arg in args]\n        super(SequentialOutput, self).__init__(*args)\n\n    def forward(self, input):\n        predictions = []\n        layers = self._modules.values()\n        for i, module in enumerate(layers):\n            output = module(input)\n            if i == 0:\n                input = output;  continue\n            predictions.append( output[:,-1,:,:] )\n            if i != len(layers) - 1:\n                input = output[:,:-1,:,:]\n        return predictions\n"
  },
  {
    "path": "options/__init__.py",
    "content": ""
  },
  {
    "path": "options/base_options.py",
    "content": "import argparse\nimport os\nfrom util import util\nimport torch\n\nclass BaseOptions():\n    def __init__(self):\n        self.parser = argparse.ArgumentParser()\n        self.initialized = False\n\n    def initialize(self):\n        self.parser.add_argument('--name', required=True, type=str, help='name of the experiment. It decides where to store samples and models')\n        self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')\n\n        self.parser.add_argument('--dataroot', required=True, type=str, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')\n        self.parser.add_argument('--n_domains', required=True, type=int, help='Number of domains to transfer among')\n\n        self.parser.add_argument('--max_dataset_size', type=int, default=float(\"inf\"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')\n        self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize|resize_and_crop|crop]')\n        self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')\n\n        self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')\n        self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')\n\n        self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')\n        self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')\n        self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')\n\n        self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')\n        self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')\n        self.parser.add_argument('--netG_n_blocks', type=int, default=9, help='number of residual blocks to use for netG')\n        self.parser.add_argument('--netG_n_shared', type=int, default=0, help='number of blocks to use for netG shared center module')\n        self.parser.add_argument('--netD_n_layers', type=int, default=4, help='number of layers to use for netD')\n\n        self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')\n        self.parser.add_argument('--use_dropout', action='store_true', help='insert dropout for the generator')\n\n        self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')\n        self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')\n\n        self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display (set >1 to use visdom)')\n        self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')\n        self.parser.add_argument('--display_winsize', type=int, default=256,  help='display window size')\n        self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')\n\n        self.initialized = True\n\n    def parse(self):\n        if not self.initialized:\n            self.initialize()\n        self.opt = self.parser.parse_args()\n        self.opt.isTrain = self.isTrain   # train or test\n\n        str_ids = self.opt.gpu_ids.split(',')\n        self.opt.gpu_ids = []\n        for str_id in str_ids:\n            id = int(str_id)\n            if id >= 0:\n                self.opt.gpu_ids.append(id)\n\n        # set gpu ids\n        if len(self.opt.gpu_ids) > 0:\n            torch.cuda.set_device(self.opt.gpu_ids[0])\n\n        args = vars(self.opt)\n\n        print('------------ Options -------------')\n        for k, v in sorted(args.items()):\n            print('%s: %s' % (str(k), str(v)))\n        print('-------------- End ----------------')\n\n        # save to the disk\n        expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)\n        util.mkdirs(expr_dir)\n        file_name = os.path.join(expr_dir, 'opt.txt')\n        with open(file_name, 'wt') as opt_file:\n            opt_file.write('------------ Options -------------\\n')\n            for k, v in sorted(args.items()):\n                opt_file.write('%s: %s\\n' % (str(k), str(v)))\n            opt_file.write('-------------- End ----------------\\n')\n        return self.opt\n"
  },
  {
    "path": "options/test_options.py",
    "content": "from .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n    def initialize(self):\n        BaseOptions.initialize(self)\n        self.isTrain = False\n\n        self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')\n        self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')\n\n        self.parser.add_argument('--which_epoch', required=True, type=int, help='which epoch to load for inference?')\n        self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc (determines name of folder to load from)')\n\n        self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run (if serial_test not enabled)')\n        self.parser.add_argument('--serial_test', action='store_true', help='read each image once from folders in sequential order')\n\n        self.parser.add_argument('--autoencode', action='store_true', help='translate images back into its own domain')\n        self.parser.add_argument('--reconstruct', action='store_true', help='do reconstructions of images during testing')\n\n        self.parser.add_argument('--show_matrix', action='store_true', help='visualize images in a matrix format as well')\n"
  },
  {
    "path": "options/train_options.py",
    "content": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n    def initialize(self):\n        BaseOptions.initialize(self)\n        self.isTrain = True\n\n        self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')\n        self.parser.add_argument('--which_epoch', type=int, default=0, help='which epoch to load if continuing training')\n        self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc (determines name of folder to load from)')\n\n        self.parser.add_argument('--niter', required=True, type=int, help='# of epochs at starting learning rate (try 50*n_domains)')\n        self.parser.add_argument('--niter_decay', required=True, type=int, help='# of epochs to linearly decay learning rate to zero (try 50*n_domains)')\n\n        self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for ADAM')\n        self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of ADAM')\n\n        self.parser.add_argument('--lambda_cycle', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')\n        self.parser.add_argument('--lambda_identity', type=float, default=0.0, help='weight for identity \"autoencode\" mapping (A -> A)')\n        self.parser.add_argument('--lambda_latent', type=float, default=0.0, help='weight for latent-space loss (A -> z -> B -> z)')\n        self.parser.add_argument('--lambda_forward', type=float, default=0.0, help='weight for forward loss (A -> B; try 0.2)')\n\n        self.parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')\n        self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')\n        self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')\n\n        self.parser.add_argument('--no_lsgan', action='store_true', help='use vanilla discriminator in place of least-squares one')\n        self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')\n        self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')\n"
  },
  {
    "path": "scripts/continue_combogan.sh",
    "content": "python train.py  \\\n    --dataroot ./datasets/alps  \\\n    --name alps_combogan  \\\n    --continue_train  \\\n    --which_epoch 117  \\\n    --n_domains 4  \\\n    --niter 200  \\\n    --niter_decay 200  \\\n    --lambda_identity 0.0  \\\n    --lambda_forward 0.0\n"
  },
  {
    "path": "scripts/test_combogan.sh",
    "content": "python test.py  \\\n    --phase test  \\\n    --dataroot ./datasets/alps  \\\n    --name alps_combogan  \\\n    --n_domains 4  \\\n    --which_epoch 400  \\\n    --show_matrix\n"
  },
  {
    "path": "scripts/train_combogan.sh",
    "content": "python train.py  \\\n    --dataroot ./datasets/alps  \\\n    --name alps_combogan  \\\n    --n_domains 4  \\\n    --niter 200  \\\n    --niter_decay 200  \\\n    --lambda_identity 0.0  \\\n    --lambda_forward 0.0\n"
  },
  {
    "path": "test.py",
    "content": "import time\nimport os\nfrom options.test_options import TestOptions\nfrom data.data_loader import DataLoader\nfrom models.combogan_model import ComboGANModel\nfrom util.visualizer import Visualizer\nfrom util import html\n\n\nopt = TestOptions().parse()\nopt.nThreads = 1   # test code only supports nThreads = 1\nopt.batchSize = 1  # test code only supports batchSize = 1\n\ndataset = DataLoader(opt)\nmodel = ComboGANModel(opt)\nvisualizer = Visualizer(opt)\n# create website\nweb_dir = os.path.join(opt.results_dir, opt.name, '%s_%d' % (opt.phase, opt.which_epoch))\nwebpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %d' % (opt.name, opt.phase, opt.which_epoch))\n# store images for matrix visualization\nvis_buffer = []\n\n# test\nfor i, data in enumerate(dataset):\n    if not opt.serial_test and i >= opt.how_many:\n        break\n    model.set_input(data)\n    model.test()\n    visuals = model.get_current_visuals(testing=True)\n    img_path = model.get_image_paths()\n    print('process image... %s' % img_path)\n    visualizer.save_images(webpage, visuals, img_path)\n\n    if opt.show_matrix:\n        vis_buffer.append(visuals)\n        if (i+1) % opt.n_domains == 0:\n            save_path = os.path.join(web_dir, 'mat_%d.png' % (i//opt.n_domains))\n            visualizer.save_image_matrix(vis_buffer, save_path)\n            vis_buffer.clear()\n\nwebpage.save()\n\n"
  },
  {
    "path": "train.py",
    "content": "import time\nfrom options.train_options import TrainOptions\nfrom data.data_loader import DataLoader\nfrom models.combogan_model import ComboGANModel\nfrom util.visualizer import Visualizer\n\n\nopt = TrainOptions().parse()\ndataset = DataLoader(opt)\nprint('# training images = %d' % len(dataset))\nmodel = ComboGANModel(opt)\nvisualizer = Visualizer(opt)\ntotal_steps = 0\n\n# Update initially if continuing\nif opt.which_epoch > 0:\n    model.update_hyperparams(opt.which_epoch)\n\nfor epoch in range(opt.which_epoch + 1, opt.niter + opt.niter_decay + 1):\n    epoch_start_time = time.time()\n    epoch_iter = 0\n    for i, data in enumerate(dataset):\n        iter_start_time = time.time()\n        total_steps += opt.batchSize\n        epoch_iter += opt.batchSize\n        model.set_input(data)\n        model.optimize_parameters()\n\n        if total_steps % opt.display_freq == 0:\n            visualizer.display_current_results(model.get_current_visuals(), epoch)\n\n        if total_steps % opt.print_freq == 0:\n            errors = model.get_current_errors()\n            t = (time.time() - iter_start_time) / opt.batchSize\n            visualizer.print_current_errors(epoch, epoch_iter, errors, t)\n            if opt.display_id > 0:\n                visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)\n\n    if epoch % opt.save_epoch_freq == 0:\n        print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))\n        model.save(epoch)\n\n    print('End of epoch %d / %d \\t Time Taken: %d sec' %\n          (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))\n\n    model.update_hyperparams(epoch)\n"
  },
  {
    "path": "util/__init__.py",
    "content": ""
  },
  {
    "path": "util/get_data.py",
    "content": "from __future__ import print_function\nimport os\nimport tarfile\nimport requests\nfrom warnings import warn\nfrom zipfile import ZipFile\nfrom bs4 import BeautifulSoup\nfrom os.path import abspath, isdir, join, basename\n\n\nclass GetData(object):\n    \"\"\"\n\n    Download CycleGAN or Pix2Pix Data.\n\n    Args:\n        technique : str\n            One of: 'cyclegan' or 'pix2pix'.\n        verbose : bool\n            If True, print additional information.\n\n    Examples:\n        >>> from util.get_data import GetData\n        >>> gd = GetData(technique='cyclegan')\n        >>> new_data_path = gd.get(save_path='./datasets')  # options will be displayed.\n\n    \"\"\"\n\n    def __init__(self, technique='cyclegan', verbose=True):\n        url_dict = {\n            'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',\n            'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'\n        }\n        self.url = url_dict.get(technique.lower())\n        self._verbose = verbose\n\n    def _print(self, text):\n        if self._verbose:\n            print(text)\n\n    @staticmethod\n    def _get_options(r):\n        soup = BeautifulSoup(r.text, 'lxml')\n        options = [h.text for h in soup.find_all('a', href=True)\n                   if h.text.endswith(('.zip', 'tar.gz'))]\n        return options\n\n    def _present_options(self):\n        r = requests.get(self.url)\n        options = self._get_options(r)\n        print('Options:\\n')\n        for i, o in enumerate(options):\n            print(\"{0}: {1}\".format(i, o))\n        choice = input(\"\\nPlease enter the number of the \"\n                       \"dataset above you wish to download:\")\n        return options[int(choice)]\n\n    def _download_data(self, dataset_url, save_path):\n        if not isdir(save_path):\n            os.makedirs(save_path)\n\n        base = basename(dataset_url)\n        temp_save_path = join(save_path, base)\n\n        with open(temp_save_path, \"wb\") as f:\n            r = requests.get(dataset_url)\n            f.write(r.content)\n\n        if base.endswith('.tar.gz'):\n            obj = tarfile.open(temp_save_path)\n        elif base.endswith('.zip'):\n            obj = ZipFile(temp_save_path, 'r')\n        else:\n            raise ValueError(\"Unknown File Type: {0}.\".format(base))\n\n        self._print(\"Unpacking Data...\")\n        obj.extractall(save_path)\n        obj.close()\n        os.remove(temp_save_path)\n\n    def get(self, save_path, dataset=None):\n        \"\"\"\n\n        Download a dataset.\n\n        Args:\n            save_path : str\n                A directory to save the data to.\n            dataset : str, optional\n                A specific dataset to download.\n                Note: this must include the file extension.\n                If None, options will be presented for you\n                to choose from.\n\n        Returns:\n            save_path_full : str\n                The absolute path to the downloaded data.\n\n        \"\"\"\n        if dataset is None:\n            selected_dataset = self._present_options()\n        else:\n            selected_dataset = dataset\n\n        save_path_full = join(save_path, selected_dataset.split('.')[0])\n\n        if isdir(save_path_full):\n            warn(\"\\n'{0}' already exists. Voiding Download.\".format(\n                save_path_full))\n        else:\n            self._print('Downloading Data...')\n            url = \"{0}/{1}\".format(self.url, selected_dataset)\n            self._download_data(url, save_path=save_path)\n\n        return abspath(save_path_full)\n"
  },
  {
    "path": "util/html.py",
    "content": "import dominate\nfrom dominate.tags import *\nimport os\n\n\nclass HTML:\n    def __init__(self, web_dir, title, reflesh=0):\n        self.title = title\n        self.web_dir = web_dir\n        self.img_dir = os.path.join(self.web_dir, 'images')\n        if not os.path.exists(self.web_dir):\n            os.makedirs(self.web_dir)\n        if not os.path.exists(self.img_dir):\n            os.makedirs(self.img_dir)\n        # print(self.img_dir)\n\n        self.doc = dominate.document(title=title)\n        if reflesh > 0:\n            with self.doc.head:\n                meta(http_equiv=\"reflesh\", content=str(reflesh))\n\n    def get_image_dir(self):\n        return self.img_dir\n\n    def add_header(self, str):\n        with self.doc:\n            h3(str)\n\n    def add_table(self, border=1):\n        self.t = table(border=border, style=\"table-layout: fixed;\")\n        self.doc.add(self.t)\n\n    def add_images(self, ims, txts, links, width=400):\n        self.add_table()\n        with self.t:\n            with tr():\n                for im, txt, link in zip(ims, txts, links):\n                    with td(style=\"word-wrap: break-word;\", halign=\"center\", valign=\"top\"):\n                        with p():\n                            with a(href=os.path.join('images', link)):\n                                img(style=\"width:%dpx\" % width, src=os.path.join('images', im))\n                            br()\n                            p(txt)\n\n    def save(self):\n        html_file = '%s/index.html' % self.web_dir\n        f = open(html_file, 'wt')\n        f.write(self.doc.render())\n        f.close()\n\n\nif __name__ == '__main__':\n    html = HTML('web/', 'test_html')\n    html.add_header('hello world')\n\n    ims = []\n    txts = []\n    links = []\n    for n in range(4):\n        ims.append('image_%d.png' % n)\n        txts.append('text_%d' % n)\n        links.append('image_%d.png' % n)\n    html.add_images(ims, txts, links)\n    html.save()\n"
  },
  {
    "path": "util/image_pool.py",
    "content": "import random\nimport numpy as np\nimport torch\nfrom torch.autograd import Variable\nclass ImagePool():\n    def __init__(self, pool_size):\n        self.pool_size = pool_size\n        if self.pool_size > 0:\n            self.num_imgs = 0\n            self.images = []\n\n    def query(self, images):\n        if self.pool_size == 0:\n            return images\n        return_images = []\n        for image in images.data:\n            image = torch.unsqueeze(image, 0)\n            if self.num_imgs < self.pool_size:\n                self.num_imgs = self.num_imgs + 1\n                self.images.append(image)\n                return_images.append(image)\n            else:\n                p = random.uniform(0, 1)\n                if p > 0.5:\n                    random_id = random.randint(0, self.pool_size-1)\n                    tmp = self.images[random_id].clone()\n                    self.images[random_id] = image\n                    return_images.append(tmp)\n                else:\n                    return_images.append(image)\n        return_images = Variable(torch.cat(return_images, 0))\n        return return_images\n"
  },
  {
    "path": "util/png.py",
    "content": "import struct\nimport zlib\n\ndef encode(buf, width, height):\n  \"\"\" buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... \"\"\"\n  assert (width * height * 3 == len(buf))\n  bpp = 3\n\n  def raw_data():\n    # reverse the vertical line order and add null bytes at the start\n    row_bytes = width * bpp\n    for row_start in range((height - 1) * width * bpp, -1, -row_bytes):\n      yield b'\\x00'\n      yield buf[row_start:row_start + row_bytes]\n\n  def chunk(tag, data):\n    return [\n        struct.pack(\"!I\", len(data)),\n        tag,\n        data,\n        struct.pack(\"!I\", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag)))\n      ]\n\n  SIGNATURE = b'\\x89PNG\\r\\n\\x1a\\n'\n  COLOR_TYPE_RGB = 2\n  COLOR_TYPE_RGBA = 6\n  bit_depth = 8\n  return b''.join(\n      [ SIGNATURE ] +\n      chunk(b'IHDR', struct.pack(\"!2I5B\", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) +\n      chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) +\n      chunk(b'IEND', b'')\n    )\n"
  },
  {
    "path": "util/util.py",
    "content": "from __future__ import print_function\nimport torch\nimport numpy as np\nfrom scipy.ndimage.filters import gaussian_filter\nfrom PIL import Image\nimport inspect, re\nimport os\nimport collections\n\n# Converts a Tensor into a Numpy array\n# |imtype|: the desired type of the converted numpy array\ndef tensor2im(image_tensor, imtype=np.uint8):\n    image_numpy = image_tensor[0].cpu().float().numpy()\n    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0\n    return image_numpy.astype(imtype)\n\ndef gkern_2d(size=5, sigma=3):\n    # Create 2D gaussian kernel\n    dirac = np.zeros((size, size))\n    dirac[size//2, size//2] = 1\n    mask = gaussian_filter(dirac, sigma)\n    # Adjust dimensions for torch conv2d\n    return np.stack([np.expand_dims(mask, axis=0)] * 3)\n\n\ndef diagnose_network(net, name='network'):\n    mean = 0.0\n    count = 0\n    for param in net.parameters():\n        if param.grad is not None:\n            mean += torch.mean(torch.abs(param.grad.data))\n            count += 1\n    if count > 0:\n        mean = mean / count\n    print(name)\n    print(mean)\n\n\ndef save_image(image_numpy, image_path):\n    image_pil = Image.fromarray(image_numpy)\n    image_pil.save(image_path)\n\ndef info(object, spacing=10, collapse=1):\n    \"\"\"Print methods and doc strings.\n    Takes module, class, list, dictionary, or string.\"\"\"\n    methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]\n    processFunc = collapse and (lambda s: \" \".join(s.split())) or (lambda s: s)\n    print( \"\\n\".join([\"%s %s\" %\n                     (method.ljust(spacing),\n                      processFunc(str(getattr(object, method).__doc__)))\n                     for method in methodList]) )\n\ndef varname(p):\n    for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:\n        m = re.search(r'\\bvarname\\s*\\(\\s*([A-Za-z_][A-Za-z0-9_]*)\\s*\\)', line)\n        if m:\n            return m.group(1)\n\ndef print_numpy(x, val=True, shp=False):\n    x = x.astype(np.float64)\n    if shp:\n        print('shape,', x.shape)\n    if val:\n        x = x.flatten()\n        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (\n            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))\n\n\ndef mkdirs(paths):\n    if isinstance(paths, list) and not isinstance(paths, str):\n        for path in paths:\n            mkdir(path)\n    else:\n        mkdir(paths)\n\n\ndef mkdir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n"
  },
  {
    "path": "util/visualizer.py",
    "content": "import numpy as np\nimport os\nimport ntpath\nimport time\nfrom . import util\nfrom . import html\n\nclass Visualizer():\n    def __init__(self, opt):\n        # self.opt = opt\n        self.display_id = opt.display_id\n        self.use_html = opt.isTrain and not opt.no_html\n        self.win_size = opt.display_winsize\n        self.name = opt.name\n        if self.display_id > 0:\n            import visdom\n            self.vis = visdom.Visdom(port = opt.display_port)\n            self.display_single_pane_ncols = opt.display_single_pane_ncols\n\n        if self.use_html:\n            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')\n            self.img_dir = os.path.join(self.web_dir, 'images')\n            print('create web directory %s...' % self.web_dir)\n            util.mkdirs([self.web_dir, self.img_dir])\n        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')\n        with open(self.log_name, \"a\") as log_file:\n            now = time.strftime(\"%c\")\n            log_file.write('================ Training Loss (%s) ================\\n' % now)\n\n    # |visuals|: dictionary of images to display or save\n    def display_current_results(self, visuals, epoch):\n        if self.display_id > 0: # show images in the browser\n            if self.display_single_pane_ncols > 0:\n                h, w = next(iter(visuals.values())).shape[:2]\n                table_css = \"\"\"<style>\n    table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}\n    table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}\n</style>\"\"\" % (w, h)\n                ncols = self.display_single_pane_ncols\n                title = self.name\n                label_html = ''\n                label_html_row = ''\n                nrows = int(np.ceil(len(visuals.items()) / ncols))\n                images = []\n                idx = 0\n                for label, image_numpy in visuals.items():\n                    label_html_row += '<td>%s</td>' % label\n                    images.append(image_numpy.transpose([2, 0, 1]))\n                    idx += 1\n                    if idx % ncols == 0:\n                        label_html += '<tr>%s</tr>' % label_html_row\n                        label_html_row = ''\n                white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255\n                while idx % ncols != 0:\n                    images.append(white_image)\n                    label_html_row += '<td></td>'\n                    idx += 1\n                if label_html_row != '':\n                    label_html += '<tr>%s</tr>' % label_html_row\n                # pane col = image row\n                self.vis.images(images, nrow=ncols, win=self.display_id + 1,\n                                padding=2, opts=dict(title=title + ' images'))\n                label_html = '<table>%s</table>' % label_html\n                self.vis.text(table_css + label_html, win = self.display_id + 2,\n                              opts=dict(title=title + ' labels'))\n            else:\n                idx = 1\n                for label, image_numpy in visuals.items():\n                    #image_numpy = np.flipud(image_numpy)\n                    self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),\n                                       win=self.display_id + idx)\n                    idx += 1\n\n        if self.use_html: # save images to a html file\n            for label, image_numpy in visuals.items():\n                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))\n                util.save_image(image_numpy, img_path)\n            # update website\n            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)\n            for n in range(epoch, 0, -1):\n                webpage.add_header('epoch [%d]' % n)\n                ims = []\n                txts = []\n                links = []\n\n                for label, image_numpy in visuals.items():\n                    img_path = 'epoch%.3d_%s.png' % (n, label)\n                    ims.append(img_path)\n                    txts.append(label)\n                    links.append(img_path)\n                webpage.add_images(ims, txts, links, width=self.win_size)\n            webpage.save()\n\n    # errors: dictionary of error labels and values\n    def plot_current_errors(self, epoch, counter_ratio, opt, errors):\n        if not hasattr(self, 'plot_data'):\n            self.plot_data = {'X':[],'Y':[], 'legend':list(errors.keys())}\n        self.plot_data['X'].append(epoch + counter_ratio)\n        self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])\n        self.vis.line(\n            X=np.stack([np.array(self.plot_data['X'])]*len(self.plot_data['legend']),1),\n            Y=np.array(self.plot_data['Y']),\n            opts={\n                'title': self.name + ' loss over time',\n                'legend': self.plot_data['legend'],\n                'xlabel': 'epoch',\n                'ylabel': 'loss'},\n            win=self.display_id)\n\n    # errors: same format as |errors| of plotCurrentErrors\n    def print_current_errors(self, epoch, i, errors, t):\n        message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)\n        for k, v in errors.items():\n            v = ['%.3f' % iv for iv in v]\n            message += k + ': ' + ', '.join(v) + ' | '\n\n        print(message)\n        with open(self.log_name, \"a\") as log_file:\n            log_file.write('%s\\n' % message)\n\n    # save image to the disk\n    def save_images(self, webpage, visuals, image_path):\n        image_dir = webpage.get_image_dir()\n        short_path = ntpath.basename(image_path[0])\n        name = os.path.splitext(short_path)[0]\n\n        webpage.add_header(name)\n        ims = []\n        txts = []\n        links = []\n\n        for label, image_numpy in visuals.items():\n            image_name = '%s_%s.png' % (name, label)\n            save_path = os.path.join(image_dir, image_name)\n            util.save_image(image_numpy, save_path)\n\n            ims.append(image_name)\n            txts.append(label)\n            links.append(image_name)\n        webpage.add_images(ims, txts, links, width=self.win_size)\n\n    def save_image_matrix(self, visuals_list, save_path):\n        images_list = []\n        get_domain = lambda x: x.split('_')[-1]\n\n        for visuals in visuals_list:\n            pairs = list(visuals.items())\n            real_label, real_img = pairs[0]\n            real_dom = get_domain(real_label)\n\n            for label, img in pairs:\n                if 'fake' not in label:\n                    continue\n                if get_domain(label) == real_dom:\n                    images_list.append(real_img)\n                else:\n                    images_list.append(img)\n\n        immat = self.stack_images(images_list)\n        util.save_image(immat, save_path)\n\n    # reshape a list of images into a square matrix of them\n    def stack_images(self, list_np_images):\n        n = int(np.ceil(np.sqrt(len(list_np_images))))\n\n        # add padding between images\n        for i, im in enumerate(list_np_images):\n            val = 255 if i%n == i//n else 0\n            r_pad = np.pad(im[:,:,0], (3,3), mode='constant', constant_values=0)\n            g_pad = np.pad(im[:,:,1], (3,3), mode='constant', constant_values=val)\n            b_pad = np.pad(im[:,:,2], (3,3), mode='constant', constant_values=0)\n            list_np_images[i] = np.stack([r_pad,g_pad,b_pad], axis=2)\n\n        data = np.array(list_np_images)\n        data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))\n        data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])\n        return data\n"
  }
]