Repository: qiaott/MirrorGAN Branch: master Commit: deb220fdae8f Files: 18 Total size: 104.0 KB Directory structure: gitextract_0488eq_x/ ├── .gitignore ├── GLAttention.py ├── README.md ├── cfg/ │ ├── __init__.py │ ├── config.py │ ├── eval_bird.yml │ └── train_bird.yml ├── datasets.py ├── do_test.sh ├── do_train.sh ├── main.py ├── miscc/ │ ├── __init__.py │ ├── losses.py │ └── utils.py ├── model.py ├── pretrain_DAMSM.py ├── test.py └── trainer.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pyc miscc/*.pyc .DS_Store .idea/ ================================================ FILE: GLAttention.py ================================================ import torch import torch.nn as nn def conv1x1(in_planes, out_planes): "1x1 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) def func_attention(query, context, gamma1): """ query: batch x ndf x queryL context: batch x ndf x ih x iw (sourceL=ihxiw) mask: batch_size x sourceL """ batch_size, queryL = query.size(0), query.size(2) ih, iw = context.size(2), context.size(3) sourceL = ih * iw # --> batch x sourceL x ndf context = context.view(batch_size, -1, sourceL) contextT = torch.transpose(context, 1, 2).contiguous() # Get attention # (batch x sourceL x ndf)(batch x ndf x queryL) # -->batch x sourceL x queryL attn = torch.bmm(contextT, query) # --> batch*sourceL x queryL attn = attn.view(batch_size*sourceL, queryL) attn = nn.Softmax()(attn) # Eq. (8) # --> batch x sourceL x queryL attn = attn.view(batch_size, sourceL, queryL) # --> batch*queryL x sourceL attn = torch.transpose(attn, 1, 2).contiguous() attn = attn.view(batch_size*queryL, sourceL) attn = attn * gamma1 attn = nn.Softmax()(attn) attn = attn.view(batch_size, queryL, sourceL) # --> batch x sourceL x queryL attnT = torch.transpose(attn, 1, 2).contiguous() # (batch x ndf x sourceL)(batch x sourceL x queryL) # --> batch x ndf x queryL weightedContext = torch.bmm(context, attnT) return weightedContext, attn.view(batch_size, -1, ih, iw) class GLAttentionGeneral(nn.Module): def __init__(self, idf, cdf): super(GLAttentionGeneral, self).__init__() self.conv_context = conv1x1(cdf, idf) self.conv_sentence_vis = conv1x1(idf, idf) self.linear = nn.Linear(100, idf) self.sm = nn.Softmax() self.mask = None def applyMask(self, mask): self.mask = mask # batch x sourceL def forward(self, input, sentence, context): """ input: batch x idf x ih x iw (queryL=ihxiw) context: batch x cdf x sourceL (this is the matrix of word vectors) sentence (c_code1): batch x idf x queryL (this is the vectors of the sentence) queryL=ih x iw """ idf, ih, iw = input.size(1), input.size(2), input.size(3) queryL = ih * iw batch_size, sourceL = context.size(0), context.size(2) # generated image feature:--> batch x queryL x idf target = input.view(batch_size, -1, queryL) # batch x idf x queryL targetT = torch.transpose(target, 1, 2).contiguous() # batch x queryL x idf # Eq(4) in MirrorGAN : local-level attention # words feature: batch x cdf x sourceL --> batch x cdf x sourceL x 1 sourceT = context.unsqueeze(3) # --> batch x idf x sourceL sourceT = self.conv_context(sourceT).squeeze(3) attn = torch.bmm(targetT, sourceT) # --> batch*queryL x sourceL attn = attn.view(batch_size*queryL, sourceL) if self.mask is not None: # batch_size x sourceL --> batch_size*queryL x sourceL mask = self.mask.repeat(queryL, 1) attn.data.masked_fill_(mask.data, -float('inf')) attn = self.sm(attn) # Eq. (2) # --> batch x queryL x sourceL attn = attn.view(batch_size, queryL, sourceL) # --> batch x sourceL x queryL attn = torch.transpose(attn, 1, 2).contiguous() # (batch x idf x sourceL)(batch x sourceL x queryL) # --> batch x idf x queryL weightedContext = torch.bmm(sourceT, attn) weightedContext = weightedContext.view(batch_size, -1, ih, iw) # batch x idf x ih x iw word_attn = attn.view(batch_size, -1, ih, iw) # (batch x sourceL x ih x iw) # Eq(5) in MirrorGAN : global-level attention sentence = self.linear(sentence) sentence = sentence.view(batch_size, idf, 1, 1) sentence = sentence.repeat(1, 1, ih, iw) sentence_vs = torch.mul(input, sentence) # batch x idf x ih x iw sentence_vs = self.conv_sentence_vis(sentence_vs) # batch x idf x ih x iw sent_att = nn.Softmax()(sentence_vs) # batch x idf x ih x iw weightedSentence = torch.mul(sentence, sent_att) # batch x idf x ih x iw return weightedContext, weightedSentence, word_attn, sent_att # weightedContext: batch x idf x ih x iw # weightedSentence: batch x idf x ih x iw # word_attn: batch x sourceL x ih x iw # sent_vs_att: batch x idf x ih x iw ================================================ FILE: README.md ================================================ # MirrorGAN Pytorch implementation for Paper [MirrorGAN: Learning Text-to-image Generation by Redescription](https://arxiv.org/abs/1903.05854) by Tingting Qiao, Jing Zhang, Duanqing Xu, Dacheng Tao. (The work was performed when Tingting Qiao was a visiting student at UBTECH Sydney AI Centre in the School of Computer Science, FEIT, the University of Sydney). ![image](images/framework.jpg) ## Getting Started ### Installation - Install PyTorch and dependencies from http://pytorch.org - Install Torch vision from the source. - Clone this repo: ```bash git clone https://github.com/qiaott/MirrorGAN.git cd MirrorGAN ``` - Download our preprocessed data from [here](https://drive.google.com/file/d/1CuW5ognTSkNbyx9TWoUFrgwqxZNk1cl0/view?usp=sharing). - The STEM was pretrained using the code provided [here](https://github.com/taoxugit/AttnGAN) - The STREAM was pretrained using the code provided [here](https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/03-advanced/image_captioning). ### Train/Test After obtaining the pretrained STEM and STREAM modules, we can train the text2image model. - Train a model: ```bash ./do_train.sh ``` - Test a model: ```bash ./do_test.sh ``` ## Citation If you use this code for your research, please cite our paper. ```bash @article{qiao2019mirrorgan, title={MirrorGAN: Learning Text-to-image Generation by Redescription}, author={Qiao, Tingting and Zhang, Jing and Xu, Duanqing and Tao, Dacheng}, journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, year={2019} } ``` ================================================ FILE: cfg/__init__.py ================================================ ================================================ FILE: cfg/config.py ================================================ from __future__ import division from __future__ import print_function import os.path as osp import numpy as np from easydict import EasyDict as edict __C = edict() cfg = __C # Dataset name: flowers, birds __C.DATASET_NAME = 'birds' __C.CONFIG_NAME = '' __C.DATA_DIR = '' __C.GPU_ID = 0 __C.CUDA = True __C.WORKERS = 6 __C.OUTPUT_PATH = '' __C.RNN_TYPE = 'LSTM' # 'GRU' __C.B_VALIDATION = False __C.TREE = edict() __C.TREE.BRANCH_NUM = 3 __C.TREE.BASE_SIZE = 64 # Training options __C.TRAIN = edict() __C.TRAIN.BATCH_SIZE = 64 __C.TRAIN.MAX_EPOCH = 600 __C.TRAIN.SNAPSHOT_INTERVAL = 2000 __C.TRAIN.DISCRIMINATOR_LR = 2e-4 __C.TRAIN.GENERATOR_LR = 2e-4 __C.TRAIN.ENCODER_LR = 2e-4 __C.TRAIN.RNN_GRAD_CLIP = 0.25 __C.TRAIN.FLAG = True __C.TRAIN.NET_E = '' __C.TRAIN.NET_G = '' __C.TRAIN.B_NET_D = True __C.TRAIN.SMOOTH = edict() __C.TRAIN.SMOOTH.GAMMA1 = 5.0 __C.TRAIN.SMOOTH.GAMMA3 = 10.0 __C.TRAIN.SMOOTH.GAMMA2 = 5.0 __C.TRAIN.SMOOTH.LAMBDA = 0.0 __C.TRAIN.SMOOTH.LAMBDA1 = 1.0 # Caption_model_settings added by tingting __C.CAP = edict() __C.CAP.embed_size = 256 __C.CAP.hidden_size = 512 __C.CAP.num_layers = 1 __C.CAP.learning_rate = 0.001 __C.CAP.caption_cnn_path = '' __C.CAP.caption_rnn_path = '' # Modal options __C.GAN = edict() __C.GAN.DF_DIM = 64 __C.GAN.GF_DIM = 128 __C.GAN.Z_DIM = 100 __C.GAN.CONDITION_DIM = 100 __C.GAN.R_NUM = 2 __C.GAN.B_ATTENTION = True __C.GAN.B_DCGAN = False __C.TEXT = edict() __C.TEXT.CAPTIONS_PER_IMAGE = 10 __C.TEXT.EMBEDDING_DIM = 256 __C.TEXT.WORDS_NUM = 18 def _merge_a_into_b(a, b): """Merge config dictionary a into config dictionary b, clobbering the options in b whenever they are also specified in a. """ if type(a) is not edict: return for k, v in a.iteritems(): # a must specify keys that are in b if not b.has_key(k): raise KeyError('{} is not a valid config key'.format(k)) # the types must match, too old_type = type(b[k]) if old_type is not type(v): if isinstance(b[k], np.ndarray): v = np.array(v, dtype=b[k].dtype) else: raise ValueError(('Type mismatch ({} vs. {}) ' 'for config key: {}').format(type(b[k]), type(v), k)) # recursively merge dicts if type(v) is edict: try: _merge_a_into_b(a[k], b[k]) except: print('Error under config key: {}'.format(k)) raise else: b[k] = v def cfg_from_file(filename): """Load a config file and merge it into the default options.""" import yaml with open(filename, 'r') as f: yaml_cfg = edict(yaml.load(f)) _merge_a_into_b(yaml_cfg, __C) ================================================ FILE: cfg/eval_bird.yml ================================================ CONFIG_NAME: 'MirrorGAN' DATASET_NAME: 'birds' DATA_DIR: '../data/birds' GPU_ID: 3 WORKERS: 1 B_VALIDATION: True # True # False TREE: BRANCH_NUM: 3 TRAIN: FLAG: False NET_G: '../data/output/bird/Model/netG.pth' # path to the trained model B_NET_D: False BATCH_SIZE: 12 NET_E: '../data/STEM/text_encoder.pth' GAN: DF_DIM: 64 GF_DIM: 32 Z_DIM: 100 R_NUM: 2 TEXT: EMBEDDING_DIM: 256 CAPTIONS_PER_IMAGE: 10 WORDS_NUM: 25 ================================================ FILE: cfg/train_bird.yml ================================================ CONFIG_NAME: 'MirrorGAN' DATASET_NAME: 'birds' DATA_DIR: '../data/birds' GPU_ID: 3 WORKERS: 4 OUTPUT_PATH: '/data/qtt/MirrorGAN/' TREE: BRANCH_NUM: 3 TRAIN: FLAG: True NET_G: '' B_NET_D: True BATCH_SIZE: 12 # 22 MAX_EPOCH: 650 SNAPSHOT_INTERVAL: 50 DISCRIMINATOR_LR: 0.0002 GENERATOR_LR: 0.0002 NET_E: '../data/STEM/text_encoder.pth' SMOOTH: GAMMA1: 4.0 # 1,2,5 good 4 best 10&100bad GAMMA2: 5.0 GAMMA3: 10.0 # 10good 1&100bad LAMBDA: 0.0 LAMBDA1: 10.0 CAP: embed_size: 256 hidden_size: 256 num_layers: 1 learning_rate: 0.001 caption_cnn_path: '../data/STREAM/cnn_encoder.ckpt' caption_rnn_path: '../data/STREAM/rnn_decoder.ckpt' GAN: DF_DIM: 64 GF_DIM: 32 Z_DIM: 100 R_NUM: 2 TEXT: EMBEDDING_DIM: 256 CAPTIONS_PER_IMAGE: 10 ================================================ FILE: datasets.py ================================================ from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals from nltk.tokenize import RegexpTokenizer from collections import defaultdict from cfg.config import cfg import torch import torch.utils.data as data from torch.autograd import Variable import torchvision.transforms as transforms import os import sys import numpy as np import pandas as pd from PIL import Image import numpy.random as random if sys.version_info[0] == 2: import cPickle as pickle else: import pickle def prepare_data(data): imgs, captions, captions_lens, class_ids, keys = data # sort data by the length in a decreasing order sorted_cap_lens, sorted_cap_indices = \ torch.sort(captions_lens, 0, True) real_imgs = [] for i in range(len(imgs)): imgs[i] = imgs[i][sorted_cap_indices] if cfg.CUDA: real_imgs.append(Variable(imgs[i]).cuda()) else: real_imgs.append(Variable(imgs[i])) captions = captions[sorted_cap_indices].squeeze() class_ids = class_ids[sorted_cap_indices].numpy() # sent_indices = sent_indices[sorted_cap_indices] keys = [keys[i] for i in sorted_cap_indices.numpy()] # print('keys', type(keys), keys[-1]) # list if cfg.CUDA: captions = Variable(captions).cuda() sorted_cap_lens = Variable(sorted_cap_lens).cuda() else: captions = Variable(captions) sorted_cap_lens = Variable(sorted_cap_lens) return [real_imgs, captions, sorted_cap_lens, class_ids, keys] def get_imgs(img_path, imsize, bbox=None, transform=None, normalize=None): img = Image.open(img_path).convert('RGB') width, height = img.size if bbox is not None: r = int(np.maximum(bbox[2], bbox[3]) * 0.75) center_x = int((2 * bbox[0] + bbox[2]) / 2) center_y = int((2 * bbox[1] + bbox[3]) / 2) y1 = np.maximum(0, center_y - r) y2 = np.minimum(height, center_y + r) x1 = np.maximum(0, center_x - r) x2 = np.minimum(width, center_x + r) img = img.crop([x1, y1, x2, y2]) if transform is not None: img = transform(img) ret = [] if cfg.GAN.B_DCGAN: ret = [normalize(img)] else: for i in range(cfg.TREE.BRANCH_NUM): # print(imsize[i]) if i < (cfg.TREE.BRANCH_NUM - 1): re_img = transforms.Scale(imsize[i])(img) else: re_img = img ret.append(normalize(re_img)) return ret class TextDataset(data.Dataset): def __init__(self, data_dir, split='train', base_size=64, transform=None, target_transform=None): self.transform = transform self.norm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) self.target_transform = target_transform self.embeddings_num = cfg.TEXT.CAPTIONS_PER_IMAGE self.imsize = [] for i in range(cfg.TREE.BRANCH_NUM): self.imsize.append(base_size) base_size = base_size * 2 self.data = [] self.data_dir = data_dir if data_dir.find('birds') != -1: self.bbox = self.load_bbox() else: self.bbox = None split_dir = os.path.join(data_dir, split) self.filenames, self.captions, self.ixtoword, \ self.wordtoix, self.n_words = self.load_text_data(data_dir, split) self.class_id = self.load_class_id(split_dir, len(self.filenames)) self.number_example = len(self.filenames) def load_bbox(self): data_dir = self.data_dir bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt') df_bounding_boxes = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int) # filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt') df_filenames = \ pd.read_csv(filepath, delim_whitespace=True, header=None) filenames = df_filenames[1].tolist() print('Total filenames: ', len(filenames), filenames[0]) # filename_bbox = {img_file[:-4]: [] for img_file in filenames} numImgs = len(filenames) for i in xrange(0, numImgs): # bbox = [x-left, y-top, width, height] bbox = df_bounding_boxes.iloc[i][1:].tolist() key = filenames[i][:-4] filename_bbox[key] = bbox # return filename_bbox def load_captions(self, data_dir, filenames): all_captions = [] for i in range(len(filenames)): cap_path = '%s/text/%s.txt' % (data_dir, filenames[i]) with open(cap_path, "r") as f: captions = f.read().decode('utf8').split('\n') cnt = 0 for cap in captions: if len(cap) == 0: continue cap = cap.replace("\ufffd\ufffd", " ") # picks out sequences of alphanumeric characters as tokens # and drops everything else tokenizer = RegexpTokenizer(r'\w+') tokens = tokenizer.tokenize(cap.lower()) # print('tokens', tokens) if len(tokens) == 0: print('cap', cap) continue tokens_new = [] for t in tokens: t = t.encode('ascii', 'ignore').decode('ascii') if len(t) > 0: tokens_new.append(t) all_captions.append(tokens_new) cnt += 1 if cnt == self.embeddings_num: break if cnt < self.embeddings_num: print('ERROR: the captions for %s less than %d' % (filenames[i], cnt)) return all_captions def build_dictionary(self, train_captions, test_captions): word_counts = defaultdict(float) captions = train_captions + test_captions for sent in captions: for word in sent: word_counts[word] += 1 vocab = [w for w in word_counts if word_counts[w] >= 0] ixtoword = {} ixtoword[0] = '' wordtoix = {} wordtoix[''] = 0 ix = 1 for w in vocab: wordtoix[w] = ix ixtoword[ix] = w ix += 1 train_captions_new = [] for t in train_captions: rev = [] for w in t: if w in wordtoix: rev.append(wordtoix[w]) # rev.append(0) # do not need '' token train_captions_new.append(rev) test_captions_new = [] for t in test_captions: rev = [] for w in t: if w in wordtoix: rev.append(wordtoix[w]) # rev.append(0) # do not need '' token test_captions_new.append(rev) return [train_captions_new, test_captions_new, ixtoword, wordtoix, len(ixtoword)] def load_text_data(self, data_dir, split): filepath = os.path.join(data_dir, 'bird_captions.pickle') train_names = self.load_filenames(data_dir, 'train') test_names = self.load_filenames(data_dir, 'test') if not os.path.isfile(filepath): train_captions = self.load_captions(data_dir, train_names) test_captions = self.load_captions(data_dir, test_names) train_captions, test_captions, ixtoword, wordtoix, n_words = \ self.build_dictionary(train_captions, test_captions) with open(filepath, 'wb') as f: pickle.dump([train_captions, test_captions, ixtoword, wordtoix], f, protocol=2) print('Save to: ', filepath) else: with open(filepath, 'rb') as f: x = pickle.load(f) train_captions, test_captions = x[0], x[1] ixtoword, wordtoix = x[2], x[3] del x n_words = len(ixtoword) print('Load from: ', filepath) if split == 'train': # a list of list: each list contains # the indices of words in a sentence captions = train_captions filenames = train_names else: # split=='test' captions = test_captions filenames = test_names return filenames, captions, ixtoword, wordtoix, n_words def load_class_id(self, data_dir, total_num): if os.path.isfile(data_dir + '/class_info.pickle'): with open(data_dir + '/class_info.pickle', 'rb') as f: class_id = pickle.load(f) else: class_id = np.arange(total_num) return class_id def load_filenames(self, data_dir, split): filepath = '%s/%s/filenames.pickle' % (data_dir, split) if os.path.isfile(filepath): with open(filepath, 'rb') as f: filenames = pickle.load(f) print('Load filenames from: %s (%d)' % (filepath, len(filenames))) else: filenames = [] return filenames def get_caption(self, sent_ix): # a list of indices for a sentence sent_caption = np.asarray(self.captions[sent_ix]).astype('int64') # if (sent_caption == 0).sum() > 0: # print('ERROR: do not need END (0) token', sent_caption) num_words = len(sent_caption) # pad with 0s (i.e., '') x = np.zeros((cfg.TEXT.WORDS_NUM, 1), dtype='int64') x_len = num_words if num_words <= cfg.TEXT.WORDS_NUM: x[:num_words, 0] = sent_caption else: ix = list(np.arange(num_words)) # 1, 2, 3,..., maxNum np.random.shuffle(ix) ix = ix[:cfg.TEXT.WORDS_NUM] ix = np.sort(ix) x[:, 0] = sent_caption[ix] x_len = cfg.TEXT.WORDS_NUM return x, x_len def __getitem__(self, index): # key = self.filenames[index] cls_id = self.class_id[index] # if self.bbox is not None: bbox = self.bbox[key] data_dir = '%s/CUB_200_2011' % self.data_dir else: bbox = None data_dir = self.data_dir # img_name = '%s/images/%s.jpg' % (data_dir, key) imgs = get_imgs(img_name, self.imsize, bbox, self.transform, normalize=self.norm) # random select a sentence sent_ix = random.randint(0, self.embeddings_num) new_sent_ix = index * self.embeddings_num + sent_ix caps, cap_len = self.get_caption(new_sent_ix) return imgs, caps, cap_len, cls_id, key def __len__(self): return len(self.filenames) ================================================ FILE: do_test.sh ================================================ cfg=cfg/eval_bird.yml python main.py --cfg $cfg ================================================ FILE: do_train.sh ================================================ cfg=cfg/train_bird.yml python main.py --cfg $cfg ================================================ FILE: main.py ================================================ from __future__ import print_function from cfg.config import cfg, cfg_from_file from datasets import TextDataset from trainer import Trainer as trainer import os import sys import time import random import pprint import datetime import dateutil.tz import argparse import numpy as np import torch import torchvision.transforms as transforms dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) sys.path.append(dir_path) def parse_args(): parser = argparse.ArgumentParser(description='Train a AttnGAN network') parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='cfg/bird_attn2.yml', type=str) parser.add_argument('--gpu', dest='gpu_id', type=int, default=-1) parser.add_argument('--data_dir', dest='data_dir', type=str, default='') parser.add_argument('--manualSeed', type=int, help='manual seed') args = parser.parse_args() return args def gen_example(wordtoix, algo): '''generate images from example sentences''' from nltk.tokenize import RegexpTokenizer filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR) data_dic = {} with open(filepath, "r") as f: filenames = f.read().decode('utf8').split('\n') for name in filenames: if len(name) == 0: continue filepath = '%s/%s.txt' % (cfg.DATA_DIR, name) with open(filepath, "r") as f: print('Load from:', name) sentences = f.read().decode('utf8').split('\n') # a list of indices for a sentence captions = [] cap_lens = [] for sent in sentences: if len(sent) == 0: continue sent = sent.replace("\ufffd\ufffd", " ") tokenizer = RegexpTokenizer(r'\w+') tokens = tokenizer.tokenize(sent.lower()) if len(tokens) == 0: print('sent', sent) continue rev = [] for t in tokens: t = t.encode('ascii', 'ignore').decode('ascii') if len(t) > 0 and t in wordtoix: rev.append(wordtoix[t]) captions.append(rev) cap_lens.append(len(rev)) max_len = np.max(cap_lens) sorted_indices = np.argsort(cap_lens)[::-1] cap_lens = np.asarray(cap_lens) cap_lens = cap_lens[sorted_indices] cap_array = np.zeros((len(captions), max_len), dtype='int64') for i in range(len(captions)): idx = sorted_indices[i] cap = captions[idx] c_len = len(cap) cap_array[i, :c_len] = cap key = name[(name.rfind('/') + 1):] data_dic[key] = [cap_array, cap_lens, sorted_indices] algo.gen_example(data_dic) if __name__ == "__main__": args = parse_args() if args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.data_dir != '': cfg.DATA_DIR = args.data_dir print('Using config:') pprint.pprint(cfg) if not cfg.TRAIN.FLAG: args.manualSeed = 100 elif args.manualSeed is None: args.manualSeed = random.randint(1, 10000) random.seed(args.manualSeed) np.random.seed(args.manualSeed) torch.manual_seed(args.manualSeed) if cfg.CUDA: torch.cuda.manual_seed_all(args.manualSeed) now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') output_dir = '%s/output/%s_%s_%s' % \ (cfg.OUTPUT_PATH, cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) split_dir, bshuffle = 'train', True if not cfg.TRAIN.FLAG: # bshuffle = False split_dir = 'test' # Get data loader imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1)) image_transform = transforms.Compose([ transforms.Scale(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), transforms.RandomHorizontalFlip()]) dataset = TextDataset(cfg.DATA_DIR, split_dir, base_size=cfg.TREE.BASE_SIZE, transform=image_transform) assert dataset dataloader = torch.utils.data.DataLoader( dataset, batch_size=cfg.TRAIN.BATCH_SIZE, drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS)) # Define models and go to train/evaluate algo = trainer(output_dir, dataloader, dataset.n_words, dataset.ixtoword) start_t = time.time() if cfg.TRAIN.FLAG: algo.train() else: '''generate images from pre-extracted embeddings''' if cfg.B_VALIDATION: algo.sampling(split_dir) # generate images for the whole valid dataset else: gen_example(dataset.wordtoix, algo) # generate images for customized captions end_t = time.time() print('Total time for training:', end_t - start_t) ================================================ FILE: miscc/__init__.py ================================================ from __future__ import division from __future__ import print_function ================================================ FILE: miscc/losses.py ================================================ import torch import torch.nn as nn import numpy as np from cfg.config import cfg from torch.nn.utils.rnn import pack_padded_sequence from GLAttention import func_attention # ##################Loss for matching text-image################### def cosine_similarity(x1, x2, dim=1, eps=1e-8): """Returns cosine similarity between x1 and x2, computed along dim. """ w12 = torch.sum(x1 * x2, dim) w1 = torch.norm(x1, 2, dim) w2 = torch.norm(x2, 2, dim) return (w12 / (w1 * w2).clamp(min=eps)).squeeze() def caption_loss(cap_output, captions): criterion = nn.CrossEntropyLoss() caption_loss = criterion(cap_output, captions) return caption_loss def sent_loss(cnn_code, rnn_code, labels, class_ids, batch_size, eps=1e-8): # ### Mask mis-match samples ### # that come from the same class as the real sample ### masks = [] if class_ids is not None: for i in range(batch_size): mask = (class_ids == class_ids[i]).astype(np.uint8) mask[i] = 0 masks.append(mask.reshape((1, -1))) masks = np.concatenate(masks, 0) # masks: batch_size x batch_size masks = torch.ByteTensor(masks) if cfg.CUDA: masks = masks.cuda() # --> seq_len x batch_size x nef if cnn_code.dim() == 2: cnn_code = cnn_code.unsqueeze(0) rnn_code = rnn_code.unsqueeze(0) # cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1 cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True) rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True) # scores* / norm*: seq_len x batch_size x batch_size scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2)) norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2)) scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3 # --> batch_size x batch_size scores0 = scores0.squeeze() if class_ids is not None: scores0.data.masked_fill_(masks, -float('inf')) scores1 = scores0.transpose(0, 1) if labels is not None: loss0 = nn.CrossEntropyLoss()(scores0, labels) loss1 = nn.CrossEntropyLoss()(scores1, labels) else: loss0, loss1 = None, None return loss0, loss1 def words_loss(img_features, words_emb, labels, cap_lens, class_ids, batch_size): """ words_emb(query): batch x nef x seq_len img_features(context): batch x nef x 17 x 17 """ masks = [] att_maps = [] similarities = [] cap_lens = cap_lens.data.tolist() for i in range(batch_size): if class_ids is not None: mask = (class_ids == class_ids[i]).astype(np.uint8) mask[i] = 0 masks.append(mask.reshape((1, -1))) # Get the i-th text description words_num = cap_lens[i] # -> 1 x nef x words_num word = words_emb[i, :, :words_num].unsqueeze(0).contiguous() # -> batch_size x nef x words_num word = word.repeat(batch_size, 1, 1) # batch x nef x 17*17 context = img_features """ word(query): batch x nef x words_num context: batch x nef x 17 x 17 weiContext: batch x nef x words_num attn: batch x words_num x 17 x 17 """ weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1) att_maps.append(attn[i].unsqueeze(0).contiguous()) # --> batch_size x words_num x nef word = word.transpose(1, 2).contiguous() weiContext = weiContext.transpose(1, 2).contiguous() # --> batch_size*words_num x nef word = word.view(batch_size * words_num, -1) weiContext = weiContext.view(batch_size * words_num, -1) # # -->batch_size*words_num row_sim = cosine_similarity(word, weiContext) # --> batch_size x words_num row_sim = row_sim.view(batch_size, words_num) # Eq. (10) row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_() row_sim = row_sim.sum(dim=1, keepdim=True) row_sim = torch.log(row_sim) # --> 1 x batch_size # similarities(i, j): the similarity between the i-th image and the j-th text description similarities.append(row_sim) # batch_size x batch_size similarities = torch.cat(similarities, 1) if class_ids is not None: masks = np.concatenate(masks, 0) # masks: batch_size x batch_size masks = torch.ByteTensor(masks) if cfg.CUDA: masks = masks.cuda() similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3 if class_ids is not None: similarities.data.masked_fill_(masks, -float('inf')) similarities1 = similarities.transpose(0, 1) if labels is not None: loss0 = nn.CrossEntropyLoss()(similarities, labels) loss1 = nn.CrossEntropyLoss()(similarities1, labels) else: loss0, loss1 = None, None return loss0, loss1, att_maps # ##################Loss for G and Ds############################## def discriminator_loss(netD, real_imgs, fake_imgs, conditions, real_labels, fake_labels): # Forward real_features = netD(real_imgs) fake_features = netD(fake_imgs.detach()) # loss # cond_real_logits = netD.COND_DNET(real_features, conditions) cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels) cond_fake_logits = netD.COND_DNET(fake_features, conditions) cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels) # batch_size = real_features.size(0) cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size]) cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size]) if netD.UNCOND_DNET is not None: real_logits = netD.UNCOND_DNET(real_features) fake_logits = netD.UNCOND_DNET(fake_features) real_errD = nn.BCELoss()(real_logits, real_labels) fake_errD = nn.BCELoss()(fake_logits, fake_labels) errD = ((real_errD + cond_real_errD) / 2. + (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.) else: errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2. return errD def generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels, words_embs, sent_emb, match_labels, cap_lens, class_ids): numDs = len(netsD) logs = '' # Forward errG_total = 0 for i in range(numDs): features = netsD[i](fake_imgs[i]) cond_logits = netsD[i].COND_DNET(features, sent_emb) cond_errG = nn.BCELoss()(cond_logits, real_labels) if netsD[i].UNCOND_DNET is not None: logits = netsD[i].UNCOND_DNET(features) errG = nn.BCELoss()(logits, real_labels) g_loss = errG + cond_errG else: g_loss = cond_errG errG_total += g_loss logs += 'g_loss%d: %.2f ' % (i, g_loss.data[0]) if i == (numDs - 1): fakeimg_feature = caption_cnn(fake_imgs[i]) captions.cuda() target_cap = pack_padded_sequence(captions, cap_lens.data.tolist(), batch_first=True)[0].cuda() cap_output = caption_rnn(fakeimg_feature, captions, cap_lens) cap_loss = caption_loss(cap_output, target_cap) * cfg.TRAIN.SMOOTH.LAMBDA1 errG_total += cap_loss logs += 'cap_loss: %.2f, ' % cap_loss return errG_total, logs ################################################################## def KL_loss(mu, logvar): # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.mean(KLD_element).mul_(-0.5) return KLD ================================================ FILE: miscc/utils.py ================================================ import os import errno import numpy as np from torch.nn import init import torch import torch.nn as nn from PIL import Image, ImageDraw, ImageFont from copy import deepcopy import skimage.transform from cfg.config import cfg # For visualization ################################################ COLOR_DIC = {0:[128,64,128], 1:[244, 35,232], 2:[70, 70, 70], 3:[102,102,156], 4:[190,153,153], 5:[153,153,153], 6:[250,170, 30], 7:[220, 220, 0], 8:[107,142, 35], 9:[152,251,152], 10:[70,130,180], 11:[220,20, 60], 12:[255, 0, 0], 13:[0, 0, 142], 14:[119,11, 32], 15:[0, 60,100], 16:[0, 80, 100], 17:[0, 0, 230], 18:[0, 0, 70], 19:[0, 0, 0]} FONT_MAX = 50 def drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2): num = captions.size(0) img_txt = Image.fromarray(convas) # get a font # fnt = None # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50) # get a drawing context d = ImageDraw.Draw(img_txt) sentence_list = [] for i in range(num): cap = captions[i].data.cpu().numpy() sentence = [] for j in range(len(cap)): if cap[j] == 0: break word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii') d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]), font=fnt, fill=(255, 255, 255, 255)) sentence.append(word) sentence_list.append(sentence) return img_txt, sentence_list def build_super_images(real_imgs, captions, ixtoword, attn_maps, att_sze, lr_imgs=None, batch_size=cfg.TRAIN.BATCH_SIZE, max_word_num=cfg.TEXT.WORDS_NUM): nvis = 8 real_imgs = real_imgs[:nvis] if lr_imgs is not None: lr_imgs = lr_imgs[:nvis] if att_sze == 17: vis_size = att_sze * 16 else: vis_size = real_imgs.size(2) text_convas = \ np.ones([batch_size * FONT_MAX, (max_word_num + 2) * (vis_size + 2), 3], dtype=np.uint8) for i in range(max_word_num): istart = (i + 2) * (vis_size + 2) iend = (i + 3) * (vis_size + 2) text_convas[:, istart:iend, :] = COLOR_DIC[i] real_imgs = \ nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) # [-1, 1] --> [0, 1] real_imgs.add_(1).div_(2).mul_(255) real_imgs = real_imgs.data.numpy() # b x c x h x w --> b x h x w x c real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) pad_sze = real_imgs.shape middle_pad = np.zeros([pad_sze[2], 2, 3]) post_pad = np.zeros([pad_sze[1], pad_sze[2], 3]) if lr_imgs is not None: lr_imgs = \ nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs) # [-1, 1] --> [0, 1] lr_imgs.add_(1).div_(2).mul_(255) lr_imgs = lr_imgs.data.numpy() # b x c x h x w --> b x h x w x c lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1)) # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 seq_len = max_word_num img_set = [] num = nvis # len(attn_maps) text_map, sentences = \ drawCaption(text_convas, captions, ixtoword, vis_size) text_map = np.asarray(text_map).astype(np.uint8) bUpdate = 1 for i in range(num): attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) # --> 1 x 1 x 17 x 17 attn_max = attn.max(dim=1, keepdim=True) attn = torch.cat([attn_max[0], attn], 1) # attn = attn.view(-1, 1, att_sze, att_sze) attn = attn.repeat(1, 3, 1, 1).data.numpy() # n x c x h x w --> n x h x w x c attn = np.transpose(attn, (0, 2, 3, 1)) num_attn = attn.shape[0] # img = real_imgs[i] if lr_imgs is None: lrI = img else: lrI = lr_imgs[i] row = [lrI, middle_pad] row_merge = [img, middle_pad] row_beforeNorm = [] minVglobal, maxVglobal = 1, 0 for j in range(num_attn): one_map = attn[j] if (vis_size // att_sze) > 1: one_map = \ skimage.transform.pyramid_expand(one_map, sigma=20, upscale=vis_size // att_sze) row_beforeNorm.append(one_map) minV = one_map.min() maxV = one_map.max() if minVglobal > minV: minVglobal = minV if maxVglobal < maxV: maxVglobal = maxV for j in range(seq_len + 1): if j < num_attn: one_map = row_beforeNorm[j] one_map = (one_map - minVglobal) / (maxVglobal - minVglobal) one_map *= 255 # PIL_im = Image.fromarray(np.uint8(img)) PIL_att = Image.fromarray(np.uint8(one_map)) merged = \ Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) mask = Image.new('L', (vis_size, vis_size), (210)) merged.paste(PIL_im, (0, 0)) merged.paste(PIL_att, (0, 0), mask) merged = np.array(merged)[:, :, :3] else: one_map = post_pad merged = post_pad row.append(one_map) row.append(middle_pad) # row_merge.append(merged) row_merge.append(middle_pad) row = np.concatenate(row, 1) row_merge = np.concatenate(row_merge, 1) txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX] if txt.shape[1] != row.shape[1]: print('txt', txt.shape, 'row', row.shape) bUpdate = 0 break row = np.concatenate([txt, row, row_merge], 0) img_set.append(row) if bUpdate: img_set = np.concatenate(img_set, 0) img_set = img_set.astype(np.uint8) return img_set, sentences else: return None def build_super_images2(real_imgs, captions, cap_lens, ixtoword, attn_maps, att_sze, vis_size=256, topK=5): batch_size = real_imgs.size(0) max_word_num = np.max(cap_lens) text_convas = np.ones([batch_size * FONT_MAX, max_word_num * (vis_size + 2), 3], dtype=np.uint8) real_imgs = \ nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs) # [-1, 1] --> [0, 1] real_imgs.add_(1).div_(2).mul_(255) real_imgs = real_imgs.data.numpy() # b x c x h x w --> b x h x w x c real_imgs = np.transpose(real_imgs, (0, 2, 3, 1)) pad_sze = real_imgs.shape middle_pad = np.zeros([pad_sze[2], 2, 3]) # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17 img_set = [] num = len(attn_maps) text_map, sentences = \ drawCaption(text_convas, captions, ixtoword, vis_size, off1=0) text_map = np.asarray(text_map).astype(np.uint8) bUpdate = 1 for i in range(num): attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze) # attn = attn.view(-1, 1, att_sze, att_sze) attn = attn.repeat(1, 3, 1, 1).data.numpy() # n x c x h x w --> n x h x w x c attn = np.transpose(attn, (0, 2, 3, 1)) num_attn = cap_lens[i] thresh = 2./float(num_attn) # img = real_imgs[i] row = [] row_merge = [] row_txt = [] row_beforeNorm = [] conf_score = [] for j in range(num_attn): one_map = attn[j] mask0 = one_map > (2. * thresh) conf_score.append(np.sum(one_map * mask0)) mask = one_map > thresh one_map = one_map * mask if (vis_size // att_sze) > 1: one_map = \ skimage.transform.pyramid_expand(one_map, sigma=20, upscale=vis_size // att_sze) minV = one_map.min() maxV = one_map.max() one_map = (one_map - minV) / (maxV - minV) row_beforeNorm.append(one_map) sorted_indices = np.argsort(conf_score)[::-1] for j in range(num_attn): one_map = row_beforeNorm[j] one_map *= 255 # PIL_im = Image.fromarray(np.uint8(img)) PIL_att = Image.fromarray(np.uint8(one_map)) merged = \ Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0)) mask = Image.new('L', (vis_size, vis_size), (180)) # (210) merged.paste(PIL_im, (0, 0)) merged.paste(PIL_att, (0, 0), mask) merged = np.array(merged)[:, :, :3] row.append(np.concatenate([one_map, middle_pad], 1)) # row_merge.append(np.concatenate([merged, middle_pad], 1)) # txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX, j * (vis_size + 2):(j + 1) * (vis_size + 2), :] row_txt.append(txt) # reorder row_new = [] row_merge_new = [] txt_new = [] for j in range(num_attn): idx = sorted_indices[j] row_new.append(row[idx]) row_merge_new.append(row_merge[idx]) txt_new.append(row_txt[idx]) row = np.concatenate(row_new[:topK], 1) row_merge = np.concatenate(row_merge_new[:topK], 1) txt = np.concatenate(txt_new[:topK], 1) if txt.shape[1] != row.shape[1]: print('Warnings: txt', txt.shape, 'row', row.shape, 'row_merge_new', row_merge_new.shape) bUpdate = 0 break row = np.concatenate([txt, row_merge], 0) img_set.append(row) if bUpdate: img_set = np.concatenate(img_set, 0) img_set = img_set.astype(np.uint8) return img_set, sentences else: return None def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.orthogonal(m.weight.data, 1.0) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) elif classname.find('Linear') != -1: nn.init.orthogonal(m.weight.data, 1.0) if m.bias is not None: m.bias.data.fill_(0.0) def load_params(model, new_param): for p, new_p in zip(model.parameters(), new_param): p.data.copy_(new_p) def copy_G_params(model): flatten = deepcopy(list(p.data for p in model.parameters())) return flatten def mkdir_p(path): try: os.makedirs(path) except OSError as exc: # Python >2.5 if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise ================================================ FILE: model.py ================================================ import torch import torch.nn as nn import torch.nn.parallel from torch.autograd import Variable from torchvision import models import torch.utils.model_zoo as model_zoo import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from cfg.config import cfg from GLAttention import GLAttentionGeneral as ATT_NET class GLU(nn.Module): def __init__(self): super(GLU, self).__init__() def forward(self, x): nc = x.size(1) assert nc % 2 == 0, 'channels dont divide 2!' nc = int(nc/2) return x[:, :nc] * F.sigmoid(x[:, nc:]) def conv1x1(in_planes, out_planes, bias=False): "1x1 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=bias) def conv3x3(in_planes, out_planes): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False) # Upsale the spatial size by a factor of 2 def upBlock(in_planes, out_planes): block = nn.Sequential( nn.Upsample(scale_factor=2, mode='nearest'), conv3x3(in_planes, out_planes * 2), nn.BatchNorm2d(out_planes * 2), GLU()) return block # Keep the spatial size def Block3x3_relu(in_planes, out_planes): block = nn.Sequential( conv3x3(in_planes, out_planes * 2), nn.BatchNorm2d(out_planes * 2), GLU()) return block class ResBlock(nn.Module): def __init__(self, channel_num): super(ResBlock, self).__init__() self.block = nn.Sequential( conv3x3(channel_num, channel_num * 2), nn.BatchNorm2d(channel_num * 2), GLU(), conv3x3(channel_num, channel_num), nn.BatchNorm2d(channel_num)) def forward(self, x): residual = x out = self.block(x) out += residual return out # ############## Text2Image Encoder-Decoder ####### class RNN_ENCODER(nn.Module): def __init__(self, ntoken, ninput=300, drop_prob=0.5, nhidden=128, nlayers=1, bidirectional=True): super(RNN_ENCODER, self).__init__() self.n_steps = cfg.TEXT.WORDS_NUM self.ntoken = ntoken # size of the dictionary self.ninput = ninput # size of each embedding vector self.drop_prob = drop_prob # probability of an element to be zeroed self.nlayers = nlayers # Number of recurrent layers self.bidirectional = bidirectional self.rnn_type = cfg.RNN_TYPE if bidirectional: self.num_directions = 2 else: self.num_directions = 1 # number of features in the hidden state self.nhidden = nhidden // self.num_directions self.define_module() self.init_weights() def define_module(self): self.encoder = nn.Embedding(self.ntoken, self.ninput) self.drop = nn.Dropout(self.drop_prob) if self.rnn_type == 'LSTM': # dropout: If non-zero, introduces a dropout layer on # the outputs of each RNN layer except the last layer self.rnn = nn.LSTM(self.ninput, self.nhidden, self.nlayers, batch_first=True, dropout=self.drop_prob, bidirectional=self.bidirectional) elif self.rnn_type == 'GRU': self.rnn = nn.GRU(self.ninput, self.nhidden, self.nlayers, batch_first=True, dropout=self.drop_prob, bidirectional=self.bidirectional) else: raise NotImplementedError def init_weights(self): initrange = 0.1 self.encoder.weight.data.uniform_(-initrange, initrange) # Do not need to initialize RNN parameters, which have been initialized # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM # self.decoder.weight.data.uniform_(-initrange, initrange) # self.decoder.bias.data.fill_(0) def init_hidden(self, bsz): weight = next(self.parameters()).data if self.rnn_type == 'LSTM': return (Variable(weight.new(self.nlayers * self.num_directions, bsz, self.nhidden).zero_()), Variable(weight.new(self.nlayers * self.num_directions, bsz, self.nhidden).zero_())) else: return Variable(weight.new(self.nlayers * self.num_directions, bsz, self.nhidden).zero_()) def forward(self, captions, cap_lens, hidden, mask=None): # input: torch.LongTensor of size batch x n_steps # --> emb: batch x n_steps x ninput emb = self.drop(self.encoder(captions)) # # Returns: a PackedSequence object cap_lens = cap_lens.data.tolist() emb = pack_padded_sequence(emb, cap_lens, batch_first=True) # #hidden and memory (num_layers * num_directions, batch, hidden_size): # tensor containing the initial hidden state for each element in batch. # #output (batch, seq_len, hidden_size * num_directions) # #or a PackedSequence object: # tensor containing output features (h_t) from the last layer of RNN output, hidden = self.rnn(emb, hidden) # PackedSequence object # --> (batch, seq_len, hidden_size * num_directions) output = pad_packed_sequence(output, batch_first=True)[0] # output = self.drop(output) # --> batch x hidden_size*num_directions x seq_len words_emb = output.transpose(1, 2) # --> batch x num_directions*hidden_size if self.rnn_type == 'LSTM': sent_emb = hidden[0].transpose(0, 1).contiguous() else: sent_emb = hidden.transpose(0, 1).contiguous() sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions) return words_emb, sent_emb class CNN_ENCODER(nn.Module): def __init__(self, nef): super(CNN_ENCODER, self).__init__() if cfg.TRAIN.FLAG: self.nef = nef else: self.nef = 256 # define a uniform ranker model = models.inception_v3() url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth' model.load_state_dict(model_zoo.load_url(url)) for param in model.parameters(): param.requires_grad = False print('Load pretrained model from ', url) # print(model) self.define_module(model) self.init_trainable_weights() def define_module(self, model): self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3 self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3 self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3 self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1 self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3 self.Mixed_5b = model.Mixed_5b self.Mixed_5c = model.Mixed_5c self.Mixed_5d = model.Mixed_5d self.Mixed_6a = model.Mixed_6a self.Mixed_6b = model.Mixed_6b self.Mixed_6c = model.Mixed_6c self.Mixed_6d = model.Mixed_6d self.Mixed_6e = model.Mixed_6e self.Mixed_7a = model.Mixed_7a self.Mixed_7b = model.Mixed_7b self.Mixed_7c = model.Mixed_7c self.emb_features = conv1x1(768, self.nef) self.emb_cnn_code = nn.Linear(2048, self.nef) def init_trainable_weights(self): initrange = 0.1 self.emb_features.weight.data.uniform_(-initrange, initrange) self.emb_cnn_code.weight.data.uniform_(-initrange, initrange) def forward(self, x): features = None # --> fixed-size input: batch x 3 x 299 x 299 x = nn.Upsample(size=(299, 299), mode='bilinear')(x) # 299 x 299 x 3 x = self.Conv2d_1a_3x3(x) # 149 x 149 x 32 x = self.Conv2d_2a_3x3(x) # 147 x 147 x 32 x = self.Conv2d_2b_3x3(x) # 147 x 147 x 64 x = F.max_pool2d(x, kernel_size=3, stride=2) # 73 x 73 x 64 x = self.Conv2d_3b_1x1(x) # 73 x 73 x 80 x = self.Conv2d_4a_3x3(x) # 71 x 71 x 192 x = F.max_pool2d(x, kernel_size=3, stride=2) # 35 x 35 x 192 x = self.Mixed_5b(x) # 35 x 35 x 256 x = self.Mixed_5c(x) # 35 x 35 x 288 x = self.Mixed_5d(x) # 35 x 35 x 288 x = self.Mixed_6a(x) # 17 x 17 x 768 x = self.Mixed_6b(x) # 17 x 17 x 768 x = self.Mixed_6c(x) # 17 x 17 x 768 x = self.Mixed_6d(x) # 17 x 17 x 768 x = self.Mixed_6e(x) # 17 x 17 x 768 # image region features features = x # 17 x 17 x 768 x = self.Mixed_7a(x) # 8 x 8 x 1280 x = self.Mixed_7b(x) # 8 x 8 x 2048 x = self.Mixed_7c(x) # 8 x 8 x 2048 x = F.avg_pool2d(x, kernel_size=8) # 1 x 1 x 2048 # x = F.dropout(x, training=self.training) # 1 x 1 x 2048 x = x.view(x.size(0), -1) # 2048 # global image features cnn_code = self.emb_cnn_code(x) # 512 if features is not None: features = self.emb_features(features) return features, cnn_code # ############## G networks ################### class CA_NET(nn.Module): # some code is modified from vae examples # (https://github.com/pytorch/examples/blob/master/vae/main.py) def __init__(self): super(CA_NET, self).__init__() self.t_dim = cfg.TEXT.EMBEDDING_DIM self.c_dim = cfg.GAN.CONDITION_DIM self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True) self.relu = GLU() def encode(self, text_embedding): x = self.relu(self.fc(text_embedding)) mu = x[:, :self.c_dim] logvar = x[:, self.c_dim:] return mu, logvar def reparametrize(self, mu, logvar): std = logvar.mul(0.5).exp_() if cfg.CUDA: eps = torch.cuda.FloatTensor(std.size()).normal_() else: eps = torch.FloatTensor(std.size()).normal_() eps = Variable(eps) return eps.mul(std).add_(mu) def forward(self, text_embedding): mu, logvar = self.encode(text_embedding) c_code = self.reparametrize(mu, logvar) return c_code, mu, logvar class INIT_STAGE_G(nn.Module): def __init__(self, ngf, ncf): super(INIT_STAGE_G, self).__init__() self.gf_dim = ngf self.in_dim = cfg.GAN.Z_DIM + ncf # cfg.TEXT.EMBEDDING_DIM self.define_module() def define_module(self): nz, ngf = self.in_dim, self.gf_dim self.fc = nn.Sequential( nn.Linear(nz, ngf * 4 * 4 * 2, bias=False), nn.BatchNorm1d(ngf * 4 * 4 * 2), GLU()) self.upsample1 = upBlock(ngf, ngf // 2) self.upsample2 = upBlock(ngf // 2, ngf // 4) self.upsample3 = upBlock(ngf // 4, ngf // 8) self.upsample4 = upBlock(ngf // 8, ngf // 16) def forward(self, z_code, c_code): """ :param z_code: batch x cfg.GAN.Z_DIM :param c_code: batch x cfg.TEXT.EMBEDDING_DIM :return: batch x ngf/16 x 64 x 64 """ c_z_code = torch.cat((c_code, z_code), 1) # state size ngf x 4 x 4 out_code = self.fc(c_z_code) out_code = out_code.view(-1, self.gf_dim, 4, 4) # state size ngf/3 x 8 x 8 out_code = self.upsample1(out_code) # state size ngf/4 x 16 x 16 out_code = self.upsample2(out_code) # state size ngf/8 x 32 x 32 out_code32 = self.upsample3(out_code) # state size ngf/16 x 64 x 64 out_code64 = self.upsample4(out_code32) return out_code64 # class NEXT_STAGE_G(nn.Module): # def __init__(self, ngf, nef, ncf): # super(NEXT_STAGE_G, self).__init__() # self.gf_dim = ngf # self.ef_dim = nef # self.cf_dim = ncf # self.num_residual = cfg.GAN.R_NUM # self.define_module() # # def _make_layer(self, block, channel_num): # layers = [] # for i in range(cfg.GAN.R_NUM): # layers.append(block(channel_num)) # return nn.Sequential(*layers) # # def define_module(self): # ngf = self.gf_dim # self.att = ATT_NET(ngf, self.ef_dim) # self.residual = self._make_layer(ResBlock, ngf * 2) # self.upsample = upBlock(ngf * 2, ngf) # # def forward(self, h_code, c_code, word_embs, mask): # """ # h_code1(query): batch x idf x ih x iw (queryL=ihxiw) # word_embs(context): batch x cdf x sourceL (sourceL=seq_len) # c_code1: batch x idf x queryL # att1: batch x sourceL x queryL # """ # self.att.applyMask(mask) # c_code, att = self.att(h_code, word_embs) # h_c_code = torch.cat((h_code, c_code), 1) # print('h_c_code:', h_c_code.size()) \ # ('h_c_code:', (16, 64, 64, 64)) # ('h_c_code:', (16, 64, 128, 128)) # out_code = self.residual(h_c_code) # # # state size ngf/2 x 2in_size x 2in_size # out_code = self.upsample(out_code) # # return out_code, att class NEXT_STAGE_G(nn.Module): def __init__(self, ngf, nef, ncf): super(NEXT_STAGE_G, self).__init__() self.gf_dim = ngf self.ef_dim = nef self.cf_dim = ncf # print(ngf, nef, ncf) (32, 256, 100) # (32, 256, 100) self.num_residual = cfg.GAN.R_NUM self.define_module() self.conv = conv1x1(ngf * 3, ngf * 2) def _make_layer(self, block, channel_num): layers = [] for i in range(cfg.GAN.R_NUM): # 2 layers.append(block(channel_num)) return nn.Sequential(*layers) def define_module(self): ngf = self.gf_dim self.att = ATT_NET(ngf, self.ef_dim) self.residual = self._make_layer(ResBlock, ngf * 2) self.upsample = upBlock(ngf * 2, ngf) def forward(self, h_code, c_code, word_embs, mask): """ h_code1(query): batch x idf x ih x iw (queryL=ihxiw) word_embs(context): batch x cdf x sourceL (sourceL=seq_len) c_code1: batch x idf x queryL att1: batch x sourceL x queryL """ # print('========') # ((16, 32, 64, 64), (16, 100), (16, 256, 18), (16, 18)) # print(h_code.size(), c_code.size(), word_embs.size(), mask.size()) self.att.applyMask(mask) # here, a new c_code is generated by self.att() method. # weightedContext, weightedSentence, word_attn, sent_vs_att c_code, weightedSentence, att, sent_att = self.att(h_code, c_code, word_embs) # Then, image feature are concated with a new c_code, they become h_c_code, # so, here I can make some change, to concate more items together. # which means I need to get more output from line 369, self.att() # also, I need to feed more information to calculate the function, and let's see what the new idea will return. h_c_code = torch.cat((h_code, c_code), 1) # print('h_c_code.size:', h_c_code.size()) # ('h_c_code.size:', (16, 64, 64, 64)) h_c_sent_code = torch.cat((h_c_code, weightedSentence), 1) # print('h_c_sent_code.size:', h_c_sent_code.size()) # ('h_c_code.size:', (16, 64, 64, 64)) # ('h_c_sent_code.size:', (16, 96, 64, 64)) h_c_sent_code = self.conv(h_c_sent_code) out_code = self.residual(h_c_sent_code) # print('out_code:', out_code.size()) # state size ngf/2 x 2in_size x 2in_size out_code = self.upsample(out_code) return out_code, att class GET_IMAGE_G(nn.Module): def __init__(self, ngf): super(GET_IMAGE_G, self).__init__() self.gf_dim = ngf self.img = nn.Sequential( conv3x3(ngf, 3), nn.Tanh() ) def forward(self, h_code): out_img = self.img(h_code) return out_img #G_NET used in the paper class G_NET(nn.Module): def __init__(self): super(G_NET, self).__init__() ngf = cfg.GAN.GF_DIM nef = cfg.TEXT.EMBEDDING_DIM ncf = cfg.GAN.CONDITION_DIM self.ca_net = CA_NET() if cfg.TREE.BRANCH_NUM > 0: self.h_net1 = INIT_STAGE_G(ngf * 16, ncf) self.img_net1 = GET_IMAGE_G(ngf) # gf x 64 x 64 if cfg.TREE.BRANCH_NUM > 1: self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf) self.img_net2 = GET_IMAGE_G(ngf) if cfg.TREE.BRANCH_NUM > 2: self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf) self.img_net3 = GET_IMAGE_G(ngf) # netG(noise, sent_emb, words_embs, mask) def forward(self, z_code, sent_emb, word_embs, mask): """ :param z_code: batch x cfg.GAN.Z_DIM :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM :param word_embs: batch x cdf x seq_len :param mask: batch x seq_len :return: """ fake_imgs = [] att_maps = [] '''this is the Conditioning Augmentation''' # print('sent_emb:', sent_emb.size()) #('sent_emb:', (16, 256)) c_code, mu, logvar = self.ca_net(sent_emb) # print('=====') # print('first c_code.size():', c_code.size()) #(16, 100) # print('=====') if cfg.TREE.BRANCH_NUM > 0: h_code1 = self.h_net1(z_code, c_code) fake_img1 = self.img_net1(h_code1) fake_imgs.append(fake_img1) if cfg.TREE.BRANCH_NUM > 1: h_code2, att1 = \ self.h_net2(h_code1, c_code, word_embs, mask) fake_img2 = self.img_net2(h_code2) fake_imgs.append(fake_img2) if att1 is not None: att_maps.append(att1) if cfg.TREE.BRANCH_NUM > 2: h_code3, att2 = \ self.h_net3(h_code2, c_code, word_embs, mask) fake_img3 = self.img_net3(h_code3) fake_imgs.append(fake_img3) if att2 is not None: att_maps.append(att2) return fake_imgs, att_maps, mu, logvar class G_DCGAN(nn.Module): def __init__(self): super(G_DCGAN, self).__init__() ngf = cfg.GAN.GF_DIM nef = cfg.TEXT.EMBEDDING_DIM ncf = cfg.GAN.CONDITION_DIM self.ca_net = CA_NET() # 16gf x 64 x 64 --> gf x 64 x 64 --> 3 x 64 x 64 if cfg.TREE.BRANCH_NUM > 0: self.h_net1 = INIT_STAGE_G(ngf * 16, ncf) # gf x 64 x 64 if cfg.TREE.BRANCH_NUM > 1: self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf) if cfg.TREE.BRANCH_NUM > 2: self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf) self.img_net = GET_IMAGE_G(ngf) def forward(self, z_code, sent_emb, word_embs, mask): """ :param z_code: batch x cfg.GAN.Z_DIM :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM :param word_embs: batch x cdf x seq_len :param mask: batch x seq_len :return: """ att_maps = [] c_code, mu, logvar = self.ca_net(sent_emb) if cfg.TREE.BRANCH_NUM > 0: h_code = self.h_net1(z_code, c_code) if cfg.TREE.BRANCH_NUM > 1: h_code, att1 = self.h_net2(h_code, c_code, word_embs, mask) if att1 is not None: att_maps.append(att1) if cfg.TREE.BRANCH_NUM > 2: h_code, att2 = self.h_net3(h_code, c_code, word_embs, mask) if att2 is not None: att_maps.append(att2) fake_imgs = self.img_net(h_code) return [fake_imgs], att_maps, mu, logvar # ############## D networks ########################## def Block3x3_leakRelu(in_planes, out_planes): block = nn.Sequential( conv3x3(in_planes, out_planes), nn.BatchNorm2d(out_planes), nn.LeakyReLU(0.2, inplace=True) ) return block # Downsale the spatial size by a factor of 2 def downBlock(in_planes, out_planes): block = nn.Sequential( nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False), nn.BatchNorm2d(out_planes), nn.LeakyReLU(0.2, inplace=True) ) return block # Downsale the spatial size by a factor of 16 def encode_image_by_16times(ndf): encode_img = nn.Sequential( # --> state size. ndf x in_size/2 x in_size/2 nn.Conv2d(3, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # --> state size 2ndf x x in_size/4 x in_size/4 nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), # --> state size 4ndf x in_size/8 x in_size/8 nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), # --> state size 8ndf x in_size/16 x in_size/16 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True) ) return encode_img class D_GET_LOGITS(nn.Module): def __init__(self, ndf, nef, bcondition=False): super(D_GET_LOGITS, self).__init__() self.df_dim = ndf self.ef_dim = nef self.bcondition = bcondition if self.bcondition: self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8) self.outlogits = nn.Sequential( nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4), nn.Sigmoid()) def forward(self, h_code, c_code=None): if self.bcondition and c_code is not None: # conditioning output c_code = c_code.view(-1, self.ef_dim, 1, 1) c_code = c_code.repeat(1, 1, 4, 4) # state size (ngf+egf) x 4 x 4 h_c_code = torch.cat((h_code, c_code), 1) # state size ngf x in_size x in_size h_c_code = self.jointConv(h_c_code) else: h_c_code = h_code output = self.outlogits(h_c_code) return output.view(-1) # For 64 x 64 images class D_NET64(nn.Module): def __init__(self, b_jcu=True): super(D_NET64, self).__init__() ndf = cfg.GAN.DF_DIM nef = cfg.TEXT.EMBEDDING_DIM self.img_code_s16 = encode_image_by_16times(ndf) if b_jcu: self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) else: self.UNCOND_DNET = None self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) def forward(self, x_var): x_code4 = self.img_code_s16(x_var) # 4 x 4 x 8df return x_code4 # For 128 x 128 images class D_NET128(nn.Module): def __init__(self, b_jcu=True): super(D_NET128, self).__init__() ndf = cfg.GAN.DF_DIM nef = cfg.TEXT.EMBEDDING_DIM self.img_code_s16 = encode_image_by_16times(ndf) self.img_code_s32 = downBlock(ndf * 8, ndf * 16) self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8) # if b_jcu: self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) else: self.UNCOND_DNET = None self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) def forward(self, x_var): x_code8 = self.img_code_s16(x_var) # 8 x 8 x 8df x_code4 = self.img_code_s32(x_code8) # 4 x 4 x 16df x_code4 = self.img_code_s32_1(x_code4) # 4 x 4 x 8df return x_code4 # For 256 x 256 images class D_NET256(nn.Module): def __init__(self, b_jcu=True): super(D_NET256, self).__init__() ndf = cfg.GAN.DF_DIM nef = cfg.TEXT.EMBEDDING_DIM self.img_code_s16 = encode_image_by_16times(ndf) self.img_code_s32 = downBlock(ndf * 8, ndf * 16) self.img_code_s64 = downBlock(ndf * 16, ndf * 32) self.img_code_s64_1 = Block3x3_leakRelu(ndf * 32, ndf * 16) self.img_code_s64_2 = Block3x3_leakRelu(ndf * 16, ndf * 8) if b_jcu: self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False) else: self.UNCOND_DNET = None self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True) def forward(self, x_var): x_code16 = self.img_code_s16(x_var) x_code8 = self.img_code_s32(x_code16) x_code4 = self.img_code_s64(x_code8) x_code4 = self.img_code_s64_1(x_code4) x_code4 = self.img_code_s64_2(x_code4) return x_code4 class CAPTION_CNN(nn.Module): def __init__(self, embed_size): """Load the pretrained ResNet-152 and replace top fc layer.""" super(CAPTION_CNN, self).__init__() resnet = models.resnet152(pretrained=True) modules = list(resnet.children())[:-1] # delete the last fc layer. self.resnet = nn.Sequential(*modules) for param in self.resnet.parameters(): param.requires_grad = False self.linear = nn.Linear(resnet.fc.in_features, embed_size) self.bn = nn.BatchNorm1d(embed_size, momentum=0.01) def forward(self, images): """Extract feature vectors from input images.""" #print ('image feature size before unsample:', images.size()) m = nn.Upsample(size=(224, 224), mode='bilinear') unsampled_images = m(images) #print ('image feature size after unsample:', unsampled_images.size()) features = self.resnet(unsampled_images) features = features.view(features.size(0), -1) features = self.bn(self.linear(features)) return features class CAPTION_RNN(nn.Module): def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20): """Set the hyper-parameters and build the layers.""" super(CAPTION_RNN, self).__init__() self.embed = nn.Embedding(vocab_size, embed_size) self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, vocab_size) self.max_seg_length = max_seq_length # def forward(self, features, captions, cap_lens): # """Decode image feature vectors and generates captions.""" # # print ('feature.size():', features.size()) #(6L, 256L) # # print ('captions.size():', captions.size()) # (6L, 12L) # # print ('embeddings.size:',embeddings.size()) #(6L, 12L, 256L) # embeddings = self.embed(captions) # embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) # packed = pack_padded_sequence(embeddings, cap_lens.data.tolist(), batch_first=True) # outputs, hidden = self.lstm(packed) # output = self.linear(outputs[0]) # (batch size, vocab_size) # return output, hidden, outputs # words embedding, sentence embedding def forward(self, features, captions, cap_lens): """Decode image feature vectors and generates captions.""" embeddings = self.embed(captions) embeddings = torch.cat((features.unsqueeze(1), embeddings), 1) packed = pack_padded_sequence(embeddings, cap_lens, batch_first=True) hiddens, _ = self.lstm(packed) outputs = self.linear(hiddens[0]) return outputs def sample(self, features, states=None): """Generate captions for given image features using greedy search.""" sampled_ids = [] inputs = features.unsqueeze(1) for i in range(self.max_seg_length): hiddens, states = self.lstm(inputs, states) # hiddens: (batch_size, 1, hidden_size) outputs = self.linear(hiddens.squeeze(1)) # outputs: (batch_size, vocab_size) _, predicted = outputs.max(1) # predicted: (batch_size) sampled_ids.append(predicted) inputs = self.embed(predicted) # inputs: (batch_size, embed_size) inputs = inputs.unsqueeze(1) # inputs: (batch_size, 1, embed_size) sampled_ids = torch.stack(sampled_ids, 1) # sampled_ids: (batch_size, max_seq_length) return sampled_ids ================================================ FILE: pretrain_DAMSM.py ================================================ from __future__ import print_function from miscc.utils import mkdir_p from miscc.utils import build_super_images from miscc.losses import sent_loss, words_loss from cfg.config import cfg, cfg_from_file from datasets import TextDataset from datasets import prepare_data from model import RNN_ENCODER, CNN_ENCODER import os import sys import time import random import pprint import datetime import dateutil.tz import argparse import numpy as np from PIL import Image import torch import torch.nn as nn import torch.optim as optim from torch.autograd import Variable import torch.backends.cudnn as cudnn import torchvision.transforms as transforms dir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.'))) sys.path.append(dir_path) UPDATE_INTERVAL = 200 def parse_args(): parser = argparse.ArgumentParser(description='Train a DAMSM network') parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='cfg/DAMSM/bird.yml', type=str) parser.add_argument('--gpu', dest='gpu_id', type=int, default=0) parser.add_argument('--data_dir', dest='data_dir', type=str, default='') parser.add_argument('--manualSeed', type=int, help='manual seed') args = parser.parse_args() return args def train(dataloader, cnn_model, rnn_model, batch_size, labels, optimizer, epoch, ixtoword, image_dir): cnn_model.train() rnn_model.train() s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 count = (epoch + 1) * len(dataloader) start_time = time.time() for step, data in enumerate(dataloader, 0): # print('step', step) rnn_model.zero_grad() cnn_model.zero_grad() imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) # words_features: batch_size x nef x 17 x 17 # sent_code: batch_size x nef words_features, sent_code = cnn_model(imgs[-1]) # --> batch_size x nef x 17*17 nef, att_sze = words_features.size(1), words_features.size(2) # words_features = words_features.view(batch_size, nef, -1) hidden = rnn_model.init_hidden(batch_size) # words_emb: batch_size x nef x seq_len # sent_emb: batch_size x nef words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss0 += w_loss0.data w_total_loss1 += w_loss1.data loss = w_loss0 + w_loss1 s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) loss += s_loss0 + s_loss1 s_total_loss0 += s_loss0.data s_total_loss1 += s_loss1.data # loss.backward() # # `clip_grad_norm` helps prevent # the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(rnn_model.parameters(), cfg.TRAIN.RNN_GRAD_CLIP) optimizer.step() if step % UPDATE_INTERVAL == 0: count = epoch * len(dataloader) + step s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | ' 's_loss {:5.2f} {:5.2f} | ' 'w_loss {:5.2f} {:5.2f}' .format(epoch, step, len(dataloader), elapsed * 1000. / UPDATE_INTERVAL, s_cur_loss0, s_cur_loss1, w_cur_loss0, w_cur_loss1)) s_total_loss0 = 0 s_total_loss1 = 0 w_total_loss0 = 0 w_total_loss1 = 0 start_time = time.time() # attention Maps img_set, _ = \ build_super_images(imgs[-1].cpu(), captions, ixtoword, attn_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/attention_maps%d.png' % (image_dir, step) im.save(fullpath) return count def evaluate(dataloader, cnn_model, rnn_model, batch_size): cnn_model.eval() rnn_model.eval() s_total_loss = 0 w_total_loss = 0 for step, data in enumerate(dataloader, 0): real_imgs, captions, cap_lens, \ class_ids, keys = prepare_data(data) words_features, sent_code = cnn_model(real_imgs[-1]) # nef = words_features.size(1) # words_features = words_features.view(batch_size, nef, -1) hidden = rnn_model.init_hidden(batch_size) words_emb, sent_emb = rnn_model(captions, cap_lens, hidden) w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels, cap_lens, class_ids, batch_size) w_total_loss += (w_loss0 + w_loss1).data s_loss0, s_loss1 = \ sent_loss(sent_code, sent_emb, labels, class_ids, batch_size) s_total_loss += (s_loss0 + s_loss1).data if step == 50: break s_cur_loss = s_total_loss[0] / step w_cur_loss = w_total_loss[0] / step return s_cur_loss, w_cur_loss def build_models(): # build model ############################################################ text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) labels = Variable(torch.LongTensor(range(batch_size))) start_epoch = 0 if cfg.TRAIN.NET_E != '': state_dict = torch.load(cfg.TRAIN.NET_E) text_encoder.load_state_dict(state_dict) print('Load ', cfg.TRAIN.NET_E) # name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = torch.load(name) image_encoder.load_state_dict(state_dict) print('Load ', name) istart = cfg.TRAIN.NET_E.rfind('_') + 8 iend = cfg.TRAIN.NET_E.rfind('.') start_epoch = cfg.TRAIN.NET_E[istart:iend] start_epoch = int(start_epoch) + 1 print('start_epoch', start_epoch) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() labels = labels.cuda() return text_encoder, image_encoder, labels, start_epoch if __name__ == "__main__": args = parse_args() if args.cfg_file is not None: cfg_from_file(args.cfg_file) if args.gpu_id == -1: cfg.CUDA = False else: cfg.GPU_ID = args.gpu_id if args.data_dir != '': cfg.DATA_DIR = args.data_dir print('Using config:') pprint.pprint(cfg) if not cfg.TRAIN.FLAG: args.manualSeed = 100 elif args.manualSeed is None: args.manualSeed = random.randint(1, 10000) random.seed(args.manualSeed) np.random.seed(args.manualSeed) torch.manual_seed(args.manualSeed) if cfg.CUDA: torch.cuda.manual_seed_all(args.manualSeed) ########################################################################## now = datetime.datetime.now(dateutil.tz.tzlocal()) timestamp = now.strftime('%Y_%m_%d_%H_%M_%S') output_dir = '../output/%s_%s_%s' % \ (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp) model_dir = os.path.join(output_dir, 'Model') image_dir = os.path.join(output_dir, 'Image') mkdir_p(model_dir) mkdir_p(image_dir) torch.cuda.set_device(cfg.GPU_ID) cudnn.benchmark = True # Get data loader ################################################## imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1)) batch_size = cfg.TRAIN.BATCH_SIZE image_transform = transforms.Compose([ transforms.Scale(int(imsize * 76 / 64)), transforms.RandomCrop(imsize), transforms.RandomHorizontalFlip()]) dataset = TextDataset(cfg.DATA_DIR, 'train', base_size=cfg.TREE.BASE_SIZE, transform=image_transform) print(dataset.n_words, dataset.embeddings_num) assert dataset dataloader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) # # validation data # dataset_val = TextDataset(cfg.DATA_DIR, 'test', base_size=cfg.TREE.BASE_SIZE, transform=image_transform) dataloader_val = torch.utils.data.DataLoader( dataset_val, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=int(cfg.WORKERS)) # Train ############################################################## text_encoder, image_encoder, labels, start_epoch = build_models() para = list(text_encoder.parameters()) for v in image_encoder.parameters(): if v.requires_grad: para.append(v) # optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999)) # At any point you can hit Ctrl + C to break out of training early. try: lr = cfg.TRAIN.ENCODER_LR for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH): optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999)) epoch_start_time = time.time() count = train(dataloader, image_encoder, text_encoder, batch_size, labels, optimizer, epoch, dataset.ixtoword, image_dir) print('-' * 89) if len(dataloader_val) > 0: s_loss, w_loss = evaluate(dataloader_val, image_encoder, text_encoder, batch_size) print('| end epoch {:3d} | valid loss ' '{:5.2f} {:5.2f} | lr {:.5f}|' .format(epoch, s_loss, w_loss, lr)) print('-' * 89) if lr > cfg.TRAIN.ENCODER_LR/10.: lr *= 0.98 if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or epoch == cfg.TRAIN.MAX_EPOCH): torch.save(image_encoder.state_dict(), '%s/image_encoder%d.pth' % (model_dir, epoch)) torch.save(text_encoder.state_dict(), '%s/text_encoder%d.pth' % (model_dir, epoch)) print('Save G/Ds models.') except KeyboardInterrupt: print('-' * 89) print('Exiting from training early') ================================================ FILE: test.py ================================================ import torch.nn as nn import torch from torch.autograd import Variable def conv1x1(in_planes, out_planes): "1x1 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) x = Variable(torch.rand(2,3,1,1)) print(x.size()) y = conv1x1(3, 3)(x) print(y.size()) # z = torch.cat(x, x) # print(z.size()) t = torch.mul(x, x) print(t.size()) ================================================ FILE: trainer.py ================================================ from __future__ import print_function from six.moves import range import torch import torch.optim as optim from torch.autograd import Variable import torch.backends.cudnn as cudnn from PIL import Image from cfg.config import cfg from miscc.utils import mkdir_p from miscc.utils import build_super_images, build_super_images2 from miscc.utils import weights_init, load_params, copy_G_params from model import G_DCGAN, G_NET from datasets import prepare_data from model import RNN_ENCODER, CNN_ENCODER, CAPTION_CNN, CAPTION_RNN from miscc.losses import words_loss from miscc.losses import discriminator_loss, generator_loss, KL_loss import os import time import numpy as np # MirrorGAN class Trainer(object): def __init__(self, output_dir, data_loader, n_words, ixtoword): if cfg.TRAIN.FLAG: self.model_dir = os.path.join(output_dir, 'Model') self.image_dir = os.path.join(output_dir, 'Image') mkdir_p(self.model_dir) mkdir_p(self.image_dir) torch.cuda.set_device(cfg.GPU_ID) cudnn.benchmark = True self.batch_size = cfg.TRAIN.BATCH_SIZE self.max_epoch = cfg.TRAIN.MAX_EPOCH self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL self.n_words = n_words self.ixtoword = ixtoword self.data_loader = data_loader self.num_batches = len(self.data_loader) def build_models(self): # text encoders if cfg.TRAIN.NET_E == '': print('Error: no pretrained text-image encoders') return image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM) img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder') state_dict = \ torch.load(img_encoder_path, map_location=lambda storage, loc: storage) image_encoder.load_state_dict(state_dict) for p in image_encoder.parameters(): p.requires_grad = False print('Load image encoder from:', img_encoder_path) image_encoder.eval() text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) for p in text_encoder.parameters(): p.requires_grad = False print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder.eval() # Caption models - cnn_encoder and rnn_decoder caption_cnn = CAPTION_CNN(cfg.CAP.embed_size) caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage)) for p in caption_cnn.parameters(): p.requires_grad = False print('Load caption model from:', cfg.CAP.caption_cnn_path) caption_cnn.eval() caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers) caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage)) for p in caption_rnn.parameters(): p.requires_grad = False print('Load caption model from:', cfg.CAP.caption_rnn_path) # Generator and Discriminator: netsD = [] if cfg.GAN.B_DCGAN: if cfg.TREE.BRANCH_NUM == 1: from model import D_NET64 as D_NET elif cfg.TREE.BRANCH_NUM == 2: from model import D_NET128 as D_NET else: # cfg.TREE.BRANCH_NUM == 3: from model import D_NET256 as D_NET netG = G_DCGAN() netsD = [D_NET(b_jcu=False)] else: from model import D_NET64, D_NET128, D_NET256 netG = G_NET() if cfg.TREE.BRANCH_NUM > 0: netsD.append(D_NET64()) if cfg.TREE.BRANCH_NUM > 1: netsD.append(D_NET128()) if cfg.TREE.BRANCH_NUM > 2: netsD.append(D_NET256()) netG.apply(weights_init) # print(netG) for i in range(len(netsD)): netsD[i].apply(weights_init) # print(netsD[i]) print('# of netsD', len(netsD)) epoch = 0 if cfg.TRAIN.NET_G != '': state_dict = \ torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', cfg.TRAIN.NET_G) istart = cfg.TRAIN.NET_G.rfind('_') + 1 iend = cfg.TRAIN.NET_G.rfind('.') epoch = cfg.TRAIN.NET_G[istart:iend] epoch = int(epoch) + 1 if cfg.TRAIN.B_NET_D: Gname = cfg.TRAIN.NET_G for i in range(len(netsD)): s_tmp = Gname[:Gname.rfind('/')] Dname = '%s/netD%d.pth' % (s_tmp, i) print('Load D from: ', Dname) state_dict = \ torch.load(Dname, map_location=lambda storage, loc: storage) netsD[i].load_state_dict(state_dict) if cfg.CUDA: text_encoder = text_encoder.cuda() image_encoder = image_encoder.cuda() caption_cnn = caption_cnn.cuda() caption_rnn = caption_rnn.cuda() netG.cuda() for i in range(len(netsD)): netsD[i].cuda() return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch] def define_optimizers(self, netG, netsD): optimizersD = [] num_Ds = len(netsD) for i in range(num_Ds): opt = optim.Adam(netsD[i].parameters(), lr=cfg.TRAIN.DISCRIMINATOR_LR, betas=(0.5, 0.999)) optimizersD.append(opt) optimizerG = optim.Adam(netG.parameters(), lr=cfg.TRAIN.GENERATOR_LR, betas=(0.5, 0.999)) return optimizerG, optimizersD def prepare_labels(self): batch_size = self.batch_size real_labels = Variable(torch.FloatTensor(batch_size).fill_(1)) fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0)) match_labels = Variable(torch.LongTensor(range(batch_size))) if cfg.CUDA: real_labels = real_labels.cuda() fake_labels = fake_labels.cuda() match_labels = match_labels.cuda() return real_labels, fake_labels, match_labels def save_model(self, netG, avg_param_G, netsD, epoch): backup_para = copy_G_params(netG) load_params(netG, avg_param_G) torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (self.model_dir, epoch)) load_params(netG, backup_para) # for i in range(len(netsD)): netD = netsD[i] torch.save(netD.state_dict(), '%s/netD%d.pth' % (self.model_dir, i)) print('Save G/Ds models.') def set_requires_grad_value(self, models_list, brequires): for i in range(len(models_list)): for p in models_list[i].parameters(): p.requires_grad = brequires def save_img_results(self, netG, noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, gen_iterations, name='current'): # Save images fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) for i in range(len(attention_maps)): if len(fake_imgs) > 1: img = fake_imgs[i + 1].detach().cpu() lr_img = fake_imgs[i].detach().cpu() else: img = fake_imgs[0].detach().cpu() lr_img = None attn_maps = attention_maps[i] att_sze = attn_maps.size(2) img_set, _ = \ build_super_images(img, captions, self.ixtoword, attn_maps, att_sze, lr_imgs=lr_img) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/G_%s_%d_%d.png' \ % (self.image_dir, name, gen_iterations, i) im.save(fullpath) i = -1 img = fake_imgs[i].detach() region_features, _ = image_encoder(img) att_sze = region_features.size(2) _, _, att_maps = words_loss(region_features.detach(), words_embs.detach(), None, cap_lens, None, self.batch_size) img_set, _ = \ build_super_images(fake_imgs[i].detach().cpu(), captions, self.ixtoword, att_maps, att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s/D_%s_%d.png' \ % (self.image_dir, name, gen_iterations) im.save(fullpath) def train(self): text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, start_epoch = self.build_models() avg_param_G = copy_G_params(netG) optimizerG, optimizersD = self.define_optimizers(netG, netsD) real_labels, fake_labels, match_labels = self.prepare_labels() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz)) fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1)) if cfg.CUDA: noise, fixed_noise = noise.cuda(), fixed_noise.cuda() gen_iterations = 0 for epoch in range(start_epoch, self.max_epoch): start_t = time.time() data_iter = iter(self.data_loader) step = 0 while step < self.num_batches: # (1) Prepare training data and Compute text embeddings data = data_iter.next() imgs, captions, cap_lens, class_ids, keys = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] # (2) Generate fake images noise.data.normal_(0, 1) fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask) # (3) Update D network errD_total = 0 D_logs = '' for i in range(len(netsD)): netsD[i].zero_grad() errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i], sent_emb, real_labels, fake_labels) # backward and update parameters errD.backward() optimizersD[i].step() errD_total += errD D_logs += 'errD%d: %.2f ' % (i, errD.data[0]) # (4) Update G network: maximize log(D(G(z))) # compute total loss for training G step += 1 gen_iterations += 1 netG.zero_grad() errG_total, G_logs = \ generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels, words_embs, sent_emb, match_labels, cap_lens, class_ids) kl_loss = KL_loss(mu, logvar) errG_total += kl_loss G_logs += 'kl_loss: %.2f ' % kl_loss.data[0] # backward and update parameters errG_total.backward() optimizerG.step() for p, avg_p in zip(netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001, p.data) if gen_iterations % 100 == 0: print(D_logs + '\n' + G_logs) # save images if gen_iterations % 1000 == 0: backup_para = copy_G_params(netG) load_params(netG, avg_param_G) self.save_img_results(netG, fixed_noise, sent_emb, words_embs, mask, image_encoder, captions, cap_lens, epoch, name='average') load_params(netG, backup_para) end_t = time.time() print('''[%d/%d][%d] Loss_D: %.2f Loss_G: %.2f Time: %.2fs''' % (epoch, self.max_epoch, self.num_batches, errD_total.data[0], errG_total.data[0], end_t - start_t)) if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0: # and epoch != 0: self.save_model(netG, avg_param_G, netsD, epoch) self.save_model(netG, avg_param_G, netsD, self.max_epoch) def save_singleimages(self, images, filenames, save_dir, split_dir, sentenceID=0): for i in range(images.size(0)): s_tmp = '%s/single_samples/%s/%s' % \ (save_dir, split_dir, filenames[i]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) fullpath = '%s_%d.jpg' % (s_tmp, sentenceID) # range from [-1, 1] to [0, 1] # img = (images[i] + 1.0) / 2 img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte() # range from [0, 1] to [0, 255] ndarr = img.permute(1, 2, 0).data.cpu().numpy() im = Image.fromarray(ndarr) im.save(fullpath) def sampling(self, split_dir): if cfg.TRAIN.NET_G == '': print('Error: the path for model is not found!') else: if split_dir == 'test': split_dir = 'valid' # Build and load the generator if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() netG.apply(weights_init) netG.cuda() netG.eval() # text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() batch_size = self.batch_size nz = cfg.GAN.Z_DIM noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) # the path to save generated images s_tmp = model_dir[:model_dir.rfind('.pth')] save_dir = '%s/%s' % (s_tmp, split_dir) mkdir_p(save_dir) cnt = 0 for _ in range(1): # (cfg.TEXT.CAPTIONS_PER_IMAGE): for step, data in enumerate(self.data_loader, 0): cnt += batch_size if step % 100 == 0: print('step: ', step) # if step > 50: # break imgs, captions, cap_lens, class_ids, keys = prepare_data(data) hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) words_embs, sent_emb = words_embs.detach(), sent_emb.detach() mask = (captions == 0) num_words = words_embs.size(2) if mask.size(1) > num_words: mask = mask[:, :num_words] # (2) Generate fake images noise.data.normal_(0, 1) fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask) for j in range(batch_size): s_tmp = '%s/single/%s' % (save_dir, keys[j]) folder = s_tmp[:s_tmp.rfind('/')] if not os.path.isdir(folder): print('Make a new folder: ', folder) mkdir_p(folder) k = -1 # for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() # [-1, 1] --> [0, 255] im = (im + 1.0) * 127.5 im = im.astype(np.uint8) im = np.transpose(im, (1, 2, 0)) im = Image.fromarray(im) fullpath = '%s_s%d.png' % (s_tmp, k) im.save(fullpath) def gen_example(self, data_dic): if cfg.TRAIN.NET_G == '': print('Error: the path for morels is not found!') else: # Build and load the generator text_encoder = \ RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM) state_dict = \ torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage) text_encoder.load_state_dict(state_dict) print('Load text encoder from:', cfg.TRAIN.NET_E) text_encoder = text_encoder.cuda() text_encoder.eval() # the path to save generated images if cfg.GAN.B_DCGAN: netG = G_DCGAN() else: netG = G_NET() s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')] model_dir = cfg.TRAIN.NET_G state_dict = \ torch.load(model_dir, map_location=lambda storage, loc: storage) netG.load_state_dict(state_dict) print('Load G from: ', model_dir) netG.cuda() netG.eval() for key in data_dic: save_dir = '%s/%s' % (s_tmp, key) mkdir_p(save_dir) captions, cap_lens, sorted_indices = data_dic[key] batch_size = captions.shape[0] nz = cfg.GAN.Z_DIM captions = Variable(torch.from_numpy(captions), volatile=True) cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True) captions = captions.cuda() cap_lens = cap_lens.cuda() for i in range(1): # 16 noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True) noise = noise.cuda() # (1) Extract text embeddings hidden = text_encoder.init_hidden(batch_size) # words_embs: batch_size x nef x seq_len # sent_emb: batch_size x nef words_embs, sent_emb = text_encoder(captions, cap_lens, hidden) mask = (captions == 0) # (2) Generate fake images noise.data.normal_(0, 1) fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask) # G attention cap_lens_np = cap_lens.cpu().data.numpy() for j in range(batch_size): save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j]) for k in range(len(fake_imgs)): im = fake_imgs[k][j].data.cpu().numpy() im = (im + 1.0) * 127.5 im = im.astype(np.uint8) # print('im', im.shape) im = np.transpose(im, (1, 2, 0)) # print('im', im.shape) im = Image.fromarray(im) fullpath = '%s_g%d.png' % (save_name, k) im.save(fullpath) for k in range(len(attention_maps)): if len(fake_imgs) > 1: im = fake_imgs[k + 1].detach().cpu() else: im = fake_imgs[0].detach().cpu() attn_maps = attention_maps[k] att_sze = attn_maps.size(2) img_set, sentences = \ build_super_images2(im[j].unsqueeze(0), captions[j].unsqueeze(0), [cap_lens_np[j]], self.ixtoword, [attn_maps[j]], att_sze) if img_set is not None: im = Image.fromarray(img_set) fullpath = '%s_a%d.png' % (save_name, k) im.save(fullpath)