Repository: qiaott/MirrorGAN
Branch: master
Commit: deb220fdae8f
Files: 18
Total size: 104.0 KB
Directory structure:
gitextract_0488eq_x/
├── .gitignore
├── GLAttention.py
├── README.md
├── cfg/
│ ├── __init__.py
│ ├── config.py
│ ├── eval_bird.yml
│ └── train_bird.yml
├── datasets.py
├── do_test.sh
├── do_train.sh
├── main.py
├── miscc/
│ ├── __init__.py
│ ├── losses.py
│ └── utils.py
├── model.py
├── pretrain_DAMSM.py
├── test.py
└── trainer.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pyc
miscc/*.pyc
.DS_Store
.idea/
================================================
FILE: GLAttention.py
================================================
import torch
import torch.nn as nn
def conv1x1(in_planes, out_planes):
"1x1 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
padding=0, bias=False)
def func_attention(query, context, gamma1):
"""
query: batch x ndf x queryL
context: batch x ndf x ih x iw (sourceL=ihxiw)
mask: batch_size x sourceL
"""
batch_size, queryL = query.size(0), query.size(2)
ih, iw = context.size(2), context.size(3)
sourceL = ih * iw
# --> batch x sourceL x ndf
context = context.view(batch_size, -1, sourceL)
contextT = torch.transpose(context, 1, 2).contiguous()
# Get attention
# (batch x sourceL x ndf)(batch x ndf x queryL)
# -->batch x sourceL x queryL
attn = torch.bmm(contextT, query)
# --> batch*sourceL x queryL
attn = attn.view(batch_size*sourceL, queryL)
attn = nn.Softmax()(attn) # Eq. (8)
# --> batch x sourceL x queryL
attn = attn.view(batch_size, sourceL, queryL)
# --> batch*queryL x sourceL
attn = torch.transpose(attn, 1, 2).contiguous()
attn = attn.view(batch_size*queryL, sourceL)
attn = attn * gamma1
attn = nn.Softmax()(attn)
attn = attn.view(batch_size, queryL, sourceL)
# --> batch x sourceL x queryL
attnT = torch.transpose(attn, 1, 2).contiguous()
# (batch x ndf x sourceL)(batch x sourceL x queryL)
# --> batch x ndf x queryL
weightedContext = torch.bmm(context, attnT)
return weightedContext, attn.view(batch_size, -1, ih, iw)
class GLAttentionGeneral(nn.Module):
def __init__(self, idf, cdf):
super(GLAttentionGeneral, self).__init__()
self.conv_context = conv1x1(cdf, idf)
self.conv_sentence_vis = conv1x1(idf, idf)
self.linear = nn.Linear(100, idf)
self.sm = nn.Softmax()
self.mask = None
def applyMask(self, mask):
self.mask = mask # batch x sourceL
def forward(self, input, sentence, context):
"""
input: batch x idf x ih x iw (queryL=ihxiw)
context: batch x cdf x sourceL (this is the matrix of word vectors)
sentence (c_code1): batch x idf x queryL (this is the vectors of the sentence)
queryL=ih x iw
"""
idf, ih, iw = input.size(1), input.size(2), input.size(3)
queryL = ih * iw
batch_size, sourceL = context.size(0), context.size(2)
# generated image feature:--> batch x queryL x idf
target = input.view(batch_size, -1, queryL) # batch x idf x queryL
targetT = torch.transpose(target, 1, 2).contiguous() # batch x queryL x idf
# Eq(4) in MirrorGAN : local-level attention
# words feature: batch x cdf x sourceL --> batch x cdf x sourceL x 1
sourceT = context.unsqueeze(3)
# --> batch x idf x sourceL
sourceT = self.conv_context(sourceT).squeeze(3)
attn = torch.bmm(targetT, sourceT)
# --> batch*queryL x sourceL
attn = attn.view(batch_size*queryL, sourceL)
if self.mask is not None:
# batch_size x sourceL --> batch_size*queryL x sourceL
mask = self.mask.repeat(queryL, 1)
attn.data.masked_fill_(mask.data, -float('inf'))
attn = self.sm(attn) # Eq. (2)
# --> batch x queryL x sourceL
attn = attn.view(batch_size, queryL, sourceL)
# --> batch x sourceL x queryL
attn = torch.transpose(attn, 1, 2).contiguous()
# (batch x idf x sourceL)(batch x sourceL x queryL)
# --> batch x idf x queryL
weightedContext = torch.bmm(sourceT, attn)
weightedContext = weightedContext.view(batch_size, -1, ih, iw) # batch x idf x ih x iw
word_attn = attn.view(batch_size, -1, ih, iw) # (batch x sourceL x ih x iw)
# Eq(5) in MirrorGAN : global-level attention
sentence = self.linear(sentence)
sentence = sentence.view(batch_size, idf, 1, 1)
sentence = sentence.repeat(1, 1, ih, iw)
sentence_vs = torch.mul(input, sentence) # batch x idf x ih x iw
sentence_vs = self.conv_sentence_vis(sentence_vs) # batch x idf x ih x iw
sent_att = nn.Softmax()(sentence_vs) # batch x idf x ih x iw
weightedSentence = torch.mul(sentence, sent_att) # batch x idf x ih x iw
return weightedContext, weightedSentence, word_attn, sent_att
# weightedContext: batch x idf x ih x iw
# weightedSentence: batch x idf x ih x iw
# word_attn: batch x sourceL x ih x iw
# sent_vs_att: batch x idf x ih x iw
================================================
FILE: README.md
================================================
# MirrorGAN
Pytorch implementation for Paper [MirrorGAN: Learning Text-to-image Generation by Redescription](https://arxiv.org/abs/1903.05854) by Tingting Qiao, Jing Zhang, Duanqing Xu, Dacheng Tao. (The work was performed when Tingting Qiao was a visiting student at UBTECH Sydney AI Centre in the School of Computer Science, FEIT, the University of Sydney).

## Getting Started
### Installation
- Install PyTorch and dependencies from http://pytorch.org
- Install Torch vision from the source.
- Clone this repo:
```bash
git clone https://github.com/qiaott/MirrorGAN.git
cd MirrorGAN
```
- Download our preprocessed data from [here](https://drive.google.com/file/d/1CuW5ognTSkNbyx9TWoUFrgwqxZNk1cl0/view?usp=sharing).
- The STEM was pretrained using the code provided [here](https://github.com/taoxugit/AttnGAN)
- The STREAM was pretrained using the code provided [here](https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/03-advanced/image_captioning).
### Train/Test
After obtaining the pretrained STEM and STREAM modules, we can train the text2image model.
- Train a model:
```bash
./do_train.sh
```
- Test a model:
```bash
./do_test.sh
```
## Citation
If you use this code for your research, please cite our paper.
```bash
@article{qiao2019mirrorgan,
title={MirrorGAN: Learning Text-to-image Generation by Redescription},
author={Qiao, Tingting and Zhang, Jing and Xu, Duanqing and Tao, Dacheng},
journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2019}
}
```
================================================
FILE: cfg/__init__.py
================================================
================================================
FILE: cfg/config.py
================================================
from __future__ import division
from __future__ import print_function
import os.path as osp
import numpy as np
from easydict import EasyDict as edict
__C = edict()
cfg = __C
# Dataset name: flowers, birds
__C.DATASET_NAME = 'birds'
__C.CONFIG_NAME = ''
__C.DATA_DIR = ''
__C.GPU_ID = 0
__C.CUDA = True
__C.WORKERS = 6
__C.OUTPUT_PATH = ''
__C.RNN_TYPE = 'LSTM' # 'GRU'
__C.B_VALIDATION = False
__C.TREE = edict()
__C.TREE.BRANCH_NUM = 3
__C.TREE.BASE_SIZE = 64
# Training options
__C.TRAIN = edict()
__C.TRAIN.BATCH_SIZE = 64
__C.TRAIN.MAX_EPOCH = 600
__C.TRAIN.SNAPSHOT_INTERVAL = 2000
__C.TRAIN.DISCRIMINATOR_LR = 2e-4
__C.TRAIN.GENERATOR_LR = 2e-4
__C.TRAIN.ENCODER_LR = 2e-4
__C.TRAIN.RNN_GRAD_CLIP = 0.25
__C.TRAIN.FLAG = True
__C.TRAIN.NET_E = ''
__C.TRAIN.NET_G = ''
__C.TRAIN.B_NET_D = True
__C.TRAIN.SMOOTH = edict()
__C.TRAIN.SMOOTH.GAMMA1 = 5.0
__C.TRAIN.SMOOTH.GAMMA3 = 10.0
__C.TRAIN.SMOOTH.GAMMA2 = 5.0
__C.TRAIN.SMOOTH.LAMBDA = 0.0
__C.TRAIN.SMOOTH.LAMBDA1 = 1.0
# Caption_model_settings added by tingting
__C.CAP = edict()
__C.CAP.embed_size = 256
__C.CAP.hidden_size = 512
__C.CAP.num_layers = 1
__C.CAP.learning_rate = 0.001
__C.CAP.caption_cnn_path = ''
__C.CAP.caption_rnn_path = ''
# Modal options
__C.GAN = edict()
__C.GAN.DF_DIM = 64
__C.GAN.GF_DIM = 128
__C.GAN.Z_DIM = 100
__C.GAN.CONDITION_DIM = 100
__C.GAN.R_NUM = 2
__C.GAN.B_ATTENTION = True
__C.GAN.B_DCGAN = False
__C.TEXT = edict()
__C.TEXT.CAPTIONS_PER_IMAGE = 10
__C.TEXT.EMBEDDING_DIM = 256
__C.TEXT.WORDS_NUM = 18
def _merge_a_into_b(a, b):
"""Merge config dictionary a into config dictionary b, clobbering the
options in b whenever they are also specified in a.
"""
if type(a) is not edict:
return
for k, v in a.iteritems():
# a must specify keys that are in b
if not b.has_key(k):
raise KeyError('{} is not a valid config key'.format(k))
# the types must match, too
old_type = type(b[k])
if old_type is not type(v):
if isinstance(b[k], np.ndarray):
v = np.array(v, dtype=b[k].dtype)
else:
raise ValueError(('Type mismatch ({} vs. {}) '
'for config key: {}').format(type(b[k]),
type(v), k))
# recursively merge dicts
if type(v) is edict:
try:
_merge_a_into_b(a[k], b[k])
except:
print('Error under config key: {}'.format(k))
raise
else:
b[k] = v
def cfg_from_file(filename):
"""Load a config file and merge it into the default options."""
import yaml
with open(filename, 'r') as f:
yaml_cfg = edict(yaml.load(f))
_merge_a_into_b(yaml_cfg, __C)
================================================
FILE: cfg/eval_bird.yml
================================================
CONFIG_NAME: 'MirrorGAN'
DATASET_NAME: 'birds'
DATA_DIR: '../data/birds'
GPU_ID: 3
WORKERS: 1
B_VALIDATION: True # True # False
TREE:
BRANCH_NUM: 3
TRAIN:
FLAG: False
NET_G: '../data/output/bird/Model/netG.pth' # path to the trained model
B_NET_D: False
BATCH_SIZE: 12
NET_E: '../data/STEM/text_encoder.pth'
GAN:
DF_DIM: 64
GF_DIM: 32
Z_DIM: 100
R_NUM: 2
TEXT:
EMBEDDING_DIM: 256
CAPTIONS_PER_IMAGE: 10
WORDS_NUM: 25
================================================
FILE: cfg/train_bird.yml
================================================
CONFIG_NAME: 'MirrorGAN'
DATASET_NAME: 'birds'
DATA_DIR: '../data/birds'
GPU_ID: 3
WORKERS: 4
OUTPUT_PATH: '/data/qtt/MirrorGAN/'
TREE:
BRANCH_NUM: 3
TRAIN:
FLAG: True
NET_G: ''
B_NET_D: True
BATCH_SIZE: 12 # 22
MAX_EPOCH: 650
SNAPSHOT_INTERVAL: 50
DISCRIMINATOR_LR: 0.0002
GENERATOR_LR: 0.0002
NET_E: '../data/STEM/text_encoder.pth'
SMOOTH:
GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad
GAMMA2: 5.0
GAMMA3: 10.0 # 10good 1&100bad
LAMBDA: 0.0
LAMBDA1: 10.0
CAP:
embed_size: 256
hidden_size: 256
num_layers: 1
learning_rate: 0.001
caption_cnn_path: '../data/STREAM/cnn_encoder.ckpt'
caption_rnn_path: '../data/STREAM/rnn_decoder.ckpt'
GAN:
DF_DIM: 64
GF_DIM: 32
Z_DIM: 100
R_NUM: 2
TEXT:
EMBEDDING_DIM: 256
CAPTIONS_PER_IMAGE: 10
================================================
FILE: datasets.py
================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from nltk.tokenize import RegexpTokenizer
from collections import defaultdict
from cfg.config import cfg
import torch
import torch.utils.data as data
from torch.autograd import Variable
import torchvision.transforms as transforms
import os
import sys
import numpy as np
import pandas as pd
from PIL import Image
import numpy.random as random
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
def prepare_data(data):
imgs, captions, captions_lens, class_ids, keys = data
# sort data by the length in a decreasing order
sorted_cap_lens, sorted_cap_indices = \
torch.sort(captions_lens, 0, True)
real_imgs = []
for i in range(len(imgs)):
imgs[i] = imgs[i][sorted_cap_indices]
if cfg.CUDA:
real_imgs.append(Variable(imgs[i]).cuda())
else:
real_imgs.append(Variable(imgs[i]))
captions = captions[sorted_cap_indices].squeeze()
class_ids = class_ids[sorted_cap_indices].numpy()
# sent_indices = sent_indices[sorted_cap_indices]
keys = [keys[i] for i in sorted_cap_indices.numpy()]
# print('keys', type(keys), keys[-1]) # list
if cfg.CUDA:
captions = Variable(captions).cuda()
sorted_cap_lens = Variable(sorted_cap_lens).cuda()
else:
captions = Variable(captions)
sorted_cap_lens = Variable(sorted_cap_lens)
return [real_imgs, captions, sorted_cap_lens,
class_ids, keys]
def get_imgs(img_path, imsize, bbox=None,
transform=None, normalize=None):
img = Image.open(img_path).convert('RGB')
width, height = img.size
if bbox is not None:
r = int(np.maximum(bbox[2], bbox[3]) * 0.75)
center_x = int((2 * bbox[0] + bbox[2]) / 2)
center_y = int((2 * bbox[1] + bbox[3]) / 2)
y1 = np.maximum(0, center_y - r)
y2 = np.minimum(height, center_y + r)
x1 = np.maximum(0, center_x - r)
x2 = np.minimum(width, center_x + r)
img = img.crop([x1, y1, x2, y2])
if transform is not None:
img = transform(img)
ret = []
if cfg.GAN.B_DCGAN:
ret = [normalize(img)]
else:
for i in range(cfg.TREE.BRANCH_NUM):
# print(imsize[i])
if i < (cfg.TREE.BRANCH_NUM - 1):
re_img = transforms.Scale(imsize[i])(img)
else:
re_img = img
ret.append(normalize(re_img))
return ret
class TextDataset(data.Dataset):
def __init__(self, data_dir, split='train',
base_size=64,
transform=None, target_transform=None):
self.transform = transform
self.norm = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
self.target_transform = target_transform
self.embeddings_num = cfg.TEXT.CAPTIONS_PER_IMAGE
self.imsize = []
for i in range(cfg.TREE.BRANCH_NUM):
self.imsize.append(base_size)
base_size = base_size * 2
self.data = []
self.data_dir = data_dir
if data_dir.find('birds') != -1:
self.bbox = self.load_bbox()
else:
self.bbox = None
split_dir = os.path.join(data_dir, split)
self.filenames, self.captions, self.ixtoword, \
self.wordtoix, self.n_words = self.load_text_data(data_dir, split)
self.class_id = self.load_class_id(split_dir, len(self.filenames))
self.number_example = len(self.filenames)
def load_bbox(self):
data_dir = self.data_dir
bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt')
df_bounding_boxes = pd.read_csv(bbox_path,
delim_whitespace=True,
header=None).astype(int)
#
filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt')
df_filenames = \
pd.read_csv(filepath, delim_whitespace=True, header=None)
filenames = df_filenames[1].tolist()
print('Total filenames: ', len(filenames), filenames[0])
#
filename_bbox = {img_file[:-4]: [] for img_file in filenames}
numImgs = len(filenames)
for i in xrange(0, numImgs):
# bbox = [x-left, y-top, width, height]
bbox = df_bounding_boxes.iloc[i][1:].tolist()
key = filenames[i][:-4]
filename_bbox[key] = bbox
#
return filename_bbox
def load_captions(self, data_dir, filenames):
all_captions = []
for i in range(len(filenames)):
cap_path = '%s/text/%s.txt' % (data_dir, filenames[i])
with open(cap_path, "r") as f:
captions = f.read().decode('utf8').split('\n')
cnt = 0
for cap in captions:
if len(cap) == 0:
continue
cap = cap.replace("\ufffd\ufffd", " ")
# picks out sequences of alphanumeric characters as tokens
# and drops everything else
tokenizer = RegexpTokenizer(r'\w+')
tokens = tokenizer.tokenize(cap.lower())
# print('tokens', tokens)
if len(tokens) == 0:
print('cap', cap)
continue
tokens_new = []
for t in tokens:
t = t.encode('ascii', 'ignore').decode('ascii')
if len(t) > 0:
tokens_new.append(t)
all_captions.append(tokens_new)
cnt += 1
if cnt == self.embeddings_num:
break
if cnt < self.embeddings_num:
print('ERROR: the captions for %s less than %d'
% (filenames[i], cnt))
return all_captions
def build_dictionary(self, train_captions, test_captions):
word_counts = defaultdict(float)
captions = train_captions + test_captions
for sent in captions:
for word in sent:
word_counts[word] += 1
vocab = [w for w in word_counts if word_counts[w] >= 0]
ixtoword = {}
ixtoword[0] = '<end>'
wordtoix = {}
wordtoix['<end>'] = 0
ix = 1
for w in vocab:
wordtoix[w] = ix
ixtoword[ix] = w
ix += 1
train_captions_new = []
for t in train_captions:
rev = []
for w in t:
if w in wordtoix:
rev.append(wordtoix[w])
# rev.append(0) # do not need '<end>' token
train_captions_new.append(rev)
test_captions_new = []
for t in test_captions:
rev = []
for w in t:
if w in wordtoix:
rev.append(wordtoix[w])
# rev.append(0) # do not need '<end>' token
test_captions_new.append(rev)
return [train_captions_new, test_captions_new,
ixtoword, wordtoix, len(ixtoword)]
def load_text_data(self, data_dir, split):
filepath = os.path.join(data_dir, 'bird_captions.pickle')
train_names = self.load_filenames(data_dir, 'train')
test_names = self.load_filenames(data_dir, 'test')
if not os.path.isfile(filepath):
train_captions = self.load_captions(data_dir, train_names)
test_captions = self.load_captions(data_dir, test_names)
train_captions, test_captions, ixtoword, wordtoix, n_words = \
self.build_dictionary(train_captions, test_captions)
with open(filepath, 'wb') as f:
pickle.dump([train_captions, test_captions,
ixtoword, wordtoix], f, protocol=2)
print('Save to: ', filepath)
else:
with open(filepath, 'rb') as f:
x = pickle.load(f)
train_captions, test_captions = x[0], x[1]
ixtoword, wordtoix = x[2], x[3]
del x
n_words = len(ixtoword)
print('Load from: ', filepath)
if split == 'train':
# a list of list: each list contains
# the indices of words in a sentence
captions = train_captions
filenames = train_names
else: # split=='test'
captions = test_captions
filenames = test_names
return filenames, captions, ixtoword, wordtoix, n_words
def load_class_id(self, data_dir, total_num):
if os.path.isfile(data_dir + '/class_info.pickle'):
with open(data_dir + '/class_info.pickle', 'rb') as f:
class_id = pickle.load(f)
else:
class_id = np.arange(total_num)
return class_id
def load_filenames(self, data_dir, split):
filepath = '%s/%s/filenames.pickle' % (data_dir, split)
if os.path.isfile(filepath):
with open(filepath, 'rb') as f:
filenames = pickle.load(f)
print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
else:
filenames = []
return filenames
def get_caption(self, sent_ix):
# a list of indices for a sentence
sent_caption = np.asarray(self.captions[sent_ix]).astype('int64')
# if (sent_caption == 0).sum() > 0:
# print('ERROR: do not need END (0) token', sent_caption)
num_words = len(sent_caption)
# pad with 0s (i.e., '<end>')
x = np.zeros((cfg.TEXT.WORDS_NUM, 1), dtype='int64')
x_len = num_words
if num_words <= cfg.TEXT.WORDS_NUM:
x[:num_words, 0] = sent_caption
else:
ix = list(np.arange(num_words)) # 1, 2, 3,..., maxNum
np.random.shuffle(ix)
ix = ix[:cfg.TEXT.WORDS_NUM]
ix = np.sort(ix)
x[:, 0] = sent_caption[ix]
x_len = cfg.TEXT.WORDS_NUM
return x, x_len
def __getitem__(self, index):
#
key = self.filenames[index]
cls_id = self.class_id[index]
#
if self.bbox is not None:
bbox = self.bbox[key]
data_dir = '%s/CUB_200_2011' % self.data_dir
else:
bbox = None
data_dir = self.data_dir
#
img_name = '%s/images/%s.jpg' % (data_dir, key)
imgs = get_imgs(img_name, self.imsize,
bbox, self.transform, normalize=self.norm)
# random select a sentence
sent_ix = random.randint(0, self.embeddings_num)
new_sent_ix = index * self.embeddings_num + sent_ix
caps, cap_len = self.get_caption(new_sent_ix)
return imgs, caps, cap_len, cls_id, key
def __len__(self):
return len(self.filenames)
================================================
FILE: do_test.sh
================================================
cfg=cfg/eval_bird.yml
python main.py --cfg $cfg
================================================
FILE: do_train.sh
================================================
cfg=cfg/train_bird.yml
python main.py --cfg $cfg
================================================
FILE: main.py
================================================
from __future__ import print_function
from cfg.config import cfg, cfg_from_file
from datasets import TextDataset
from trainer import Trainer as trainer
import os
import sys
import time
import random
import pprint
import datetime
import dateutil.tz
import argparse
import numpy as np
import torch
import torchvision.transforms as transforms
dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
sys.path.append(dir_path)
def parse_args():
parser = argparse.ArgumentParser(description='Train a AttnGAN network')
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default='cfg/bird_attn2.yml', type=str)
parser.add_argument('--gpu', dest='gpu_id', type=int, default=-1)
parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
parser.add_argument('--manualSeed', type=int, help='manual seed')
args = parser.parse_args()
return args
def gen_example(wordtoix, algo):
'''generate images from example sentences'''
from nltk.tokenize import RegexpTokenizer
filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR)
data_dic = {}
with open(filepath, "r") as f:
filenames = f.read().decode('utf8').split('\n')
for name in filenames:
if len(name) == 0:
continue
filepath = '%s/%s.txt' % (cfg.DATA_DIR, name)
with open(filepath, "r") as f:
print('Load from:', name)
sentences = f.read().decode('utf8').split('\n')
# a list of indices for a sentence
captions = []
cap_lens = []
for sent in sentences:
if len(sent) == 0:
continue
sent = sent.replace("\ufffd\ufffd", " ")
tokenizer = RegexpTokenizer(r'\w+')
tokens = tokenizer.tokenize(sent.lower())
if len(tokens) == 0:
print('sent', sent)
continue
rev = []
for t in tokens:
t = t.encode('ascii', 'ignore').decode('ascii')
if len(t) > 0 and t in wordtoix:
rev.append(wordtoix[t])
captions.append(rev)
cap_lens.append(len(rev))
max_len = np.max(cap_lens)
sorted_indices = np.argsort(cap_lens)[::-1]
cap_lens = np.asarray(cap_lens)
cap_lens = cap_lens[sorted_indices]
cap_array = np.zeros((len(captions), max_len), dtype='int64')
for i in range(len(captions)):
idx = sorted_indices[i]
cap = captions[idx]
c_len = len(cap)
cap_array[i, :c_len] = cap
key = name[(name.rfind('/') + 1):]
data_dic[key] = [cap_array, cap_lens, sorted_indices]
algo.gen_example(data_dic)
if __name__ == "__main__":
args = parse_args()
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.data_dir != '':
cfg.DATA_DIR = args.data_dir
print('Using config:')
pprint.pprint(cfg)
if not cfg.TRAIN.FLAG:
args.manualSeed = 100
elif args.manualSeed is None:
args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
np.random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if cfg.CUDA:
torch.cuda.manual_seed_all(args.manualSeed)
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = '%s/output/%s_%s_%s' % \
(cfg.OUTPUT_PATH, cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
split_dir, bshuffle = 'train', True
if not cfg.TRAIN.FLAG:
# bshuffle = False
split_dir = 'test'
# Get data loader
imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1))
image_transform = transforms.Compose([
transforms.Scale(int(imsize * 76 / 64)),
transforms.RandomCrop(imsize),
transforms.RandomHorizontalFlip()])
dataset = TextDataset(cfg.DATA_DIR, split_dir,
base_size=cfg.TREE.BASE_SIZE,
transform=image_transform)
assert dataset
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=cfg.TRAIN.BATCH_SIZE,
drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS))
# Define models and go to train/evaluate
algo = trainer(output_dir, dataloader, dataset.n_words, dataset.ixtoword)
start_t = time.time()
if cfg.TRAIN.FLAG:
algo.train()
else:
'''generate images from pre-extracted embeddings'''
if cfg.B_VALIDATION:
algo.sampling(split_dir) # generate images for the whole valid dataset
else:
gen_example(dataset.wordtoix, algo) # generate images for customized captions
end_t = time.time()
print('Total time for training:', end_t - start_t)
================================================
FILE: miscc/__init__.py
================================================
from __future__ import division
from __future__ import print_function
================================================
FILE: miscc/losses.py
================================================
import torch
import torch.nn as nn
import numpy as np
from cfg.config import cfg
from torch.nn.utils.rnn import pack_padded_sequence
from GLAttention import func_attention
# ##################Loss for matching text-image###################
def cosine_similarity(x1, x2, dim=1, eps=1e-8):
"""Returns cosine similarity between x1 and x2, computed along dim.
"""
w12 = torch.sum(x1 * x2, dim)
w1 = torch.norm(x1, 2, dim)
w2 = torch.norm(x2, 2, dim)
return (w12 / (w1 * w2).clamp(min=eps)).squeeze()
def caption_loss(cap_output, captions):
criterion = nn.CrossEntropyLoss()
caption_loss = criterion(cap_output, captions)
return caption_loss
def sent_loss(cnn_code, rnn_code, labels, class_ids,
batch_size, eps=1e-8):
# ### Mask mis-match samples ###
# that come from the same class as the real sample ###
masks = []
if class_ids is not None:
for i in range(batch_size):
mask = (class_ids == class_ids[i]).astype(np.uint8)
mask[i] = 0
masks.append(mask.reshape((1, -1)))
masks = np.concatenate(masks, 0)
# masks: batch_size x batch_size
masks = torch.ByteTensor(masks)
if cfg.CUDA:
masks = masks.cuda()
# --> seq_len x batch_size x nef
if cnn_code.dim() == 2:
cnn_code = cnn_code.unsqueeze(0)
rnn_code = rnn_code.unsqueeze(0)
# cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1
cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)
rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)
# scores* / norm*: seq_len x batch_size x batch_size
scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2))
norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2))
scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3
# --> batch_size x batch_size
scores0 = scores0.squeeze()
if class_ids is not None:
scores0.data.masked_fill_(masks, -float('inf'))
scores1 = scores0.transpose(0, 1)
if labels is not None:
loss0 = nn.CrossEntropyLoss()(scores0, labels)
loss1 = nn.CrossEntropyLoss()(scores1, labels)
else:
loss0, loss1 = None, None
return loss0, loss1
def words_loss(img_features, words_emb, labels,
cap_lens, class_ids, batch_size):
"""
words_emb(query): batch x nef x seq_len
img_features(context): batch x nef x 17 x 17
"""
masks = []
att_maps = []
similarities = []
cap_lens = cap_lens.data.tolist()
for i in range(batch_size):
if class_ids is not None:
mask = (class_ids == class_ids[i]).astype(np.uint8)
mask[i] = 0
masks.append(mask.reshape((1, -1)))
# Get the i-th text description
words_num = cap_lens[i]
# -> 1 x nef x words_num
word = words_emb[i, :, :words_num].unsqueeze(0).contiguous()
# -> batch_size x nef x words_num
word = word.repeat(batch_size, 1, 1)
# batch x nef x 17*17
context = img_features
"""
word(query): batch x nef x words_num
context: batch x nef x 17 x 17
weiContext: batch x nef x words_num
attn: batch x words_num x 17 x 17
"""
weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1)
att_maps.append(attn[i].unsqueeze(0).contiguous())
# --> batch_size x words_num x nef
word = word.transpose(1, 2).contiguous()
weiContext = weiContext.transpose(1, 2).contiguous()
# --> batch_size*words_num x nef
word = word.view(batch_size * words_num, -1)
weiContext = weiContext.view(batch_size * words_num, -1)
#
# -->batch_size*words_num
row_sim = cosine_similarity(word, weiContext)
# --> batch_size x words_num
row_sim = row_sim.view(batch_size, words_num)
# Eq. (10)
row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_()
row_sim = row_sim.sum(dim=1, keepdim=True)
row_sim = torch.log(row_sim)
# --> 1 x batch_size
# similarities(i, j): the similarity between the i-th image and the j-th text description
similarities.append(row_sim)
# batch_size x batch_size
similarities = torch.cat(similarities, 1)
if class_ids is not None:
masks = np.concatenate(masks, 0)
# masks: batch_size x batch_size
masks = torch.ByteTensor(masks)
if cfg.CUDA:
masks = masks.cuda()
similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3
if class_ids is not None:
similarities.data.masked_fill_(masks, -float('inf'))
similarities1 = similarities.transpose(0, 1)
if labels is not None:
loss0 = nn.CrossEntropyLoss()(similarities, labels)
loss1 = nn.CrossEntropyLoss()(similarities1, labels)
else:
loss0, loss1 = None, None
return loss0, loss1, att_maps
# ##################Loss for G and Ds##############################
def discriminator_loss(netD, real_imgs, fake_imgs, conditions,
real_labels, fake_labels):
# Forward
real_features = netD(real_imgs)
fake_features = netD(fake_imgs.detach())
# loss
#
cond_real_logits = netD.COND_DNET(real_features, conditions)
cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels)
cond_fake_logits = netD.COND_DNET(fake_features, conditions)
cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels)
#
batch_size = real_features.size(0)
cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size])
cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size])
if netD.UNCOND_DNET is not None:
real_logits = netD.UNCOND_DNET(real_features)
fake_logits = netD.UNCOND_DNET(fake_features)
real_errD = nn.BCELoss()(real_logits, real_labels)
fake_errD = nn.BCELoss()(fake_logits, fake_labels)
errD = ((real_errD + cond_real_errD) / 2. +
(fake_errD + cond_fake_errD + cond_wrong_errD) / 3.)
else:
errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2.
return errD
def generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels,
words_embs, sent_emb, match_labels,
cap_lens, class_ids):
numDs = len(netsD)
logs = ''
# Forward
errG_total = 0
for i in range(numDs):
features = netsD[i](fake_imgs[i])
cond_logits = netsD[i].COND_DNET(features, sent_emb)
cond_errG = nn.BCELoss()(cond_logits, real_labels)
if netsD[i].UNCOND_DNET is not None:
logits = netsD[i].UNCOND_DNET(features)
errG = nn.BCELoss()(logits, real_labels)
g_loss = errG + cond_errG
else:
g_loss = cond_errG
errG_total += g_loss
logs += 'g_loss%d: %.2f ' % (i, g_loss.data[0])
if i == (numDs - 1):
fakeimg_feature = caption_cnn(fake_imgs[i])
captions.cuda()
target_cap = pack_padded_sequence(captions, cap_lens.data.tolist(), batch_first=True)[0].cuda()
cap_output = caption_rnn(fakeimg_feature, captions, cap_lens)
cap_loss = caption_loss(cap_output, target_cap) * cfg.TRAIN.SMOOTH.LAMBDA1
errG_total += cap_loss
logs += 'cap_loss: %.2f, ' % cap_loss
return errG_total, logs
##################################################################
def KL_loss(mu, logvar):
# -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.mean(KLD_element).mul_(-0.5)
return KLD
================================================
FILE: miscc/utils.py
================================================
import os
import errno
import numpy as np
from torch.nn import init
import torch
import torch.nn as nn
from PIL import Image, ImageDraw, ImageFont
from copy import deepcopy
import skimage.transform
from cfg.config import cfg
# For visualization ################################################
COLOR_DIC = {0:[128,64,128], 1:[244, 35,232],
2:[70, 70, 70], 3:[102,102,156],
4:[190,153,153], 5:[153,153,153],
6:[250,170, 30], 7:[220, 220, 0],
8:[107,142, 35], 9:[152,251,152],
10:[70,130,180], 11:[220,20, 60],
12:[255, 0, 0], 13:[0, 0, 142],
14:[119,11, 32], 15:[0, 60,100],
16:[0, 80, 100], 17:[0, 0, 230],
18:[0, 0, 70], 19:[0, 0, 0]}
FONT_MAX = 50
def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2):
num = captions.size(0)
img_txt = Image.fromarray(convas)
# get a font
# fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
# get a drawing context
d = ImageDraw.Draw(img_txt)
sentence_list = []
for i in range(num):
cap = captions[i].data.cpu().numpy()
sentence = []
for j in range(len(cap)):
if cap[j] == 0:
break
word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii')
d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]),
font=fnt, fill=(255, 255, 255, 255))
sentence.append(word)
sentence_list.append(sentence)
return img_txt, sentence_list
def build_super_images(real_imgs, captions, ixtoword,
attn_maps, att_sze, lr_imgs=None,
batch_size=cfg.TRAIN.BATCH_SIZE,
max_word_num=cfg.TEXT.WORDS_NUM):
nvis = 8
real_imgs = real_imgs[:nvis]
if lr_imgs is not None:
lr_imgs = lr_imgs[:nvis]
if att_sze == 17:
vis_size = att_sze * 16
else:
vis_size = real_imgs.size(2)
text_convas = \
np.ones([batch_size * FONT_MAX,
(max_word_num + 2) * (vis_size + 2), 3],
dtype=np.uint8)
for i in range(max_word_num):
istart = (i + 2) * (vis_size + 2)
iend = (i + 3) * (vis_size + 2)
text_convas[:, istart:iend, :] = COLOR_DIC[i]
real_imgs = \
nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)
# [-1, 1] --> [0, 1]
real_imgs.add_(1).div_(2).mul_(255)
real_imgs = real_imgs.data.numpy()
# b x c x h x w --> b x h x w x c
real_imgs = np.transpose(real_imgs, (0, 2, 3, 1))
pad_sze = real_imgs.shape
middle_pad = np.zeros([pad_sze[2], 2, 3])
post_pad = np.zeros([pad_sze[1], pad_sze[2], 3])
if lr_imgs is not None:
lr_imgs = \
nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs)
# [-1, 1] --> [0, 1]
lr_imgs.add_(1).div_(2).mul_(255)
lr_imgs = lr_imgs.data.numpy()
# b x c x h x w --> b x h x w x c
lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1))
# batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17
seq_len = max_word_num
img_set = []
num = nvis # len(attn_maps)
text_map, sentences = \
drawCaption(text_convas, captions, ixtoword, vis_size)
text_map = np.asarray(text_map).astype(np.uint8)
bUpdate = 1
for i in range(num):
attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze)
# --> 1 x 1 x 17 x 17
attn_max = attn.max(dim=1, keepdim=True)
attn = torch.cat([attn_max[0], attn], 1)
#
attn = attn.view(-1, 1, att_sze, att_sze)
attn = attn.repeat(1, 3, 1, 1).data.numpy()
# n x c x h x w --> n x h x w x c
attn = np.transpose(attn, (0, 2, 3, 1))
num_attn = attn.shape[0]
#
img = real_imgs[i]
if lr_imgs is None:
lrI = img
else:
lrI = lr_imgs[i]
row = [lrI, middle_pad]
row_merge = [img, middle_pad]
row_beforeNorm = []
minVglobal, maxVglobal = 1, 0
for j in range(num_attn):
one_map = attn[j]
if (vis_size // att_sze) > 1:
one_map = \
skimage.transform.pyramid_expand(one_map, sigma=20,
upscale=vis_size // att_sze)
row_beforeNorm.append(one_map)
minV = one_map.min()
maxV = one_map.max()
if minVglobal > minV:
minVglobal = minV
if maxVglobal < maxV:
maxVglobal = maxV
for j in range(seq_len + 1):
if j < num_attn:
one_map = row_beforeNorm[j]
one_map = (one_map - minVglobal) / (maxVglobal - minVglobal)
one_map *= 255
#
PIL_im = Image.fromarray(np.uint8(img))
PIL_att = Image.fromarray(np.uint8(one_map))
merged = \
Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0))
mask = Image.new('L', (vis_size, vis_size), (210))
merged.paste(PIL_im, (0, 0))
merged.paste(PIL_att, (0, 0), mask)
merged = np.array(merged)[:, :, :3]
else:
one_map = post_pad
merged = post_pad
row.append(one_map)
row.append(middle_pad)
#
row_merge.append(merged)
row_merge.append(middle_pad)
row = np.concatenate(row, 1)
row_merge = np.concatenate(row_merge, 1)
txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX]
if txt.shape[1] != row.shape[1]:
print('txt', txt.shape, 'row', row.shape)
bUpdate = 0
break
row = np.concatenate([txt, row, row_merge], 0)
img_set.append(row)
if bUpdate:
img_set = np.concatenate(img_set, 0)
img_set = img_set.astype(np.uint8)
return img_set, sentences
else:
return None
def build_super_images2(real_imgs, captions, cap_lens, ixtoword,
attn_maps, att_sze, vis_size=256, topK=5):
batch_size = real_imgs.size(0)
max_word_num = np.max(cap_lens)
text_convas = np.ones([batch_size * FONT_MAX,
max_word_num * (vis_size + 2), 3],
dtype=np.uint8)
real_imgs = \
nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)
# [-1, 1] --> [0, 1]
real_imgs.add_(1).div_(2).mul_(255)
real_imgs = real_imgs.data.numpy()
# b x c x h x w --> b x h x w x c
real_imgs = np.transpose(real_imgs, (0, 2, 3, 1))
pad_sze = real_imgs.shape
middle_pad = np.zeros([pad_sze[2], 2, 3])
# batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17
img_set = []
num = len(attn_maps)
text_map, sentences = \
drawCaption(text_convas, captions, ixtoword, vis_size, off1=0)
text_map = np.asarray(text_map).astype(np.uint8)
bUpdate = 1
for i in range(num):
attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze)
#
attn = attn.view(-1, 1, att_sze, att_sze)
attn = attn.repeat(1, 3, 1, 1).data.numpy()
# n x c x h x w --> n x h x w x c
attn = np.transpose(attn, (0, 2, 3, 1))
num_attn = cap_lens[i]
thresh = 2./float(num_attn)
#
img = real_imgs[i]
row = []
row_merge = []
row_txt = []
row_beforeNorm = []
conf_score = []
for j in range(num_attn):
one_map = attn[j]
mask0 = one_map > (2. * thresh)
conf_score.append(np.sum(one_map * mask0))
mask = one_map > thresh
one_map = one_map * mask
if (vis_size // att_sze) > 1:
one_map = \
skimage.transform.pyramid_expand(one_map, sigma=20,
upscale=vis_size // att_sze)
minV = one_map.min()
maxV = one_map.max()
one_map = (one_map - minV) / (maxV - minV)
row_beforeNorm.append(one_map)
sorted_indices = np.argsort(conf_score)[::-1]
for j in range(num_attn):
one_map = row_beforeNorm[j]
one_map *= 255
#
PIL_im = Image.fromarray(np.uint8(img))
PIL_att = Image.fromarray(np.uint8(one_map))
merged = \
Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0))
mask = Image.new('L', (vis_size, vis_size), (180)) # (210)
merged.paste(PIL_im, (0, 0))
merged.paste(PIL_att, (0, 0), mask)
merged = np.array(merged)[:, :, :3]
row.append(np.concatenate([one_map, middle_pad], 1))
#
row_merge.append(np.concatenate([merged, middle_pad], 1))
#
txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX,
j * (vis_size + 2):(j + 1) * (vis_size + 2), :]
row_txt.append(txt)
# reorder
row_new = []
row_merge_new = []
txt_new = []
for j in range(num_attn):
idx = sorted_indices[j]
row_new.append(row[idx])
row_merge_new.append(row_merge[idx])
txt_new.append(row_txt[idx])
row = np.concatenate(row_new[:topK], 1)
row_merge = np.concatenate(row_merge_new[:topK], 1)
txt = np.concatenate(txt_new[:topK], 1)
if txt.shape[1] != row.shape[1]:
print('Warnings: txt', txt.shape, 'row', row.shape,
'row_merge_new', row_merge_new.shape)
bUpdate = 0
break
row = np.concatenate([txt, row_merge], 0)
img_set.append(row)
if bUpdate:
img_set = np.concatenate(img_set, 0)
img_set = img_set.astype(np.uint8)
return img_set, sentences
else:
return None
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.orthogonal(m.weight.data, 1.0)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('Linear') != -1:
nn.init.orthogonal(m.weight.data, 1.0)
if m.bias is not None:
m.bias.data.fill_(0.0)
def load_params(model, new_param):
for p, new_p in zip(model.parameters(), new_param):
p.data.copy_(new_p)
def copy_G_params(model):
flatten = deepcopy(list(p.data for p in model.parameters()))
return flatten
def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
================================================
FILE: model.py
================================================
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.autograd import Variable
from torchvision import models
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from cfg.config import cfg
from GLAttention import GLAttentionGeneral as ATT_NET
class GLU(nn.Module):
def __init__(self):
super(GLU, self).__init__()
def forward(self, x):
nc = x.size(1)
assert nc % 2 == 0, 'channels dont divide 2!'
nc = int(nc/2)
return x[:, :nc] * F.sigmoid(x[:, nc:])
def conv1x1(in_planes, out_planes, bias=False):
"1x1 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
padding=0, bias=bias)
def conv3x3(in_planes, out_planes):
"3x3 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,
padding=1, bias=False)
# Upsale the spatial size by a factor of 2
def upBlock(in_planes, out_planes):
block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='nearest'),
conv3x3(in_planes, out_planes * 2),
nn.BatchNorm2d(out_planes * 2),
GLU())
return block
# Keep the spatial size
def Block3x3_relu(in_planes, out_planes):
block = nn.Sequential(
conv3x3(in_planes, out_planes * 2),
nn.BatchNorm2d(out_planes * 2),
GLU())
return block
class ResBlock(nn.Module):
def __init__(self, channel_num):
super(ResBlock, self).__init__()
self.block = nn.Sequential(
conv3x3(channel_num, channel_num * 2),
nn.BatchNorm2d(channel_num * 2),
GLU(),
conv3x3(channel_num, channel_num),
nn.BatchNorm2d(channel_num))
def forward(self, x):
residual = x
out = self.block(x)
out += residual
return out
# ############## Text2Image Encoder-Decoder #######
class RNN_ENCODER(nn.Module):
def __init__(self, ntoken, ninput=300, drop_prob=0.5,
nhidden=128, nlayers=1, bidirectional=True):
super(RNN_ENCODER, self).__init__()
self.n_steps = cfg.TEXT.WORDS_NUM
self.ntoken = ntoken # size of the dictionary
self.ninput = ninput # size of each embedding vector
self.drop_prob = drop_prob # probability of an element to be zeroed
self.nlayers = nlayers # Number of recurrent layers
self.bidirectional = bidirectional
self.rnn_type = cfg.RNN_TYPE
if bidirectional:
self.num_directions = 2
else:
self.num_directions = 1
# number of features in the hidden state
self.nhidden = nhidden // self.num_directions
self.define_module()
self.init_weights()
def define_module(self):
self.encoder = nn.Embedding(self.ntoken, self.ninput)
self.drop = nn.Dropout(self.drop_prob)
if self.rnn_type == 'LSTM':
# dropout: If non-zero, introduces a dropout layer on
# the outputs of each RNN layer except the last layer
self.rnn = nn.LSTM(self.ninput, self.nhidden,
self.nlayers, batch_first=True,
dropout=self.drop_prob,
bidirectional=self.bidirectional)
elif self.rnn_type == 'GRU':
self.rnn = nn.GRU(self.ninput, self.nhidden,
self.nlayers, batch_first=True,
dropout=self.drop_prob,
bidirectional=self.bidirectional)
else:
raise NotImplementedError
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
# Do not need to initialize RNN parameters, which have been initialized
# http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM
# self.decoder.weight.data.uniform_(-initrange, initrange)
# self.decoder.bias.data.fill_(0)
def init_hidden(self, bsz):
weight = next(self.parameters()).data
if self.rnn_type == 'LSTM':
return (Variable(weight.new(self.nlayers * self.num_directions,
bsz, self.nhidden).zero_()),
Variable(weight.new(self.nlayers * self.num_directions,
bsz, self.nhidden).zero_()))
else:
return Variable(weight.new(self.nlayers * self.num_directions,
bsz, self.nhidden).zero_())
def forward(self, captions, cap_lens, hidden, mask=None):
# input: torch.LongTensor of size batch x n_steps
# --> emb: batch x n_steps x ninput
emb = self.drop(self.encoder(captions))
#
# Returns: a PackedSequence object
cap_lens = cap_lens.data.tolist()
emb = pack_padded_sequence(emb, cap_lens, batch_first=True)
# #hidden and memory (num_layers * num_directions, batch, hidden_size):
# tensor containing the initial hidden state for each element in batch.
# #output (batch, seq_len, hidden_size * num_directions)
# #or a PackedSequence object:
# tensor containing output features (h_t) from the last layer of RNN
output, hidden = self.rnn(emb, hidden)
# PackedSequence object
# --> (batch, seq_len, hidden_size * num_directions)
output = pad_packed_sequence(output, batch_first=True)[0]
# output = self.drop(output)
# --> batch x hidden_size*num_directions x seq_len
words_emb = output.transpose(1, 2)
# --> batch x num_directions*hidden_size
if self.rnn_type == 'LSTM':
sent_emb = hidden[0].transpose(0, 1).contiguous()
else:
sent_emb = hidden.transpose(0, 1).contiguous()
sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions)
return words_emb, sent_emb
class CNN_ENCODER(nn.Module):
def __init__(self, nef):
super(CNN_ENCODER, self).__init__()
if cfg.TRAIN.FLAG:
self.nef = nef
else:
self.nef = 256 # define a uniform ranker
model = models.inception_v3()
url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'
model.load_state_dict(model_zoo.load_url(url))
for param in model.parameters():
param.requires_grad = False
print('Load pretrained model from ', url)
# print(model)
self.define_module(model)
self.init_trainable_weights()
def define_module(self, model):
self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3
self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3
self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3
self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1
self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3
self.Mixed_5b = model.Mixed_5b
self.Mixed_5c = model.Mixed_5c
self.Mixed_5d = model.Mixed_5d
self.Mixed_6a = model.Mixed_6a
self.Mixed_6b = model.Mixed_6b
self.Mixed_6c = model.Mixed_6c
self.Mixed_6d = model.Mixed_6d
self.Mixed_6e = model.Mixed_6e
self.Mixed_7a = model.Mixed_7a
self.Mixed_7b = model.Mixed_7b
self.Mixed_7c = model.Mixed_7c
self.emb_features = conv1x1(768, self.nef)
self.emb_cnn_code = nn.Linear(2048, self.nef)
def init_trainable_weights(self):
initrange = 0.1
self.emb_features.weight.data.uniform_(-initrange, initrange)
self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)
def forward(self, x):
features = None
# --> fixed-size input: batch x 3 x 299 x 299
x = nn.Upsample(size=(299, 299), mode='bilinear')(x)
# 299 x 299 x 3
x = self.Conv2d_1a_3x3(x)
# 149 x 149 x 32
x = self.Conv2d_2a_3x3(x)
# 147 x 147 x 32
x = self.Conv2d_2b_3x3(x)
# 147 x 147 x 64
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 73 x 73 x 64
x = self.Conv2d_3b_1x1(x)
# 73 x 73 x 80
x = self.Conv2d_4a_3x3(x)
# 71 x 71 x 192
x = F.max_pool2d(x, kernel_size=3, stride=2)
# 35 x 35 x 192
x = self.Mixed_5b(x)
# 35 x 35 x 256
x = self.Mixed_5c(x)
# 35 x 35 x 288
x = self.Mixed_5d(x)
# 35 x 35 x 288
x = self.Mixed_6a(x)
# 17 x 17 x 768
x = self.Mixed_6b(x)
# 17 x 17 x 768
x = self.Mixed_6c(x)
# 17 x 17 x 768
x = self.Mixed_6d(x)
# 17 x 17 x 768
x = self.Mixed_6e(x)
# 17 x 17 x 768
# image region features
features = x
# 17 x 17 x 768
x = self.Mixed_7a(x)
# 8 x 8 x 1280
x = self.Mixed_7b(x)
# 8 x 8 x 2048
x = self.Mixed_7c(x)
# 8 x 8 x 2048
x = F.avg_pool2d(x, kernel_size=8)
# 1 x 1 x 2048
# x = F.dropout(x, training=self.training)
# 1 x 1 x 2048
x = x.view(x.size(0), -1)
# 2048
# global image features
cnn_code = self.emb_cnn_code(x)
# 512
if features is not None:
features = self.emb_features(features)
return features, cnn_code
# ############## G networks ###################
class CA_NET(nn.Module):
# some code is modified from vae examples
# (https://github.com/pytorch/examples/blob/master/vae/main.py)
def __init__(self):
super(CA_NET, self).__init__()
self.t_dim = cfg.TEXT.EMBEDDING_DIM
self.c_dim = cfg.GAN.CONDITION_DIM
self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True)
self.relu = GLU()
def encode(self, text_embedding):
x = self.relu(self.fc(text_embedding))
mu = x[:, :self.c_dim]
logvar = x[:, self.c_dim:]
return mu, logvar
def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
if cfg.CUDA:
eps = torch.cuda.FloatTensor(std.size()).normal_()
else:
eps = torch.FloatTensor(std.size()).normal_()
eps = Variable(eps)
return eps.mul(std).add_(mu)
def forward(self, text_embedding):
mu, logvar = self.encode(text_embedding)
c_code = self.reparametrize(mu, logvar)
return c_code, mu, logvar
class INIT_STAGE_G(nn.Module):
def __init__(self, ngf, ncf):
super(INIT_STAGE_G, self).__init__()
self.gf_dim = ngf
self.in_dim = cfg.GAN.Z_DIM + ncf # cfg.TEXT.EMBEDDING_DIM
self.define_module()
def define_module(self):
nz, ngf = self.in_dim, self.gf_dim
self.fc = nn.Sequential(
nn.Linear(nz, ngf * 4 * 4 * 2, bias=False),
nn.BatchNorm1d(ngf * 4 * 4 * 2),
GLU())
self.upsample1 = upBlock(ngf, ngf // 2)
self.upsample2 = upBlock(ngf // 2, ngf // 4)
self.upsample3 = upBlock(ngf // 4, ngf // 8)
self.upsample4 = upBlock(ngf // 8, ngf // 16)
def forward(self, z_code, c_code):
"""
:param z_code: batch x cfg.GAN.Z_DIM
:param c_code: batch x cfg.TEXT.EMBEDDING_DIM
:return: batch x ngf/16 x 64 x 64
"""
c_z_code = torch.cat((c_code, z_code), 1)
# state size ngf x 4 x 4
out_code = self.fc(c_z_code)
out_code = out_code.view(-1, self.gf_dim, 4, 4)
# state size ngf/3 x 8 x 8
out_code = self.upsample1(out_code)
# state size ngf/4 x 16 x 16
out_code = self.upsample2(out_code)
# state size ngf/8 x 32 x 32
out_code32 = self.upsample3(out_code)
# state size ngf/16 x 64 x 64
out_code64 = self.upsample4(out_code32)
return out_code64
# class NEXT_STAGE_G(nn.Module):
# def __init__(self, ngf, nef, ncf):
# super(NEXT_STAGE_G, self).__init__()
# self.gf_dim = ngf
# self.ef_dim = nef
# self.cf_dim = ncf
# self.num_residual = cfg.GAN.R_NUM
# self.define_module()
#
# def _make_layer(self, block, channel_num):
# layers = []
# for i in range(cfg.GAN.R_NUM):
# layers.append(block(channel_num))
# return nn.Sequential(*layers)
#
# def define_module(self):
# ngf = self.gf_dim
# self.att = ATT_NET(ngf, self.ef_dim)
# self.residual = self._make_layer(ResBlock, ngf * 2)
# self.upsample = upBlock(ngf * 2, ngf)
#
# def forward(self, h_code, c_code, word_embs, mask):
# """
# h_code1(query): batch x idf x ih x iw (queryL=ihxiw)
# word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
# c_code1: batch x idf x queryL
# att1: batch x sourceL x queryL
# """
# self.att.applyMask(mask)
# c_code, att = self.att(h_code, word_embs)
# h_c_code = torch.cat((h_code, c_code), 1)
# print('h_c_code:', h_c_code.size()) \
# ('h_c_code:', (16, 64, 64, 64))
# ('h_c_code:', (16, 64, 128, 128))
# out_code = self.residual(h_c_code)
#
# # state size ngf/2 x 2in_size x 2in_size
# out_code = self.upsample(out_code)
#
# return out_code, att
class NEXT_STAGE_G(nn.Module):
def __init__(self, ngf, nef, ncf):
super(NEXT_STAGE_G, self).__init__()
self.gf_dim = ngf
self.ef_dim = nef
self.cf_dim = ncf
# print(ngf, nef, ncf) (32, 256, 100)
# (32, 256, 100)
self.num_residual = cfg.GAN.R_NUM
self.define_module()
self.conv = conv1x1(ngf * 3, ngf * 2)
def _make_layer(self, block, channel_num):
layers = []
for i in range(cfg.GAN.R_NUM): # 2
layers.append(block(channel_num))
return nn.Sequential(*layers)
def define_module(self):
ngf = self.gf_dim
self.att = ATT_NET(ngf, self.ef_dim)
self.residual = self._make_layer(ResBlock, ngf * 2)
self.upsample = upBlock(ngf * 2, ngf)
def forward(self, h_code, c_code, word_embs, mask):
"""
h_code1(query): batch x idf x ih x iw (queryL=ihxiw)
word_embs(context): batch x cdf x sourceL (sourceL=seq_len)
c_code1: batch x idf x queryL
att1: batch x sourceL x queryL
"""
# print('========')
# ((16, 32, 64, 64), (16, 100), (16, 256, 18), (16, 18))
# print(h_code.size(), c_code.size(), word_embs.size(), mask.size())
self.att.applyMask(mask)
# here, a new c_code is generated by self.att() method.
# weightedContext, weightedSentence, word_attn, sent_vs_att
c_code, weightedSentence, att, sent_att = self.att(h_code, c_code, word_embs)
# Then, image feature are concated with a new c_code, they become h_c_code,
# so, here I can make some change, to concate more items together.
# which means I need to get more output from line 369, self.att()
# also, I need to feed more information to calculate the function, and let's see what the new idea will return.
h_c_code = torch.cat((h_code, c_code), 1)
# print('h_c_code.size:', h_c_code.size()) # ('h_c_code.size:', (16, 64, 64, 64))
h_c_sent_code = torch.cat((h_c_code, weightedSentence), 1)
# print('h_c_sent_code.size:', h_c_sent_code.size())
# ('h_c_code.size:', (16, 64, 64, 64))
# ('h_c_sent_code.size:', (16, 96, 64, 64))
h_c_sent_code = self.conv(h_c_sent_code)
out_code = self.residual(h_c_sent_code)
# print('out_code:', out_code.size())
# state size ngf/2 x 2in_size x 2in_size
out_code = self.upsample(out_code)
return out_code, att
class GET_IMAGE_G(nn.Module):
def __init__(self, ngf):
super(GET_IMAGE_G, self).__init__()
self.gf_dim = ngf
self.img = nn.Sequential(
conv3x3(ngf, 3),
nn.Tanh()
)
def forward(self, h_code):
out_img = self.img(h_code)
return out_img
#G_NET used in the paper
class G_NET(nn.Module):
def __init__(self):
super(G_NET, self).__init__()
ngf = cfg.GAN.GF_DIM
nef = cfg.TEXT.EMBEDDING_DIM
ncf = cfg.GAN.CONDITION_DIM
self.ca_net = CA_NET()
if cfg.TREE.BRANCH_NUM > 0:
self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)
self.img_net1 = GET_IMAGE_G(ngf)
# gf x 64 x 64
if cfg.TREE.BRANCH_NUM > 1:
self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)
self.img_net2 = GET_IMAGE_G(ngf)
if cfg.TREE.BRANCH_NUM > 2:
self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf)
self.img_net3 = GET_IMAGE_G(ngf)
# netG(noise, sent_emb, words_embs, mask)
def forward(self, z_code, sent_emb, word_embs, mask):
"""
:param z_code: batch x cfg.GAN.Z_DIM
:param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM
:param word_embs: batch x cdf x seq_len
:param mask: batch x seq_len
:return:
"""
fake_imgs = []
att_maps = []
'''this is the Conditioning Augmentation'''
# print('sent_emb:', sent_emb.size()) #('sent_emb:', (16, 256))
c_code, mu, logvar = self.ca_net(sent_emb)
# print('=====')
# print('first c_code.size():', c_code.size()) #(16, 100)
# print('=====')
if cfg.TREE.BRANCH_NUM > 0:
h_code1 = self.h_net1(z_code, c_code)
fake_img1 = self.img_net1(h_code1)
fake_imgs.append(fake_img1)
if cfg.TREE.BRANCH_NUM > 1:
h_code2, att1 = \
self.h_net2(h_code1, c_code, word_embs, mask)
fake_img2 = self.img_net2(h_code2)
fake_imgs.append(fake_img2)
if att1 is not None:
att_maps.append(att1)
if cfg.TREE.BRANCH_NUM > 2:
h_code3, att2 = \
self.h_net3(h_code2, c_code, word_embs, mask)
fake_img3 = self.img_net3(h_code3)
fake_imgs.append(fake_img3)
if att2 is not None:
att_maps.append(att2)
return fake_imgs, att_maps, mu, logvar
class G_DCGAN(nn.Module):
def __init__(self):
super(G_DCGAN, self).__init__()
ngf = cfg.GAN.GF_DIM
nef = cfg.TEXT.EMBEDDING_DIM
ncf = cfg.GAN.CONDITION_DIM
self.ca_net = CA_NET()
# 16gf x 64 x 64 --> gf x 64 x 64 --> 3 x 64 x 64
if cfg.TREE.BRANCH_NUM > 0:
self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)
# gf x 64 x 64
if cfg.TREE.BRANCH_NUM > 1:
self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)
if cfg.TREE.BRANCH_NUM > 2:
self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf)
self.img_net = GET_IMAGE_G(ngf)
def forward(self, z_code, sent_emb, word_embs, mask):
"""
:param z_code: batch x cfg.GAN.Z_DIM
:param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM
:param word_embs: batch x cdf x seq_len
:param mask: batch x seq_len
:return:
"""
att_maps = []
c_code, mu, logvar = self.ca_net(sent_emb)
if cfg.TREE.BRANCH_NUM > 0:
h_code = self.h_net1(z_code, c_code)
if cfg.TREE.BRANCH_NUM > 1:
h_code, att1 = self.h_net2(h_code, c_code, word_embs, mask)
if att1 is not None:
att_maps.append(att1)
if cfg.TREE.BRANCH_NUM > 2:
h_code, att2 = self.h_net3(h_code, c_code, word_embs, mask)
if att2 is not None:
att_maps.append(att2)
fake_imgs = self.img_net(h_code)
return [fake_imgs], att_maps, mu, logvar
# ############## D networks ##########################
def Block3x3_leakRelu(in_planes, out_planes):
block = nn.Sequential(
conv3x3(in_planes, out_planes),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.2, inplace=True)
)
return block
# Downsale the spatial size by a factor of 2
def downBlock(in_planes, out_planes):
block = nn.Sequential(
nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False),
nn.BatchNorm2d(out_planes),
nn.LeakyReLU(0.2, inplace=True)
)
return block
# Downsale the spatial size by a factor of 16
def encode_image_by_16times(ndf):
encode_img = nn.Sequential(
# --> state size. ndf x in_size/2 x in_size/2
nn.Conv2d(3, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# --> state size 2ndf x x in_size/4 x in_size/4
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# --> state size 4ndf x in_size/8 x in_size/8
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# --> state size 8ndf x in_size/16 x in_size/16
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True)
)
return encode_img
class D_GET_LOGITS(nn.Module):
def __init__(self, ndf, nef, bcondition=False):
super(D_GET_LOGITS, self).__init__()
self.df_dim = ndf
self.ef_dim = nef
self.bcondition = bcondition
if self.bcondition:
self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8)
self.outlogits = nn.Sequential(
nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),
nn.Sigmoid())
def forward(self, h_code, c_code=None):
if self.bcondition and c_code is not None:
# conditioning output
c_code = c_code.view(-1, self.ef_dim, 1, 1)
c_code = c_code.repeat(1, 1, 4, 4)
# state size (ngf+egf) x 4 x 4
h_c_code = torch.cat((h_code, c_code), 1)
# state size ngf x in_size x in_size
h_c_code = self.jointConv(h_c_code)
else:
h_c_code = h_code
output = self.outlogits(h_c_code)
return output.view(-1)
# For 64 x 64 images
class D_NET64(nn.Module):
def __init__(self, b_jcu=True):
super(D_NET64, self).__init__()
ndf = cfg.GAN.DF_DIM
nef = cfg.TEXT.EMBEDDING_DIM
self.img_code_s16 = encode_image_by_16times(ndf)
if b_jcu:
self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
else:
self.UNCOND_DNET = None
self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)
def forward(self, x_var):
x_code4 = self.img_code_s16(x_var) # 4 x 4 x 8df
return x_code4
# For 128 x 128 images
class D_NET128(nn.Module):
def __init__(self, b_jcu=True):
super(D_NET128, self).__init__()
ndf = cfg.GAN.DF_DIM
nef = cfg.TEXT.EMBEDDING_DIM
self.img_code_s16 = encode_image_by_16times(ndf)
self.img_code_s32 = downBlock(ndf * 8, ndf * 16)
self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8)
#
if b_jcu:
self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
else:
self.UNCOND_DNET = None
self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)
def forward(self, x_var):
x_code8 = self.img_code_s16(x_var) # 8 x 8 x 8df
x_code4 = self.img_code_s32(x_code8) # 4 x 4 x 16df
x_code4 = self.img_code_s32_1(x_code4) # 4 x 4 x 8df
return x_code4
# For 256 x 256 images
class D_NET256(nn.Module):
def __init__(self, b_jcu=True):
super(D_NET256, self).__init__()
ndf = cfg.GAN.DF_DIM
nef = cfg.TEXT.EMBEDDING_DIM
self.img_code_s16 = encode_image_by_16times(ndf)
self.img_code_s32 = downBlock(ndf * 8, ndf * 16)
self.img_code_s64 = downBlock(ndf * 16, ndf * 32)
self.img_code_s64_1 = Block3x3_leakRelu(ndf * 32, ndf * 16)
self.img_code_s64_2 = Block3x3_leakRelu(ndf * 16, ndf * 8)
if b_jcu:
self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)
else:
self.UNCOND_DNET = None
self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)
def forward(self, x_var):
x_code16 = self.img_code_s16(x_var)
x_code8 = self.img_code_s32(x_code16)
x_code4 = self.img_code_s64(x_code8)
x_code4 = self.img_code_s64_1(x_code4)
x_code4 = self.img_code_s64_2(x_code4)
return x_code4
class CAPTION_CNN(nn.Module):
def __init__(self, embed_size):
"""Load the pretrained ResNet-152 and replace top fc layer."""
super(CAPTION_CNN, self).__init__()
resnet = models.resnet152(pretrained=True)
modules = list(resnet.children())[:-1] # delete the last fc layer.
self.resnet = nn.Sequential(*modules)
for param in self.resnet.parameters():
param.requires_grad = False
self.linear = nn.Linear(resnet.fc.in_features, embed_size)
self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)
def forward(self, images):
"""Extract feature vectors from input images."""
#print ('image feature size before unsample:', images.size())
m = nn.Upsample(size=(224, 224), mode='bilinear')
unsampled_images = m(images)
#print ('image feature size after unsample:', unsampled_images.size())
features = self.resnet(unsampled_images)
features = features.view(features.size(0), -1)
features = self.bn(self.linear(features))
return features
class CAPTION_RNN(nn.Module):
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):
"""Set the hyper-parameters and build the layers."""
super(CAPTION_RNN, self).__init__()
self.embed = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
self.linear = nn.Linear(hidden_size, vocab_size)
self.max_seg_length = max_seq_length
# def forward(self, features, captions, cap_lens):
# """Decode image feature vectors and generates captions."""
# # print ('feature.size():', features.size()) #(6L, 256L)
# # print ('captions.size():', captions.size()) # (6L, 12L)
# # print ('embeddings.size:',embeddings.size()) #(6L, 12L, 256L)
# embeddings = self.embed(captions)
# embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
# packed = pack_padded_sequence(embeddings, cap_lens.data.tolist(), batch_first=True)
# outputs, hidden = self.lstm(packed)
# output = self.linear(outputs[0]) # (batch size, vocab_size)
# return output, hidden, outputs # words embedding, sentence embedding
def forward(self, features, captions, cap_lens):
"""Decode image feature vectors and generates captions."""
embeddings = self.embed(captions)
embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)
packed = pack_padded_sequence(embeddings, cap_lens, batch_first=True)
hiddens, _ = self.lstm(packed)
outputs = self.linear(hiddens[0])
return outputs
def sample(self, features, states=None):
"""Generate captions for given image features using greedy search."""
sampled_ids = []
inputs = features.unsqueeze(1)
for i in range(self.max_seg_length):
hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size)
outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size)
_, predicted = outputs.max(1) # predicted: (batch_size)
sampled_ids.append(predicted)
inputs = self.embed(predicted) # inputs: (batch_size, embed_size)
inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size)
sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length)
return sampled_ids
================================================
FILE: pretrain_DAMSM.py
================================================
from __future__ import print_function
from miscc.utils import mkdir_p
from miscc.utils import build_super_images
from miscc.losses import sent_loss, words_loss
from cfg.config import cfg, cfg_from_file
from datasets import TextDataset
from datasets import prepare_data
from model import RNN_ENCODER, CNN_ENCODER
import os
import sys
import time
import random
import pprint
import datetime
import dateutil.tz
import argparse
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))
sys.path.append(dir_path)
UPDATE_INTERVAL = 200
def parse_args():
parser = argparse.ArgumentParser(description='Train a DAMSM network')
parser.add_argument('--cfg', dest='cfg_file',
help='optional config file',
default='cfg/DAMSM/bird.yml', type=str)
parser.add_argument('--gpu', dest='gpu_id', type=int, default=0)
parser.add_argument('--data_dir', dest='data_dir', type=str, default='')
parser.add_argument('--manualSeed', type=int, help='manual seed')
args = parser.parse_args()
return args
def train(dataloader, cnn_model, rnn_model, batch_size,
labels, optimizer, epoch, ixtoword, image_dir):
cnn_model.train()
rnn_model.train()
s_total_loss0 = 0
s_total_loss1 = 0
w_total_loss0 = 0
w_total_loss1 = 0
count = (epoch + 1) * len(dataloader)
start_time = time.time()
for step, data in enumerate(dataloader, 0):
# print('step', step)
rnn_model.zero_grad()
cnn_model.zero_grad()
imgs, captions, cap_lens, \
class_ids, keys = prepare_data(data)
# words_features: batch_size x nef x 17 x 17
# sent_code: batch_size x nef
words_features, sent_code = cnn_model(imgs[-1])
# --> batch_size x nef x 17*17
nef, att_sze = words_features.size(1), words_features.size(2)
# words_features = words_features.view(batch_size, nef, -1)
hidden = rnn_model.init_hidden(batch_size)
# words_emb: batch_size x nef x seq_len
# sent_emb: batch_size x nef
words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels,
cap_lens, class_ids, batch_size)
w_total_loss0 += w_loss0.data
w_total_loss1 += w_loss1.data
loss = w_loss0 + w_loss1
s_loss0, s_loss1 = \
sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
loss += s_loss0 + s_loss1
s_total_loss0 += s_loss0.data
s_total_loss1 += s_loss1.data
#
loss.backward()
#
# `clip_grad_norm` helps prevent
# the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm(rnn_model.parameters(),
cfg.TRAIN.RNN_GRAD_CLIP)
optimizer.step()
if step % UPDATE_INTERVAL == 0:
count = epoch * len(dataloader) + step
s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL
s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL
w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL
w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
's_loss {:5.2f} {:5.2f} | '
'w_loss {:5.2f} {:5.2f}'
.format(epoch, step, len(dataloader),
elapsed * 1000. / UPDATE_INTERVAL,
s_cur_loss0, s_cur_loss1,
w_cur_loss0, w_cur_loss1))
s_total_loss0 = 0
s_total_loss1 = 0
w_total_loss0 = 0
w_total_loss1 = 0
start_time = time.time()
# attention Maps
img_set, _ = \
build_super_images(imgs[-1].cpu(), captions,
ixtoword, attn_maps, att_sze)
if img_set is not None:
im = Image.fromarray(img_set)
fullpath = '%s/attention_maps%d.png' % (image_dir, step)
im.save(fullpath)
return count
def evaluate(dataloader, cnn_model, rnn_model, batch_size):
cnn_model.eval()
rnn_model.eval()
s_total_loss = 0
w_total_loss = 0
for step, data in enumerate(dataloader, 0):
real_imgs, captions, cap_lens, \
class_ids, keys = prepare_data(data)
words_features, sent_code = cnn_model(real_imgs[-1])
# nef = words_features.size(1)
# words_features = words_features.view(batch_size, nef, -1)
hidden = rnn_model.init_hidden(batch_size)
words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)
w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,
cap_lens, class_ids, batch_size)
w_total_loss += (w_loss0 + w_loss1).data
s_loss0, s_loss1 = \
sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)
s_total_loss += (s_loss0 + s_loss1).data
if step == 50:
break
s_cur_loss = s_total_loss[0] / step
w_cur_loss = w_total_loss[0] / step
return s_cur_loss, w_cur_loss
def build_models():
# build model ############################################################
text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
labels = Variable(torch.LongTensor(range(batch_size)))
start_epoch = 0
if cfg.TRAIN.NET_E != '':
state_dict = torch.load(cfg.TRAIN.NET_E)
text_encoder.load_state_dict(state_dict)
print('Load ', cfg.TRAIN.NET_E)
#
name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
state_dict = torch.load(name)
image_encoder.load_state_dict(state_dict)
print('Load ', name)
istart = cfg.TRAIN.NET_E.rfind('_') + 8
iend = cfg.TRAIN.NET_E.rfind('.')
start_epoch = cfg.TRAIN.NET_E[istart:iend]
start_epoch = int(start_epoch) + 1
print('start_epoch', start_epoch)
if cfg.CUDA:
text_encoder = text_encoder.cuda()
image_encoder = image_encoder.cuda()
labels = labels.cuda()
return text_encoder, image_encoder, labels, start_epoch
if __name__ == "__main__":
args = parse_args()
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.gpu_id == -1:
cfg.CUDA = False
else:
cfg.GPU_ID = args.gpu_id
if args.data_dir != '':
cfg.DATA_DIR = args.data_dir
print('Using config:')
pprint.pprint(cfg)
if not cfg.TRAIN.FLAG:
args.manualSeed = 100
elif args.manualSeed is None:
args.manualSeed = random.randint(1, 10000)
random.seed(args.manualSeed)
np.random.seed(args.manualSeed)
torch.manual_seed(args.manualSeed)
if cfg.CUDA:
torch.cuda.manual_seed_all(args.manualSeed)
##########################################################################
now = datetime.datetime.now(dateutil.tz.tzlocal())
timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
output_dir = '../output/%s_%s_%s' % \
(cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)
model_dir = os.path.join(output_dir, 'Model')
image_dir = os.path.join(output_dir, 'Image')
mkdir_p(model_dir)
mkdir_p(image_dir)
torch.cuda.set_device(cfg.GPU_ID)
cudnn.benchmark = True
# Get data loader ##################################################
imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))
batch_size = cfg.TRAIN.BATCH_SIZE
image_transform = transforms.Compose([
transforms.Scale(int(imsize * 76 / 64)),
transforms.RandomCrop(imsize),
transforms.RandomHorizontalFlip()])
dataset = TextDataset(cfg.DATA_DIR, 'train',
base_size=cfg.TREE.BASE_SIZE,
transform=image_transform)
print(dataset.n_words, dataset.embeddings_num)
assert dataset
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, drop_last=True,
shuffle=True, num_workers=int(cfg.WORKERS))
# # validation data #
dataset_val = TextDataset(cfg.DATA_DIR, 'test',
base_size=cfg.TREE.BASE_SIZE,
transform=image_transform)
dataloader_val = torch.utils.data.DataLoader(
dataset_val, batch_size=batch_size, drop_last=True,
shuffle=True, num_workers=int(cfg.WORKERS))
# Train ##############################################################
text_encoder, image_encoder, labels, start_epoch = build_models()
para = list(text_encoder.parameters())
for v in image_encoder.parameters():
if v.requires_grad:
para.append(v)
# optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999))
# At any point you can hit Ctrl + C to break out of training early.
try:
lr = cfg.TRAIN.ENCODER_LR
for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH):
optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999))
epoch_start_time = time.time()
count = train(dataloader, image_encoder, text_encoder,
batch_size, labels, optimizer, epoch,
dataset.ixtoword, image_dir)
print('-' * 89)
if len(dataloader_val) > 0:
s_loss, w_loss = evaluate(dataloader_val, image_encoder,
text_encoder, batch_size)
print('| end epoch {:3d} | valid loss '
'{:5.2f} {:5.2f} | lr {:.5f}|'
.format(epoch, s_loss, w_loss, lr))
print('-' * 89)
if lr > cfg.TRAIN.ENCODER_LR/10.:
lr *= 0.98
if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or
epoch == cfg.TRAIN.MAX_EPOCH):
torch.save(image_encoder.state_dict(),
'%s/image_encoder%d.pth' % (model_dir, epoch))
torch.save(text_encoder.state_dict(),
'%s/text_encoder%d.pth' % (model_dir, epoch))
print('Save G/Ds models.')
except KeyboardInterrupt:
print('-' * 89)
print('Exiting from training early')
================================================
FILE: test.py
================================================
import torch.nn as nn
import torch
from torch.autograd import Variable
def conv1x1(in_planes, out_planes):
"1x1 convolution with padding"
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,
padding=0, bias=False)
x = Variable(torch.rand(2,3,1,1))
print(x.size())
y = conv1x1(3, 3)(x)
print(y.size())
# z = torch.cat(x, x)
# print(z.size())
t = torch.mul(x, x)
print(t.size())
================================================
FILE: trainer.py
================================================
from __future__ import print_function
from six.moves import range
import torch
import torch.optim as optim
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from PIL import Image
from cfg.config import cfg
from miscc.utils import mkdir_p
from miscc.utils import build_super_images, build_super_images2
from miscc.utils import weights_init, load_params, copy_G_params
from model import G_DCGAN, G_NET
from datasets import prepare_data
from model import RNN_ENCODER, CNN_ENCODER, CAPTION_CNN, CAPTION_RNN
from miscc.losses import words_loss
from miscc.losses import discriminator_loss, generator_loss, KL_loss
import os
import time
import numpy as np
# MirrorGAN
class Trainer(object):
def __init__(self, output_dir, data_loader, n_words, ixtoword):
if cfg.TRAIN.FLAG:
self.model_dir = os.path.join(output_dir, 'Model')
self.image_dir = os.path.join(output_dir, 'Image')
mkdir_p(self.model_dir)
mkdir_p(self.image_dir)
torch.cuda.set_device(cfg.GPU_ID)
cudnn.benchmark = True
self.batch_size = cfg.TRAIN.BATCH_SIZE
self.max_epoch = cfg.TRAIN.MAX_EPOCH
self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL
self.n_words = n_words
self.ixtoword = ixtoword
self.data_loader = data_loader
self.num_batches = len(self.data_loader)
def build_models(self):
# text encoders
if cfg.TRAIN.NET_E == '':
print('Error: no pretrained text-image encoders')
return
image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)
img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')
state_dict = \
torch.load(img_encoder_path, map_location=lambda storage, loc: storage)
image_encoder.load_state_dict(state_dict)
for p in image_encoder.parameters():
p.requires_grad = False
print('Load image encoder from:', img_encoder_path)
image_encoder.eval()
text_encoder = \
RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
state_dict = \
torch.load(cfg.TRAIN.NET_E,
map_location=lambda storage, loc: storage)
text_encoder.load_state_dict(state_dict)
for p in text_encoder.parameters():
p.requires_grad = False
print('Load text encoder from:', cfg.TRAIN.NET_E)
text_encoder.eval()
# Caption models - cnn_encoder and rnn_decoder
caption_cnn = CAPTION_CNN(cfg.CAP.embed_size)
caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage))
for p in caption_cnn.parameters():
p.requires_grad = False
print('Load caption model from:', cfg.CAP.caption_cnn_path)
caption_cnn.eval()
caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers)
caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage))
for p in caption_rnn.parameters():
p.requires_grad = False
print('Load caption model from:', cfg.CAP.caption_rnn_path)
# Generator and Discriminator:
netsD = []
if cfg.GAN.B_DCGAN:
if cfg.TREE.BRANCH_NUM == 1:
from model import D_NET64 as D_NET
elif cfg.TREE.BRANCH_NUM == 2:
from model import D_NET128 as D_NET
else: # cfg.TREE.BRANCH_NUM == 3:
from model import D_NET256 as D_NET
netG = G_DCGAN()
netsD = [D_NET(b_jcu=False)]
else:
from model import D_NET64, D_NET128, D_NET256
netG = G_NET()
if cfg.TREE.BRANCH_NUM > 0:
netsD.append(D_NET64())
if cfg.TREE.BRANCH_NUM > 1:
netsD.append(D_NET128())
if cfg.TREE.BRANCH_NUM > 2:
netsD.append(D_NET256())
netG.apply(weights_init)
# print(netG)
for i in range(len(netsD)):
netsD[i].apply(weights_init)
# print(netsD[i])
print('# of netsD', len(netsD))
epoch = 0
if cfg.TRAIN.NET_G != '':
state_dict = \
torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
print('Load G from: ', cfg.TRAIN.NET_G)
istart = cfg.TRAIN.NET_G.rfind('_') + 1
iend = cfg.TRAIN.NET_G.rfind('.')
epoch = cfg.TRAIN.NET_G[istart:iend]
epoch = int(epoch) + 1
if cfg.TRAIN.B_NET_D:
Gname = cfg.TRAIN.NET_G
for i in range(len(netsD)):
s_tmp = Gname[:Gname.rfind('/')]
Dname = '%s/netD%d.pth' % (s_tmp, i)
print('Load D from: ', Dname)
state_dict = \
torch.load(Dname, map_location=lambda storage, loc: storage)
netsD[i].load_state_dict(state_dict)
if cfg.CUDA:
text_encoder = text_encoder.cuda()
image_encoder = image_encoder.cuda()
caption_cnn = caption_cnn.cuda()
caption_rnn = caption_rnn.cuda()
netG.cuda()
for i in range(len(netsD)):
netsD[i].cuda()
return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch]
def define_optimizers(self, netG, netsD):
optimizersD = []
num_Ds = len(netsD)
for i in range(num_Ds):
opt = optim.Adam(netsD[i].parameters(),
lr=cfg.TRAIN.DISCRIMINATOR_LR,
betas=(0.5, 0.999))
optimizersD.append(opt)
optimizerG = optim.Adam(netG.parameters(),
lr=cfg.TRAIN.GENERATOR_LR,
betas=(0.5, 0.999))
return optimizerG, optimizersD
def prepare_labels(self):
batch_size = self.batch_size
real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))
fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))
match_labels = Variable(torch.LongTensor(range(batch_size)))
if cfg.CUDA:
real_labels = real_labels.cuda()
fake_labels = fake_labels.cuda()
match_labels = match_labels.cuda()
return real_labels, fake_labels, match_labels
def save_model(self, netG, avg_param_G, netsD, epoch):
backup_para = copy_G_params(netG)
load_params(netG, avg_param_G)
torch.save(netG.state_dict(),
'%s/netG_epoch_%d.pth' % (self.model_dir, epoch))
load_params(netG, backup_para)
#
for i in range(len(netsD)):
netD = netsD[i]
torch.save(netD.state_dict(),
'%s/netD%d.pth' % (self.model_dir, i))
print('Save G/Ds models.')
def set_requires_grad_value(self, models_list, brequires):
for i in range(len(models_list)):
for p in models_list[i].parameters():
p.requires_grad = brequires
def save_img_results(self, netG, noise, sent_emb, words_embs, mask,
image_encoder, captions, cap_lens,
gen_iterations, name='current'):
# Save images
fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
for i in range(len(attention_maps)):
if len(fake_imgs) > 1:
img = fake_imgs[i + 1].detach().cpu()
lr_img = fake_imgs[i].detach().cpu()
else:
img = fake_imgs[0].detach().cpu()
lr_img = None
attn_maps = attention_maps[i]
att_sze = attn_maps.size(2)
img_set, _ = \
build_super_images(img, captions, self.ixtoword,
attn_maps, att_sze, lr_imgs=lr_img)
if img_set is not None:
im = Image.fromarray(img_set)
fullpath = '%s/G_%s_%d_%d.png' \
% (self.image_dir, name, gen_iterations, i)
im.save(fullpath)
i = -1
img = fake_imgs[i].detach()
region_features, _ = image_encoder(img)
att_sze = region_features.size(2)
_, _, att_maps = words_loss(region_features.detach(),
words_embs.detach(),
None, cap_lens,
None, self.batch_size)
img_set, _ = \
build_super_images(fake_imgs[i].detach().cpu(),
captions, self.ixtoword, att_maps, att_sze)
if img_set is not None:
im = Image.fromarray(img_set)
fullpath = '%s/D_%s_%d.png' \
% (self.image_dir, name, gen_iterations)
im.save(fullpath)
def train(self):
text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, start_epoch = self.build_models()
avg_param_G = copy_G_params(netG)
optimizerG, optimizersD = self.define_optimizers(netG, netsD)
real_labels, fake_labels, match_labels = self.prepare_labels()
batch_size = self.batch_size
nz = cfg.GAN.Z_DIM
noise = Variable(torch.FloatTensor(batch_size, nz))
fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))
if cfg.CUDA:
noise, fixed_noise = noise.cuda(), fixed_noise.cuda()
gen_iterations = 0
for epoch in range(start_epoch, self.max_epoch):
start_t = time.time()
data_iter = iter(self.data_loader)
step = 0
while step < self.num_batches:
# (1) Prepare training data and Compute text embeddings
data = data_iter.next()
imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
hidden = text_encoder.init_hidden(batch_size)
# words_embs: batch_size x nef x seq_len
# sent_emb: batch_size x nef
words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
mask = (captions == 0)
num_words = words_embs.size(2)
if mask.size(1) > num_words:
mask = mask[:, :num_words]
# (2) Generate fake images
noise.data.normal_(0, 1)
fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)
# (3) Update D network
errD_total = 0
D_logs = ''
for i in range(len(netsD)):
netsD[i].zero_grad()
errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
sent_emb, real_labels, fake_labels)
# backward and update parameters
errD.backward()
optimizersD[i].step()
errD_total += errD
D_logs += 'errD%d: %.2f ' % (i, errD.data[0])
# (4) Update G network: maximize log(D(G(z)))
# compute total loss for training G
step += 1
gen_iterations += 1
netG.zero_grad()
errG_total, G_logs = \
generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels,
words_embs, sent_emb, match_labels, cap_lens, class_ids)
kl_loss = KL_loss(mu, logvar)
errG_total += kl_loss
G_logs += 'kl_loss: %.2f ' % kl_loss.data[0]
# backward and update parameters
errG_total.backward()
optimizerG.step()
for p, avg_p in zip(netG.parameters(), avg_param_G):
avg_p.mul_(0.999).add_(0.001, p.data)
if gen_iterations % 100 == 0:
print(D_logs + '\n' + G_logs)
# save images
if gen_iterations % 1000 == 0:
backup_para = copy_G_params(netG)
load_params(netG, avg_param_G)
self.save_img_results(netG, fixed_noise, sent_emb,
words_embs, mask, image_encoder,
captions, cap_lens, epoch, name='average')
load_params(netG, backup_para)
end_t = time.time()
print('''[%d/%d][%d]
Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''
% (epoch, self.max_epoch, self.num_batches,
errD_total.data[0], errG_total.data[0],
end_t - start_t))
if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0:
self.save_model(netG, avg_param_G, netsD, epoch)
self.save_model(netG, avg_param_G, netsD, self.max_epoch)
def save_singleimages(self, images, filenames, save_dir,
split_dir, sentenceID=0):
for i in range(images.size(0)):
s_tmp = '%s/single_samples/%s/%s' % \
(save_dir, split_dir, filenames[i])
folder = s_tmp[:s_tmp.rfind('/')]
if not os.path.isdir(folder):
print('Make a new folder: ', folder)
mkdir_p(folder)
fullpath = '%s_%d.jpg' % (s_tmp, sentenceID)
# range from [-1, 1] to [0, 1]
# img = (images[i] + 1.0) / 2
img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()
# range from [0, 1] to [0, 255]
ndarr = img.permute(1, 2, 0).data.cpu().numpy()
im = Image.fromarray(ndarr)
im.save(fullpath)
def sampling(self, split_dir):
if cfg.TRAIN.NET_G == '':
print('Error: the path for model is not found!')
else:
if split_dir == 'test':
split_dir = 'valid'
# Build and load the generator
if cfg.GAN.B_DCGAN:
netG = G_DCGAN()
else:
netG = G_NET()
netG.apply(weights_init)
netG.cuda()
netG.eval()
#
text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
state_dict = \
torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
text_encoder.load_state_dict(state_dict)
print('Load text encoder from:', cfg.TRAIN.NET_E)
text_encoder = text_encoder.cuda()
text_encoder.eval()
batch_size = self.batch_size
nz = cfg.GAN.Z_DIM
noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
noise = noise.cuda()
model_dir = cfg.TRAIN.NET_G
state_dict = \
torch.load(model_dir, map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
print('Load G from: ', model_dir)
# the path to save generated images
s_tmp = model_dir[:model_dir.rfind('.pth')]
save_dir = '%s/%s' % (s_tmp, split_dir)
mkdir_p(save_dir)
cnt = 0
for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE):
for step, data in enumerate(self.data_loader, 0):
cnt += batch_size
if step % 100 == 0:
print('step: ', step)
# if step > 50:
# break
imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
hidden = text_encoder.init_hidden(batch_size)
# words_embs: batch_size x nef x seq_len
# sent_emb: batch_size x nef
words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
words_embs, sent_emb = words_embs.detach(), sent_emb.detach()
mask = (captions == 0)
num_words = words_embs.size(2)
if mask.size(1) > num_words:
mask = mask[:, :num_words]
# (2) Generate fake images
noise.data.normal_(0, 1)
fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask)
for j in range(batch_size):
s_tmp = '%s/single/%s' % (save_dir, keys[j])
folder = s_tmp[:s_tmp.rfind('/')]
if not os.path.isdir(folder):
print('Make a new folder: ', folder)
mkdir_p(folder)
k = -1
# for k in range(len(fake_imgs)):
im = fake_imgs[k][j].data.cpu().numpy()
# [-1, 1] --> [0, 255]
im = (im + 1.0) * 127.5
im = im.astype(np.uint8)
im = np.transpose(im, (1, 2, 0))
im = Image.fromarray(im)
fullpath = '%s_s%d.png' % (s_tmp, k)
im.save(fullpath)
def gen_example(self, data_dic):
if cfg.TRAIN.NET_G == '':
print('Error: the path for morels is not found!')
else:
# Build and load the generator
text_encoder = \
RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)
state_dict = \
torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)
text_encoder.load_state_dict(state_dict)
print('Load text encoder from:', cfg.TRAIN.NET_E)
text_encoder = text_encoder.cuda()
text_encoder.eval()
# the path to save generated images
if cfg.GAN.B_DCGAN:
netG = G_DCGAN()
else:
netG = G_NET()
s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]
model_dir = cfg.TRAIN.NET_G
state_dict = \
torch.load(model_dir, map_location=lambda storage, loc: storage)
netG.load_state_dict(state_dict)
print('Load G from: ', model_dir)
netG.cuda()
netG.eval()
for key in data_dic:
save_dir = '%s/%s' % (s_tmp, key)
mkdir_p(save_dir)
captions, cap_lens, sorted_indices = data_dic[key]
batch_size = captions.shape[0]
nz = cfg.GAN.Z_DIM
captions = Variable(torch.from_numpy(captions), volatile=True)
cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)
captions = captions.cuda()
cap_lens = cap_lens.cuda()
for i in range(1): # 16
noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)
noise = noise.cuda()
# (1) Extract text embeddings
hidden = text_encoder.init_hidden(batch_size)
# words_embs: batch_size x nef x seq_len
# sent_emb: batch_size x nef
words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)
mask = (captions == 0)
# (2) Generate fake images
noise.data.normal_(0, 1)
fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)
# G attention
cap_lens_np = cap_lens.cpu().data.numpy()
for j in range(batch_size):
save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j])
for k in range(len(fake_imgs)):
im = fake_imgs[k][j].data.cpu().numpy()
im = (im + 1.0) * 127.5
im = im.astype(np.uint8)
# print('im', im.shape)
im = np.transpose(im, (1, 2, 0))
# print('im', im.shape)
im = Image.fromarray(im)
fullpath = '%s_g%d.png' % (save_name, k)
im.save(fullpath)
for k in range(len(attention_maps)):
if len(fake_imgs) > 1:
im = fake_imgs[k + 1].detach().cpu()
else:
im = fake_imgs[0].detach().cpu()
attn_maps = attention_maps[k]
att_sze = attn_maps.size(2)
img_set, sentences = \
build_super_images2(im[j].unsqueeze(0),
captions[j].unsqueeze(0),
[cap_lens_np[j]], self.ixtoword,
[attn_maps[j]], att_sze)
if img_set is not None:
im = Image.fromarray(img_set)
fullpath = '%s_a%d.png' % (save_name, k)
im.save(fullpath)
gitextract_0488eq_x/ ├── .gitignore ├── GLAttention.py ├── README.md ├── cfg/ │ ├── __init__.py │ ├── config.py │ ├── eval_bird.yml │ └── train_bird.yml ├── datasets.py ├── do_test.sh ├── do_train.sh ├── main.py ├── miscc/ │ ├── __init__.py │ ├── losses.py │ └── utils.py ├── model.py ├── pretrain_DAMSM.py ├── test.py └── trainer.py
SYMBOL INDEX (120 symbols across 10 files)
FILE: GLAttention.py
function conv1x1 (line 4) | def conv1x1(in_planes, out_planes):
function func_attention (line 10) | def func_attention(query, context, gamma1):
class GLAttentionGeneral (line 51) | class GLAttentionGeneral(nn.Module):
method __init__ (line 52) | def __init__(self, idf, cdf):
method applyMask (line 60) | def applyMask(self, mask):
method forward (line 63) | def forward(self, input, sentence, context):
FILE: cfg/config.py
function _merge_a_into_b (line 77) | def _merge_a_into_b(a, b):
function cfg_from_file (line 110) | def cfg_from_file(filename):
FILE: datasets.py
function prepare_data (line 28) | def prepare_data(data):
function get_imgs (line 59) | def get_imgs(img_path, imsize, bbox=None,
class TextDataset (line 91) | class TextDataset(data.Dataset):
method __init__ (line 92) | def __init__(self, data_dir, split='train',
method load_bbox (line 121) | def load_bbox(self):
method load_captions (line 145) | def load_captions(self, data_dir, filenames):
method build_dictionary (line 179) | def build_dictionary(self, train_captions, test_captions):
method load_text_data (line 219) | def load_text_data(self, data_dir, split):
method load_class_id (line 251) | def load_class_id(self, data_dir, total_num):
method load_filenames (line 259) | def load_filenames(self, data_dir, split):
method get_caption (line 269) | def get_caption(self, sent_ix):
method __getitem__ (line 289) | def __getitem__(self, index):
method __len__ (line 311) | def __len__(self):
FILE: main.py
function parse_args (line 24) | def parse_args():
function gen_example (line 36) | def gen_example(wordtoix, algo):
FILE: miscc/losses.py
function cosine_similarity (line 11) | def cosine_similarity(x1, x2, dim=1, eps=1e-8):
function caption_loss (line 19) | def caption_loss(cap_output, captions):
function sent_loss (line 24) | def sent_loss(cnn_code, rnn_code, labels, class_ids,
function words_loss (line 66) | def words_loss(img_features, words_emb, labels,
function discriminator_loss (line 140) | def discriminator_loss(netD, real_imgs, fake_imgs, conditions,
function generator_loss (line 168) | def generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, capti...
function KL_loss (line 204) | def KL_loss(mu, logvar):
FILE: miscc/utils.py
function drawCaption (line 30) | def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2):
function build_super_images (line 53) | def build_super_images(real_imgs, captions, ixtoword,
function build_super_images2 (line 179) | def build_super_images2(real_imgs, captions, cap_lens, ixtoword,
function weights_init (line 285) | def weights_init(m):
function load_params (line 298) | def load_params(model, new_param):
function copy_G_params (line 303) | def copy_G_params(model):
function mkdir_p (line 308) | def mkdir_p(path):
FILE: model.py
class GLU (line 13) | class GLU(nn.Module):
method __init__ (line 14) | def __init__(self):
method forward (line 17) | def forward(self, x):
function conv1x1 (line 24) | def conv1x1(in_planes, out_planes, bias=False):
function conv3x3 (line 29) | def conv3x3(in_planes, out_planes):
function upBlock (line 35) | def upBlock(in_planes, out_planes):
function Block3x3_relu (line 44) | def Block3x3_relu(in_planes, out_planes):
class ResBlock (line 52) | class ResBlock(nn.Module):
method __init__ (line 53) | def __init__(self, channel_num):
method forward (line 62) | def forward(self, x):
class RNN_ENCODER (line 70) | class RNN_ENCODER(nn.Module):
method __init__ (line 71) | def __init__(self, ntoken, ninput=300, drop_prob=0.5,
method define_module (line 91) | def define_module(self):
method init_weights (line 109) | def init_weights(self):
method init_hidden (line 117) | def init_hidden(self, bsz):
method forward (line 128) | def forward(self, captions, cap_lens, hidden, mask=None):
class CNN_ENCODER (line 157) | class CNN_ENCODER(nn.Module):
method __init__ (line 158) | def __init__(self, nef):
method define_module (line 176) | def define_module(self, model):
method init_trainable_weights (line 197) | def init_trainable_weights(self):
method forward (line 202) | def forward(self, x):
class CA_NET (line 266) | class CA_NET(nn.Module):
method __init__ (line 269) | def __init__(self):
method encode (line 276) | def encode(self, text_embedding):
method reparametrize (line 282) | def reparametrize(self, mu, logvar):
method forward (line 291) | def forward(self, text_embedding):
class INIT_STAGE_G (line 297) | class INIT_STAGE_G(nn.Module):
method __init__ (line 298) | def __init__(self, ngf, ncf):
method define_module (line 305) | def define_module(self):
method forward (line 317) | def forward(self, z_code, c_code):
class NEXT_STAGE_G (line 380) | class NEXT_STAGE_G(nn.Module):
method __init__ (line 381) | def __init__(self, ngf, nef, ncf):
method _make_layer (line 392) | def _make_layer(self, block, channel_num):
method define_module (line 398) | def define_module(self):
method forward (line 404) | def forward(self, h_code, c_code, word_embs, mask):
class GET_IMAGE_G (line 436) | class GET_IMAGE_G(nn.Module):
method __init__ (line 437) | def __init__(self, ngf):
method forward (line 445) | def forward(self, h_code):
class G_NET (line 450) | class G_NET(nn.Module):
method __init__ (line 451) | def __init__(self):
method forward (line 469) | def forward(self, z_code, sent_emb, word_embs, mask):
class G_DCGAN (line 508) | class G_DCGAN(nn.Module):
method __init__ (line 509) | def __init__(self):
method forward (line 526) | def forward(self, z_code, sent_emb, word_embs, mask):
function Block3x3_leakRelu (line 552) | def Block3x3_leakRelu(in_planes, out_planes):
function downBlock (line 562) | def downBlock(in_planes, out_planes):
function encode_image_by_16times (line 572) | def encode_image_by_16times(ndf):
class D_GET_LOGITS (line 593) | class D_GET_LOGITS(nn.Module):
method __init__ (line 594) | def __init__(self, ndf, nef, bcondition=False):
method forward (line 606) | def forward(self, h_code, c_code=None):
class D_NET64 (line 622) | class D_NET64(nn.Module):
method __init__ (line 623) | def __init__(self, b_jcu=True):
method forward (line 634) | def forward(self, x_var):
class D_NET128 (line 640) | class D_NET128(nn.Module):
method __init__ (line 641) | def __init__(self, b_jcu=True):
method forward (line 655) | def forward(self, x_var):
class D_NET256 (line 663) | class D_NET256(nn.Module):
method __init__ (line 664) | def __init__(self, b_jcu=True):
method forward (line 679) | def forward(self, x_var):
class CAPTION_CNN (line 686) | class CAPTION_CNN(nn.Module):
method __init__ (line 687) | def __init__(self, embed_size):
method forward (line 698) | def forward(self, images):
class CAPTION_RNN (line 709) | class CAPTION_RNN(nn.Module):
method __init__ (line 710) | def __init__(self, embed_size, hidden_size, vocab_size, num_layers, ma...
method forward (line 730) | def forward(self, features, captions, cap_lens):
method sample (line 739) | def sample(self, features, states=None):
FILE: pretrain_DAMSM.py
function parse_args (line 37) | def parse_args():
function train (line 49) | def train(dataloader, cnn_model, rnn_model, batch_size,
function evaluate (line 133) | def evaluate(dataloader, cnn_model, rnn_model, batch_size):
function build_models (line 166) | def build_models():
FILE: test.py
function conv1x1 (line 5) | def conv1x1(in_planes, out_planes):
FILE: trainer.py
class Trainer (line 23) | class Trainer(object):
method __init__ (line 24) | def __init__(self, output_dir, data_loader, n_words, ixtoword):
method build_models (line 43) | def build_models(self):
method define_optimizers (line 142) | def define_optimizers(self, netG, netsD):
method prepare_labels (line 157) | def prepare_labels(self):
method save_model (line 169) | def save_model(self, netG, avg_param_G, netsD, epoch):
method set_requires_grad_value (line 182) | def set_requires_grad_value(self, models_list, brequires):
method save_img_results (line 187) | def save_img_results(self, netG, noise, sent_emb, words_embs, mask,
method train (line 227) | def train(self):
method save_singleimages (line 318) | def save_singleimages(self, images, filenames, save_dir,
method sampling (line 337) | def sampling(self, split_dir):
method gen_example (line 418) | def gen_example(self, data_dic):
Condensed preview — 18 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (110K chars).
[
{
"path": ".gitignore",
"chars": 35,
"preview": "*.pyc\nmiscc/*.pyc\n.DS_Store\n.idea/\n"
},
{
"path": "GLAttention.py",
"chars": 4593,
"preview": "import torch\nimport torch.nn as nn\n\ndef conv1x1(in_planes, out_planes):\n \"1x1 convolution with padding\"\n return nn"
},
{
"path": "README.md",
"chars": 1571,
"preview": "# MirrorGAN\n\nPytorch implementation for Paper [MirrorGAN: Learning Text-to-image Generation by Redescription](https://ar"
},
{
"path": "cfg/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cfg/config.py",
"chars": 2820,
"preview": "from __future__ import division\nfrom __future__ import print_function\n\nimport os.path as osp\nimport numpy as np\nfrom eas"
},
{
"path": "cfg/eval_bird.yml",
"chars": 476,
"preview": "CONFIG_NAME: 'MirrorGAN'\nDATASET_NAME: 'birds'\nDATA_DIR: '../data/birds'\nGPU_ID: 3\nWORKERS: 1\n\nB_VALIDATION: True # Tru"
},
{
"path": "cfg/train_bird.yml",
"chars": 866,
"preview": "CONFIG_NAME: 'MirrorGAN'\nDATASET_NAME: 'birds'\nDATA_DIR: '../data/birds'\nGPU_ID: 3\nWORKERS: 4\nOUTPUT_PATH: '/data/qtt/Mi"
},
{
"path": "datasets.py",
"chars": 11129,
"preview": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nfrom __futu"
},
{
"path": "do_test.sh",
"chars": 48,
"preview": "cfg=cfg/eval_bird.yml\npython main.py --cfg $cfg\n"
},
{
"path": "do_train.sh",
"chars": 49,
"preview": "cfg=cfg/train_bird.yml\npython main.py --cfg $cfg\n"
},
{
"path": "main.py",
"chars": 5098,
"preview": "from __future__ import print_function\n\nfrom cfg.config import cfg, cfg_from_file\nfrom datasets import TextDataset\nfrom t"
},
{
"path": "miscc/__init__.py",
"chars": 70,
"preview": "from __future__ import division\nfrom __future__ import print_function\n"
},
{
"path": "miscc/losses.py",
"chars": 7793,
"preview": "import torch\nimport torch.nn as nn\n\nimport numpy as np\nfrom cfg.config import cfg\nfrom torch.nn.utils.rnn import pack_pa"
},
{
"path": "miscc/utils.py",
"chars": 10954,
"preview": "import os\nimport errno\nimport numpy as np\nfrom torch.nn import init\n\nimport torch\nimport torch.nn as nn\n\nfrom PIL import"
},
{
"path": "model.py",
"chars": 28148,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.parallel\nfrom torch.autograd import Variable\nfrom torchvision import "
},
{
"path": "pretrain_DAMSM.py",
"chars": 10745,
"preview": "from __future__ import print_function\n\nfrom miscc.utils import mkdir_p\nfrom miscc.utils import build_super_images\nfrom m"
},
{
"path": "test.py",
"chars": 422,
"preview": "import torch.nn as nn\nimport torch\nfrom torch.autograd import Variable\n\ndef conv1x1(in_planes, out_planes):\n \"1x1 con"
},
{
"path": "trainer.py",
"chars": 21646,
"preview": "from __future__ import print_function\nfrom six.moves import range\nimport torch\nimport torch.optim as optim\nfrom torch.au"
}
]
About this extraction
This page contains the full source code of the qiaott/MirrorGAN GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 18 files (104.0 KB), approximately 28.3k tokens, and a symbol index with 120 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.