[
  {
    "path": ".gitignore",
    "content": "*.pyc\nmiscc/*.pyc\n.DS_Store\n.idea/\n"
  },
  {
    "path": "GLAttention.py",
    "content": "import torch\nimport torch.nn as nn\n\ndef conv1x1(in_planes, out_planes):\n    \"1x1 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,\n                     padding=0, bias=False)\n\n\ndef func_attention(query, context, gamma1):\n    \"\"\"\n    query: batch x ndf x queryL\n    context: batch x ndf x ih x iw (sourceL=ihxiw)\n    mask: batch_size x sourceL\n    \"\"\"\n    batch_size, queryL = query.size(0), query.size(2)\n    ih, iw = context.size(2), context.size(3)\n    sourceL = ih * iw\n\n    # --> batch x sourceL x ndf\n    context = context.view(batch_size, -1, sourceL)\n    contextT = torch.transpose(context, 1, 2).contiguous()\n\n    # Get attention\n    # (batch x sourceL x ndf)(batch x ndf x queryL)\n    # -->batch x sourceL x queryL\n    attn = torch.bmm(contextT, query)\n    # --> batch*sourceL x queryL\n    attn = attn.view(batch_size*sourceL, queryL)\n    attn = nn.Softmax()(attn)  # Eq. (8)\n\n    # --> batch x sourceL x queryL\n    attn = attn.view(batch_size, sourceL, queryL)\n    # --> batch*queryL x sourceL\n    attn = torch.transpose(attn, 1, 2).contiguous()\n    attn = attn.view(batch_size*queryL, sourceL)\n\n    attn = attn * gamma1\n    attn = nn.Softmax()(attn)\n    attn = attn.view(batch_size, queryL, sourceL)\n    # --> batch x sourceL x queryL\n    attnT = torch.transpose(attn, 1, 2).contiguous()\n\n    # (batch x ndf x sourceL)(batch x sourceL x queryL)\n    # --> batch x ndf x queryL\n    weightedContext = torch.bmm(context, attnT)\n\n    return weightedContext, attn.view(batch_size, -1, ih, iw)\n\n\nclass GLAttentionGeneral(nn.Module):\n    def __init__(self, idf, cdf):\n        super(GLAttentionGeneral, self).__init__()\n        self.conv_context = conv1x1(cdf, idf)\n        self.conv_sentence_vis = conv1x1(idf, idf)\n        self.linear = nn.Linear(100, idf)\n        self.sm = nn.Softmax()\n        self.mask = None\n\n    def applyMask(self, mask):\n        self.mask = mask  # batch x sourceL\n\n    def forward(self, input, sentence, context):\n        \"\"\"\n            input: batch x idf x ih x iw (queryL=ihxiw)\n\n            context: batch x cdf x sourceL (this is the matrix of word vectors)\n\n            sentence (c_code1): batch x idf x queryL (this is the vectors of the sentence)\n            queryL=ih x iw\n        \"\"\"\n\n        idf, ih, iw = input.size(1), input.size(2), input.size(3)\n        queryL = ih * iw\n        batch_size, sourceL = context.size(0), context.size(2)\n\n        # generated image feature:--> batch x queryL x idf\n        target = input.view(batch_size, -1, queryL)             # batch x idf x queryL\n        targetT = torch.transpose(target, 1, 2).contiguous()    # batch x queryL x idf\n\n\n        # Eq(4) in MirrorGAN : local-level attention\n        # words feature:  batch x cdf x sourceL --> batch x cdf x sourceL x 1\n        sourceT = context.unsqueeze(3)\n        # --> batch x idf x sourceL\n        sourceT = self.conv_context(sourceT).squeeze(3)\n\n        attn = torch.bmm(targetT, sourceT)\n        # --> batch*queryL x sourceL\n        attn = attn.view(batch_size*queryL, sourceL)\n        if self.mask is not None:\n            # batch_size x sourceL --> batch_size*queryL x sourceL\n            mask = self.mask.repeat(queryL, 1)\n            attn.data.masked_fill_(mask.data, -float('inf'))\n        attn = self.sm(attn)  # Eq. (2)\n        # --> batch x queryL x sourceL\n        attn = attn.view(batch_size, queryL, sourceL)\n        # --> batch x sourceL x queryL\n        attn = torch.transpose(attn, 1, 2).contiguous()\n\n        # (batch x idf x sourceL)(batch x sourceL x queryL)\n        # --> batch x idf x queryL\n        weightedContext = torch.bmm(sourceT, attn)\n        weightedContext = weightedContext.view(batch_size, -1, ih, iw)  # batch x idf x ih x iw\n        word_attn = attn.view(batch_size, -1, ih, iw)  # (batch x sourceL x ih x iw)\n\n        # Eq(5) in MirrorGAN : global-level attention\n        sentence = self.linear(sentence)\n        sentence = sentence.view(batch_size, idf, 1, 1)\n        sentence = sentence.repeat(1, 1, ih, iw)\n        sentence_vs = torch.mul(input, sentence)   # batch x idf x ih x iw\n        sentence_vs = self.conv_sentence_vis(sentence_vs) # batch x idf x ih x iw\n        sent_att = nn.Softmax()(sentence_vs)  # batch x idf x ih x iw\n        weightedSentence = torch.mul(sentence, sent_att)  # batch x idf x ih x iw\n\n        return weightedContext, weightedSentence, word_attn, sent_att\n\n        # weightedContext: batch x idf x ih x iw\n        # weightedSentence: batch x idf x ih x iw\n        # word_attn: batch x sourceL x ih x iw\n        # sent_vs_att: batch x idf x ih x iw\n"
  },
  {
    "path": "README.md",
    "content": "# MirrorGAN\n\nPytorch implementation for Paper [MirrorGAN: Learning Text-to-image Generation by Redescription](https://arxiv.org/abs/1903.05854) by Tingting Qiao, Jing Zhang, Duanqing Xu, Dacheng Tao. (The work was performed when Tingting Qiao was a visiting student at UBTECH Sydney AI Centre in the School of Computer Science, FEIT, the University of Sydney).\n\n![image](images/framework.jpg)\n\n## Getting Started\n### Installation\n\n- Install PyTorch and dependencies from http://pytorch.org\n- Install Torch vision from the source.\n\n- Clone this repo:\n```bash\ngit clone https://github.com/qiaott/MirrorGAN.git\ncd MirrorGAN\n```\n- Download our preprocessed data from [here](https://drive.google.com/file/d/1CuW5ognTSkNbyx9TWoUFrgwqxZNk1cl0/view?usp=sharing).\n\n- The STEM was pretrained using the code provided [here](https://github.com/taoxugit/AttnGAN)\n\n- The STREAM was pretrained using the code provided [here](https://github.com/yunjey/pytorch-tutorial/tree/master/tutorials/03-advanced/image_captioning).\n\n### Train/Test\n\nAfter obtaining the pretrained STEM and STREAM modules, we can train the text2image model.\n- Train a model:\n```bash\n./do_train.sh\n```\n- Test a model:\n```bash\n./do_test.sh\n```\n\n## Citation\nIf you use this code for your research, please cite our paper.\n\n```bash\n@article{qiao2019mirrorgan,\n  title={MirrorGAN: Learning Text-to-image Generation by Redescription},\n  author={Qiao, Tingting and Zhang, Jing and Xu, Duanqing and Tao, Dacheng},\n  journal={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},\n  year={2019}\n}\n```"
  },
  {
    "path": "cfg/__init__.py",
    "content": ""
  },
  {
    "path": "cfg/config.py",
    "content": "from __future__ import division\nfrom __future__ import print_function\n\nimport os.path as osp\nimport numpy as np\nfrom easydict import EasyDict as edict\n\n\n__C = edict()\ncfg = __C\n\n# Dataset name: flowers, birds\n__C.DATASET_NAME = 'birds'\n__C.CONFIG_NAME = ''\n__C.DATA_DIR = ''\n__C.GPU_ID = 0\n__C.CUDA = True\n__C.WORKERS = 6\n__C.OUTPUT_PATH = ''\n__C.RNN_TYPE = 'LSTM'   # 'GRU'\n__C.B_VALIDATION = False\n\n__C.TREE = edict()\n__C.TREE.BRANCH_NUM = 3\n__C.TREE.BASE_SIZE = 64\n\n\n# Training options\n__C.TRAIN = edict()\n__C.TRAIN.BATCH_SIZE = 64\n__C.TRAIN.MAX_EPOCH = 600\n__C.TRAIN.SNAPSHOT_INTERVAL = 2000\n__C.TRAIN.DISCRIMINATOR_LR = 2e-4\n__C.TRAIN.GENERATOR_LR = 2e-4\n__C.TRAIN.ENCODER_LR = 2e-4\n__C.TRAIN.RNN_GRAD_CLIP = 0.25\n__C.TRAIN.FLAG = True\n__C.TRAIN.NET_E = ''\n__C.TRAIN.NET_G = ''\n__C.TRAIN.B_NET_D = True\n\n__C.TRAIN.SMOOTH = edict()\n__C.TRAIN.SMOOTH.GAMMA1 = 5.0\n__C.TRAIN.SMOOTH.GAMMA3 = 10.0\n__C.TRAIN.SMOOTH.GAMMA2 = 5.0\n__C.TRAIN.SMOOTH.LAMBDA = 0.0\n__C.TRAIN.SMOOTH.LAMBDA1 = 1.0\n\n\n# Caption_model_settings added by tingting\n__C.CAP = edict()\n__C.CAP.embed_size = 256\n__C.CAP.hidden_size = 512\n__C.CAP.num_layers = 1\n__C.CAP.learning_rate = 0.001\n__C.CAP.caption_cnn_path = ''\n__C.CAP.caption_rnn_path = ''\n\n\n# Modal options\n__C.GAN = edict()\n__C.GAN.DF_DIM = 64\n__C.GAN.GF_DIM = 128\n__C.GAN.Z_DIM = 100\n__C.GAN.CONDITION_DIM = 100\n__C.GAN.R_NUM = 2\n__C.GAN.B_ATTENTION = True\n__C.GAN.B_DCGAN = False\n\n\n__C.TEXT = edict()\n__C.TEXT.CAPTIONS_PER_IMAGE = 10\n__C.TEXT.EMBEDDING_DIM = 256\n__C.TEXT.WORDS_NUM = 18\n\n\ndef _merge_a_into_b(a, b):\n    \"\"\"Merge config dictionary a into config dictionary b, clobbering the\n    options in b whenever they are also specified in a.\n    \"\"\"\n    if type(a) is not edict:\n        return\n\n    for k, v in a.iteritems():\n        # a must specify keys that are in b\n        if not b.has_key(k):\n            raise KeyError('{} is not a valid config key'.format(k))\n\n        # the types must match, too\n        old_type = type(b[k])\n        if old_type is not type(v):\n            if isinstance(b[k], np.ndarray):\n                v = np.array(v, dtype=b[k].dtype)\n            else:\n                raise ValueError(('Type mismatch ({} vs. {}) '\n                                  'for config key: {}').format(type(b[k]),\n                                                               type(v), k))\n\n        # recursively merge dicts\n        if type(v) is edict:\n            try:\n                _merge_a_into_b(a[k], b[k])\n            except:\n                print('Error under config key: {}'.format(k))\n                raise\n        else:\n            b[k] = v\n\n\ndef cfg_from_file(filename):\n    \"\"\"Load a config file and merge it into the default options.\"\"\"\n    import yaml\n    with open(filename, 'r') as f:\n        yaml_cfg = edict(yaml.load(f))\n\n    _merge_a_into_b(yaml_cfg, __C)\n"
  },
  {
    "path": "cfg/eval_bird.yml",
    "content": "CONFIG_NAME: 'MirrorGAN'\nDATASET_NAME: 'birds'\nDATA_DIR: '../data/birds'\nGPU_ID: 3\nWORKERS: 1\n\nB_VALIDATION: True  # True  # False\nTREE:\n    BRANCH_NUM: 3\n\nTRAIN:\n    FLAG: False\n    NET_G: '../data/output/bird/Model/netG.pth'   # path to the trained model\n    B_NET_D: False\n    BATCH_SIZE: 12\n    NET_E: '../data/STEM/text_encoder.pth'\nGAN:\n    DF_DIM: 64\n    GF_DIM: 32\n    Z_DIM: 100\n    R_NUM: 2\n\nTEXT:\n    EMBEDDING_DIM: 256\n    CAPTIONS_PER_IMAGE: 10\n    WORDS_NUM: 25\n"
  },
  {
    "path": "cfg/train_bird.yml",
    "content": "CONFIG_NAME: 'MirrorGAN'\nDATASET_NAME: 'birds'\nDATA_DIR: '../data/birds'\nGPU_ID: 3\nWORKERS: 4\nOUTPUT_PATH: '/data/qtt/MirrorGAN/'\nTREE:\n    BRANCH_NUM: 3\n\nTRAIN:\n    FLAG: True\n    NET_G: ''\n    B_NET_D: True\n    BATCH_SIZE: 12  # 22\n    MAX_EPOCH: 650\n    SNAPSHOT_INTERVAL: 50\n    DISCRIMINATOR_LR: 0.0002\n    GENERATOR_LR: 0.0002\n\n    NET_E: '../data/STEM/text_encoder.pth'\n    SMOOTH:\n        GAMMA1: 4.0  # 1,2,5 good 4 best  10&100bad\n        GAMMA2: 5.0\n        GAMMA3: 10.0  # 10good 1&100bad\n        LAMBDA: 0.0\n        LAMBDA1: 10.0\n\nCAP:\n    embed_size: 256\n    hidden_size: 256\n    num_layers: 1\n    learning_rate: 0.001\n    caption_cnn_path: '../data/STREAM/cnn_encoder.ckpt'\n    caption_rnn_path: '../data/STREAM/rnn_decoder.ckpt'\n\nGAN:\n    DF_DIM: 64\n    GF_DIM: 32\n    Z_DIM: 100\n    R_NUM: 2\n\nTEXT:\n    EMBEDDING_DIM: 256\n    CAPTIONS_PER_IMAGE: 10\n"
  },
  {
    "path": "datasets.py",
    "content": "from __future__ import absolute_import\nfrom __future__ import division\nfrom __future__ import print_function\nfrom __future__ import unicode_literals\n\n\nfrom nltk.tokenize import RegexpTokenizer\nfrom collections import defaultdict\nfrom cfg.config import cfg\n\nimport torch\nimport torch.utils.data as data\nfrom torch.autograd import Variable\nimport torchvision.transforms as transforms\n\nimport os\nimport sys\nimport numpy as np\nimport pandas as pd\nfrom PIL import Image\nimport numpy.random as random\nif sys.version_info[0] == 2:\n    import cPickle as pickle\nelse:\n    import pickle\n\n\ndef prepare_data(data):\n    imgs, captions, captions_lens, class_ids, keys = data\n\n    # sort data by the length in a decreasing order\n    sorted_cap_lens, sorted_cap_indices = \\\n        torch.sort(captions_lens, 0, True)\n\n    real_imgs = []\n    for i in range(len(imgs)):\n        imgs[i] = imgs[i][sorted_cap_indices]\n        if cfg.CUDA:\n            real_imgs.append(Variable(imgs[i]).cuda())\n        else:\n            real_imgs.append(Variable(imgs[i]))\n\n    captions = captions[sorted_cap_indices].squeeze()\n    class_ids = class_ids[sorted_cap_indices].numpy()\n    # sent_indices = sent_indices[sorted_cap_indices]\n    keys = [keys[i] for i in sorted_cap_indices.numpy()]\n    # print('keys', type(keys), keys[-1])  # list\n    if cfg.CUDA:\n        captions = Variable(captions).cuda()\n        sorted_cap_lens = Variable(sorted_cap_lens).cuda()\n    else:\n        captions = Variable(captions)\n        sorted_cap_lens = Variable(sorted_cap_lens)\n\n    return [real_imgs, captions, sorted_cap_lens,\n            class_ids, keys]\n\n\ndef get_imgs(img_path, imsize, bbox=None,\n             transform=None, normalize=None):\n    img = Image.open(img_path).convert('RGB')\n    width, height = img.size\n    if bbox is not None:\n        r = int(np.maximum(bbox[2], bbox[3]) * 0.75)\n        center_x = int((2 * bbox[0] + bbox[2]) / 2)\n        center_y = int((2 * bbox[1] + bbox[3]) / 2)\n        y1 = np.maximum(0, center_y - r)\n        y2 = np.minimum(height, center_y + r)\n        x1 = np.maximum(0, center_x - r)\n        x2 = np.minimum(width, center_x + r)\n        img = img.crop([x1, y1, x2, y2])\n\n    if transform is not None:\n        img = transform(img)\n\n    ret = []\n    if cfg.GAN.B_DCGAN:\n        ret = [normalize(img)]\n    else:\n        for i in range(cfg.TREE.BRANCH_NUM):\n            # print(imsize[i])\n            if i < (cfg.TREE.BRANCH_NUM - 1):\n                re_img = transforms.Scale(imsize[i])(img)\n            else:\n                re_img = img\n            ret.append(normalize(re_img))\n\n    return ret\n\n\nclass TextDataset(data.Dataset):\n    def __init__(self, data_dir, split='train',\n                 base_size=64,\n                 transform=None, target_transform=None):\n        self.transform = transform\n        self.norm = transforms.Compose([\n            transforms.ToTensor(),\n            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n        self.target_transform = target_transform\n        self.embeddings_num = cfg.TEXT.CAPTIONS_PER_IMAGE\n\n        self.imsize = []\n        for i in range(cfg.TREE.BRANCH_NUM):\n            self.imsize.append(base_size)\n            base_size = base_size * 2\n\n        self.data = []\n        self.data_dir = data_dir\n        if data_dir.find('birds') != -1:\n            self.bbox = self.load_bbox()\n        else:\n            self.bbox = None\n        split_dir = os.path.join(data_dir, split)\n\n        self.filenames, self.captions, self.ixtoword, \\\n            self.wordtoix, self.n_words = self.load_text_data(data_dir, split)\n\n        self.class_id = self.load_class_id(split_dir, len(self.filenames))\n        self.number_example = len(self.filenames)\n\n    def load_bbox(self):\n        data_dir = self.data_dir\n        bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt')\n        df_bounding_boxes = pd.read_csv(bbox_path,\n                                        delim_whitespace=True,\n                                        header=None).astype(int)\n        #\n        filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt')\n        df_filenames = \\\n            pd.read_csv(filepath, delim_whitespace=True, header=None)\n        filenames = df_filenames[1].tolist()\n        print('Total filenames: ', len(filenames), filenames[0])\n        #\n        filename_bbox = {img_file[:-4]: [] for img_file in filenames}\n        numImgs = len(filenames)\n        for i in xrange(0, numImgs):\n            # bbox = [x-left, y-top, width, height]\n            bbox = df_bounding_boxes.iloc[i][1:].tolist()\n\n            key = filenames[i][:-4]\n            filename_bbox[key] = bbox\n        #\n        return filename_bbox\n\n    def load_captions(self, data_dir, filenames):\n        all_captions = []\n        for i in range(len(filenames)):\n            cap_path = '%s/text/%s.txt' % (data_dir, filenames[i])\n            with open(cap_path, \"r\") as f:\n                captions = f.read().decode('utf8').split('\\n')\n                cnt = 0\n                for cap in captions:\n                    if len(cap) == 0:\n                        continue\n                    cap = cap.replace(\"\\ufffd\\ufffd\", \" \")\n                    # picks out sequences of alphanumeric characters as tokens\n                    # and drops everything else\n                    tokenizer = RegexpTokenizer(r'\\w+')\n                    tokens = tokenizer.tokenize(cap.lower())\n                    # print('tokens', tokens)\n                    if len(tokens) == 0:\n                        print('cap', cap)\n                        continue\n\n                    tokens_new = []\n                    for t in tokens:\n                        t = t.encode('ascii', 'ignore').decode('ascii')\n                        if len(t) > 0:\n                            tokens_new.append(t)\n                    all_captions.append(tokens_new)\n                    cnt += 1\n                    if cnt == self.embeddings_num:\n                        break\n                if cnt < self.embeddings_num:\n                    print('ERROR: the captions for %s less than %d'\n                          % (filenames[i], cnt))\n        return all_captions\n\n    def build_dictionary(self, train_captions, test_captions):\n        word_counts = defaultdict(float)\n        captions = train_captions + test_captions\n        for sent in captions:\n            for word in sent:\n                word_counts[word] += 1\n\n        vocab = [w for w in word_counts if word_counts[w] >= 0]\n\n        ixtoword = {}\n        ixtoword[0] = '<end>'\n        wordtoix = {}\n        wordtoix['<end>'] = 0\n        ix = 1\n        for w in vocab:\n            wordtoix[w] = ix\n            ixtoword[ix] = w\n            ix += 1\n\n        train_captions_new = []\n        for t in train_captions:\n            rev = []\n            for w in t:\n                if w in wordtoix:\n                    rev.append(wordtoix[w])\n            # rev.append(0)  # do not need '<end>' token\n            train_captions_new.append(rev)\n\n        test_captions_new = []\n        for t in test_captions:\n            rev = []\n            for w in t:\n                if w in wordtoix:\n                    rev.append(wordtoix[w])\n            # rev.append(0)  # do not need '<end>' token\n            test_captions_new.append(rev)\n\n        return [train_captions_new, test_captions_new,\n                ixtoword, wordtoix, len(ixtoword)]\n\n    def load_text_data(self, data_dir, split):\n        filepath = os.path.join(data_dir, 'bird_captions.pickle')\n        train_names = self.load_filenames(data_dir, 'train')\n        test_names = self.load_filenames(data_dir, 'test')\n        if not os.path.isfile(filepath):\n            train_captions = self.load_captions(data_dir, train_names)\n            test_captions = self.load_captions(data_dir, test_names)\n\n            train_captions, test_captions, ixtoword, wordtoix, n_words = \\\n                self.build_dictionary(train_captions, test_captions)\n            with open(filepath, 'wb') as f:\n                pickle.dump([train_captions, test_captions,\n                             ixtoword, wordtoix], f, protocol=2)\n                print('Save to: ', filepath)\n        else:\n            with open(filepath, 'rb') as f:\n                x = pickle.load(f)\n                train_captions, test_captions = x[0], x[1]\n                ixtoword, wordtoix = x[2], x[3]\n                del x\n                n_words = len(ixtoword)\n                print('Load from: ', filepath)\n        if split == 'train':\n            # a list of list: each list contains\n            # the indices of words in a sentence\n            captions = train_captions\n            filenames = train_names\n        else:  # split=='test'\n            captions = test_captions\n            filenames = test_names\n        return filenames, captions, ixtoword, wordtoix, n_words\n\n    def load_class_id(self, data_dir, total_num):\n        if os.path.isfile(data_dir + '/class_info.pickle'):\n            with open(data_dir + '/class_info.pickle', 'rb') as f:\n                class_id = pickle.load(f)\n        else:\n            class_id = np.arange(total_num)\n        return class_id\n\n    def load_filenames(self, data_dir, split):\n        filepath = '%s/%s/filenames.pickle' % (data_dir, split)\n        if os.path.isfile(filepath):\n            with open(filepath, 'rb') as f:\n                filenames = pickle.load(f)\n            print('Load filenames from: %s (%d)' % (filepath, len(filenames)))\n        else:\n            filenames = []\n        return filenames\n\n    def get_caption(self, sent_ix):\n        # a list of indices for a sentence\n        sent_caption = np.asarray(self.captions[sent_ix]).astype('int64')\n        # if (sent_caption == 0).sum() > 0:\n        #     print('ERROR: do not need END (0) token', sent_caption)\n        num_words = len(sent_caption)\n        # pad with 0s (i.e., '<end>')\n        x = np.zeros((cfg.TEXT.WORDS_NUM, 1), dtype='int64')\n        x_len = num_words\n        if num_words <= cfg.TEXT.WORDS_NUM:\n            x[:num_words, 0] = sent_caption\n        else:\n            ix = list(np.arange(num_words))  # 1, 2, 3,..., maxNum\n            np.random.shuffle(ix)\n            ix = ix[:cfg.TEXT.WORDS_NUM]\n            ix = np.sort(ix)\n            x[:, 0] = sent_caption[ix]\n            x_len = cfg.TEXT.WORDS_NUM\n        return x, x_len\n\n    def __getitem__(self, index):\n        #\n        key = self.filenames[index]\n        cls_id = self.class_id[index]\n        #\n        if self.bbox is not None:\n            bbox = self.bbox[key]\n            data_dir = '%s/CUB_200_2011' % self.data_dir\n        else:\n            bbox = None\n            data_dir = self.data_dir\n        #\n        img_name = '%s/images/%s.jpg' % (data_dir, key)\n        imgs = get_imgs(img_name, self.imsize,\n                        bbox, self.transform, normalize=self.norm)\n        # random select a sentence\n        sent_ix = random.randint(0, self.embeddings_num)\n        new_sent_ix = index * self.embeddings_num + sent_ix\n        caps, cap_len = self.get_caption(new_sent_ix)\n        return imgs, caps, cap_len, cls_id, key\n\n\n    def __len__(self):\n        return len(self.filenames)\n"
  },
  {
    "path": "do_test.sh",
    "content": "cfg=cfg/eval_bird.yml\npython main.py --cfg $cfg\n"
  },
  {
    "path": "do_train.sh",
    "content": "cfg=cfg/train_bird.yml\npython main.py --cfg $cfg\n"
  },
  {
    "path": "main.py",
    "content": "from __future__ import print_function\n\nfrom cfg.config import cfg, cfg_from_file\nfrom datasets import TextDataset\nfrom trainer import Trainer as trainer\n\nimport os\nimport sys\nimport time\nimport random\nimport pprint\nimport datetime\nimport dateutil.tz\nimport argparse\nimport numpy as np\n\nimport torch\nimport torchvision.transforms as transforms\n\ndir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))\nsys.path.append(dir_path)\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Train a AttnGAN network')\n    parser.add_argument('--cfg', dest='cfg_file',\n                        help='optional config file',\n                        default='cfg/bird_attn2.yml', type=str)\n    parser.add_argument('--gpu', dest='gpu_id', type=int, default=-1)\n    parser.add_argument('--data_dir', dest='data_dir', type=str, default='')\n    parser.add_argument('--manualSeed', type=int, help='manual seed')\n    args = parser.parse_args()\n    return args\n\n\ndef gen_example(wordtoix, algo):\n    '''generate images from example sentences'''\n    from nltk.tokenize import RegexpTokenizer\n    filepath = '%s/example_filenames.txt' % (cfg.DATA_DIR)\n    data_dic = {}\n    with open(filepath, \"r\") as f:\n        filenames = f.read().decode('utf8').split('\\n')\n        for name in filenames:\n            if len(name) == 0:\n                continue\n            filepath = '%s/%s.txt' % (cfg.DATA_DIR, name)\n            with open(filepath, \"r\") as f:\n                print('Load from:', name)\n                sentences = f.read().decode('utf8').split('\\n')\n                # a list of indices for a sentence\n                captions = []\n                cap_lens = []\n                for sent in sentences:\n                    if len(sent) == 0:\n                        continue\n                    sent = sent.replace(\"\\ufffd\\ufffd\", \" \")\n                    tokenizer = RegexpTokenizer(r'\\w+')\n                    tokens = tokenizer.tokenize(sent.lower())\n                    if len(tokens) == 0:\n                        print('sent', sent)\n                        continue\n\n                    rev = []\n                    for t in tokens:\n                        t = t.encode('ascii', 'ignore').decode('ascii')\n                        if len(t) > 0 and t in wordtoix:\n                            rev.append(wordtoix[t])\n                    captions.append(rev)\n                    cap_lens.append(len(rev))\n            max_len = np.max(cap_lens)\n\n            sorted_indices = np.argsort(cap_lens)[::-1]\n            cap_lens = np.asarray(cap_lens)\n            cap_lens = cap_lens[sorted_indices]\n            cap_array = np.zeros((len(captions), max_len), dtype='int64')\n            for i in range(len(captions)):\n                idx = sorted_indices[i]\n                cap = captions[idx]\n                c_len = len(cap)\n                cap_array[i, :c_len] = cap\n            key = name[(name.rfind('/') + 1):]\n            data_dic[key] = [cap_array, cap_lens, sorted_indices]\n    algo.gen_example(data_dic)\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    if args.cfg_file is not None:\n        cfg_from_file(args.cfg_file)\n\n    if args.data_dir != '':\n        cfg.DATA_DIR = args.data_dir\n    print('Using config:')\n    pprint.pprint(cfg)\n\n    if not cfg.TRAIN.FLAG:\n        args.manualSeed = 100\n    elif args.manualSeed is None:\n        args.manualSeed = random.randint(1, 10000)\n    random.seed(args.manualSeed)\n    np.random.seed(args.manualSeed)\n    torch.manual_seed(args.manualSeed)\n    if cfg.CUDA:\n        torch.cuda.manual_seed_all(args.manualSeed)\n\n    now = datetime.datetime.now(dateutil.tz.tzlocal())\n    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')\n    output_dir = '%s/output/%s_%s_%s' % \\\n        (cfg.OUTPUT_PATH, cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)\n\n    split_dir, bshuffle = 'train', True\n    if not cfg.TRAIN.FLAG:\n        # bshuffle = False\n        split_dir = 'test'\n\n    # Get data loader\n    imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM - 1))\n    image_transform = transforms.Compose([\n        transforms.Scale(int(imsize * 76 / 64)),\n        transforms.RandomCrop(imsize),\n        transforms.RandomHorizontalFlip()])\n    dataset = TextDataset(cfg.DATA_DIR, split_dir,\n                          base_size=cfg.TREE.BASE_SIZE,\n                          transform=image_transform)\n    assert dataset\n    dataloader = torch.utils.data.DataLoader(\n        dataset, batch_size=cfg.TRAIN.BATCH_SIZE,\n        drop_last=True, shuffle=bshuffle, num_workers=int(cfg.WORKERS))\n\n    # Define models and go to train/evaluate\n    algo = trainer(output_dir, dataloader, dataset.n_words, dataset.ixtoword)\n\n    start_t = time.time()\n    if cfg.TRAIN.FLAG:\n        algo.train()\n    else:\n        '''generate images from pre-extracted embeddings'''\n        if cfg.B_VALIDATION:\n            algo.sampling(split_dir)  # generate images for the whole valid dataset\n        else:\n            gen_example(dataset.wordtoix, algo)  # generate images for customized captions\n    end_t = time.time()\n    print('Total time for training:', end_t - start_t)\n"
  },
  {
    "path": "miscc/__init__.py",
    "content": "from __future__ import division\nfrom __future__ import print_function\n"
  },
  {
    "path": "miscc/losses.py",
    "content": "import torch\nimport torch.nn as nn\n\nimport numpy as np\nfrom cfg.config import cfg\nfrom torch.nn.utils.rnn import pack_padded_sequence\nfrom GLAttention import func_attention\n\n\n# ##################Loss for matching text-image###################\ndef cosine_similarity(x1, x2, dim=1, eps=1e-8):\n    \"\"\"Returns cosine similarity between x1 and x2, computed along dim.\n    \"\"\"\n    w12 = torch.sum(x1 * x2, dim)\n    w1 = torch.norm(x1, 2, dim)\n    w2 = torch.norm(x2, 2, dim)\n    return (w12 / (w1 * w2).clamp(min=eps)).squeeze()\n\ndef caption_loss(cap_output, captions):\n    criterion = nn.CrossEntropyLoss()\n    caption_loss = criterion(cap_output, captions)\n    return caption_loss\n\ndef sent_loss(cnn_code, rnn_code, labels, class_ids,\n              batch_size, eps=1e-8):\n    # ### Mask mis-match samples  ###\n    # that come from the same class as the real sample ###\n    masks = []\n    if class_ids is not None:\n        for i in range(batch_size):\n            mask = (class_ids == class_ids[i]).astype(np.uint8)\n            mask[i] = 0\n            masks.append(mask.reshape((1, -1)))\n        masks = np.concatenate(masks, 0)\n        # masks: batch_size x batch_size\n        masks = torch.ByteTensor(masks)\n        if cfg.CUDA:\n            masks = masks.cuda()\n\n    # --> seq_len x batch_size x nef\n    if cnn_code.dim() == 2:\n        cnn_code = cnn_code.unsqueeze(0)\n        rnn_code = rnn_code.unsqueeze(0)\n\n    # cnn_code_norm / rnn_code_norm: seq_len x batch_size x 1\n    cnn_code_norm = torch.norm(cnn_code, 2, dim=2, keepdim=True)\n    rnn_code_norm = torch.norm(rnn_code, 2, dim=2, keepdim=True)\n    # scores* / norm*: seq_len x batch_size x batch_size\n    scores0 = torch.bmm(cnn_code, rnn_code.transpose(1, 2))\n    norm0 = torch.bmm(cnn_code_norm, rnn_code_norm.transpose(1, 2))\n    scores0 = scores0 / norm0.clamp(min=eps) * cfg.TRAIN.SMOOTH.GAMMA3\n\n    # --> batch_size x batch_size\n    scores0 = scores0.squeeze()\n    if class_ids is not None:\n        scores0.data.masked_fill_(masks, -float('inf'))\n    scores1 = scores0.transpose(0, 1)\n    if labels is not None:\n        loss0 = nn.CrossEntropyLoss()(scores0, labels)\n        loss1 = nn.CrossEntropyLoss()(scores1, labels)\n    else:\n        loss0, loss1 = None, None\n    return loss0, loss1\n\n\ndef words_loss(img_features, words_emb, labels,\n               cap_lens, class_ids, batch_size):\n    \"\"\"\n        words_emb(query): batch x nef x seq_len\n        img_features(context): batch x nef x 17 x 17\n    \"\"\"\n    masks = []\n    att_maps = []\n    similarities = []\n    cap_lens = cap_lens.data.tolist()\n    for i in range(batch_size):\n        if class_ids is not None:\n            mask = (class_ids == class_ids[i]).astype(np.uint8)\n            mask[i] = 0\n            masks.append(mask.reshape((1, -1)))\n        # Get the i-th text description\n        words_num = cap_lens[i]\n        # -> 1 x nef x words_num\n        word = words_emb[i, :, :words_num].unsqueeze(0).contiguous()\n        # -> batch_size x nef x words_num\n        word = word.repeat(batch_size, 1, 1)\n        # batch x nef x 17*17\n        context = img_features\n        \"\"\"\n            word(query): batch x nef x words_num\n            context: batch x nef x 17 x 17\n            weiContext: batch x nef x words_num\n            attn: batch x words_num x 17 x 17\n        \"\"\"\n        weiContext, attn = func_attention(word, context, cfg.TRAIN.SMOOTH.GAMMA1)\n        att_maps.append(attn[i].unsqueeze(0).contiguous())\n        # --> batch_size x words_num x nef\n        word = word.transpose(1, 2).contiguous()\n        weiContext = weiContext.transpose(1, 2).contiguous()\n        # --> batch_size*words_num x nef\n        word = word.view(batch_size * words_num, -1)\n        weiContext = weiContext.view(batch_size * words_num, -1)\n        #\n        # -->batch_size*words_num\n        row_sim = cosine_similarity(word, weiContext)\n        # --> batch_size x words_num\n        row_sim = row_sim.view(batch_size, words_num)\n\n        # Eq. (10)\n        row_sim.mul_(cfg.TRAIN.SMOOTH.GAMMA2).exp_()\n        row_sim = row_sim.sum(dim=1, keepdim=True)\n        row_sim = torch.log(row_sim)\n\n        # --> 1 x batch_size\n        # similarities(i, j): the similarity between the i-th image and the j-th text description\n        similarities.append(row_sim)\n\n    # batch_size x batch_size\n    similarities = torch.cat(similarities, 1)\n    if class_ids is not None:\n        masks = np.concatenate(masks, 0)\n        # masks: batch_size x batch_size\n        masks = torch.ByteTensor(masks)\n        if cfg.CUDA:\n            masks = masks.cuda()\n\n    similarities = similarities * cfg.TRAIN.SMOOTH.GAMMA3\n    if class_ids is not None:\n        similarities.data.masked_fill_(masks, -float('inf'))\n    similarities1 = similarities.transpose(0, 1)\n    if labels is not None:\n        loss0 = nn.CrossEntropyLoss()(similarities, labels)\n        loss1 = nn.CrossEntropyLoss()(similarities1, labels)\n    else:\n        loss0, loss1 = None, None\n    return loss0, loss1, att_maps\n\n\n# ##################Loss for G and Ds##############################\ndef discriminator_loss(netD, real_imgs, fake_imgs, conditions,\n                       real_labels, fake_labels):\n    # Forward\n    real_features = netD(real_imgs)\n    fake_features = netD(fake_imgs.detach())\n    # loss\n    #\n    cond_real_logits = netD.COND_DNET(real_features, conditions)\n    cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels)\n    cond_fake_logits = netD.COND_DNET(fake_features, conditions)\n    cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels)\n    #\n    batch_size = real_features.size(0)\n    cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size])\n    cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size])\n\n    if netD.UNCOND_DNET is not None:\n        real_logits = netD.UNCOND_DNET(real_features)\n        fake_logits = netD.UNCOND_DNET(fake_features)\n        real_errD = nn.BCELoss()(real_logits, real_labels)\n        fake_errD = nn.BCELoss()(fake_logits, fake_labels)\n        errD = ((real_errD + cond_real_errD) / 2. +\n                (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.)\n    else:\n        errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2.\n    return errD\n\n\ndef generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels,\n                   words_embs, sent_emb, match_labels,\n                   cap_lens, class_ids):\n    numDs = len(netsD)\n    logs = ''\n    # Forward\n    errG_total = 0\n\n    for i in range(numDs):\n        features = netsD[i](fake_imgs[i])\n        cond_logits = netsD[i].COND_DNET(features, sent_emb)\n        cond_errG = nn.BCELoss()(cond_logits, real_labels)\n\n        if netsD[i].UNCOND_DNET is  not None:\n            logits = netsD[i].UNCOND_DNET(features)\n            errG = nn.BCELoss()(logits, real_labels)\n            g_loss = errG + cond_errG\n        else:\n            g_loss = cond_errG\n        errG_total += g_loss\n\n        logs += 'g_loss%d: %.2f ' % (i, g_loss.data[0])\n\n        if i == (numDs - 1):\n            fakeimg_feature = caption_cnn(fake_imgs[i])\n            captions.cuda()\n            target_cap = pack_padded_sequence(captions, cap_lens.data.tolist(), batch_first=True)[0].cuda()\n            cap_output = caption_rnn(fakeimg_feature, captions, cap_lens)\n            cap_loss = caption_loss(cap_output, target_cap) * cfg.TRAIN.SMOOTH.LAMBDA1\n\n            errG_total += cap_loss\n            logs += 'cap_loss: %.2f, ' % cap_loss\n    return errG_total, logs\n\n\n##################################################################\ndef KL_loss(mu, logvar):\n    # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)\n    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)\n    KLD = torch.mean(KLD_element).mul_(-0.5)\n    return KLD\n"
  },
  {
    "path": "miscc/utils.py",
    "content": "import os\nimport errno\nimport numpy as np\nfrom torch.nn import init\n\nimport torch\nimport torch.nn as nn\n\nfrom PIL import Image, ImageDraw, ImageFont\nfrom copy import deepcopy\nimport skimage.transform\n\nfrom cfg.config import cfg\n\n\n# For visualization ################################################\nCOLOR_DIC = {0:[128,64,128],  1:[244, 35,232],\n             2:[70, 70, 70],  3:[102,102,156],\n             4:[190,153,153], 5:[153,153,153],\n             6:[250,170, 30], 7:[220, 220, 0],\n             8:[107,142, 35], 9:[152,251,152],\n             10:[70,130,180], 11:[220,20, 60],\n             12:[255, 0, 0],  13:[0, 0, 142],\n             14:[119,11, 32], 15:[0, 60,100],\n             16:[0, 80, 100], 17:[0, 0, 230],\n             18:[0,  0, 70],  19:[0, 0,  0]}\nFONT_MAX = 50\n\n\ndef drawCaption(convas, captions, ixtoword, vis_size, off1=2, off2=2):\n    num = captions.size(0)\n    img_txt = Image.fromarray(convas)\n    # get a font\n    # fnt = None  # ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)\n    fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)\n    # get a drawing context\n    d = ImageDraw.Draw(img_txt)\n    sentence_list = []\n    for i in range(num):\n        cap = captions[i].data.cpu().numpy()\n        sentence = []\n        for j in range(len(cap)):\n            if cap[j] == 0:\n                break\n            word = ixtoword[cap[j]].encode('ascii', 'ignore').decode('ascii')\n            d.text(((j + off1) * (vis_size + off2), i * FONT_MAX), '%d:%s' % (j, word[:6]),\n                   font=fnt, fill=(255, 255, 255, 255))\n            sentence.append(word)\n        sentence_list.append(sentence)\n    return img_txt, sentence_list\n\n\ndef build_super_images(real_imgs, captions, ixtoword,\n                       attn_maps, att_sze, lr_imgs=None,\n                       batch_size=cfg.TRAIN.BATCH_SIZE,\n                       max_word_num=cfg.TEXT.WORDS_NUM):\n    nvis = 8\n    real_imgs = real_imgs[:nvis]\n    if lr_imgs is not None:\n        lr_imgs = lr_imgs[:nvis]\n    if att_sze == 17:\n        vis_size = att_sze * 16\n    else:\n        vis_size = real_imgs.size(2)\n\n    text_convas = \\\n        np.ones([batch_size * FONT_MAX,\n                 (max_word_num + 2) * (vis_size + 2), 3],\n                dtype=np.uint8)\n\n    for i in range(max_word_num):\n        istart = (i + 2) * (vis_size + 2)\n        iend = (i + 3) * (vis_size + 2)\n        text_convas[:, istart:iend, :] = COLOR_DIC[i]\n\n\n    real_imgs = \\\n        nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)\n    # [-1, 1] --> [0, 1]\n    real_imgs.add_(1).div_(2).mul_(255)\n    real_imgs = real_imgs.data.numpy()\n    # b x c x h x w --> b x h x w x c\n    real_imgs = np.transpose(real_imgs, (0, 2, 3, 1))\n    pad_sze = real_imgs.shape\n    middle_pad = np.zeros([pad_sze[2], 2, 3])\n    post_pad = np.zeros([pad_sze[1], pad_sze[2], 3])\n    if lr_imgs is not None:\n        lr_imgs = \\\n            nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(lr_imgs)\n        # [-1, 1] --> [0, 1]\n        lr_imgs.add_(1).div_(2).mul_(255)\n        lr_imgs = lr_imgs.data.numpy()\n        # b x c x h x w --> b x h x w x c\n        lr_imgs = np.transpose(lr_imgs, (0, 2, 3, 1))\n\n    # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17\n    seq_len = max_word_num\n    img_set = []\n    num = nvis  # len(attn_maps)\n\n    text_map, sentences = \\\n        drawCaption(text_convas, captions, ixtoword, vis_size)\n    text_map = np.asarray(text_map).astype(np.uint8)\n\n    bUpdate = 1\n    for i in range(num):\n        attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze)\n        # --> 1 x 1 x 17 x 17\n        attn_max = attn.max(dim=1, keepdim=True)\n        attn = torch.cat([attn_max[0], attn], 1)\n        #\n        attn = attn.view(-1, 1, att_sze, att_sze)\n        attn = attn.repeat(1, 3, 1, 1).data.numpy()\n        # n x c x h x w --> n x h x w x c\n        attn = np.transpose(attn, (0, 2, 3, 1))\n        num_attn = attn.shape[0]\n        #\n        img = real_imgs[i]\n        if lr_imgs is None:\n            lrI = img\n        else:\n            lrI = lr_imgs[i]\n        row = [lrI, middle_pad]\n        row_merge = [img, middle_pad]\n        row_beforeNorm = []\n        minVglobal, maxVglobal = 1, 0\n        for j in range(num_attn):\n            one_map = attn[j]\n            if (vis_size // att_sze) > 1:\n                one_map = \\\n                    skimage.transform.pyramid_expand(one_map, sigma=20,\n                                                     upscale=vis_size // att_sze)\n            row_beforeNorm.append(one_map)\n            minV = one_map.min()\n            maxV = one_map.max()\n            if minVglobal > minV:\n                minVglobal = minV\n            if maxVglobal < maxV:\n                maxVglobal = maxV\n        for j in range(seq_len + 1):\n            if j < num_attn:\n                one_map = row_beforeNorm[j]\n                one_map = (one_map - minVglobal) / (maxVglobal - minVglobal)\n                one_map *= 255\n                #\n                PIL_im = Image.fromarray(np.uint8(img))\n                PIL_att = Image.fromarray(np.uint8(one_map))\n                merged = \\\n                    Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0))\n                mask = Image.new('L', (vis_size, vis_size), (210))\n                merged.paste(PIL_im, (0, 0))\n                merged.paste(PIL_att, (0, 0), mask)\n                merged = np.array(merged)[:, :, :3]\n            else:\n                one_map = post_pad\n                merged = post_pad\n            row.append(one_map)\n            row.append(middle_pad)\n            #\n            row_merge.append(merged)\n            row_merge.append(middle_pad)\n        row = np.concatenate(row, 1)\n        row_merge = np.concatenate(row_merge, 1)\n        txt = text_map[i * FONT_MAX: (i + 1) * FONT_MAX]\n        if txt.shape[1] != row.shape[1]:\n            print('txt', txt.shape, 'row', row.shape)\n            bUpdate = 0\n            break\n        row = np.concatenate([txt, row, row_merge], 0)\n        img_set.append(row)\n    if bUpdate:\n        img_set = np.concatenate(img_set, 0)\n        img_set = img_set.astype(np.uint8)\n        return img_set, sentences\n    else:\n        return None\n\n\ndef build_super_images2(real_imgs, captions, cap_lens, ixtoword,\n                        attn_maps, att_sze, vis_size=256, topK=5):\n    batch_size = real_imgs.size(0)\n    max_word_num = np.max(cap_lens)\n    text_convas = np.ones([batch_size * FONT_MAX,\n                           max_word_num * (vis_size + 2), 3],\n                           dtype=np.uint8)\n\n    real_imgs = \\\n        nn.Upsample(size=(vis_size, vis_size), mode='bilinear')(real_imgs)\n    # [-1, 1] --> [0, 1]\n    real_imgs.add_(1).div_(2).mul_(255)\n    real_imgs = real_imgs.data.numpy()\n    # b x c x h x w --> b x h x w x c\n    real_imgs = np.transpose(real_imgs, (0, 2, 3, 1))\n    pad_sze = real_imgs.shape\n    middle_pad = np.zeros([pad_sze[2], 2, 3])\n\n    # batch x seq_len x 17 x 17 --> batch x 1 x 17 x 17\n    img_set = []\n    num = len(attn_maps)\n\n    text_map, sentences = \\\n        drawCaption(text_convas, captions, ixtoword, vis_size, off1=0)\n    text_map = np.asarray(text_map).astype(np.uint8)\n\n    bUpdate = 1\n    for i in range(num):\n        attn = attn_maps[i].cpu().view(1, -1, att_sze, att_sze)\n        #\n        attn = attn.view(-1, 1, att_sze, att_sze)\n        attn = attn.repeat(1, 3, 1, 1).data.numpy()\n        # n x c x h x w --> n x h x w x c\n        attn = np.transpose(attn, (0, 2, 3, 1))\n        num_attn = cap_lens[i]\n        thresh = 2./float(num_attn)\n        #\n        img = real_imgs[i]\n        row = []\n        row_merge = []\n        row_txt = []\n        row_beforeNorm = []\n        conf_score = []\n        for j in range(num_attn):\n            one_map = attn[j]\n            mask0 = one_map > (2. * thresh)\n            conf_score.append(np.sum(one_map * mask0))\n            mask = one_map > thresh\n            one_map = one_map * mask\n            if (vis_size // att_sze) > 1:\n                one_map = \\\n                    skimage.transform.pyramid_expand(one_map, sigma=20,\n                                                     upscale=vis_size // att_sze)\n            minV = one_map.min()\n            maxV = one_map.max()\n            one_map = (one_map - minV) / (maxV - minV)\n            row_beforeNorm.append(one_map)\n        sorted_indices = np.argsort(conf_score)[::-1]\n\n        for j in range(num_attn):\n            one_map = row_beforeNorm[j]\n            one_map *= 255\n            #\n            PIL_im = Image.fromarray(np.uint8(img))\n            PIL_att = Image.fromarray(np.uint8(one_map))\n            merged = \\\n                Image.new('RGBA', (vis_size, vis_size), (0, 0, 0, 0))\n            mask = Image.new('L', (vis_size, vis_size), (180))  # (210)\n            merged.paste(PIL_im, (0, 0))\n            merged.paste(PIL_att, (0, 0), mask)\n            merged = np.array(merged)[:, :, :3]\n\n            row.append(np.concatenate([one_map, middle_pad], 1))\n            #\n            row_merge.append(np.concatenate([merged, middle_pad], 1))\n            #\n            txt = text_map[i * FONT_MAX:(i + 1) * FONT_MAX,\n                           j * (vis_size + 2):(j + 1) * (vis_size + 2), :]\n            row_txt.append(txt)\n        # reorder\n        row_new = []\n        row_merge_new = []\n        txt_new = []\n        for j in range(num_attn):\n            idx = sorted_indices[j]\n            row_new.append(row[idx])\n            row_merge_new.append(row_merge[idx])\n            txt_new.append(row_txt[idx])\n        row = np.concatenate(row_new[:topK], 1)\n        row_merge = np.concatenate(row_merge_new[:topK], 1)\n        txt = np.concatenate(txt_new[:topK], 1)\n        if txt.shape[1] != row.shape[1]:\n            print('Warnings: txt', txt.shape, 'row', row.shape,\n                  'row_merge_new', row_merge_new.shape)\n            bUpdate = 0\n            break\n        row = np.concatenate([txt, row_merge], 0)\n        img_set.append(row)\n    if bUpdate:\n        img_set = np.concatenate(img_set, 0)\n        img_set = img_set.astype(np.uint8)\n        return img_set, sentences\n    else:\n        return None\n\n\ndef weights_init(m):\n    classname = m.__class__.__name__\n    if classname.find('Conv') != -1:\n        nn.init.orthogonal(m.weight.data, 1.0)\n    elif classname.find('BatchNorm') != -1:\n        m.weight.data.normal_(1.0, 0.02)\n        m.bias.data.fill_(0)\n    elif classname.find('Linear') != -1:\n        nn.init.orthogonal(m.weight.data, 1.0)\n        if m.bias is not None:\n            m.bias.data.fill_(0.0)\n\n\ndef load_params(model, new_param):\n    for p, new_p in zip(model.parameters(), new_param):\n        p.data.copy_(new_p)\n\n\ndef copy_G_params(model):\n    flatten = deepcopy(list(p.data for p in model.parameters()))\n    return flatten\n\n\ndef mkdir_p(path):\n    try:\n        os.makedirs(path)\n    except OSError as exc:  # Python >2.5\n        if exc.errno == errno.EEXIST and os.path.isdir(path):\n            pass\n        else:\n            raise\n"
  },
  {
    "path": "model.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.parallel\nfrom torch.autograd import Variable\nfrom torchvision import models\nimport torch.utils.model_zoo as model_zoo\nimport torch.nn.functional as F\nfrom torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\nfrom cfg.config import cfg\nfrom GLAttention import GLAttentionGeneral as ATT_NET\n\n\nclass GLU(nn.Module):\n    def __init__(self):\n        super(GLU, self).__init__()\n\n    def forward(self, x):\n        nc = x.size(1)\n        assert nc % 2 == 0, 'channels dont divide 2!'\n        nc = int(nc/2)\n        return x[:, :nc] * F.sigmoid(x[:, nc:])\n\n\ndef conv1x1(in_planes, out_planes, bias=False):\n    \"1x1 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,\n                     padding=0, bias=bias)\n\ndef conv3x3(in_planes, out_planes):\n    \"3x3 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1,\n                     padding=1, bias=False)\n\n# Upsale the spatial size by a factor of 2\ndef upBlock(in_planes, out_planes):\n    block = nn.Sequential(\n        nn.Upsample(scale_factor=2, mode='nearest'),\n        conv3x3(in_planes, out_planes * 2),\n        nn.BatchNorm2d(out_planes * 2),\n        GLU())\n    return block\n\n# Keep the spatial size\ndef Block3x3_relu(in_planes, out_planes):\n    block = nn.Sequential(\n        conv3x3(in_planes, out_planes * 2),\n        nn.BatchNorm2d(out_planes * 2),\n        GLU())\n    return block\n\n\nclass ResBlock(nn.Module):\n    def __init__(self, channel_num):\n        super(ResBlock, self).__init__()\n        self.block = nn.Sequential(\n            conv3x3(channel_num, channel_num * 2),\n            nn.BatchNorm2d(channel_num * 2),\n            GLU(),\n            conv3x3(channel_num, channel_num),\n            nn.BatchNorm2d(channel_num))\n\n    def forward(self, x):\n        residual = x\n        out = self.block(x)\n        out += residual\n        return out\n\n\n# ############## Text2Image Encoder-Decoder #######\nclass RNN_ENCODER(nn.Module):\n    def __init__(self, ntoken, ninput=300, drop_prob=0.5,\n                 nhidden=128, nlayers=1, bidirectional=True):\n        super(RNN_ENCODER, self).__init__()\n        self.n_steps = cfg.TEXT.WORDS_NUM\n        self.ntoken = ntoken  # size of the dictionary\n        self.ninput = ninput  # size of each embedding vector\n        self.drop_prob = drop_prob  # probability of an element to be zeroed\n        self.nlayers = nlayers  # Number of recurrent layers\n        self.bidirectional = bidirectional\n        self.rnn_type = cfg.RNN_TYPE\n        if bidirectional:\n            self.num_directions = 2\n        else:\n            self.num_directions = 1\n        # number of features in the hidden state\n        self.nhidden = nhidden // self.num_directions\n\n        self.define_module()\n        self.init_weights()\n\n    def define_module(self):\n        self.encoder = nn.Embedding(self.ntoken, self.ninput)\n        self.drop = nn.Dropout(self.drop_prob)\n        if self.rnn_type == 'LSTM':\n            # dropout: If non-zero, introduces a dropout layer on\n            # the outputs of each RNN layer except the last layer\n            self.rnn = nn.LSTM(self.ninput, self.nhidden,\n                               self.nlayers, batch_first=True,\n                               dropout=self.drop_prob,\n                               bidirectional=self.bidirectional)\n        elif self.rnn_type == 'GRU':\n            self.rnn = nn.GRU(self.ninput, self.nhidden,\n                              self.nlayers, batch_first=True,\n                              dropout=self.drop_prob,\n                              bidirectional=self.bidirectional)\n        else:\n            raise NotImplementedError\n\n    def init_weights(self):\n        initrange = 0.1\n        self.encoder.weight.data.uniform_(-initrange, initrange)\n        # Do not need to initialize RNN parameters, which have been initialized\n        # http://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html#LSTM\n        # self.decoder.weight.data.uniform_(-initrange, initrange)\n        # self.decoder.bias.data.fill_(0)\n\n    def init_hidden(self, bsz):\n        weight = next(self.parameters()).data\n        if self.rnn_type == 'LSTM':\n            return (Variable(weight.new(self.nlayers * self.num_directions,\n                                        bsz, self.nhidden).zero_()),\n                    Variable(weight.new(self.nlayers * self.num_directions,\n                                        bsz, self.nhidden).zero_()))\n        else:\n            return Variable(weight.new(self.nlayers * self.num_directions,\n                                       bsz, self.nhidden).zero_())\n\n    def forward(self, captions, cap_lens, hidden, mask=None):\n        # input: torch.LongTensor of size batch x n_steps\n        # --> emb: batch x n_steps x ninput\n        emb = self.drop(self.encoder(captions))\n        #\n        # Returns: a PackedSequence object\n        cap_lens = cap_lens.data.tolist()\n        emb = pack_padded_sequence(emb, cap_lens, batch_first=True)\n        # #hidden and memory (num_layers * num_directions, batch, hidden_size):\n        # tensor containing the initial hidden state for each element in batch.\n        # #output (batch, seq_len, hidden_size * num_directions)\n        # #or a PackedSequence object:\n        # tensor containing output features (h_t) from the last layer of RNN\n        output, hidden = self.rnn(emb, hidden)\n        # PackedSequence object\n        # --> (batch, seq_len, hidden_size * num_directions)\n        output = pad_packed_sequence(output, batch_first=True)[0]\n        # output = self.drop(output)\n        # --> batch x hidden_size*num_directions x seq_len\n        words_emb = output.transpose(1, 2)\n        # --> batch x num_directions*hidden_size\n        if self.rnn_type == 'LSTM':\n            sent_emb = hidden[0].transpose(0, 1).contiguous()\n        else:\n            sent_emb = hidden.transpose(0, 1).contiguous()\n        sent_emb = sent_emb.view(-1, self.nhidden * self.num_directions)\n        return words_emb, sent_emb\n\n\nclass CNN_ENCODER(nn.Module):\n    def __init__(self, nef):\n        super(CNN_ENCODER, self).__init__()\n        if cfg.TRAIN.FLAG:\n            self.nef = nef\n        else:\n            self.nef = 256  # define a uniform ranker\n\n        model = models.inception_v3()\n        url = 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth'\n        model.load_state_dict(model_zoo.load_url(url))\n        for param in model.parameters():\n            param.requires_grad = False\n        print('Load pretrained model from ', url)\n        # print(model)\n\n        self.define_module(model)\n        self.init_trainable_weights()\n\n    def define_module(self, model):\n        self.Conv2d_1a_3x3 = model.Conv2d_1a_3x3\n        self.Conv2d_2a_3x3 = model.Conv2d_2a_3x3\n        self.Conv2d_2b_3x3 = model.Conv2d_2b_3x3\n        self.Conv2d_3b_1x1 = model.Conv2d_3b_1x1\n        self.Conv2d_4a_3x3 = model.Conv2d_4a_3x3\n        self.Mixed_5b = model.Mixed_5b\n        self.Mixed_5c = model.Mixed_5c\n        self.Mixed_5d = model.Mixed_5d\n        self.Mixed_6a = model.Mixed_6a\n        self.Mixed_6b = model.Mixed_6b\n        self.Mixed_6c = model.Mixed_6c\n        self.Mixed_6d = model.Mixed_6d\n        self.Mixed_6e = model.Mixed_6e\n        self.Mixed_7a = model.Mixed_7a\n        self.Mixed_7b = model.Mixed_7b\n        self.Mixed_7c = model.Mixed_7c\n\n        self.emb_features = conv1x1(768, self.nef)\n        self.emb_cnn_code = nn.Linear(2048, self.nef)\n\n    def init_trainable_weights(self):\n        initrange = 0.1\n        self.emb_features.weight.data.uniform_(-initrange, initrange)\n        self.emb_cnn_code.weight.data.uniform_(-initrange, initrange)\n\n    def forward(self, x):\n        features = None\n        # --> fixed-size input: batch x 3 x 299 x 299\n        x = nn.Upsample(size=(299, 299), mode='bilinear')(x)\n        # 299 x 299 x 3\n        x = self.Conv2d_1a_3x3(x)\n        # 149 x 149 x 32\n        x = self.Conv2d_2a_3x3(x)\n        # 147 x 147 x 32\n        x = self.Conv2d_2b_3x3(x)\n        # 147 x 147 x 64\n        x = F.max_pool2d(x, kernel_size=3, stride=2)\n        # 73 x 73 x 64\n        x = self.Conv2d_3b_1x1(x)\n        # 73 x 73 x 80\n        x = self.Conv2d_4a_3x3(x)\n        # 71 x 71 x 192\n\n        x = F.max_pool2d(x, kernel_size=3, stride=2)\n        # 35 x 35 x 192\n        x = self.Mixed_5b(x)\n        # 35 x 35 x 256\n        x = self.Mixed_5c(x)\n        # 35 x 35 x 288\n        x = self.Mixed_5d(x)\n        # 35 x 35 x 288\n\n        x = self.Mixed_6a(x)\n        # 17 x 17 x 768\n        x = self.Mixed_6b(x)\n        # 17 x 17 x 768\n        x = self.Mixed_6c(x)\n        # 17 x 17 x 768\n        x = self.Mixed_6d(x)\n        # 17 x 17 x 768\n        x = self.Mixed_6e(x)\n        # 17 x 17 x 768\n\n        # image region features\n        features = x\n        # 17 x 17 x 768\n\n        x = self.Mixed_7a(x)\n        # 8 x 8 x 1280\n        x = self.Mixed_7b(x)\n        # 8 x 8 x 2048\n        x = self.Mixed_7c(x)\n        # 8 x 8 x 2048\n        x = F.avg_pool2d(x, kernel_size=8)\n        # 1 x 1 x 2048\n        # x = F.dropout(x, training=self.training)\n        # 1 x 1 x 2048\n        x = x.view(x.size(0), -1)\n        # 2048\n\n        # global image features\n        cnn_code = self.emb_cnn_code(x)\n        # 512\n        if features is not None:\n            features = self.emb_features(features)\n        return features, cnn_code\n\n\n# ############## G networks ###################\nclass CA_NET(nn.Module):\n    # some code is modified from vae examples\n    # (https://github.com/pytorch/examples/blob/master/vae/main.py)\n    def __init__(self):\n        super(CA_NET, self).__init__()\n        self.t_dim = cfg.TEXT.EMBEDDING_DIM\n        self.c_dim = cfg.GAN.CONDITION_DIM\n        self.fc = nn.Linear(self.t_dim, self.c_dim * 4, bias=True)\n        self.relu = GLU()\n\n    def encode(self, text_embedding):\n        x = self.relu(self.fc(text_embedding))\n        mu = x[:, :self.c_dim]\n        logvar = x[:, self.c_dim:]\n        return mu, logvar\n\n    def reparametrize(self, mu, logvar):\n        std = logvar.mul(0.5).exp_()\n        if cfg.CUDA:\n            eps = torch.cuda.FloatTensor(std.size()).normal_()\n        else:\n            eps = torch.FloatTensor(std.size()).normal_()\n        eps = Variable(eps)\n        return eps.mul(std).add_(mu)\n\n    def forward(self, text_embedding):\n        mu, logvar = self.encode(text_embedding)\n        c_code = self.reparametrize(mu, logvar)\n        return c_code, mu, logvar\n\n\nclass INIT_STAGE_G(nn.Module):\n    def __init__(self, ngf, ncf):\n        super(INIT_STAGE_G, self).__init__()\n        self.gf_dim = ngf\n        self.in_dim = cfg.GAN.Z_DIM + ncf  # cfg.TEXT.EMBEDDING_DIM\n\n        self.define_module()\n\n    def define_module(self):\n        nz, ngf = self.in_dim, self.gf_dim\n        self.fc = nn.Sequential(\n            nn.Linear(nz, ngf * 4 * 4 * 2, bias=False),\n            nn.BatchNorm1d(ngf * 4 * 4 * 2),\n            GLU())\n\n        self.upsample1 = upBlock(ngf, ngf // 2)\n        self.upsample2 = upBlock(ngf // 2, ngf // 4)\n        self.upsample3 = upBlock(ngf // 4, ngf // 8)\n        self.upsample4 = upBlock(ngf // 8, ngf // 16)\n\n    def forward(self, z_code, c_code):\n        \"\"\"\n        :param z_code: batch x cfg.GAN.Z_DIM\n        :param c_code: batch x cfg.TEXT.EMBEDDING_DIM\n        :return: batch x ngf/16 x 64 x 64\n        \"\"\"\n        c_z_code = torch.cat((c_code, z_code), 1)\n        # state size ngf x 4 x 4\n        out_code = self.fc(c_z_code)\n        out_code = out_code.view(-1, self.gf_dim, 4, 4)\n        # state size ngf/3 x 8 x 8\n        out_code = self.upsample1(out_code)\n        # state size ngf/4 x 16 x 16\n        out_code = self.upsample2(out_code)\n        # state size ngf/8 x 32 x 32\n        out_code32 = self.upsample3(out_code)\n        # state size ngf/16 x 64 x 64\n        out_code64 = self.upsample4(out_code32)\n\n        return out_code64\n\n# class NEXT_STAGE_G(nn.Module):\n#     def __init__(self, ngf, nef, ncf):\n#         super(NEXT_STAGE_G, self).__init__()\n#         self.gf_dim = ngf\n#         self.ef_dim = nef\n#         self.cf_dim = ncf\n#         self.num_residual = cfg.GAN.R_NUM\n#         self.define_module()\n#\n#     def _make_layer(self, block, channel_num):\n#         layers = []\n#         for i in range(cfg.GAN.R_NUM):\n#             layers.append(block(channel_num))\n#         return nn.Sequential(*layers)\n#\n#     def define_module(self):\n#         ngf = self.gf_dim\n#         self.att = ATT_NET(ngf, self.ef_dim)\n#         self.residual = self._make_layer(ResBlock, ngf * 2)\n#         self.upsample = upBlock(ngf * 2, ngf)\n#\n#     def forward(self, h_code, c_code, word_embs, mask):\n#         \"\"\"\n#             h_code1(query):  batch x idf x ih x iw (queryL=ihxiw)\n#             word_embs(context): batch x cdf x sourceL (sourceL=seq_len)\n#             c_code1: batch x idf x queryL\n#             att1: batch x sourceL x queryL\n#         \"\"\"\n#         self.att.applyMask(mask)\n#         c_code, att = self.att(h_code, word_embs)\n#         h_c_code = torch.cat((h_code, c_code), 1)\n#         print('h_c_code:', h_c_code.size()) \\\n#             ('h_c_code:', (16, 64, 64, 64))\n#             ('h_c_code:', (16, 64, 128, 128))\n#         out_code = self.residual(h_c_code)\n#\n#         # state size ngf/2 x 2in_size x 2in_size\n#         out_code = self.upsample(out_code)\n#\n#         return out_code, att\n\n\nclass NEXT_STAGE_G(nn.Module):\n    def __init__(self, ngf, nef, ncf):\n        super(NEXT_STAGE_G, self).__init__()\n        self.gf_dim = ngf\n        self.ef_dim = nef\n        self.cf_dim = ncf\n        # print(ngf, nef, ncf)  (32, 256, 100)\n        # (32, 256, 100)\n        self.num_residual = cfg.GAN.R_NUM\n        self.define_module()\n        self.conv = conv1x1(ngf * 3, ngf * 2)\n\n    def _make_layer(self, block, channel_num):\n        layers = []\n        for i in range(cfg.GAN.R_NUM): # 2\n            layers.append(block(channel_num))\n        return nn.Sequential(*layers)\n\n    def define_module(self):\n        ngf = self.gf_dim\n        self.att = ATT_NET(ngf, self.ef_dim)\n        self.residual = self._make_layer(ResBlock, ngf * 2)\n        self.upsample = upBlock(ngf * 2, ngf)\n\n    def forward(self, h_code, c_code, word_embs, mask):\n        \"\"\"\n            h_code1(query):  batch x idf x ih x iw (queryL=ihxiw)\n            word_embs(context): batch x cdf x sourceL (sourceL=seq_len)\n            c_code1: batch x idf x queryL\n            att1: batch x sourceL x queryL\n        \"\"\"\n        # print('========')\n        # ((16, 32, 64, 64), (16, 100), (16, 256, 18), (16, 18))\n        # print(h_code.size(), c_code.size(), word_embs.size(), mask.size())\n        self.att.applyMask(mask)\n        # here, a new c_code is generated by self.att() method.\n        # weightedContext, weightedSentence, word_attn, sent_vs_att\n        c_code, weightedSentence, att, sent_att = self.att(h_code, c_code, word_embs)\n        # Then, image feature are concated with a new c_code, they become h_c_code,\n        # so, here I can make some change, to concate more items together.\n        # which means I need to get more output from line 369, self.att()\n        # also, I need to feed more information to calculate the function, and let's see what the new idea will return.\n        h_c_code = torch.cat((h_code, c_code), 1)\n        # print('h_c_code.size:', h_c_code.size())  # ('h_c_code.size:', (16, 64, 64, 64))\n        h_c_sent_code = torch.cat((h_c_code, weightedSentence), 1)\n        # print('h_c_sent_code.size:', h_c_sent_code.size())\n        # ('h_c_code.size:', (16, 64, 64, 64))\n        # ('h_c_sent_code.size:', (16, 96, 64, 64))\n        h_c_sent_code = self.conv(h_c_sent_code)\n        out_code = self.residual(h_c_sent_code)\n        # print('out_code:', out_code.size())\n        # state size ngf/2 x 2in_size x 2in_size\n        out_code = self.upsample(out_code)\n        return out_code, att\n\n\nclass GET_IMAGE_G(nn.Module):\n    def __init__(self, ngf):\n        super(GET_IMAGE_G, self).__init__()\n        self.gf_dim = ngf\n        self.img = nn.Sequential(\n            conv3x3(ngf, 3),\n            nn.Tanh()\n        )\n\n    def forward(self, h_code):\n        out_img = self.img(h_code)\n        return out_img\n\n#G_NET used in the paper\nclass G_NET(nn.Module):\n    def __init__(self):\n        super(G_NET, self).__init__()\n        ngf = cfg.GAN.GF_DIM\n        nef = cfg.TEXT.EMBEDDING_DIM\n        ncf = cfg.GAN.CONDITION_DIM\n        self.ca_net = CA_NET()\n\n        if cfg.TREE.BRANCH_NUM > 0:\n            self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)\n            self.img_net1 = GET_IMAGE_G(ngf)\n        # gf x 64 x 64\n        if cfg.TREE.BRANCH_NUM > 1:\n            self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)\n            self.img_net2 = GET_IMAGE_G(ngf)\n        if cfg.TREE.BRANCH_NUM > 2:\n            self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf)\n            self.img_net3 = GET_IMAGE_G(ngf)\n    # netG(noise, sent_emb, words_embs, mask)\n    def forward(self, z_code, sent_emb, word_embs, mask):\n        \"\"\"\n            :param z_code: batch x cfg.GAN.Z_DIM\n            :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM\n            :param word_embs: batch x cdf x seq_len\n            :param mask: batch x seq_len\n            :return:\n        \"\"\"\n        fake_imgs = []\n        att_maps = []\n        '''this is the Conditioning Augmentation'''\n        # print('sent_emb:', sent_emb.size())  #('sent_emb:', (16, 256))\n        c_code, mu, logvar = self.ca_net(sent_emb)\n        # print('=====')\n        # print('first c_code.size():', c_code.size())  #(16, 100)\n        # print('=====')\n        if cfg.TREE.BRANCH_NUM > 0:\n            h_code1 = self.h_net1(z_code, c_code)\n            fake_img1 = self.img_net1(h_code1)\n            fake_imgs.append(fake_img1)\n        if cfg.TREE.BRANCH_NUM > 1:\n            h_code2, att1 = \\\n                self.h_net2(h_code1, c_code, word_embs, mask)\n            fake_img2 = self.img_net2(h_code2)\n            fake_imgs.append(fake_img2)\n            if att1 is not None:\n                att_maps.append(att1)\n        if cfg.TREE.BRANCH_NUM > 2:\n            h_code3, att2 = \\\n                self.h_net3(h_code2, c_code, word_embs, mask)\n            fake_img3 = self.img_net3(h_code3)\n            fake_imgs.append(fake_img3)\n            if att2 is not None:\n                att_maps.append(att2)\n\n        return fake_imgs, att_maps, mu, logvar\n\n\n\nclass G_DCGAN(nn.Module):\n    def __init__(self):\n        super(G_DCGAN, self).__init__()\n        ngf = cfg.GAN.GF_DIM\n        nef = cfg.TEXT.EMBEDDING_DIM\n        ncf = cfg.GAN.CONDITION_DIM\n        self.ca_net = CA_NET()\n\n        # 16gf x 64 x 64 --> gf x 64 x 64 --> 3 x 64 x 64\n        if cfg.TREE.BRANCH_NUM > 0:\n            self.h_net1 = INIT_STAGE_G(ngf * 16, ncf)\n        # gf x 64 x 64\n        if cfg.TREE.BRANCH_NUM > 1:\n            self.h_net2 = NEXT_STAGE_G(ngf, nef, ncf)\n        if cfg.TREE.BRANCH_NUM > 2:\n            self.h_net3 = NEXT_STAGE_G(ngf, nef, ncf)\n        self.img_net = GET_IMAGE_G(ngf)\n\n    def forward(self, z_code, sent_emb, word_embs, mask):\n        \"\"\"\n            :param z_code: batch x cfg.GAN.Z_DIM\n            :param sent_emb: batch x cfg.TEXT.EMBEDDING_DIM\n            :param word_embs: batch x cdf x seq_len\n            :param mask: batch x seq_len\n            :return:\n        \"\"\"\n        att_maps = []\n        c_code, mu, logvar = self.ca_net(sent_emb)\n        if cfg.TREE.BRANCH_NUM > 0:\n            h_code = self.h_net1(z_code, c_code)\n        if cfg.TREE.BRANCH_NUM > 1:\n            h_code, att1 = self.h_net2(h_code, c_code, word_embs, mask)\n            if att1 is not None:\n                att_maps.append(att1)\n        if cfg.TREE.BRANCH_NUM > 2:\n            h_code, att2 = self.h_net3(h_code, c_code, word_embs, mask)\n            if att2 is not None:\n                att_maps.append(att2)\n\n        fake_imgs = self.img_net(h_code)\n        return [fake_imgs], att_maps, mu, logvar\n\n\n# ############## D networks ##########################\ndef Block3x3_leakRelu(in_planes, out_planes):\n    block = nn.Sequential(\n        conv3x3(in_planes, out_planes),\n        nn.BatchNorm2d(out_planes),\n        nn.LeakyReLU(0.2, inplace=True)\n    )\n    return block\n\n\n# Downsale the spatial size by a factor of 2\ndef downBlock(in_planes, out_planes):\n    block = nn.Sequential(\n        nn.Conv2d(in_planes, out_planes, 4, 2, 1, bias=False),\n        nn.BatchNorm2d(out_planes),\n        nn.LeakyReLU(0.2, inplace=True)\n    )\n    return block\n\n\n# Downsale the spatial size by a factor of 16\ndef encode_image_by_16times(ndf):\n    encode_img = nn.Sequential(\n        # --> state size. ndf x in_size/2 x in_size/2\n        nn.Conv2d(3, ndf, 4, 2, 1, bias=False),\n        nn.LeakyReLU(0.2, inplace=True),\n        # --> state size 2ndf x x in_size/4 x in_size/4\n        nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),\n        nn.BatchNorm2d(ndf * 2),\n        nn.LeakyReLU(0.2, inplace=True),\n        # --> state size 4ndf x in_size/8 x in_size/8\n        nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),\n        nn.BatchNorm2d(ndf * 4),\n        nn.LeakyReLU(0.2, inplace=True),\n        # --> state size 8ndf x in_size/16 x in_size/16\n        nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),\n        nn.BatchNorm2d(ndf * 8),\n        nn.LeakyReLU(0.2, inplace=True)\n    )\n    return encode_img\n\n\nclass D_GET_LOGITS(nn.Module):\n    def __init__(self, ndf, nef, bcondition=False):\n        super(D_GET_LOGITS, self).__init__()\n        self.df_dim = ndf\n        self.ef_dim = nef\n        self.bcondition = bcondition\n        if self.bcondition:\n            self.jointConv = Block3x3_leakRelu(ndf * 8 + nef, ndf * 8)\n\n        self.outlogits = nn.Sequential(\n            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=4),\n            nn.Sigmoid())\n\n    def forward(self, h_code, c_code=None):\n        if self.bcondition and c_code is not None:\n            # conditioning output\n            c_code = c_code.view(-1, self.ef_dim, 1, 1)\n            c_code = c_code.repeat(1, 1, 4, 4)\n            # state size (ngf+egf) x 4 x 4\n            h_c_code = torch.cat((h_code, c_code), 1)\n            # state size ngf x in_size x in_size\n            h_c_code = self.jointConv(h_c_code)\n        else:\n            h_c_code = h_code\n        output = self.outlogits(h_c_code)\n        return output.view(-1)\n\n\n# For 64 x 64 images\nclass D_NET64(nn.Module):\n    def __init__(self, b_jcu=True):\n        super(D_NET64, self).__init__()\n        ndf = cfg.GAN.DF_DIM\n        nef = cfg.TEXT.EMBEDDING_DIM\n        self.img_code_s16 = encode_image_by_16times(ndf)\n        if b_jcu:\n            self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)\n        else:\n            self.UNCOND_DNET = None\n        self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)\n\n    def forward(self, x_var):\n        x_code4 = self.img_code_s16(x_var)  # 4 x 4 x 8df\n        return x_code4\n\n\n# For 128 x 128 images\nclass D_NET128(nn.Module):\n    def __init__(self, b_jcu=True):\n        super(D_NET128, self).__init__()\n        ndf = cfg.GAN.DF_DIM\n        nef = cfg.TEXT.EMBEDDING_DIM\n        self.img_code_s16 = encode_image_by_16times(ndf)\n        self.img_code_s32 = downBlock(ndf * 8, ndf * 16)\n        self.img_code_s32_1 = Block3x3_leakRelu(ndf * 16, ndf * 8)\n        #\n        if b_jcu:\n            self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)\n        else:\n            self.UNCOND_DNET = None\n        self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)\n\n    def forward(self, x_var):\n        x_code8 = self.img_code_s16(x_var)   # 8 x 8 x 8df\n        x_code4 = self.img_code_s32(x_code8)   # 4 x 4 x 16df\n        x_code4 = self.img_code_s32_1(x_code4)  # 4 x 4 x 8df\n        return x_code4\n\n\n# For 256 x 256 images\nclass D_NET256(nn.Module):\n    def __init__(self, b_jcu=True):\n        super(D_NET256, self).__init__()\n        ndf = cfg.GAN.DF_DIM\n        nef = cfg.TEXT.EMBEDDING_DIM\n        self.img_code_s16 = encode_image_by_16times(ndf)\n        self.img_code_s32 = downBlock(ndf * 8, ndf * 16)\n        self.img_code_s64 = downBlock(ndf * 16, ndf * 32)\n        self.img_code_s64_1 = Block3x3_leakRelu(ndf * 32, ndf * 16)\n        self.img_code_s64_2 = Block3x3_leakRelu(ndf * 16, ndf * 8)\n        if b_jcu:\n            self.UNCOND_DNET = D_GET_LOGITS(ndf, nef, bcondition=False)\n        else:\n            self.UNCOND_DNET = None\n        self.COND_DNET = D_GET_LOGITS(ndf, nef, bcondition=True)\n\n    def forward(self, x_var):\n        x_code16 = self.img_code_s16(x_var)\n        x_code8 = self.img_code_s32(x_code16)\n        x_code4 = self.img_code_s64(x_code8)\n        x_code4 = self.img_code_s64_1(x_code4)\n        x_code4 = self.img_code_s64_2(x_code4)\n        return x_code4\nclass CAPTION_CNN(nn.Module):\n    def __init__(self, embed_size):\n        \"\"\"Load the pretrained ResNet-152 and replace top fc layer.\"\"\"\n        super(CAPTION_CNN, self).__init__()\n        resnet = models.resnet152(pretrained=True)\n        modules = list(resnet.children())[:-1]  # delete the last fc layer.\n        self.resnet = nn.Sequential(*modules)\n        for param in self.resnet.parameters():\n            param.requires_grad = False\n        self.linear = nn.Linear(resnet.fc.in_features, embed_size)\n        self.bn = nn.BatchNorm1d(embed_size, momentum=0.01)\n\n    def forward(self, images):\n        \"\"\"Extract feature vectors from input images.\"\"\"\n        #print ('image feature size before unsample:', images.size())\n        m = nn.Upsample(size=(224, 224), mode='bilinear')\n        unsampled_images = m(images)\n        #print ('image feature size after unsample:', unsampled_images.size())\n        features = self.resnet(unsampled_images)\n        features = features.view(features.size(0), -1)\n        features = self.bn(self.linear(features))\n        return features\n\nclass CAPTION_RNN(nn.Module):\n    def __init__(self, embed_size, hidden_size, vocab_size, num_layers, max_seq_length=20):\n        \"\"\"Set the hyper-parameters and build the layers.\"\"\"\n        super(CAPTION_RNN, self).__init__()\n        self.embed = nn.Embedding(vocab_size, embed_size)\n        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)\n        self.linear = nn.Linear(hidden_size, vocab_size)\n        self.max_seg_length = max_seq_length\n\n    # def forward(self, features, captions, cap_lens):\n    #     \"\"\"Decode image feature vectors and generates captions.\"\"\"\n    #     # print ('feature.size():', features.size()) #(6L, 256L)\n    #     # print ('captions.size():', captions.size()) # (6L, 12L)\n    #     # print ('embeddings.size:',embeddings.size()) #(6L, 12L, 256L)\n    #     embeddings = self.embed(captions)\n    #     embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)\n    #     packed = pack_padded_sequence(embeddings, cap_lens.data.tolist(), batch_first=True)\n    #     outputs, hidden = self.lstm(packed)\n    #     output = self.linear(outputs[0])   # (batch size, vocab_size)\n    #     return output, hidden, outputs     # words embedding, sentence embedding\n\n    def forward(self, features, captions, cap_lens):\n        \"\"\"Decode image feature vectors and generates captions.\"\"\"\n        embeddings = self.embed(captions)\n        embeddings = torch.cat((features.unsqueeze(1), embeddings), 1)\n        packed = pack_padded_sequence(embeddings, cap_lens, batch_first=True)\n        hiddens, _ = self.lstm(packed)\n        outputs = self.linear(hiddens[0])\n        return outputs\n\n    def sample(self, features, states=None):\n        \"\"\"Generate captions for given image features using greedy search.\"\"\"\n        sampled_ids = []\n        inputs = features.unsqueeze(1)\n        for i in range(self.max_seg_length):\n            hiddens, states = self.lstm(inputs, states)  # hiddens: (batch_size, 1, hidden_size)\n            outputs = self.linear(hiddens.squeeze(1))  # outputs:  (batch_size, vocab_size)\n            _, predicted = outputs.max(1)  # predicted: (batch_size)\n            sampled_ids.append(predicted)\n            inputs = self.embed(predicted)  # inputs: (batch_size, embed_size)\n            inputs = inputs.unsqueeze(1)  # inputs: (batch_size, 1, embed_size)\n        sampled_ids = torch.stack(sampled_ids, 1)  # sampled_ids: (batch_size, max_seq_length)\n        return sampled_ids"
  },
  {
    "path": "pretrain_DAMSM.py",
    "content": "from __future__ import print_function\n\nfrom miscc.utils import mkdir_p\nfrom miscc.utils import build_super_images\nfrom miscc.losses import sent_loss, words_loss\nfrom cfg.config import cfg, cfg_from_file\n\nfrom datasets import TextDataset\nfrom datasets import prepare_data\n\nfrom model import RNN_ENCODER, CNN_ENCODER\n\nimport os\nimport sys\nimport time\nimport random\nimport pprint\nimport datetime\nimport dateutil.tz\nimport argparse\nimport numpy as np\nfrom PIL import Image\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom torch.autograd import Variable\nimport torch.backends.cudnn as cudnn\nimport torchvision.transforms as transforms\n\n\ndir_path = (os.path.abspath(os.path.join(os.path.realpath(__file__), './.')))\nsys.path.append(dir_path)\n\n\nUPDATE_INTERVAL = 200\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Train a DAMSM network')\n    parser.add_argument('--cfg', dest='cfg_file',\n                        help='optional config file',\n                        default='cfg/DAMSM/bird.yml', type=str)\n    parser.add_argument('--gpu', dest='gpu_id', type=int, default=0)\n    parser.add_argument('--data_dir', dest='data_dir', type=str, default='')\n    parser.add_argument('--manualSeed', type=int, help='manual seed')\n    args = parser.parse_args()\n    return args\n\n\ndef train(dataloader, cnn_model, rnn_model, batch_size,\n          labels, optimizer, epoch, ixtoword, image_dir):\n    cnn_model.train()\n    rnn_model.train()\n    s_total_loss0 = 0\n    s_total_loss1 = 0\n    w_total_loss0 = 0\n    w_total_loss1 = 0\n    count = (epoch + 1) * len(dataloader)\n    start_time = time.time()\n    for step, data in enumerate(dataloader, 0):\n        # print('step', step)\n        rnn_model.zero_grad()\n        cnn_model.zero_grad()\n\n        imgs, captions, cap_lens, \\\n            class_ids, keys = prepare_data(data)\n\n\n        # words_features: batch_size x nef x 17 x 17\n        # sent_code: batch_size x nef\n        words_features, sent_code = cnn_model(imgs[-1])\n        # --> batch_size x nef x 17*17\n        nef, att_sze = words_features.size(1), words_features.size(2)\n        # words_features = words_features.view(batch_size, nef, -1)\n\n        hidden = rnn_model.init_hidden(batch_size)\n        # words_emb: batch_size x nef x seq_len\n        # sent_emb: batch_size x nef\n        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)\n\n        w_loss0, w_loss1, attn_maps = words_loss(words_features, words_emb, labels,\n                                                 cap_lens, class_ids, batch_size)\n        w_total_loss0 += w_loss0.data\n        w_total_loss1 += w_loss1.data\n        loss = w_loss0 + w_loss1\n\n        s_loss0, s_loss1 = \\\n            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)\n        loss += s_loss0 + s_loss1\n        s_total_loss0 += s_loss0.data\n        s_total_loss1 += s_loss1.data\n        #\n        loss.backward()\n        #\n        # `clip_grad_norm` helps prevent\n        # the exploding gradient problem in RNNs / LSTMs.\n        torch.nn.utils.clip_grad_norm(rnn_model.parameters(),\n                                      cfg.TRAIN.RNN_GRAD_CLIP)\n        optimizer.step()\n\n        if step % UPDATE_INTERVAL == 0:\n            count = epoch * len(dataloader) + step\n\n            s_cur_loss0 = s_total_loss0[0] / UPDATE_INTERVAL\n            s_cur_loss1 = s_total_loss1[0] / UPDATE_INTERVAL\n\n            w_cur_loss0 = w_total_loss0[0] / UPDATE_INTERVAL\n            w_cur_loss1 = w_total_loss1[0] / UPDATE_INTERVAL\n\n            elapsed = time.time() - start_time\n            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '\n                  's_loss {:5.2f} {:5.2f} | '\n                  'w_loss {:5.2f} {:5.2f}'\n                  .format(epoch, step, len(dataloader),\n                          elapsed * 1000. / UPDATE_INTERVAL,\n                          s_cur_loss0, s_cur_loss1,\n                          w_cur_loss0, w_cur_loss1))\n            s_total_loss0 = 0\n            s_total_loss1 = 0\n            w_total_loss0 = 0\n            w_total_loss1 = 0\n            start_time = time.time()\n            # attention Maps\n            img_set, _ = \\\n                build_super_images(imgs[-1].cpu(), captions,\n                                   ixtoword, attn_maps, att_sze)\n            if img_set is not None:\n                im = Image.fromarray(img_set)\n                fullpath = '%s/attention_maps%d.png' % (image_dir, step)\n                im.save(fullpath)\n    return count\n\n\ndef evaluate(dataloader, cnn_model, rnn_model, batch_size):\n    cnn_model.eval()\n    rnn_model.eval()\n    s_total_loss = 0\n    w_total_loss = 0\n    for step, data in enumerate(dataloader, 0):\n        real_imgs, captions, cap_lens, \\\n                class_ids, keys = prepare_data(data)\n\n        words_features, sent_code = cnn_model(real_imgs[-1])\n        # nef = words_features.size(1)\n        # words_features = words_features.view(batch_size, nef, -1)\n\n        hidden = rnn_model.init_hidden(batch_size)\n        words_emb, sent_emb = rnn_model(captions, cap_lens, hidden)\n\n        w_loss0, w_loss1, attn = words_loss(words_features, words_emb, labels,\n                                            cap_lens, class_ids, batch_size)\n        w_total_loss += (w_loss0 + w_loss1).data\n\n        s_loss0, s_loss1 = \\\n            sent_loss(sent_code, sent_emb, labels, class_ids, batch_size)\n        s_total_loss += (s_loss0 + s_loss1).data\n\n        if step == 50:\n            break\n\n    s_cur_loss = s_total_loss[0] / step\n    w_cur_loss = w_total_loss[0] / step\n\n    return s_cur_loss, w_cur_loss\n\n\ndef build_models():\n    # build model ############################################################\n    text_encoder = RNN_ENCODER(dataset.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)\n    image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)\n    labels = Variable(torch.LongTensor(range(batch_size)))\n    start_epoch = 0\n    if cfg.TRAIN.NET_E != '':\n        state_dict = torch.load(cfg.TRAIN.NET_E)\n        text_encoder.load_state_dict(state_dict)\n        print('Load ', cfg.TRAIN.NET_E)\n        #\n        name = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')\n        state_dict = torch.load(name)\n        image_encoder.load_state_dict(state_dict)\n        print('Load ', name)\n\n        istart = cfg.TRAIN.NET_E.rfind('_') + 8\n        iend = cfg.TRAIN.NET_E.rfind('.')\n        start_epoch = cfg.TRAIN.NET_E[istart:iend]\n        start_epoch = int(start_epoch) + 1\n        print('start_epoch', start_epoch)\n    if cfg.CUDA:\n        text_encoder = text_encoder.cuda()\n        image_encoder = image_encoder.cuda()\n        labels = labels.cuda()\n\n    return text_encoder, image_encoder, labels, start_epoch\n\n\nif __name__ == \"__main__\":\n    args = parse_args()\n    if args.cfg_file is not None:\n        cfg_from_file(args.cfg_file)\n\n    if args.gpu_id == -1:\n        cfg.CUDA = False\n    else:\n        cfg.GPU_ID = args.gpu_id\n\n    if args.data_dir != '':\n        cfg.DATA_DIR = args.data_dir\n    print('Using config:')\n    pprint.pprint(cfg)\n\n    if not cfg.TRAIN.FLAG:\n        args.manualSeed = 100\n    elif args.manualSeed is None:\n        args.manualSeed = random.randint(1, 10000)\n    random.seed(args.manualSeed)\n    np.random.seed(args.manualSeed)\n    torch.manual_seed(args.manualSeed)\n    if cfg.CUDA:\n        torch.cuda.manual_seed_all(args.manualSeed)\n\n    ##########################################################################\n    now = datetime.datetime.now(dateutil.tz.tzlocal())\n    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')\n    output_dir = '../output/%s_%s_%s' % \\\n        (cfg.DATASET_NAME, cfg.CONFIG_NAME, timestamp)\n\n    model_dir = os.path.join(output_dir, 'Model')\n    image_dir = os.path.join(output_dir, 'Image')\n    mkdir_p(model_dir)\n    mkdir_p(image_dir)\n\n    torch.cuda.set_device(cfg.GPU_ID)\n    cudnn.benchmark = True\n\n    # Get data loader ##################################################\n    imsize = cfg.TREE.BASE_SIZE * (2 ** (cfg.TREE.BRANCH_NUM-1))\n    batch_size = cfg.TRAIN.BATCH_SIZE\n    image_transform = transforms.Compose([\n        transforms.Scale(int(imsize * 76 / 64)),\n        transforms.RandomCrop(imsize),\n        transforms.RandomHorizontalFlip()])\n    dataset = TextDataset(cfg.DATA_DIR, 'train',\n                          base_size=cfg.TREE.BASE_SIZE,\n                          transform=image_transform)\n\n    print(dataset.n_words, dataset.embeddings_num)\n    assert dataset\n    dataloader = torch.utils.data.DataLoader(\n        dataset, batch_size=batch_size, drop_last=True,\n        shuffle=True, num_workers=int(cfg.WORKERS))\n\n    # # validation data #\n    dataset_val = TextDataset(cfg.DATA_DIR, 'test',\n                              base_size=cfg.TREE.BASE_SIZE,\n                              transform=image_transform)\n    dataloader_val = torch.utils.data.DataLoader(\n        dataset_val, batch_size=batch_size, drop_last=True,\n        shuffle=True, num_workers=int(cfg.WORKERS))\n\n    # Train ##############################################################\n    text_encoder, image_encoder, labels, start_epoch = build_models()\n    para = list(text_encoder.parameters())\n    for v in image_encoder.parameters():\n        if v.requires_grad:\n            para.append(v)\n    # optimizer = optim.Adam(para, lr=cfg.TRAIN.ENCODER_LR, betas=(0.5, 0.999))\n    # At any point you can hit Ctrl + C to break out of training early.\n    try:\n        lr = cfg.TRAIN.ENCODER_LR\n        for epoch in range(start_epoch, cfg.TRAIN.MAX_EPOCH):\n            optimizer = optim.Adam(para, lr=lr, betas=(0.5, 0.999))\n            epoch_start_time = time.time()\n            count = train(dataloader, image_encoder, text_encoder,\n                          batch_size, labels, optimizer, epoch,\n                          dataset.ixtoword, image_dir)\n            print('-' * 89)\n            if len(dataloader_val) > 0:\n                s_loss, w_loss = evaluate(dataloader_val, image_encoder,\n                                          text_encoder, batch_size)\n                print('| end epoch {:3d} | valid loss '\n                      '{:5.2f} {:5.2f} | lr {:.5f}|'\n                      .format(epoch, s_loss, w_loss, lr))\n            print('-' * 89)\n            if lr > cfg.TRAIN.ENCODER_LR/10.:\n                lr *= 0.98\n\n            if (epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0 or\n                epoch == cfg.TRAIN.MAX_EPOCH):\n                torch.save(image_encoder.state_dict(),\n                           '%s/image_encoder%d.pth' % (model_dir, epoch))\n                torch.save(text_encoder.state_dict(),\n                           '%s/text_encoder%d.pth' % (model_dir, epoch))\n                print('Save G/Ds models.')\n    except KeyboardInterrupt:\n        print('-' * 89)\n        print('Exiting from training early')\n"
  },
  {
    "path": "test.py",
    "content": "import torch.nn as nn\nimport torch\nfrom torch.autograd import Variable\n\ndef conv1x1(in_planes, out_planes):\n    \"1x1 convolution with padding\"\n    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1,\n                     padding=0, bias=False)\nx = Variable(torch.rand(2,3,1,1))\nprint(x.size())\ny = conv1x1(3, 3)(x)\nprint(y.size())\n\n# z = torch.cat(x, x)\n# print(z.size())\n\nt = torch.mul(x, x)\nprint(t.size())\n\n"
  },
  {
    "path": "trainer.py",
    "content": "from __future__ import print_function\nfrom six.moves import range\nimport torch\nimport torch.optim as optim\nfrom torch.autograd import Variable\nimport torch.backends.cudnn as cudnn\nfrom PIL import Image\nfrom cfg.config import cfg\nfrom miscc.utils import mkdir_p\nfrom miscc.utils import build_super_images, build_super_images2\nfrom miscc.utils import weights_init, load_params, copy_G_params\nfrom model import G_DCGAN, G_NET\nfrom datasets import prepare_data\nfrom model import RNN_ENCODER, CNN_ENCODER, CAPTION_CNN, CAPTION_RNN\nfrom miscc.losses import words_loss\nfrom miscc.losses import discriminator_loss, generator_loss, KL_loss\nimport os\nimport time\nimport numpy as np\n\n\n# MirrorGAN\nclass Trainer(object):\n    def __init__(self, output_dir, data_loader, n_words, ixtoword):\n        if cfg.TRAIN.FLAG:\n            self.model_dir = os.path.join(output_dir, 'Model')\n            self.image_dir = os.path.join(output_dir, 'Image')\n            mkdir_p(self.model_dir)\n            mkdir_p(self.image_dir)\n\n        torch.cuda.set_device(cfg.GPU_ID)\n        cudnn.benchmark = True\n\n        self.batch_size = cfg.TRAIN.BATCH_SIZE\n        self.max_epoch = cfg.TRAIN.MAX_EPOCH\n        self.snapshot_interval = cfg.TRAIN.SNAPSHOT_INTERVAL\n\n        self.n_words = n_words\n        self.ixtoword = ixtoword\n        self.data_loader = data_loader\n        self.num_batches = len(self.data_loader)\n\n    def build_models(self):\n        # text encoders\n        if cfg.TRAIN.NET_E == '':\n            print('Error: no pretrained text-image encoders')\n            return\n\n        image_encoder = CNN_ENCODER(cfg.TEXT.EMBEDDING_DIM)\n        img_encoder_path = cfg.TRAIN.NET_E.replace('text_encoder', 'image_encoder')\n        state_dict = \\\n            torch.load(img_encoder_path, map_location=lambda storage, loc: storage)\n        image_encoder.load_state_dict(state_dict)\n        for p in image_encoder.parameters():\n            p.requires_grad = False\n        print('Load image encoder from:', img_encoder_path)\n        image_encoder.eval()\n\n        text_encoder = \\\n            RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)\n        state_dict = \\\n            torch.load(cfg.TRAIN.NET_E,\n                       map_location=lambda storage, loc: storage)\n        text_encoder.load_state_dict(state_dict)\n        for p in text_encoder.parameters():\n            p.requires_grad = False\n        print('Load text encoder from:', cfg.TRAIN.NET_E)\n        text_encoder.eval()\n\n        # Caption models - cnn_encoder and rnn_decoder\n        caption_cnn = CAPTION_CNN(cfg.CAP.embed_size)\n        caption_cnn.load_state_dict(torch.load(cfg.CAP.caption_cnn_path, map_location=lambda storage, loc: storage))\n        for p in caption_cnn.parameters():\n            p.requires_grad = False\n        print('Load caption model from:', cfg.CAP.caption_cnn_path)\n        caption_cnn.eval()\n\n        caption_rnn = CAPTION_RNN(cfg.CAP.embed_size, cfg.CAP.hidden_size * 2, self.n_words, cfg.CAP.num_layers)\n        caption_rnn.load_state_dict(torch.load(cfg.CAP.caption_rnn_path, map_location=lambda storage, loc: storage))\n        for p in caption_rnn.parameters():\n            p.requires_grad = False\n        print('Load caption model from:', cfg.CAP.caption_rnn_path)\n\n        # Generator and Discriminator:\n        netsD = []\n        if cfg.GAN.B_DCGAN:\n            if cfg.TREE.BRANCH_NUM == 1:\n                from model import D_NET64 as D_NET\n            elif cfg.TREE.BRANCH_NUM == 2:\n                from model import D_NET128 as D_NET\n            else:  # cfg.TREE.BRANCH_NUM == 3:\n                from model import D_NET256 as D_NET\n\n            netG = G_DCGAN()\n            netsD = [D_NET(b_jcu=False)]\n        else:\n            from model import D_NET64, D_NET128, D_NET256\n            netG = G_NET()\n            if cfg.TREE.BRANCH_NUM > 0:\n                netsD.append(D_NET64())\n            if cfg.TREE.BRANCH_NUM > 1:\n                netsD.append(D_NET128())\n            if cfg.TREE.BRANCH_NUM > 2:\n                netsD.append(D_NET256())\n        netG.apply(weights_init)\n        # print(netG)\n        for i in range(len(netsD)):\n            netsD[i].apply(weights_init)\n            # print(netsD[i])\n        print('# of netsD', len(netsD))\n\n        epoch = 0\n        if cfg.TRAIN.NET_G != '':\n            state_dict = \\\n                torch.load(cfg.TRAIN.NET_G, map_location=lambda storage, loc: storage)\n            netG.load_state_dict(state_dict)\n            print('Load G from: ', cfg.TRAIN.NET_G)\n            istart = cfg.TRAIN.NET_G.rfind('_') + 1\n            iend = cfg.TRAIN.NET_G.rfind('.')\n            epoch = cfg.TRAIN.NET_G[istart:iend]\n            epoch = int(epoch) + 1\n            if cfg.TRAIN.B_NET_D:\n                Gname = cfg.TRAIN.NET_G\n                for i in range(len(netsD)):\n                    s_tmp = Gname[:Gname.rfind('/')]\n                    Dname = '%s/netD%d.pth' % (s_tmp, i)\n                    print('Load D from: ', Dname)\n                    state_dict = \\\n                        torch.load(Dname, map_location=lambda storage, loc: storage)\n                    netsD[i].load_state_dict(state_dict)\n\n        if cfg.CUDA:\n            text_encoder = text_encoder.cuda()\n            image_encoder = image_encoder.cuda()\n            caption_cnn = caption_cnn.cuda()\n            caption_rnn = caption_rnn.cuda()\n            netG.cuda()\n            for i in range(len(netsD)):\n                netsD[i].cuda()\n        return [text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, epoch]\n\n    def define_optimizers(self, netG, netsD):\n        optimizersD = []\n        num_Ds = len(netsD)\n        for i in range(num_Ds):\n            opt = optim.Adam(netsD[i].parameters(),\n                             lr=cfg.TRAIN.DISCRIMINATOR_LR,\n                             betas=(0.5, 0.999))\n            optimizersD.append(opt)\n\n        optimizerG = optim.Adam(netG.parameters(),\n                                lr=cfg.TRAIN.GENERATOR_LR,\n                                betas=(0.5, 0.999))\n\n        return optimizerG, optimizersD\n\n    def prepare_labels(self):\n        batch_size = self.batch_size\n        real_labels = Variable(torch.FloatTensor(batch_size).fill_(1))\n        fake_labels = Variable(torch.FloatTensor(batch_size).fill_(0))\n        match_labels = Variable(torch.LongTensor(range(batch_size)))\n        if cfg.CUDA:\n            real_labels = real_labels.cuda()\n            fake_labels = fake_labels.cuda()\n            match_labels = match_labels.cuda()\n\n        return real_labels, fake_labels, match_labels\n\n    def save_model(self, netG, avg_param_G, netsD, epoch):\n        backup_para = copy_G_params(netG)\n        load_params(netG, avg_param_G)\n        torch.save(netG.state_dict(),\n                   '%s/netG_epoch_%d.pth' % (self.model_dir, epoch))\n        load_params(netG, backup_para)\n        #\n        for i in range(len(netsD)):\n            netD = netsD[i]\n            torch.save(netD.state_dict(),\n                       '%s/netD%d.pth' % (self.model_dir, i))\n        print('Save G/Ds models.')\n\n    def set_requires_grad_value(self, models_list, brequires):\n        for i in range(len(models_list)):\n            for p in models_list[i].parameters():\n                p.requires_grad = brequires\n\n    def save_img_results(self, netG, noise, sent_emb, words_embs, mask,\n                         image_encoder, captions, cap_lens,\n                         gen_iterations, name='current'):\n        # Save images\n        fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)\n        for i in range(len(attention_maps)):\n            if len(fake_imgs) > 1:\n                img = fake_imgs[i + 1].detach().cpu()\n                lr_img = fake_imgs[i].detach().cpu()\n            else:\n                img = fake_imgs[0].detach().cpu()\n                lr_img = None\n            attn_maps = attention_maps[i]\n            att_sze = attn_maps.size(2)\n            img_set, _ = \\\n                build_super_images(img, captions, self.ixtoword,\n                                   attn_maps, att_sze, lr_imgs=lr_img)\n            if img_set is not None:\n                im = Image.fromarray(img_set)\n                fullpath = '%s/G_%s_%d_%d.png' \\\n                           % (self.image_dir, name, gen_iterations, i)\n                im.save(fullpath)\n\n        i = -1\n        img = fake_imgs[i].detach()\n        region_features, _ = image_encoder(img)\n        att_sze = region_features.size(2)\n        _, _, att_maps = words_loss(region_features.detach(),\n                                    words_embs.detach(),\n                                    None, cap_lens,\n                                    None, self.batch_size)\n        img_set, _ = \\\n            build_super_images(fake_imgs[i].detach().cpu(),\n                               captions, self.ixtoword, att_maps, att_sze)\n        if img_set is not None:\n            im = Image.fromarray(img_set)\n            fullpath = '%s/D_%s_%d.png' \\\n                       % (self.image_dir, name, gen_iterations)\n            im.save(fullpath)\n\n    def train(self):\n        text_encoder, image_encoder, caption_cnn, caption_rnn, netG, netsD, start_epoch = self.build_models()\n        avg_param_G = copy_G_params(netG)\n        optimizerG, optimizersD = self.define_optimizers(netG, netsD)\n        real_labels, fake_labels, match_labels = self.prepare_labels()\n\n        batch_size = self.batch_size\n        nz = cfg.GAN.Z_DIM\n        noise = Variable(torch.FloatTensor(batch_size, nz))\n        fixed_noise = Variable(torch.FloatTensor(batch_size, nz).normal_(0, 1))\n        if cfg.CUDA:\n            noise, fixed_noise = noise.cuda(), fixed_noise.cuda()\n\n        gen_iterations = 0\n        for epoch in range(start_epoch, self.max_epoch):\n            start_t = time.time()\n\n            data_iter = iter(self.data_loader)\n            step = 0\n            while step < self.num_batches:\n                # (1) Prepare training data and Compute text embeddings\n                data = data_iter.next()\n                imgs, captions, cap_lens, class_ids, keys = prepare_data(data)\n\n                hidden = text_encoder.init_hidden(batch_size)\n                # words_embs: batch_size x nef x seq_len\n                # sent_emb: batch_size x nef\n                words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)\n                words_embs, sent_emb = words_embs.detach(), sent_emb.detach()\n                mask = (captions == 0)\n                num_words = words_embs.size(2)\n                if mask.size(1) > num_words:\n                    mask = mask[:, :num_words]\n\n                # (2) Generate fake images\n                noise.data.normal_(0, 1)\n                fake_imgs, _, mu, logvar = netG(noise, sent_emb, words_embs, mask)\n\n                # (3) Update D network\n                errD_total = 0\n                D_logs = ''\n                for i in range(len(netsD)):\n                    netsD[i].zero_grad()\n                    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],\n                                              sent_emb, real_labels, fake_labels)\n                    # backward and update parameters\n                    errD.backward()\n                    optimizersD[i].step()\n                    errD_total += errD\n                    D_logs += 'errD%d: %.2f ' % (i, errD.data[0])\n\n                # (4) Update G network: maximize log(D(G(z)))\n                # compute total loss for training G\n                step += 1\n                gen_iterations += 1\n                netG.zero_grad()\n                errG_total, G_logs = \\\n                    generator_loss(netsD, image_encoder, caption_cnn, caption_rnn, captions, fake_imgs, real_labels,\n                                   words_embs, sent_emb, match_labels, cap_lens, class_ids)\n                kl_loss = KL_loss(mu, logvar)\n                errG_total += kl_loss\n                G_logs += 'kl_loss: %.2f ' % kl_loss.data[0]\n                # backward and update parameters\n                errG_total.backward()\n                optimizerG.step()\n                for p, avg_p in zip(netG.parameters(), avg_param_G):\n                    avg_p.mul_(0.999).add_(0.001, p.data)\n\n                if gen_iterations % 100 == 0:\n                    print(D_logs + '\\n' + G_logs)\n                # save images\n                if gen_iterations % 1000 == 0:\n                    backup_para = copy_G_params(netG)\n                    load_params(netG, avg_param_G)\n                    self.save_img_results(netG, fixed_noise, sent_emb,\n                                          words_embs, mask, image_encoder,\n                                          captions, cap_lens, epoch, name='average')\n                    load_params(netG, backup_para)\n            end_t = time.time()\n\n            print('''[%d/%d][%d]\n                  Loss_D: %.2f Loss_G: %.2f Time: %.2fs'''\n                  % (epoch, self.max_epoch, self.num_batches,\n                     errD_total.data[0], errG_total.data[0],\n                     end_t - start_t))\n\n            if epoch % cfg.TRAIN.SNAPSHOT_INTERVAL == 0:  # and epoch != 0:\n                self.save_model(netG, avg_param_G, netsD, epoch)\n\n        self.save_model(netG, avg_param_G, netsD, self.max_epoch)\n\n    def save_singleimages(self, images, filenames, save_dir,\n                          split_dir, sentenceID=0):\n        for i in range(images.size(0)):\n            s_tmp = '%s/single_samples/%s/%s' % \\\n                    (save_dir, split_dir, filenames[i])\n            folder = s_tmp[:s_tmp.rfind('/')]\n            if not os.path.isdir(folder):\n                print('Make a new folder: ', folder)\n                mkdir_p(folder)\n\n            fullpath = '%s_%d.jpg' % (s_tmp, sentenceID)\n            # range from [-1, 1] to [0, 1]\n            # img = (images[i] + 1.0) / 2\n            img = images[i].add(1).div(2).mul(255).clamp(0, 255).byte()\n            # range from [0, 1] to [0, 255]\n            ndarr = img.permute(1, 2, 0).data.cpu().numpy()\n            im = Image.fromarray(ndarr)\n            im.save(fullpath)\n\n    def sampling(self, split_dir):\n        if cfg.TRAIN.NET_G == '':\n            print('Error: the path for model is not found!')\n        else:\n            if split_dir == 'test':\n                split_dir = 'valid'\n            # Build and load the generator\n            if cfg.GAN.B_DCGAN:\n                netG = G_DCGAN()\n            else:\n                netG = G_NET()\n            netG.apply(weights_init)\n            netG.cuda()\n            netG.eval()\n            #\n            text_encoder = RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)\n            state_dict = \\\n                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)\n            text_encoder.load_state_dict(state_dict)\n            print('Load text encoder from:', cfg.TRAIN.NET_E)\n            text_encoder = text_encoder.cuda()\n            text_encoder.eval()\n\n            batch_size = self.batch_size\n            nz = cfg.GAN.Z_DIM\n            noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)\n            noise = noise.cuda()\n\n            model_dir = cfg.TRAIN.NET_G\n            state_dict = \\\n                torch.load(model_dir, map_location=lambda storage, loc: storage)\n            netG.load_state_dict(state_dict)\n            print('Load G from: ', model_dir)\n\n            # the path to save generated images\n            s_tmp = model_dir[:model_dir.rfind('.pth')]\n            save_dir = '%s/%s' % (s_tmp, split_dir)\n            mkdir_p(save_dir)\n\n            cnt = 0\n\n            for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):\n                for step, data in enumerate(self.data_loader, 0):\n                    cnt += batch_size\n                    if step % 100 == 0:\n                        print('step: ', step)\n                    # if step > 50:\n                    #     break\n\n                    imgs, captions, cap_lens, class_ids, keys = prepare_data(data)\n\n                    hidden = text_encoder.init_hidden(batch_size)\n                    # words_embs: batch_size x nef x seq_len\n                    # sent_emb: batch_size x nef\n                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)\n                    words_embs, sent_emb = words_embs.detach(), sent_emb.detach()\n                    mask = (captions == 0)\n                    num_words = words_embs.size(2)\n                    if mask.size(1) > num_words:\n                        mask = mask[:, :num_words]\n\n                    # (2) Generate fake images\n                    noise.data.normal_(0, 1)\n                    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask)\n                    for j in range(batch_size):\n                        s_tmp = '%s/single/%s' % (save_dir, keys[j])\n                        folder = s_tmp[:s_tmp.rfind('/')]\n                        if not os.path.isdir(folder):\n                            print('Make a new folder: ', folder)\n                            mkdir_p(folder)\n                        k = -1\n                        # for k in range(len(fake_imgs)):\n                        im = fake_imgs[k][j].data.cpu().numpy()\n                        # [-1, 1] --> [0, 255]\n                        im = (im + 1.0) * 127.5\n                        im = im.astype(np.uint8)\n                        im = np.transpose(im, (1, 2, 0))\n                        im = Image.fromarray(im)\n                        fullpath = '%s_s%d.png' % (s_tmp, k)\n                        im.save(fullpath)\n\n    def gen_example(self, data_dic):\n        if cfg.TRAIN.NET_G == '':\n            print('Error: the path for morels is not found!')\n        else:\n            # Build and load the generator\n            text_encoder = \\\n                RNN_ENCODER(self.n_words, nhidden=cfg.TEXT.EMBEDDING_DIM)\n            state_dict = \\\n                torch.load(cfg.TRAIN.NET_E, map_location=lambda storage, loc: storage)\n            text_encoder.load_state_dict(state_dict)\n            print('Load text encoder from:', cfg.TRAIN.NET_E)\n            text_encoder = text_encoder.cuda()\n            text_encoder.eval()\n\n            # the path to save generated images\n            if cfg.GAN.B_DCGAN:\n                netG = G_DCGAN()\n            else:\n                netG = G_NET()\n            s_tmp = cfg.TRAIN.NET_G[:cfg.TRAIN.NET_G.rfind('.pth')]\n            model_dir = cfg.TRAIN.NET_G\n            state_dict = \\\n                torch.load(model_dir, map_location=lambda storage, loc: storage)\n            netG.load_state_dict(state_dict)\n            print('Load G from: ', model_dir)\n            netG.cuda()\n            netG.eval()\n            for key in data_dic:\n                save_dir = '%s/%s' % (s_tmp, key)\n                mkdir_p(save_dir)\n                captions, cap_lens, sorted_indices = data_dic[key]\n\n                batch_size = captions.shape[0]\n                nz = cfg.GAN.Z_DIM\n                captions = Variable(torch.from_numpy(captions), volatile=True)\n                cap_lens = Variable(torch.from_numpy(cap_lens), volatile=True)\n\n                captions = captions.cuda()\n                cap_lens = cap_lens.cuda()\n                for i in range(1):  # 16\n                    noise = Variable(torch.FloatTensor(batch_size, nz), volatile=True)\n                    noise = noise.cuda()\n                    # (1) Extract text embeddings\n                    hidden = text_encoder.init_hidden(batch_size)\n                    # words_embs: batch_size x nef x seq_len\n                    # sent_emb: batch_size x nef\n                    words_embs, sent_emb = text_encoder(captions, cap_lens, hidden)\n                    mask = (captions == 0)\n                    # (2) Generate fake images\n                    noise.data.normal_(0, 1)\n                    fake_imgs, attention_maps, _, _ = netG(noise, sent_emb, words_embs, mask)\n                    # G attention\n                    cap_lens_np = cap_lens.cpu().data.numpy()\n                    for j in range(batch_size):\n                        save_name = '%s/%d_s_%d' % (save_dir, i, sorted_indices[j])\n                        for k in range(len(fake_imgs)):\n                            im = fake_imgs[k][j].data.cpu().numpy()\n                            im = (im + 1.0) * 127.5\n                            im = im.astype(np.uint8)\n                            # print('im', im.shape)\n                            im = np.transpose(im, (1, 2, 0))\n                            # print('im', im.shape)\n                            im = Image.fromarray(im)\n                            fullpath = '%s_g%d.png' % (save_name, k)\n                            im.save(fullpath)\n\n                        for k in range(len(attention_maps)):\n                            if len(fake_imgs) > 1:\n                                im = fake_imgs[k + 1].detach().cpu()\n                            else:\n                                im = fake_imgs[0].detach().cpu()\n                            attn_maps = attention_maps[k]\n                            att_sze = attn_maps.size(2)\n                            img_set, sentences = \\\n                                build_super_images2(im[j].unsqueeze(0),\n                                                    captions[j].unsqueeze(0),\n                                                    [cap_lens_np[j]], self.ixtoword,\n                                                    [attn_maps[j]], att_sze)\n                            if img_set is not None:\n                                im = Image.fromarray(img_set)\n                                fullpath = '%s_a%d.png' % (save_name, k)\n                                im.save(fullpath)\n"
  }
]