Repository: sagiebenaim/OneShotTranslation
Branch: master
Commit: 6f790ae5f4eb
Files: 43
Total size: 155.3 KB
Directory structure:
gitextract_3ffpn665/
├── LICENSE
├── README.md
├── drawing_and_style_transfer/
│ ├── data/
│ │ ├── __init__.py
│ │ ├── aligned_dataset.py
│ │ ├── base_data_loader.py
│ │ ├── base_dataset.py
│ │ ├── image_folder.py
│ │ ├── single_dataset.py
│ │ └── unaligned_dataset.py
│ ├── datasets/
│ │ ├── combine_A_and_B.py
│ │ ├── download_cyclegan_dataset.sh
│ │ └── make_dataset_aligned.py
│ ├── environment.yml
│ ├── models/
│ │ ├── __init__.py
│ │ ├── autoencoder_model.py
│ │ ├── base_model.py
│ │ ├── networks.py
│ │ ├── ost.py
│ │ └── test_model.py
│ ├── options/
│ │ ├── __init__.py
│ │ ├── base_options.py
│ │ ├── test_options.py
│ │ └── train_options.py
│ ├── scripts/
│ │ ├── test_ost.sh
│ │ ├── train_autoencoder.sh
│ │ └── train_ost.sh
│ ├── test.py
│ ├── train.py
│ └── util/
│ ├── __init__.py
│ ├── get_data.py
│ ├── html.py
│ ├── image_pool.py
│ ├── util.py
│ └── visualizer.py
└── mnist_to_svhn/
├── data_loader.py
├── download.sh
├── main_autoencoder.py
├── main_mnist_to_svhn.py
├── main_svhn_to_mnist.py
├── model.py
├── solver_autoencoder.py
├── solver_mnist_to_svhn.py
└── solver_svhn_to_mnist.py
================================================
FILE CONTENTS
================================================
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2017
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
--------------------------- LICENSE FOR mnist-svhn-transfer ---------
MIT License
Copyright (c) 2017
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
--------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ---------
Copyright (c) 2017, Jun-Yan Zhu and Taesung Park
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
================================================
FILE: README.md
================================================
# Pytorch implementation of One-Shot Unsupervised Cross Domain Translation ([arxiv](https://arxiv.org/abs/1806.06029)).
Prerequisites
--------------
- Python 3.6
- Pytorch 0.4
- Numpy/Scipy/Pandas
- Progressbar
- OpenCV
- [visdom](https://github.com/facebookresearch/visdom)
- [dominate](https://github.com/Knio/dominate)
## MNIST-to-SVHN and SVHN-to-MNIST
To train autoencoder for both MNIST and SVHN (In mnist_to_svhn folder):
python main_autoencoder.py --use_augmentation=True
To train OST for MNIST to SVHN:
python main_mnist_to_svhn.py --pretrained_g=True --save_models_and_samples=True --use_augmentation=True --one_way_cycle=True --freeze_shared=False
To train OST for SVHN to MNIST:
python main_svhn_to_mnist.py --pretrained_g=True --save_models_and_samples=True --use_augmentation=True --one_way_cycle=True --freeze_shared=False
## Drawing and Style Transfer Tasks
### Download Dataset
To download dataset (in drawing_and_style_transfer folder):
bash datasets/download_cyclegan_dataset.sh $DATASET_NAME
where DATASET_NAME is one of (facades, cityscapes, maps, monet2photo, summer2winter_yosemite)
### Train Autoencoder
To train autoencoder for facades (in drawing_and_style_transfer folder):
python train.py --dataroot=./datasets/facades/trainB --name=facades_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2
In the reverse direction (images of facades):
python train.py --dataroot=./datasets/facades/trainA --name=facades_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2
### Train OST
To train OST for images to facades:
python train.py --dataroot=./datasets/facades/ --name=facades_ost --load_dir=facades_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1
To train OST for facades to images (reverse direction):
python train.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --load_dir=facades_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'
To visualize losses: run python -m visdom.server
### Test OST
To test OST for images to facades:
python test.py --dataroot=./datasets/facades/ --name=facades_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1
To test OST for facades to images (reverse direction):
python test.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'
### Options
Additional scripts for other datasets are at ./drawing_and_style_transfer/scripts
Options are at ./drawing_and_style_transfer/options
## Reference
If you found this code useful, please cite the following paper:
```
@inproceedings{Benaim2018OneShotUC,
title={One-Shot Unsupervised Cross Domain Translation},
author={Sagie Benaim and Lior Wolf},
booktitle={NeurIPS},
year={2018}
}
```
================================================
FILE: drawing_and_style_transfer/data/__init__.py
================================================
import torch.utils.data
from data.base_data_loader import BaseDataLoader
def CreateDataLoader(opt):
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
def CreateDataset(opt):
if opt.dataset_mode == 'aligned':
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
elif opt.dataset_mode == 'unaligned':
from data.unaligned_dataset import UnalignedDataset
dataset = UnalignedDataset()
elif opt.dataset_mode == 'single':
from data.single_dataset import SingleDataset
dataset = SingleDataset()
else:
raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self
def __len__(self):
return len(self.dataset)
def __iter__(self):
for i, data in enumerate(self.dataloader):
yield data
================================================
FILE: drawing_and_style_transfer/data/aligned_dataset.py
================================================
import os.path
import random
import torchvision.transforms as transforms
import torch
from data.base_dataset import BaseDataset
from data.image_folder import make_dataset
from PIL import Image
class AlignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_AB = os.path.join(opt.dataroot, opt.phase)
self.AB_paths = sorted(make_dataset(self.dir_AB))
assert (opt.resize_or_crop == 'resize_and_crop')
def __getitem__(self, index):
AB_path = self.AB_paths[index]
AB = Image.open(AB_path).convert('RGB')
w, h = AB.size
w2 = int(w / 2)
A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
A = transforms.ToTensor()(A)
B = transforms.ToTensor()(B)
w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
else:
input_nc = self.opt.input_nc
output_nc = self.opt.output_nc
if (not self.opt.no_flip) and random.random() < 0.5:
idx = [i for i in range(A.size(2) - 1, -1, -1)]
idx = torch.LongTensor(idx)
A = A.index_select(2, idx)
B = B.index_select(2, idx)
if input_nc == 1: # RGB to gray
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
A = tmp.unsqueeze(0)
if output_nc == 1: # RGB to gray
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
B = tmp.unsqueeze(0)
return {'A': A, 'B': B,
'A_paths': AB_path, 'B_paths': AB_path}
def __len__(self):
return len(self.AB_paths)
def name(self):
return 'AlignedDataset'
================================================
FILE: drawing_and_style_transfer/data/base_data_loader.py
================================================
class BaseDataLoader():
def __init__(self):
pass
def initialize(self, opt):
self.opt = opt
pass
def load_data(self):
return None
================================================
FILE: drawing_and_style_transfer/data/base_dataset.py
================================================
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_transform(opt):
transform_list = []
if opt.resize_or_crop == 'resize_and_crop':
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, Image.BICUBIC))
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'crop':
transform_list.append(transforms.RandomCrop(opt.fineSize))
elif opt.resize_or_crop == 'scale_width':
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.fineSize)))
elif opt.resize_or_crop == 'scale_width_and_crop':
transform_list.append(transforms.Lambda(
lambda img: __scale_width(img, opt.loadSize)))
transform_list.append(transforms.RandomCrop(opt.fineSize))
if opt.isTrain and not opt.no_flip_and_rotation:
# Default augmentations as in paper
transform_list.append(transforms.RandomHorizontalFlip())
transform_list.append(transforms.RandomRotation(opt.rotation_degree))
transform_list += [transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def __scale_width(img, target_width):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), Image.BICUBIC)
================================================
FILE: drawing_and_style_transfer/data/image_folder.py
================================================
###############################################################################
# Code from
# https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
# Modified the original code so that it also loads images from the current
# directory as well as the subdirectories
###############################################################################
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_items=-1, start=0):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
if max_items >= 0:
return sorted(images)[start:start + max_items]
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise (RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " +
",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)
================================================
FILE: drawing_and_style_transfer/data/single_dataset.py
================================================
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
class SingleDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = os.path.join(opt.dataroot)
self.A_paths = make_dataset(self.dir_A)
self.A_paths = sorted(self.A_paths)
self.transform = get_transform(opt)
def __getitem__(self, index):
A_path = self.A_paths[index]
A_img = Image.open(A_path).convert('RGB')
A = self.transform(A_img)
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
else:
input_nc = self.opt.input_nc
if input_nc == 1: # RGB to gray
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
A = tmp.unsqueeze(0)
return {'A': A, 'A_paths': A_path}
def __len__(self):
return len(self.A_paths)
def name(self):
return 'SingleImageDataset'
================================================
FILE: drawing_and_style_transfer/data/unaligned_dataset.py
================================================
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset
from PIL import Image
import random
class UnalignedDataset(BaseDataset):
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
self.dir_A = os.path.join(opt.dataroot, opt.phase + opt.A)
self.dir_B = os.path.join(opt.dataroot, opt.phase + opt.B)
self.A_paths = make_dataset(self.dir_A, max_items=opt.max_items_A, start=opt.start)
self.B_paths = make_dataset(self.dir_B, max_items=opt.max_items_B, start=opt.start)
self.A_paths = sorted(self.A_paths)
self.B_paths = sorted(self.B_paths)
self.A_size = len(self.A_paths)
self.B_size = len(self.B_paths)
self.transform = get_transform(opt)
def __getitem__(self, index):
A_path = self.A_paths[index % self.A_size]
if self.opt.serial_batches:
index_B = index % self.B_size
else:
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
# print('(A, B) = (%d, %d)' % (index_A, index_B))
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
A = self.transform(A_img)
B = self.transform(B_img)
if self.opt.which_direction == 'BtoA':
input_nc = self.opt.output_nc
output_nc = self.opt.input_nc
else:
input_nc = self.opt.input_nc
output_nc = self.opt.output_nc
if input_nc == 1: # RGB to gray
tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
A = tmp.unsqueeze(0)
if output_nc == 1: # RGB to gray
tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
B = tmp.unsqueeze(0)
return {'A': A, 'B': B,
'A_paths': A_path, 'B_paths': B_path}
def __len__(self):
return max(self.A_size, self.B_size)
def name(self):
return 'UnalignedDataset'
================================================
FILE: drawing_and_style_transfer/datasets/combine_A_and_B.py
================================================
import os
import numpy as np
import cv2
import argparse
parser = argparse.ArgumentParser('create image pairs')
parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true')
args = parser.parse_args()
for arg in vars(args):
print('[%s] = ' % arg, getattr(args, arg))
splits = os.listdir(args.fold_A)
for sp in splits:
img_fold_A = os.path.join(args.fold_A, sp)
img_fold_B = os.path.join(args.fold_B, sp)
img_list = os.listdir(img_fold_A)
if args.use_AB:
img_list = [img_path for img_path in img_list if '_A.' in img_path]
num_imgs = min(args.num_imgs, len(img_list))
print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
img_fold_AB = os.path.join(args.fold_AB, sp)
if not os.path.isdir(img_fold_AB):
os.makedirs(img_fold_AB)
print('split = %s, number of images = %d' % (sp, num_imgs))
for n in range(num_imgs):
name_A = img_list[n]
path_A = os.path.join(img_fold_A, name_A)
if args.use_AB:
name_B = name_A.replace('_A.', '_B.')
else:
name_B = name_A
path_B = os.path.join(img_fold_B, name_B)
if os.path.isfile(path_A) and os.path.isfile(path_B):
name_AB = name_A
if args.use_AB:
name_AB = name_AB.replace('_A.', '.') # remove _A
path_AB = os.path.join(img_fold_AB, name_AB)
im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR)
im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR)
im_AB = np.concatenate([im_A, im_B], 1)
cv2.imwrite(path_AB, im_AB)
================================================
FILE: drawing_and_style_transfer/datasets/download_cyclegan_dataset.sh
================================================
FILE=$1
if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
exit 1
fi
URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
ZIP_FILE=./datasets/$FILE.zip
TARGET_DIR=./datasets/$FILE/
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d ./datasets/
rm $ZIP_FILE
================================================
FILE: drawing_and_style_transfer/datasets/make_dataset_aligned.py
================================================
import os
from PIL import Image
def get_file_paths(folder):
image_file_paths = []
for root, dirs, filenames in os.walk(folder):
filenames = sorted(filenames)
for filename in filenames:
input_path = os.path.abspath(root)
file_path = os.path.join(input_path, filename)
if filename.endswith('.png') or filename.endswith('.jpg'):
image_file_paths.append(file_path)
break # prevent descending into subfolders
return image_file_paths
def align_images(a_file_paths, b_file_paths, target_path):
if not os.path.exists(target_path):
os.makedirs(target_path)
for i in range(len(a_file_paths)):
img_a = Image.open(a_file_paths[i])
img_b = Image.open(b_file_paths[i])
assert(img_a.size == img_b.size)
aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1]))
aligned_image.paste(img_a, (0, 0))
aligned_image.paste(img_b, (img_a.size[0], 0))
aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i)))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset-path',
dest='dataset_path',
help='Which folder to process (it should have subfolders testA, testB, trainA and trainB'
)
args = parser.parse_args()
dataset_folder = args.dataset_path
print(dataset_folder)
test_a_path = os.path.join(dataset_folder, 'testA')
test_b_path = os.path.join(dataset_folder, 'testB')
test_a_file_paths = get_file_paths(test_a_path)
test_b_file_paths = get_file_paths(test_b_path)
assert(len(test_a_file_paths) == len(test_b_file_paths))
test_path = os.path.join(dataset_folder, 'test')
train_a_path = os.path.join(dataset_folder, 'trainA')
train_b_path = os.path.join(dataset_folder, 'trainB')
train_a_file_paths = get_file_paths(train_a_path)
train_b_file_paths = get_file_paths(train_b_path)
assert(len(train_a_file_paths) == len(train_b_file_paths))
train_path = os.path.join(dataset_folder, 'train')
align_images(test_a_file_paths, test_b_file_paths, test_path)
align_images(train_a_file_paths, train_b_file_paths, train_path)
================================================
FILE: drawing_and_style_transfer/environment.yml
================================================
name: OST
channels:
- peterjc123
- defaults
dependencies:
- python=3.6.5
- pytorch=0.4.0
- scipy
- pip:
- dominate==2.3.1
- git+https://github.com/pytorch/vision.git
- Pillow==5.0.0
- numpy==1.14.1
- visdom==0.1.7
================================================
FILE: drawing_and_style_transfer/models/__init__.py
================================================
def create_model(opt):
print(opt.model)
if opt.model == 'ost':
assert (opt.dataset_mode == 'unaligned')
from .ost import OSTModel
model = OSTModel()
elif opt.model == 'autoencoder':
assert (opt.dataset_mode == 'single')
from .autoencoder_model import AutoEncoderModel
model = AutoEncoderModel()
elif opt.model == 'test':
assert (opt.dataset_mode == 'single')
from .test_model import TestModel
model = TestModel()
else:
raise NotImplementedError('model [%s] not implemented.' % opt.model)
model.initialize(opt)
print("model [%s] was created" % (model.name()))
return model
================================================
FILE: drawing_and_style_transfer/models/autoencoder_model.py
================================================
import torch
from collections import OrderedDict
from torch.autograd import Variable
import itertools
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class AutoEncoderModel(BaseModel):
def name(self):
return 'AutoEncoderModel'
def set_encoders_and_decoders(self, opt):
n_downsampling = opt.n_downsampling
start_unshared = 0
num_unshared = opt.num_unshared
start_shared = num_unshared
end_shared = n_downsampling
start_dec_shared = start_unshared
end_dec_shared = start_unshared + (end_shared - start_shared)
start_dec_unshared = end_dec_shared
end_dec_unshared = n_downsampling
num_res_blocks_unshared = opt.num_res_blocks_unshared
n_res_blocks_shared = opt.num_res_blocks_shared
self.netEnc_b, self.netDec_b = networks.define_ED(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout,
opt.init_type, self.gpu_ids,
n_blocks_encoder=num_res_blocks_unshared,
n_blocks_decoder=num_res_blocks_unshared,
start=start_unshared,
end=num_unshared, n_downsampling=n_downsampling,
input_layer=True,
output_layer=True, start_dec=start_dec_unshared,
end_dec=end_dec_unshared)
self.netEnc_shared, self.netDec_shared = networks.define_ED(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout,
opt.init_type, self.gpu_ids,
n_blocks_encoder=n_res_blocks_shared,
n_blocks_decoder=n_res_blocks_shared,
start=start_shared, n_downsampling=n_downsampling,
end=end_shared,
input_layer=False,
output_layer=False, start_dec=start_dec_shared,
end_dec=end_dec_shared)
def initialize(self, opt):
BaseModel.initialize(self, opt)
self.set_encoders_and_decoders(opt)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
if not self.isTrain or opt.continue_train:
which_epoch = opt.which_epoch
self.load_network(self.netEnc_b, 'Enc_b', which_epoch)
self.load_network(self.netDec_b, 'Dec_b', which_epoch)
self.load_network(self.netEnc_shared, 'Enc_shared', which_epoch)
self.load_network(self.netDec_shared, 'Dec_shared', which_epoch)
if self.isTrain:
self.load_network(self.netD, 'D', which_epoch)
if self.isTrain:
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionIdt = torch.nn.L1Loss()
# initialize optimizers
self.optimizer_Enc = torch.optim.Adam(
itertools.chain(self.netEnc_b.parameters(), self.netEnc_shared.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_Dec = torch.optim.Adam(
itertools.chain(self.netDec_b.parameters(), self.netDec_shared.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = []
self.schedulers = []
self.optimizers.append(self.optimizer_Enc)
self.optimizers.append(self.optimizer_Dec)
self.optimizers.append(self.optimizer_D)
for optimizer in self.optimizers:
self.schedulers.append(networks.get_scheduler(optimizer, opt))
print('---------- Networks initialized -------------')
networks.print_network(self.netEnc_b)
networks.print_network(self.netDec_b)
networks.print_network(self.netEnc_shared)
networks.print_network(self.netDec_shared)
if self.isTrain:
networks.print_network(self.netD)
print('-----------------------------------------------')
def set_input(self, input):
# 'A' is given as single_dataset
input_B = input['A']
if len(self.gpu_ids) > 0:
input_B = input_B.cuda(self.gpu_ids[0], async=True)
self.input_B = input_B
# 'A' is given as single_dataset
self.image_paths = input['A_paths']
def forward(self):
self.real_B = Variable(self.input_B)
def netEnc(self, x):
return self.netEnc_shared(self.netEnc_b(x))
def netDec(self, x):
return self.netDec_b(self.netDec_shared(x))
def test(self):
real_B = Variable(self.input_B, volatile=True)
fake_B = self.netDec(self.netEnc(real_B))
self.fake_B = fake_B.data
# get image paths
def get_image_paths(self):
return self.image_paths
def backward_D_basic(self, netD, real, fake):
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_D(self):
fake_B = self.fake_B_pool.query(self.fake_B)
loss_D = self.backward_D_basic(self.netD, self.real_B, fake_B)
self.loss_D = loss_D.data[0]
def _compute_kl(self, mu):
mu_2 = torch.pow(mu, 2)
encoding_loss = torch.mean(mu_2)
return encoding_loss
def backward_G(self):
lambda_B = self.opt.lambda_B
# GAN loss D_B(G_B(B))
enc_b = self.netEnc(self.real_B)
fake_B = self.netDec(enc_b)
pred_fake = self.netD(fake_B)
loss_Gan = self.criterionGAN(pred_fake, True)
loss_idt_B = self.criterionIdt(fake_B, self.real_B) * lambda_B
loss_kl_B = self.opt.kl_lambda * self._compute_kl(enc_b)
# combined loss
loss_G = loss_Gan + loss_idt_B + loss_kl_B
loss_G.backward()
self.fake_B = fake_B.data
self.loss_Gan = loss_Gan.data[0]
self.loss_idt_B = loss_idt_B.data[0]
def optimize_parameters(self):
# forward
self.forward()
# G
self.optimizer_Enc.zero_grad()
self.optimizer_Dec.zero_grad()
self.backward_G()
self.optimizer_Enc.step()
self.optimizer_Dec.step()
# D
self.optimizer_D.zero_grad()
self.backward_D()
self.optimizer_D.step()
def get_current_errors(self):
ret_errors = OrderedDict([('D', self.loss_D), ('G_B', self.loss_Gan), ('Idt_B', self.loss_idt_B)])
return ret_errors
def get_current_visuals(self):
real_B = util.tensor2im(self.input_B)
fake_B = util.tensor2im(self.fake_B)
ret_visuals = OrderedDict([('real_B', real_B), ('fake_B', fake_B), ])
return ret_visuals
def save(self, label):
self.save_network(self.netEnc_b, 'Enc_b', label, self.gpu_ids)
self.save_network(self.netDec_b, 'Dec_b', label, self.gpu_ids)
self.save_network(self.netEnc_shared, 'Enc_shared', label, self.gpu_ids)
self.save_network(self.netDec_shared, 'Dec_shared', label, self.gpu_ids)
self.save_network(self.netD, 'D', label, self.gpu_ids)
================================================
FILE: drawing_and_style_transfer/models/base_model.py
================================================
import os
import torch
class BaseModel(object):
def name(self):
return 'BaseModel'
def initialize(self, opt):
self.opt = opt
self.gpu_ids = opt.gpu_ids
self.isTrain = opt.isTrain
self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
self.load_dir = os.path.join(opt.checkpoints_dir, opt.load_dir)
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
def set_input(self, input):
self.input = input
def forward(self):
pass
# used in test time, no backprop
def test(self):
pass
def get_image_paths(self):
pass
def optimize_parameters(self):
pass
def get_current_visuals(self):
return self.input
def get_current_errors(self):
return {}
def save(self, label):
pass
# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label, gpu_ids):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.cpu().state_dict(), save_path)
if len(gpu_ids) and torch.cuda.is_available():
network.cuda(gpu_ids[0])
# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
save_path = os.path.join(self.load_dir, save_filename)
network.load_state_dict(torch.load(save_path))
# update learning rate (called once every epoch)
def update_learning_rate(self):
for scheduler in self.schedulers:
scheduler.step()
lr = self.optimizers[0].param_groups[0]['lr']
print('learning rate = %.7f' % lr)
def as_np(self, data):
return data.cpu().data.numpy()
================================================
FILE: drawing_and_style_transfer/models/networks.py
================================================
import torch
import torch.nn as nn
from torch.nn import init
import functools
from torch.autograd import Variable
from torch.optim import lr_scheduler
###############################################################################
# Functions
###############################################################################
class pixel_norm(nn.Module):
def forward(self, x, epsilon=1e-8):
return x * torch.rsqrt(torch.mean(x.pow(2), dim=1, keepdim=True) + epsilon)
def weights_init_normal(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('Linear') != -1:
init.normal(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_xavier(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
init.xavier_normal(m.weight.data, gain=0.02)
elif classname.find('Linear') != -1:
init.xavier_normal(m.weight.data, gain=0.02)
elif classname.find('BatchNorm2d') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_kaiming(m):
classname = m.__class__.__name__
# print(classname)
if classname.find('Conv') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('Linear') != -1:
init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
elif classname.find('BatchNorm2d') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def weights_init_orthogonal(m):
classname = m.__class__.__name__
print(classname)
if classname.find('Conv') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('Linear') != -1:
init.orthogonal(m.weight.data, gain=1)
elif classname.find('BatchNorm2d') != -1:
init.normal(m.weight.data, 1.0, 0.02)
init.constant(m.bias.data, 0.0)
def init_weights(net, init_type='normal'):
print('initialization method [%s]' % init_type)
if init_type == 'normal':
net.apply(weights_init_normal)
elif init_type == 'xavier':
net.apply(weights_init_xavier)
elif init_type == 'kaiming':
net.apply(weights_init_kaiming)
elif init_type == 'orthogonal':
net.apply(weights_init_orthogonal)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def get_scheduler(optimizer, opt):
if opt.lr_policy == 'lambda':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif opt.lr_policy == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
elif opt.lr_policy == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
return scheduler
def define_ED(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal',
gpu_ids=[], n_downsampling=2, start=0, end=2, input_layer=True, output_layer=True, n_blocks_encoder=9,
n_blocks_decoder=9, start_dec=0, end_dec=1):
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert (torch.cuda.is_available())
if which_model_netG == 'resnet_9blocks':
netE = ResnetEncoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
n_blocks=n_blocks_encoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, start=start,
end=end, input_layer=input_layer)
netD = ResnetDecoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
n_blocks=n_blocks_decoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, end=end_dec,
start=start_dec, output_layer=output_layer)
elif which_model_netG == 'resnet_6blocks':
netE = ResnetEncoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
n_blocks=n_blocks_encoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, start=start,
end=end, input_layer=input_layer)
netD = ResnetDecoder(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
n_blocks=n_blocks_decoder, gpu_ids=gpu_ids, n_downsampling=n_downsampling, end=end_dec,
start=start_dec, output_layer=output_layer)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
netE.cuda(gpu_ids[0])
netD.cuda(gpu_ids[0])
init_weights(netE, init_type=init_type)
init_weights(netD, init_type=init_type)
return netE, netD
def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal',
gpu_ids=[]):
netG = None
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert (torch.cuda.is_available())
if which_model_netG == 'resnet_9blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,
gpu_ids=gpu_ids)
elif which_model_netG == 'resnet_6blocks':
netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,
gpu_ids=gpu_ids)
elif which_model_netG == 'unet_128':
netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
gpu_ids=gpu_ids)
elif which_model_netG == 'unet_256':
netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout,
gpu_ids=gpu_ids)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
if len(gpu_ids) > 0:
netG.cuda(gpu_ids[0])
init_weights(netG, init_type=init_type)
return netG
def define_D(input_nc, ndf, which_model_netD,
n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[]):
use_gpu = len(gpu_ids) > 0
norm_layer = get_norm_layer(norm_type=norm)
if use_gpu:
assert (torch.cuda.is_available())
if which_model_netD == 'basic':
netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
gpu_ids=gpu_ids)
elif which_model_netD == 'n_layers':
netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid,
gpu_ids=gpu_ids)
elif which_model_netD == 'pixel':
netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
else:
raise NotImplementedError('Discriminator model name [%s] is not recognized' %
which_model_netD)
if use_gpu:
netD.cuda(gpu_ids[0])
init_weights(netD, init_type=init_type)
return netD
def print_network(net):
num_params = 0
for param in net.parameters():
num_params += param.numel()
print(net)
print('Total number of parameters: %d' % num_params)
##############################################################################
# Classes
##############################################################################
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class GANLoss(nn.Module):
def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
tensor=torch.FloatTensor):
super(GANLoss, self).__init__()
self.real_label = target_real_label
self.fake_label = target_fake_label
self.real_label_var = None
self.fake_label_var = None
self.Tensor = tensor
if use_lsgan:
self.loss = nn.MSELoss()
else:
self.loss = nn.BCELoss()
def get_target_tensor(self, input, target_is_real):
target_tensor = None
if target_is_real:
create_label = ((self.real_label_var is None) or
(self.real_label_var.numel() != input.numel()))
if create_label:
real_tensor = self.Tensor(input.size()).fill_(self.real_label)
self.real_label_var = Variable(real_tensor, requires_grad=False)
target_tensor = self.real_label_var
else:
create_label = ((self.fake_label_var is None) or
(self.fake_label_var.numel() != input.numel()))
if create_label:
fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
self.fake_label_var = Variable(fake_tensor, requires_grad=False)
target_tensor = self.fake_label_var
return target_tensor
def __call__(self, input, target_is_real):
target_tensor = self.get_target_tensor(input, target_is_real)
return self.loss(input, target_tensor)
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetEncoder(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
gpu_ids=[], padding_type='reflect', n_downsampling=2, start=0, end=2, input_layer=True, n_blocks=6):
assert (n_blocks >= 0)
super(ResnetEncoder, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = []
if input_layer:
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
for i in range(start, end):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [
ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
use_bias=use_bias)]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
class ResnetDecoder(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False,
gpu_ids=[], padding_type='reflect', n_downsampling=2, start=0, end=2, output_layer=True, n_blocks=6):
assert (n_blocks >= 0)
super(ResnetDecoder, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = []
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [
ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
use_bias=use_bias)]
for i in range(start, end):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
if output_layer:
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
# Defines the generator that consists of Resnet blocks between a few
# downsampling/upsampling operations.
# Code and idea originally from Justin Johnson's architecture.
# https://github.com/jcjohnson/fast-neural-style/
class ResnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6,
gpu_ids=[], padding_type='reflect'):
assert (n_blocks >= 0)
super(ResnetGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
self.ngf = ngf
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
model = [nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
bias=use_bias),
norm_layer(ngf),
nn.ReLU(True)]
n_downsampling = 2
for i in range(n_downsampling):
mult = 2 ** i
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
stride=2, padding=1, bias=use_bias),
norm_layer(ngf * mult * 2),
nn.ReLU(True)]
mult = 2 ** n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
use_bias=use_bias)]
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=use_bias),
norm_layer(int(ngf * mult / 2)),
nn.ReLU(True)]
model += [nn.ReflectionPad2d(3)]
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
model += [nn.Tanh()]
self.model = nn.Sequential(*model)
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
# Define a resnet block
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad2d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad2d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
class UnetGenerator(nn.Module):
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
norm_layer=nn.BatchNorm2d, use_dropout=False, gpu_ids=[]):
super(UnetGenerator, self).__init__()
self.gpu_ids = gpu_ids
# construct unet structure
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
innermost=True)
for i in range(num_downs - 5):
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer, use_dropout=use_dropout)
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
norm_layer=norm_layer)
self.model = unet_block
def forward(self, input):
if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
# Defines the submodule with skip connection.
# X -------------------identity---------------------- X
# |-- downsampling -- |submodule| -- upsampling --|
class UnetSkipConnectionBlock(nn.Module):
def __init__(self, outer_nc, inner_nc, input_nc=None,
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
super(UnetSkipConnectionBlock, self).__init__()
self.outermost = outermost
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
if input_nc is None:
input_nc = outer_nc
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
stride=2, padding=1, bias=use_bias)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = norm_layer(inner_nc)
uprelu = nn.ReLU(True)
upnorm = norm_layer(outer_nc)
if outermost:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
kernel_size=4, stride=2,
padding=1, bias=use_bias)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if use_dropout:
model = down + [submodule] + up + [nn.Dropout(0.5)]
else:
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
# Defines the PatchGAN discriminator with the specified arguments.
class NLayerDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(NLayerDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=use_bias),
norm_layer(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
if use_sigmoid:
sequence += [nn.Sigmoid()]
self.model = nn.Sequential(*sequence)
def forward(self, input):
if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
else:
return self.model(input)
class PixelDiscriminator(nn.Module):
def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False, gpu_ids=[]):
super(PixelDiscriminator, self).__init__()
self.gpu_ids = gpu_ids
if type(norm_layer) == functools.partial:
use_bias = norm_layer.func == nn.InstanceNorm2d
else:
use_bias = norm_layer == nn.InstanceNorm2d
self.net = [
nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
norm_layer(ndf * 2),
nn.LeakyReLU(0.2, True),
nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
if use_sigmoid:
self.net.append(nn.Sigmoid())
self.net = nn.Sequential(*self.net)
def forward(self, input):
if len(self.gpu_ids) and isinstance(input.data, torch.cuda.FloatTensor):
return nn.parallel.data_parallel(self.net, input, self.gpu_ids)
else:
return self.net(input)
================================================
FILE: drawing_and_style_transfer/models/ost.py
================================================
import torch
from collections import OrderedDict
from torch.autograd import Variable
import itertools
import util.util as util
from util.image_pool import ImagePool
from .base_model import BaseModel
from . import networks
class OSTModel(BaseModel):
def name(self):
return 'OSTModel'
def _compute_kl(self, mu):
mu_2 = torch.pow(mu, 2)
encoding_loss = torch.mean(mu_2)
return encoding_loss
def set_encoders_and_decoders(self, opt):
n_downsampling = opt.n_downsampling
start_unshared = 0
num_unshared = opt.num_unshared
start_shared = num_unshared
end_shared = n_downsampling
start_dec_shared = start_unshared
end_dec_shared = start_unshared + (end_shared - start_shared)
start_dec_unshared = end_dec_shared
end_dec_unshared = n_downsampling
num_res_blocks_unshared = opt.num_res_blocks_unshared
n_res_blocks_shared = opt.num_res_blocks_shared
self.netEnc_a, self.netDec_a = networks.define_ED(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout,
opt.init_type, self.gpu_ids,
n_blocks_encoder=num_res_blocks_unshared,
n_blocks_decoder=num_res_blocks_unshared,
start=start_unshared,
end=num_unshared, n_downsampling=n_downsampling,
input_layer=True, output_layer=True,
start_dec=start_dec_unshared, end_dec=end_dec_unshared)
self.netEnc_b, self.netDec_b = networks.define_ED(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout,
opt.init_type, self.gpu_ids,
n_blocks_encoder=num_res_blocks_unshared,
n_blocks_decoder=num_res_blocks_unshared,
start=start_unshared,
end=num_unshared, n_downsampling=n_downsampling,
input_layer=True,
output_layer=True, start_dec=start_dec_unshared,
end_dec=end_dec_unshared)
self.netEnc_shared, self.netDec_shared = networks.define_ED(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG, opt.norm,
not opt.no_dropout,
opt.init_type, self.gpu_ids,
n_blocks_encoder=n_res_blocks_shared,
n_blocks_decoder=n_res_blocks_shared,
start=start_shared, n_downsampling=n_downsampling,
end=end_shared,
input_layer=False,
output_layer=False, start_dec=start_dec_shared,
end_dec=end_dec_shared)
def initialize(self, opt):
BaseModel.initialize(self, opt)
self.set_encoders_and_decoders(opt)
if self.isTrain:
use_sigmoid = opt.no_lsgan
self.netD_a = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
self.netD_b = networks.define_D(opt.output_nc, opt.ndf,
opt.which_model_netD,
opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
if not opt.dont_load_pretrained_autoencoder:
which_epoch = opt.which_epoch
self.load_network(self.netEnc_b, 'Enc_b', which_epoch)
self.load_network(self.netDec_b, 'Dec_b', which_epoch)
self.load_network(self.netEnc_shared, 'Enc_shared', which_epoch)
self.load_network(self.netDec_shared, 'Dec_shared', which_epoch)
if not self.isTrain or opt.continue_train:
which_epoch = opt.which_epoch
self.load_network(self.netEnc_a, 'Enc_a', which_epoch)
self.load_network(self.netDec_a, 'Dec_a', which_epoch)
self.load_network(self.netEnc_b, 'Enc_b', which_epoch)
self.load_network(self.netDec_b, 'Dec_b', which_epoch)
self.load_network(self.netEnc_shared, 'Enc_shared', which_epoch)
self.load_network(self.netDec_shared, 'Dec_shared', which_epoch)
if self.isTrain:
self.load_network(self.netD_a, 'D_a', which_epoch)
self.load_network(self.netD_b, 'D_b', which_epoch)
if self.isTrain:
self.fake_A_pool = ImagePool(opt.pool_size)
self.fake_B_pool = ImagePool(opt.pool_size)
# define loss functions
self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
self.criterionIdt = torch.nn.L1Loss()
# initialize optimizers
self.optimizer_Enc_a = torch.optim.Adam(self.netEnc_a.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_Dec_a = torch.optim.Adam(self.netDec_a.parameters(),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_Enc_b = torch.optim.Adam(
itertools.chain(self.netEnc_b.parameters(), self.netEnc_shared.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_Dec_b = torch.optim.Adam(
itertools.chain(self.netDec_b.parameters(), self.netDec_shared.parameters()),
lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_a = torch.optim.Adam(self.netD_a.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D_b = torch.optim.Adam(self.netD_b.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers = []
self.schedulers = []
self.optimizers.append(self.optimizer_Enc_a)
self.optimizers.append(self.optimizer_Dec_a)
self.optimizers.append(self.optimizer_Enc_b)
self.optimizers.append(self.optimizer_Dec_b)
self.optimizers.append(self.optimizer_D_a)
self.optimizers.append(self.optimizer_D_b)
for optimizer in self.optimizers:
self.schedulers.append(networks.get_scheduler(optimizer, opt))
print('---------- Networks initialized -------------')
networks.print_network(self.netEnc_a)
networks.print_network(self.netDec_a)
networks.print_network(self.netEnc_b)
networks.print_network(self.netDec_b)
networks.print_network(self.netEnc_shared)
networks.print_network(self.netDec_shared)
if self.isTrain:
networks.print_network(self.netD_a)
networks.print_network(self.netD_b)
print('-----------------------------------------------')
def set_input(self, input):
AtoB = self.opt.which_direction == 'AtoB'
input_A = input['A' if AtoB else 'B']
input_B = input['B' if AtoB else 'A']
if len(self.gpu_ids) > 0:
input_A = input_A.cuda(self.gpu_ids[0], async=True)
input_B = input_B.cuda(self.gpu_ids[0], async=True)
self.input_A = input_A
self.input_B = input_B
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
self.real_A = Variable(self.input_A)
self.real_B = Variable(self.input_B)
def test(self):
real_A = Variable(self.input_A, volatile=True)
real_B = Variable(self.input_B, volatile=True)
enc_a = self.netEnc_shared(self.netEnc_a(real_A))
enc_b = self.netEnc_shared(self.netEnc_b(real_B))
fake_AA = self.netDec_a(self.netDec_shared(enc_a))
fake_AB = self.netDec_b(self.netDec_shared(enc_a))
fake_BB = self.netDec_b(self.netDec_shared(enc_b))
enc_ab = self.netEnc_shared(self.netEnc_b(fake_AB))
fake_ABA = self.netDec_a(self.netDec_shared(enc_ab))
self.fake_AA = fake_AA.data
self.fake_AB = fake_AB.data
self.fake_BB = fake_BB.data
self.fake_ABA = fake_ABA.data
# get image paths
def get_image_paths(self):
return self.image_paths
def backward_D_basic(self, netD, real, fake):
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
# backward
loss_D.backward()
return loss_D
def backward_D(self):
fake_AB = self.fake_B_pool.query(self.fake_AB)
loss_D_ab = self.backward_D_basic(self.netD_b, self.real_B, fake_AB)
self.loss_D_ab = loss_D_ab.data[0]
fake_BB = self.fake_B_pool.query(self.fake_BB)
loss_D_bb = self.backward_D_basic(self.netD_b, self.real_B, fake_BB)
self.loss_D_bb = loss_D_bb.data[0]
def backward_G(self):
# GAN loss D_A(G_A(A))
enc_a = self.netEnc_shared(self.netEnc_a(self.real_A))
enc_b = self.netEnc_shared(self.netEnc_b(self.real_B))
fake_AA = self.netDec_a(self.netDec_shared(enc_a))
fake_AB = self.netDec_b(self.netDec_shared(enc_a))
fake_BB = self.netDec_b(self.netDec_shared(enc_b))
enc_ab = self.netEnc_shared(self.netEnc_b(fake_AB))
fake_ABA = self.netDec_a(self.netDec_shared(enc_ab))
pred_fake_AB = self.netD_b(fake_AB)
loss_Gan_AB = self.criterionGAN(pred_fake_AB, True)
loss_idt_A = self.criterionIdt(fake_AA, self.real_A)
loss_cycle_A = self.opt.lambda_A * self.criterionIdt(fake_ABA, self.real_A)
loss_idt_B = self.criterionIdt(fake_BB, self.real_B)
pred_fake_BB = self.netD_b(fake_BB)
loss_Gan_BB = self.criterionGAN(pred_fake_BB, True)
loss_kl_B = self.opt.kl_lambda * self._compute_kl(enc_b)
# combined losses
loss_G_B = loss_idt_B + loss_kl_B + loss_Gan_BB
loss_G_A = loss_Gan_AB + loss_cycle_A + loss_idt_A
self.fake_AA = fake_AA.data
self.fake_BB = fake_BB.data
self.fake_AB = fake_AB.data
self.fake_ABA = fake_ABA.data
self.loss_cycle_A = loss_cycle_A.data[0]
self.loss_Gan_AB = loss_Gan_AB.data[0]
self.loss_Gan_BB = loss_Gan_AB.data[0]
self.loss_idt_A = loss_idt_A.data[0]
self.loss_idt_B = loss_idt_B.data[0]
self.loss_kl_B = loss_kl_B.data[0]
return loss_G_A, loss_G_B
def optimize_parameters(self):
# forward
self.forward()
loss_G_A, loss_G_B = self.backward_G()
# x loss updates
self.optimizer_Enc_a.zero_grad()
self.optimizer_Dec_a.zero_grad()
loss_G_A.backward(retain_graph=True)
self.optimizer_Enc_a.step()
self.optimizer_Dec_a.step()
# B loss updates
self.optimizer_Enc_b.zero_grad()
self.optimizer_Dec_b.zero_grad()
loss_G_B.backward()
self.optimizer_Enc_b.step()
self.optimizer_Dec_b.step()
# D
self.optimizer_D_a.zero_grad()
self.optimizer_D_b.zero_grad()
self.backward_D()
self.optimizer_D_a.step()
self.optimizer_D_b.step()
def get_current_errors(self):
ret_errors = OrderedDict(
[('D_ab', self.loss_D_ab), ('D_bb', self.loss_D_bb),
('G_AB', self.loss_Gan_AB), ('G_BB', self.loss_Gan_BB),
('Idt_B', self.loss_idt_B), ('Idt_A', self.loss_idt_A),
('Cycle_A', self.loss_cycle_A), ('Kl_B', self.loss_kl_B), ])
return ret_errors
def get_current_visuals(self):
real_A = util.tensor2im(self.input_A)
real_B = util.tensor2im(self.input_B)
fake_BB = util.tensor2im(self.fake_BB)
fake_AB = util.tensor2im(self.fake_AB)
fake_AA = util.tensor2im(self.fake_AA)
fake_ABA = util.tensor2im(self.fake_ABA)
ret_visuals = OrderedDict(
[('real_B', real_B), ('fake_BB', fake_BB),
('real_A', real_A), ('fake_AA', fake_AA), ('fake_AB', fake_AB), ('fake_ABA', fake_ABA), ])
return ret_visuals
def save(self, label):
self.save_network(self.netEnc_a, 'Enc_a', label, self.gpu_ids)
self.save_network(self.netDec_a, 'Dec_a', label, self.gpu_ids)
self.save_network(self.netD_a, 'D_a', label, self.gpu_ids)
self.save_network(self.netEnc_b, 'Enc_b', label, self.gpu_ids)
self.save_network(self.netDec_b, 'Dec_b', label, self.gpu_ids)
self.save_network(self.netD_b, 'D_b', label, self.gpu_ids)
self.save_network(self.netEnc_shared, 'Enc_shared', label, self.gpu_ids)
self.save_network(self.netDec_shared, 'Dec_shared', label, self.gpu_ids)
================================================
FILE: drawing_and_style_transfer/models/test_model.py
================================================
from torch.autograd import Variable
from collections import OrderedDict
import util.util as util
from .base_model import BaseModel
from . import networks
class TestModel(BaseModel):
def name(self):
return 'TestModel'
def initialize(self, opt):
assert (not opt.isTrain)
BaseModel.initialize(self, opt)
self.netG = networks.define_G(opt.input_nc, opt.output_nc,
opt.ngf, opt.which_model_netG,
opt.norm, not opt.no_dropout,
opt.init_type,
self.gpu_ids)
which_epoch = opt.which_epoch
self.load_network(self.netG, 'G', which_epoch)
print('---------- Networks initialized -------------')
networks.print_network(self.netG)
print('-----------------------------------------------')
def set_input(self, input):
# we need to use single_dataset mode
input_A = input['A']
if len(self.gpu_ids) > 0:
input_A = input_A.cuda(self.gpu_ids[0], async=True)
self.input_A = input_A
self.image_paths = input['A_paths']
def test(self):
self.real_A = Variable(self.input_A, volatile=True)
self.fake_B = self.netG(self.real_A)
# get image paths
def get_image_paths(self):
return self.image_paths
def get_current_visuals(self):
real_A = util.tensor2im(self.real_A.data)
fake_B = util.tensor2im(self.fake_B.data)
return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
================================================
FILE: drawing_and_style_transfer/options/__init__.py
================================================
================================================
FILE: drawing_and_style_transfer/options/base_options.py
================================================
import argparse
import os
from util import util
import torch
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
self.initialized = False
def initialize(self):
self.parser.add_argument('--dataroot', required=True,
help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
self.parser.add_argument('--max_items_A', type=int, default=-1,
help='max number of items for domain A, -1 indicates no maximum')
self.parser.add_argument('--max_items_B', type=int, default=-1,
help='max number of items for domain B, -1 indicates no maximum')
self.parser.add_argument('--start', type=int, default=0,
help='starting index of items of domain A, after sorting')
self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks',
help='selects model to use for netG and netED')
self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
self.parser.add_argument('--name', type=str, default='experiment_name',
help='name of the experiment. It decides where to store samples and models')
self.parser.add_argument('--dataset_mode', type=str, default='unaligned',
help='chooses how datasets are loaded. [unaligned | aligned | single]')
self.parser.add_argument('--model', type=str, default='ost',
help='chooses which model to use. ost, autoencoder, test.')
self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
self.parser.add_argument('--load_dir', type=str, default='./checkpoints', help='models are loaded here')
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
self.parser.add_argument('--norm', type=str, default='instance',
help='instance normalization or batch normalization')
self.parser.add_argument('--serial_batches', action='store_true',
help='if true, takes images in order to make batches, otherwise takes them randomly')
self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
self.parser.add_argument('--display_server', type=str, default="http://localhost",
help='visdom server of the web display')
self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop',
help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
self.parser.add_argument('--no_flip_and_rotation', action='store_true',
help='if specified, do not flip and rotate the images for data augmentation')
self.parser.add_argument('--rotation_degree', type=int, default=7, help='rotation degree used for augmentation')
self.parser.add_argument('--init_type', type=str, default='normal',
help='network initialization [normal|xavier|kaiming|orthogonal]')
self.parser.add_argument('--A', type=str, default='A',
help='used to exchange dataset A for B by setting the value to B')
self.parser.add_argument('--B', type=str, default='B',
help='used to exchange dataset B for A by setting the value to A')
self.parser.add_argument('--n_downsampling', type=int, default=2,
help="number of downsampling/upsampling convolutional/deconvolutional layers")
self.parser.add_argument('--num_unshared', type=int, default=1,
help="number of unshared encoder/decoder layers, not including input and final layers")
self.parser.add_argument('--num_res_blocks_unshared', type=int, default=0,
help='number of unshared resnet blocks')
self.parser.add_argument('--num_res_blocks_shared', type=int, default=6, help='number of shared resnet blocks')
self.initialized = True
def parse(self):
if not self.initialized:
self.initialize()
self.opt = self.parser.parse_args()
self.opt.isTrain = self.isTrain # train or test
str_ids = self.opt.gpu_ids.split(',')
self.opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
self.opt.gpu_ids.append(id)
# set gpu ids
if len(self.opt.gpu_ids) > 0:
torch.cuda.set_device(self.opt.gpu_ids[0])
args = vars(self.opt)
print('------------ Options -------------')
for k, v in sorted(args.items()):
print('%s: %s' % (str(k), str(v)))
print('-------------- End ----------------')
# save to the disk
expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
util.mkdirs(expr_dir)
file_name = os.path.join(expr_dir, 'opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('------------ Options -------------\n')
for k, v in sorted(args.items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
return self.opt
================================================
FILE: drawing_and_style_transfer/options/test_options.py
================================================
from .base_options import BaseOptions
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest',
help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
self.isTrain = False
================================================
FILE: drawing_and_style_transfer/options/train_options.py
================================================
from .base_options import BaseOptions
class TrainOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
self.parser.add_argument('--display_freq', type=int, default=100,
help='frequency of showing training results on screen')
self.parser.add_argument('--display_single_pane_ncols', type=int, default=0,
help='if positive, display all images in a single visdom web panel with certain number of images per row.')
self.parser.add_argument('--update_html_freq', type=int, default=1000,
help='frequency of saving training results to html')
self.parser.add_argument('--print_freq', type=int, default=100,
help='frequency of showing training results on console')
self.parser.add_argument('--save_latest_freq', type=int, default=10000,
help='frequency of saving the latest results')
self.parser.add_argument('--save_epoch_freq', type=int, default=10,
help='frequency of saving checkpoints at the end of epochs')
self.parser.add_argument('--continue_train', action='store_true',
help='continue training: load the latest model')
self.parser.add_argument('--epoch_count', type=int, default=1,
help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
self.parser.add_argument('--which_epoch', type=str, default='latest',
help='which epoch to load? set to latest to use latest cached model')
self.parser.add_argument('--niter', type=int, default=60, help='# of iter at starting learning rate')
self.parser.add_argument('--niter_decay', type=int, default=20,
help='# of iter to linearly decay learning rate to zero')
self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
self.parser.add_argument('--no_lsgan', action='store_true',
help='do *not* use least square GAN, if false, use vanilla GAN')
self.parser.add_argument('--kl_lambda', type=float, default=0.1, help='weight for kl loss')
self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for cycle loss (A -> B -> A)')
self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
self.parser.add_argument('--pool_size', type=int, default=50,
help='the size of image buffer that stores previously generated images')
self.parser.add_argument('--dont_load_pretrained_autoencoder', action='store_true',
help='do not load pretrained autoencoder')
self.parser.add_argument('--no_html', action='store_true',
help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
self.parser.add_argument('--lr_policy', type=str, default='lambda',
help='learning rate policy: lambda|step|plateau')
self.parser.add_argument('--lr_decay_iters', type=int, default=50,
help='multiply by a gamma every lr_decay_iters iterations')
self.isTrain = True
================================================
FILE: drawing_and_style_transfer/scripts/test_ost.sh
================================================
# images to cityscapes
python test.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost --load_dir=cityscapes_ost --model=ost --no_dropout --n_downsampling=3 --num_unshared=3 --start=0 --max_items_A=1
# cityscapes to images
python test.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost_reverse --load_dir=cityscapes_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'
# images to facades
python test.py --dataroot=./datasets/facades/ --name=facades_ost --load_dir=facades_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1
# facades to images
python test.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --load_dir=facades_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'
# aerial view to maps
python test.py --dataroot=./datasets/maps/ --name=maps_ost --load_dir=maps_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1
# maps to aerial view
python test.py --dataroot=./datasets/maps/ --name=maps_ost_reverse --load_dir=maps_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'
# monet2photo
python test.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost --load_dir=monet2photo_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1
# photo2monet
python test.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost_reverse --load_dir=monet2photo_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A'
# summer2winter
python test.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost --load_dir=summer2winter_yosemite_ost --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1
# winter2summer
python test.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost_reverse --load_dir=summer2winter_yosemite_ost_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A'
================================================
FILE: drawing_and_style_transfer/scripts/train_autoencoder.sh
================================================
# images to cityscapes
python train.py --dataroot=./datasets/cityscapes/trainB --name=cityscapes_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=3 --num_unshared=3
# cityscapes to images
python train.py --dataroot=./datasets/cityscapes/trainA --name=cityscapes_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2
# images to facades
python train.py --dataroot=./datasets/facades/trainB --name=facades_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2
# facades to images
python train.py --dataroot=./datasets/facades/trainA --name=facades_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2
# aerial view to maps
python train.py --dataroot=./datasets/maps/trainB --name=maps_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2
# maps to aerial view
python train.py --dataroot=./datasets/maps/trainA --name=maps_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=2
# monet2photo
python train.py --dataroot=./datasets/monet2photo/trainB --name=monet2photo_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0
# photo2monet
python train.py --dataroot=./datasets/monet2photo/trainA --name=monet2photo_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0
# summer2winter
python train.py --dataroot=./datasets/summer2winter_yosemite/trainB --name=summer2winter_yosemite_autoencoder --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0
# winter2summer
python train.py --dataroot=./datasets/summer2winter_yosemite/trainA --name=summer2winter_yosemite_autoencoder_reverse --model=autoencoder --dataset_mode=single --no_dropout --n_downsampling=2 --num_unshared=0
================================================
FILE: drawing_and_style_transfer/scripts/train_ost.sh
================================================
# images to cityscapes
python train.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost --load_dir=cityscapes_autoencoder --model=ost --no_dropout --n_downsampling=3 --num_unshared=3 --start=0 --max_items_A=1
# cityscapes to images
python train.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost_reverse --load_dir=cityscapes_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'
# images to facades
python train.py --dataroot=./datasets/facades/ --name=facades_ost --load_dir=facades_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1
# facades to images
python train.py --dataroot=./datasets/facades/ --name=facades_ost_reverse --load_dir=facades_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'
# aerial view to maps
python train.py --dataroot=./datasets/maps/ --name=maps_ost --load_dir=maps_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1
# maps to aerial view
python train.py --dataroot=./datasets/maps/ --name=maps_ost_reverse --load_dir=maps_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=2 --start=0 --max_items_A=1 --A='B' --B='A'
# monet2photo
python train.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost --load_dir=monet2photo_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1
# photo2monet
python train.py --dataroot=./datasets/monet2photo/ --name=monet2photo_ost_reverse --load_dir=monet2photo_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A'
# summer2winter
python train.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost --load_dir=summer2winter_yosemite_autoencoder --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1
# winter2summer
python train.py --dataroot=./datasets/summer2winter_yosemite/ --name=summer2winter_yosemite_ost_reverse --load_dir=summer2winter_yosemite_autoencoder_reverse --model=ost --no_dropout --n_downsampling=2 --num_unshared=0 --start=0 --max_items_A=1 --A='B' --B='A'
================================================
FILE: drawing_and_style_transfer/test.py
================================================
import os
from options.test_options import TestOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
from util import html
if __name__ == '__main__':
opt = TestOptions().parse()
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
# We are interested on testing only on x's in A on which we trained
opt.phase = 'train'
if opt.max_items_A >= 0:
opt.max_items_B = opt.max_items_A
if opt.max_items_B >= 0:
opt.max_items_A = opt.max_items_B
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
model = create_model(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
for i, data in enumerate(dataset):
if i >= opt.how_many:
break
model.set_input(data)
model.test()
visuals = model.get_current_visuals()
img_path = model.get_image_paths()
print('%04d: process image... %s' % (i, img_path))
visualizer.save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, index=i, split=0)
webpage.save()
================================================
FILE: drawing_and_style_transfer/train.py
================================================
import time
from options.train_options import TrainOptions
from data import CreateDataLoader
from models import create_model
from util.visualizer import Visualizer
if __name__ == '__main__':
opt = TrainOptions().parse()
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
model = create_model(opt)
visualizer = Visualizer(opt)
total_steps = 0
for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
iter_data_time = time.time()
epoch_iter = 0
for i, data in enumerate(dataset):
iter_start_time = time.time()
if total_steps % opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
visualizer.reset()
total_steps += opt.batchSize
epoch_iter += opt.batchSize
model.set_input(data)
model.optimize_parameters()
if total_steps % opt.display_freq == 0:
save_result = total_steps % opt.update_html_freq == 0
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors()
t = (time.time() - iter_start_time) / opt.batchSize
visualizer.print_current_errors(epoch, epoch_iter, errors, t, t_data)
if opt.display_id > 0:
visualizer.plot_current_errors(epoch, float(epoch_iter) / dataset_size, opt, errors)
if total_steps % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_steps %d)' %
(epoch, total_steps))
model.save('latest')
iter_data_time = time.time()
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
model.save('latest')
model.save(epoch)
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
model.update_learning_rate()
================================================
FILE: drawing_and_style_transfer/util/__init__.py
================================================
================================================
FILE: drawing_and_style_transfer/util/get_data.py
================================================
from __future__ import print_function
import os
import tarfile
import requests
from warnings import warn
from zipfile import ZipFile
from bs4 import BeautifulSoup
from os.path import abspath, isdir, join, basename
class GetData(object):
"""
Download CycleGAN or Pix2Pix Data.
Args:
technique : str
One of: 'cyclegan' or 'pix2pix'.
verbose : bool
If True, print additional information.
Examples:
>>> from util.get_data import GetData
>>> gd = GetData(technique='cyclegan')
>>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
"""
def __init__(self, technique='cyclegan', verbose=True):
url_dict = {
'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
}
self.url = url_dict.get(technique.lower())
self._verbose = verbose
def _print(self, text):
if self._verbose:
print(text)
@staticmethod
def _get_options(r):
soup = BeautifulSoup(r.text, 'lxml')
options = [h.text for h in soup.find_all('a', href=True)
if h.text.endswith(('.zip', 'tar.gz'))]
return options
def _present_options(self):
r = requests.get(self.url)
options = self._get_options(r)
print('Options:\n')
for i, o in enumerate(options):
print("{0}: {1}".format(i, o))
choice = input("\nPlease enter the number of the "
"dataset above you wish to download:")
return options[int(choice)]
def _download_data(self, dataset_url, save_path):
if not isdir(save_path):
os.makedirs(save_path)
base = basename(dataset_url)
temp_save_path = join(save_path, base)
with open(temp_save_path, "wb") as f:
r = requests.get(dataset_url)
f.write(r.content)
if base.endswith('.tar.gz'):
obj = tarfile.open(temp_save_path)
elif base.endswith('.zip'):
obj = ZipFile(temp_save_path, 'r')
else:
raise ValueError("Unknown File Type: {0}.".format(base))
self._print("Unpacking Data...")
obj.extractall(save_path)
obj.close()
os.remove(temp_save_path)
def get(self, save_path, dataset=None):
"""
Download a dataset.
Args:
save_path : str
A directory to save the data to.
dataset : str, optional
A specific dataset to download.
Note: this must include the file extension.
If None, options will be presented for you
to choose from.
Returns:
save_path_full : str
The absolute path to the downloaded data.
"""
if dataset is None:
selected_dataset = self._present_options()
else:
selected_dataset = dataset
save_path_full = join(save_path, selected_dataset.split('.')[0])
if isdir(save_path_full):
warn("\n'{0}' already exists. Voiding Download.".format(
save_path_full))
else:
self._print('Downloading Data...')
url = "{0}/{1}".format(self.url, selected_dataset)
self._download_data(url, save_path=save_path)
return abspath(save_path_full)
================================================
FILE: drawing_and_style_transfer/util/html.py
================================================
import dominate
from dominate.tags import *
import os
class HTML:
def __init__(self, web_dir, title, reflesh=0):
self.title = title
self.web_dir = web_dir
self.img_dir = os.path.join(self.web_dir, 'images')
if not os.path.exists(self.web_dir):
os.makedirs(self.web_dir)
if not os.path.exists(self.img_dir):
os.makedirs(self.img_dir)
# print(self.img_dir)
self.doc = dominate.document(title=title)
if reflesh > 0:
with self.doc.head:
meta(http_equiv="reflesh", content=str(reflesh))
def get_image_dir(self):
return self.img_dir
def add_header(self, str):
with self.doc:
h3(str)
def add_table(self, border=1):
self.t = table(border=border, style="table-layout: fixed;")
self.doc.add(self.t)
def add_images(self, ims, txts, links, width=400):
self.add_table()
with self.t:
with tr():
for im, txt, link in zip(ims, txts, links):
with td(style="word-wrap: break-word;", halign="center", valign="top"):
with p():
with a(href=os.path.join('images', link)):
img(style="width:%dpx" % width, src=os.path.join('images', im))
br()
p(txt)
def save(self):
html_file = '%s/index.html' % self.web_dir
f = open(html_file, 'wt')
f.write(self.doc.render())
f.close()
if __name__ == '__main__':
html = HTML('web/', 'test_html')
html.add_header('hello world')
ims = []
txts = []
links = []
for n in range(4):
ims.append('image_%d.png' % n)
txts.append('text_%d' % n)
links.append('image_%d.png' % n)
html.add_images(ims, txts, links)
html.save()
================================================
FILE: drawing_and_style_transfer/util/image_pool.py
================================================
import random
import torch
from torch.autograd import Variable
class ImagePool():
def __init__(self, pool_size):
self.pool_size = pool_size
if self.pool_size > 0:
self.num_imgs = 0
self.images = []
def query(self, images):
if self.pool_size == 0:
return Variable(images)
return_images = []
for image in images:
image = torch.unsqueeze(image, 0)
if self.num_imgs < self.pool_size:
self.num_imgs = self.num_imgs + 1
self.images.append(image)
return_images.append(image)
else:
p = random.uniform(0, 1)
if p > 0.5:
random_id = random.randint(0, self.pool_size - 1)
tmp = self.images[random_id].clone()
self.images[random_id] = image
return_images.append(tmp)
else:
return_images.append(image)
return_images = Variable(torch.cat(return_images, 0))
return return_images
================================================
FILE: drawing_and_style_transfer/util/util.py
================================================
from __future__ import print_function
import torch
import numpy as np
from PIL import Image
import os
# Converts a Tensor into a Numpy array
# |imtype|: the desired type of the converted numpy array
def tensor2im(image_tensor, imtype=np.uint8):
image_numpy = image_tensor[0].cpu().float().numpy()
if image_numpy.shape[0] == 1:
image_numpy = np.tile(image_numpy, (3, 1, 1))
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
return image_numpy.astype(imtype)
def diagnose_network(net, name='network'):
mean = 0.0
count = 0
for param in net.parameters():
if param.grad is not None:
mean += torch.mean(torch.abs(param.grad.data))
count += 1
if count > 0:
mean = mean / count
print(name)
print(mean)
def save_image(image_numpy, image_path):
image_pil = Image.fromarray(image_numpy)
image_pil.save(image_path)
def print_numpy(x, val=True, shp=False):
x = x.astype(np.float64)
if shp:
print('shape,', x.shape)
if val:
x = x.flatten()
print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
def mkdirs(paths):
if isinstance(paths, list) and not isinstance(paths, str):
for path in paths:
mkdir(path)
else:
mkdir(paths)
def mkdir(path):
if not os.path.exists(path):
os.makedirs(path)
================================================
FILE: drawing_and_style_transfer/util/visualizer.py
================================================
import numpy as np
import os
import ntpath
import time
from . import util
from . import html
from scipy.misc import imresize
class Visualizer():
def __init__(self, opt):
# self.opt = opt
self.display_id = opt.display_id
self.use_html = opt.isTrain and not opt.no_html
self.win_size = opt.display_winsize
self.name = opt.name
self.opt = opt
self.saved = False
if self.display_id > 0:
import visdom
self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port)
if self.use_html:
self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
self.img_dir = os.path.join(self.web_dir, 'images')
print('create web directory %s...' % self.web_dir)
util.mkdirs([self.web_dir, self.img_dir])
self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
with open(self.log_name, "a") as log_file:
now = time.strftime("%c")
log_file.write('================ Training Loss (%s) ================\n' % now)
def reset(self):
self.saved = False
# |visuals|: dictionary of images to display or save
def display_current_results(self, visuals, epoch, save_result):
if self.display_id > 0: # show images in the browser
ncols = self.opt.display_single_pane_ncols
if ncols > 0:
h, w = next(iter(visuals.values())).shape[:2]
table_css = """<style>
table {border-collapse: separate; border-spacing:4px; white-space:nowrap; text-align:center}
table td {width: %dpx; height: %dpx; padding: 4px; outline: 4px solid black}
</style>""" % (w, h)
title = self.name
label_html = ''
label_html_row = ''
nrows = int(np.ceil(len(visuals.items()) / ncols))
images = []
idx = 0
for label, image_numpy in visuals.items():
label_html_row += '<td>%s</td>' % label
images.append(image_numpy.transpose([2, 0, 1]))
idx += 1
if idx % ncols == 0:
label_html += '<tr>%s</tr>' % label_html_row
label_html_row = ''
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
while idx % ncols != 0:
images.append(white_image)
label_html_row += '<td></td>'
idx += 1
if label_html_row != '':
label_html += '<tr>%s</tr>' % label_html_row
# pane col = image row
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
padding=2, opts=dict(title=title + ' images'))
label_html = '<table>%s</table>' % label_html
self.vis.text(table_css + label_html, win=self.display_id + 2,
opts=dict(title=title + ' labels'))
else:
idx = 1
for label, image_numpy in visuals.items():
self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
win=self.display_id + idx)
idx += 1
if self.use_html and (save_result or not self.saved): # save images to a html file
self.saved = True
for label, image_numpy in visuals.items():
img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
util.save_image(image_numpy, img_path)
# update website
webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
for n in range(epoch, 0, -1):
webpage.add_header('epoch [%d]' % n)
ims = []
txts = []
links = []
for label, image_numpy in visuals.items():
img_path = 'epoch%.3d_%s.png' % (n, label)
ims.append(img_path)
txts.append(label)
links.append(img_path)
webpage.add_images(ims, txts, links, width=self.win_size)
webpage.save()
# errors: dictionary of error labels and values
def plot_current_errors(self, epoch, counter_ratio, opt, errors):
if not hasattr(self, 'plot_data'):
self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}
self.plot_data['X'].append(epoch + counter_ratio)
self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
self.vis.line(
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
Y=np.array(self.plot_data['Y']),
opts={
'title': self.name + ' loss over time',
'legend': self.plot_data['legend'],
'xlabel': 'epoch',
'ylabel': 'loss'},
win=self.display_id)
# errors: same format as |errors| of plotCurrentErrors
def print_current_errors(self, epoch, i, errors, t, t_data):
message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
for k, v in errors.items():
message += '%s: %.3f ' % (k, v)
print(message)
with open(self.log_name, "a") as log_file:
log_file.write('%s\n' % message)
# save image to the disk
def save_images(self, webpage, visuals, image_path, aspect_ratio=1.0, index=None, split=1):
image_dir = webpage.get_image_dir()
short_path = ntpath.basename(image_path[0])
name = os.path.splitext(short_path)[0]
if index is not None:
name_splits = name.split("_")
if split == 0:
name = str(index)
else:
name = str(index) + "_" + name_splits[split]
webpage.add_header(name)
ims = []
txts = []
links = []
for label, im in visuals.items():
image_name = '%s_%s.png' % (name, label)
save_path = os.path.join(image_dir, image_name)
h, w, _ = im.shape
if aspect_ratio > 1.0:
im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
if aspect_ratio < 1.0:
im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
util.save_image(im, save_path)
ims.append(image_name)
txts.append(label)
links.append(image_name)
webpage.add_images(ims, txts, links, width=self.win_size)
================================================
FILE: mnist_to_svhn/data_loader.py
================================================
import torch
from torchvision import datasets
from torchvision import transforms
def get_loader(config):
"""Builds and returns Dataloader for MNIST and SVHN dataset."""
transform_list = []
if config.use_augmentation:
transform_list.append(transforms.RandomHorizontalFlip())
transform_list.append(transforms.RandomRotation(0.1))
transform_list.append(transforms.Scale(config.image_size))
transform_list.append(transforms.ToTensor())
transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
transform_test = transforms.Compose([
transforms.Scale(config.image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
transform_train = transforms.Compose(transform_list)
svhn = datasets.SVHN(root=config.svhn_path, download=True, transform=transform_train, split='train')
mnist = datasets.MNIST(root=config.mnist_path, download=True, transform=transform_train, train=True)
svhn_test = datasets.SVHN(root=config.svhn_path, download=True, transform=transform_test, split='test')
mnist_test = datasets.MNIST(root=config.mnist_path, download=True, transform=transform_test, train=False)
svhn_loader = torch.utils.data.DataLoader(dataset=svhn,
batch_size=config.batch_size,
shuffle=config.shuffle,
num_workers=config.num_workers)
mnist_loader = torch.utils.data.DataLoader(dataset=mnist,
batch_size=config.batch_size,
shuffle=config.shuffle,
num_workers=config.num_workers)
svhn_test_loader = torch.utils.data.DataLoader(dataset=svhn_test,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers)
mnist_test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers)
return svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader
================================================
FILE: mnist_to_svhn/download.sh
================================================
mkdir -p mnist
mkdir -p svhn
wget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat
wget -O svhn/test_32x32.mat http://ufldl.stanford.edu/housenumbers/test_32x32.mat
wget -O svhn/extra_32x32.mat http://ufldl.stanford.edu/housenumbers/extra_32x32.mat
================================================
FILE: mnist_to_svhn/main_autoencoder.py
================================================
import argparse
import os
from torch.backends import cudnn
from solver_autoencoder import Solver
from data_loader import get_loader
def str2bool(v):
return v.lower() in ('true')
def main(config):
svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader(config)
solver = Solver(config, svhn_loader, mnist_loader)
cudnn.benchmark = True
# create directories if not exist
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
if not os.path.exists(config.sample_path):
os.makedirs(config.sample_path)
if config.mode == 'train':
solver.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model hyper-parameters
parser.add_argument('--image_size', type=int, default=32)
parser.add_argument('--g_conv_dim', type=int, default=64)
parser.add_argument('--d_conv_dim', type=int, default=64)
parser.add_argument('--num_classes', type=int, default=10)
# training hyper-parameters
parser.add_argument('--train_iters', type=int, default=15000)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--kl_lambda', type=float, default=0.1)
# misc
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--model_path', type=str, default='./models_autoencoder')
parser.add_argument('--sample_path', type=str, default='./samples_autoencoder')
parser.add_argument('--mnist_path', type=str, default='./mnist')
parser.add_argument('--svhn_path', type=str, default='./svhn')
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--sample_step', type=int, default=500)
parser.add_argument('--shuffle', type=bool, default=True)
parser.add_argument('--use_augmentation', required=True, type=str2bool)
config = parser.parse_args()
print(config)
main(config)
================================================
FILE: mnist_to_svhn/main_mnist_to_svhn.py
================================================
import argparse
import logging
import os
from torch.backends import cudnn
from data_loader import get_loader
from solver_mnist_to_svhn import Solver
def str2bool(v):
return v.lower() in ('true')
def main(config):
svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader(config)
solver = Solver(config, svhn_loader, mnist_loader)
cudnn.benchmark = True
# create directories if not exist
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
if not os.path.exists(config.sample_path):
os.makedirs(config.sample_path)
base = config.log_path
filename = os.path.join(base, str(config.max_items))
if not os.path.isdir(base):
os.mkdir(base)
logging.basicConfig(filename=filename, level=logging.DEBUG)
if config.mode == 'train':
solver.train()
elif config.mode == 'sample':
solver.sample()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model hyper-parameters
parser.add_argument('--image_size', type=int, default=32)
parser.add_argument('--g_conv_dim', type=int, default=64)
parser.add_argument('--d_conv_dim', type=int, default=64)
parser.add_argument('--num_classes', type=int, default=10)
# training hyper-parameters
parser.add_argument('--train_iters', type=int, default=40000)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--kl_lambda', type=float, default=0.1)
# misc
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--mnist_path', type=str, default='./mnist')
parser.add_argument('--svhn_path', type=str, default='./svhn')
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--shuffle', type=bool, default=True)
parser.add_argument('--load_iter', type=int, default=10000)
parser.add_argument('--sample_step', type=int, default=500)
parser.add_argument('--num_averaging_runs', type=int, default=1000)
parser.add_argument('--num_iters_save_model_and_return', type=int, default=5000)
parser.add_argument('--num_d_iterations', type=int, default=1)
parser.add_argument('--num_g_iterations', type=int, default=1)
parser.add_argument('--model_path', type=str, default='./models_ost')
parser.add_argument('--sample_path', type=str, default='./samples_ost')
parser.add_argument('--load_path', type=str, default='./models_autoencoder')
parser.add_argument('--log_path', type=str, default='logs_ost')
parser.add_argument('--pretrained_g', required=True, type=str2bool)
parser.add_argument('--save_models_and_samples', required=True, type=str2bool)
parser.add_argument('--use_augmentation', required=True, type=str2bool)
parser.add_argument('--one_way_cycle', required=True, type=str2bool)
parser.add_argument('--freeze_shared', required=True, type=str2bool)
parser.add_argument('--max_items', type=int, default=1)
config = parser.parse_args()
print(config)
main(config)
================================================
FILE: mnist_to_svhn/main_svhn_to_mnist.py
================================================
import argparse
import logging
import os
from data_loader import get_loader
from solver_svhn_to_mnist import Solver
from torch.backends import cudnn
def str2bool(v):
return v.lower() in ('true')
def main(config):
svhn_loader, mnist_loader, svhn_test_loader, mnist_test_loader = get_loader(config)
solver = Solver(config, svhn_loader, mnist_loader)
cudnn.benchmark = True
# create directories if not exist
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
if not os.path.exists(config.sample_path):
os.makedirs(config.sample_path)
base = config.log_path
filename = os.path.join(base, str(config.max_items))
if not os.path.isdir(base):
os.mkdir(base)
logging.basicConfig(filename=filename, level=logging.DEBUG)
if config.mode == 'train':
solver.train()
elif config.mode == 'sample':
solver.sample()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model hyper-parameters
parser.add_argument('--image_size', type=int, default=32)
parser.add_argument('--g_conv_dim', type=int, default=64)
parser.add_argument('--d_conv_dim', type=int, default=64)
parser.add_argument('--num_classes', type=int, default=10)
# training hyper-parameters
parser.add_argument('--train_iters', type=int, default=40000)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=2)
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--kl_lambda', type=float, default=0.1)
# misc
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--mnist_path', type=str, default='./mnist')
parser.add_argument('--svhn_path', type=str, default='./svhn')
parser.add_argument('--log_step', type=int, default=10)
parser.add_argument('--shuffle', type=bool, default=True)
parser.add_argument('--load_iter', type=int, default=10000)
parser.add_argument('--sample_step', type=int, default=500)
parser.add_argument('--num_averaging_runs', type=int, default=1000)
parser.add_argument('--num_iters_save_model_and_return', type=int, default=5000)
parser.add_argument('--num_d_iterations', type=int, default=1)
parser.add_argument('--num_g_iterations', type=int, default=1)
parser.add_argument('--model_path', type=str, default='./models_ost')
parser.add_argument('--sample_path', type=str, default='./samples_ost')
parser.add_argument('--load_path', type=str, default='./models_autoencoder')
parser.add_argument('--log_path', type=str, default='logs_ost')
parser.add_argument('--pretrained_g', required=True, type=str2bool)
parser.add_argument('--save_models_and_samples', required=True, type=str2bool)
parser.add_argument('--use_augmentation', required=True, type=str2bool)
parser.add_argument('--one_way_cycle', required=True, type=str2bool)
parser.add_argument('--freeze_shared', required=True, type=str2bool)
parser.add_argument('--max_items', type=int, default=1)
config = parser.parse_args()
print(config)
main(config)
================================================
FILE: mnist_to_svhn/model.py
================================================
import torch.nn as nn
import torch.nn.functional as F
def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
"""Custom deconvolutional layer for simplicity."""
layers = []
layers.append(nn.ConvTranspose2d(c_in, c_out, k_size, stride, pad, bias=False))
if bn:
layers.append(nn.BatchNorm2d(c_out))
return nn.Sequential(*layers)
def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
"""Custom convolutional layer for simplicity."""
layers = []
layers.append(nn.Conv2d(c_in, c_out, k_size, stride, pad, bias=False))
if bn:
layers.append(nn.BatchNorm2d(c_out))
return nn.Sequential(*layers)
class G11(nn.Module):
def __init__(self, conv_dim=64):
super(G11, self).__init__()
# encoding blocks
self.conv1 = conv(1, conv_dim, 4)
self.conv1_svhn = conv(3, conv_dim, 4)
self.conv2 = conv(conv_dim, conv_dim * 2, 4)
# residual blocks
res_dim = conv_dim * 2
self.conv3 = conv(res_dim, res_dim, 3, 1, 1)
self.conv4 = conv(res_dim, res_dim, 3, 1, 1)
# decoding blocks
self.deconv1 = deconv(conv_dim * 2, conv_dim, 4)
self.deconv2 = deconv(conv_dim, 1, 4, bn=False)
self.deconv2_svhn = deconv(conv_dim, 3, 4, bn=False)
def forward(self, x, svhn=False):
if svhn:
out = F.leaky_relu(self.conv1_svhn(x), 0.05) # (?, 64, 16, 16)
else:
out = F.leaky_relu(self.conv1(x), 0.05) # (?, 64, 16, 16)
out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8)
out = F.leaky_relu(self.conv3(out), 0.05) # ( " )
out = F.leaky_relu(self.conv4(out), 0.05) # ( " )
out = F.leaky_relu(self.deconv1(out), 0.05) # (?, 64, 16, 16)
if svhn:
out = F.tanh(self.deconv2_svhn(out)) # (?, 3, 32, 32)
else:
out = F.tanh(self.deconv2(out)) # (?, 3, 32, 32)
return out
def encode(self, x, svhn=False):
if svhn:
out = F.leaky_relu(self.conv1_svhn(x), 0.05) # (?, 64, 16, 16)
else:
out = F.leaky_relu(self.conv1(x), 0.05) # (?, 64, 16, 16)
out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8)
out = F.leaky_relu(self.conv3(out), 0.05) # ( " )
return out
def decode(self, out, svhn=False):
out = F.leaky_relu(self.conv4(out), 0.05)
out = F.leaky_relu(self.deconv1(out), 0.05) # (?, 64, 16, 16)
if svhn:
out = F.tanh(self.deconv2_svhn(out)) # (?, 3, 32, 32)
else:
out = F.tanh(self.deconv2(out)) # (?, 3, 32, 32)
return out
def encode_params(self):
layers_basic = list(self.conv1_svhn.parameters()) + \
list(self.conv1.parameters())
layers_basic += list(self.conv2.parameters())
layers_basic += list(self.conv3.parameters())
return layers_basic
def decode_params(self):
layers_basic = list(self.deconv2_svhn.parameters()) + \
list(self.deconv2.parameters())
layers_basic += list(self.deconv1.parameters())
layers_basic += list(self.conv4.parameters())
return layers_basic
def unshared_parameters(self):
return list(self.deconv2_svhn.parameters()) + list(self.conv1_svhn.parameters()) + \
list(self.deconv2.parameters()) + list(self.conv1.parameters())
class G22(nn.Module):
def __init__(self, conv_dim=64):
super(G22, self).__init__()
# encoding blocks
self.conv1 = conv(3, conv_dim, 4)
self.conv1_mnist = conv(1, conv_dim, 4)
self.conv2 = conv(conv_dim, conv_dim * 2, 4)
# residual blocks
res_dim = conv_dim * 2
self.conv3 = conv(res_dim, res_dim, 3, 1, 1)
self.conv4 = conv(res_dim, res_dim, 3, 1, 1)
# decoding blocks
self.deconv1 = deconv(conv_dim * 2, conv_dim, 4)
self.deconv2 = deconv(conv_dim, 3, 4, bn=False)
self.deconv2_mnist = deconv(conv_dim, 1, 4, bn=False)
def forward(self, x, mnist=False):
if mnist:
out = F.leaky_relu(self.conv1_mnist(x), 0.05) # (?, 64, 16, 16)
else:
out = F.leaky_relu(self.conv1(x), 0.05) # (?, 64, 16, 16)
out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8)
out = F.leaky_relu(self.conv3(out), 0.05) # ( " )
out = F.leaky_relu(self.conv4(out), 0.05) # ( " )
out = F.leaky_relu(self.deconv1(out), 0.05) # (?, 64, 16, 16)
if mnist:
out = F.tanh(self.deconv2_mnist(out)) # (?, 3, 32, 32)
else:
out = F.tanh(self.deconv2(out)) # (?, 3, 32, 32)
return out
def encode(self, x, mnist=False):
if mnist:
out = F.leaky_relu(self.conv1_mnist(x), 0.05) # (?, 64, 16, 16)
else:
out = F.leaky_relu(self.conv1(x), 0.05) # (?, 64, 16, 16)
out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8)
out = F.leaky_relu(self.conv3(out), 0.05) # ( " )
return out
def decode(self, out, mnist=False):
out = F.leaky_relu(self.conv4(out), 0.05)
out = F.leaky_relu(self.deconv1(out), 0.05) # (?, 64, 16, 16)
if mnist:
out = F.tanh(self.deconv2_mnist(out)) # (?, 3, 32, 32)
else:
out = F.tanh(self.deconv2(out)) # (?, 3, 32, 32)
return out
def encode_params(self):
layers_basic = list(self.conv1_mnist.parameters()) + \
list(self.conv1.parameters())
layers_basic += list(self.conv2.parameters())
layers_basic += list(self.conv3.parameters())
return layers_basic
def decode_params(self):
layers_basic = list(self.deconv2_mnist.parameters()) + \
list(self.deconv2.parameters())
layers_basic += list(self.deconv1.parameters())
layers_basic += list(self.conv4.parameters())
return layers_basic
def unshared_parameters(self):
return list(self.deconv2_mnist.parameters()) + list(self.conv1_mnist.parameters()) + \
list(self.deconv2.parameters()) + list(self.conv1.parameters())
class D1(nn.Module):
"""Discriminator for mnist."""
def __init__(self, conv_dim=64, use_labels=False):
super(D1, self).__init__()
self.conv1 = conv(1, conv_dim, 4, bn=False)
self.conv2 = conv(conv_dim, conv_dim * 2, 4)
self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4)
n_out = 11 if use_labels else 1
self.fc = conv(conv_dim * 4, n_out, 4, 1, 0, False)
def forward(self, x_0):
out = F.leaky_relu(self.conv1(x_0), 0.05) # (?, 64, 16, 16)
out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8)
out = F.leaky_relu(self.conv3(out), 0.05) # (?, 256, 4, 4)
out_0 = self.fc(out).squeeze()
return out_0
class D2(nn.Module):
"""Discriminator for svhn."""
def __init__(self, conv_dim=64, use_labels=False):
super(D2, self).__init__()
self.conv1 = conv(3, conv_dim, 4, bn=False)
self.conv2 = conv(conv_dim, conv_dim * 2, 4)
self.conv3 = conv(conv_dim * 2, conv_dim * 4, 4)
n_out = 11 if use_labels else 1
self.fc = conv(conv_dim * 4, n_out, 4, 1, 0, False)
def forward(self, x_0):
out = F.leaky_relu(self.conv1(x_0), 0.05) # (?, 64, 16, 16)
out = F.leaky_relu(self.conv2(out), 0.05) # (?, 128, 8, 8)
out = F.leaky_relu(self.conv3(out), 0.05) # (?, 256, 4, 4)
out_0 = self.fc(out).squeeze()
return out_0
================================================
FILE: mnist_to_svhn/solver_autoencoder.py
================================================
import os
import numpy as np
import scipy.io
import torch
from torch import optim
from torch.autograd import Variable
from model import D1, D2
from model import G11, G22
class Solver(object):
def __init__(self, config, svhn_loader, mnist_loader):
self.svhn_loader = svhn_loader
self.mnist_loader = mnist_loader
self.g11 = None
self.g22 = None
self.d1 = None
self.d2 = None
self.g_optimizer = None
self.d_optimizer = None
self.num_classes = config.num_classes
self.beta1 = config.beta1
self.beta2 = config.beta2
self.g_conv_dim = config.g_conv_dim
self.d_conv_dim = config.d_conv_dim
self.train_iters = config.train_iters
self.batch_size = config.batch_size
self.lr = config.lr
self.log_step = config.log_step
self.sample_step = config.sample_step
self.sample_path = config.sample_path
self.model_path = config.model_path
self.kl_lambda = config.kl_lambda
self.build_model()
def build_model(self):
"""Builds a generator and a discriminator."""
self.g11 = G11(conv_dim=self.g_conv_dim)
self.g22 = G22(conv_dim=self.g_conv_dim)
self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False)
self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False)
g_params = list(self.g11.parameters()) + list(self.g22.parameters())
d_params = list(self.d1.parameters()) + list(self.d2.parameters())
self.g_optimizer = optim.Adam(g_params, self.lr, [self.beta1, self.beta2])
self.d_optimizer = optim.Adam(d_params, self.lr, [self.beta1, self.beta2])
if torch.cuda.is_available():
self.g11.cuda()
self.g22.cuda()
self.d1.cuda()
self.d2.cuda()
def merge_images(self, sources, targets, k=10):
_, _, h, w = sources.shape
row = int(np.sqrt(self.batch_size))
merged = np.zeros([3, row * h, row * w * 2])
for idx, (s, t) in enumerate(zip(sources, targets)):
i = idx // row
j = idx % row
merged[:, i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h] = s
merged[:, i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t
return merged.transpose(1, 2, 0)
def to_var(self, x):
"""Converts numpy to variable."""
if torch.cuda.is_available():
x = x.cuda()
return Variable(x)
def to_data(self, x):
"""Converts variable to numpy."""
if torch.cuda.is_available():
x = x.cpu()
return x.data.numpy()
def reset_grad(self):
"""Zeros the gradient buffers."""
self.g_optimizer.zero_grad()
self.d_optimizer.zero_grad()
def _compute_kl(self, mu):
mu_2 = torch.pow(mu, 2)
encoding_loss = torch.mean(mu_2)
return encoding_loss
def train(self):
svhn_iter = iter(self.svhn_loader)
mnist_iter = iter(self.mnist_loader)
iter_per_epoch = min(len(svhn_iter), len(mnist_iter))
# fixed mnist and svhn for sampling
fixed_svhn = self.to_var(svhn_iter.next()[0])
fixed_mnist = self.to_var(mnist_iter.next()[0])
# Train autoencoder for mnist
for step in range(self.train_iters + 1):
# reset data_iter for each epoch
if (step + 1) % iter_per_epoch == 0:
mnist_iter = iter(self.mnist_loader)
# mnist dataset
mnist_data, m_labels_data = mnist_iter.next()
mnist, m_labels = self.to_var(mnist_data), self.to_var(m_labels_data)
# ============ train D ============#
# train with real images
self.reset_grad()
out = self.d1(mnist)
d1_loss = torch.mean((out - 1) ** 2)
d_mnist_loss = d1_loss
d_real_loss = d1_loss
d_real_loss.backward()
self.d_optimizer.step()
# train with fake images
self.reset_grad()
fake_mnist = self.g22.forward(mnist, mnist=True)
out = self.d1(fake_mnist)
d2_loss = torch.mean(out ** 2)
d_fake_loss = d2_loss
d_fake_loss.backward()
self.d_optimizer.step()
# ============ train G ============
self.reset_grad()
fake_mnist = self.g22.forward(mnist, mnist=True)
out = self.d1(fake_mnist)
g_loss = torch.mean((out - 1) ** 2)
g_loss += torch.mean((mnist - fake_mnist) ** 2)
em = self.g22.encode(mnist, mnist=True)
g_loss += self.kl_lambda * self._compute_kl(em)
g_loss.backward()
self.g_optimizer.step()
# print the log info
if (step + 1) % self.log_step == 0:
print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, '
'g_loss: %.4f'
% (step + 1, self.train_iters, d_real_loss.data[0], d_mnist_loss.data[0],
d_fake_loss.data[0], g_loss.data[0]))
# save the sampled images
if (step + 1) % self.sample_step == 0:
fake_mnist = self.g22.forward(fixed_mnist, mnist=True)
mnist, fake_mnist = self.to_data(fixed_mnist), self.to_data(fake_mnist)
merged = self.merge_images(mnist, fake_mnist)
path = os.path.join(self.sample_path, 'sample-%d-m-s.png' % (step + 1))
scipy.misc.imsave(path, merged)
print('saved %s' % path)
if (step + 1) % 10000 == 0:
# save the model parameters for each epoch
g22_path = os.path.join(self.model_path, 'g22-%d.pkl' % (step + 1))
d1_path = os.path.join(self.model_path, 'd1-%d.pkl' % (step + 1))
torch.save(self.g22.state_dict(), g22_path)
torch.save(self.d1.state_dict(), d1_path)
# Train autoencoder for svhn
for step in range(self.train_iters + 1):
# reset data_iter for each epoch
if (step + 1) % iter_per_epoch == 0:
svhn_iter = iter(self.svhn_loader)
# load svhn and mnist dataset
svhn_data, s_labels_data = svhn_iter.next()
svhn, s_labels = self.to_var(svhn_data), self.to_var(s_labels_data).long().squeeze()
# ============ train D ============#
# train with real images
self.reset_grad()
out = self.d2(svhn)
d2_loss = torch.mean((out - 1) ** 2)
d_svhn_loss = d2_loss
d_real_loss = d2_loss
d_real_loss.backward()
self.d_optimizer.step()
# train with fake images
self.reset_grad()
fake_svhn = self.g11.forward(svhn, svhn=True)
out = self.d2(fake_svhn)
d1_loss = torch.mean(out ** 2)
d_fake_loss = d1_loss
d_fake_loss.backward()
self.d_optimizer.step()
# ============ train G ============#
self.reset_grad()
fake_svhn = self.g11.forward(svhn, svhn=True)
out = self.d2(fake_svhn)
g_loss = torch.mean((out - 1) ** 2)
g_loss += torch.mean((svhn - fake_svhn) ** 2)
es = self.g11.encode(svhn, svhn=True)
g_loss += self.kl_lambda * self._compute_kl(es)
g_loss.backward()
self.g_optimizer.step()
# print the log info
if (step + 1) % self.log_step == 0:
print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f,'
'd_fake_loss: %.4f, g_loss: %.4f'
% (step + 1, self.train_iters, d_real_loss.data[0],
d_svhn_loss.data[0], d_fake_loss.data[0], g_loss.data[0]))
# save the sampled images
if (step + 1) % self.sample_step == 0:
fake_svhn = self.g11.forward(fixed_svhn, svhn=True)
svhn, fake_svhn = self.to_data(fixed_svhn), self.to_data(fake_svhn)
merged = self.merge_images(svhn, fake_svhn)
path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (step + 1))
scipy.misc.imsave(path, merged)
print('saved %s' % path)
if (step + 1) % 10000 == 0:
# save the model parameters for each epoch
g11_path = os.path.join(self.model_path, 'g11-%d.pkl' % (step + 1))
d2_path = os.path.join(self.model_path, 'd2-%d.pkl' % (step + 1))
torch.save(self.g11.state_dict(), g11_path)
torch.save(self.d2.state_dict(), d2_path)
================================================
FILE: mnist_to_svhn/solver_mnist_to_svhn.py
================================================
import os
import numpy as np
import scipy.io
import torch
from torch import optim
from torch.autograd import Variable
from model import D1, D2
from model import G11
class Solver(object):
def __init__(self, config, svhn_loader, mnist_loader):
self.config = config
self.svhn_loader = svhn_loader
self.mnist_loader = mnist_loader
self.g11 = None
self.g22 = None
self.d1 = None
self.d2 = None
self.g_optimizer = None
self.num_classes = config.num_classes
self.beta1 = config.beta1
self.beta2 = config.beta2
self.g_conv_dim = config.g_conv_dim
self.d_conv_dim = config.d_conv_dim
self.train_iters = config.train_iters
self.batch_size = config.batch_size
self.lr = config.lr
self.kl_lambda = config.kl_lambda
self.log_step = config.log_step
self.sample_step = config.sample_step
self.sample_path = config.sample_path
self.model_path = config.model_path
self.g11_load_path = os.path.join(config.load_path, "g11-" + str(config.load_iter) + ".pkl")
self.d1_load_path = os.path.join(config.load_path, "d1-" + str(config.load_iter) + ".pkl")
self.g22_load_path = os.path.join(config.load_path, "g22-" + str(config.load_iter) + ".pkl")
self.d2_load_path = os.path.join(config.load_path, "d2-" + str(config.load_iter) + ".pkl")
self.build_model()
def build_model(self):
"""Builds a generator and a discriminator."""
self.g11 = G11(conv_dim=self.g_conv_dim)
self.g_optimizer = optim.Adam(list(self.g11.encode_params()) + list(self.g11.decode_params()), self.lr,
[self.beta1, self.beta2])
self.unshared_optimizer = optim.Adam(list(self.g11.unshared_parameters()), self.lr,
[self.beta1, self.beta2])
self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False)
self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False)
self.d_optimizer = optim.Adam(list(self.d1.parameters()) + list(self.d2.parameters()), self.lr,
[self.beta1, self.beta2])
if torch.cuda.is_available():
self.g11.cuda()
self.d1.cuda()
self.d2.cuda()
def merge_images(self, sources, targets, k=10):
_, _, h, w = sources.shape
row = int(np.sqrt(self.batch_size)) + 1
merged = np.zeros([3, row * h, row * w * 2])
for idx, (s, t) in enumerate(zip(sources, targets)):
i = idx // row
j = idx % row
merged[:, i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h] = s
merged[:, i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t
return merged.transpose(1, 2, 0)
def to_var(self, x, volatile=False):
"""Converts numpy to variable."""
if torch.cuda.is_available():
x = x.cuda()
if volatile:
return Variable(x, volatile=True)
return Variable(x)
def to_no_grad_var(self, x):
x = self.to_data(x, no_numpy=True)
return self.to_var(x, volatile=True)
def to_data(self, x, no_numpy=False):
"""Converts variable to numpy."""
if torch.cuda.is_available():
x = x.cpu()
if no_numpy:
return x.data
return x.data.numpy()
def reset_grad(self):
"""Zeros the gradient buffers."""
self.unshared_optimizer.zero_grad()
self.g_optimizer.zero_grad()
self.d_optimizer.zero_grad()
def _compute_kl(self, mu):
mu_2 = torch.pow(mu, 2)
encoding_loss = torch.mean(mu_2)
return encoding_loss
def train(self):
self.build_model()
if self.config.pretrained_g:
self.g11.load_state_dict(torch.load(self.g11_load_path))
svhn_iter = iter(self.svhn_loader)
mnist_iter = iter(self.mnist_loader)
iter_per_epoch = min(len(svhn_iter), len(mnist_iter))
# fixed mnist and svhn for sampling
svhn_fixed_data, svhn_fixed_labels = svhn_iter.next()
mnist_fixed_data, mnist_fixed_labels = mnist_iter.next()
fixed_mnist = self.to_var(mnist_fixed_data)
counter = 0
for step in range(self.train_iters + 1):
# reset data_iter for each epoch
if (step + 1) % iter_per_epoch == 0:
mnist_iter = iter(self.mnist_loader)
svhn_iter = iter(self.svhn_loader)
# load svhn and mnist dataset
svhn_data, s_labels_data = svhn_iter.next()
mnist_data, m_labels_data = mnist_iter.next()
svhn, s_labels = self.to_var(svhn_data), self.to_var(s_labels_data).long().squeeze()
mnist, m_labels = self.to_var(mnist_data), self.to_var(m_labels_data)
# This sets the maximum number of items for A domain
# We assume max_items is a multiple of batch_size
# And reset mnist loader when we pass the number of allowed items.
if self.batch_size > self.config.max_items:
exit(-1)
elif self.batch_size == self.config.max_items:
mnist = fixed_mnist
elif self.batch_size < self.config.max_items:
counter += 1
if counter * self.batch_size >= self.config.max_items:
mnist_iter = iter(self.mnist_loader)
counter = 0
# ============ train D ============#
# train with real images
self.reset_grad()
out = self.d1(mnist)
d1_loss = torch.mean((out - 1) ** 2)
out = self.d2(svhn)
d2_loss = torch.mean((out - 1) ** 2)
d_mnist_loss = d1_loss
d_svhn_loss = d2_loss
# Only optimizing d1
d_real_loss = d1_loss + d2_loss
d_real_loss.backward()
self.d_optimizer.step()
# train with fake images
self.reset_grad()
es = self.g11.encode(svhn, svhn=True)
fake_mnist = self.g11.decode(es)
out = self.d1(fake_mnist)
d2_loss = torch.mean(out ** 2)
em = self.g11.encode(mnist)
fake_svhn = self.g11.decode(em, svhn=True)
out = self.d2(fake_svhn)
d1_loss = torch.mean(out ** 2)
d_fake_loss = d2_loss + d1_loss
d_fake_loss.backward()
self.d_optimizer.step()
# ============ train G ============#
# train mnist-svhn-mnist cycle
self.reset_grad()
es = self.g11.encode(svhn, svhn=True)
fake_mnist = self.g11.decode(es)
out = self.d1(fake_mnist)
g_loss = torch.mean((out - 1) ** 2)
em = self.g11.encode(mnist)
fake_svhn = self.g11.decode(em, svhn=True)
out = self.d2(fake_svhn)
g_loss += torch.mean((out - 1) ** 2)
self.reset_grad()
em = self.g11.encode(mnist)
fake_mnist = self.g11.decode(em)
g_loss += torch.mean((mnist - fake_mnist) ** 2)
if self.config.one_way_cycle:
em = self.g11.encode(mnist)
fake_svhn = self.g11.decode(em, svhn=True)
es = self.g11.encode(fake_svhn, svhn=True)
fake_mnist = self.g11.decode(es)
g_loss += torch.mean((mnist - fake_mnist) ** 2)
g_loss.backward()
self.unshared_optimizer.step()
if not self.config.freeze_shared:
self.reset_grad()
es = self.g11.encode(svhn, svhn=True)
fake_es = self.g11.decode(es, svhn=True)
g_loss = torch.mean((svhn - fake_es) ** 2)
g_loss += self.kl_lambda * self._compute_kl(es)
g_loss.backward()
self.g_optimizer.step()
# print the log info
if (step + 1) % self.log_step == 0:
print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, '
'd_fake_loss: %.4f, g_loss: %.4f'
% (step + 1, self.train_iters, d_real_loss.data[0], d_mnist_loss.data[0],
d_svhn_loss.data[0], d_fake_loss.data[0], g_loss.data[0]))
# save the sampled images
if (step + 1) % self.sample_step == 0:
em = self.g11.encode(fixed_mnist)
fake_svhn_var = self.g11.decode(em, svhn=True)
fake_svhn = self.to_data(fake_svhn_var)
if self.config.save_models_and_samples:
merged = self.merge_images(mnist_fixed_data, fake_svhn)
path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (step + 1))
scipy.misc.imsave(path, merged)
print('saved %s' % path)
if (step + 1) % self.config.num_iters_save_model_and_return == 0:
# save the model parameters for each epoch
if self.config.save_models_and_samples:
g11_path = os.path.join(self.model_path, 'g11-%d.pkl' % (step + 1))
d1_path = os.path.join(self.model_path, 'd1-%d.pkl' % (step + 1))
d2_path = os.path.join(self.model_path, 'd2-%d.pkl' % (step + 1))
torch.save(self.g11.state_dict(), g11_path)
torch.save(self.d1.state_dict(), d1_path)
torch.save(self.d2.state_dict(), d2_path)
return
================================================
FILE: mnist_to_svhn/solver_svhn_to_mnist.py
================================================
import os
import numpy as np
import scipy.io
import torch
from torch import optim
from torch.autograd import Variable
from model import D1, D2
from model import G22
class Solver(object):
def __init__(self, config, svhn_loader, mnist_loader):
self.config = config
self.svhn_loader = svhn_loader
self.mnist_loader = mnist_loader
self.g11 = None
self.g22 = None
self.d1 = None
self.d2 = None
self.g_optimizer = None
self.num_classes = config.num_classes
self.beta1 = config.beta1
self.beta2 = config.beta2
self.g_conv_dim = config.g_conv_dim
self.d_conv_dim = config.d_conv_dim
self.train_iters = config.train_iters
self.batch_size = config.batch_size
self.lr = config.lr
self.kl_lambda = config.kl_lambda
self.log_step = config.log_step
self.sample_step = config.sample_step
self.sample_path = config.sample_path
self.model_path = config.model_path
self.g11_load_path = os.path.join(config.load_path, "g11-" + str(config.load_iter) + ".pkl")
self.d1_load_path = os.path.join(config.load_path, "d1-" + str(config.load_iter) + ".pkl")
self.g22_load_path = os.path.join(config.load_path, "g22-" + str(config.load_iter) + ".pkl")
self.d2_load_path = os.path.join(config.load_path, "d2-" + str(config.load_iter) + ".pkl")
self.build_model()
def build_model(self):
"""Builds a generator and a discriminator."""
self.g22 = G22(conv_dim=self.g_conv_dim)
self.g_optimizer = optim.Adam(list(self.g22.encode_params()) + list(self.g22.decode_params()), self.lr,
[self.beta1, self.beta2])
self.unshared_optimizer = optim.Adam(list(self.g22.unshared_parameters()), self.lr,
[self.beta1, self.beta2])
self.d1 = D1(conv_dim=self.d_conv_dim, use_labels=False)
self.d2 = D2(conv_dim=self.d_conv_dim, use_labels=False)
self.d_optimizer = optim.Adam(list(self.d1.parameters()) + list(self.d2.parameters()), self.lr,
[self.beta1, self.beta2])
if torch.cuda.is_available():
self.g22.cuda()
self.d1.cuda()
self.d2.cuda()
def merge_images(self, sources, targets, k=10):
_, _, h, w = sources.shape
row = int(np.sqrt(self.batch_size)) + 1
merged = np.zeros([3, row * h, row * w * 2])
for idx, (s, t) in enumerate(zip(sources, targets)):
i = idx // row
j = idx % row
merged[:, i * h:(i + 1) * h, (j * 2) * h:(j * 2 + 1) * h] = s
merged[:, i * h:(i + 1) * h, (j * 2 + 1) * h:(j * 2 + 2) * h] = t
return merged.transpose(1, 2, 0)
def to_var(self, x, volatile=False):
"""Converts numpy to variable."""
if torch.cuda.is_available():
x = x.cuda()
if volatile:
return Variable(x, volatile=True)
return Variable(x)
def to_no_grad_var(self, x):
x = self.to_data(x, no_numpy=True)
return self.to_var(x, volatile=True)
def to_data(self, x, no_numpy=False):
"""Converts variable to numpy."""
if torch.cuda.is_available():
x = x.cpu()
if no_numpy:
return x.data
return x.data.numpy()
def reset_grad(self):
"""Zeros the gradient buffers."""
self.unshared_optimizer.zero_grad()
self.g_optimizer.zero_grad()
self.d_optimizer.zero_grad()
def _compute_kl(self, mu):
mu_2 = torch.pow(mu, 2)
encoding_loss = torch.mean(mu_2)
return encoding_loss
def train(self):
self.build_model()
if self.config.pretrained_g:
self.g22.load_state_dict(torch.load(self.g22_load_path))
svhn_iter = iter(self.svhn_loader)
mnist_iter = iter(self.mnist_loader)
iter_per_epoch = min(len(svhn_iter), len(mnist_iter))
# fixed mnist and svhn for sampling
svhn_fixed_data, svhn_fixed_labels = svhn_iter.next()
mnist_fixed_data, mnist_fixed_labels = mnist_iter.next()
fixed_svhn = self.to_var(svhn_fixed_data)
counter = 0
for step in range(self.train_iters + 1):
# reset data_iter for each epoch
if (step + 1) % iter_per_epoch == 0:
mnist_iter = iter(self.mnist_loader)
svhn_iter = iter(self.svhn_loader)
# load svhn and mnist dataset
svhn_data, s_labels_data = svhn_iter.next()
mnist_data, m_labels_data = mnist_iter.next()
svhn, s_labels = self.to_var(svhn_data), self.to_var(s_labels_data).long().squeeze()
mnist, m_labels = self.to_var(mnist_data), self.to_var(m_labels_data)
# This sets the maximum number of items for A domain
# We assume max_items is a multiple of batch_size
# And reset mnist loader when we pass the number of allowed items.
if self.batch_size > self.config.max_items:
exit(-1)
elif self.batch_size == self.config.max_items:
svhn = fixed_svhn
elif self.batch_size < self.config.max_items:
counter += 1
if counter * self.batch_size >= self.config.max_items:
svhn_iter = iter(self.svhn_loader)
counter = 0
# ============ train D ============#
# train with real images
self.reset_grad()
out = self.d1(mnist)
d1_loss = torch.mean((out - 1) ** 2)
out = self.d2(svhn)
d2_loss = torch.mean((out - 1) ** 2)
d_mnist_loss = d1_loss
d_svhn_loss = d2_loss
# Only optimizing d1
d_real_loss = d1_loss + d2_loss
d_real_loss.backward()
self.d_optimizer.step()
# train with fake images
self.reset_grad()
es = self.g22.encode(svhn)
fake_mnist = self.g22.decode(es, mnist=True)
out = self.d1(fake_mnist)
d2_loss = torch.mean(out ** 2)
em = self.g22.encode(mnist, mnist=True)
fake_svhn = self.g22.decode(em)
out = self.d2(fake_svhn)
d1_loss = torch.mean(out ** 2)
d_fake_loss = d2_loss + d1_loss
d_fake_loss.backward()
self.d_optimizer.step()
# ============ train G ============#
self.reset_grad()
es = self.g22.encode(svhn)
fake_mnist = self.g22.decode(es, mnist=True)
out = self.d1(fake_mnist)
g_loss = torch.mean((out - 1) ** 2)
em = self.g22.encode(mnist, mnist=True)
fake_svhn = self.g22.decode(em)
out = self.d2(fake_svhn)
g_loss += torch.mean((out - 1) ** 2)
self.reset_grad()
es = self.g22.encode(svhn)
fake_svhn = self.g22.decode(es)
g_loss += torch.mean((svhn - fake_svhn) ** 2)
if self.config.one_way_cycle:
es = self.g22.encode(svhn)
fake_mnist = self.g22.decode(es, mnist=True)
es = self.g22.encode(fake_mnist, mnist=True)
fake_svhn = self.g22.decode(es)
g_loss += torch.mean((svhn - fake_svhn) ** 2)
g_loss.backward()
self.unshared_optimizer.step()
if not self.config.freeze_shared:
self.reset_grad()
em = self.g22.encode(mnist, mnist=True)
fake_em = self.g22.decode(em, mnist=True)
g_loss = torch.mean((mnist - fake_em) ** 2)
g_loss += self.kl_lambda * self._compute_kl(em)
g_loss.backward()
self.g_optimizer.step()
# print the log info
if (step + 1) % self.log_step == 0:
print('Step [%d/%d], d_real_loss: %.4f, d_mnist_loss: %.4f, d_svhn_loss: %.4f, '
'd_fake_loss: %.4f, g_loss: %.4f'
% (step + 1, self.train_iters, d_real_loss.data[0], d_mnist_loss.data[0],
d_svhn_loss.data[0], d_fake_loss.data[0], g_loss.data[0]))
# save the sampled images
if (step + 1) % self.sample_step == 0:
es = self.g22.encode(fixed_svhn)
fake_mnist_var = self.g22.decode(es, mnist=True)
fake_mnist = self.to_data(fake_mnist_var)
if self.config.save_models_and_samples:
merged = self.merge_images(svhn_fixed_data, fake_mnist)
path = os.path.join(self.sample_path, 'sample-%d-s-m.png' % (step + 1))
scipy.misc.imsave(path, merged)
print('saved %s' % path)
if (step + 1) % self.config.num_iters_save_model_and_return == 0:
# save the model parameters for each epoch
if self.config.save_models_and_samples:
g22_path = os.path.join(self.model_path, 'g22-%d.pkl' % (step + 1))
d1_path = os.path.join(self.model_path, 'd1-%d.pkl' % (step + 1))
d2_path = os.path.join(self.model_path, 'd2-%d.pkl' % (step + 1))
torch.save(self.g22.state_dict(), g22_path)
torch.save(self.d1.state_dict(), d1_path)
torch.save(self.d2.state_dict(), d2_path)
return
gitextract_3ffpn665/
├── LICENSE
├── README.md
├── drawing_and_style_transfer/
│ ├── data/
│ │ ├── __init__.py
│ │ ├── aligned_dataset.py
│ │ ├── base_data_loader.py
│ │ ├── base_dataset.py
│ │ ├── image_folder.py
│ │ ├── single_dataset.py
│ │ └── unaligned_dataset.py
│ ├── datasets/
│ │ ├── combine_A_and_B.py
│ │ ├── download_cyclegan_dataset.sh
│ │ └── make_dataset_aligned.py
│ ├── environment.yml
│ ├── models/
│ │ ├── __init__.py
│ │ ├── autoencoder_model.py
│ │ ├── base_model.py
│ │ ├── networks.py
│ │ ├── ost.py
│ │ └── test_model.py
│ ├── options/
│ │ ├── __init__.py
│ │ ├── base_options.py
│ │ ├── test_options.py
│ │ └── train_options.py
│ ├── scripts/
│ │ ├── test_ost.sh
│ │ ├── train_autoencoder.sh
│ │ └── train_ost.sh
│ ├── test.py
│ ├── train.py
│ └── util/
│ ├── __init__.py
│ ├── get_data.py
│ ├── html.py
│ ├── image_pool.py
│ ├── util.py
│ └── visualizer.py
└── mnist_to_svhn/
├── data_loader.py
├── download.sh
├── main_autoencoder.py
├── main_mnist_to_svhn.py
├── main_svhn_to_mnist.py
├── model.py
├── solver_autoencoder.py
├── solver_mnist_to_svhn.py
└── solver_svhn_to_mnist.py
SYMBOL INDEX (239 symbols across 30 files)
FILE: drawing_and_style_transfer/data/__init__.py
function CreateDataLoader (line 5) | def CreateDataLoader(opt):
function CreateDataset (line 12) | def CreateDataset(opt):
class CustomDatasetDataLoader (line 30) | class CustomDatasetDataLoader(BaseDataLoader):
method name (line 31) | def name(self):
method initialize (line 34) | def initialize(self, opt):
method load_data (line 43) | def load_data(self):
method __len__ (line 46) | def __len__(self):
method __iter__ (line 49) | def __iter__(self):
FILE: drawing_and_style_transfer/data/aligned_dataset.py
class AlignedDataset (line 10) | class AlignedDataset(BaseDataset):
method initialize (line 11) | def initialize(self, opt):
method __getitem__ (line 18) | def __getitem__(self, index):
method __len__ (line 60) | def __len__(self):
method name (line 63) | def name(self):
FILE: drawing_and_style_transfer/data/base_data_loader.py
class BaseDataLoader (line 1) | class BaseDataLoader():
method __init__ (line 2) | def __init__(self):
method initialize (line 5) | def initialize(self, opt):
method load_data (line 9) | def load_data(self):
FILE: drawing_and_style_transfer/data/base_dataset.py
class BaseDataset (line 6) | class BaseDataset(data.Dataset):
method __init__ (line 7) | def __init__(self):
method name (line 10) | def name(self):
method initialize (line 13) | def initialize(self, opt):
function get_transform (line 17) | def get_transform(opt):
function __scale_width (line 44) | def __scale_width(img, target_width):
FILE: drawing_and_style_transfer/data/image_folder.py
function is_image_file (line 20) | def is_image_file(filename):
function make_dataset (line 24) | def make_dataset(dir, max_items=-1, start=0):
function default_loader (line 39) | def default_loader(path):
class ImageFolder (line 43) | class ImageFolder(data.Dataset):
method __init__ (line 44) | def __init__(self, root, transform=None, return_paths=False,
method __getitem__ (line 58) | def __getitem__(self, index):
method __len__ (line 68) | def __len__(self):
FILE: drawing_and_style_transfer/data/single_dataset.py
class SingleDataset (line 7) | class SingleDataset(BaseDataset):
method initialize (line 8) | def initialize(self, opt):
method __getitem__ (line 19) | def __getitem__(self, index):
method __len__ (line 34) | def __len__(self):
method name (line 37) | def name(self):
FILE: drawing_and_style_transfer/data/unaligned_dataset.py
class UnalignedDataset (line 8) | class UnalignedDataset(BaseDataset):
method initialize (line 9) | def initialize(self, opt):
method __getitem__ (line 23) | def __getitem__(self, index):
method __len__ (line 53) | def __len__(self):
method name (line 56) | def name(self):
FILE: drawing_and_style_transfer/datasets/make_dataset_aligned.py
function get_file_paths (line 6) | def get_file_paths(folder):
function align_images (line 20) | def align_images(a_file_paths, b_file_paths, target_path):
FILE: drawing_and_style_transfer/models/__init__.py
function create_model (line 1) | def create_model(opt):
FILE: drawing_and_style_transfer/models/autoencoder_model.py
class AutoEncoderModel (line 11) | class AutoEncoderModel(BaseModel):
method name (line 12) | def name(self):
method set_encoders_and_decoders (line 15) | def set_encoders_and_decoders(self, opt):
method initialize (line 52) | def initialize(self, opt):
method set_input (line 103) | def set_input(self, input):
method forward (line 112) | def forward(self):
method netEnc (line 115) | def netEnc(self, x):
method netDec (line 118) | def netDec(self, x):
method test (line 121) | def test(self):
method get_image_paths (line 127) | def get_image_paths(self):
method backward_D_basic (line 130) | def backward_D_basic(self, netD, real, fake):
method backward_D (line 143) | def backward_D(self):
method _compute_kl (line 148) | def _compute_kl(self, mu):
method backward_G (line 153) | def backward_G(self):
method optimize_parameters (line 172) | def optimize_parameters(self):
method get_current_errors (line 188) | def get_current_errors(self):
method get_current_visuals (line 192) | def get_current_visuals(self):
method save (line 198) | def save(self, label):
FILE: drawing_and_style_transfer/models/base_model.py
class BaseModel (line 5) | class BaseModel(object):
method name (line 6) | def name(self):
method initialize (line 9) | def initialize(self, opt):
method set_input (line 17) | def set_input(self, input):
method forward (line 20) | def forward(self):
method test (line 24) | def test(self):
method get_image_paths (line 27) | def get_image_paths(self):
method optimize_parameters (line 30) | def optimize_parameters(self):
method get_current_visuals (line 33) | def get_current_visuals(self):
method get_current_errors (line 36) | def get_current_errors(self):
method save (line 39) | def save(self, label):
method save_network (line 43) | def save_network(self, network, network_label, epoch_label, gpu_ids):
method load_network (line 51) | def load_network(self, network, network_label, epoch_label):
method update_learning_rate (line 57) | def update_learning_rate(self):
method as_np (line 63) | def as_np(self, data):
FILE: drawing_and_style_transfer/models/networks.py
class pixel_norm (line 13) | class pixel_norm(nn.Module):
method forward (line 14) | def forward(self, x, epsilon=1e-8):
function weights_init_normal (line 18) | def weights_init_normal(m):
function weights_init_xavier (line 30) | def weights_init_xavier(m):
function weights_init_kaiming (line 42) | def weights_init_kaiming(m):
function weights_init_orthogonal (line 54) | def weights_init_orthogonal(m):
function init_weights (line 66) | def init_weights(net, init_type='normal'):
function get_norm_layer (line 80) | def get_norm_layer(norm_type='instance'):
function get_scheduler (line 92) | def get_scheduler(optimizer, opt):
function define_ED (line 108) | def define_ED(input_nc, output_nc, ngf, which_model_netG, norm='batch', ...
function define_G (line 141) | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', u...
function define_D (line 170) | def define_D(input_nc, ndf, which_model_netD,
function print_network (line 194) | def print_network(net):
class GANLoss (line 211) | class GANLoss(nn.Module):
method __init__ (line 212) | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_...
method get_target_tensor (line 225) | def get_target_tensor(self, input, target_is_real):
method __call__ (line 243) | def __call__(self, input, target_is_real):
class ResnetEncoder (line 253) | class ResnetEncoder(nn.Module):
method __init__ (line 254) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
method forward (line 290) | def forward(self, input):
class ResnetDecoder (line 297) | class ResnetDecoder(nn.Module):
method __init__ (line 298) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
method forward (line 333) | def forward(self, input):
class ResnetGenerator (line 344) | class ResnetGenerator(nn.Module):
method __init__ (line 345) | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNor...
method forward (line 391) | def forward(self, input):
class ResnetBlock (line 399) | class ResnetBlock(nn.Module):
method __init__ (line 400) | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
method build_conv_block (line 404) | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout,...
method forward (line 436) | def forward(self, x):
class UnetGenerator (line 445) | class UnetGenerator(nn.Module):
method __init__ (line 446) | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
method forward (line 467) | def forward(self, input):
class UnetSkipConnectionBlock (line 477) | class UnetSkipConnectionBlock(nn.Module):
method __init__ (line 478) | def __init__(self, outer_nc, inner_nc, input_nc=None,
method forward (line 523) | def forward(self, x):
class NLayerDiscriminator (line 531) | class NLayerDiscriminator(nn.Module):
method __init__ (line 532) | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNo...
method forward (line 575) | def forward(self, input):
class PixelDiscriminator (line 582) | class PixelDiscriminator(nn.Module):
method __init__ (line 583) | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_si...
method forward (line 604) | def forward(self, input):
FILE: drawing_and_style_transfer/models/ost.py
class OSTModel (line 11) | class OSTModel(BaseModel):
method name (line 12) | def name(self):
method _compute_kl (line 15) | def _compute_kl(self, mu):
method set_encoders_and_decoders (line 20) | def set_encoders_and_decoders(self, opt):
method initialize (line 67) | def initialize(self, opt):
method set_input (line 144) | def set_input(self, input):
method forward (line 155) | def forward(self):
method test (line 159) | def test(self):
method get_image_paths (line 178) | def get_image_paths(self):
method backward_D_basic (line 181) | def backward_D_basic(self, netD, real, fake):
method backward_D (line 194) | def backward_D(self):
method backward_G (line 203) | def backward_G(self):
method optimize_parameters (line 239) | def optimize_parameters(self):
method get_current_errors (line 265) | def get_current_errors(self):
method get_current_visuals (line 273) | def get_current_visuals(self):
method save (line 286) | def save(self, label):
FILE: drawing_and_style_transfer/models/test_model.py
class TestModel (line 8) | class TestModel(BaseModel):
method name (line 9) | def name(self):
method initialize (line 12) | def initialize(self, opt):
method set_input (line 27) | def set_input(self, input):
method test (line 35) | def test(self):
method get_image_paths (line 40) | def get_image_paths(self):
method get_current_visuals (line 43) | def get_current_visuals(self):
FILE: drawing_and_style_transfer/options/base_options.py
class BaseOptions (line 7) | class BaseOptions():
method __init__ (line 8) | def __init__(self):
method initialize (line 12) | def initialize(self):
method parse (line 74) | def parse(self):
FILE: drawing_and_style_transfer/options/test_options.py
class TestOptions (line 4) | class TestOptions(BaseOptions):
method initialize (line 5) | def initialize(self):
FILE: drawing_and_style_transfer/options/train_options.py
class TrainOptions (line 4) | class TrainOptions(BaseOptions):
method initialize (line 5) | def initialize(self):
FILE: drawing_and_style_transfer/util/get_data.py
class GetData (line 11) | class GetData(object):
method __init__ (line 29) | def __init__(self, technique='cyclegan', verbose=True):
method _print (line 37) | def _print(self, text):
method _get_options (line 42) | def _get_options(r):
method _present_options (line 48) | def _present_options(self):
method _download_data (line 58) | def _download_data(self, dataset_url, save_path):
method get (line 81) | def get(self, save_path, dataset=None):
FILE: drawing_and_style_transfer/util/html.py
class HTML (line 6) | class HTML:
method __init__ (line 7) | def __init__(self, web_dir, title, reflesh=0):
method get_image_dir (line 22) | def get_image_dir(self):
method add_header (line 25) | def add_header(self, str):
method add_table (line 29) | def add_table(self, border=1):
method add_images (line 33) | def add_images(self, ims, txts, links, width=400):
method save (line 45) | def save(self):
FILE: drawing_and_style_transfer/util/image_pool.py
class ImagePool (line 6) | class ImagePool():
method __init__ (line 7) | def __init__(self, pool_size):
method query (line 13) | def query(self, images):
FILE: drawing_and_style_transfer/util/util.py
function tensor2im (line 10) | def tensor2im(image_tensor, imtype=np.uint8):
function diagnose_network (line 18) | def diagnose_network(net, name='network'):
function save_image (line 31) | def save_image(image_numpy, image_path):
function print_numpy (line 36) | def print_numpy(x, val=True, shp=False):
function mkdirs (line 46) | def mkdirs(paths):
function mkdir (line 54) | def mkdir(path):
FILE: drawing_and_style_transfer/util/visualizer.py
class Visualizer (line 10) | class Visualizer():
method __init__ (line 11) | def __init__(self, opt):
method reset (line 33) | def reset(self):
method display_current_results (line 37) | def display_current_results(self, visuals, epoch, save_result):
method plot_current_errors (line 101) | def plot_current_errors(self, epoch, counter_ratio, opt, errors):
method print_current_errors (line 117) | def print_current_errors(self, epoch, i, errors, t, t_data):
method save_images (line 127) | def save_images(self, webpage, visuals, image_path, aspect_ratio=1.0, ...
FILE: mnist_to_svhn/data_loader.py
function get_loader (line 6) | def get_loader(config):
FILE: mnist_to_svhn/main_autoencoder.py
function str2bool (line 9) | def str2bool(v):
function main (line 13) | def main(config):
FILE: mnist_to_svhn/main_mnist_to_svhn.py
function str2bool (line 10) | def str2bool(v):
function main (line 14) | def main(config):
FILE: mnist_to_svhn/main_svhn_to_mnist.py
function str2bool (line 10) | def str2bool(v):
function main (line 14) | def main(config):
FILE: mnist_to_svhn/model.py
function deconv (line 5) | def deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
function conv (line 14) | def conv(c_in, c_out, k_size, stride=2, pad=1, bn=True):
class G11 (line 23) | class G11(nn.Module):
method __init__ (line 24) | def __init__(self, conv_dim=64):
method forward (line 42) | def forward(self, x, svhn=False):
method encode (line 60) | def encode(self, x, svhn=False):
method decode (line 72) | def decode(self, out, svhn=False):
method encode_params (line 84) | def encode_params(self):
method decode_params (line 92) | def decode_params(self):
method unshared_parameters (line 100) | def unshared_parameters(self):
class G22 (line 105) | class G22(nn.Module):
method __init__ (line 106) | def __init__(self, conv_dim=64):
method forward (line 124) | def forward(self, x, mnist=False):
method encode (line 142) | def encode(self, x, mnist=False):
method decode (line 154) | def decode(self, out, mnist=False):
method encode_params (line 166) | def encode_params(self):
method decode_params (line 174) | def decode_params(self):
method unshared_parameters (line 182) | def unshared_parameters(self):
class D1 (line 187) | class D1(nn.Module):
method __init__ (line 190) | def __init__(self, conv_dim=64, use_labels=False):
method forward (line 198) | def forward(self, x_0):
class D2 (line 207) | class D2(nn.Module):
method __init__ (line 210) | def __init__(self, conv_dim=64, use_labels=False):
method forward (line 218) | def forward(self, x_0):
FILE: mnist_to_svhn/solver_autoencoder.py
class Solver (line 13) | class Solver(object):
method __init__ (line 14) | def __init__(self, config, svhn_loader, mnist_loader):
method build_model (line 38) | def build_model(self):
method merge_images (line 57) | def merge_images(self, sources, targets, k=10):
method to_var (line 68) | def to_var(self, x):
method to_data (line 74) | def to_data(self, x):
method reset_grad (line 80) | def reset_grad(self):
method _compute_kl (line 85) | def _compute_kl(self, mu):
method train (line 90) | def train(self):
FILE: mnist_to_svhn/solver_mnist_to_svhn.py
class Solver (line 13) | class Solver(object):
method __init__ (line 14) | def __init__(self, config, svhn_loader, mnist_loader):
method build_model (line 42) | def build_model(self):
method merge_images (line 61) | def merge_images(self, sources, targets, k=10):
method to_var (line 72) | def to_var(self, x, volatile=False):
method to_no_grad_var (line 80) | def to_no_grad_var(self, x):
method to_data (line 84) | def to_data(self, x, no_numpy=False):
method reset_grad (line 92) | def reset_grad(self):
method _compute_kl (line 98) | def _compute_kl(self, mu):
method train (line 104) | def train(self):
FILE: mnist_to_svhn/solver_svhn_to_mnist.py
class Solver (line 13) | class Solver(object):
method __init__ (line 14) | def __init__(self, config, svhn_loader, mnist_loader):
method build_model (line 42) | def build_model(self):
method merge_images (line 61) | def merge_images(self, sources, targets, k=10):
method to_var (line 72) | def to_var(self, x, volatile=False):
method to_no_grad_var (line 80) | def to_no_grad_var(self, x):
method to_data (line 84) | def to_data(self, x, no_numpy=False):
method reset_grad (line 92) | def reset_grad(self):
method _compute_kl (line 98) | def _compute_kl(self, mu):
method train (line 103) | def train(self):
Condensed preview — 43 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (166K chars).
[
{
"path": "LICENSE",
"chars": 3568,
"preview": "MIT License\n\nCopyright (c) 2017 \n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this s"
},
{
"path": "README.md",
"chars": 3008,
"preview": "# Pytorch implementation of One-Shot Unsupervised Cross Domain Translation ([arxiv](https://arxiv.org/abs/1806.06029)).\n"
},
{
"path": "drawing_and_style_transfer/data/__init__.py",
"chars": 1480,
"preview": "import torch.utils.data\nfrom data.base_data_loader import BaseDataLoader\n\n\ndef CreateDataLoader(opt):\n data_loader = "
},
{
"path": "drawing_and_style_transfer/data/aligned_dataset.py",
"chars": 2410,
"preview": "import os.path\nimport random\nimport torchvision.transforms as transforms\nimport torch\nfrom data.base_dataset import Base"
},
{
"path": "drawing_and_style_transfer/data/base_data_loader.py",
"chars": 175,
"preview": "class BaseDataLoader():\n def __init__(self):\n pass\n\n def initialize(self, opt):\n self.opt = opt\n "
},
{
"path": "drawing_and_style_transfer/data/base_dataset.py",
"chars": 1735,
"preview": "import torch.utils.data as data\nfrom PIL import Image\nimport torchvision.transforms as transforms\n\n\nclass BaseDataset(da"
},
{
"path": "drawing_and_style_transfer/data/image_folder.py",
"chars": 2080,
"preview": "###############################################################################\n# Code from\n# https://github.com/pytorch"
},
{
"path": "drawing_and_style_transfer/data/single_dataset.py",
"chars": 1056,
"preview": "import os.path\nfrom data.base_dataset import BaseDataset, get_transform\nfrom data.image_folder import make_dataset\nfrom "
},
{
"path": "drawing_and_style_transfer/data/unaligned_dataset.py",
"chars": 2051,
"preview": "import os.path\nfrom data.base_dataset import BaseDataset, get_transform\nfrom data.image_folder import make_dataset\nfrom "
},
{
"path": "drawing_and_style_transfer/datasets/combine_A_and_B.py",
"chars": 2124,
"preview": "import os\nimport numpy as np\nimport cv2\nimport argparse\n\nparser = argparse.ArgumentParser('create image pairs')\nparser.a"
},
{
"path": "drawing_and_style_transfer/datasets/download_cyclegan_dataset.sh",
"chars": 809,
"preview": "FILE=$1\n\nif [[ $FILE != \"ae_photos\" && $FILE != \"apple2orange\" && $FILE != \"summer2winter_yosemite\" && $FILE != \"horse2"
},
{
"path": "drawing_and_style_transfer/datasets/make_dataset_aligned.py",
"chars": 2257,
"preview": "import os\n\nfrom PIL import Image\n\n\ndef get_file_paths(folder):\n image_file_paths = []\n for root, dirs, filenames i"
},
{
"path": "drawing_and_style_transfer/environment.yml",
"chars": 224,
"preview": "name: OST\nchannels:\n- peterjc123\n- defaults\ndependencies:\n- python=3.6.5\n- pytorch=0.4.0\n- scipy\n- pip:\n - dominate==2."
},
{
"path": "drawing_and_style_transfer/models/__init__.py",
"chars": 684,
"preview": "def create_model(opt):\n print(opt.model)\n if opt.model == 'ost':\n assert (opt.dataset_mode == 'unaligned')\n"
},
{
"path": "drawing_and_style_transfer/models/autoencoder_model.py",
"chars": 8669,
"preview": "import torch\nfrom collections import OrderedDict\nfrom torch.autograd import Variable\nimport itertools\nimport util.util a"
},
{
"path": "drawing_and_style_transfer/models/base_model.py",
"chars": 1919,
"preview": "import os\nimport torch\n\n\nclass BaseModel(object):\n def name(self):\n return 'BaseModel'\n\n def initialize(sel"
},
{
"path": "drawing_and_style_transfer/models/networks.py",
"chars": 25147,
"preview": "import torch\nimport torch.nn as nn\nfrom torch.nn import init\nimport functools\nfrom torch.autograd import Variable\nfrom t"
},
{
"path": "drawing_and_style_transfer/models/ost.py",
"chars": 14038,
"preview": "import torch\nfrom collections import OrderedDict\nfrom torch.autograd import Variable\nimport itertools\nimport util.util a"
},
{
"path": "drawing_and_style_transfer/models/test_model.py",
"chars": 1606,
"preview": "from torch.autograd import Variable\nfrom collections import OrderedDict\nimport util.util as util\nfrom .base_model import"
},
{
"path": "drawing_and_style_transfer/options/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "drawing_and_style_transfer/options/base_options.py",
"chars": 7094,
"preview": "import argparse\nimport os\nfrom util import util\nimport torch\n\n\nclass BaseOptions():\n def __init__(self):\n self"
},
{
"path": "drawing_and_style_transfer/options/test_options.py",
"chars": 879,
"preview": "from .base_options import BaseOptions\n\n\nclass TestOptions(BaseOptions):\n def initialize(self):\n BaseOptions.in"
},
{
"path": "drawing_and_style_transfer/options/train_options.py",
"chars": 3703,
"preview": "from .base_options import BaseOptions\n\n\nclass TrainOptions(BaseOptions):\n def initialize(self):\n BaseOptions.i"
},
{
"path": "drawing_and_style_transfer/scripts/test_ost.sh",
"chars": 2217,
"preview": "# images to cityscapes\npython test.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost --load_dir=cityscapes_ost "
},
{
"path": "drawing_and_style_transfer/scripts/train_autoencoder.sh",
"chars": 2019,
"preview": "# images to cityscapes\npython train.py --dataroot=./datasets/cityscapes/trainB --name=cityscapes_autoencoder --model=aut"
},
{
"path": "drawing_and_style_transfer/scripts/train_ost.sh",
"chars": 2307,
"preview": "# images to cityscapes\npython train.py --dataroot=./datasets/cityscapes/ --name=cityscapes_ost --load_dir=cityscapes_aut"
},
{
"path": "drawing_and_style_transfer/test.py",
"chars": 1482,
"preview": "import os\nfrom options.test_options import TestOptions\nfrom data import CreateDataLoader\nfrom models import create_model"
},
{
"path": "drawing_and_style_transfer/train.py",
"chars": 2312,
"preview": "import time\nfrom options.train_options import TrainOptions\nfrom data import CreateDataLoader\nfrom models import create_m"
},
{
"path": "drawing_and_style_transfer/util/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "drawing_and_style_transfer/util/get_data.py",
"chars": 3511,
"preview": "from __future__ import print_function\nimport os\nimport tarfile\nimport requests\nfrom warnings import warn\nfrom zipfile im"
},
{
"path": "drawing_and_style_transfer/util/html.py",
"chars": 1912,
"preview": "import dominate\nfrom dominate.tags import *\nimport os\n\n\nclass HTML:\n def __init__(self, web_dir, title, reflesh=0):\n "
},
{
"path": "drawing_and_style_transfer/util/image_pool.py",
"chars": 1099,
"preview": "import random\nimport torch\nfrom torch.autograd import Variable\n\n\nclass ImagePool():\n def __init__(self, pool_size):\n "
},
{
"path": "drawing_and_style_transfer/util/util.py",
"chars": 1482,
"preview": "from __future__ import print_function\nimport torch\nimport numpy as np\nfrom PIL import Image\nimport os\n\n\n# Converts a Ten"
},
{
"path": "drawing_and_style_transfer/util/visualizer.py",
"chars": 6788,
"preview": "import numpy as np\nimport os\nimport ntpath\nimport time\nfrom . import util\nfrom . import html\nfrom scipy.misc import imre"
},
{
"path": "mnist_to_svhn/data_loader.py",
"chars": 2484,
"preview": "import torch\nfrom torchvision import datasets\nfrom torchvision import transforms\n\n\ndef get_loader(config):\n \"\"\"Builds"
},
{
"path": "mnist_to_svhn/download.sh",
"chars": 279,
"preview": "mkdir -p mnist\nmkdir -p svhn\n\nwget -O svhn/train_32x32.mat http://ufldl.stanford.edu/housenumbers/train_32x32.mat\nwget -"
},
{
"path": "mnist_to_svhn/main_autoencoder.py",
"chars": 2164,
"preview": "import argparse\nimport os\nfrom torch.backends import cudnn\n\nfrom solver_autoencoder import Solver\nfrom data_loader impor"
},
{
"path": "mnist_to_svhn/main_mnist_to_svhn.py",
"chars": 3292,
"preview": "import argparse\nimport logging\nimport os\nfrom torch.backends import cudnn\n\nfrom data_loader import get_loader\nfrom solve"
},
{
"path": "mnist_to_svhn/main_svhn_to_mnist.py",
"chars": 3293,
"preview": "import argparse\nimport logging\nimport os\n\nfrom data_loader import get_loader\nfrom solver_svhn_to_mnist import Solver\nfro"
},
{
"path": "mnist_to_svhn/model.py",
"chars": 7627,
"preview": "import torch.nn as nn\nimport torch.nn.functional as F\n\n\ndef deconv(c_in, c_out, k_size, stride=2, pad=1, bn=True):\n \""
},
{
"path": "mnist_to_svhn/solver_autoencoder.py",
"chars": 8815,
"preview": "import os\n\nimport numpy as np\nimport scipy.io\nimport torch\nfrom torch import optim\nfrom torch.autograd import Variable\n\n"
},
{
"path": "mnist_to_svhn/solver_mnist_to_svhn.py",
"chars": 9652,
"preview": "import os\n\nimport numpy as np\nimport scipy.io\nimport torch\nfrom torch import optim\nfrom torch.autograd import Variable\n\n"
},
{
"path": "mnist_to_svhn/solver_svhn_to_mnist.py",
"chars": 9609,
"preview": "import os\n\nimport numpy as np\nimport scipy.io\nimport torch\nfrom torch import optim\nfrom torch.autograd import Variable\n\n"
}
]
About this extraction
This page contains the full source code of the sagiebenaim/OneShotTranslation GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 43 files (155.3 KB), approximately 39.6k tokens, and a symbol index with 239 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.