master 6f790ae5f4eb cached
43 files
155.3 KB
39.6k tokens
239 symbols
1 requests
Download .txt
Repository: sagiebenaim/OneShotTranslation
Branch: master
Commit: 6f790ae5f4eb
Files: 43
Total size: 155.3 KB

Directory structure:
gitextract_3ffpn665/

├── LICENSE
├── README.md
├── drawing_and_style_transfer/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── aligned_dataset.py
│   │   ├── base_data_loader.py
│   │   ├── base_dataset.py
│   │   ├── image_folder.py
│   │   ├── single_dataset.py
│   │   └── unaligned_dataset.py
│   ├── datasets/
│   │   ├── combine_A_and_B.py
│   │   ├── download_cyclegan_dataset.sh
│   │   └── make_dataset_aligned.py
│   ├── environment.yml
│   ├── models/
│   │   ├── __init__.py
│   │   ├── autoencoder_model.py
│   │   ├── base_model.py
│   │   ├── networks.py
│   │   ├── ost.py
│   │   └── test_model.py
│   ├── options/
│   │   ├── __init__.py
│   │   ├── base_options.py
│   │   ├── test_options.py
│   │   └── train_options.py
│   ├── scripts/
│   │   ├── test_ost.sh
│   │   ├── train_autoencoder.sh
│   │   └── train_ost.sh
│   ├── test.py
│   ├── train.py
│   └── util/
│       ├── __init__.py
│       ├── get_data.py
│       ├── html.py
│       ├── image_pool.py
│       ├── util.py
│       └── visualizer.py
└── mnist_to_svhn/
    ├── data_loader.py
    ├── download.sh
    ├── main_autoencoder.py
    ├── main_mnist_to_svhn.py
    ├── main_svhn_to_mnist.py
    ├── model.py
    ├── solver_autoencoder.py
    ├── solver_mnist_to_svhn.py
    └── solver_svhn_to_mnist.py

================================================
FILE CONTENTS
================================================

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2017 

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

--------------------------- LICENSE FOR mnist-svhn-transfer ---------

MIT License

Copyright (c) 2017

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE


--------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ---------

Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: README.md
================================================
# Pytorch implementation of One-Shot Unsupervised Cross Domain Translation ([arxiv](https://arxiv.org/abs/1806.06029)).

Prerequisites
--------------
- Python 3.6
- Pytorch 0.4
- Numpy/Scipy/Pandas
- Progressbar
- OpenCV
- [visdom](https://github.com/facebookresearch/visdom)
- [dominate](https://github.com/Knio/dominate)

## MNIST-to-SVHN and SVHN-to-MNIST

To train autoencoder for both MNIST and SVHN (In mnist_to_svhn folder):
python main_autoencoder.py --use_augmentation=True

To train OST for MNIST to SVHN:
python main_mnist_to_svhn.py --pretrained_g=True --save_models_and_samples=True --use_augmentation=True --one_way_cycle=True --freeze_shared=False

To train OST for SVHN to MNIST:
python main_svhn_to_mnist.py --pretrained_g=True --save_models_and_samples=True --use_augmentation=True --one_way_cycle=True --freeze_shared=False

## Drawing and Style Transfer Tasks

### Download Dataset

To download dataset (in drawing_and_style_transfer folder):
bash datasets/download_cyclegan_dataset.sh $DATASET_NAME
where DATASET_NAME is one of (facades, cityscapes, maps, monet2photo, summer2winter_yosemite)

### Train Autoencoder

To train autoencoder for facades (in drawing_and_style_transfer folder):
python train.py --dataroot=./datasets/facades/trainB --name=facades_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2

In the reverse direction (images of facades):
python train.py --dataroot=./datasets/facades/trainA --name=facades_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2

### Train OST

To train OST for images to facades:
python train.py --dataroot=./datasets/facades/ --name=facades_ost --load_dir=facades_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1

To train OST for facades to images (reverse direction):
python train.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --load_dir=facades_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'

To visualize losses: run python -m visdom.server

### Test OST

To test OST for images to facades:
python test.py --dataroot=./datasets/facades/ --name=facades_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1

To test OST for facades to images (reverse direction):
python test.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'

### Options
Additional scripts for other datasets are at ./drawing_and_style_transfer/scripts

Options are at ./drawing_and_style_transfer/options

## Reference
If you found this code useful, please cite the following paper:
```
@inproceedings{Benaim2018OneShotUC,
  title={One-Shot Unsupervised Cross Domain Translation},
  author={Sagie Benaim and Lior Wolf},
  booktitle={NeurIPS},
  year={2018}
}
```



================================================
FILE: drawing_and_style_transfer/data/__init__.py
================================================
import torch.utils.data
from data.base_data_loader import BaseDataLoader


def CreateDataLoader(opt):
    data_loader = CustomDatasetDataLoader()
    print(data_loader.name())
    data_loader.initialize(opt)
    return data_loader


def CreateDataset(opt):
    if opt.dataset_mode == 'aligned':
        from data.aligned_dataset import AlignedDataset
        dataset = AlignedDataset()
    elif opt.dataset_mode == 'unaligned':
        from data.unaligned_dataset import UnalignedDataset
        dataset = UnalignedDataset()
    elif opt.dataset_mode == 'single':
        from data.single_dataset import SingleDataset
        dataset = SingleDataset()
    else:
        raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)

    print("dataset [%s] was created" % (dataset.name()))
    dataset.initialize(opt)
    return dataset


class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)
        self.dataset = CreateDataset(opt)
        self.dataloader = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads))

    def load_data(self):
        return self

    def __len__(self):
        return len(self.dataset)

    def __iter__(self):
        for i, data in enumerate(self.dataloader):
            yield data


================================================
FILE: drawing_and_style_transfer/data/aligned_dataset.py
================================================
import os.path
import random
import torchvision.transforms as transforms
import torch
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image


class AlignedDataset(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.root = opt.dataroot
        self.dir_AB = os.path.join(opt.dataroot, opt.phase)
        self.AB_paths = sorted(make_dataset(self.dir_AB))
        assert (opt.resize_or_crop == 'resize_and_crop')

    def __getitem__(self, index):
        AB_path = self.AB_paths[index]
        AB = Image.open(AB_path).convert('RGB')
        w, h = AB.size
        w2 = int(w / 2)
        A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
        B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
        A = transforms.ToTensor()(A)
        B = transforms.ToTensor()(B)
        w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
        h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))

        A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
        B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]

        A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
        B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)

        if self.opt.which_direction == 'BtoA':
            input_nc = self.opt.output_nc
            output_nc = self.opt.input_nc
        else:
            input_nc = self.opt.input_nc
            output_nc = self.opt.output_nc

        if (not self.opt.no_flip) and random.random() < 0.5:
            idx = [i for i in range(A.size(2) - 1, -1, -1)]
            idx = torch.LongTensor(idx)
            A = A.index_select(2, idx)
            B = B.index_select(2, idx)

        if input_nc == 1:  # RGB to gray
            tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
            A = tmp.unsqueeze(0)

        if output_nc == 1:  # RGB to gray
            tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
            B = tmp.unsqueeze(0)

        return {'A': A, 'B': B,
                'A_paths': AB_path, 'B_paths': AB_path}

    def __len__(self):
        return len(self.AB_paths)

    def name(self):
        return 'AlignedDataset'


================================================
FILE: drawing_and_style_transfer/data/base_data_loader.py
================================================
class BaseDataLoader():
    def __init__(self):
        pass

    def initialize(self, opt):
        self.opt = opt
        pass

    def load_data(self):
        return None


================================================
FILE: drawing_and_style_transfer/data/base_dataset.py
================================================
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms


class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return 'BaseDataset'

    def initialize(self, opt):
        pass


def get_transform(opt):
    transform_list = []
    if opt.resize_or_crop == 'resize_and_crop':
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, Image.BICUBIC))
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'crop':
        transform_list.append(transforms.RandomCrop(opt.fineSize))
    elif opt.resize_or_crop == 'scale_width':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.fineSize)))
    elif opt.resize_or_crop == 'scale_width_and_crop':
        transform_list.append(transforms.Lambda(
            lambda img: __scale_width(img, opt.loadSize)))
        transform_list.append(transforms.RandomCrop(opt.fineSize))

    if opt.isTrain and not opt.no_flip_and_rotation:
        # Default augmentations as in paper
        transform_list.append(transforms.RandomHorizontalFlip())
        transform_list.append(transforms.RandomRotation(opt.rotation_degree))

    transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)


def __scale_width(img, target_width):
    ow, oh = img.size
    if (ow == target_width):
        return img
    w = target_width
    h = int(target_width * oh / ow)
    return img.resize((w, h), Image.BICUBIC)


================================================
FILE: drawing_and_style_transfer/data/image_folder.py
================================================
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################

import torch.utils.data as data

from PIL import Image
import os
import os.path

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)


def make_dataset(dir, max_items=-1, start=0):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)

    if max_items >= 0:
        return sorted(images)[start:start + max_items]
    return images


def default_loader(path):
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):
    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root)
        if len(imgs) == 0:
            raise (RuntimeError("Found 0 images in: " + root + "\n"
                                                               "Supported image extensions are: " +
                                ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.return_paths:
            return img, path
        else:
            return img

    def __len__(self):
        return len(self.imgs)


================================================
FILE: drawing_and_style_transfer/data/single_dataset.py
================================================
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image


class SingleDataset(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.root = opt.dataroot
        self.dir_A = os.path.join(opt.dataroot)

        self.A_paths = make_dataset(self.dir_A)

        self.A_paths = sorted(self.A_paths)

        self.transform = get_transform(opt)

    def __getitem__(self, index):
        A_path = self.A_paths[index]
        A_img = Image.open(A_path).convert('RGB')
        A = self.transform(A_img)
        if self.opt.which_direction == 'BtoA':
            input_nc = self.opt.output_nc
        else:
            input_nc = self.opt.input_nc

        if input_nc == 1:  # RGB to gray
            tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
            A = tmp.unsqueeze(0)

        return {'A': A, 'A_paths': A_path}

    def __len__(self):
        return len(self.A_paths)

    def name(self):
        return 'SingleImageDataset'


================================================
FILE: drawing_and_style_transfer/data/unaligned_dataset.py
================================================
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import random


class UnalignedDataset(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.root = opt.dataroot
        self.dir_A = os.path.join(opt.dataroot, opt.phase + opt.A)
        self.dir_B = os.path.join(opt.dataroot, opt.phase + opt.B)
        self.A_paths = make_dataset(self.dir_A, max_items=opt.max_items_A, start=opt.start)
        self.B_paths = make_dataset(self.dir_B, max_items=opt.max_items_B, start=opt.start)

        self.A_paths = sorted(self.A_paths)
        self.B_paths = sorted(self.B_paths)
        self.A_size = len(self.A_paths)
        self.B_size = len(self.B_paths)
        self.transform = get_transform(opt)

    def __getitem__(self, index):
        A_path = self.A_paths[index % self.A_size]
        if self.opt.serial_batches:
            index_B = index % self.B_size
        else:
            index_B = random.randint(0, self.B_size - 1)
        B_path = self.B_paths[index_B]
        # print('(A, B) = (%d, %d)' % (index_A, index_B))
        A_img = Image.open(A_path).convert('RGB')
        B_img = Image.open(B_path).convert('RGB')

        A = self.transform(A_img)
        B = self.transform(B_img)
        if self.opt.which_direction == 'BtoA':
            input_nc = self.opt.output_nc
            output_nc = self.opt.input_nc
        else:
            input_nc = self.opt.input_nc
            output_nc = self.opt.output_nc

        if input_nc == 1:  # RGB to gray
            tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
            A = tmp.unsqueeze(0)

        if output_nc == 1:  # RGB to gray
            tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
            B = tmp.unsqueeze(0)
        return {'A': A, 'B': B,
                'A_paths': A_path, 'B_paths': B_path}

    def __len__(self):
        return max(self.A_size, self.B_size)

    def name(self):
        return 'UnalignedDataset'


================================================
FILE: drawing_and_style_transfer/datasets/combine_A_and_B.py
================================================
import os
import numpy as np
import cv2
import argparse

parser = argparse.ArgumentParser('create image pairs')
parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true')
args = parser.parse_args()

for arg in vars(args):
    print('[%s] = ' % arg,  getattr(args, arg))

splits = os.listdir(args.fold_A)

for sp in splits:
    img_fold_A = os.path.join(args.fold_A, sp)
    img_fold_B = os.path.join(args.fold_B, sp)
    img_list = os.listdir(img_fold_A)
    if args.use_AB:
        img_list = [img_path for img_path in img_list if '_A.' in img_path]

    num_imgs = min(args.num_imgs, len(img_list))
    print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
    img_fold_AB = os.path.join(args.fold_AB, sp)
    if not os.path.isdir(img_fold_AB):
        os.makedirs(img_fold_AB)
    print('split = %s, number of images = %d' % (sp, num_imgs))
    for n in range(num_imgs):
        name_A = img_list[n]
        path_A = os.path.join(img_fold_A, name_A)
        if args.use_AB:
            name_B = name_A.replace('_A.', '_B.')
        else:
            name_B = name_A
        path_B = os.path.join(img_fold_B, name_B)
        if os.path.isfile(path_A) and os.path.isfile(path_B):
            name_AB = name_A
            if args.use_AB:
                name_AB = name_AB.replace('_A.', '.') # remove _A
            path_AB = os.path.join(img_fold_AB, name_AB)
            im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR)
            im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR)
            im_AB = np.concatenate([im_A, im_B], 1)
            cv2.imwrite(path_AB, im_AB)


================================================
FILE: drawing_and_style_transfer/datasets/download_cyclegan_dataset.sh
================================================
FILE=$1

if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" &&  $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
    echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
    exit 1
fi

URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./datasets/$FILE.zip
TARGET_DIR=./datasets/$FILE/
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d ./datasets/
rm $ZIP_FILE


================================================
FILE: drawing_and_style_transfer/datasets/make_dataset_aligned.py
================================================
import os

from PIL import Image


def get_file_paths(folder):
    image_file_paths = []
    for root, dirs, filenames in os.walk(folder):
        filenames = sorted(filenames)
        for filename in filenames:
            input_path = os.path.abspath(root)
            file_path = os.path.join(input_path, filename)
            if filename.endswith('.png') or filename.endswith('.jpg'):
                image_file_paths.append(file_path)

        break  # prevent descending into subfolders
    return image_file_paths


def align_images(a_file_paths, b_file_paths, target_path):
    if not os.path.exists(target_path):
        os.makedirs(target_path)

    for i in range(len(a_file_paths)):
        img_a = Image.open(a_file_paths[i])
        img_b = Image.open(b_file_paths[i])
        assert(img_a.size == img_b.size)

        aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1]))
        aligned_image.paste(img_a, (0, 0))
        aligned_image.paste(img_b, (img_a.size[0], 0))
        aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i)))


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--dataset-path',
        dest='dataset_path',
        help='Which folder to process (it should have subfolders testA, testB, trainA and trainB'
    )
    args = parser.parse_args()

    dataset_folder = args.dataset_path
    print(dataset_folder)

    test_a_path = os.path.join(dataset_folder, 'testA')
    test_b_path = os.path.join(dataset_folder, 'testB')
    test_a_file_paths = get_file_paths(test_a_path)
    test_b_file_paths = get_file_paths(test_b_path)
    assert(len(test_a_file_paths) == len(test_b_file_paths))
    test_path = os.path.join(dataset_folder, 'test')

    train_a_path = os.path.join(dataset_folder, 'trainA')
    train_b_path = os.path.join(dataset_folder, 'trainB')
    train_a_file_paths = get_file_paths(train_a_path)
    train_b_file_paths = get_file_paths(train_b_path)
    assert(len(train_a_file_paths) == len(train_b_file_paths))
    train_path = os.path.join(dataset_folder, 'train')

    align_images(test_a_file_paths, test_b_file_paths, test_path)
    align_images(train_a_file_paths, train_b_file_paths, train_path)


================================================
FILE: drawing_and_style_transfer/environment.yml
================================================
name: OST
channels:
- peterjc123
- defaults
dependencies:
- python=3.6.5
- pytorch=0.4.0
- scipy
- pip:
  - dominate==2.3.1
  - git+https://github.com/pytorch/vision.git
  - Pillow==5.0.0
  - numpy==1.14.1
  - visdom==0.1.7


================================================
FILE: drawing_and_style_transfer/models/__init__.py
================================================
def create_model(opt):
    print(opt.model)
    if opt.model == 'ost':
        assert (opt.dataset_mode == 'unaligned')
        from .ost import OSTModel
        model = OSTModel()
    elif opt.model == 'autoencoder':
        assert (opt.dataset_mode == 'single')
        from .autoencoder_model import AutoEncoderModel
        model = AutoEncoderModel()
    elif opt.model == 'test':
        assert (opt.dataset_mode == 'single')
        from .test_model import TestModel
        model = TestModel()
    else:
        raise NotImplementedError('model [%s] not implemented.' % opt.model)
    model.initialize(opt)
    print("model [%s] was created" % (model.name()))
    return model


================================================
FILE: drawing_and_style_transfer/models/autoencoder_model.py
================================================
import torch
from collections import OrderedDict
from torch.autograd import Variable
import itertools
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks


class AutoEncoderModel(BaseModel):
    def name(self):
        return 'AutoEncoderModel'

    def set_encoders_and_decoders(self, opt):
        n_downsampling = opt.n_downsampling
        start_unshared = 0
        num_unshared = opt.num_unshared
        start_shared = num_unshared
        end_shared = n_downsampling
        start_dec_shared = start_unshared
        end_dec_shared = start_unshared + (end_shared - start_shared)
        start_dec_unshared = end_dec_shared
        end_dec_unshared = n_downsampling

        num_res_blocks_unshared = opt.num_res_blocks_unshared
        n_res_blocks_shared = opt.num_res_blocks_shared

        self.netEnc_b, self.netDec_b = networks.define_ED(opt.input_nc, opt.output_nc,
                                                          opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout,
                                                          opt.init_type, self.gpu_ids,
                                                          n_blocks_encoder=num_res_blocks_unshared,
                                                          n_blocks_decoder=num_res_blocks_unshared,
                                                          start=start_unshared,
                                                          end=num_unshared, n_downsampling=n_downsampling,
                                                          input_layer=True,
                                                          output_layer=True, start_dec=start_dec_unshared,
                                                          end_dec=end_dec_unshared)

        self.netEnc_shared, self.netDec_shared = networks.define_ED(opt.input_nc, opt.output_nc,
                                                                    opt.ngf, opt.which_model_netG, opt.norm,
                                                                    not opt.no_dropout,
                                                                    opt.init_type, self.gpu_ids,
                                                                    n_blocks_encoder=n_res_blocks_shared,
                                                                    n_blocks_decoder=n_res_blocks_shared,
                                                                    start=start_shared, n_downsampling=n_downsampling,
                                                                    end=end_shared,
                                                                    input_layer=False,
                                                                    output_layer=False, start_dec=start_dec_shared,
                                                                    end_dec=end_dec_shared)

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.set_encoders_and_decoders(opt)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD = networks.define_D(opt.output_nc, opt.ndf,
                                          opt.which_model_netD,
                                          opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netEnc_b, 'Enc_b', which_epoch)
            self.load_network(self.netDec_b, 'Dec_b', which_epoch)
            self.load_network(self.netEnc_shared, 'Enc_shared', which_epoch)
            self.load_network(self.netDec_shared, 'Dec_shared', which_epoch)
            if self.isTrain:
                self.load_network(self.netD, 'D', which_epoch)

        if self.isTrain:
            self.fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_Enc = torch.optim.Adam(
                itertools.chain(self.netEnc_b.parameters(), self.netEnc_shared.parameters()),
                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_Dec = torch.optim.Adam(
                itertools.chain(self.netDec_b.parameters(), self.netDec_shared.parameters()),
                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_Enc)
            self.optimizers.append(self.optimizer_Dec)
            self.optimizers.append(self.optimizer_D)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netEnc_b)
        networks.print_network(self.netDec_b)
        networks.print_network(self.netEnc_shared)
        networks.print_network(self.netDec_shared)
        if self.isTrain:
            networks.print_network(self.netD)
        print('-----------------------------------------------')

    def set_input(self, input):
        # 'A' is given as single_dataset
        input_B = input['A']
        if len(self.gpu_ids) > 0:
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_B = input_B
        # 'A' is given as single_dataset
        self.image_paths = input['A_paths']

    def forward(self):
        self.real_B = Variable(self.input_B)

    def netEnc(self, x):
        return self.netEnc_shared(self.netEnc_b(x))

    def netDec(self, x):
        return self.netDec_b(self.netDec_shared(x))

    def test(self):
        real_B = Variable(self.input_B, volatile=True)
        fake_B = self.netDec(self.netEnc(real_B))
        self.fake_B = fake_B.data

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D(self):
        fake_B = self.fake_B_pool.query(self.fake_B)
        loss_D = self.backward_D_basic(self.netD, self.real_B, fake_B)
        self.loss_D = loss_D.data[0]

    def _compute_kl(self, mu):
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def backward_G(self):
        lambda_B = self.opt.lambda_B

        # GAN loss D_B(G_B(B))
        enc_b = self.netEnc(self.real_B)
        fake_B = self.netDec(enc_b)
        pred_fake = self.netD(fake_B)
        loss_Gan = self.criterionGAN(pred_fake, True)
        loss_idt_B = self.criterionIdt(fake_B, self.real_B) * lambda_B
        loss_kl_B = self.opt.kl_lambda * self._compute_kl(enc_b)

        # combined loss
        loss_G = loss_Gan + loss_idt_B + loss_kl_B
        loss_G.backward()

        self.fake_B = fake_B.data
        self.loss_Gan = loss_Gan.data[0]
        self.loss_idt_B = loss_idt_B.data[0]

    def optimize_parameters(self):
        # forward
        self.forward()

        # G
        self.optimizer_Enc.zero_grad()
        self.optimizer_Dec.zero_grad()
        self.backward_G()
        self.optimizer_Enc.step()
        self.optimizer_Dec.step()

        # D
        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

    def get_current_errors(self):
        ret_errors = OrderedDict([('D', self.loss_D), ('G_B', self.loss_Gan), ('Idt_B', self.loss_idt_B)])
        return ret_errors

    def get_current_visuals(self):
        real_B = util.tensor2im(self.input_B)
        fake_B = util.tensor2im(self.fake_B)
        ret_visuals = OrderedDict([('real_B', real_B), ('fake_B', fake_B), ])
        return ret_visuals

    def save(self, label):
        self.save_network(self.netEnc_b, 'Enc_b', label, self.gpu_ids)
        self.save_network(self.netDec_b, 'Dec_b', label, self.gpu_ids)
        self.save_network(self.netEnc_shared, 'Enc_shared', label, self.gpu_ids)
        self.save_network(self.netDec_shared, 'Dec_shared', label, self.gpu_ids)
        self.save_network(self.netD, 'D', label, self.gpu_ids)


================================================
FILE: drawing_and_style_transfer/models/base_model.py
================================================
import os
import torch


class BaseModel(object):
    def name(self):
        return 'BaseModel'

    def initialize(self, opt):
        self.opt = opt
        self.gpu_ids = opt.gpu_ids
        self.isTrain = opt.isTrain
        self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
        self.load_dir = os.path.join(opt.checkpoints_dir, opt.load_dir)
        self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)

    def set_input(self, input):
        self.input = input

    def forward(self):
        pass

    # used in test time, no backprop
    def test(self):
        pass

    def get_image_paths(self):
        pass

    def optimize_parameters(self):
        pass

    def get_current_visuals(self):
        return self.input

    def get_current_errors(self):
        return {}

    def save(self, label):
        pass

    # helper saving function that can be used by subclasses
    def save_network(self, network, network_label, epoch_label, gpu_ids):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.save_dir, save_filename)
        torch.save(network.cpu().state_dict(), save_path)
        if len(gpu_ids) and torch.cuda.is_available():
            network.cuda(gpu_ids[0])

    # helper loading function that can be used by subclasses
    def load_network(self, network, network_label, epoch_label):
        save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
        save_path = os.path.join(self.load_dir, save_filename)
        network.load_state_dict(torch.load(save_path))

    # update learning rate (called once every epoch)
    def update_learning_rate(self):
        for scheduler in self.schedulers:
            scheduler.step()
        lr = self.optimizers[0].param_groups[0]['lr']
        print('learning rate = %.7f' % lr)

    def as_np(self, data):
        return data.cpu().data.numpy()


================================================
FILE: drawing_and_style_transfer/models/networks.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
from torch.optim import lr_scheduler


###############################################################################
# Functions
###############################################################################

class pixel_norm(nn.Module):
    def forward(self, x, epsilon=1e-8):
        return x * torch.rsqrt(torch.mean(x.pow(2), dim=1, keepdim=True) + epsilon)


def weights_init_normal(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.normal(m.weight.data, 0.0, 0.02)
    elif classname.find('Linear') != -1:
        init.normal(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


def weights_init_xavier(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.xavier_normal(m.weight.data, gain=0.02)
    elif classname.find('Linear') != -1:
        init.xavier_normal(m.weight.data, gain=0.02)
    elif classname.find('BatchNorm2d') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


def weights_init_kaiming(m):
    classname = m.__class__.__name__
    # print(classname)
    if classname.find('Conv') != -1:
        init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm2d') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


def weights_init_orthogonal(m):
    classname = m.__class__.__name__
    print(classname)
    if classname.find('Conv') != -1:
        init.orthogonal(m.weight.data, gain=1)
    elif classname.find('Linear') != -1:
        init.orthogonal(m.weight.data, gain=1)
    elif classname.find('BatchNorm2d') != -1:
        init.normal(m.weight.data, 1.0, 0.02)
        init.constant(m.bias.data, 0.0)


def init_weights(net, init_type='normal'):
    print('initialization method [%s]' % init_type)
    if init_type == 'normal':
        net.apply(weights_init_normal)
    elif init_type == 'xavier':
        net.apply(weights_init_xavier)
    elif init_type == 'kaiming':
        net.apply(weights_init_kaiming)
    elif init_type == 'orthogonal':
        net.apply(weights_init_orthogonal)
    else:
        raise NotImplementedError('initialization method [%s] is not implemented' % init_type)


def get_norm_layer(norm_type='instance'):
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
    elif norm_type == 'none':
        norm_layer = None
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer


def get_scheduler(optimizer, opt):
    if opt.lr_policy == 'lambda':
        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
            return lr_l

        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == 'step':
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == 'plateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
    else:
        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
    return scheduler


def define_ED(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal',
              gpu_ids=[], n_downsampling=2, start=0, end=2, input_layer=True, output_layer=True, n_blocks_encoder=9,
              n_blocks_decoder=9, start_dec=0, end_dec=1):
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert (torch.cuda.is_available())

    if which_model_netG == 'resnet_9blocks':
        netE = ResnetEncoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             n_blocks=n_blocks_encoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, start=start,
                             end=end, input_layer=input_layer)
        netD = ResnetDecoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             n_blocks=n_blocks_decoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, end=end_dec,
                             start=start_dec, output_layer=output_layer)
    elif which_model_netG == 'resnet_6blocks':
        netE = ResnetEncoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             n_blocks=n_blocks_encoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, start=start,
                             end=end, input_layer=input_layer)
        netD = ResnetDecoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             n_blocks=n_blocks_decoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, end=end_dec,
                             start=start_dec, output_layer=output_layer)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
    if len(gpu_ids) > 0:
        netE.cuda(gpu_ids[0])
        netD.cuda(gpu_ids[0])
    init_weights(netE, init_type=init_type)
    init_weights(netD, init_type=init_type)
    return netE, netD


def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal',
             gpu_ids=[]):
    netG = None
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert (torch.cuda.is_available())

    if which_model_netG == 'resnet_9blocks':
        netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,
                               gpu_ids=gpu_ids)
    elif which_model_netG == 'resnet_6blocks':
        netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,
                               gpu_ids=gpu_ids)
    elif which_model_netG == 'unet_128':
        netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             gpu_ids=gpu_ids)
    elif which_model_netG == 'unet_256':
        netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
                             gpu_ids=gpu_ids)
    else:
        raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
    if len(gpu_ids) > 0:
        netG.cuda(gpu_ids[0])
    init_weights(netG, init_type=init_type)
    return netG


def define_D(input_nc, ndf, which_model_netD,
             n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]):
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert (torch.cuda.is_available())
    if which_model_netD == 'basic':
        netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
                                   gpu_ids=gpu_ids)
    elif which_model_netD == 'n_layers':
        netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
                                   gpu_ids=gpu_ids)
    elif which_model_netD == 'pixel':
        netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
    else:
        raise NotImplementedError('Discriminator model name [%s] is not recognized' %
                                  which_model_netD)
    if use_gpu:
        netD.cuda(gpu_ids[0])
    init_weights(netD, init_type=init_type)
    return netD


def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)


##############################################################################
# Classes
##############################################################################


# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.FloatTensor):
        super(GANLoss, self).__init__()
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        if use_lsgan:
            self.loss = nn.MSELoss()
        else:
            self.loss = nn.BCELoss()

    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = Variable(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = Variable(fake_tensor, requires_grad=False)
            target_tensor = self.fake_label_var
        return target_tensor

    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real)
        return self.loss(input, target_tensor)

        # Defines the generator that consists of Resnet blocks between a few
        # downsampling/upsampling operations.
        # Code and idea originally from Justin Johnson's architecture.
        # https://github.com/jcjohnson/fast-neural-style/


class ResnetEncoder(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
                 gpu_ids=[], padding_type='reflect', n_downsampling=2, start=0, end=2, input_layer=True, n_blocks=6):
        assert (n_blocks >= 0)
        super(ResnetEncoder, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        self.gpu_ids = gpu_ids
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = []
        if input_layer:
            model = [nn.ReflectionPad2d(3),
                     nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
                               bias=use_bias),
                     norm_layer(ngf),
                     nn.ReLU(True)]

        for i in range(start, end):
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [
                ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                            use_bias=use_bias)]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)


class ResnetDecoder(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
                 gpu_ids=[], padding_type='reflect', n_downsampling=2, start=0, end=2, output_layer=True, n_blocks=6):
        assert (n_blocks >= 0)
        super(ResnetDecoder, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        self.gpu_ids = gpu_ids
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = []
        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [
                ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                            use_bias=use_bias)]
        for i in range(start, end):
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]

        if output_layer:
            model += [nn.ReflectionPad2d(3)]
            model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
            model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)


# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6,
                 gpu_ids=[], padding_type='reflect'):
        assert (n_blocks >= 0)
        super(ResnetGenerator, self).__init__()
        self.input_nc = input_nc
        self.output_nc = output_nc
        self.ngf = ngf
        self.gpu_ids = gpu_ids
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
                           bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                stride=2, padding=1, bias=use_bias),
                      norm_layer(ngf * mult * 2),
                      nn.ReLU(True)]

        mult = 2 ** n_downsampling
        for i in range(n_blocks):
            model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
                                  use_bias=use_bias)]

        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                         kernel_size=3, stride=2,
                                         padding=1, output_padding=1,
                                         bias=use_bias),
                      norm_layer(int(ngf * mult / 2)),
                      nn.ReLU(True)]
        model += [nn.ReflectionPad2d(3)]
        model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
        model += [nn.Tanh()]

        self.model = nn.Sequential(*model)

    def forward(self, input):
        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)


# Define a resnet block
class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim),
                       nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out


# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
    def __init__(self, input_nc, output_nc, num_downs, ngf=64,
                 norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
        super(UnetGenerator, self).__init__()
        self.gpu_ids = gpu_ids

        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
                                             innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
                                                 norm_layer=norm_layer, use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
                                             norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
                                             norm_layer=norm_layer)

        self.model = unet_block

    def forward(self, input):
        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)


# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
#   |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
    def __init__(self, outer_nc, inner_nc, input_nc=None,
                 submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d
        if input_nc is None:
            input_nc = outer_nc
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)
        downnorm = norm_layer(inner_nc)
        uprelu = nn.ReLU(True)
        upnorm = norm_layer(outer_nc)

        if outermost:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1)
            down = [downconv]
            up = [uprelu, upconv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv]
            up = [uprelu, upconv, upnorm]
            model = down + up
        else:
            upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
                                        kernel_size=4, stride=2,
                                        padding=1, bias=use_bias)
            down = [downrelu, downconv, downnorm]
            up = [uprelu, upconv, upnorm]

            if use_dropout:
                model = down + [submodule] + up + [nn.Dropout(0.5)]
            else:
                model = down + [submodule] + up

        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([x, self.model(x)], 1)


# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
        super(NLayerDiscriminator, self).__init__()
        self.gpu_ids = gpu_ids
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        kw = 4
        padw = 1
        sequence = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            sequence += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        sequence += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                      kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        if use_sigmoid:
            sequence += [nn.Sigmoid()]

        self.model = nn.Sequential(*sequence)

    def forward(self, input):
        if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            return self.model(input)


class PixelDiscriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
        super(PixelDiscriminator, self).__init__()
        self.gpu_ids = gpu_ids
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        self.net = [
            nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
            norm_layer(ndf * 2),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]

        if use_sigmoid:
            self.net.append(nn.Sigmoid())

        self.net = nn.Sequential(*self.net)

    def forward(self, input):
        if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
            return nn.parallel.data_parallel(self.net, input, self.gpu_ids)
        else:
            return self.net(input)


================================================
FILE: drawing_and_style_transfer/models/ost.py
================================================
import torch
from collections import OrderedDict
from torch.autograd import Variable
import itertools
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks


class OSTModel(BaseModel):
    def name(self):
        return 'OSTModel'

    def _compute_kl(self, mu):
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def set_encoders_and_decoders(self, opt):
        n_downsampling = opt.n_downsampling
        start_unshared = 0
        num_unshared = opt.num_unshared
        start_shared = num_unshared
        end_shared = n_downsampling
        start_dec_shared = start_unshared
        end_dec_shared = start_unshared + (end_shared - start_shared)
        start_dec_unshared = end_dec_shared
        end_dec_unshared = n_downsampling

        num_res_blocks_unshared = opt.num_res_blocks_unshared
        n_res_blocks_shared = opt.num_res_blocks_shared

        self.netEnc_a, self.netDec_a = networks.define_ED(opt.input_nc, opt.output_nc,
                                                          opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout,
                                                          opt.init_type, self.gpu_ids,
                                                          n_blocks_encoder=num_res_blocks_unshared,
                                                          n_blocks_decoder=num_res_blocks_unshared,
                                                          start=start_unshared,
                                                          end=num_unshared, n_downsampling=n_downsampling,
                                                          input_layer=True, output_layer=True,
                                                          start_dec=start_dec_unshared, end_dec=end_dec_unshared)

        self.netEnc_b, self.netDec_b = networks.define_ED(opt.input_nc, opt.output_nc,
                                                          opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout,
                                                          opt.init_type, self.gpu_ids,
                                                          n_blocks_encoder=num_res_blocks_unshared,
                                                          n_blocks_decoder=num_res_blocks_unshared,
                                                          start=start_unshared,
                                                          end=num_unshared, n_downsampling=n_downsampling,
                                                          input_layer=True,
                                                          output_layer=True, start_dec=start_dec_unshared,
                                                          end_dec=end_dec_unshared)

        self.netEnc_shared, self.netDec_shared = networks.define_ED(opt.input_nc, opt.output_nc,
                                                                    opt.ngf, opt.which_model_netG, opt.norm,
                                                                    not opt.no_dropout,
                                                                    opt.init_type, self.gpu_ids,
                                                                    n_blocks_encoder=n_res_blocks_shared,
                                                                    n_blocks_decoder=n_res_blocks_shared,
                                                                    start=start_shared, n_downsampling=n_downsampling,
                                                                    end=end_shared,
                                                                    input_layer=False,
                                                                    output_layer=False, start_dec=start_dec_shared,
                                                                    end_dec=end_dec_shared)

    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.set_encoders_and_decoders(opt)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_a = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
            self.netD_b = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

            if not opt.dont_load_pretrained_autoencoder:
                which_epoch = opt.which_epoch
                self.load_network(self.netEnc_b, 'Enc_b', which_epoch)
                self.load_network(self.netDec_b, 'Dec_b', which_epoch)
                self.load_network(self.netEnc_shared, 'Enc_shared', which_epoch)
                self.load_network(self.netDec_shared, 'Dec_shared', which_epoch)

        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netEnc_a, 'Enc_a', which_epoch)
            self.load_network(self.netDec_a, 'Dec_a', which_epoch)
            self.load_network(self.netEnc_b, 'Enc_b', which_epoch)
            self.load_network(self.netDec_b, 'Dec_b', which_epoch)
            self.load_network(self.netEnc_shared, 'Enc_shared', which_epoch)
            self.load_network(self.netDec_shared, 'Dec_shared', which_epoch)
            if self.isTrain:
                self.load_network(self.netD_a, 'D_a', which_epoch)
                self.load_network(self.netD_b, 'D_b', which_epoch)

        if self.isTrain:
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)

            # define loss functions
            self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionIdt = torch.nn.L1Loss()

            # initialize optimizers
            self.optimizer_Enc_a = torch.optim.Adam(self.netEnc_a.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_Dec_a = torch.optim.Adam(self.netDec_a.parameters(),
                                                    lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_Enc_b = torch.optim.Adam(
                itertools.chain(self.netEnc_b.parameters(), self.netEnc_shared.parameters()),
                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_Dec_b = torch.optim.Adam(
                itertools.chain(self.netDec_b.parameters(), self.netDec_shared.parameters()),
                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_a = torch.optim.Adam(self.netD_a.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_b = torch.optim.Adam(self.netD_b.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_Enc_a)
            self.optimizers.append(self.optimizer_Dec_a)
            self.optimizers.append(self.optimizer_Enc_b)
            self.optimizers.append(self.optimizer_Dec_b)
            self.optimizers.append(self.optimizer_D_a)
            self.optimizers.append(self.optimizer_D_b)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netEnc_a)
        networks.print_network(self.netDec_a)
        networks.print_network(self.netEnc_b)
        networks.print_network(self.netDec_b)
        networks.print_network(self.netEnc_shared)
        networks.print_network(self.netDec_shared)
        if self.isTrain:
            networks.print_network(self.netD_a)
            networks.print_network(self.netD_b)
        print('-----------------------------------------------')

    def set_input(self, input):
        AtoB = self.opt.which_direction == 'AtoB'
        input_A = input['A' if AtoB else 'B']
        input_B = input['B' if AtoB else 'A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
            input_B = input_B.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.input_B = input_B
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        self.real_A = Variable(self.input_A)
        self.real_B = Variable(self.input_B)

    def test(self):
        real_A = Variable(self.input_A, volatile=True)
        real_B = Variable(self.input_B, volatile=True)
        enc_a = self.netEnc_shared(self.netEnc_a(real_A))
        enc_b = self.netEnc_shared(self.netEnc_b(real_B))

        fake_AA = self.netDec_a(self.netDec_shared(enc_a))
        fake_AB = self.netDec_b(self.netDec_shared(enc_a))
        fake_BB = self.netDec_b(self.netDec_shared(enc_b))

        enc_ab = self.netEnc_shared(self.netEnc_b(fake_AB))
        fake_ABA = self.netDec_a(self.netDec_shared(enc_ab))

        self.fake_AA = fake_AA.data
        self.fake_AB = fake_AB.data
        self.fake_BB = fake_BB.data
        self.fake_ABA = fake_ABA.data

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def backward_D_basic(self, netD, real, fake):
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        # backward
        loss_D.backward()
        return loss_D

    def backward_D(self):
        fake_AB = self.fake_B_pool.query(self.fake_AB)
        loss_D_ab = self.backward_D_basic(self.netD_b, self.real_B, fake_AB)
        self.loss_D_ab = loss_D_ab.data[0]

        fake_BB = self.fake_B_pool.query(self.fake_BB)
        loss_D_bb = self.backward_D_basic(self.netD_b, self.real_B, fake_BB)
        self.loss_D_bb = loss_D_bb.data[0]

    def backward_G(self):
        # GAN loss D_A(G_A(A))
        enc_a = self.netEnc_shared(self.netEnc_a(self.real_A))
        enc_b = self.netEnc_shared(self.netEnc_b(self.real_B))
        fake_AA = self.netDec_a(self.netDec_shared(enc_a))
        fake_AB = self.netDec_b(self.netDec_shared(enc_a))
        fake_BB = self.netDec_b(self.netDec_shared(enc_b))
        enc_ab = self.netEnc_shared(self.netEnc_b(fake_AB))
        fake_ABA = self.netDec_a(self.netDec_shared(enc_ab))

        pred_fake_AB = self.netD_b(fake_AB)
        loss_Gan_AB = self.criterionGAN(pred_fake_AB, True)
        loss_idt_A = self.criterionIdt(fake_AA, self.real_A)
        loss_cycle_A = self.opt.lambda_A * self.criterionIdt(fake_ABA, self.real_A)
        loss_idt_B = self.criterionIdt(fake_BB, self.real_B)
        pred_fake_BB = self.netD_b(fake_BB)
        loss_Gan_BB = self.criterionGAN(pred_fake_BB, True)
        loss_kl_B = self.opt.kl_lambda * self._compute_kl(enc_b)

        # combined losses
        loss_G_B = loss_idt_B + loss_kl_B + loss_Gan_BB
        loss_G_A = loss_Gan_AB + loss_cycle_A + loss_idt_A

        self.fake_AA = fake_AA.data
        self.fake_BB = fake_BB.data
        self.fake_AB = fake_AB.data
        self.fake_ABA = fake_ABA.data
        self.loss_cycle_A = loss_cycle_A.data[0]
        self.loss_Gan_AB = loss_Gan_AB.data[0]
        self.loss_Gan_BB = loss_Gan_AB.data[0]
        self.loss_idt_A = loss_idt_A.data[0]
        self.loss_idt_B = loss_idt_B.data[0]
        self.loss_kl_B = loss_kl_B.data[0]

        return loss_G_A, loss_G_B

    def optimize_parameters(self):
        # forward
        self.forward()
        loss_G_A, loss_G_B = self.backward_G()

        # x loss updates
        self.optimizer_Enc_a.zero_grad()
        self.optimizer_Dec_a.zero_grad()
        loss_G_A.backward(retain_graph=True)
        self.optimizer_Enc_a.step()
        self.optimizer_Dec_a.step()

        # B loss updates
        self.optimizer_Enc_b.zero_grad()
        self.optimizer_Dec_b.zero_grad()
        loss_G_B.backward()
        self.optimizer_Enc_b.step()
        self.optimizer_Dec_b.step()

        # D
        self.optimizer_D_a.zero_grad()
        self.optimizer_D_b.zero_grad()
        self.backward_D()
        self.optimizer_D_a.step()
        self.optimizer_D_b.step()

    def get_current_errors(self):
        ret_errors = OrderedDict(
            [('D_ab', self.loss_D_ab), ('D_bb', self.loss_D_bb),
             ('G_AB', self.loss_Gan_AB), ('G_BB', self.loss_Gan_BB),
             ('Idt_B', self.loss_idt_B), ('Idt_A', self.loss_idt_A),
             ('Cycle_A', self.loss_cycle_A), ('Kl_B', self.loss_kl_B), ])
        return ret_errors

    def get_current_visuals(self):
        real_A = util.tensor2im(self.input_A)
        real_B = util.tensor2im(self.input_B)
        fake_BB = util.tensor2im(self.fake_BB)
        fake_AB = util.tensor2im(self.fake_AB)
        fake_AA = util.tensor2im(self.fake_AA)
        fake_ABA = util.tensor2im(self.fake_ABA)

        ret_visuals = OrderedDict(
            [('real_B', real_B), ('fake_BB', fake_BB),
             ('real_A', real_A), ('fake_AA', fake_AA), ('fake_AB', fake_AB), ('fake_ABA', fake_ABA), ])
        return ret_visuals

    def save(self, label):
        self.save_network(self.netEnc_a, 'Enc_a', label, self.gpu_ids)
        self.save_network(self.netDec_a, 'Dec_a', label, self.gpu_ids)
        self.save_network(self.netD_a, 'D_a', label, self.gpu_ids)
        self.save_network(self.netEnc_b, 'Enc_b', label, self.gpu_ids)
        self.save_network(self.netDec_b, 'Dec_b', label, self.gpu_ids)
        self.save_network(self.netD_b, 'D_b', label, self.gpu_ids)
        self.save_network(self.netEnc_shared, 'Enc_shared', label, self.gpu_ids)
        self.save_network(self.netDec_shared, 'Dec_shared', label, self.gpu_ids)


================================================
FILE: drawing_and_style_transfer/models/test_model.py
================================================
from torch.autograd import Variable
from collections import OrderedDict
import util.util as util
from .base_model import BaseModel
from . import networks


class TestModel(BaseModel):
    def name(self):
        return 'TestModel'

    def initialize(self, opt):
        assert (not opt.isTrain)
        BaseModel.initialize(self, opt)
        self.netG = networks.define_G(opt.input_nc, opt.output_nc,
                                      opt.ngf, opt.which_model_netG,
                                      opt.norm, not opt.no_dropout,
                                      opt.init_type,
                                      self.gpu_ids)
        which_epoch = opt.which_epoch
        self.load_network(self.netG, 'G', which_epoch)

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG)
        print('-----------------------------------------------')

    def set_input(self, input):
        # we need to use single_dataset mode
        input_A = input['A']
        if len(self.gpu_ids) > 0:
            input_A = input_A.cuda(self.gpu_ids[0], async=True)
        self.input_A = input_A
        self.image_paths = input['A_paths']

    def test(self):
        self.real_A = Variable(self.input_A, volatile=True)
        self.fake_B = self.netG(self.real_A)

    # get image paths
    def get_image_paths(self):
        return self.image_paths

    def get_current_visuals(self):
        real_A = util.tensor2im(self.real_A.data)
        fake_B = util.tensor2im(self.fake_B.data)
        return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])


================================================
FILE: drawing_and_style_transfer/options/__init__.py
================================================


================================================
FILE: drawing_and_style_transfer/options/base_options.py
================================================
import argparse
import os
from util import util
import torch


class BaseOptions():
    def __init__(self):
        self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        self.initialized = False

    def initialize(self):
        self.parser.add_argument('--dataroot', required=True,
                                 help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
        self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
        self.parser.add_argument('--max_items_A', type=int, default=-1,
                                 help='max number of items for domain A, -1 indicates no maximum')
        self.parser.add_argument('--max_items_B', type=int, default=-1,
                                 help='max number of items for domain B, -1 indicates no maximum')
        self.parser.add_argument('--start', type=int, default=0,
                                 help='starting index of items of domain A, after sorting')
        self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
        self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
        self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
        self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
        self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
        self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
        self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
        self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks',
                                 help='selects model to use for netG and netED')
        self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
        self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
        self.parser.add_argument('--name', type=str, default='experiment_name',
                                 help='name of the experiment. It decides where to store samples and models')
        self.parser.add_argument('--dataset_mode', type=str, default='unaligned',
                                 help='chooses how datasets are loaded. [unaligned | aligned | single]')
        self.parser.add_argument('--model', type=str, default='ost',
                                 help='chooses which model to use. ost, autoencoder, test.')
        self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
        self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
        self.parser.add_argument('--load_dir', type=str, default='./checkpoints', help='models are loaded here')
        self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
        self.parser.add_argument('--norm', type=str, default='instance',
                                 help='instance normalization or batch normalization')
        self.parser.add_argument('--serial_batches', action='store_true',
                                 help='if true, takes images in order to make batches, otherwise takes them randomly')
        self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
        self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
        self.parser.add_argument('--display_server', type=str, default="http://localhost",
                                 help='visdom server of the web display')
        self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
        self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
        self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop',
                                 help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
        self.parser.add_argument('--no_flip_and_rotation', action='store_true',
                                 help='if specified, do not flip and rotate the images for data augmentation')
        self.parser.add_argument('--rotation_degree', type=int, default=7, help='rotation degree used for augmentation')
        self.parser.add_argument('--init_type', type=str, default='normal',
                                 help='network initialization [normal|xavier|kaiming|orthogonal]')
        self.parser.add_argument('--A', type=str, default='A',
                                 help='used to exchange dataset A for B by setting the value to B')
        self.parser.add_argument('--B', type=str, default='B',
                                 help='used to exchange dataset B for A by setting the value to A')
        self.parser.add_argument('--n_downsampling', type=int, default=2,
                                 help="number of downsampling/upsampling convolutional/deconvolutional layers")
        self.parser.add_argument('--num_unshared', type=int, default=1,
                                 help="number of unshared encoder/decoder layers, not including input and final layers")
        self.parser.add_argument('--num_res_blocks_unshared', type=int, default=0,
                                 help='number of unshared resnet blocks')
        self.parser.add_argument('--num_res_blocks_shared', type=int, default=6, help='number of shared resnet blocks')

        self.initialized = True

    def parse(self):
        if not self.initialized:
            self.initialize()
        self.opt = self.parser.parse_args()
        self.opt.isTrain = self.isTrain  # train or test

        str_ids = self.opt.gpu_ids.split(',')
        self.opt.gpu_ids = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0:
                self.opt.gpu_ids.append(id)

        # set gpu ids
        if len(self.opt.gpu_ids) > 0:
            torch.cuda.set_device(self.opt.gpu_ids[0])

        args = vars(self.opt)

        print('------------ Options -------------')
        for k, v in sorted(args.items()):
            print('%s: %s' % (str(k), str(v)))
        print('-------------- End ----------------')

        # save to the disk
        expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
        util.mkdirs(expr_dir)
        file_name = os.path.join(expr_dir, 'opt.txt')
        with open(file_name, 'wt') as opt_file:
            opt_file.write('------------ Options -------------\n')
            for k, v in sorted(args.items()):
                opt_file.write('%s: %s\n' % (str(k), str(v)))
            opt_file.write('-------------- End ----------------\n')
        return self.opt


================================================
FILE: drawing_and_style_transfer/options/test_options.py
================================================
from .base_options import BaseOptions


class TestOptions(BaseOptions):
    def initialize(self):
        BaseOptions.initialize(self)
        self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
        self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
        self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
        self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        self.parser.add_argument('--which_epoch', type=str, default='latest',
                                 help='which epoch to load? set to latest to use latest cached model')
        self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
        self.isTrain = False


================================================
FILE: drawing_and_style_transfer/options/train_options.py
================================================
from .base_options import BaseOptions


class TrainOptions(BaseOptions):
    def initialize(self):
        BaseOptions.initialize(self)
        self.parser.add_argument('--display_freq', type=int, default=100,
                                 help='frequency of showing training results on screen')
        self.parser.add_argument('--display_single_pane_ncols', type=int, default=0,
                                 help='if positive, display all images in a single visdom web panel with certain number of images per row.')
        self.parser.add_argument('--update_html_freq', type=int, default=1000,
                                 help='frequency of saving training results to html')
        self.parser.add_argument('--print_freq', type=int, default=100,
                                 help='frequency of showing training results on console')
        self.parser.add_argument('--save_latest_freq', type=int, default=10000,
                                 help='frequency of saving the latest results')
        self.parser.add_argument('--save_epoch_freq', type=int, default=10,
                                 help='frequency of saving checkpoints at the end of epochs')
        self.parser.add_argument('--continue_train', action='store_true',
                                 help='continue training: load the latest model')
        self.parser.add_argument('--epoch_count', type=int, default=1,
                                 help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
        self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
        self.parser.add_argument('--which_epoch', type=str, default='latest',
                                 help='which epoch to load? set to latest to use latest cached model')
        self.parser.add_argument('--niter', type=int, default=60, help='# of iter at starting learning rate')
        self.parser.add_argument('--niter_decay', type=int, default=20,
                                 help='# of iter to linearly decay learning rate to zero')
        self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
        self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
        self.parser.add_argument('--no_lsgan', action='store_true',
                                 help='do *not* use least square GAN, if false, use vanilla GAN')
        self.parser.add_argument('--kl_lambda', type=float, default=0.1, help='weight for kl loss')
        self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
        self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
        self.parser.add_argument('--pool_size', type=int, default=50,
                                 help='the size of image buffer that stores previously generated images')
        self.parser.add_argument('--dont_load_pretrained_autoencoder', action='store_true',
                                 help='do not load pretrained autoencoder')
        self.parser.add_argument('--no_html', action='store_true',
                                 help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
        self.parser.add_argument('--lr_policy', type=str, default='lambda',
                                 help='learning rate policy: lambda|step|plateau')
        self.parser.add_argument('--lr_decay_iters', type=int, default=50,
                                 help='multiply by a gamma every lr_decay_iters iterations')

        self.isTrain = True


================================================
FILE: drawing_and_style_transfer/scripts/test_ost.sh
================================================
# images to cityscapes
python test.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost --load_dir=cityscapes_ost --model=ost --no_dropout --n_downsampling=3 --num_unshared=3 --start=0 --max_items_A=1
# cityscapes to images
python test.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost_reverse --load_dir=cityscapes_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'

# images to facades
python test.py --dataroot=./datasets/facades/ --name=facades_ost --load_dir=facades_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1
# facades to images
python test.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --load_dir=facades_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'

# aerial view to maps
python test.py --dataroot=./datasets/maps/ --name=maps_ost --load_dir=maps_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1
# maps to aerial view
python test.py --dataroot=./datasets/maps/ --name=maps_ost_reverse --load_dir=maps_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'

# monet2photo
python test.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost --load_dir=monet2photo_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1
# photo2monet
python test.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost_reverse --load_dir=monet2photo_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A'

# summer2winter
python test.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost --load_dir=summer2winter_yosemite_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1
# winter2summer
python test.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost_reverse --load_dir=summer2winter_yosemite_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A'

================================================
FILE: drawing_and_style_transfer/scripts/train_autoencoder.sh
================================================
# images to cityscapes
python train.py --dataroot=./datasets/cityscapes/trainB --name=cityscapes_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=3 --num_unshared=3
# cityscapes to images
python train.py --dataroot=./datasets/cityscapes/trainA --name=cityscapes_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2

# images to facades
python train.py --dataroot=./datasets/facades/trainB --name=facades_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2
# facades to images
python train.py --dataroot=./datasets/facades/trainA --name=facades_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2

# aerial view to maps
python train.py --dataroot=./datasets/maps/trainB --name=maps_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2
# maps to aerial view
python train.py --dataroot=./datasets/maps/trainA --name=maps_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2

# monet2photo
python train.py --dataroot=./datasets/monet2photo/trainB --name=monet2photo_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0
# photo2monet
python train.py --dataroot=./datasets/monet2photo/trainA --name=monet2photo_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0

# summer2winter
python train.py --dataroot=./datasets/summer2winter_yosemite/trainB --name=summer2winter_yosemite_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0
# winter2summer
python train.py --dataroot=./datasets/summer2winter_yosemite/trainA --name=summer2winter_yosemite_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0

================================================
FILE: drawing_and_style_transfer/scripts/train_ost.sh
================================================
# images to cityscapes
python train.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost --load_dir=cityscapes_autoencoder --model=ost --no_dropout --n_downsampling=3 --num_unshared=3 --start=0 --max_items_A=1
# cityscapes to images
python train.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost_reverse --load_dir=cityscapes_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'

# images to facades
python train.py --dataroot=./datasets/facades/ --name=facades_ost --load_dir=facades_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1
# facades to images
python train.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --load_dir=facades_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'

# aerial view to maps
python train.py --dataroot=./datasets/maps/ --name=maps_ost --load_dir=maps_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1
# maps to aerial view
python train.py --dataroot=./datasets/maps/ --name=maps_ost_reverse --load_dir=maps_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'

# monet2photo
python train.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost --load_dir=monet2photo_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1
# photo2monet
python train.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost_reverse --load_dir=monet2photo_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A'

# summer2winter
python train.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost --load_dir=summer2winter_yosemite_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1
# winter2summer
python train.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost_reverse --load_dir=summer2winter_yosemite_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A'

================================================
FILE: drawing_and_style_transfer/test.py
================================================
import os
from options.test_options import TestOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
from util import html

if __name__ == '__main__':
    opt = TestOptions().parse()
    opt.nThreads = 1  # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    # We are interested on testing only on x's in A on which we trained
    opt.phase = 'train'
    if opt.max_items_A >= 0:
        opt.max_items_B = opt.max_items_A
    if opt.max_items_B >= 0:
        opt.max_items_A = opt.max_items_B

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    model = create_model(opt)
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
    # test
    for i, data in enumerate(dataset):
        if i >= opt.how_many:
            break
        model.set_input(data)
        model.test()
        visuals = model.get_current_visuals()
        img_path = model.get_image_paths()
        print('%04d: process image... %s' % (i, img_path))
        visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, index=i, split=0)

    webpage.save()


================================================
FILE: drawing_and_style_transfer/train.py
================================================
import time
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer

if __name__ == '__main__':
    opt = TrainOptions().parse()

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)
    total_steps = 0

    for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        iter_data_time = time.time()
        epoch_iter = 0

        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            if total_steps % opt.print_freq == 0:
                t_data = iter_start_time - iter_data_time
            visualizer.reset()
            total_steps += opt.batchSize
            epoch_iter += opt.batchSize
            model.set_input(data)
            model.optimize_parameters()

            if total_steps % opt.display_freq == 0:
                save_result = total_steps % opt.update_html_freq == 0
                visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)

            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors()
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)

            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

            iter_data_time = time.time()
        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
        model.update_learning_rate()


================================================
FILE: drawing_and_style_transfer/util/__init__.py
================================================


================================================
FILE: drawing_and_style_transfer/util/get_data.py
================================================
from __future__ import print_function
import os
import tarfile
import requests
from warnings import warn
from zipfile import ZipFile
from bs4 import BeautifulSoup
from os.path import abspath, isdir, join, basename


class GetData(object):
    """

    Download CycleGAN or Pix2Pix Data.

    Args:
        technique : str
            One of: 'cyclegan' or 'pix2pix'.
        verbose : bool
            If True, print additional information.

    Examples:
        >>> from util.get_data import GetData
        >>> gd = GetData(technique='cyclegan')
        >>> new_data_path = gd.get(save_path='./datasets')  # options will be displayed.

    """

    def __init__(self, technique='cyclegan', verbose=True):
        url_dict = {
            'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
            'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
        }
        self.url = url_dict.get(technique.lower())
        self._verbose = verbose

    def _print(self, text):
        if self._verbose:
            print(text)

    @staticmethod
    def _get_options(r):
        soup = BeautifulSoup(r.text, 'lxml')
        options = [h.text for h in soup.find_all('a', href=True)
                   if h.text.endswith(('.zip', 'tar.gz'))]
        return options

    def _present_options(self):
        r = requests.get(self.url)
        options = self._get_options(r)
        print('Options:\n')
        for i, o in enumerate(options):
            print("{0}: {1}".format(i, o))
        choice = input("\nPlease enter the number of the "
                       "dataset above you wish to download:")
        return options[int(choice)]

    def _download_data(self, dataset_url, save_path):
        if not isdir(save_path):
            os.makedirs(save_path)

        base = basename(dataset_url)
        temp_save_path = join(save_path, base)

        with open(temp_save_path, "wb") as f:
            r = requests.get(dataset_url)
            f.write(r.content)

        if base.endswith('.tar.gz'):
            obj = tarfile.open(temp_save_path)
        elif base.endswith('.zip'):
            obj = ZipFile(temp_save_path, 'r')
        else:
            raise ValueError("Unknown File Type: {0}.".format(base))

        self._print("Unpacking Data...")
        obj.extractall(save_path)
        obj.close()
        os.remove(temp_save_path)

    def get(self, save_path, dataset=None):
        """

        Download a dataset.

        Args:
            save_path : str
                A directory to save the data to.
            dataset : str, optional
                A specific dataset to download.
                Note: this must include the file extension.
                If None, options will be presented for you
                to choose from.

        Returns:
            save_path_full : str
                The absolute path to the downloaded data.

        """
        if dataset is None:
            selected_dataset = self._present_options()
        else:
            selected_dataset = dataset

        save_path_full = join(save_path, selected_dataset.split('.')[0])

        if isdir(save_path_full):
            warn("\n'{0}' already exists. Voiding Download.".format(
                save_path_full))
        else:
            self._print('Downloading Data...')
            url = "{0}/{1}".format(self.url, selected_dataset)
            self._download_data(url, save_path=save_path)

        return abspath(save_path_full)


================================================
FILE: drawing_and_style_transfer/util/html.py
================================================
import dominate
from dominate.tags import *
import os


class HTML:
    def __init__(self, web_dir, title, reflesh=0):
        self.title = title
        self.web_dir = web_dir
        self.img_dir = os.path.join(self.web_dir, 'images')
        if not os.path.exists(self.web_dir):
            os.makedirs(self.web_dir)
        if not os.path.exists(self.img_dir):
            os.makedirs(self.img_dir)
        # print(self.img_dir)

        self.doc = dominate.document(title=title)
        if reflesh > 0:
            with self.doc.head:
                meta(http_equiv="reflesh", content=str(reflesh))

    def get_image_dir(self):
        return self.img_dir

    def add_header(self, str):
        with self.doc:
            h3(str)

    def add_table(self, border=1):
        self.t = table(border=border, style="table-layout: fixed;")
        self.doc.add(self.t)

    def add_images(self, ims, txts, links, width=400):
        self.add_table()
        with self.t:
            with tr():
                for im, txt, link in zip(ims, txts, links):
                    with td(style="word-wrap: break-word;", halign="center", valign="top"):
                        with p():
                            with a(href=os.path.join('images', link)):
                                img(style="width:%dpx" % width, src=os.path.join('images', im))
                            br()
                            p(txt)

    def save(self):
        html_file = '%s/index.html' % self.web_dir
        f = open(html_file, 'wt')
        f.write(self.doc.render())
        f.close()


if __name__ == '__main__':
    html = HTML('web/', 'test_html')
    html.add_header('hello world')

    ims = []
    txts = []
    links = []
    for n in range(4):
        ims.append('image_%d.png' % n)
        txts.append('text_%d' % n)
        links.append('image_%d.png' % n)
    html.add_images(ims, txts, links)
    html.save()


================================================
FILE: drawing_and_style_transfer/util/image_pool.py
================================================
import random
import torch
from torch.autograd import Variable


class ImagePool():
    def __init__(self, pool_size):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return Variable(images)
        return_images = []
        for image in images:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size - 1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images


================================================
FILE: drawing_and_style_transfer/util/util.py
================================================
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os


# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8):
    image_numpy = image_tensor[0].cpu().float().numpy()
    if image_numpy.shape[0] == 1:
        image_numpy = np.tile(image_numpy, (3, 1, 1))
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
    return image_numpy.astype(imtype)


def diagnose_network(net, name='network'):
    mean = 0.0
    count = 0
    for param in net.parameters():
        if param.grad is not None:
            mean += torch.mean(torch.abs(param.grad.data))
            count += 1
    if count > 0:
        mean = mean / count
    print(name)
    print(mean)


def save_image(image_numpy, image_path):
    image_pil = Image.fromarray(image_numpy)
    image_pil.save(image_path)


def print_numpy(x, val=True, shp=False):
    x = x.astype(np.float64)
    if shp:
        print('shape,', x.shape)
    if val:
        x = x.flatten()
        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))


def mkdirs(paths):
    if isinstance(paths, list) and not isinstance(paths, str):
        for path in paths:
            mkdir(path)
    else:
        mkdir(paths)


def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)


================================================
FILE: drawing_and_style_transfer/util/visualizer.py
================================================
import numpy as np
import os
import ntpath
import time
from . import util
from . import html
from scipy.misc import imresize


class Visualizer():
    def __init__(self, opt):
        # self.opt = opt
        self.display_id = opt.display_id
        self.use_html = opt.isTrain and not opt.no_html
        self.win_size = opt.display_winsize
        self.name = opt.name
        self.opt = opt
        self.saved = False
        if self.display_id > 0:
            import visdom
            self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port)

        if self.use_html:
            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
            self.img_dir = os.path.join(self.web_dir, 'images')
            print('create web directory %s...' % self.web_dir)
            util.mkdirs([self.web_dir, self.img_dir])
        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
        with open(self.log_name, "a") as log_file:
            now = time.strftime("%c")
            log_file.write('================ Training Loss (%s) ================\n' % now)

    def reset(self):
        self.saved = False

    # |visuals|: dictionary of images to display or save
    def display_current_results(self, visuals, epoch, save_result):
        if self.display_id > 0:  # show images in the browser
            ncols = self.opt.display_single_pane_ncols
            if ncols > 0:
                h, w = next(iter(visuals.values())).shape[:2]
                table_css = """<style>
                        table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
                        table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
                        </style>""" % (w, h)
                title = self.name
                label_html = ''
                label_html_row = ''
                nrows = int(np.ceil(len(visuals.items()) / ncols))
                images = []
                idx = 0
                for label, image_numpy in visuals.items():
                    label_html_row += '<td>%s</td>' % label
                    images.append(image_numpy.transpose([2, 0, 1]))
                    idx += 1
                    if idx % ncols == 0:
                        label_html += '<tr>%s</tr>' % label_html_row
                        label_html_row = ''
                white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
                while idx % ncols != 0:
                    images.append(white_image)
                    label_html_row += '<td></td>'
                    idx += 1
                if label_html_row != '':
                    label_html += '<tr>%s</tr>' % label_html_row
                # pane col = image row
                self.vis.images(images, nrow=ncols, win=self.display_id + 1,
                                padding=2, opts=dict(title=title + ' images'))
                label_html = '<table>%s</table>' % label_html
                self.vis.text(table_css + label_html, win=self.display_id + 2,
                              opts=dict(title=title + ' labels'))
            else:
                idx = 1
                for label, image_numpy in visuals.items():
                    self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
                                   win=self.display_id + idx)
                    idx += 1

        if self.use_html and (save_result or not self.saved):  # save images to a html file
            self.saved = True
            for label, image_numpy in visuals.items():
                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
                util.save_image(image_numpy, img_path)
            # update website
            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
            for n in range(epoch, 0, -1):
                webpage.add_header('epoch [%d]' % n)
                ims = []
                txts = []
                links = []

                for label, image_numpy in visuals.items():
                    img_path = 'epoch%.3d_%s.png' % (n, label)
                    ims.append(img_path)
                    txts.append(label)
                    links.append(img_path)
                webpage.add_images(ims, txts, links, width=self.win_size)
            webpage.save()

    # errors: dictionary of error labels and values
    def plot_current_errors(self, epoch, counter_ratio, opt, errors):
        if not hasattr(self, 'plot_data'):
            self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}
        self.plot_data['X'].append(epoch + counter_ratio)
        self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
        self.vis.line(
            X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
            Y=np.array(self.plot_data['Y']),
            opts={
                'title': self.name + ' loss over time',
                'legend': self.plot_data['legend'],
                'xlabel': 'epoch',
                'ylabel': 'loss'},
            win=self.display_id)

    # errors: same format as |errors| of plotCurrentErrors
    def print_current_errors(self, epoch, i, errors, t, t_data):
        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
        for k, v in errors.items():
            message += '%s: %.3f ' % (k, v)

        print(message)
        with open(self.log_name, "a") as log_file:
            log_file.write('%s\n' % message)

    # save image to the disk
    def save_images(self, webpage, visuals, image_path, aspect_ratio=1.0, index=None, split=1):
        image_dir = webpage.get_image_dir()
        short_path = ntpath.basename(image_path[0])
        name = os.path.splitext(short_path)[0]
        if index is not None:
            name_splits = name.split("_")
            if split == 0:
                name = str(index)
            else:
                name = str(index) + "_" + name_splits[split]

        webpage.add_header(name)
        ims = []
        txts = []
        links = []

        for label, im in visuals.items():
            image_name = '%s_%s.png' % (name, label)
            save_path = os.path.join(image_dir, image_name)
            h, w, _ = im.shape
            if aspect_ratio > 1.0:
                im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
            if aspect_ratio < 1.0:
                im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
            util.save_image(im, save_path)

            ims.append(image_name)
            txts.append(label)
            links.append(image_name)
        webpage.add_images(ims, txts, links, width=self.win_size)


================================================
FILE: mnist_to_svhn/data_loader.py
================================================
import torch
from torchvision import datasets
from torchvision import transforms


def get_loader(config):
    """Builds and returns Dataloader for MNIST and SVHN dataset."""

    transform_list = []

    if config.use_augmentation:
        transform_list.append(transforms.RandomHorizontalFlip())
        transform_list.append(transforms.RandomRotation(0.1))

    transform_list.append(transforms.Scale(config.image_size))
    transform_list.append(transforms.ToTensor())
    transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))

    transform_test = transforms.Compose([
        transforms.Scale(config.image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    transform_train = transforms.Compose(transform_list)

    svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform_train, split='train')
    mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform_train, train=True)

    svhn_test = datasets.SVHN(root=config.svhn_path, download=True, transform=transform_test, split='test')
    mnist_test = datasets.MNIST(root=config.mnist_path, download=True, transform=transform_test, train=False)

    svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
                                              batch_size=config.batch_size,
                                              shuffle=config.shuffle,
                                              num_workers=config.num_workers)

    mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
                                               batch_size=config.batch_size,
                                               shuffle=config.shuffle,
                                               num_workers=config.num_workers)

    svhn_test_loader = torch.utils.data.DataLoader(dataset=svhn_test,
                                                   batch_size=config.batch_size,
                                                   shuffle=False,
                                                   num_workers=config.num_workers)

    mnist_test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                                    batch_size=config.batch_size,
                                                    shuffle=False,
                                                    num_workers=config.num_workers)

    return svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader


================================================
FILE: mnist_to_svhn/download.sh
================================================
mkdir -p mnist
mkdir -p svhn

wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat
wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat
wget -O svhn/extra_32x32.mat http://ufldl.stanford.edu/housenumbers/extra_32x32.mat

================================================
FILE: mnist_to_svhn/main_autoencoder.py
================================================
import argparse
import os
from torch.backends import cudnn

from solver_autoencoder import Solver
from data_loader import get_loader


def str2bool(v):
    return v.lower() in ('true')


def main(config):
    svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader(config)

    solver = Solver(config, svhn_loader, mnist_loader)
    cudnn.benchmark = True

    # create directories if not exist
    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)

    if config.mode == 'train':
        solver.train()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # model hyper-parameters
    parser.add_argument('--image_size', type=int, default=32)
    parser.add_argument('--g_conv_dim', type=int, default=64)
    parser.add_argument('--d_conv_dim', type=int, default=64)
    parser.add_argument('--num_classes', type=int, default=10)

    # training hyper-parameters
    parser.add_argument('--train_iters', type=int, default=15000)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--lr', type=float, default=0.0002)
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--beta2', type=float, default=0.999)
    parser.add_argument('--kl_lambda', type=float, default=0.1)

    # misc
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument('--model_path', type=str, default='./models_autoencoder')
    parser.add_argument('--sample_path', type=str, default='./samples_autoencoder')
    parser.add_argument('--mnist_path', type=str, default='./mnist')
    parser.add_argument('--svhn_path', type=str, default='./svhn')
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=500)
    parser.add_argument('--shuffle', type=bool, default=True)
    parser.add_argument('--use_augmentation', required=True, type=str2bool)

    config = parser.parse_args()
    print(config)
    main(config)


================================================
FILE: mnist_to_svhn/main_mnist_to_svhn.py
================================================
import argparse
import logging
import os
from torch.backends import cudnn

from data_loader import get_loader
from solver_mnist_to_svhn import Solver


def str2bool(v):
    return v.lower() in ('true')


def main(config):
    svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader(config)

    solver = Solver(config, svhn_loader, mnist_loader)
    cudnn.benchmark = True

    # create directories if not exist
    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)

    base = config.log_path
    filename = os.path.join(base, str(config.max_items))
    if not os.path.isdir(base):
        os.mkdir(base)
    logging.basicConfig(filename=filename, level=logging.DEBUG)

    if config.mode == 'train':
        solver.train()
    elif config.mode == 'sample':
        solver.sample()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # model hyper-parameters
    parser.add_argument('--image_size', type=int, default=32)
    parser.add_argument('--g_conv_dim', type=int, default=64)
    parser.add_argument('--d_conv_dim', type=int, default=64)
    parser.add_argument('--num_classes', type=int, default=10)

    # training hyper-parameters
    parser.add_argument('--train_iters', type=int, default=40000)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--lr', type=float, default=0.0002)
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--beta2', type=float, default=0.999)
    parser.add_argument('--kl_lambda', type=float, default=0.1)

    # misc
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument('--mnist_path', type=str, default='./mnist')
    parser.add_argument('--svhn_path', type=str, default='./svhn')
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--shuffle', type=bool, default=True)

    parser.add_argument('--load_iter', type=int, default=10000)
    parser.add_argument('--sample_step', type=int, default=500)
    parser.add_argument('--num_averaging_runs', type=int, default=1000)
    parser.add_argument('--num_iters_save_model_and_return', type=int, default=5000)
    parser.add_argument('--num_d_iterations', type=int, default=1)
    parser.add_argument('--num_g_iterations', type=int, default=1)
    parser.add_argument('--model_path', type=str, default='./models_ost')
    parser.add_argument('--sample_path', type=str, default='./samples_ost')
    parser.add_argument('--load_path', type=str, default='./models_autoencoder')
    parser.add_argument('--log_path', type=str, default='logs_ost')
    parser.add_argument('--pretrained_g', required=True, type=str2bool)
    parser.add_argument('--save_models_and_samples', required=True, type=str2bool)
    parser.add_argument('--use_augmentation', required=True, type=str2bool)
    parser.add_argument('--one_way_cycle', required=True, type=str2bool)
    parser.add_argument('--freeze_shared', required=True, type=str2bool)
    parser.add_argument('--max_items', type=int, default=1)

    config = parser.parse_args()
    print(config)
    main(config)


================================================
FILE: mnist_to_svhn/main_svhn_to_mnist.py
================================================
import argparse
import logging
import os

from data_loader import get_loader
from solver_svhn_to_mnist import Solver
from torch.backends import cudnn


def str2bool(v):
    return v.lower() in ('true')


def main(config):
    svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader(config)

    solver = Solver(config, svhn_loader, mnist_loader)
    cudnn.benchmark = True

    # create directories if not exist
    if not os.path.exists(config.model_path):
        os.makedirs(config.model_path)
    if not os.path.exists(config.sample_path):
        os.makedirs(config.sample_path)

    base = config.log_path
    filename = os.path.join(base, str(config.max_items))
    if not os.path.isdir(base):
        os.mkdir(base)
    logging.basicConfig(filename=filename, level=logging.DEBUG)

    if config.mode == 'train':
        solver.train()

    elif config.mode == 'sample':
        solver.sample()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # model hyper-parameters
    parser.add_argument('--image_size', type=int, default=32)
    parser.add_argument('--g_conv_dim', type=int, default=64)
    parser.add_argument('--d_conv_dim', type=int, default=64)
    parser.add_argument('--num_classes', type=int, default=10)

    # training hyper-parameters
    parser.add_argument('--train_iters', type=int, default=40000)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--num_workers', type=int, default=2)
    parser.add_argument('--lr', type=float, default=0.0002)
    parser.add_argument('--beta1', type=float, default=0.5)
    parser.add_argument('--beta2', type=float, default=0.999)
    parser.add_argument('--kl_lambda', type=float, default=0.1)

    # misc
    parser.add_argument('--mode', type=str, default='train')
    parser.add_argument('--mnist_path', type=str, default='./mnist')
    parser.add_argument('--svhn_path', type=str, default='./svhn')
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--shuffle', type=bool, default=True)

    parser.add_argument('--load_iter', type=int, default=10000)
    parser.add_argument('--sample_step', type=int, default=500)
    parser.add_argument('--num_averaging_runs', type=int, default=1000)
    parser.add_argument('--num_iters_save_model_and_return', type=int, default=5000)
    parser.add_argument('--num_d_iterations', type=int, default=1)
    parser.add_argument('--num_g_iterations', type=int, default=1)
    parser.add_argument('--model_path', type=str, default='./models_ost')
    parser.add_argument('--sample_path', type=str, default='./samples_ost')
    parser.add_argument('--load_path', type=str, default='./models_autoencoder')
    parser.add_argument('--log_path', type=str, default='logs_ost')
    parser.add_argument('--pretrained_g', required=True, type=str2bool)
    parser.add_argument('--save_models_and_samples', required=True, type=str2bool)
    parser.add_argument('--use_augmentation', required=True, type=str2bool)
    parser.add_argument('--one_way_cycle', required=True, type=str2bool)
    parser.add_argument('--freeze_shared', required=True, type=str2bool)
    parser.add_argument('--max_items', type=int, default=1)

    config = parser.parse_args()
    print(config)
    main(config)


================================================
FILE: mnist_to_svhn/model.py
================================================
import torch.nn as nn
import torch.nn.functional as F


def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    """Custom deconvolutional layer for simplicity."""
    layers = []
    layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)


def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
    """Custom convolutional layer for simplicity."""
    layers = []
    layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False))
    if bn:
        layers.append(nn.BatchNorm2d(c_out))
    return nn.Sequential(*layers)


class G11(nn.Module):
    def __init__(self, conv_dim=64):
        super(G11, self).__init__()

        # encoding blocks
        self.conv1 = conv(1, conv_dim, 4)
        self.conv1_svhn = conv(3, conv_dim, 4)
        self.conv2 = conv(conv_dim, conv_dim * 2, 4)

        # residual blocks
        res_dim = conv_dim * 2
        self.conv3 = conv(res_dim, res_dim, 3, 1, 1)
        self.conv4 = conv(res_dim, res_dim, 3, 1, 1)

        # decoding blocks
        self.deconv1 = deconv(conv_dim * 2, conv_dim, 4)
        self.deconv2 = deconv(conv_dim, 1, 4, bn=False)
        self.deconv2_svhn = deconv(conv_dim, 3, 4, bn=False)

    def forward(self, x, svhn=False):
        if svhn:
            out = F.leaky_relu(self.conv1_svhn(x), 0.05)  # (?, 64, 16, 16)
        else:
            out = F.leaky_relu(self.conv1(x), 0.05)  # (?, 64, 16, 16)

        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 8, 8)
        out = F.leaky_relu(self.conv3(out), 0.05)  # ( " )
        out = F.leaky_relu(self.conv4(out), 0.05)  # ( " )
        out = F.leaky_relu(self.deconv1(out), 0.05)  # (?, 64, 16, 16)

        if svhn:
            out = F.tanh(self.deconv2_svhn(out))  # (?, 3, 32, 32)
        else:
            out = F.tanh(self.deconv2(out))  # (?, 3, 32, 32)

        return out

    def encode(self, x, svhn=False):

        if svhn:
            out = F.leaky_relu(self.conv1_svhn(x), 0.05)  # (?, 64, 16, 16)
        else:
            out = F.leaky_relu(self.conv1(x), 0.05)  # (?, 64, 16, 16)

        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 8, 8)
        out = F.leaky_relu(self.conv3(out), 0.05)  # ( " )

        return out

    def decode(self, out, svhn=False):

        out = F.leaky_relu(self.conv4(out), 0.05)
        out = F.leaky_relu(self.deconv1(out), 0.05)  # (?, 64, 16, 16)

        if svhn:
            out = F.tanh(self.deconv2_svhn(out))  # (?, 3, 32, 32)
        else:
            out = F.tanh(self.deconv2(out))  # (?, 3, 32, 32)

        return out

    def encode_params(self):
        layers_basic = list(self.conv1_svhn.parameters()) + \
                       list(self.conv1.parameters())
        layers_basic += list(self.conv2.parameters())
        layers_basic += list(self.conv3.parameters())

        return layers_basic

    def decode_params(self):
        layers_basic = list(self.deconv2_svhn.parameters()) + \
                       list(self.deconv2.parameters())
        layers_basic += list(self.deconv1.parameters())
        layers_basic += list(self.conv4.parameters())

        return layers_basic

    def unshared_parameters(self):
        return list(self.deconv2_svhn.parameters()) + list(self.conv1_svhn.parameters()) + \
               list(self.deconv2.parameters()) + list(self.conv1.parameters())


class G22(nn.Module):
    def __init__(self, conv_dim=64):
        super(G22, self).__init__()

        # encoding blocks
        self.conv1 = conv(3, conv_dim, 4)
        self.conv1_mnist = conv(1, conv_dim, 4)
        self.conv2 = conv(conv_dim, conv_dim * 2, 4)

        # residual blocks
        res_dim = conv_dim * 2
        self.conv3 = conv(res_dim, res_dim, 3, 1, 1)
        self.conv4 = conv(res_dim, res_dim, 3, 1, 1)

        # decoding blocks
        self.deconv1 = deconv(conv_dim * 2, conv_dim, 4)
        self.deconv2 = deconv(conv_dim, 3, 4, bn=False)
        self.deconv2_mnist = deconv(conv_dim, 1, 4, bn=False)

    def forward(self, x, mnist=False):
        if mnist:
            out = F.leaky_relu(self.conv1_mnist(x), 0.05)  # (?, 64, 16, 16)
        else:
            out = F.leaky_relu(self.conv1(x), 0.05)  # (?, 64, 16, 16)

        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 8, 8)
        out = F.leaky_relu(self.conv3(out), 0.05)  # ( " )
        out = F.leaky_relu(self.conv4(out), 0.05)  # ( " )
        out = F.leaky_relu(self.deconv1(out), 0.05)  # (?, 64, 16, 16)

        if mnist:
            out = F.tanh(self.deconv2_mnist(out))  # (?, 3, 32, 32)
        else:
            out = F.tanh(self.deconv2(out))  # (?, 3, 32, 32)

        return out

    def encode(self, x, mnist=False):

        if mnist:
            out = F.leaky_relu(self.conv1_mnist(x), 0.05)  # (?, 64, 16, 16)
        else:
            out = F.leaky_relu(self.conv1(x), 0.05)  # (?, 64, 16, 16)

        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 8, 8)
        out = F.leaky_relu(self.conv3(out), 0.05)  # ( " )

        return out

    def decode(self, out, mnist=False):

        out = F.leaky_relu(self.conv4(out), 0.05)
        out = F.leaky_relu(self.deconv1(out), 0.05)  # (?, 64, 16, 16)

        if mnist:
            out = F.tanh(self.deconv2_mnist(out))  # (?, 3, 32, 32)
        else:
            out = F.tanh(self.deconv2(out))  # (?, 3, 32, 32)

        return out

    def encode_params(self):
        layers_basic = list(self.conv1_mnist.parameters()) + \
                       list(self.conv1.parameters())
        layers_basic += list(self.conv2.parameters())
        layers_basic += list(self.conv3.parameters())

        return layers_basic

    def decode_params(self):
        layers_basic = list(self.deconv2_mnist.parameters()) + \
                       list(self.deconv2.parameters())
        layers_basic += list(self.deconv1.parameters())
        layers_basic += list(self.conv4.parameters())

        return layers_basic

    def unshared_parameters(self):
        return list(self.deconv2_mnist.parameters()) + list(self.conv1_mnist.parameters()) + \
               list(self.deconv2.parameters()) + list(self.conv1.parameters())


class D1(nn.Module):
    """Discriminator for mnist."""

    def __init__(self, conv_dim=64, use_labels=False):
        super(D1, self).__init__()
        self.conv1 = conv(1, conv_dim, 4, bn=False)
        self.conv2 = conv(conv_dim, conv_dim * 2, 4)
        self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4)
        n_out = 11 if use_labels else 1
        self.fc = conv(conv_dim * 4, n_out, 4, 1, 0, False)

    def forward(self, x_0):
        out = F.leaky_relu(self.conv1(x_0), 0.05)  # (?, 64, 16, 16)
        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 8, 8)
        out = F.leaky_relu(self.conv3(out), 0.05)  # (?, 256, 4, 4)
        out_0 = self.fc(out).squeeze()

        return out_0


class D2(nn.Module):
    """Discriminator for svhn."""

    def __init__(self, conv_dim=64, use_labels=False):
        super(D2, self).__init__()
        self.conv1 = conv(3, conv_dim, 4, bn=False)
        self.conv2 = conv(conv_dim, conv_dim * 2, 4)
        self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4)
        n_out = 11 if use_labels else 1
        self.fc = conv(conv_dim * 4, n_out, 4, 1, 0, False)

    def forward(self, x_0):
        out = F.leaky_relu(self.conv1(x_0), 0.05)  # (?, 64, 16, 16)
        out = F.leaky_relu(self.conv2(out), 0.05)  # (?, 128, 8, 8)
        out = F.leaky_relu(self.conv3(out), 0.05)  # (?, 256, 4, 4)
        out_0 = self.fc(out).squeeze()

        return out_0


================================================
FILE: mnist_to_svhn/solver_autoencoder.py
================================================
import os

import numpy as np
import scipy.io
import torch
from torch import optim
from torch.autograd import Variable

from model import D1, D2
from model import G11, G22


class Solver(object):
    def __init__(self, config, svhn_loader, mnist_loader):
        self.svhn_loader = svhn_loader
        self.mnist_loader = mnist_loader
        self.g11 = None
        self.g22 = None
        self.d1 = None
        self.d2 = None
        self.g_optimizer = None
        self.d_optimizer = None
        self.num_classes = config.num_classes
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.train_iters = config.train_iters
        self.batch_size = config.batch_size
        self.lr = config.lr
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.sample_path = config.sample_path
        self.model_path = config.model_path
        self.kl_lambda = config.kl_lambda
        self.build_model()

    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g11 = G11(conv_dim=self.g_conv_dim)
        self.g22 = G22(conv_dim=self.g_conv_dim)
        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False)
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False)

        g_params = list(self.g11.parameters()) + list(self.g22.parameters())
        d_params = list(self.d1.parameters()) + list(self.d2.parameters())

        self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.g11.cuda()
            self.g22.cuda()
            self.d1.cuda()
            self.d2.cuda()

    def merge_images(self, sources, targets, k=10):
        _, _, h, w = sources.shape
        row = int(np.sqrt(self.batch_size))
        merged = np.zeros([3, row * h, row * w * 2])
        for idx, (s, t) in enumerate(zip(sources, targets)):
            i = idx // row
            j = idx % row
            merged[:, i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h] = s
            merged[:, i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t
        return merged.transpose(1, 2, 0)

    def to_var(self, x):
        """Converts numpy to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)

    def to_data(self, x):
        """Converts variable to numpy."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data.numpy()

    def reset_grad(self):
        """Zeros the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def _compute_kl(self, mu):
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def train(self):
        svhn_iter = iter(self.svhn_loader)
        mnist_iter = iter(self.mnist_loader)
        iter_per_epoch = min(len(svhn_iter), len(mnist_iter))

        # fixed mnist and svhn for sampling
        fixed_svhn = self.to_var(svhn_iter.next()[0])
        fixed_mnist = self.to_var(mnist_iter.next()[0])

        # Train autoencoder for mnist
        for step in range(self.train_iters + 1):
            # reset data_iter for each epoch
            if (step + 1) % iter_per_epoch == 0:
                mnist_iter = iter(self.mnist_loader)

            # mnist dataset
            mnist_data, m_labels_data = mnist_iter.next()
            mnist, m_labels = self.to_var(mnist_data), self.to_var(m_labels_data)

            # ============ train D ============#
            # train with real images
            self.reset_grad()
            
            out = self.d1(mnist)
            d1_loss = torch.mean((out - 1) ** 2)

            d_mnist_loss = d1_loss
            d_real_loss = d1_loss
            d_real_loss.backward()
            self.d_optimizer.step()

            # train with fake images
            self.reset_grad()
            fake_mnist = self.g22.forward(mnist, mnist=True)
            out = self.d1(fake_mnist)
            d2_loss = torch.mean(out ** 2)

            d_fake_loss = d2_loss
            d_fake_loss.backward()
            self.d_optimizer.step()

            # ============ train G ============
            self.reset_grad()
            fake_mnist = self.g22.forward(mnist, mnist=True)
            out = self.d1(fake_mnist)
            g_loss = torch.mean((out - 1) ** 2)
            g_loss += torch.mean((mnist - fake_mnist) ** 2)
            em = self.g22.encode(mnist, mnist=True)
            g_loss += self.kl_lambda * self._compute_kl(em)

            g_loss.backward()
            self.g_optimizer.step()

            # print the log info
            if (step + 1) % self.log_step == 0:
                print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, '
                      'g_loss: %.4f'
                      % (step + 1, self.train_iters, d_real_loss.data[0], d_mnist_loss.data[0],
                         d_fake_loss.data[0], g_loss.data[0]))

            # save the sampled images
            if (step + 1) % self.sample_step == 0:
                fake_mnist = self.g22.forward(fixed_mnist, mnist=True)
                mnist, fake_mnist = self.to_data(fixed_mnist), self.to_data(fake_mnist)

                merged = self.merge_images(mnist, fake_mnist)
                path = os.path.join(self.sample_path, 'sample-%d-m-s.png' % (step + 1))
                scipy.misc.imsave(path, merged)
                print('saved %s' % path)

            if (step + 1) % 10000 == 0:
                # save the model parameters for each epoch
                g22_path = os.path.join(self.model_path, 'g22-%d.pkl' % (step + 1))
                d1_path = os.path.join(self.model_path, 'd1-%d.pkl' % (step + 1))
                torch.save(self.g22.state_dict(), g22_path)
                torch.save(self.d1.state_dict(), d1_path)

        # Train autoencoder for svhn
        for step in range(self.train_iters + 1):
            # reset data_iter for each epoch
            if (step + 1) % iter_per_epoch == 0:
                svhn_iter = iter(self.svhn_loader)

            # load svhn and mnist dataset
            svhn_data, s_labels_data = svhn_iter.next()
            svhn, s_labels = self.to_var(svhn_data), self.to_var(s_labels_data).long().squeeze()

            # ============ train D ============#

            # train with real images
            self.reset_grad()

            out = self.d2(svhn)
            d2_loss = torch.mean((out - 1) ** 2)

            d_svhn_loss = d2_loss
            d_real_loss = d2_loss
            d_real_loss.backward()
            self.d_optimizer.step()

            # train with fake images
            self.reset_grad()

            fake_svhn = self.g11.forward(svhn, svhn=True)
            out = self.d2(fake_svhn)
            d1_loss = torch.mean(out ** 2)

            d_fake_loss = d1_loss
            d_fake_loss.backward()
            self.d_optimizer.step()

            # ============ train G ============#

            self.reset_grad()
            fake_svhn = self.g11.forward(svhn, svhn=True)
            out = self.d2(fake_svhn)
            g_loss = torch.mean((out - 1) ** 2)
            g_loss += torch.mean((svhn - fake_svhn) ** 2)
            es = self.g11.encode(svhn, svhn=True)
            g_loss += self.kl_lambda * self._compute_kl(es)

            g_loss.backward()
            self.g_optimizer.step()

            # print the log info
            if (step + 1) % self.log_step == 0:
                print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f,'
                      'd_fake_loss: %.4f, g_loss: %.4f'
                      % (step + 1, self.train_iters, d_real_loss.data[0],
                         d_svhn_loss.data[0], d_fake_loss.data[0], g_loss.data[0]))

            # save the sampled images
            if (step + 1) % self.sample_step == 0:
                fake_svhn = self.g11.forward(fixed_svhn, svhn=True)
                svhn, fake_svhn = self.to_data(fixed_svhn), self.to_data(fake_svhn)

                merged = self.merge_images(svhn, fake_svhn)
                path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (step + 1))
                scipy.misc.imsave(path, merged)
                print('saved %s' % path)

            if (step + 1) % 10000 == 0:
                # save the model parameters for each epoch
                g11_path = os.path.join(self.model_path, 'g11-%d.pkl' % (step + 1))
                d2_path = os.path.join(self.model_path, 'd2-%d.pkl' % (step + 1))
                torch.save(self.g11.state_dict(), g11_path)
                torch.save(self.d2.state_dict(), d2_path)


================================================
FILE: mnist_to_svhn/solver_mnist_to_svhn.py
================================================
import os

import numpy as np
import scipy.io
import torch
from torch import optim
from torch.autograd import Variable

from model import D1, D2
from model import G11


class Solver(object):
    def __init__(self, config, svhn_loader, mnist_loader):
        self.config = config
        self.svhn_loader = svhn_loader
        self.mnist_loader = mnist_loader
        self.g11 = None
        self.g22 = None
        self.d1 = None
        self.d2 = None
        self.g_optimizer = None
        self.num_classes = config.num_classes
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.train_iters = config.train_iters
        self.batch_size = config.batch_size
        self.lr = config.lr
        self.kl_lambda = config.kl_lambda
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.sample_path = config.sample_path
        self.model_path = config.model_path
        self.g11_load_path = os.path.join(config.load_path, "g11-" + str(config.load_iter) + ".pkl")
        self.d1_load_path = os.path.join(config.load_path, "d1-" + str(config.load_iter) + ".pkl")
        self.g22_load_path = os.path.join(config.load_path, "g22-" + str(config.load_iter) + ".pkl")
        self.d2_load_path = os.path.join(config.load_path, "d2-" + str(config.load_iter) + ".pkl")
        self.build_model()

    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g11 = G11(conv_dim=self.g_conv_dim)
        self.g_optimizer = optim.Adam(list(self.g11.encode_params()) + list(self.g11.decode_params()), self.lr,
                                      [self.beta1, self.beta2])
        self.unshared_optimizer = optim.Adam(list(self.g11.unshared_parameters()), self.lr,
                                             [self.beta1, self.beta2])

        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False)
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False)

        self.d_optimizer = optim.Adam(list(self.d1.parameters()) + list(self.d2.parameters()), self.lr,
                                      [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.g11.cuda()
            self.d1.cuda()
            self.d2.cuda()

    def merge_images(self, sources, targets, k=10):
        _, _, h, w = sources.shape
        row = int(np.sqrt(self.batch_size)) + 1
        merged = np.zeros([3, row * h, row * w * 2])
        for idx, (s, t) in enumerate(zip(sources, targets)):
            i = idx // row
            j = idx % row
            merged[:, i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h] = s
            merged[:, i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t
        return merged.transpose(1, 2, 0)

    def to_var(self, x, volatile=False):
        """Converts numpy to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        if volatile:
            return Variable(x, volatile=True)
        return Variable(x)

    def to_no_grad_var(self, x):
        x = self.to_data(x, no_numpy=True)
        return self.to_var(x, volatile=True)

    def to_data(self, x, no_numpy=False):
        """Converts variable to numpy."""
        if torch.cuda.is_available():
            x = x.cpu()
        if no_numpy:
            return x.data
        return x.data.numpy()

    def reset_grad(self):
        """Zeros the gradient buffers."""
        self.unshared_optimizer.zero_grad()
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def _compute_kl(self, mu):

        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def train(self):
        self.build_model()
        if self.config.pretrained_g:
            self.g11.load_state_dict(torch.load(self.g11_load_path))

        svhn_iter = iter(self.svhn_loader)
        mnist_iter = iter(self.mnist_loader)
        iter_per_epoch = min(len(svhn_iter), len(mnist_iter))

        # fixed mnist and svhn for sampling
        svhn_fixed_data, svhn_fixed_labels = svhn_iter.next()
        mnist_fixed_data, mnist_fixed_labels = mnist_iter.next()
        fixed_mnist = self.to_var(mnist_fixed_data)
        counter = 0

        for step in range(self.train_iters + 1):

            # reset data_iter for each epoch
            if (step + 1) % iter_per_epoch == 0:
                mnist_iter = iter(self.mnist_loader)
                svhn_iter = iter(self.svhn_loader)

            # load svhn and mnist dataset
            svhn_data, s_labels_data = svhn_iter.next()
            mnist_data, m_labels_data = mnist_iter.next()
            svhn, s_labels = self.to_var(svhn_data), self.to_var(s_labels_data).long().squeeze()
            mnist, m_labels = self.to_var(mnist_data), self.to_var(m_labels_data)

            # This sets the maximum number of items for A domain
            # We assume max_items is a multiple of batch_size
            # And reset mnist loader when we pass the number of allowed items.
            if self.batch_size > self.config.max_items:
                exit(-1)
            elif self.batch_size == self.config.max_items:
                mnist = fixed_mnist
            elif self.batch_size < self.config.max_items:
                counter += 1
                if counter * self.batch_size >= self.config.max_items:
                    mnist_iter = iter(self.mnist_loader)
                    counter = 0

            # ============ train D ============#
            # train with real images
            self.reset_grad()
            out = self.d1(mnist)
            d1_loss = torch.mean((out - 1) ** 2)

            out = self.d2(svhn)
            d2_loss = torch.mean((out - 1) ** 2)

            d_mnist_loss = d1_loss
            d_svhn_loss = d2_loss
            # Only optimizing d1
            d_real_loss = d1_loss + d2_loss
            d_real_loss.backward()
            self.d_optimizer.step()

            # train with fake images
            self.reset_grad()
            es = self.g11.encode(svhn, svhn=True)
            fake_mnist = self.g11.decode(es)
            out = self.d1(fake_mnist)
            d2_loss = torch.mean(out ** 2)

            em = self.g11.encode(mnist)
            fake_svhn = self.g11.decode(em, svhn=True)
            out = self.d2(fake_svhn)
            d1_loss = torch.mean(out ** 2)

            d_fake_loss = d2_loss + d1_loss
            d_fake_loss.backward()
            self.d_optimizer.step()

            # ============ train G ============#

            # train mnist-svhn-mnist cycle
            self.reset_grad()
            es = self.g11.encode(svhn, svhn=True)
            fake_mnist = self.g11.decode(es)
            out = self.d1(fake_mnist)
            g_loss = torch.mean((out - 1) ** 2)

            em = self.g11.encode(mnist)
            fake_svhn = self.g11.decode(em, svhn=True)
            out = self.d2(fake_svhn)
            g_loss += torch.mean((out - 1) ** 2)

            self.reset_grad()
            em = self.g11.encode(mnist)
            fake_mnist = self.g11.decode(em)
            g_loss += torch.mean((mnist - fake_mnist) ** 2)

            if self.config.one_way_cycle:
                em = self.g11.encode(mnist)
                fake_svhn = self.g11.decode(em, svhn=True)
                es = self.g11.encode(fake_svhn, svhn=True)
                fake_mnist = self.g11.decode(es)
                g_loss += torch.mean((mnist - fake_mnist) ** 2)

            g_loss.backward()
            self.unshared_optimizer.step()

            if not self.config.freeze_shared:
                self.reset_grad()
                es = self.g11.encode(svhn, svhn=True)
                fake_es = self.g11.decode(es, svhn=True)
                g_loss = torch.mean((svhn - fake_es) ** 2)
                g_loss += self.kl_lambda * self._compute_kl(es)

                g_loss.backward()
                self.g_optimizer.step()

            # print the log info
            if (step + 1) % self.log_step == 0:
                print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, '
                      'd_fake_loss: %.4f, g_loss: %.4f'
                      % (step + 1, self.train_iters, d_real_loss.data[0], d_mnist_loss.data[0],
                         d_svhn_loss.data[0], d_fake_loss.data[0], g_loss.data[0]))

            # save the sampled images
            if (step + 1) % self.sample_step == 0:
                em = self.g11.encode(fixed_mnist)
                fake_svhn_var = self.g11.decode(em, svhn=True)
                fake_svhn = self.to_data(fake_svhn_var)
                if self.config.save_models_and_samples:
                    merged = self.merge_images(mnist_fixed_data, fake_svhn)
                    path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (step + 1))
                    scipy.misc.imsave(path, merged)
                    print('saved %s' % path)

            if (step + 1) % self.config.num_iters_save_model_and_return == 0:
                # save the model parameters for each epoch
                if self.config.save_models_and_samples:
                    g11_path = os.path.join(self.model_path, 'g11-%d.pkl' % (step + 1))
                    d1_path = os.path.join(self.model_path, 'd1-%d.pkl' % (step + 1))
                    d2_path = os.path.join(self.model_path, 'd2-%d.pkl' % (step + 1))
                    torch.save(self.g11.state_dict(), g11_path)
                    torch.save(self.d1.state_dict(), d1_path)
                    torch.save(self.d2.state_dict(), d2_path)

                return


================================================
FILE: mnist_to_svhn/solver_svhn_to_mnist.py
================================================
import os

import numpy as np
import scipy.io
import torch
from torch import optim
from torch.autograd import Variable

from model import D1, D2
from model import G22


class Solver(object):
    def __init__(self, config, svhn_loader, mnist_loader):
        self.config = config
        self.svhn_loader = svhn_loader
        self.mnist_loader = mnist_loader
        self.g11 = None
        self.g22 = None
        self.d1 = None
        self.d2 = None
        self.g_optimizer = None
        self.num_classes = config.num_classes
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.train_iters = config.train_iters
        self.batch_size = config.batch_size
        self.lr = config.lr
        self.kl_lambda = config.kl_lambda
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.sample_path = config.sample_path
        self.model_path = config.model_path
        self.g11_load_path = os.path.join(config.load_path, "g11-" + str(config.load_iter) + ".pkl")
        self.d1_load_path = os.path.join(config.load_path, "d1-" + str(config.load_iter) + ".pkl")
        self.g22_load_path = os.path.join(config.load_path, "g22-" + str(config.load_iter) + ".pkl")
        self.d2_load_path = os.path.join(config.load_path, "d2-" + str(config.load_iter) + ".pkl")
        self.build_model()

    def build_model(self):
        """Builds a generator and a discriminator."""
        self.g22 = G22(conv_dim=self.g_conv_dim)
        self.g_optimizer = optim.Adam(list(self.g22.encode_params()) + list(self.g22.decode_params()), self.lr,
                                      [self.beta1, self.beta2])
        self.unshared_optimizer = optim.Adam(list(self.g22.unshared_parameters()), self.lr,
                                             [self.beta1, self.beta2])

        self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False)
        self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False)

        self.d_optimizer = optim.Adam(list(self.d1.parameters()) + list(self.d2.parameters()), self.lr,
                                      [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.g22.cuda()
            self.d1.cuda()
            self.d2.cuda()

    def merge_images(self, sources, targets, k=10):
        _, _, h, w = sources.shape
        row = int(np.sqrt(self.batch_size)) + 1
        merged = np.zeros([3, row * h, row * w * 2])
        for idx, (s, t) in enumerate(zip(sources, targets)):
            i = idx // row
            j = idx % row
            merged[:, i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h] = s
            merged[:, i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t
        return merged.transpose(1, 2, 0)

    def to_var(self, x, volatile=False):
        """Converts numpy to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        if volatile:
            return Variable(x, volatile=True)
        return Variable(x)

    def to_no_grad_var(self, x):
        x = self.to_data(x, no_numpy=True)
        return self.to_var(x, volatile=True)

    def to_data(self, x, no_numpy=False):
        """Converts variable to numpy."""
        if torch.cuda.is_available():
            x = x.cpu()
        if no_numpy:
            return x.data
        return x.data.numpy()

    def reset_grad(self):
        """Zeros the gradient buffers."""
        self.unshared_optimizer.zero_grad()
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def _compute_kl(self, mu):
        mu_2 = torch.pow(mu, 2)
        encoding_loss = torch.mean(mu_2)
        return encoding_loss

    def train(self):
        self.build_model()
        if self.config.pretrained_g:
            self.g22.load_state_dict(torch.load(self.g22_load_path))

        svhn_iter = iter(self.svhn_loader)
        mnist_iter = iter(self.mnist_loader)
        iter_per_epoch = min(len(svhn_iter), len(mnist_iter))

        # fixed mnist and svhn for sampling
        svhn_fixed_data, svhn_fixed_labels = svhn_iter.next()
        mnist_fixed_data, mnist_fixed_labels = mnist_iter.next()
        fixed_svhn = self.to_var(svhn_fixed_data)
        counter = 0

        for step in range(self.train_iters + 1):

            # reset data_iter for each epoch
            if (step + 1) % iter_per_epoch == 0:
                mnist_iter = iter(self.mnist_loader)
                svhn_iter = iter(self.svhn_loader)

            # load svhn and mnist dataset
            svhn_data, s_labels_data = svhn_iter.next()
            mnist_data, m_labels_data = mnist_iter.next()
            svhn, s_labels = self.to_var(svhn_data), self.to_var(s_labels_data).long().squeeze()
            mnist, m_labels = self.to_var(mnist_data), self.to_var(m_labels_data)

            # This sets the maximum number of items for A domain
            # We assume max_items is a multiple of batch_size
            # And reset mnist loader when we pass the number of allowed items.
            if self.batch_size > self.config.max_items:
                exit(-1)
            elif self.batch_size == self.config.max_items:
                svhn = fixed_svhn
            elif self.batch_size < self.config.max_items:
                counter += 1
                if counter * self.batch_size >= self.config.max_items:
                    svhn_iter = iter(self.svhn_loader)
                    counter = 0

            # ============ train D ============#
            # train with real images
            self.reset_grad()
            out = self.d1(mnist)
            d1_loss = torch.mean((out - 1) ** 2)

            out = self.d2(svhn)
            d2_loss = torch.mean((out - 1) ** 2)

            d_mnist_loss = d1_loss
            d_svhn_loss = d2_loss
            # Only optimizing d1
            d_real_loss = d1_loss + d2_loss
            d_real_loss.backward()
            self.d_optimizer.step()

            # train with fake images
            self.reset_grad()
            es = self.g22.encode(svhn)
            fake_mnist = self.g22.decode(es, mnist=True)
            out = self.d1(fake_mnist)
            d2_loss = torch.mean(out ** 2)

            em = self.g22.encode(mnist, mnist=True)
            fake_svhn = self.g22.decode(em)
            out = self.d2(fake_svhn)
            d1_loss = torch.mean(out ** 2)

            d_fake_loss = d2_loss + d1_loss
            d_fake_loss.backward()
            self.d_optimizer.step()

            # ============ train G ============#

            self.reset_grad()
            es = self.g22.encode(svhn)
            fake_mnist = self.g22.decode(es, mnist=True)
            out = self.d1(fake_mnist)
            g_loss = torch.mean((out - 1) ** 2)

            em = self.g22.encode(mnist, mnist=True)
            fake_svhn = self.g22.decode(em)
            out = self.d2(fake_svhn)
            g_loss += torch.mean((out - 1) ** 2)

            self.reset_grad()
            es = self.g22.encode(svhn)
            fake_svhn = self.g22.decode(es)
            g_loss += torch.mean((svhn - fake_svhn) ** 2)

            if self.config.one_way_cycle:
                es = self.g22.encode(svhn)
                fake_mnist = self.g22.decode(es, mnist=True)
                es = self.g22.encode(fake_mnist, mnist=True)
                fake_svhn = self.g22.decode(es)
                g_loss += torch.mean((svhn - fake_svhn) ** 2)

            g_loss.backward()
            self.unshared_optimizer.step()

            if not self.config.freeze_shared:
                self.reset_grad()
                em = self.g22.encode(mnist, mnist=True)
                fake_em = self.g22.decode(em, mnist=True)
                g_loss = torch.mean((mnist - fake_em) ** 2)
                g_loss += self.kl_lambda * self._compute_kl(em)

                g_loss.backward()
                self.g_optimizer.step()

            # print the log info
            if (step + 1) % self.log_step == 0:
                print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, '
                      'd_fake_loss: %.4f, g_loss: %.4f'
                      % (step + 1, self.train_iters, d_real_loss.data[0], d_mnist_loss.data[0],
                         d_svhn_loss.data[0], d_fake_loss.data[0], g_loss.data[0]))

            # save the sampled images
            if (step + 1) % self.sample_step == 0:
                es = self.g22.encode(fixed_svhn)
                fake_mnist_var = self.g22.decode(es, mnist=True)
                fake_mnist = self.to_data(fake_mnist_var)
                if self.config.save_models_and_samples:
                    merged = self.merge_images(svhn_fixed_data, fake_mnist)
                    path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (step + 1))
                    scipy.misc.imsave(path, merged)
                    print('saved %s' % path)

            if (step + 1) % self.config.num_iters_save_model_and_return == 0:
                # save the model parameters for each epoch
                if self.config.save_models_and_samples:
                    g22_path = os.path.join(self.model_path, 'g22-%d.pkl' % (step + 1))
                    d1_path = os.path.join(self.model_path, 'd1-%d.pkl' % (step + 1))
                    d2_path = os.path.join(self.model_path, 'd2-%d.pkl' % (step + 1))
                    torch.save(self.g22.state_dict(), g22_path)
                    torch.save(self.d1.state_dict(), d1_path)
                    torch.save(self.d2.state_dict(), d2_path)

                return
Download .txt
gitextract_3ffpn665/

├── LICENSE
├── README.md
├── drawing_and_style_transfer/
│   ├── data/
│   │   ├── __init__.py
│   │   ├── aligned_dataset.py
│   │   ├── base_data_loader.py
│   │   ├── base_dataset.py
│   │   ├── image_folder.py
│   │   ├── single_dataset.py
│   │   └── unaligned_dataset.py
│   ├── datasets/
│   │   ├── combine_A_and_B.py
│   │   ├── download_cyclegan_dataset.sh
│   │   └── make_dataset_aligned.py
│   ├── environment.yml
│   ├── models/
│   │   ├── __init__.py
│   │   ├── autoencoder_model.py
│   │   ├── base_model.py
│   │   ├── networks.py
│   │   ├── ost.py
│   │   └── test_model.py
│   ├── options/
│   │   ├── __init__.py
│   │   ├── base_options.py
│   │   ├── test_options.py
│   │   └── train_options.py
│   ├── scripts/
│   │   ├── test_ost.sh
│   │   ├── train_autoencoder.sh
│   │   └── train_ost.sh
│   ├── test.py
│   ├── train.py
│   └── util/
│       ├── __init__.py
│       ├── get_data.py
│       ├── html.py
│       ├── image_pool.py
│       ├── util.py
│       └── visualizer.py
└── mnist_to_svhn/
    ├── data_loader.py
    ├── download.sh
    ├── main_autoencoder.py
    ├── main_mnist_to_svhn.py
    ├── main_svhn_to_mnist.py
    ├── model.py
    ├── solver_autoencoder.py
    ├── solver_mnist_to_svhn.py
    └── solver_svhn_to_mnist.py
Download .txt
SYMBOL INDEX (239 symbols across 30 files)

FILE: drawing_and_style_transfer/data/__init__.py
  function CreateDataLoader (line 5) | def CreateDataLoader(opt):
  function CreateDataset (line 12) | def CreateDataset(opt):
  class CustomDatasetDataLoader (line 30) | class CustomDatasetDataLoader(BaseDataLoader):
    method name (line 31) | def name(self):
    method initialize (line 34) | def initialize(self, opt):
    method load_data (line 43) | def load_data(self):
    method __len__ (line 46) | def __len__(self):
    method __iter__ (line 49) | def __iter__(self):

FILE: drawing_and_style_transfer/data/aligned_dataset.py
  class AlignedDataset (line 10) | class AlignedDataset(BaseDataset):
    method initialize (line 11) | def initialize(self, opt):
    method __getitem__ (line 18) | def __getitem__(self, index):
    method __len__ (line 60) | def __len__(self):
    method name (line 63) | def name(self):

FILE: drawing_and_style_transfer/data/base_data_loader.py
  class BaseDataLoader (line 1) | class BaseDataLoader():
    method __init__ (line 2) | def __init__(self):
    method initialize (line 5) | def initialize(self, opt):
    method load_data (line 9) | def load_data(self):

FILE: drawing_and_style_transfer/data/base_dataset.py
  class BaseDataset (line 6) | class BaseDataset(data.Dataset):
    method __init__ (line 7) | def __init__(self):
    method name (line 10) | def name(self):
    method initialize (line 13) | def initialize(self, opt):
  function get_transform (line 17) | def get_transform(opt):
  function __scale_width (line 44) | def __scale_width(img, target_width):

FILE: drawing_and_style_transfer/data/image_folder.py
  function is_image_file (line 20) | def is_image_file(filename):
  function make_dataset (line 24) | def make_dataset(dir, max_items=-1, start=0):
  function default_loader (line 39) | def default_loader(path):
  class ImageFolder (line 43) | class ImageFolder(data.Dataset):
    method __init__ (line 44) | def __init__(self, root, transform=None, return_paths=False,
    method __getitem__ (line 58) | def __getitem__(self, index):
    method __len__ (line 68) | def __len__(self):

FILE: drawing_and_style_transfer/data/single_dataset.py
  class SingleDataset (line 7) | class SingleDataset(BaseDataset):
    method initialize (line 8) | def initialize(self, opt):
    method __getitem__ (line 19) | def __getitem__(self, index):
    method __len__ (line 34) | def __len__(self):
    method name (line 37) | def name(self):

FILE: drawing_and_style_transfer/data/unaligned_dataset.py
  class UnalignedDataset (line 8) | class UnalignedDataset(BaseDataset):
    method initialize (line 9) | def initialize(self, opt):
    method __getitem__ (line 23) | def __getitem__(self, index):
    method __len__ (line 53) | def __len__(self):
    method name (line 56) | def name(self):

FILE: drawing_and_style_transfer/datasets/make_dataset_aligned.py
  function get_file_paths (line 6) | def get_file_paths(folder):
  function align_images (line 20) | def align_images(a_file_paths, b_file_paths, target_path):

FILE: drawing_and_style_transfer/models/__init__.py
  function create_model (line 1) | def create_model(opt):

FILE: drawing_and_style_transfer/models/autoencoder_model.py
  class AutoEncoderModel (line 11) | class AutoEncoderModel(BaseModel):
    method name (line 12) | def name(self):
    method set_encoders_and_decoders (line 15) | def set_encoders_and_decoders(self, opt):
    method initialize (line 52) | def initialize(self, opt):
    method set_input (line 103) | def set_input(self, input):
    method forward (line 112) | def forward(self):
    method netEnc (line 115) | def netEnc(self, x):
    method netDec (line 118) | def netDec(self, x):
    method test (line 121) | def test(self):
    method get_image_paths (line 127) | def get_image_paths(self):
    method backward_D_basic (line 130) | def backward_D_basic(self, netD, real, fake):
    method backward_D (line 143) | def backward_D(self):
    method _compute_kl (line 148) | def _compute_kl(self, mu):
    method backward_G (line 153) | def backward_G(self):
    method optimize_parameters (line 172) | def optimize_parameters(self):
    method get_current_errors (line 188) | def get_current_errors(self):
    method get_current_visuals (line 192) | def get_current_visuals(self):
    method save (line 198) | def save(self, label):

FILE: drawing_and_style_transfer/models/base_model.py
  class BaseModel (line 5) | class BaseModel(object):
    method name (line 6) | def name(self):
    method initialize (line 9) | def initialize(self, opt):
    method set_input (line 17) | def set_input(self, input):
    method forward (line 20) | def forward(self):
    method test (line 24) | def test(self):
    method get_image_paths (line 27) | def get_image_paths(self):
    method optimize_parameters (line 30) | def optimize_parameters(self):
    method get_current_visuals (line 33) | def get_current_visuals(self):
    method get_current_errors (line 36) | def get_current_errors(self):
    method save (line 39) | def save(self, label):
    method save_network (line 43) | def save_network(self, network, network_label, epoch_label, gpu_ids):
    method load_network (line 51) | def load_network(self, network, network_label, epoch_label):
    method update_learning_rate (line 57) | def update_learning_rate(self):
    method as_np (line 63) | def as_np(self, data):

FILE: drawing_and_style_transfer/models/networks.py
  class pixel_norm (line 13) | class pixel_norm(nn.Module):
    method forward (line 14) | def forward(self, x, epsilon=1e-8):
  function weights_init_normal (line 18) | def weights_init_normal(m):
  function weights_init_xavier (line 30) | def weights_init_xavier(m):
  function weights_init_kaiming (line 42) | def weights_init_kaiming(m):
  function weights_init_orthogonal (line 54) | def weights_init_orthogonal(m):
  function init_weights (line 66) | def init_weights(net, init_type='normal'):
  function get_norm_layer (line 80) | def get_norm_layer(norm_type='instance'):
  function get_scheduler (line 92) | def get_scheduler(optimizer, opt):
  function define_ED (line 108) | def define_ED(input_nc, output_nc, ngf, which_model_netG, norm='batch', ...
  function define_G (line 141) | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', u...
  function define_D (line 170) | def define_D(input_nc, ndf, which_model_netD,
  function print_network (line 194) | def print_network(net):
  class GANLoss (line 211) | class GANLoss(nn.Module):
    method __init__ (line 212) | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_...
    method get_target_tensor (line 225) | def get_target_tensor(self, input, target_is_real):
    method __call__ (line 243) | def __call__(self, input, target_is_real):
  class ResnetEncoder (line 253) | class ResnetEncoder(nn.Module):
    method __init__ (line 254) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
    method forward (line 290) | def forward(self, input):
  class ResnetDecoder (line 297) | class ResnetDecoder(nn.Module):
    method __init__ (line 298) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
    method forward (line 333) | def forward(self, input):
  class ResnetGenerator (line 344) | class ResnetGenerator(nn.Module):
    method __init__ (line 345) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
    method forward (line 391) | def forward(self, input):
  class ResnetBlock (line 399) | class ResnetBlock(nn.Module):
    method __init__ (line 400) | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
    method build_conv_block (line 404) | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout,...
    method forward (line 436) | def forward(self, x):
  class UnetGenerator (line 445) | class UnetGenerator(nn.Module):
    method __init__ (line 446) | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
    method forward (line 467) | def forward(self, input):
  class UnetSkipConnectionBlock (line 477) | class UnetSkipConnectionBlock(nn.Module):
    method __init__ (line 478) | def __init__(self, outer_nc, inner_nc, input_nc=None,
    method forward (line 523) | def forward(self, x):
  class NLayerDiscriminator (line 531) | class NLayerDiscriminator(nn.Module):
    method __init__ (line 532) | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNo...
    method forward (line 575) | def forward(self, input):
  class PixelDiscriminator (line 582) | class PixelDiscriminator(nn.Module):
    method __init__ (line 583) | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_si...
    method forward (line 604) | def forward(self, input):

FILE: drawing_and_style_transfer/models/ost.py
  class OSTModel (line 11) | class OSTModel(BaseModel):
    method name (line 12) | def name(self):
    method _compute_kl (line 15) | def _compute_kl(self, mu):
    method set_encoders_and_decoders (line 20) | def set_encoders_and_decoders(self, opt):
    method initialize (line 67) | def initialize(self, opt):
    method set_input (line 144) | def set_input(self, input):
    method forward (line 155) | def forward(self):
    method test (line 159) | def test(self):
    method get_image_paths (line 178) | def get_image_paths(self):
    method backward_D_basic (line 181) | def backward_D_basic(self, netD, real, fake):
    method backward_D (line 194) | def backward_D(self):
    method backward_G (line 203) | def backward_G(self):
    method optimize_parameters (line 239) | def optimize_parameters(self):
    method get_current_errors (line 265) | def get_current_errors(self):
    method get_current_visuals (line 273) | def get_current_visuals(self):
    method save (line 286) | def save(self, label):

FILE: drawing_and_style_transfer/models/test_model.py
  class TestModel (line 8) | class TestModel(BaseModel):
    method name (line 9) | def name(self):
    method initialize (line 12) | def initialize(self, opt):
    method set_input (line 27) | def set_input(self, input):
    method test (line 35) | def test(self):
    method get_image_paths (line 40) | def get_image_paths(self):
    method get_current_visuals (line 43) | def get_current_visuals(self):

FILE: drawing_and_style_transfer/options/base_options.py
  class BaseOptions (line 7) | class BaseOptions():
    method __init__ (line 8) | def __init__(self):
    method initialize (line 12) | def initialize(self):
    method parse (line 74) | def parse(self):

FILE: drawing_and_style_transfer/options/test_options.py
  class TestOptions (line 4) | class TestOptions(BaseOptions):
    method initialize (line 5) | def initialize(self):

FILE: drawing_and_style_transfer/options/train_options.py
  class TrainOptions (line 4) | class TrainOptions(BaseOptions):
    method initialize (line 5) | def initialize(self):

FILE: drawing_and_style_transfer/util/get_data.py
  class GetData (line 11) | class GetData(object):
    method __init__ (line 29) | def __init__(self, technique='cyclegan', verbose=True):
    method _print (line 37) | def _print(self, text):
    method _get_options (line 42) | def _get_options(r):
    method _present_options (line 48) | def _present_options(self):
    method _download_data (line 58) | def _download_data(self, dataset_url, save_path):
    method get (line 81) | def get(self, save_path, dataset=None):

FILE: drawing_and_style_transfer/util/html.py
  class HTML (line 6) | class HTML:
    method __init__ (line 7) | def __init__(self, web_dir, title, reflesh=0):
    method get_image_dir (line 22) | def get_image_dir(self):
    method add_header (line 25) | def add_header(self, str):
    method add_table (line 29) | def add_table(self, border=1):
    method add_images (line 33) | def add_images(self, ims, txts, links, width=400):
    method save (line 45) | def save(self):

FILE: drawing_and_style_transfer/util/image_pool.py
  class ImagePool (line 6) | class ImagePool():
    method __init__ (line 7) | def __init__(self, pool_size):
    method query (line 13) | def query(self, images):

FILE: drawing_and_style_transfer/util/util.py
  function tensor2im (line 10) | def tensor2im(image_tensor, imtype=np.uint8):
  function diagnose_network (line 18) | def diagnose_network(net, name='network'):
  function save_image (line 31) | def save_image(image_numpy, image_path):
  function print_numpy (line 36) | def print_numpy(x, val=True, shp=False):
  function mkdirs (line 46) | def mkdirs(paths):
  function mkdir (line 54) | def mkdir(path):

FILE: drawing_and_style_transfer/util/visualizer.py
  class Visualizer (line 10) | class Visualizer():
    method __init__ (line 11) | def __init__(self, opt):
    method reset (line 33) | def reset(self):
    method display_current_results (line 37) | def display_current_results(self, visuals, epoch, save_result):
    method plot_current_errors (line 101) | def plot_current_errors(self, epoch, counter_ratio, opt, errors):
    method print_current_errors (line 117) | def print_current_errors(self, epoch, i, errors, t, t_data):
    method save_images (line 127) | def save_images(self, webpage, visuals, image_path, aspect_ratio=1.0, ...

FILE: mnist_to_svhn/data_loader.py
  function get_loader (line 6) | def get_loader(config):

FILE: mnist_to_svhn/main_autoencoder.py
  function str2bool (line 9) | def str2bool(v):
  function main (line 13) | def main(config):

FILE: mnist_to_svhn/main_mnist_to_svhn.py
  function str2bool (line 10) | def str2bool(v):
  function main (line 14) | def main(config):

FILE: mnist_to_svhn/main_svhn_to_mnist.py
  function str2bool (line 10) | def str2bool(v):
  function main (line 14) | def main(config):

FILE: mnist_to_svhn/model.py
  function deconv (line 5) | def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
  function conv (line 14) | def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
  class G11 (line 23) | class G11(nn.Module):
    method __init__ (line 24) | def __init__(self, conv_dim=64):
    method forward (line 42) | def forward(self, x, svhn=False):
    method encode (line 60) | def encode(self, x, svhn=False):
    method decode (line 72) | def decode(self, out, svhn=False):
    method encode_params (line 84) | def encode_params(self):
    method decode_params (line 92) | def decode_params(self):
    method unshared_parameters (line 100) | def unshared_parameters(self):
  class G22 (line 105) | class G22(nn.Module):
    method __init__ (line 106) | def __init__(self, conv_dim=64):
    method forward (line 124) | def forward(self, x, mnist=False):
    method encode (line 142) | def encode(self, x, mnist=False):
    method decode (line 154) | def decode(self, out, mnist=False):
    method encode_params (line 166) | def encode_params(self):
    method decode_params (line 174) | def decode_params(self):
    method unshared_parameters (line 182) | def unshared_parameters(self):
  class D1 (line 187) | class D1(nn.Module):
    method __init__ (line 190) | def __init__(self, conv_dim=64, use_labels=False):
    method forward (line 198) | def forward(self, x_0):
  class D2 (line 207) | class D2(nn.Module):
    method __init__ (line 210) | def __init__(self, conv_dim=64, use_labels=False):
    method forward (line 218) | def forward(self, x_0):

FILE: mnist_to_svhn/solver_autoencoder.py
  class Solver (line 13) | class Solver(object):
    method __init__ (line 14) | def __init__(self, config, svhn_loader, mnist_loader):
    method build_model (line 38) | def build_model(self):
    method merge_images (line 57) | def merge_images(self, sources, targets, k=10):
    method to_var (line 68) | def to_var(self, x):
    method to_data (line 74) | def to_data(self, x):
    method reset_grad (line 80) | def reset_grad(self):
    method _compute_kl (line 85) | def _compute_kl(self, mu):
    method train (line 90) | def train(self):

FILE: mnist_to_svhn/solver_mnist_to_svhn.py
  class Solver (line 13) | class Solver(object):
    method __init__ (line 14) | def __init__(self, config, svhn_loader, mnist_loader):
    method build_model (line 42) | def build_model(self):
    method merge_images (line 61) | def merge_images(self, sources, targets, k=10):
    method to_var (line 72) | def to_var(self, x, volatile=False):
    method to_no_grad_var (line 80) | def to_no_grad_var(self, x):
    method to_data (line 84) | def to_data(self, x, no_numpy=False):
    method reset_grad (line 92) | def reset_grad(self):
    method _compute_kl (line 98) | def _compute_kl(self, mu):
    method train (line 104) | def train(self):

FILE: mnist_to_svhn/solver_svhn_to_mnist.py
  class Solver (line 13) | class Solver(object):
    method __init__ (line 14) | def __init__(self, config, svhn_loader, mnist_loader):
    method build_model (line 42) | def build_model(self):
    method merge_images (line 61) | def merge_images(self, sources, targets, k=10):
    method to_var (line 72) | def to_var(self, x, volatile=False):
    method to_no_grad_var (line 80) | def to_no_grad_var(self, x):
    method to_data (line 84) | def to_data(self, x, no_numpy=False):
    method reset_grad (line 92) | def reset_grad(self):
    method _compute_kl (line 98) | def _compute_kl(self, mu):
    method train (line 103) | def train(self):
Condensed preview — 43 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (166K chars).
[
  {
    "path": "LICENSE",
    "chars": 3568,
    "preview": "MIT License\n\nCopyright (c) 2017 \n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this s"
  },
  {
    "path": "README.md",
    "chars": 3008,
    "preview": "# Pytorch implementation of One-Shot Unsupervised Cross Domain Translation ([arxiv](https://arxiv.org/abs/1806.06029)).\n"
  },
  {
    "path": "drawing_and_style_transfer/data/__init__.py",
    "chars": 1480,
    "preview": "import torch.utils.data\nfrom data.base_data_loader import BaseDataLoader\n\n\ndef CreateDataLoader(opt):\n    data_loader = "
  },
  {
    "path": "drawing_and_style_transfer/data/aligned_dataset.py",
    "chars": 2410,
    "preview": "import os.path\nimport random\nimport torchvision.transforms as transforms\nimport torch\nfrom data.base_dataset import Base"
  },
  {
    "path": "drawing_and_style_transfer/data/base_data_loader.py",
    "chars": 175,
    "preview": "class BaseDataLoader():\n    def __init__(self):\n        pass\n\n    def initialize(self, opt):\n        self.opt = opt\n    "
  },
  {
    "path": "drawing_and_style_transfer/data/base_dataset.py",
    "chars": 1735,
    "preview": "import torch.utils.data as data\nfrom PIL import Image\nimport torchvision.transforms as transforms\n\n\nclass BaseDataset(da"
  },
  {
    "path": "drawing_and_style_transfer/data/image_folder.py",
    "chars": 2080,
    "preview": "###############################################################################\n# Code from\n# https://github.com/pytorch"
  },
  {
    "path": "drawing_and_style_transfer/data/single_dataset.py",
    "chars": 1056,
    "preview": "import os.path\nfrom data.base_dataset import BaseDataset, get_transform\nfrom data.image_folder import make_dataset\nfrom "
  },
  {
    "path": "drawing_and_style_transfer/data/unaligned_dataset.py",
    "chars": 2051,
    "preview": "import os.path\nfrom data.base_dataset import BaseDataset, get_transform\nfrom data.image_folder import make_dataset\nfrom "
  },
  {
    "path": "drawing_and_style_transfer/datasets/combine_A_and_B.py",
    "chars": 2124,
    "preview": "import os\nimport numpy as np\nimport cv2\nimport argparse\n\nparser = argparse.ArgumentParser('create image pairs')\nparser.a"
  },
  {
    "path": "drawing_and_style_transfer/datasets/download_cyclegan_dataset.sh",
    "chars": 809,
    "preview": "FILE=$1\n\nif [[ $FILE != \"ae_photos\" && $FILE != \"apple2orange\" && $FILE != \"summer2winter_yosemite\" &&  $FILE != \"horse2"
  },
  {
    "path": "drawing_and_style_transfer/datasets/make_dataset_aligned.py",
    "chars": 2257,
    "preview": "import os\n\nfrom PIL import Image\n\n\ndef get_file_paths(folder):\n    image_file_paths = []\n    for root, dirs, filenames i"
  },
  {
    "path": "drawing_and_style_transfer/environment.yml",
    "chars": 224,
    "preview": "name: OST\nchannels:\n- peterjc123\n- defaults\ndependencies:\n- python=3.6.5\n- pytorch=0.4.0\n- scipy\n- pip:\n  - dominate==2."
  },
  {
    "path": "drawing_and_style_transfer/models/__init__.py",
    "chars": 684,
    "preview": "def create_model(opt):\n    print(opt.model)\n    if opt.model == 'ost':\n        assert (opt.dataset_mode == 'unaligned')\n"
  },
  {
    "path": "drawing_and_style_transfer/models/autoencoder_model.py",
    "chars": 8669,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom torch.autograd import Variable\nimport itertools\nimport util.util a"
  },
  {
    "path": "drawing_and_style_transfer/models/base_model.py",
    "chars": 1919,
    "preview": "import os\nimport torch\n\n\nclass BaseModel(object):\n    def name(self):\n        return 'BaseModel'\n\n    def initialize(sel"
  },
  {
    "path": "drawing_and_style_transfer/models/networks.py",
    "chars": 25147,
    "preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools\nfrom torch.autograd import Variable\nfrom t"
  },
  {
    "path": "drawing_and_style_transfer/models/ost.py",
    "chars": 14038,
    "preview": "import torch\nfrom collections import OrderedDict\nfrom torch.autograd import Variable\nimport itertools\nimport util.util a"
  },
  {
    "path": "drawing_and_style_transfer/models/test_model.py",
    "chars": 1606,
    "preview": "from torch.autograd import Variable\nfrom collections import OrderedDict\nimport util.util as util\nfrom .base_model import"
  },
  {
    "path": "drawing_and_style_transfer/options/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "drawing_and_style_transfer/options/base_options.py",
    "chars": 7094,
    "preview": "import argparse\nimport os\nfrom util import util\nimport torch\n\n\nclass BaseOptions():\n    def __init__(self):\n        self"
  },
  {
    "path": "drawing_and_style_transfer/options/test_options.py",
    "chars": 879,
    "preview": "from .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n    def initialize(self):\n        BaseOptions.in"
  },
  {
    "path": "drawing_and_style_transfer/options/train_options.py",
    "chars": 3703,
    "preview": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n    def initialize(self):\n        BaseOptions.i"
  },
  {
    "path": "drawing_and_style_transfer/scripts/test_ost.sh",
    "chars": 2217,
    "preview": "# images to cityscapes\npython test.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost --load_dir=cityscapes_ost "
  },
  {
    "path": "drawing_and_style_transfer/scripts/train_autoencoder.sh",
    "chars": 2019,
    "preview": "# images to cityscapes\npython train.py --dataroot=./datasets/cityscapes/trainB --name=cityscapes_autoencoder --model=aut"
  },
  {
    "path": "drawing_and_style_transfer/scripts/train_ost.sh",
    "chars": 2307,
    "preview": "# images to cityscapes\npython train.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost --load_dir=cityscapes_aut"
  },
  {
    "path": "drawing_and_style_transfer/test.py",
    "chars": 1482,
    "preview": "import os\nfrom options.test_options import TestOptions\nfrom data import CreateDataLoader\nfrom models import create_model"
  },
  {
    "path": "drawing_and_style_transfer/train.py",
    "chars": 2312,
    "preview": "import time\nfrom options.train_options import TrainOptions\nfrom data import CreateDataLoader\nfrom models import create_m"
  },
  {
    "path": "drawing_and_style_transfer/util/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "drawing_and_style_transfer/util/get_data.py",
    "chars": 3511,
    "preview": "from __future__ import print_function\nimport os\nimport tarfile\nimport requests\nfrom warnings import warn\nfrom zipfile im"
  },
  {
    "path": "drawing_and_style_transfer/util/html.py",
    "chars": 1912,
    "preview": "import dominate\nfrom dominate.tags import *\nimport os\n\n\nclass HTML:\n    def __init__(self, web_dir, title, reflesh=0):\n "
  },
  {
    "path": "drawing_and_style_transfer/util/image_pool.py",
    "chars": 1099,
    "preview": "import random\nimport torch\nfrom torch.autograd import Variable\n\n\nclass ImagePool():\n    def __init__(self, pool_size):\n "
  },
  {
    "path": "drawing_and_style_transfer/util/util.py",
    "chars": 1482,
    "preview": "from __future__ import print_function\nimport torch\nimport numpy as np\nfrom PIL import Image\nimport os\n\n\n# Converts a Ten"
  },
  {
    "path": "drawing_and_style_transfer/util/visualizer.py",
    "chars": 6788,
    "preview": "import numpy as np\nimport os\nimport ntpath\nimport time\nfrom . import util\nfrom . import html\nfrom scipy.misc import imre"
  },
  {
    "path": "mnist_to_svhn/data_loader.py",
    "chars": 2484,
    "preview": "import torch\nfrom torchvision import datasets\nfrom torchvision import transforms\n\n\ndef get_loader(config):\n    \"\"\"Builds"
  },
  {
    "path": "mnist_to_svhn/download.sh",
    "chars": 279,
    "preview": "mkdir -p mnist\nmkdir -p svhn\n\nwget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat\nwget -"
  },
  {
    "path": "mnist_to_svhn/main_autoencoder.py",
    "chars": 2164,
    "preview": "import argparse\nimport os\nfrom torch.backends import cudnn\n\nfrom solver_autoencoder import Solver\nfrom data_loader impor"
  },
  {
    "path": "mnist_to_svhn/main_mnist_to_svhn.py",
    "chars": 3292,
    "preview": "import argparse\nimport logging\nimport os\nfrom torch.backends import cudnn\n\nfrom data_loader import get_loader\nfrom solve"
  },
  {
    "path": "mnist_to_svhn/main_svhn_to_mnist.py",
    "chars": 3293,
    "preview": "import argparse\nimport logging\nimport os\n\nfrom data_loader import get_loader\nfrom solver_svhn_to_mnist import Solver\nfro"
  },
  {
    "path": "mnist_to_svhn/model.py",
    "chars": 7627,
    "preview": "import torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):\n    \""
  },
  {
    "path": "mnist_to_svhn/solver_autoencoder.py",
    "chars": 8815,
    "preview": "import os\n\nimport numpy as np\nimport scipy.io\nimport torch\nfrom torch import optim\nfrom torch.autograd import Variable\n\n"
  },
  {
    "path": "mnist_to_svhn/solver_mnist_to_svhn.py",
    "chars": 9652,
    "preview": "import os\n\nimport numpy as np\nimport scipy.io\nimport torch\nfrom torch import optim\nfrom torch.autograd import Variable\n\n"
  },
  {
    "path": "mnist_to_svhn/solver_svhn_to_mnist.py",
    "chars": 9609,
    "preview": "import os\n\nimport numpy as np\nimport scipy.io\nimport torch\nfrom torch import optim\nfrom torch.autograd import Variable\n\n"
  }
]

About this extraction

This page contains the full source code of the sagiebenaim/OneShotTranslation GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 43 files (155.3 KB), approximately 39.6k tokens, and a symbol index with 239 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!