Repository: sagiebenaim/OneShotTranslation Branch: master Commit: 6f790ae5f4eb Files: 43 Total size: 155.3 KB Directory structure: gitextract_3ffpn665/ ├── LICENSE ├── README.md ├── drawing_and_style_transfer/ │ ├── data/ │ │ ├── __init__.py │ │ ├── aligned_dataset.py │ │ ├── base_data_loader.py │ │ ├── base_dataset.py │ │ ├── image_folder.py │ │ ├── single_dataset.py │ │ └── unaligned_dataset.py │ ├── datasets/ │ │ ├── combine_A_and_B.py │ │ ├── download_cyclegan_dataset.sh │ │ └── make_dataset_aligned.py │ ├── environment.yml │ ├── models/ │ │ ├── __init__.py │ │ ├── autoencoder_model.py │ │ ├── base_model.py │ │ ├── networks.py │ │ ├── ost.py │ │ └── test_model.py │ ├── options/ │ │ ├── __init__.py │ │ ├── base_options.py │ │ ├── test_options.py │ │ └── train_options.py │ ├── scripts/ │ │ ├── test_ost.sh │ │ ├── train_autoencoder.sh │ │ └── train_ost.sh │ ├── test.py │ ├── train.py │ └── util/ │ ├── __init__.py │ ├── get_data.py │ ├── html.py │ ├── image_pool.py │ ├── util.py │ └── visualizer.py └── mnist_to_svhn/ ├── data_loader.py ├── download.sh ├── main_autoencoder.py ├── main_mnist_to_svhn.py ├── main_svhn_to_mnist.py ├── model.py ├── solver_autoencoder.py ├── solver_mnist_to_svhn.py └── solver_svhn_to_mnist.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2017 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. --------------------------- LICENSE FOR mnist-svhn-transfer --------- MIT License Copyright (c) 2017 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix --------- Copyright (c) 2017, Jun-Yan Zhu and Taesung Park All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: README.md ================================================ # Pytorch implementation of One-Shot Unsupervised Cross Domain Translation ([arxiv](https://arxiv.org/abs/1806.06029)). Prerequisites -------------- - Python 3.6 - Pytorch 0.4 - Numpy/Scipy/Pandas - Progressbar - OpenCV - [visdom](https://github.com/facebookresearch/visdom) - [dominate](https://github.com/Knio/dominate) ## MNIST-to-SVHN and SVHN-to-MNIST To train autoencoder for both MNIST and SVHN (In mnist_to_svhn folder): python main_autoencoder.py --use_augmentation=True To train OST for MNIST to SVHN: python main_mnist_to_svhn.py --pretrained_g=True --save_models_and_samples=True --use_augmentation=True --one_way_cycle=True --freeze_shared=False To train OST for SVHN to MNIST: python main_svhn_to_mnist.py --pretrained_g=True --save_models_and_samples=True --use_augmentation=True --one_way_cycle=True --freeze_shared=False ## Drawing and Style Transfer Tasks ### Download Dataset To download dataset (in drawing_and_style_transfer folder): bash datasets/download_cyclegan_dataset.sh $DATASET_NAME where DATASET_NAME is one of (facades, cityscapes, maps, monet2photo, summer2winter_yosemite) ### Train Autoencoder To train autoencoder for facades (in drawing_and_style_transfer folder): python train.py --dataroot=./datasets/facades/trainB --name=facades_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2 In the reverse direction (images of facades): python train.py --dataroot=./datasets/facades/trainA --name=facades_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2 ### Train OST To train OST for images to facades: python train.py --dataroot=./datasets/facades/ --name=facades_ost --load_dir=facades_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 To train OST for facades to images (reverse direction): python train.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --load_dir=facades_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A' To visualize losses: run python -m visdom.server ### Test OST To test OST for images to facades: python test.py --dataroot=./datasets/facades/ --name=facades_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 To test OST for facades to images (reverse direction): python test.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A' ### Options Additional scripts for other datasets are at ./drawing_and_style_transfer/scripts Options are at ./drawing_and_style_transfer/options ## Reference If you found this code useful, please cite the following paper: ``` @inproceedings{Benaim2018OneShotUC, title={One-Shot Unsupervised Cross Domain Translation}, author={Sagie Benaim and Lior Wolf}, booktitle={NeurIPS}, year={2018} } ``` ================================================ FILE: drawing_and_style_transfer/data/__init__.py ================================================ import torch.utils.data from data.base_data_loader import BaseDataLoader def CreateDataLoader(opt): data_loader = CustomDatasetDataLoader() print(data_loader.name()) data_loader.initialize(opt) return data_loader def CreateDataset(opt): if opt.dataset_mode == 'aligned': from data.aligned_dataset import AlignedDataset dataset = AlignedDataset() elif opt.dataset_mode == 'unaligned': from data.unaligned_dataset import UnalignedDataset dataset = UnalignedDataset() elif opt.dataset_mode == 'single': from data.single_dataset import SingleDataset dataset = SingleDataset() else: raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) print("dataset [%s] was created" % (dataset.name())) dataset.initialize(opt) return dataset class CustomDatasetDataLoader(BaseDataLoader): def name(self): return 'CustomDatasetDataLoader' def initialize(self, opt): BaseDataLoader.initialize(self, opt) self.dataset = CreateDataset(opt) self.dataloader = torch.utils.data.DataLoader( self.dataset, batch_size=opt.batchSize, shuffle=not opt.serial_batches, num_workers=int(opt.nThreads)) def load_data(self): return self def __len__(self): return len(self.dataset) def __iter__(self): for i, data in enumerate(self.dataloader): yield data ================================================ FILE: drawing_and_style_transfer/data/aligned_dataset.py ================================================ import os.path import random import torchvision.transforms as transforms import torch from data.base_dataset import BaseDataset from data.image_folder import make_dataset from PIL import Image class AlignedDataset(BaseDataset): def initialize(self, opt): self.opt = opt self.root = opt.dataroot self.dir_AB = os.path.join(opt.dataroot, opt.phase) self.AB_paths = sorted(make_dataset(self.dir_AB)) assert (opt.resize_or_crop == 'resize_and_crop') def __getitem__(self, index): AB_path = self.AB_paths[index] AB = Image.open(AB_path).convert('RGB') w, h = AB.size w2 = int(w / 2) A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) A = transforms.ToTensor()(A) B = transforms.ToTensor()(B) w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B) if self.opt.which_direction == 'BtoA': input_nc = self.opt.output_nc output_nc = self.opt.input_nc else: input_nc = self.opt.input_nc output_nc = self.opt.output_nc if (not self.opt.no_flip) and random.random() < 0.5: idx = [i for i in range(A.size(2) - 1, -1, -1)] idx = torch.LongTensor(idx) A = A.index_select(2, idx) B = B.index_select(2, idx) if input_nc == 1: # RGB to gray tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 A = tmp.unsqueeze(0) if output_nc == 1: # RGB to gray tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 B = tmp.unsqueeze(0) return {'A': A, 'B': B, 'A_paths': AB_path, 'B_paths': AB_path} def __len__(self): return len(self.AB_paths) def name(self): return 'AlignedDataset' ================================================ FILE: drawing_and_style_transfer/data/base_data_loader.py ================================================ class BaseDataLoader(): def __init__(self): pass def initialize(self, opt): self.opt = opt pass def load_data(self): return None ================================================ FILE: drawing_and_style_transfer/data/base_dataset.py ================================================ import torch.utils.data as data from PIL import Image import torchvision.transforms as transforms class BaseDataset(data.Dataset): def __init__(self): super(BaseDataset, self).__init__() def name(self): return 'BaseDataset' def initialize(self, opt): pass def get_transform(opt): transform_list = [] if opt.resize_or_crop == 'resize_and_crop': osize = [opt.loadSize, opt.loadSize] transform_list.append(transforms.Scale(osize, Image.BICUBIC)) transform_list.append(transforms.RandomCrop(opt.fineSize)) elif opt.resize_or_crop == 'crop': transform_list.append(transforms.RandomCrop(opt.fineSize)) elif opt.resize_or_crop == 'scale_width': transform_list.append(transforms.Lambda( lambda img: __scale_width(img, opt.fineSize))) elif opt.resize_or_crop == 'scale_width_and_crop': transform_list.append(transforms.Lambda( lambda img: __scale_width(img, opt.loadSize))) transform_list.append(transforms.RandomCrop(opt.fineSize)) if opt.isTrain and not opt.no_flip_and_rotation: # Default augmentations as in paper transform_list.append(transforms.RandomHorizontalFlip()) transform_list.append(transforms.RandomRotation(opt.rotation_degree)) transform_list += [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] return transforms.Compose(transform_list) def __scale_width(img, target_width): ow, oh = img.size if (ow == target_width): return img w = target_width h = int(target_width * oh / ow) return img.resize((w, h), Image.BICUBIC) ================================================ FILE: drawing_and_style_transfer/data/image_folder.py ================================================ ############################################################################### # Code from # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py # Modified the original code so that it also loads images from the current # directory as well as the subdirectories ############################################################################### import torch.utils.data as data from PIL import Image import os import os.path IMG_EXTENSIONS = [ '.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', ] def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) def make_dataset(dir, max_items=-1, start=0): images = [] assert os.path.isdir(dir), '%s is not a valid directory' % dir for root, _, fnames in sorted(os.walk(dir)): for fname in fnames: if is_image_file(fname): path = os.path.join(root, fname) images.append(path) if max_items >= 0: return sorted(images)[start:start + max_items] return images def default_loader(path): return Image.open(path).convert('RGB') class ImageFolder(data.Dataset): def __init__(self, root, transform=None, return_paths=False, loader=default_loader): imgs = make_dataset(root) if len(imgs) == 0: raise (RuntimeError("Found 0 images in: " + root + "\n" "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) self.root = root self.imgs = imgs self.transform = transform self.return_paths = return_paths self.loader = loader def __getitem__(self, index): path = self.imgs[index] img = self.loader(path) if self.transform is not None: img = self.transform(img) if self.return_paths: return img, path else: return img def __len__(self): return len(self.imgs) ================================================ FILE: drawing_and_style_transfer/data/single_dataset.py ================================================ import os.path from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image class SingleDataset(BaseDataset): def initialize(self, opt): self.opt = opt self.root = opt.dataroot self.dir_A = os.path.join(opt.dataroot) self.A_paths = make_dataset(self.dir_A) self.A_paths = sorted(self.A_paths) self.transform = get_transform(opt) def __getitem__(self, index): A_path = self.A_paths[index] A_img = Image.open(A_path).convert('RGB') A = self.transform(A_img) if self.opt.which_direction == 'BtoA': input_nc = self.opt.output_nc else: input_nc = self.opt.input_nc if input_nc == 1: # RGB to gray tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 A = tmp.unsqueeze(0) return {'A': A, 'A_paths': A_path} def __len__(self): return len(self.A_paths) def name(self): return 'SingleImageDataset' ================================================ FILE: drawing_and_style_transfer/data/unaligned_dataset.py ================================================ import os.path from data.base_dataset import BaseDataset, get_transform from data.image_folder import make_dataset from PIL import Image import random class UnalignedDataset(BaseDataset): def initialize(self, opt): self.opt = opt self.root = opt.dataroot self.dir_A = os.path.join(opt.dataroot, opt.phase + opt.A) self.dir_B = os.path.join(opt.dataroot, opt.phase + opt.B) self.A_paths = make_dataset(self.dir_A, max_items=opt.max_items_A, start=opt.start) self.B_paths = make_dataset(self.dir_B, max_items=opt.max_items_B, start=opt.start) self.A_paths = sorted(self.A_paths) self.B_paths = sorted(self.B_paths) self.A_size = len(self.A_paths) self.B_size = len(self.B_paths) self.transform = get_transform(opt) def __getitem__(self, index): A_path = self.A_paths[index % self.A_size] if self.opt.serial_batches: index_B = index % self.B_size else: index_B = random.randint(0, self.B_size - 1) B_path = self.B_paths[index_B] # print('(A, B) = (%d, %d)' % (index_A, index_B)) A_img = Image.open(A_path).convert('RGB') B_img = Image.open(B_path).convert('RGB') A = self.transform(A_img) B = self.transform(B_img) if self.opt.which_direction == 'BtoA': input_nc = self.opt.output_nc output_nc = self.opt.input_nc else: input_nc = self.opt.input_nc output_nc = self.opt.output_nc if input_nc == 1: # RGB to gray tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 A = tmp.unsqueeze(0) if output_nc == 1: # RGB to gray tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 B = tmp.unsqueeze(0) return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path} def __len__(self): return max(self.A_size, self.B_size) def name(self): return 'UnalignedDataset' ================================================ FILE: drawing_and_style_transfer/datasets/combine_A_and_B.py ================================================ import os import numpy as np import cv2 import argparse parser = argparse.ArgumentParser('create image pairs') parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges') parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg') parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB') parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true') args = parser.parse_args() for arg in vars(args): print('[%s] = ' % arg, getattr(args, arg)) splits = os.listdir(args.fold_A) for sp in splits: img_fold_A = os.path.join(args.fold_A, sp) img_fold_B = os.path.join(args.fold_B, sp) img_list = os.listdir(img_fold_A) if args.use_AB: img_list = [img_path for img_path in img_list if '_A.' in img_path] num_imgs = min(args.num_imgs, len(img_list)) print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) img_fold_AB = os.path.join(args.fold_AB, sp) if not os.path.isdir(img_fold_AB): os.makedirs(img_fold_AB) print('split = %s, number of images = %d' % (sp, num_imgs)) for n in range(num_imgs): name_A = img_list[n] path_A = os.path.join(img_fold_A, name_A) if args.use_AB: name_B = name_A.replace('_A.', '_B.') else: name_B = name_A path_B = os.path.join(img_fold_B, name_B) if os.path.isfile(path_A) and os.path.isfile(path_B): name_AB = name_A if args.use_AB: name_AB = name_AB.replace('_A.', '.') # remove _A path_AB = os.path.join(img_fold_AB, name_AB) im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR) im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR) im_AB = np.concatenate([im_A, im_B], 1) cv2.imwrite(path_AB, im_AB) ================================================ FILE: drawing_and_style_transfer/datasets/download_cyclegan_dataset.sh ================================================ FILE=$1 if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" exit 1 fi URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip ZIP_FILE=./datasets/$FILE.zip TARGET_DIR=./datasets/$FILE/ wget -N $URL -O $ZIP_FILE mkdir $TARGET_DIR unzip $ZIP_FILE -d ./datasets/ rm $ZIP_FILE ================================================ FILE: drawing_and_style_transfer/datasets/make_dataset_aligned.py ================================================ import os from PIL import Image def get_file_paths(folder): image_file_paths = [] for root, dirs, filenames in os.walk(folder): filenames = sorted(filenames) for filename in filenames: input_path = os.path.abspath(root) file_path = os.path.join(input_path, filename) if filename.endswith('.png') or filename.endswith('.jpg'): image_file_paths.append(file_path) break # prevent descending into subfolders return image_file_paths def align_images(a_file_paths, b_file_paths, target_path): if not os.path.exists(target_path): os.makedirs(target_path) for i in range(len(a_file_paths)): img_a = Image.open(a_file_paths[i]) img_b = Image.open(b_file_paths[i]) assert(img_a.size == img_b.size) aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1])) aligned_image.paste(img_a, (0, 0)) aligned_image.paste(img_b, (img_a.size[0], 0)) aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i))) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument( '--dataset-path', dest='dataset_path', help='Which folder to process (it should have subfolders testA, testB, trainA and trainB' ) args = parser.parse_args() dataset_folder = args.dataset_path print(dataset_folder) test_a_path = os.path.join(dataset_folder, 'testA') test_b_path = os.path.join(dataset_folder, 'testB') test_a_file_paths = get_file_paths(test_a_path) test_b_file_paths = get_file_paths(test_b_path) assert(len(test_a_file_paths) == len(test_b_file_paths)) test_path = os.path.join(dataset_folder, 'test') train_a_path = os.path.join(dataset_folder, 'trainA') train_b_path = os.path.join(dataset_folder, 'trainB') train_a_file_paths = get_file_paths(train_a_path) train_b_file_paths = get_file_paths(train_b_path) assert(len(train_a_file_paths) == len(train_b_file_paths)) train_path = os.path.join(dataset_folder, 'train') align_images(test_a_file_paths, test_b_file_paths, test_path) align_images(train_a_file_paths, train_b_file_paths, train_path) ================================================ FILE: drawing_and_style_transfer/environment.yml ================================================ name: OST channels: - peterjc123 - defaults dependencies: - python=3.6.5 - pytorch=0.4.0 - scipy - pip: - dominate==2.3.1 - git+https://github.com/pytorch/vision.git - Pillow==5.0.0 - numpy==1.14.1 - visdom==0.1.7 ================================================ FILE: drawing_and_style_transfer/models/__init__.py ================================================ def create_model(opt): print(opt.model) if opt.model == 'ost': assert (opt.dataset_mode == 'unaligned') from .ost import OSTModel model = OSTModel() elif opt.model == 'autoencoder': assert (opt.dataset_mode == 'single') from .autoencoder_model import AutoEncoderModel model = AutoEncoderModel() elif opt.model == 'test': assert (opt.dataset_mode == 'single') from .test_model import TestModel model = TestModel() else: raise NotImplementedError('model [%s] not implemented.' % opt.model) model.initialize(opt) print("model [%s] was created" % (model.name())) return model ================================================ FILE: drawing_and_style_transfer/models/autoencoder_model.py ================================================ import torch from collections import OrderedDict from torch.autograd import Variable import itertools import util.util as util from util.image_pool import ImagePool from .base_model import BaseModel from . import networks class AutoEncoderModel(BaseModel): def name(self): return 'AutoEncoderModel' def set_encoders_and_decoders(self, opt): n_downsampling = opt.n_downsampling start_unshared = 0 num_unshared = opt.num_unshared start_shared = num_unshared end_shared = n_downsampling start_dec_shared = start_unshared end_dec_shared = start_unshared + (end_shared - start_shared) start_dec_unshared = end_dec_shared end_dec_unshared = n_downsampling num_res_blocks_unshared = opt.num_res_blocks_unshared n_res_blocks_shared = opt.num_res_blocks_shared self.netEnc_b, self.netDec_b = networks.define_ED(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, n_blocks_encoder=num_res_blocks_unshared, n_blocks_decoder=num_res_blocks_unshared, start=start_unshared, end=num_unshared, n_downsampling=n_downsampling, input_layer=True, output_layer=True, start_dec=start_dec_unshared, end_dec=end_dec_unshared) self.netEnc_shared, self.netDec_shared = networks.define_ED(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, n_blocks_encoder=n_res_blocks_shared, n_blocks_decoder=n_res_blocks_shared, start=start_shared, n_downsampling=n_downsampling, end=end_shared, input_layer=False, output_layer=False, start_dec=start_dec_shared, end_dec=end_dec_shared) def initialize(self, opt): BaseModel.initialize(self, opt) self.set_encoders_and_decoders(opt) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netEnc_b, 'Enc_b', which_epoch) self.load_network(self.netDec_b, 'Dec_b', which_epoch) self.load_network(self.netEnc_shared, 'Enc_shared', which_epoch) self.load_network(self.netDec_shared, 'Dec_shared', which_epoch) if self.isTrain: self.load_network(self.netD, 'D', which_epoch) if self.isTrain: self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_Enc = torch.optim.Adam( itertools.chain(self.netEnc_b.parameters(), self.netEnc_shared.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_Dec = torch.optim.Adam( itertools.chain(self.netDec_b.parameters(), self.netDec_shared.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_Enc) self.optimizers.append(self.optimizer_Dec) self.optimizers.append(self.optimizer_D) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netEnc_b) networks.print_network(self.netDec_b) networks.print_network(self.netEnc_shared) networks.print_network(self.netDec_shared) if self.isTrain: networks.print_network(self.netD) print('-----------------------------------------------') def set_input(self, input): # 'A' is given as single_dataset input_B = input['A'] if len(self.gpu_ids) > 0: input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_B = input_B # 'A' is given as single_dataset self.image_paths = input['A_paths'] def forward(self): self.real_B = Variable(self.input_B) def netEnc(self, x): return self.netEnc_shared(self.netEnc_b(x)) def netDec(self, x): return self.netDec_b(self.netDec_shared(x)) def test(self): real_B = Variable(self.input_B, volatile=True) fake_B = self.netDec(self.netEnc(real_B)) self.fake_B = fake_B.data # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D(self): fake_B = self.fake_B_pool.query(self.fake_B) loss_D = self.backward_D_basic(self.netD, self.real_B, fake_B) self.loss_D = loss_D.data[0] def _compute_kl(self, mu): mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def backward_G(self): lambda_B = self.opt.lambda_B # GAN loss D_B(G_B(B)) enc_b = self.netEnc(self.real_B) fake_B = self.netDec(enc_b) pred_fake = self.netD(fake_B) loss_Gan = self.criterionGAN(pred_fake, True) loss_idt_B = self.criterionIdt(fake_B, self.real_B) * lambda_B loss_kl_B = self.opt.kl_lambda * self._compute_kl(enc_b) # combined loss loss_G = loss_Gan + loss_idt_B + loss_kl_B loss_G.backward() self.fake_B = fake_B.data self.loss_Gan = loss_Gan.data[0] self.loss_idt_B = loss_idt_B.data[0] def optimize_parameters(self): # forward self.forward() # G self.optimizer_Enc.zero_grad() self.optimizer_Dec.zero_grad() self.backward_G() self.optimizer_Enc.step() self.optimizer_Dec.step() # D self.optimizer_D.zero_grad() self.backward_D() self.optimizer_D.step() def get_current_errors(self): ret_errors = OrderedDict([('D', self.loss_D), ('G_B', self.loss_Gan), ('Idt_B', self.loss_idt_B)]) return ret_errors def get_current_visuals(self): real_B = util.tensor2im(self.input_B) fake_B = util.tensor2im(self.fake_B) ret_visuals = OrderedDict([('real_B', real_B), ('fake_B', fake_B), ]) return ret_visuals def save(self, label): self.save_network(self.netEnc_b, 'Enc_b', label, self.gpu_ids) self.save_network(self.netDec_b, 'Dec_b', label, self.gpu_ids) self.save_network(self.netEnc_shared, 'Enc_shared', label, self.gpu_ids) self.save_network(self.netDec_shared, 'Dec_shared', label, self.gpu_ids) self.save_network(self.netD, 'D', label, self.gpu_ids) ================================================ FILE: drawing_and_style_transfer/models/base_model.py ================================================ import os import torch class BaseModel(object): def name(self): return 'BaseModel' def initialize(self, opt): self.opt = opt self.gpu_ids = opt.gpu_ids self.isTrain = opt.isTrain self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor self.load_dir = os.path.join(opt.checkpoints_dir, opt.load_dir) self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) def set_input(self, input): self.input = input def forward(self): pass # used in test time, no backprop def test(self): pass def get_image_paths(self): pass def optimize_parameters(self): pass def get_current_visuals(self): return self.input def get_current_errors(self): return {} def save(self, label): pass # helper saving function that can be used by subclasses def save_network(self, network, network_label, epoch_label, gpu_ids): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if len(gpu_ids) and torch.cuda.is_available(): network.cuda(gpu_ids[0]) # helper loading function that can be used by subclasses def load_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.load_dir, save_filename) network.load_state_dict(torch.load(save_path)) # update learning rate (called once every epoch) def update_learning_rate(self): for scheduler in self.schedulers: scheduler.step() lr = self.optimizers[0].param_groups[0]['lr'] print('learning rate = %.7f' % lr) def as_np(self, data): return data.cpu().data.numpy() ================================================ FILE: drawing_and_style_transfer/models/networks.py ================================================ import torch import torch.nn as nn from torch.nn import init import functools from torch.autograd import Variable from torch.optim import lr_scheduler ############################################################################### # Functions ############################################################################### class pixel_norm(nn.Module): def forward(self, x, epsilon=1e-8): return x * torch.rsqrt(torch.mean(x.pow(2), dim=1, keepdim=True) + epsilon) def weights_init_normal(m): classname = m.__class__.__name__ # print(classname) if classname.find('Conv') != -1: init.normal(m.weight.data, 0.0, 0.02) elif classname.find('Linear') != -1: init.normal(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm2d') != -1: init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) def weights_init_xavier(m): classname = m.__class__.__name__ # print(classname) if classname.find('Conv') != -1: init.xavier_normal(m.weight.data, gain=0.02) elif classname.find('Linear') != -1: init.xavier_normal(m.weight.data, gain=0.02) elif classname.find('BatchNorm2d') != -1: init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) def weights_init_kaiming(m): classname = m.__class__.__name__ # print(classname) if classname.find('Conv') != -1: init.kaiming_normal(m.weight.data, a=0, mode='fan_in') elif classname.find('Linear') != -1: init.kaiming_normal(m.weight.data, a=0, mode='fan_in') elif classname.find('BatchNorm2d') != -1: init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) def weights_init_orthogonal(m): classname = m.__class__.__name__ print(classname) if classname.find('Conv') != -1: init.orthogonal(m.weight.data, gain=1) elif classname.find('Linear') != -1: init.orthogonal(m.weight.data, gain=1) elif classname.find('BatchNorm2d') != -1: init.normal(m.weight.data, 1.0, 0.02) init.constant(m.bias.data, 0.0) def init_weights(net, init_type='normal'): print('initialization method [%s]' % init_type) if init_type == 'normal': net.apply(weights_init_normal) elif init_type == 'xavier': net.apply(weights_init_xavier) elif init_type == 'kaiming': net.apply(weights_init_kaiming) elif init_type == 'orthogonal': net.apply(weights_init_orthogonal) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) def get_norm_layer(norm_type='instance'): if norm_type == 'batch': norm_layer = functools.partial(nn.BatchNorm2d, affine=True) elif norm_type == 'instance': norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) elif norm_type == 'none': norm_layer = None else: raise NotImplementedError('normalization layer [%s] is not found' % norm_type) return norm_layer def get_scheduler(optimizer, opt): if opt.lr_policy == 'lambda': def lambda_rule(epoch): lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) return lr_l scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) elif opt.lr_policy == 'step': scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) elif opt.lr_policy == 'plateau': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) else: return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) return scheduler def define_ED(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[], n_downsampling=2, start=0, end=2, input_layer=True, output_layer=True, n_blocks_encoder=9, n_blocks_decoder=9, start_dec=0, end_dec=1): use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) if use_gpu: assert (torch.cuda.is_available()) if which_model_netG == 'resnet_9blocks': netE = ResnetEncoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=n_blocks_encoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, start=start, end=end, input_layer=input_layer) netD = ResnetDecoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=n_blocks_decoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, end=end_dec, start=start_dec, output_layer=output_layer) elif which_model_netG == 'resnet_6blocks': netE = ResnetEncoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=n_blocks_encoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, start=start, end=end, input_layer=input_layer) netD = ResnetDecoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=n_blocks_decoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, end=end_dec, start=start_dec, output_layer=output_layer) else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: netE.cuda(gpu_ids[0]) netD.cuda(gpu_ids[0]) init_weights(netE, init_type=init_type) init_weights(netD, init_type=init_type) return netE, netD def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', gpu_ids=[]): netG = None use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) if use_gpu: assert (torch.cuda.is_available()) if which_model_netG == 'resnet_9blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) elif which_model_netG == 'resnet_6blocks': netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids) elif which_model_netG == 'unet_128': netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) elif which_model_netG == 'unet_256': netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids) else: raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) if len(gpu_ids) > 0: netG.cuda(gpu_ids[0]) init_weights(netG, init_type=init_type) return netG def define_D(input_nc, ndf, which_model_netD, n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]): use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) if use_gpu: assert (torch.cuda.is_available()) if which_model_netD == 'basic': netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'n_layers': netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) elif which_model_netD == 'pixel': netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids) else: raise NotImplementedError('Discriminator model name [%s] is not recognized' % which_model_netD) if use_gpu: netD.cuda(gpu_ids[0]) init_weights(netD, init_type=init_type) return netD def print_network(net): num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print('Total number of parameters: %d' % num_params) ############################################################################## # Classes ############################################################################## # Defines the GAN loss which uses either LSGAN or the regular GAN. # When LSGAN is used, it is basically same as MSELoss, # but it abstracts away the need to create the target label tensor # that has the same size as the input class GANLoss(nn.Module): def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, tensor=torch.FloatTensor): super(GANLoss, self).__init__() self.real_label = target_real_label self.fake_label = target_fake_label self.real_label_var = None self.fake_label_var = None self.Tensor = tensor if use_lsgan: self.loss = nn.MSELoss() else: self.loss = nn.BCELoss() def get_target_tensor(self, input, target_is_real): target_tensor = None if target_is_real: create_label = ((self.real_label_var is None) or (self.real_label_var.numel() != input.numel())) if create_label: real_tensor = self.Tensor(input.size()).fill_(self.real_label) self.real_label_var = Variable(real_tensor, requires_grad=False) target_tensor = self.real_label_var else: create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) if create_label: fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) self.fake_label_var = Variable(fake_tensor, requires_grad=False) target_tensor = self.fake_label_var return target_tensor def __call__(self, input, target_is_real): target_tensor = self.get_target_tensor(input, target_is_real) return self.loss(input, target_tensor) # Defines the generator that consists of Resnet blocks between a few # downsampling/upsampling operations. # Code and idea originally from Justin Johnson's architecture. # https://github.com/jcjohnson/fast-neural-style/ class ResnetEncoder(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], padding_type='reflect', n_downsampling=2, start=0, end=2, input_layer=True, n_blocks=6): assert (n_blocks >= 0) super(ResnetEncoder, self).__init__() self.input_nc = input_nc self.output_nc = output_nc self.ngf = ngf self.gpu_ids = gpu_ids if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d model = [] if input_layer: model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] for i in range(start, end): mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2 ** n_downsampling for i in range(n_blocks): model += [ ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] self.model = nn.Sequential(*model) def forward(self, input): if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) class ResnetDecoder(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[], padding_type='reflect', n_downsampling=2, start=0, end=2, output_layer=True, n_blocks=6): assert (n_blocks >= 0) super(ResnetDecoder, self).__init__() self.input_nc = input_nc self.output_nc = output_nc self.ngf = ngf self.gpu_ids = gpu_ids if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d model = [] mult = 2 ** n_downsampling for i in range(n_blocks): model += [ ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] for i in range(start, end): mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] if output_layer: model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) # Defines the generator that consists of Resnet blocks between a few # downsampling/upsampling operations. # Code and idea originally from Justin Johnson's architecture. # https://github.com/jcjohnson/fast-neural-style/ class ResnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[], padding_type='reflect'): assert (n_blocks >= 0) super(ResnetGenerator, self).__init__() self.input_nc = input_nc self.output_nc = output_nc self.ngf = ngf self.gpu_ids = gpu_ids if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), norm_layer(ngf), nn.ReLU(True)] n_downsampling = 2 for i in range(n_downsampling): mult = 2 ** i model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), norm_layer(ngf * mult * 2), nn.ReLU(True)] mult = 2 ** n_downsampling for i in range(n_blocks): model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] for i in range(n_downsampling): mult = 2 ** (n_downsampling - i) model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias), norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] model += [nn.ReflectionPad2d(3)] model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] model += [nn.Tanh()] self.model = nn.Sequential(*model) def forward(self, input): if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) # Define a resnet block class ResnetBlock(nn.Module): def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): super(ResnetBlock, self).__init__() self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): conv_block = [] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if padding_type == 'reflect': conv_block += [nn.ReflectionPad2d(1)] elif padding_type == 'replicate': conv_block += [nn.ReplicationPad2d(1)] elif padding_type == 'zero': p = 1 else: raise NotImplementedError('padding [%s] is not implemented' % padding_type) conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] return nn.Sequential(*conv_block) def forward(self, x): out = x + self.conv_block(x) return out # Defines the Unet generator. # |num_downs|: number of downsamplings in UNet. For example, # if |num_downs| == 7, image of size 128x128 will become of size 1x1 # at the bottleneck class UnetGenerator(nn.Module): def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]): super(UnetGenerator, self).__init__() self.gpu_ids = gpu_ids # construct unet structure unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) for i in range(num_downs - 5): unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) self.model = unet_block def forward(self, input): if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) # Defines the submodule with skip connection. # X -------------------identity---------------------- X # |-- downsampling -- |submodule| -- upsampling --| class UnetSkipConnectionBlock(nn.Module): def __init__(self, outer_nc, inner_nc, input_nc=None, submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d if input_nc is None: input_nc = outer_nc downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) downrelu = nn.LeakyReLU(0.2, True) downnorm = norm_layer(inner_nc) uprelu = nn.ReLU(True) upnorm = norm_layer(outer_nc) if outermost: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1) down = [downconv] up = [uprelu, upconv, nn.Tanh()] model = down + [submodule] + up elif innermost: upconv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv] up = [uprelu, upconv, upnorm] model = down + up else: upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias) down = [downrelu, downconv, downnorm] up = [uprelu, upconv, upnorm] if use_dropout: model = down + [submodule] + up + [nn.Dropout(0.5)] else: model = down + [submodule] + up self.model = nn.Sequential(*model) def forward(self, x): if self.outermost: return self.model(x) else: return torch.cat([x, self.model(x)], 1) # Defines the PatchGAN discriminator with the specified arguments. class NLayerDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): super(NLayerDiscriminator, self).__init__() self.gpu_ids = gpu_ids if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d kw = 4 padw = 1 sequence = [ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True) ] nf_mult = 1 nf_mult_prev = 1 for n in range(1, n_layers): nf_mult_prev = nf_mult nf_mult = min(2 ** n, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] nf_mult_prev = nf_mult nf_mult = min(2 ** n_layers, 8) sequence += [ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), norm_layer(ndf * nf_mult), nn.LeakyReLU(0.2, True) ] sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] if use_sigmoid: sequence += [nn.Sigmoid()] self.model = nn.Sequential(*sequence) def forward(self, input): if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.model, input, self.gpu_ids) else: return self.model(input) class PixelDiscriminator(nn.Module): def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]): super(PixelDiscriminator, self).__init__() self.gpu_ids = gpu_ids if type(norm_layer) == functools.partial: use_bias = norm_layer.func == nn.InstanceNorm2d else: use_bias = norm_layer == nn.InstanceNorm2d self.net = [ nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), norm_layer(ndf * 2), nn.LeakyReLU(0.2, True), nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] if use_sigmoid: self.net.append(nn.Sigmoid()) self.net = nn.Sequential(*self.net) def forward(self, input): if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor): return nn.parallel.data_parallel(self.net, input, self.gpu_ids) else: return self.net(input) ================================================ FILE: drawing_and_style_transfer/models/ost.py ================================================ import torch from collections import OrderedDict from torch.autograd import Variable import itertools import util.util as util from util.image_pool import ImagePool from .base_model import BaseModel from . import networks class OSTModel(BaseModel): def name(self): return 'OSTModel' def _compute_kl(self, mu): mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def set_encoders_and_decoders(self, opt): n_downsampling = opt.n_downsampling start_unshared = 0 num_unshared = opt.num_unshared start_shared = num_unshared end_shared = n_downsampling start_dec_shared = start_unshared end_dec_shared = start_unshared + (end_shared - start_shared) start_dec_unshared = end_dec_shared end_dec_unshared = n_downsampling num_res_blocks_unshared = opt.num_res_blocks_unshared n_res_blocks_shared = opt.num_res_blocks_shared self.netEnc_a, self.netDec_a = networks.define_ED(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, n_blocks_encoder=num_res_blocks_unshared, n_blocks_decoder=num_res_blocks_unshared, start=start_unshared, end=num_unshared, n_downsampling=n_downsampling, input_layer=True, output_layer=True, start_dec=start_dec_unshared, end_dec=end_dec_unshared) self.netEnc_b, self.netDec_b = networks.define_ED(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, n_blocks_encoder=num_res_blocks_unshared, n_blocks_decoder=num_res_blocks_unshared, start=start_unshared, end=num_unshared, n_downsampling=n_downsampling, input_layer=True, output_layer=True, start_dec=start_dec_unshared, end_dec=end_dec_unshared) self.netEnc_shared, self.netDec_shared = networks.define_ED(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, n_blocks_encoder=n_res_blocks_shared, n_blocks_decoder=n_res_blocks_shared, start=start_shared, n_downsampling=n_downsampling, end=end_shared, input_layer=False, output_layer=False, start_dec=start_dec_shared, end_dec=end_dec_shared) def initialize(self, opt): BaseModel.initialize(self, opt) self.set_encoders_and_decoders(opt) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_a = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) self.netD_b = networks.define_D(opt.output_nc, opt.ndf, opt.which_model_netD, opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids) if not opt.dont_load_pretrained_autoencoder: which_epoch = opt.which_epoch self.load_network(self.netEnc_b, 'Enc_b', which_epoch) self.load_network(self.netDec_b, 'Dec_b', which_epoch) self.load_network(self.netEnc_shared, 'Enc_shared', which_epoch) self.load_network(self.netDec_shared, 'Dec_shared', which_epoch) if not self.isTrain or opt.continue_train: which_epoch = opt.which_epoch self.load_network(self.netEnc_a, 'Enc_a', which_epoch) self.load_network(self.netDec_a, 'Dec_a', which_epoch) self.load_network(self.netEnc_b, 'Enc_b', which_epoch) self.load_network(self.netDec_b, 'Dec_b', which_epoch) self.load_network(self.netEnc_shared, 'Enc_shared', which_epoch) self.load_network(self.netDec_shared, 'Dec_shared', which_epoch) if self.isTrain: self.load_network(self.netD_a, 'D_a', which_epoch) self.load_network(self.netD_b, 'D_b', which_epoch) if self.isTrain: self.fake_A_pool = ImagePool(opt.pool_size) self.fake_B_pool = ImagePool(opt.pool_size) # define loss functions self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionIdt = torch.nn.L1Loss() # initialize optimizers self.optimizer_Enc_a = torch.optim.Adam(self.netEnc_a.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_Dec_a = torch.optim.Adam(self.netDec_a.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_Enc_b = torch.optim.Adam( itertools.chain(self.netEnc_b.parameters(), self.netEnc_shared.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_Dec_b = torch.optim.Adam( itertools.chain(self.netDec_b.parameters(), self.netDec_shared.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_a = torch.optim.Adam(self.netD_a.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_b = torch.optim.Adam(self.netD_b.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizers = [] self.schedulers = [] self.optimizers.append(self.optimizer_Enc_a) self.optimizers.append(self.optimizer_Dec_a) self.optimizers.append(self.optimizer_Enc_b) self.optimizers.append(self.optimizer_Dec_b) self.optimizers.append(self.optimizer_D_a) self.optimizers.append(self.optimizer_D_b) for optimizer in self.optimizers: self.schedulers.append(networks.get_scheduler(optimizer, opt)) print('---------- Networks initialized -------------') networks.print_network(self.netEnc_a) networks.print_network(self.netDec_a) networks.print_network(self.netEnc_b) networks.print_network(self.netDec_b) networks.print_network(self.netEnc_shared) networks.print_network(self.netDec_shared) if self.isTrain: networks.print_network(self.netD_a) networks.print_network(self.netD_b) print('-----------------------------------------------') def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' input_A = input['A' if AtoB else 'B'] input_B = input['B' if AtoB else 'A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) input_B = input_B.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.input_B = input_B self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A) self.real_B = Variable(self.input_B) def test(self): real_A = Variable(self.input_A, volatile=True) real_B = Variable(self.input_B, volatile=True) enc_a = self.netEnc_shared(self.netEnc_a(real_A)) enc_b = self.netEnc_shared(self.netEnc_b(real_B)) fake_AA = self.netDec_a(self.netDec_shared(enc_a)) fake_AB = self.netDec_b(self.netDec_shared(enc_a)) fake_BB = self.netDec_b(self.netDec_shared(enc_b)) enc_ab = self.netEnc_shared(self.netEnc_b(fake_AB)) fake_ABA = self.netDec_a(self.netDec_shared(enc_ab)) self.fake_AA = fake_AA.data self.fake_AB = fake_AB.data self.fake_BB = fake_BB.data self.fake_ABA = fake_ABA.data # get image paths def get_image_paths(self): return self.image_paths def backward_D_basic(self, netD, real, fake): # Real pred_real = netD(real) loss_D_real = self.criterionGAN(pred_real, True) # Fake pred_fake = netD(fake.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # Combined loss loss_D = (loss_D_real + loss_D_fake) * 0.5 # backward loss_D.backward() return loss_D def backward_D(self): fake_AB = self.fake_B_pool.query(self.fake_AB) loss_D_ab = self.backward_D_basic(self.netD_b, self.real_B, fake_AB) self.loss_D_ab = loss_D_ab.data[0] fake_BB = self.fake_B_pool.query(self.fake_BB) loss_D_bb = self.backward_D_basic(self.netD_b, self.real_B, fake_BB) self.loss_D_bb = loss_D_bb.data[0] def backward_G(self): # GAN loss D_A(G_A(A)) enc_a = self.netEnc_shared(self.netEnc_a(self.real_A)) enc_b = self.netEnc_shared(self.netEnc_b(self.real_B)) fake_AA = self.netDec_a(self.netDec_shared(enc_a)) fake_AB = self.netDec_b(self.netDec_shared(enc_a)) fake_BB = self.netDec_b(self.netDec_shared(enc_b)) enc_ab = self.netEnc_shared(self.netEnc_b(fake_AB)) fake_ABA = self.netDec_a(self.netDec_shared(enc_ab)) pred_fake_AB = self.netD_b(fake_AB) loss_Gan_AB = self.criterionGAN(pred_fake_AB, True) loss_idt_A = self.criterionIdt(fake_AA, self.real_A) loss_cycle_A = self.opt.lambda_A * self.criterionIdt(fake_ABA, self.real_A) loss_idt_B = self.criterionIdt(fake_BB, self.real_B) pred_fake_BB = self.netD_b(fake_BB) loss_Gan_BB = self.criterionGAN(pred_fake_BB, True) loss_kl_B = self.opt.kl_lambda * self._compute_kl(enc_b) # combined losses loss_G_B = loss_idt_B + loss_kl_B + loss_Gan_BB loss_G_A = loss_Gan_AB + loss_cycle_A + loss_idt_A self.fake_AA = fake_AA.data self.fake_BB = fake_BB.data self.fake_AB = fake_AB.data self.fake_ABA = fake_ABA.data self.loss_cycle_A = loss_cycle_A.data[0] self.loss_Gan_AB = loss_Gan_AB.data[0] self.loss_Gan_BB = loss_Gan_AB.data[0] self.loss_idt_A = loss_idt_A.data[0] self.loss_idt_B = loss_idt_B.data[0] self.loss_kl_B = loss_kl_B.data[0] return loss_G_A, loss_G_B def optimize_parameters(self): # forward self.forward() loss_G_A, loss_G_B = self.backward_G() # x loss updates self.optimizer_Enc_a.zero_grad() self.optimizer_Dec_a.zero_grad() loss_G_A.backward(retain_graph=True) self.optimizer_Enc_a.step() self.optimizer_Dec_a.step() # B loss updates self.optimizer_Enc_b.zero_grad() self.optimizer_Dec_b.zero_grad() loss_G_B.backward() self.optimizer_Enc_b.step() self.optimizer_Dec_b.step() # D self.optimizer_D_a.zero_grad() self.optimizer_D_b.zero_grad() self.backward_D() self.optimizer_D_a.step() self.optimizer_D_b.step() def get_current_errors(self): ret_errors = OrderedDict( [('D_ab', self.loss_D_ab), ('D_bb', self.loss_D_bb), ('G_AB', self.loss_Gan_AB), ('G_BB', self.loss_Gan_BB), ('Idt_B', self.loss_idt_B), ('Idt_A', self.loss_idt_A), ('Cycle_A', self.loss_cycle_A), ('Kl_B', self.loss_kl_B), ]) return ret_errors def get_current_visuals(self): real_A = util.tensor2im(self.input_A) real_B = util.tensor2im(self.input_B) fake_BB = util.tensor2im(self.fake_BB) fake_AB = util.tensor2im(self.fake_AB) fake_AA = util.tensor2im(self.fake_AA) fake_ABA = util.tensor2im(self.fake_ABA) ret_visuals = OrderedDict( [('real_B', real_B), ('fake_BB', fake_BB), ('real_A', real_A), ('fake_AA', fake_AA), ('fake_AB', fake_AB), ('fake_ABA', fake_ABA), ]) return ret_visuals def save(self, label): self.save_network(self.netEnc_a, 'Enc_a', label, self.gpu_ids) self.save_network(self.netDec_a, 'Dec_a', label, self.gpu_ids) self.save_network(self.netD_a, 'D_a', label, self.gpu_ids) self.save_network(self.netEnc_b, 'Enc_b', label, self.gpu_ids) self.save_network(self.netDec_b, 'Dec_b', label, self.gpu_ids) self.save_network(self.netD_b, 'D_b', label, self.gpu_ids) self.save_network(self.netEnc_shared, 'Enc_shared', label, self.gpu_ids) self.save_network(self.netDec_shared, 'Dec_shared', label, self.gpu_ids) ================================================ FILE: drawing_and_style_transfer/models/test_model.py ================================================ from torch.autograd import Variable from collections import OrderedDict import util.util as util from .base_model import BaseModel from . import networks class TestModel(BaseModel): def name(self): return 'TestModel' def initialize(self, opt): assert (not opt.isTrain) BaseModel.initialize(self, opt) self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids) which_epoch = opt.which_epoch self.load_network(self.netG, 'G', which_epoch) print('---------- Networks initialized -------------') networks.print_network(self.netG) print('-----------------------------------------------') def set_input(self, input): # we need to use single_dataset mode input_A = input['A'] if len(self.gpu_ids) > 0: input_A = input_A.cuda(self.gpu_ids[0], async=True) self.input_A = input_A self.image_paths = input['A_paths'] def test(self): self.real_A = Variable(self.input_A, volatile=True) self.fake_B = self.netG(self.real_A) # get image paths def get_image_paths(self): return self.image_paths def get_current_visuals(self): real_A = util.tensor2im(self.real_A.data) fake_B = util.tensor2im(self.fake_B.data) return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) ================================================ FILE: drawing_and_style_transfer/options/__init__.py ================================================ ================================================ FILE: drawing_and_style_transfer/options/base_options.py ================================================ import argparse import os from util import util import torch class BaseOptions(): def __init__(self): self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) self.initialized = False def initialize(self): self.parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') self.parser.add_argument('--max_items_A', type=int, default=-1, help='max number of items for domain A, -1 indicates no maximum') self.parser.add_argument('--max_items_B', type=int, default=-1, help='max number of items for domain B, -1 indicates no maximum') self.parser.add_argument('--start', type=int, default=0, help='starting index of items of domain A, after sorting') self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD') self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG and netED') self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') self.parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') self.parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') self.parser.add_argument('--model', type=str, default='ost', help='chooses which model to use. ost, autoencoder, test.') self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') self.parser.add_argument('--load_dir', type=str, default='./checkpoints', help='models are loaded here') self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size') self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') self.parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') self.parser.add_argument('--no_flip_and_rotation', action='store_true', help='if specified, do not flip and rotate the images for data augmentation') self.parser.add_argument('--rotation_degree', type=int, default=7, help='rotation degree used for augmentation') self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') self.parser.add_argument('--A', type=str, default='A', help='used to exchange dataset A for B by setting the value to B') self.parser.add_argument('--B', type=str, default='B', help='used to exchange dataset B for A by setting the value to A') self.parser.add_argument('--n_downsampling', type=int, default=2, help="number of downsampling/upsampling convolutional/deconvolutional layers") self.parser.add_argument('--num_unshared', type=int, default=1, help="number of unshared encoder/decoder layers, not including input and final layers") self.parser.add_argument('--num_res_blocks_unshared', type=int, default=0, help='number of unshared resnet blocks') self.parser.add_argument('--num_res_blocks_shared', type=int, default=6, help='number of shared resnet blocks') self.initialized = True def parse(self): if not self.initialized: self.initialize() self.opt = self.parser.parse_args() self.opt.isTrain = self.isTrain # train or test str_ids = self.opt.gpu_ids.split(',') self.opt.gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: self.opt.gpu_ids.append(id) # set gpu ids if len(self.opt.gpu_ids) > 0: torch.cuda.set_device(self.opt.gpu_ids[0]) args = vars(self.opt) print('------------ Options -------------') for k, v in sorted(args.items()): print('%s: %s' % (str(k), str(v))) print('-------------- End ----------------') # save to the disk expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) util.mkdirs(expr_dir) file_name = os.path.join(expr_dir, 'opt.txt') with open(file_name, 'wt') as opt_file: opt_file.write('------------ Options -------------\n') for k, v in sorted(args.items()): opt_file.write('%s: %s\n' % (str(k), str(v))) opt_file.write('-------------- End ----------------\n') return self.opt ================================================ FILE: drawing_and_style_transfer/options/test_options.py ================================================ from .base_options import BaseOptions class TestOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') self.isTrain = False ================================================ FILE: drawing_and_style_transfer/options/train_options.py ================================================ from .base_options import BaseOptions class TrainOptions(BaseOptions): def initialize(self): BaseOptions.initialize(self) self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') self.parser.add_argument('--save_latest_freq', type=int, default=10000, help='frequency of saving the latest results') self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') self.parser.add_argument('--niter', type=int, default=60, help='# of iter at starting learning rate') self.parser.add_argument('--niter_decay', type=int, default=20, help='# of iter to linearly decay learning rate to zero') self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') self.parser.add_argument('--kl_lambda', type=float, default=0.1, help='weight for kl loss') self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)') self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') self.parser.add_argument('--dont_load_pretrained_autoencoder', action='store_true', help='do not load pretrained autoencoder') self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') self.isTrain = True ================================================ FILE: drawing_and_style_transfer/scripts/test_ost.sh ================================================ # images to cityscapes python test.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost --load_dir=cityscapes_ost --model=ost --no_dropout --n_downsampling=3 --num_unshared=3 --start=0 --max_items_A=1 # cityscapes to images python test.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost_reverse --load_dir=cityscapes_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A' # images to facades python test.py --dataroot=./datasets/facades/ --name=facades_ost --load_dir=facades_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 # facades to images python test.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --load_dir=facades_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A' # aerial view to maps python test.py --dataroot=./datasets/maps/ --name=maps_ost --load_dir=maps_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 # maps to aerial view python test.py --dataroot=./datasets/maps/ --name=maps_ost_reverse --load_dir=maps_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A' # monet2photo python test.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost --load_dir=monet2photo_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 # photo2monet python test.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost_reverse --load_dir=monet2photo_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A' # summer2winter python test.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost --load_dir=summer2winter_yosemite_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 # winter2summer python test.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost_reverse --load_dir=summer2winter_yosemite_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A' ================================================ FILE: drawing_and_style_transfer/scripts/train_autoencoder.sh ================================================ # images to cityscapes python train.py --dataroot=./datasets/cityscapes/trainB --name=cityscapes_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=3 --num_unshared=3 # cityscapes to images python train.py --dataroot=./datasets/cityscapes/trainA --name=cityscapes_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2 # images to facades python train.py --dataroot=./datasets/facades/trainB --name=facades_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2 # facades to images python train.py --dataroot=./datasets/facades/trainA --name=facades_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2 # aerial view to maps python train.py --dataroot=./datasets/maps/trainB --name=maps_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2 # maps to aerial view python train.py --dataroot=./datasets/maps/trainA --name=maps_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2 # monet2photo python train.py --dataroot=./datasets/monet2photo/trainB --name=monet2photo_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0 # photo2monet python train.py --dataroot=./datasets/monet2photo/trainA --name=monet2photo_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0 # summer2winter python train.py --dataroot=./datasets/summer2winter_yosemite/trainB --name=summer2winter_yosemite_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0 # winter2summer python train.py --dataroot=./datasets/summer2winter_yosemite/trainA --name=summer2winter_yosemite_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0 ================================================ FILE: drawing_and_style_transfer/scripts/train_ost.sh ================================================ # images to cityscapes python train.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost --load_dir=cityscapes_autoencoder --model=ost --no_dropout --n_downsampling=3 --num_unshared=3 --start=0 --max_items_A=1 # cityscapes to images python train.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost_reverse --load_dir=cityscapes_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A' # images to facades python train.py --dataroot=./datasets/facades/ --name=facades_ost --load_dir=facades_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 # facades to images python train.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --load_dir=facades_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A' # aerial view to maps python train.py --dataroot=./datasets/maps/ --name=maps_ost --load_dir=maps_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 # maps to aerial view python train.py --dataroot=./datasets/maps/ --name=maps_ost_reverse --load_dir=maps_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A' # monet2photo python train.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost --load_dir=monet2photo_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 # photo2monet python train.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost_reverse --load_dir=monet2photo_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A' # summer2winter python train.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost --load_dir=summer2winter_yosemite_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 # winter2summer python train.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost_reverse --load_dir=summer2winter_yosemite_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A' ================================================ FILE: drawing_and_style_transfer/test.py ================================================ import os from options.test_options import TestOptions from data import CreateDataLoader from models import create_model from util.visualizer import Visualizer from util import html if __name__ == '__main__': opt = TestOptions().parse() opt.nThreads = 1 # test code only supports nThreads = 1 opt.batchSize = 1 # test code only supports batchSize = 1 opt.serial_batches = True # no shuffle opt.no_flip = True # no flip # We are interested on testing only on x's in A on which we trained opt.phase = 'train' if opt.max_items_A >= 0: opt.max_items_B = opt.max_items_A if opt.max_items_B >= 0: opt.max_items_A = opt.max_items_B data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() model = create_model(opt) visualizer = Visualizer(opt) # create website web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) # test for i, data in enumerate(dataset): if i >= opt.how_many: break model.set_input(data) model.test() visuals = model.get_current_visuals() img_path = model.get_image_paths() print('%04d: process image... %s' % (i, img_path)) visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, index=i, split=0) webpage.save() ================================================ FILE: drawing_and_style_transfer/train.py ================================================ import time from options.train_options import TrainOptions from data import CreateDataLoader from models import create_model from util.visualizer import Visualizer if __name__ == '__main__': opt = TrainOptions().parse() data_loader = CreateDataLoader(opt) dataset = data_loader.load_data() dataset_size = len(data_loader) print('#training images = %d' % dataset_size) model = create_model(opt) visualizer = Visualizer(opt) total_steps = 0 for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): epoch_start_time = time.time() iter_data_time = time.time() epoch_iter = 0 for i, data in enumerate(dataset): iter_start_time = time.time() if total_steps % opt.print_freq == 0: t_data = iter_start_time - iter_data_time visualizer.reset() total_steps += opt.batchSize epoch_iter += opt.batchSize model.set_input(data) model.optimize_parameters() if total_steps % opt.display_freq == 0: save_result = total_steps % opt.update_html_freq == 0 visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) if total_steps % opt.print_freq == 0: errors = model.get_current_errors() t = (time.time() - iter_start_time) / opt.batchSize visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data) if opt.display_id > 0: visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors) if total_steps % opt.save_latest_freq == 0: print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) model.save('latest') iter_data_time = time.time() if epoch % opt.save_epoch_freq == 0: print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) model.save('latest') model.save(epoch) print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) model.update_learning_rate() ================================================ FILE: drawing_and_style_transfer/util/__init__.py ================================================ ================================================ FILE: drawing_and_style_transfer/util/get_data.py ================================================ from __future__ import print_function import os import tarfile import requests from warnings import warn from zipfile import ZipFile from bs4 import BeautifulSoup from os.path import abspath, isdir, join, basename class GetData(object): """ Download CycleGAN or Pix2Pix Data. Args: technique : str One of: 'cyclegan' or 'pix2pix'. verbose : bool If True, print additional information. Examples: >>> from util.get_data import GetData >>> gd = GetData(technique='cyclegan') >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. """ def __init__(self, technique='cyclegan', verbose=True): url_dict = { 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' } self.url = url_dict.get(technique.lower()) self._verbose = verbose def _print(self, text): if self._verbose: print(text) @staticmethod def _get_options(r): soup = BeautifulSoup(r.text, 'lxml') options = [h.text for h in soup.find_all('a', href=True) if h.text.endswith(('.zip', 'tar.gz'))] return options def _present_options(self): r = requests.get(self.url) options = self._get_options(r) print('Options:\n') for i, o in enumerate(options): print("{0}: {1}".format(i, o)) choice = input("\nPlease enter the number of the " "dataset above you wish to download:") return options[int(choice)] def _download_data(self, dataset_url, save_path): if not isdir(save_path): os.makedirs(save_path) base = basename(dataset_url) temp_save_path = join(save_path, base) with open(temp_save_path, "wb") as f: r = requests.get(dataset_url) f.write(r.content) if base.endswith('.tar.gz'): obj = tarfile.open(temp_save_path) elif base.endswith('.zip'): obj = ZipFile(temp_save_path, 'r') else: raise ValueError("Unknown File Type: {0}.".format(base)) self._print("Unpacking Data...") obj.extractall(save_path) obj.close() os.remove(temp_save_path) def get(self, save_path, dataset=None): """ Download a dataset. Args: save_path : str A directory to save the data to. dataset : str, optional A specific dataset to download. Note: this must include the file extension. If None, options will be presented for you to choose from. Returns: save_path_full : str The absolute path to the downloaded data. """ if dataset is None: selected_dataset = self._present_options() else: selected_dataset = dataset save_path_full = join(save_path, selected_dataset.split('.')[0]) if isdir(save_path_full): warn("\n'{0}' already exists. Voiding Download.".format( save_path_full)) else: self._print('Downloading Data...') url = "{0}/{1}".format(self.url, selected_dataset) self._download_data(url, save_path=save_path) return abspath(save_path_full) ================================================ FILE: drawing_and_style_transfer/util/html.py ================================================ import dominate from dominate.tags import * import os class HTML: def __init__(self, web_dir, title, reflesh=0): self.title = title self.web_dir = web_dir self.img_dir = os.path.join(self.web_dir, 'images') if not os.path.exists(self.web_dir): os.makedirs(self.web_dir) if not os.path.exists(self.img_dir): os.makedirs(self.img_dir) # print(self.img_dir) self.doc = dominate.document(title=title) if reflesh > 0: with self.doc.head: meta(http_equiv="reflesh", content=str(reflesh)) def get_image_dir(self): return self.img_dir def add_header(self, str): with self.doc: h3(str) def add_table(self, border=1): self.t = table(border=border, style="table-layout: fixed;") self.doc.add(self.t) def add_images(self, ims, txts, links, width=400): self.add_table() with self.t: with tr(): for im, txt, link in zip(ims, txts, links): with td(style="word-wrap: break-word;", halign="center", valign="top"): with p(): with a(href=os.path.join('images', link)): img(style="width:%dpx" % width, src=os.path.join('images', im)) br() p(txt) def save(self): html_file = '%s/index.html' % self.web_dir f = open(html_file, 'wt') f.write(self.doc.render()) f.close() if __name__ == '__main__': html = HTML('web/', 'test_html') html.add_header('hello world') ims = [] txts = [] links = [] for n in range(4): ims.append('image_%d.png' % n) txts.append('text_%d' % n) links.append('image_%d.png' % n) html.add_images(ims, txts, links) html.save() ================================================ FILE: drawing_and_style_transfer/util/image_pool.py ================================================ import random import torch from torch.autograd import Variable class ImagePool(): def __init__(self, pool_size): self.pool_size = pool_size if self.pool_size > 0: self.num_imgs = 0 self.images = [] def query(self, images): if self.pool_size == 0: return Variable(images) return_images = [] for image in images: image = torch.unsqueeze(image, 0) if self.num_imgs < self.pool_size: self.num_imgs = self.num_imgs + 1 self.images.append(image) return_images.append(image) else: p = random.uniform(0, 1) if p > 0.5: random_id = random.randint(0, self.pool_size - 1) tmp = self.images[random_id].clone() self.images[random_id] = image return_images.append(tmp) else: return_images.append(image) return_images = Variable(torch.cat(return_images, 0)) return return_images ================================================ FILE: drawing_and_style_transfer/util/util.py ================================================ from __future__ import print_function import torch import numpy as np from PIL import Image import os # Converts a Tensor into a Numpy array # |imtype|: the desired type of the converted numpy array def tensor2im(image_tensor, imtype=np.uint8): image_numpy = image_tensor[0].cpu().float().numpy() if image_numpy.shape[0] == 1: image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 return image_numpy.astype(imtype) def diagnose_network(net, name='network'): mean = 0.0 count = 0 for param in net.parameters(): if param.grad is not None: mean += torch.mean(torch.abs(param.grad.data)) count += 1 if count > 0: mean = mean / count print(name) print(mean) def save_image(image_numpy, image_path): image_pil = Image.fromarray(image_numpy) image_pil.save(image_path) def print_numpy(x, val=True, shp=False): x = x.astype(np.float64) if shp: print('shape,', x.shape) if val: x = x.flatten() print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) def mkdirs(paths): if isinstance(paths, list) and not isinstance(paths, str): for path in paths: mkdir(path) else: mkdir(paths) def mkdir(path): if not os.path.exists(path): os.makedirs(path) ================================================ FILE: drawing_and_style_transfer/util/visualizer.py ================================================ import numpy as np import os import ntpath import time from . import util from . import html from scipy.misc import imresize class Visualizer(): def __init__(self, opt): # self.opt = opt self.display_id = opt.display_id self.use_html = opt.isTrain and not opt.no_html self.win_size = opt.display_winsize self.name = opt.name self.opt = opt self.saved = False if self.display_id > 0: import visdom self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port) if self.use_html: self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') self.img_dir = os.path.join(self.web_dir, 'images') print('create web directory %s...' % self.web_dir) util.mkdirs([self.web_dir, self.img_dir]) self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') with open(self.log_name, "a") as log_file: now = time.strftime("%c") log_file.write('================ Training Loss (%s) ================\n' % now) def reset(self): self.saved = False # |visuals|: dictionary of images to display or save def display_current_results(self, visuals, epoch, save_result): if self.display_id > 0: # show images in the browser ncols = self.opt.display_single_pane_ncols if ncols > 0: h, w = next(iter(visuals.values())).shape[:2] table_css = """""" % (w, h) title = self.name label_html = '' label_html_row = '' nrows = int(np.ceil(len(visuals.items()) / ncols)) images = [] idx = 0 for label, image_numpy in visuals.items(): label_html_row += '%s' % label images.append(image_numpy.transpose([2, 0, 1])) idx += 1 if idx % ncols == 0: label_html += '%s' % label_html_row label_html_row = '' white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 while idx % ncols != 0: images.append(white_image) label_html_row += '' idx += 1 if label_html_row != '': label_html += '%s' % label_html_row # pane col = image row self.vis.images(images, nrow=ncols, win=self.display_id + 1, padding=2, opts=dict(title=title + ' images')) label_html = '%s
' % label_html self.vis.text(table_css + label_html, win=self.display_id + 2, opts=dict(title=title + ' labels')) else: idx = 1 for label, image_numpy in visuals.items(): self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), win=self.display_id + idx) idx += 1 if self.use_html and (save_result or not self.saved): # save images to a html file self.saved = True for label, image_numpy in visuals.items(): img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) util.save_image(image_numpy, img_path) # update website webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) for n in range(epoch, 0, -1): webpage.add_header('epoch [%d]' % n) ims = [] txts = [] links = [] for label, image_numpy in visuals.items(): img_path = 'epoch%.3d_%s.png' % (n, label) ims.append(img_path) txts.append(label) links.append(img_path) webpage.add_images(ims, txts, links, width=self.win_size) webpage.save() # errors: dictionary of error labels and values def plot_current_errors(self, epoch, counter_ratio, opt, errors): if not hasattr(self, 'plot_data'): self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())} self.plot_data['X'].append(epoch + counter_ratio) self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']]) self.vis.line( X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), Y=np.array(self.plot_data['Y']), opts={ 'title': self.name + ' loss over time', 'legend': self.plot_data['legend'], 'xlabel': 'epoch', 'ylabel': 'loss'}, win=self.display_id) # errors: same format as |errors| of plotCurrentErrors def print_current_errors(self, epoch, i, errors, t, t_data): message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) for k, v in errors.items(): message += '%s: %.3f ' % (k, v) print(message) with open(self.log_name, "a") as log_file: log_file.write('%s\n' % message) # save image to the disk def save_images(self, webpage, visuals, image_path, aspect_ratio=1.0, index=None, split=1): image_dir = webpage.get_image_dir() short_path = ntpath.basename(image_path[0]) name = os.path.splitext(short_path)[0] if index is not None: name_splits = name.split("_") if split == 0: name = str(index) else: name = str(index) + "_" + name_splits[split] webpage.add_header(name) ims = [] txts = [] links = [] for label, im in visuals.items(): image_name = '%s_%s.png' % (name, label) save_path = os.path.join(image_dir, image_name) h, w, _ = im.shape if aspect_ratio > 1.0: im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') if aspect_ratio < 1.0: im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') util.save_image(im, save_path) ims.append(image_name) txts.append(label) links.append(image_name) webpage.add_images(ims, txts, links, width=self.win_size) ================================================ FILE: mnist_to_svhn/data_loader.py ================================================ import torch from torchvision import datasets from torchvision import transforms def get_loader(config): """Builds and returns Dataloader for MNIST and SVHN dataset.""" transform_list = [] if config.use_augmentation: transform_list.append(transforms.RandomHorizontalFlip()) transform_list.append(transforms.RandomRotation(0.1)) transform_list.append(transforms.Scale(config.image_size)) transform_list.append(transforms.ToTensor()) transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))) transform_test = transforms.Compose([ transforms.Scale(config.image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) transform_train = transforms.Compose(transform_list) svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform_train, split='train') mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform_train, train=True) svhn_test = datasets.SVHN(root=config.svhn_path, download=True, transform=transform_test, split='test') mnist_test = datasets.MNIST(root=config.mnist_path, download=True, transform=transform_test, train=False) svhn_loader = torch.utils.data.DataLoader(dataset=svhn, batch_size=config.batch_size, shuffle=config.shuffle, num_workers=config.num_workers) mnist_loader = torch.utils.data.DataLoader(dataset=mnist, batch_size=config.batch_size, shuffle=config.shuffle, num_workers=config.num_workers) svhn_test_loader = torch.utils.data.DataLoader(dataset=svhn_test, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) mnist_test_loader = torch.utils.data.DataLoader(dataset=mnist_test, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers) return svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader ================================================ FILE: mnist_to_svhn/download.sh ================================================ mkdir -p mnist mkdir -p svhn wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat wget -O svhn/extra_32x32.mat http://ufldl.stanford.edu/housenumbers/extra_32x32.mat ================================================ FILE: mnist_to_svhn/main_autoencoder.py ================================================ import argparse import os from torch.backends import cudnn from solver_autoencoder import Solver from data_loader import get_loader def str2bool(v): return v.lower() in ('true') def main(config): svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader(config) solver = Solver(config, svhn_loader, mnist_loader) cudnn.benchmark = True # create directories if not exist if not os.path.exists(config.model_path): os.makedirs(config.model_path) if not os.path.exists(config.sample_path): os.makedirs(config.sample_path) if config.mode == 'train': solver.train() if __name__ == '__main__': parser = argparse.ArgumentParser() # model hyper-parameters parser.add_argument('--image_size', type=int, default=32) parser.add_argument('--g_conv_dim', type=int, default=64) parser.add_argument('--d_conv_dim', type=int, default=64) parser.add_argument('--num_classes', type=int, default=10) # training hyper-parameters parser.add_argument('--train_iters', type=int, default=15000) parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--num_workers', type=int, default=2) parser.add_argument('--lr', type=float, default=0.0002) parser.add_argument('--beta1', type=float, default=0.5) parser.add_argument('--beta2', type=float, default=0.999) parser.add_argument('--kl_lambda', type=float, default=0.1) # misc parser.add_argument('--mode', type=str, default='train') parser.add_argument('--model_path', type=str, default='./models_autoencoder') parser.add_argument('--sample_path', type=str, default='./samples_autoencoder') parser.add_argument('--mnist_path', type=str, default='./mnist') parser.add_argument('--svhn_path', type=str, default='./svhn') parser.add_argument('--log_step', type=int, default=10) parser.add_argument('--sample_step', type=int, default=500) parser.add_argument('--shuffle', type=bool, default=True) parser.add_argument('--use_augmentation', required=True, type=str2bool) config = parser.parse_args() print(config) main(config) ================================================ FILE: mnist_to_svhn/main_mnist_to_svhn.py ================================================ import argparse import logging import os from torch.backends import cudnn from data_loader import get_loader from solver_mnist_to_svhn import Solver def str2bool(v): return v.lower() in ('true') def main(config): svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader(config) solver = Solver(config, svhn_loader, mnist_loader) cudnn.benchmark = True # create directories if not exist if not os.path.exists(config.model_path): os.makedirs(config.model_path) if not os.path.exists(config.sample_path): os.makedirs(config.sample_path) base = config.log_path filename = os.path.join(base, str(config.max_items)) if not os.path.isdir(base): os.mkdir(base) logging.basicConfig(filename=filename, level=logging.DEBUG) if config.mode == 'train': solver.train() elif config.mode == 'sample': solver.sample() if __name__ == '__main__': parser = argparse.ArgumentParser() # model hyper-parameters parser.add_argument('--image_size', type=int, default=32) parser.add_argument('--g_conv_dim', type=int, default=64) parser.add_argument('--d_conv_dim', type=int, default=64) parser.add_argument('--num_classes', type=int, default=10) # training hyper-parameters parser.add_argument('--train_iters', type=int, default=40000) parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--num_workers', type=int, default=2) parser.add_argument('--lr', type=float, default=0.0002) parser.add_argument('--beta1', type=float, default=0.5) parser.add_argument('--beta2', type=float, default=0.999) parser.add_argument('--kl_lambda', type=float, default=0.1) # misc parser.add_argument('--mode', type=str, default='train') parser.add_argument('--mnist_path', type=str, default='./mnist') parser.add_argument('--svhn_path', type=str, default='./svhn') parser.add_argument('--log_step', type=int, default=10) parser.add_argument('--shuffle', type=bool, default=True) parser.add_argument('--load_iter', type=int, default=10000) parser.add_argument('--sample_step', type=int, default=500) parser.add_argument('--num_averaging_runs', type=int, default=1000) parser.add_argument('--num_iters_save_model_and_return', type=int, default=5000) parser.add_argument('--num_d_iterations', type=int, default=1) parser.add_argument('--num_g_iterations', type=int, default=1) parser.add_argument('--model_path', type=str, default='./models_ost') parser.add_argument('--sample_path', type=str, default='./samples_ost') parser.add_argument('--load_path', type=str, default='./models_autoencoder') parser.add_argument('--log_path', type=str, default='logs_ost') parser.add_argument('--pretrained_g', required=True, type=str2bool) parser.add_argument('--save_models_and_samples', required=True, type=str2bool) parser.add_argument('--use_augmentation', required=True, type=str2bool) parser.add_argument('--one_way_cycle', required=True, type=str2bool) parser.add_argument('--freeze_shared', required=True, type=str2bool) parser.add_argument('--max_items', type=int, default=1) config = parser.parse_args() print(config) main(config) ================================================ FILE: mnist_to_svhn/main_svhn_to_mnist.py ================================================ import argparse import logging import os from data_loader import get_loader from solver_svhn_to_mnist import Solver from torch.backends import cudnn def str2bool(v): return v.lower() in ('true') def main(config): svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader(config) solver = Solver(config, svhn_loader, mnist_loader) cudnn.benchmark = True # create directories if not exist if not os.path.exists(config.model_path): os.makedirs(config.model_path) if not os.path.exists(config.sample_path): os.makedirs(config.sample_path) base = config.log_path filename = os.path.join(base, str(config.max_items)) if not os.path.isdir(base): os.mkdir(base) logging.basicConfig(filename=filename, level=logging.DEBUG) if config.mode == 'train': solver.train() elif config.mode == 'sample': solver.sample() if __name__ == '__main__': parser = argparse.ArgumentParser() # model hyper-parameters parser.add_argument('--image_size', type=int, default=32) parser.add_argument('--g_conv_dim', type=int, default=64) parser.add_argument('--d_conv_dim', type=int, default=64) parser.add_argument('--num_classes', type=int, default=10) # training hyper-parameters parser.add_argument('--train_iters', type=int, default=40000) parser.add_argument('--batch_size', type=int, default=1) parser.add_argument('--num_workers', type=int, default=2) parser.add_argument('--lr', type=float, default=0.0002) parser.add_argument('--beta1', type=float, default=0.5) parser.add_argument('--beta2', type=float, default=0.999) parser.add_argument('--kl_lambda', type=float, default=0.1) # misc parser.add_argument('--mode', type=str, default='train') parser.add_argument('--mnist_path', type=str, default='./mnist') parser.add_argument('--svhn_path', type=str, default='./svhn') parser.add_argument('--log_step', type=int, default=10) parser.add_argument('--shuffle', type=bool, default=True) parser.add_argument('--load_iter', type=int, default=10000) parser.add_argument('--sample_step', type=int, default=500) parser.add_argument('--num_averaging_runs', type=int, default=1000) parser.add_argument('--num_iters_save_model_and_return', type=int, default=5000) parser.add_argument('--num_d_iterations', type=int, default=1) parser.add_argument('--num_g_iterations', type=int, default=1) parser.add_argument('--model_path', type=str, default='./models_ost') parser.add_argument('--sample_path', type=str, default='./samples_ost') parser.add_argument('--load_path', type=str, default='./models_autoencoder') parser.add_argument('--log_path', type=str, default='logs_ost') parser.add_argument('--pretrained_g', required=True, type=str2bool) parser.add_argument('--save_models_and_samples', required=True, type=str2bool) parser.add_argument('--use_augmentation', required=True, type=str2bool) parser.add_argument('--one_way_cycle', required=True, type=str2bool) parser.add_argument('--freeze_shared', required=True, type=str2bool) parser.add_argument('--max_items', type=int, default=1) config = parser.parse_args() print(config) main(config) ================================================ FILE: mnist_to_svhn/model.py ================================================ import torch.nn as nn import torch.nn.functional as F def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True): """Custom deconvolutional layer for simplicity.""" layers = [] layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False)) if bn: layers.append(nn.BatchNorm2d(c_out)) return nn.Sequential(*layers) def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True): """Custom convolutional layer for simplicity.""" layers = [] layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False)) if bn: layers.append(nn.BatchNorm2d(c_out)) return nn.Sequential(*layers) class G11(nn.Module): def __init__(self, conv_dim=64): super(G11, self).__init__() # encoding blocks self.conv1 = conv(1, conv_dim, 4) self.conv1_svhn = conv(3, conv_dim, 4) self.conv2 = conv(conv_dim, conv_dim * 2, 4) # residual blocks res_dim = conv_dim * 2 self.conv3 = conv(res_dim, res_dim, 3, 1, 1) self.conv4 = conv(res_dim, res_dim, 3, 1, 1) # decoding blocks self.deconv1 = deconv(conv_dim * 2, conv_dim, 4) self.deconv2 = deconv(conv_dim, 1, 4, bn=False) self.deconv2_svhn = deconv(conv_dim, 3, 4, bn=False) def forward(self, x, svhn=False): if svhn: out = F.leaky_relu(self.conv1_svhn(x), 0.05) # (?, 64, 16, 16) else: out = F.leaky_relu(self.conv1(x), 0.05) # (?, 64, 16, 16) out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8) out = F.leaky_relu(self.conv3(out), 0.05) # ( " ) out = F.leaky_relu(self.conv4(out), 0.05) # ( " ) out = F.leaky_relu(self.deconv1(out), 0.05) # (?, 64, 16, 16) if svhn: out = F.tanh(self.deconv2_svhn(out)) # (?, 3, 32, 32) else: out = F.tanh(self.deconv2(out)) # (?, 3, 32, 32) return out def encode(self, x, svhn=False): if svhn: out = F.leaky_relu(self.conv1_svhn(x), 0.05) # (?, 64, 16, 16) else: out = F.leaky_relu(self.conv1(x), 0.05) # (?, 64, 16, 16) out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8) out = F.leaky_relu(self.conv3(out), 0.05) # ( " ) return out def decode(self, out, svhn=False): out = F.leaky_relu(self.conv4(out), 0.05) out = F.leaky_relu(self.deconv1(out), 0.05) # (?, 64, 16, 16) if svhn: out = F.tanh(self.deconv2_svhn(out)) # (?, 3, 32, 32) else: out = F.tanh(self.deconv2(out)) # (?, 3, 32, 32) return out def encode_params(self): layers_basic = list(self.conv1_svhn.parameters()) + \ list(self.conv1.parameters()) layers_basic += list(self.conv2.parameters()) layers_basic += list(self.conv3.parameters()) return layers_basic def decode_params(self): layers_basic = list(self.deconv2_svhn.parameters()) + \ list(self.deconv2.parameters()) layers_basic += list(self.deconv1.parameters()) layers_basic += list(self.conv4.parameters()) return layers_basic def unshared_parameters(self): return list(self.deconv2_svhn.parameters()) + list(self.conv1_svhn.parameters()) + \ list(self.deconv2.parameters()) + list(self.conv1.parameters()) class G22(nn.Module): def __init__(self, conv_dim=64): super(G22, self).__init__() # encoding blocks self.conv1 = conv(3, conv_dim, 4) self.conv1_mnist = conv(1, conv_dim, 4) self.conv2 = conv(conv_dim, conv_dim * 2, 4) # residual blocks res_dim = conv_dim * 2 self.conv3 = conv(res_dim, res_dim, 3, 1, 1) self.conv4 = conv(res_dim, res_dim, 3, 1, 1) # decoding blocks self.deconv1 = deconv(conv_dim * 2, conv_dim, 4) self.deconv2 = deconv(conv_dim, 3, 4, bn=False) self.deconv2_mnist = deconv(conv_dim, 1, 4, bn=False) def forward(self, x, mnist=False): if mnist: out = F.leaky_relu(self.conv1_mnist(x), 0.05) # (?, 64, 16, 16) else: out = F.leaky_relu(self.conv1(x), 0.05) # (?, 64, 16, 16) out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8) out = F.leaky_relu(self.conv3(out), 0.05) # ( " ) out = F.leaky_relu(self.conv4(out), 0.05) # ( " ) out = F.leaky_relu(self.deconv1(out), 0.05) # (?, 64, 16, 16) if mnist: out = F.tanh(self.deconv2_mnist(out)) # (?, 3, 32, 32) else: out = F.tanh(self.deconv2(out)) # (?, 3, 32, 32) return out def encode(self, x, mnist=False): if mnist: out = F.leaky_relu(self.conv1_mnist(x), 0.05) # (?, 64, 16, 16) else: out = F.leaky_relu(self.conv1(x), 0.05) # (?, 64, 16, 16) out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8) out = F.leaky_relu(self.conv3(out), 0.05) # ( " ) return out def decode(self, out, mnist=False): out = F.leaky_relu(self.conv4(out), 0.05) out = F.leaky_relu(self.deconv1(out), 0.05) # (?, 64, 16, 16) if mnist: out = F.tanh(self.deconv2_mnist(out)) # (?, 3, 32, 32) else: out = F.tanh(self.deconv2(out)) # (?, 3, 32, 32) return out def encode_params(self): layers_basic = list(self.conv1_mnist.parameters()) + \ list(self.conv1.parameters()) layers_basic += list(self.conv2.parameters()) layers_basic += list(self.conv3.parameters()) return layers_basic def decode_params(self): layers_basic = list(self.deconv2_mnist.parameters()) + \ list(self.deconv2.parameters()) layers_basic += list(self.deconv1.parameters()) layers_basic += list(self.conv4.parameters()) return layers_basic def unshared_parameters(self): return list(self.deconv2_mnist.parameters()) + list(self.conv1_mnist.parameters()) + \ list(self.deconv2.parameters()) + list(self.conv1.parameters()) class D1(nn.Module): """Discriminator for mnist.""" def __init__(self, conv_dim=64, use_labels=False): super(D1, self).__init__() self.conv1 = conv(1, conv_dim, 4, bn=False) self.conv2 = conv(conv_dim, conv_dim * 2, 4) self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4) n_out = 11 if use_labels else 1 self.fc = conv(conv_dim * 4, n_out, 4, 1, 0, False) def forward(self, x_0): out = F.leaky_relu(self.conv1(x_0), 0.05) # (?, 64, 16, 16) out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8) out = F.leaky_relu(self.conv3(out), 0.05) # (?, 256, 4, 4) out_0 = self.fc(out).squeeze() return out_0 class D2(nn.Module): """Discriminator for svhn.""" def __init__(self, conv_dim=64, use_labels=False): super(D2, self).__init__() self.conv1 = conv(3, conv_dim, 4, bn=False) self.conv2 = conv(conv_dim, conv_dim * 2, 4) self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4) n_out = 11 if use_labels else 1 self.fc = conv(conv_dim * 4, n_out, 4, 1, 0, False) def forward(self, x_0): out = F.leaky_relu(self.conv1(x_0), 0.05) # (?, 64, 16, 16) out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8) out = F.leaky_relu(self.conv3(out), 0.05) # (?, 256, 4, 4) out_0 = self.fc(out).squeeze() return out_0 ================================================ FILE: mnist_to_svhn/solver_autoencoder.py ================================================ import os import numpy as np import scipy.io import torch from torch import optim from torch.autograd import Variable from model import D1, D2 from model import G11, G22 class Solver(object): def __init__(self, config, svhn_loader, mnist_loader): self.svhn_loader = svhn_loader self.mnist_loader = mnist_loader self.g11 = None self.g22 = None self.d1 = None self.d2 = None self.g_optimizer = None self.d_optimizer = None self.num_classes = config.num_classes self.beta1 = config.beta1 self.beta2 = config.beta2 self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.train_iters = config.train_iters self.batch_size = config.batch_size self.lr = config.lr self.log_step = config.log_step self.sample_step = config.sample_step self.sample_path = config.sample_path self.model_path = config.model_path self.kl_lambda = config.kl_lambda self.build_model() def build_model(self): """Builds a generator and a discriminator.""" self.g11 = G11(conv_dim=self.g_conv_dim) self.g22 = G22(conv_dim=self.g_conv_dim) self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False) self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False) g_params = list(self.g11.parameters()) + list(self.g22.parameters()) d_params = list(self.d1.parameters()) + list(self.d2.parameters()) self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2]) self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2]) if torch.cuda.is_available(): self.g11.cuda() self.g22.cuda() self.d1.cuda() self.d2.cuda() def merge_images(self, sources, targets, k=10): _, _, h, w = sources.shape row = int(np.sqrt(self.batch_size)) merged = np.zeros([3, row * h, row * w * 2]) for idx, (s, t) in enumerate(zip(sources, targets)): i = idx // row j = idx % row merged[:, i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h] = s merged[:, i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t return merged.transpose(1, 2, 0) def to_var(self, x): """Converts numpy to variable.""" if torch.cuda.is_available(): x = x.cuda() return Variable(x) def to_data(self, x): """Converts variable to numpy.""" if torch.cuda.is_available(): x = x.cpu() return x.data.numpy() def reset_grad(self): """Zeros the gradient buffers.""" self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def _compute_kl(self, mu): mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def train(self): svhn_iter = iter(self.svhn_loader) mnist_iter = iter(self.mnist_loader) iter_per_epoch = min(len(svhn_iter), len(mnist_iter)) # fixed mnist and svhn for sampling fixed_svhn = self.to_var(svhn_iter.next()[0]) fixed_mnist = self.to_var(mnist_iter.next()[0]) # Train autoencoder for mnist for step in range(self.train_iters + 1): # reset data_iter for each epoch if (step + 1) % iter_per_epoch == 0: mnist_iter = iter(self.mnist_loader) # mnist dataset mnist_data, m_labels_data = mnist_iter.next() mnist, m_labels = self.to_var(mnist_data), self.to_var(m_labels_data) # ============ train D ============# # train with real images self.reset_grad() out = self.d1(mnist) d1_loss = torch.mean((out - 1) ** 2) d_mnist_loss = d1_loss d_real_loss = d1_loss d_real_loss.backward() self.d_optimizer.step() # train with fake images self.reset_grad() fake_mnist = self.g22.forward(mnist, mnist=True) out = self.d1(fake_mnist) d2_loss = torch.mean(out ** 2) d_fake_loss = d2_loss d_fake_loss.backward() self.d_optimizer.step() # ============ train G ============ self.reset_grad() fake_mnist = self.g22.forward(mnist, mnist=True) out = self.d1(fake_mnist) g_loss = torch.mean((out - 1) ** 2) g_loss += torch.mean((mnist - fake_mnist) ** 2) em = self.g22.encode(mnist, mnist=True) g_loss += self.kl_lambda * self._compute_kl(em) g_loss.backward() self.g_optimizer.step() # print the log info if (step + 1) % self.log_step == 0: print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, ' 'g_loss: %.4f' % (step + 1, self.train_iters, d_real_loss.data[0], d_mnist_loss.data[0], d_fake_loss.data[0], g_loss.data[0])) # save the sampled images if (step + 1) % self.sample_step == 0: fake_mnist = self.g22.forward(fixed_mnist, mnist=True) mnist, fake_mnist = self.to_data(fixed_mnist), self.to_data(fake_mnist) merged = self.merge_images(mnist, fake_mnist) path = os.path.join(self.sample_path, 'sample-%d-m-s.png' % (step + 1)) scipy.misc.imsave(path, merged) print('saved %s' % path) if (step + 1) % 10000 == 0: # save the model parameters for each epoch g22_path = os.path.join(self.model_path, 'g22-%d.pkl' % (step + 1)) d1_path = os.path.join(self.model_path, 'd1-%d.pkl' % (step + 1)) torch.save(self.g22.state_dict(), g22_path) torch.save(self.d1.state_dict(), d1_path) # Train autoencoder for svhn for step in range(self.train_iters + 1): # reset data_iter for each epoch if (step + 1) % iter_per_epoch == 0: svhn_iter = iter(self.svhn_loader) # load svhn and mnist dataset svhn_data, s_labels_data = svhn_iter.next() svhn, s_labels = self.to_var(svhn_data), self.to_var(s_labels_data).long().squeeze() # ============ train D ============# # train with real images self.reset_grad() out = self.d2(svhn) d2_loss = torch.mean((out - 1) ** 2) d_svhn_loss = d2_loss d_real_loss = d2_loss d_real_loss.backward() self.d_optimizer.step() # train with fake images self.reset_grad() fake_svhn = self.g11.forward(svhn, svhn=True) out = self.d2(fake_svhn) d1_loss = torch.mean(out ** 2) d_fake_loss = d1_loss d_fake_loss.backward() self.d_optimizer.step() # ============ train G ============# self.reset_grad() fake_svhn = self.g11.forward(svhn, svhn=True) out = self.d2(fake_svhn) g_loss = torch.mean((out - 1) ** 2) g_loss += torch.mean((svhn - fake_svhn) ** 2) es = self.g11.encode(svhn, svhn=True) g_loss += self.kl_lambda * self._compute_kl(es) g_loss.backward() self.g_optimizer.step() # print the log info if (step + 1) % self.log_step == 0: print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f,' 'd_fake_loss: %.4f, g_loss: %.4f' % (step + 1, self.train_iters, d_real_loss.data[0], d_svhn_loss.data[0], d_fake_loss.data[0], g_loss.data[0])) # save the sampled images if (step + 1) % self.sample_step == 0: fake_svhn = self.g11.forward(fixed_svhn, svhn=True) svhn, fake_svhn = self.to_data(fixed_svhn), self.to_data(fake_svhn) merged = self.merge_images(svhn, fake_svhn) path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (step + 1)) scipy.misc.imsave(path, merged) print('saved %s' % path) if (step + 1) % 10000 == 0: # save the model parameters for each epoch g11_path = os.path.join(self.model_path, 'g11-%d.pkl' % (step + 1)) d2_path = os.path.join(self.model_path, 'd2-%d.pkl' % (step + 1)) torch.save(self.g11.state_dict(), g11_path) torch.save(self.d2.state_dict(), d2_path) ================================================ FILE: mnist_to_svhn/solver_mnist_to_svhn.py ================================================ import os import numpy as np import scipy.io import torch from torch import optim from torch.autograd import Variable from model import D1, D2 from model import G11 class Solver(object): def __init__(self, config, svhn_loader, mnist_loader): self.config = config self.svhn_loader = svhn_loader self.mnist_loader = mnist_loader self.g11 = None self.g22 = None self.d1 = None self.d2 = None self.g_optimizer = None self.num_classes = config.num_classes self.beta1 = config.beta1 self.beta2 = config.beta2 self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.train_iters = config.train_iters self.batch_size = config.batch_size self.lr = config.lr self.kl_lambda = config.kl_lambda self.log_step = config.log_step self.sample_step = config.sample_step self.sample_path = config.sample_path self.model_path = config.model_path self.g11_load_path = os.path.join(config.load_path, "g11-" + str(config.load_iter) + ".pkl") self.d1_load_path = os.path.join(config.load_path, "d1-" + str(config.load_iter) + ".pkl") self.g22_load_path = os.path.join(config.load_path, "g22-" + str(config.load_iter) + ".pkl") self.d2_load_path = os.path.join(config.load_path, "d2-" + str(config.load_iter) + ".pkl") self.build_model() def build_model(self): """Builds a generator and a discriminator.""" self.g11 = G11(conv_dim=self.g_conv_dim) self.g_optimizer = optim.Adam(list(self.g11.encode_params()) + list(self.g11.decode_params()), self.lr, [self.beta1, self.beta2]) self.unshared_optimizer = optim.Adam(list(self.g11.unshared_parameters()), self.lr, [self.beta1, self.beta2]) self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False) self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False) self.d_optimizer = optim.Adam(list(self.d1.parameters()) + list(self.d2.parameters()), self.lr, [self.beta1, self.beta2]) if torch.cuda.is_available(): self.g11.cuda() self.d1.cuda() self.d2.cuda() def merge_images(self, sources, targets, k=10): _, _, h, w = sources.shape row = int(np.sqrt(self.batch_size)) + 1 merged = np.zeros([3, row * h, row * w * 2]) for idx, (s, t) in enumerate(zip(sources, targets)): i = idx // row j = idx % row merged[:, i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h] = s merged[:, i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t return merged.transpose(1, 2, 0) def to_var(self, x, volatile=False): """Converts numpy to variable.""" if torch.cuda.is_available(): x = x.cuda() if volatile: return Variable(x, volatile=True) return Variable(x) def to_no_grad_var(self, x): x = self.to_data(x, no_numpy=True) return self.to_var(x, volatile=True) def to_data(self, x, no_numpy=False): """Converts variable to numpy.""" if torch.cuda.is_available(): x = x.cpu() if no_numpy: return x.data return x.data.numpy() def reset_grad(self): """Zeros the gradient buffers.""" self.unshared_optimizer.zero_grad() self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def _compute_kl(self, mu): mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def train(self): self.build_model() if self.config.pretrained_g: self.g11.load_state_dict(torch.load(self.g11_load_path)) svhn_iter = iter(self.svhn_loader) mnist_iter = iter(self.mnist_loader) iter_per_epoch = min(len(svhn_iter), len(mnist_iter)) # fixed mnist and svhn for sampling svhn_fixed_data, svhn_fixed_labels = svhn_iter.next() mnist_fixed_data, mnist_fixed_labels = mnist_iter.next() fixed_mnist = self.to_var(mnist_fixed_data) counter = 0 for step in range(self.train_iters + 1): # reset data_iter for each epoch if (step + 1) % iter_per_epoch == 0: mnist_iter = iter(self.mnist_loader) svhn_iter = iter(self.svhn_loader) # load svhn and mnist dataset svhn_data, s_labels_data = svhn_iter.next() mnist_data, m_labels_data = mnist_iter.next() svhn, s_labels = self.to_var(svhn_data), self.to_var(s_labels_data).long().squeeze() mnist, m_labels = self.to_var(mnist_data), self.to_var(m_labels_data) # This sets the maximum number of items for A domain # We assume max_items is a multiple of batch_size # And reset mnist loader when we pass the number of allowed items. if self.batch_size > self.config.max_items: exit(-1) elif self.batch_size == self.config.max_items: mnist = fixed_mnist elif self.batch_size < self.config.max_items: counter += 1 if counter * self.batch_size >= self.config.max_items: mnist_iter = iter(self.mnist_loader) counter = 0 # ============ train D ============# # train with real images self.reset_grad() out = self.d1(mnist) d1_loss = torch.mean((out - 1) ** 2) out = self.d2(svhn) d2_loss = torch.mean((out - 1) ** 2) d_mnist_loss = d1_loss d_svhn_loss = d2_loss # Only optimizing d1 d_real_loss = d1_loss + d2_loss d_real_loss.backward() self.d_optimizer.step() # train with fake images self.reset_grad() es = self.g11.encode(svhn, svhn=True) fake_mnist = self.g11.decode(es) out = self.d1(fake_mnist) d2_loss = torch.mean(out ** 2) em = self.g11.encode(mnist) fake_svhn = self.g11.decode(em, svhn=True) out = self.d2(fake_svhn) d1_loss = torch.mean(out ** 2) d_fake_loss = d2_loss + d1_loss d_fake_loss.backward() self.d_optimizer.step() # ============ train G ============# # train mnist-svhn-mnist cycle self.reset_grad() es = self.g11.encode(svhn, svhn=True) fake_mnist = self.g11.decode(es) out = self.d1(fake_mnist) g_loss = torch.mean((out - 1) ** 2) em = self.g11.encode(mnist) fake_svhn = self.g11.decode(em, svhn=True) out = self.d2(fake_svhn) g_loss += torch.mean((out - 1) ** 2) self.reset_grad() em = self.g11.encode(mnist) fake_mnist = self.g11.decode(em) g_loss += torch.mean((mnist - fake_mnist) ** 2) if self.config.one_way_cycle: em = self.g11.encode(mnist) fake_svhn = self.g11.decode(em, svhn=True) es = self.g11.encode(fake_svhn, svhn=True) fake_mnist = self.g11.decode(es) g_loss += torch.mean((mnist - fake_mnist) ** 2) g_loss.backward() self.unshared_optimizer.step() if not self.config.freeze_shared: self.reset_grad() es = self.g11.encode(svhn, svhn=True) fake_es = self.g11.decode(es, svhn=True) g_loss = torch.mean((svhn - fake_es) ** 2) g_loss += self.kl_lambda * self._compute_kl(es) g_loss.backward() self.g_optimizer.step() # print the log info if (step + 1) % self.log_step == 0: print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, ' 'd_fake_loss: %.4f, g_loss: %.4f' % (step + 1, self.train_iters, d_real_loss.data[0], d_mnist_loss.data[0], d_svhn_loss.data[0], d_fake_loss.data[0], g_loss.data[0])) # save the sampled images if (step + 1) % self.sample_step == 0: em = self.g11.encode(fixed_mnist) fake_svhn_var = self.g11.decode(em, svhn=True) fake_svhn = self.to_data(fake_svhn_var) if self.config.save_models_and_samples: merged = self.merge_images(mnist_fixed_data, fake_svhn) path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (step + 1)) scipy.misc.imsave(path, merged) print('saved %s' % path) if (step + 1) % self.config.num_iters_save_model_and_return == 0: # save the model parameters for each epoch if self.config.save_models_and_samples: g11_path = os.path.join(self.model_path, 'g11-%d.pkl' % (step + 1)) d1_path = os.path.join(self.model_path, 'd1-%d.pkl' % (step + 1)) d2_path = os.path.join(self.model_path, 'd2-%d.pkl' % (step + 1)) torch.save(self.g11.state_dict(), g11_path) torch.save(self.d1.state_dict(), d1_path) torch.save(self.d2.state_dict(), d2_path) return ================================================ FILE: mnist_to_svhn/solver_svhn_to_mnist.py ================================================ import os import numpy as np import scipy.io import torch from torch import optim from torch.autograd import Variable from model import D1, D2 from model import G22 class Solver(object): def __init__(self, config, svhn_loader, mnist_loader): self.config = config self.svhn_loader = svhn_loader self.mnist_loader = mnist_loader self.g11 = None self.g22 = None self.d1 = None self.d2 = None self.g_optimizer = None self.num_classes = config.num_classes self.beta1 = config.beta1 self.beta2 = config.beta2 self.g_conv_dim = config.g_conv_dim self.d_conv_dim = config.d_conv_dim self.train_iters = config.train_iters self.batch_size = config.batch_size self.lr = config.lr self.kl_lambda = config.kl_lambda self.log_step = config.log_step self.sample_step = config.sample_step self.sample_path = config.sample_path self.model_path = config.model_path self.g11_load_path = os.path.join(config.load_path, "g11-" + str(config.load_iter) + ".pkl") self.d1_load_path = os.path.join(config.load_path, "d1-" + str(config.load_iter) + ".pkl") self.g22_load_path = os.path.join(config.load_path, "g22-" + str(config.load_iter) + ".pkl") self.d2_load_path = os.path.join(config.load_path, "d2-" + str(config.load_iter) + ".pkl") self.build_model() def build_model(self): """Builds a generator and a discriminator.""" self.g22 = G22(conv_dim=self.g_conv_dim) self.g_optimizer = optim.Adam(list(self.g22.encode_params()) + list(self.g22.decode_params()), self.lr, [self.beta1, self.beta2]) self.unshared_optimizer = optim.Adam(list(self.g22.unshared_parameters()), self.lr, [self.beta1, self.beta2]) self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False) self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False) self.d_optimizer = optim.Adam(list(self.d1.parameters()) + list(self.d2.parameters()), self.lr, [self.beta1, self.beta2]) if torch.cuda.is_available(): self.g22.cuda() self.d1.cuda() self.d2.cuda() def merge_images(self, sources, targets, k=10): _, _, h, w = sources.shape row = int(np.sqrt(self.batch_size)) + 1 merged = np.zeros([3, row * h, row * w * 2]) for idx, (s, t) in enumerate(zip(sources, targets)): i = idx // row j = idx % row merged[:, i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h] = s merged[:, i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t return merged.transpose(1, 2, 0) def to_var(self, x, volatile=False): """Converts numpy to variable.""" if torch.cuda.is_available(): x = x.cuda() if volatile: return Variable(x, volatile=True) return Variable(x) def to_no_grad_var(self, x): x = self.to_data(x, no_numpy=True) return self.to_var(x, volatile=True) def to_data(self, x, no_numpy=False): """Converts variable to numpy.""" if torch.cuda.is_available(): x = x.cpu() if no_numpy: return x.data return x.data.numpy() def reset_grad(self): """Zeros the gradient buffers.""" self.unshared_optimizer.zero_grad() self.g_optimizer.zero_grad() self.d_optimizer.zero_grad() def _compute_kl(self, mu): mu_2 = torch.pow(mu, 2) encoding_loss = torch.mean(mu_2) return encoding_loss def train(self): self.build_model() if self.config.pretrained_g: self.g22.load_state_dict(torch.load(self.g22_load_path)) svhn_iter = iter(self.svhn_loader) mnist_iter = iter(self.mnist_loader) iter_per_epoch = min(len(svhn_iter), len(mnist_iter)) # fixed mnist and svhn for sampling svhn_fixed_data, svhn_fixed_labels = svhn_iter.next() mnist_fixed_data, mnist_fixed_labels = mnist_iter.next() fixed_svhn = self.to_var(svhn_fixed_data) counter = 0 for step in range(self.train_iters + 1): # reset data_iter for each epoch if (step + 1) % iter_per_epoch == 0: mnist_iter = iter(self.mnist_loader) svhn_iter = iter(self.svhn_loader) # load svhn and mnist dataset svhn_data, s_labels_data = svhn_iter.next() mnist_data, m_labels_data = mnist_iter.next() svhn, s_labels = self.to_var(svhn_data), self.to_var(s_labels_data).long().squeeze() mnist, m_labels = self.to_var(mnist_data), self.to_var(m_labels_data) # This sets the maximum number of items for A domain # We assume max_items is a multiple of batch_size # And reset mnist loader when we pass the number of allowed items. if self.batch_size > self.config.max_items: exit(-1) elif self.batch_size == self.config.max_items: svhn = fixed_svhn elif self.batch_size < self.config.max_items: counter += 1 if counter * self.batch_size >= self.config.max_items: svhn_iter = iter(self.svhn_loader) counter = 0 # ============ train D ============# # train with real images self.reset_grad() out = self.d1(mnist) d1_loss = torch.mean((out - 1) ** 2) out = self.d2(svhn) d2_loss = torch.mean((out - 1) ** 2) d_mnist_loss = d1_loss d_svhn_loss = d2_loss # Only optimizing d1 d_real_loss = d1_loss + d2_loss d_real_loss.backward() self.d_optimizer.step() # train with fake images self.reset_grad() es = self.g22.encode(svhn) fake_mnist = self.g22.decode(es, mnist=True) out = self.d1(fake_mnist) d2_loss = torch.mean(out ** 2) em = self.g22.encode(mnist, mnist=True) fake_svhn = self.g22.decode(em) out = self.d2(fake_svhn) d1_loss = torch.mean(out ** 2) d_fake_loss = d2_loss + d1_loss d_fake_loss.backward() self.d_optimizer.step() # ============ train G ============# self.reset_grad() es = self.g22.encode(svhn) fake_mnist = self.g22.decode(es, mnist=True) out = self.d1(fake_mnist) g_loss = torch.mean((out - 1) ** 2) em = self.g22.encode(mnist, mnist=True) fake_svhn = self.g22.decode(em) out = self.d2(fake_svhn) g_loss += torch.mean((out - 1) ** 2) self.reset_grad() es = self.g22.encode(svhn) fake_svhn = self.g22.decode(es) g_loss += torch.mean((svhn - fake_svhn) ** 2) if self.config.one_way_cycle: es = self.g22.encode(svhn) fake_mnist = self.g22.decode(es, mnist=True) es = self.g22.encode(fake_mnist, mnist=True) fake_svhn = self.g22.decode(es) g_loss += torch.mean((svhn - fake_svhn) ** 2) g_loss.backward() self.unshared_optimizer.step() if not self.config.freeze_shared: self.reset_grad() em = self.g22.encode(mnist, mnist=True) fake_em = self.g22.decode(em, mnist=True) g_loss = torch.mean((mnist - fake_em) ** 2) g_loss += self.kl_lambda * self._compute_kl(em) g_loss.backward() self.g_optimizer.step() # print the log info if (step + 1) % self.log_step == 0: print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, ' 'd_fake_loss: %.4f, g_loss: %.4f' % (step + 1, self.train_iters, d_real_loss.data[0], d_mnist_loss.data[0], d_svhn_loss.data[0], d_fake_loss.data[0], g_loss.data[0])) # save the sampled images if (step + 1) % self.sample_step == 0: es = self.g22.encode(fixed_svhn) fake_mnist_var = self.g22.decode(es, mnist=True) fake_mnist = self.to_data(fake_mnist_var) if self.config.save_models_and_samples: merged = self.merge_images(svhn_fixed_data, fake_mnist) path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (step + 1)) scipy.misc.imsave(path, merged) print('saved %s' % path) if (step + 1) % self.config.num_iters_save_model_and_return == 0: # save the model parameters for each epoch if self.config.save_models_and_samples: g22_path = os.path.join(self.model_path, 'g22-%d.pkl' % (step + 1)) d1_path = os.path.join(self.model_path, 'd1-%d.pkl' % (step + 1)) d2_path = os.path.join(self.model_path, 'd2-%d.pkl' % (step + 1)) torch.save(self.g22.state_dict(), g22_path) torch.save(self.d1.state_dict(), d1_path) torch.save(self.d2.state_dict(), d2_path) return