Full Code of facebookresearch/EmbodiedQA for AI

main 306fd6ef3064 cached
17 files
253.9 KB
56.4k tokens
122 symbols
1 requests
Download .txt
Showing preview only (263K chars total). Download the full file or copy to clipboard to get everything.
Repository: facebookresearch/EmbodiedQA
Branch: main
Commit: 306fd6ef3064
Files: 17
Total size: 253.9 KB

Directory structure:
gitextract_mc95tebh/

├── .gitignore
├── .gitmodules
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── requirements.txt
├── training/
│   ├── data.py
│   ├── metrics.py
│   ├── models.py
│   ├── train_eqa.py
│   ├── train_nav.py
│   ├── train_vqa.py
│   └── utils/
│       ├── preprocess_questions.py
│       └── preprocess_questions_pkl.py
└── utils/
    ├── house3d.py
    └── make_houses.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
tmp
data
logs
checkpoints
*.pem
*.sh
*autoenv*

# PYTHON
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# NODE
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*

# Runtime data
pids
*.pid
*.seed
*.pid.lock

# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov

# Coverage directory used by tools like istanbul
coverage

# nyc test coverage
.nyc_output

# Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
.grunt

# Bower dependency directory (https://bower.io/)
bower_components

# node-waf configuration
.lock-wscript

# Compiled binary addons (http://nodejs.org/api/addons.html)
build/Release

# Dependency directories
node_modules/
jspm_packages/

# Typescript v1 declaration files
typings/

# Optional npm cache directory
.npm

# Optional eslint cache
.eslintcache

# Optional REPL history
.node_repl_history

# Output of 'npm pack'
*.tgz

# Yarn Integrity file
.yarn-integrity

# dotenv environment variables file
.env



================================================
FILE: .gitmodules
================================================
[submodule "House3D"]
	path = House3D
	url = git@github.com:abhshkdz/House3D.git


================================================
FILE: CODE_OF_CONDUCT.md
================================================
# Code of Conduct

Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
Please read the [full text](https://code.facebook.com/pages/876921332402685/open-source-code-of-conduct)
so that you can understand what actions will and will not be tolerated.

================================================
FILE: CONTRIBUTING.md
================================================
# Contributing to EmbodiedQA
We want to make contributing to this project as easy and transparent as
possible.

## Our Development Process
Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis.

## Pull Requests
We actively welcome your pull requests.

1. Fork the repo and create your branch from `master`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").

## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.

Complete your CLA here: <https://code.facebook.com/cla>

## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.

Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.

## Coding Style  
* 4 spaces for indentation rather than tabs
* 80 character line length

## License
By contributing to EmbodiedQA, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.

================================================
FILE: LICENSE
================================================
BSD License

For EmbodiedQA software

Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

 * Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.

 * Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

 * Neither the name Facebook nor the names of its contributors may be used to
   endorse or promote products derived from this software without specific
   prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


================================================
FILE: README.md
================================================
# EmbodiedQA

Code for the paper

**[Embodied Question Answering][1]**  
Abhishek Das, Samyak Datta, Georgia Gkioxari, Stefan Lee, Devi Parikh, Dhruv Batra  
[arxiv.org/abs/1711.11543][2]  
CVPR 2018 (Oral)

In Embodied Question Answering (EmbodiedQA), an agent is spawned at a random location in a 3D environment and asked a question (for e.g. "What color is the car?"). In order to answer, the agent must first intelligently navigate to explore the environment, gather necessary visual information through first-person vision, and then answer the question ("orange").

![](https://i.imgur.com/jeI7bxm.jpg)

This repository provides

- [Pretrained CNN](#pretrained-cnn) for [House3D][house3d]
- Code for [generating EQA questions](#question-generation)
    - EQA v1: location, color, place preposition
    - EQA v1-extended: existence, logical, object counts, room counts, distance comparison
- Code to train and evaluate [navigation](#navigation) and [question-answering](#visual-question-answering) models
    - [independently with supervised learning](#supervised-learning) on shortest paths
    - jointly using [reinforcement learning](#reinforce)

If you find this code useful, consider citing our work:

```
@inproceedings{embodiedqa,
  title={{E}mbodied {Q}uestion {A}nswering},
  author={Abhishek Das and Samyak Datta and Georgia Gkioxari and Stefan Lee and Devi Parikh and Dhruv Batra},
  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2018}
}
```

## Setup

```
virtualenv -p python3 .env
source .env/bin/activate
pip install -r requirements.txt
```

Download the [SUNCG v1 dataset](https://github.com/facebookresearch/House3D/blob/master/INSTRUCTION.md#usage-instructions) and [install House3D](https://github.com/abhshkdz/House3D/tree/master/renderer#rendering-code-of-house3d).

NOTE: This code uses a [fork of House3D](https://github.com/abhshkdz/house3d) with a few changes to support arbitrary map discretization resolutions.

## Question generation

Questions for EmbodiedQA are generated programmatically, in a manner similar to [CLEVR (Johnson et al., 2017)][clevr].

NOTE: Pre-generated EQA v1 questions are available for download [here][eqav1].

### Generating questions for all templates in EQA v1, v1-extended

```
cd data/question-gen
./run_me.sh MM_DD
```

### List defined question templates

```
from engine import Engine

E = Engine()
for i in E.template_defs:
    print(i, E.template_defs[i])
```

### Generate questions for a particular template (say `location`)

```
from house_parse import HouseParse
from engine import Engine

Hp = HouseParse(dataDir='/path/to/suncg')
Hp.parse('0aa5e04f06a805881285402096eac723')

E = Engine()
E.cacheHouse(Hp)
qns = E.executeFn(E.template_defs['location'])

print(qns[0]['question'], qns[0]['answer'])
# what room is the clock located in? bedroom

```

## Pretrained CNN

We trained a shallow encoder-decoder CNN from scratch in the House3D environment,
for RGB reconstruction, semantic segmentation and depth estimation.
Once trained, we throw away the decoders, and use the encoder as a frozen feature
extractor for navigation and question answering. The CNN is available for download here:

`wget https://www.dropbox.com/s/ju1zw4iipxlj966/03_13_h3d_hybrid_cnn.pt`

The training code expects the checkpoint to be present in `training/models/`.

## Supervised Learning

### Download and preprocess the dataset

Download [EQA v1][eqav1] and shortest path navigations:

```
wget https://www.dropbox.com/s/6zu1b1jzl0qt7t1/eqa_v1.json
wget https://www.dropbox.com/s/lhajthx7cdlnhns/a-star-500.zip
unzip a-star-500.zip
```

If this is the first time you are using SUNCG, you will have to clone and use the
[SUNCG toolbox](https://github.com/shurans/SUNCGtoolbox#convert-to-objmtl)
to generate obj + mtl files for the houses in EQA.

NOTE: Shortest paths have been updated.  Earlier we computed shortest paths using a discrete grid world, but we found that these shortest paths were sometimes innacurate.  Old shortest paths are [here](https://www.dropbox.com/s/vgp2ygh1bht1jyb/shortest-paths.zip).

```
cd utils
python make_houses.py \
    -eqa_path /path/to/eqa.json \
    -suncg_toolbox_path /path/to/SUNCGtoolbox \
    -suncg_data_path /path/to/suncg/data_root
```

Preprocess the dataset for training


```
cd training
python utils/preprocess_questions_pkl.py \
    -input_json /path/to/eqa_v1.json \
    -shortest_path_dir /path/to/shortest/paths/a-star-500 \
    -output_train_h5 data/train.h5 \
    -output_val_h5 data/val.h5 \
    -output_test_h5 data/test.h5 \
    -output_data_json data/data.json \
    -output_vocab data/vocab.json
```

### Visual question answering

Update pretrained CNN path in `models.py`.

`python train_vqa.py -input_type ques,image -identifier ques-image -log -cache`

This model computes question-conditioned attention over last 5 frames from oracle navigation (shortest paths),
and predicts answer. Assuming shortest paths are optimal for answering the question -- which is predominantly
true for most questions in EQA v1 (`location`, `color`, `place preposition`) with the
exception of a few `location` questions that might need more visual context than walking right up till the object --
this can be thought of as an upper bound on expected accuracy, and performance will get worse when navigation
trajectories are sampled from trained policies.

A pretrained VQA model is available for download [here](https://www.dropbox.com/s/jd15af00r7m8neh/vqa_11_18_2018_va0.6154.pt). This gets a top-1 accuracy of 61.54% on val, and 58.46% on test (with GT navigation).

Note that keeping the `cache` flag ON caches images as they are rendered in the first training epoch, so that subsequent epochs are very fast. This is memory-intensive though, and consumes ~100-120G RAM.

### Navigation

Download potential maps for evaluating navigation and training with REINFORCE.

```
wget https://www.dropbox.com/s/53edqtr04jts4q0/target-obj-conn-maps-500.zip
```

#### Planner-controller policy

`python train_nav.py -model_type pacman -identifier pacman -log`

## REINFORCE

```
python train_eqa.py \
    -nav_checkpoint_path /path/to/nav/ques-image-pacman/checkpoint.pt \
    -ans_checkpoint_path /path/to/vqa/ques-image/checkpoint.pt \
    -identifier ques-image-eqa \
    -log
```

## Changelog

### 09/07

- We added the baseline models from the CVPR paper (Reactive and LSTM).
- With the LSTM model, we achieved d_T values of: 0.74693/3.99891/8.10669 on the test set for d equal to 10/30/50 respectively training with behavior cloning (no reinforcement learning).
- We also updated the shortest paths to fix an issue with the shortest path algorithm we initially used.  Code to generate shortest paths is [here](https://github.com/facebookresearch/EmbodiedQA/blob/master/data/shortest-path-gen/generate-paths-a-star.py).

### 06/13

This code release contains the following changes over the CVPR version

- Larger dataset of questions + shortest paths
- Color names as answers to color questions (earlier they were hex strings)

## Acknowledgements

- Parts of this code are adapted from [pytorch-a3c][pytorch-a3c] by Ilya Kostrikov
- [Lisa Anne Hendricks](https://people.eecs.berkeley.edu/~lisa_anne/) and [Licheng Yu](http://www.cs.unc.edu/~licheng/)
helped with running / testing / debugging code prior to release

## License

BSD

[1]: https://embodiedqa.org
[2]: https://arxiv.org/abs/1711.11543
[house3d]: https://github.com/facebookresearch/house3d
[dijkstar]: https://bitbucket.org/wyatt/dijkstar
[pytorch-a3c]: https://github.com/ikostrikov/pytorch-a3c
[eqav1]: https://embodiedqa.org/data
[clevr]: https://github.com/facebookresearch/clevr-dataset-gen


================================================
FILE: requirements.txt
================================================
certifi==2018.4.16
chardet==3.0.4
future==0.16.0
gym==0.10.5
h5py==2.8.0
idna==2.6
numpy==1.14.4
opencv-python==3.4.1.15
Pillow==5.1.0
pyglet==1.3.2
requests==2.18.4
scipy==1.1.0
six==1.11.0
torch==0.3.1
torchvision==0.2.1
tqdm==4.23.4
urllib3==1.22


================================================
FILE: training/data.py
================================================
import math
import time
import h5py
import logging
import argparse
import numpy as np
import os, sys, json
from tqdm import tqdm

from scipy.misc import imread, imresize

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torch.autograd import Variable

sys.path.insert(0, '../../House3D/')
from House3D import objrender, Environment, load_config
from House3D.core import local_create_house

sys.path.insert(0, '../utils/')
from house3d import House3DUtils

from models import MultitaskCNN

import pdb

def load_vocab(path):
    with open(path, 'r') as f:
        vocab = json.load(f)
        vocab['questionIdxToToken'] = invert_dict(vocab['questionTokenToIdx'])
        vocab['answerIdxToToken'] = invert_dict(vocab['answerTokenToIdx'])

    assert vocab['questionTokenToIdx']['<NULL>'] == 0
    assert vocab['questionTokenToIdx']['<START>'] == 1
    assert vocab['questionTokenToIdx']['<END>'] == 2
    return vocab


def invert_dict(d):
    return {v: k for k, v in d.items()}


"""
if the action sequence is [f, f, l, l, f, f, f, r]

input sequence to planner is [<start>, f, l, f, r]
output sequence for planner is [f, l, f, r, <end>]

input sequences to controller are [f, f, l, l, f, f, f, r]
output sequences for controller are [1, 0, 1, 0, 1, 1, 0, 0]
"""
def flat_to_hierarchical_actions(actions, controller_action_lim):
    assert len(actions) != 0

    controller_action_ctr = 0

    planner_actions, controller_actions = [1], []
    prev_action = 1

    pq_idx, cq_idx, ph_idx = [], [], []
    ph_trck = 0

    for i in range(1, len(actions)):

        if actions[i] != prev_action:
            planner_actions.append(actions[i])
            pq_idx.append(i-1)

        if i > 1:
            ph_idx.append(ph_trck)
            if actions[i] == prev_action:
                controller_actions.append(1)
                controller_action_ctr += 1
            else:
                controller_actions.append(0)
                controller_action_ctr = 0
                ph_trck += 1
            cq_idx.append(i-1)


        prev_action = actions[i]

        if controller_action_ctr == controller_action_lim-1:
            prev_action = False

    return planner_actions, controller_actions, pq_idx, cq_idx, ph_idx


def _dataset_to_tensor(dset, mask=None, dtype=np.int64):
    arr = np.asarray(dset, dtype=dtype)
    if mask is not None:
        arr = arr[mask]
    if dtype == np.float32:
        tensor = torch.FloatTensor(arr)
    else:
        tensor = torch.LongTensor(arr)
    return tensor


def eqaCollateCnn(batch):
    transposed = list(zip(*batch))
    idx_batch = default_collate(transposed[0])
    question_batch = default_collate(transposed[1])
    answer_batch = default_collate(transposed[2])
    images_batch = default_collate(transposed[3])
    actions_in_batch = default_collate(transposed[4])
    actions_out_batch = default_collate(transposed[5])
    action_lengths_batch = default_collate(transposed[6])
    return [
        idx_batch, question_batch, answer_batch, images_batch,
        actions_in_batch, actions_out_batch, action_lengths_batch
    ]


def eqaCollateSeq2seq(batch):
    transposed = list(zip(*batch))
    idx_batch = default_collate(transposed[0])
    questions_batch = default_collate(transposed[1])
    answers_batch = default_collate(transposed[2])
    images_batch = default_collate(transposed[3])
    actions_in_batch = default_collate(transposed[4])
    actions_out_batch = default_collate(transposed[5])
    action_lengths_batch = default_collate(transposed[6])
    mask_batch = default_collate(transposed[7])

    return [
        idx_batch, questions_batch, answers_batch, images_batch,
        actions_in_batch, actions_out_batch, action_lengths_batch, mask_batch
    ]


class EqaDataset(Dataset):
    def __init__(self,
                 questions_h5,
                 vocab,
                 num_frames=1,
                 data_json=False,
                 split='train',
                 gpu_id=0,
                 input_type='ques',
                 max_threads_per_gpu=10,
                 to_cache=False,
                 target_obj_conn_map_dir=False,
                 map_resolution=1000,
                 overfit=False,
                 max_controller_actions=5,
                 max_actions=None):

        self.questions_h5 = questions_h5
        self.vocab = load_vocab(vocab)
        self.num_frames = num_frames
        self.max_controller_actions = max_controller_actions

        np.random.seed()

        self.data_json = data_json
        self.split = split
        self.gpu_id = gpu_id

        self.input_type = input_type

        self.max_threads_per_gpu = max_threads_per_gpu

        self.target_obj_conn_map_dir = target_obj_conn_map_dir
        self.map_resolution = map_resolution
        self.overfit = overfit

        self.to_cache = to_cache
        self.img_data_cache = {}

        print('Reading question data into memory')
        self.idx = _dataset_to_tensor(questions_h5['idx'])
        self.questions = _dataset_to_tensor(questions_h5['questions'])
        self.answers = _dataset_to_tensor(questions_h5['answers'])
        self.actions = _dataset_to_tensor(questions_h5['action_labels'])
        self.action_lengths = _dataset_to_tensor(
            questions_h5['action_lengths'])

        if max_actions: #max actions will allow us to create arrays of a certain length.  Helpful if you only want to train with 10 actions.
            assert isinstance(max_actions, int)
            num_data_items = self.actions.shape[0]
            new_actions = np.zeros((num_data_items, max_actions+2), dtype=np.int64)
            new_lengths = np.ones((num_data_items,), dtype=np.int64)*max_actions
            for i in range(num_data_items):
                action_length = int(self.action_lengths[i])
                new_actions[i,0] = 1
                new_actions[i,1:max_actions+1] = self.actions[i, action_length-max_actions: action_length].numpy() 
            self.actions = torch.LongTensor(new_actions)
            self.action_lengths = torch.LongTensor(new_lengths)

        if self.data_json != False:
            data = json.load(open(self.data_json, 'r'))
            self.envs = data['envs']

            self.env_idx = data[self.split + '_env_idx']
            self.env_list = [self.envs[x] for x in self.env_idx]
            self.env_set = list(set(self.env_list))
            self.env_set.sort()

            if self.overfit == True:
                self.env_idx = self.env_idx[:1]
                self.env_set = self.env_list = [self.envs[x] for x in self.env_idx]
                print('Trying to overfit to [house %s]' % self.env_set[0])
                logging.info('Trying to overfit to [house {}]'.format(self.env_set[0]))

            print('Total envs: %d' % len(list(set(self.envs))))
            print('Envs in %s: %d' % (self.split,
                                      len(list(set(self.env_idx)))))

            if input_type != 'ques':
                ''''
                If training, randomly sample and load a subset of environments,
                train on those, and then cycle through to load the rest.

                On the validation and test set, load in order, and cycle through.

                For both, add optional caching so that if all environments
                have been cycled through once, then no need to re-load and
                instead, just the cache can be used.
                '''

                self.api_threads = []
                self._load_envs(start_idx=0, in_order=True)

                cnn_kwargs = {'num_classes': 191, 'pretrained': True}
                self.cnn = MultitaskCNN(**cnn_kwargs)
                self.cnn.eval()
                self.cnn.cuda()

            self.pos_queue = data[self.split + '_pos_queue']
            self.boxes = data[self.split + '_boxes']

            if max_actions:
                for i in range(len(self.pos_queue)):
                    self.pos_queue[i] = self.pos_queue[i][-1*max_actions:] 

        if input_type == 'pacman':

            self.planner_actions = self.actions.clone().fill_(0)
            self.controller_actions = self.actions.clone().fill_(-1)

            self.planner_action_lengths = self.action_lengths.clone().fill_(0)
            self.controller_action_lengths = self.action_lengths.clone().fill_(
                0)

            self.planner_hidden_idx = self.actions.clone().fill_(0)

            self.planner_pos_queue_idx, self.controller_pos_queue_idx = [], []

            # parsing flat actions to planner-controller hierarchy
            for i in tqdm(range(len(self.actions))):

                pa, ca, pq_idx, cq_idx, ph_idx = flat_to_hierarchical_actions(
                    actions=self.actions[i][:self.action_lengths[i]+1],
                    controller_action_lim=max_controller_actions)

                self.planner_actions[i][:len(pa)] = torch.Tensor(pa)
                self.controller_actions[i][:len(ca)] = torch.Tensor(ca)

                self.planner_action_lengths[i] = len(pa)-1
                self.controller_action_lengths[i] = len(ca)

                self.planner_pos_queue_idx.append(pq_idx)
                self.controller_pos_queue_idx.append(cq_idx)

                self.planner_hidden_idx[i][:len(ca)] = torch.Tensor(ph_idx)

    def _pick_envs_to_load(self,
                           split='train',
                           max_envs=10,
                           start_idx=0,
                           in_order=False):
        if split in ['val', 'test'] or in_order == True:
            pruned_env_set = self.env_set[start_idx:start_idx + max_envs]
        else:
            if max_envs < len(self.env_set):
                env_inds = np.random.choice(
                    len(self.env_set), max_envs, replace=False)
            else:
                env_inds = np.random.choice(
                    len(self.env_set), max_envs, replace=True)
            pruned_env_set = [self.env_set[x] for x in env_inds]
        return pruned_env_set

    def _load_envs(self, start_idx=-1, in_order=False):
        #self._clear_memory()
        if start_idx == -1:
            start_idx = self.env_set.index(self.pruned_env_set[-1]) + 1

        # Pick envs
        self.pruned_env_set = self._pick_envs_to_load(
            split=self.split,
            max_envs=self.max_threads_per_gpu,
            start_idx=start_idx,
            in_order=in_order)

        if len(self.pruned_env_set) == 0:
            return

        # Load api threads
        start = time.time()
        if len(self.api_threads) == 0:
            for i in range(self.max_threads_per_gpu):
                self.api_threads.append(
                    objrender.RenderAPIThread(
                        w=224, h=224, device=self.gpu_id))

        try:
            self.cfg = load_config('../House3D/tests/config.json')
        except:
            self.cfg = load_config('../../House3D/tests/config.json') #Sorry guys; this is so Lisa can run on her system; maybe we should make this an input somewhere?

        print('[%.02f] Loaded %d api threads' % (time.time() - start,
                                                 len(self.api_threads)))
        start = time.time()

        # Load houses
        from multiprocessing import Pool
        _args = ([h, self.cfg, self.map_resolution]
                 for h in self.pruned_env_set)
        with Pool(len(self.pruned_env_set)) as pool:
            self.all_houses = pool.starmap(local_create_house, _args)

        print('[%.02f] Loaded %d houses' % (time.time() - start,
                                            len(self.all_houses)))
        start = time.time()

        # Load envs
        self.env_loaded = {}
        for i in range(len(self.all_houses)):
            print('[%02d/%d][split:%s][gpu:%d][house:%s]' %
                  (i + 1, len(self.all_houses), self.split, self.gpu_id,
                   self.all_houses[i].house['id']))
            environment = Environment(self.api_threads[i], self.all_houses[i], self.cfg)
            self.env_loaded[self.all_houses[i].house['id']] = House3DUtils(
                environment,
                target_obj_conn_map_dir=self.target_obj_conn_map_dir,
                build_graph=False)

        # [TODO] Unused till now
        self.env_ptr = -1

        print('[%.02f] Loaded %d house3d envs' % (time.time() - start,
                                                  len(self.env_loaded)))

        # Mark available data indices
        self.available_idx = [
            i for i, v in enumerate(self.env_list) if v in self.env_loaded
        ]

        # [TODO] only keeping legit sequences
        # needed for things to play well with old data
        temp_available_idx = self.available_idx.copy()
        for i in range(len(temp_available_idx)):
            if self.action_lengths[temp_available_idx[i]] < 5:
                self.available_idx.remove(temp_available_idx[i])

        print('Available inds: %d' % len(self.available_idx))

        # Flag to check if loaded envs have been cycled through or not
        # [TODO] Unused till now
        self.all_envs_loaded = False

    def _clear_api_threads(self):
        for i in range(len(self.api_threads)):
            del self.api_threads[0]
        self.api_threads = []

    def _clear_memory(self):
        if hasattr(self, 'episode_house'):
            del self.episode_house
        if hasattr(self, 'env_loaded'):
            del self.env_loaded
        if hasattr(self, 'api_threads'):
            del self.api_threads
        self.api_threads = []

    def _check_if_all_envs_loaded(self):
        print('[CHECK][Cache:%d][Total:%d]' % (len(self.img_data_cache),
                                               len(self.env_list)))
        if len(self.img_data_cache) == len(self.env_list):
            self.available_idx = [i for i, v in enumerate(self.env_list)]
            return True
        else:
            return False

    def set_camera(self, e, pos, robot_height=1.0):
        assert len(pos) == 4

        e.env.cam.pos.x = pos[0]
        e.env.cam.pos.y = robot_height
        e.env.cam.pos.z = pos[2]
        e.env.cam.yaw = pos[3]

        e.env.cam.updateDirection()

    def render(self, e):
        return e.env.render()

    def get_frames(self, e, pos_queue, preprocess=True):
        if isinstance(pos_queue, list) == False:
            pos_queue = [pos_queue]

        res = []
        for i in range(len(pos_queue)):
            self.set_camera(e, pos_queue[i])
            img = np.array(self.render(e), copy=False, dtype=np.float32)

            if preprocess == True:
                img = img.transpose(2, 0, 1)
                img = img / 255.0

            res.append(img)

        return np.array(res)

    def get_hierarchical_features_till_spawn(self, actions, backtrack_steps=0, max_controller_actions=5):

        action_length = len(actions)-1
        pa, ca, pq_idx, cq_idx, ph_idx = flat_to_hierarchical_actions(
            actions=actions,
            controller_action_lim=max_controller_actions)
        
        # count how many actions of same type have been encountered pefore starting navigation
        backtrack_controller_steps = actions[1:action_length - backtrack_steps + 1:][::-1]
        counter = 0 

        if len(backtrack_controller_steps) > 0:
            while (counter <= self.max_controller_actions) and (counter < len(backtrack_controller_steps) and (backtrack_controller_steps[counter] == backtrack_controller_steps[0])):
                counter += 1 

        target_pos_idx = action_length - backtrack_steps

        controller_step = True
        if target_pos_idx in pq_idx:
            controller_step = False

        pq_idx_pruned = [v for v in pq_idx if v <= target_pos_idx]
        pa_pruned = pa[:len(pq_idx_pruned)+1]

        images = self.get_frames(
            self.episode_house,
            self.episode_pos_queue,
            preprocess=True)
        raw_img_feats = self.cnn(
            Variable(torch.FloatTensor(images)
                     .cuda())).data.cpu().numpy().copy()

        controller_img_feat = torch.from_numpy(raw_img_feats[target_pos_idx].copy())
        controller_action_in = pa_pruned[-1] - 2

        planner_img_feats = torch.from_numpy(raw_img_feats[pq_idx_pruned].copy())
        planner_actions_in = torch.from_numpy(np.array(pa_pruned[:-1]) - 1)

        return planner_actions_in, planner_img_feats, controller_step, controller_action_in, \
            controller_img_feat, self.episode_pos_queue[target_pos_idx], counter

    def __getitem__(self, index):
        # [VQA] question-only
        if self.input_type == 'ques':
            idx = self.idx[index]
            question = self.questions[index]
            answer = self.answers[index]

            return (idx, question, answer)

        # [VQA] question+image
        elif self.input_type == 'ques,image':
            index = self.available_idx[index]

            idx = self.idx[index]
            question = self.questions[index]
            answer = self.answers[index]

            action_length = self.action_lengths[index]
            actions = self.actions[index]

            actions_in = actions[action_length - self.num_frames:action_length]
            actions_out = actions[action_length - self.num_frames + 1:
                                  action_length + 1]

            if self.to_cache == True and index in self.img_data_cache:
                images = self.img_data_cache[index]
            else:
                pos_queue = self.pos_queue[index][
                    -self.num_frames:]  # last 5 frames
                images = self.get_frames(
                    self.env_loaded[self.env_list[index]],
                    pos_queue,
                    preprocess=True)
                if self.to_cache == True:
                    self.img_data_cache[index] = images.copy()

            return (idx, question, answer, images, actions_in, actions_out,
                    action_length)

        # [NAV] question+cnn
        elif self.input_type in ['cnn', 'cnn+q']:

            index = self.available_idx[index]

            idx = self.idx[index]
            question = self.questions[index]
            answer = self.answers[index]

            action_length = self.action_lengths[index]
            actions = self.actions[index]

            if self.to_cache == True and index in self.img_data_cache:
                img_feats = self.img_data_cache[index]
            else:
                pos_queue = self.pos_queue[index]
                images = self.get_frames(
                    self.env_loaded[self.env_list[index]],
                    pos_queue,
                    preprocess=True)
                img_feats = self.cnn(
                    Variable(torch.FloatTensor(images)
                             .cuda())).data.cpu().numpy().copy()
                if self.to_cache == True:
                    self.img_data_cache[index] = img_feats

            # for val or test (evaluation), or
            # when target_obj_conn_map_dir is defined (reinforce),
            # load entire shortest path navigation trajectory
            # and load connectivity map for intermediate rewards
            if self.split in ['val', 'test'
                              ] or self.target_obj_conn_map_dir != False:
                target_obj_id, target_room = False, False
                bbox_obj = [
                    x for x in self.boxes[index]
                    if x['type'] == 'object' and x['target'] == True
                ][0]['box']
                for obj_id in self.env_loaded[self.env_list[index]].objects:
                    box2 = self.env_loaded[self.env_list[index]].objects[
                        obj_id]['bbox']
                    if all([bbox_obj['min'][x] == box2['min'][x] for x in range(3)]) == True and \
                        all([bbox_obj['max'][x] == box2['max'][x] for x in range(3)]) == True:
                        target_obj_id = obj_id
                        break
                bbox_room = [
                    x for x in self.boxes[index]
                    if x['type'] == 'room' and x['target'] == False
                ][0]
                for room in self.env_loaded[self.env_list[
                        index]].env.house.all_rooms:
                    if all([room['bbox']['min'][i] == bbox_room['box']['min'][i] for i in range(3)]) and \
                        all([room['bbox']['max'][i] == bbox_room['box']['max'][i] for i in range(3)]):
                        target_room = room
                        break
                assert target_obj_id != False
                assert target_room != False
                self.env_loaded[self.env_list[index]].set_target_object(
                    self.env_loaded[self.env_list[index]].objects[
                        target_obj_id], target_room)

                # [NOTE] only works for batch size = 1
                self.episode_pos_queue = self.pos_queue[index]
                self.episode_house = self.env_loaded[self.env_list[index]]
                self.target_room = target_room
                self.target_obj = self.env_loaded[self.env_list[
                    index]].objects[target_obj_id]

                actions_in = actions[:action_length]
                actions_out = actions[1:action_length + 1] - 2

                return (idx, question, answer, img_feats, actions_in,
                        actions_out, action_length)

            # if action_length is n
            # images.shape[0] is also n
            # actions[0] is <START>
            # actions[n] is <END>

            # grab 5 random frames
            # [NOTE]: this'll break for longer-than-5 navigation sequences
            start_idx = np.random.choice(img_feats.shape[0] + 1 -
                                         self.num_frames)
            img_feats = img_feats[start_idx:start_idx + self.num_frames]

            actions_in = actions[start_idx:start_idx + self.num_frames]
            actions_out = actions[start_idx + self.num_frames] - 2

            return (idx, question, answer, img_feats, actions_in, actions_out,
                    action_length)

        # [NAV] question+lstm
        elif self.input_type in ['lstm', 'lstm+q']:

            index = self.available_idx[index]

            idx = self.idx[index]
            question = self.questions[index]
            answer = self.answers[index]

            action_length = self.action_lengths[index]
            actions = self.actions[index]

            if self.split == 'train':
                if self.to_cache == True and index in self.img_data_cache:
                    img_feats = self.img_data_cache[index]
                else:
                    pos_queue = self.pos_queue[index]
                    images = self.get_frames(
                        self.env_loaded[self.env_list[index]],
                        pos_queue,
                        preprocess=True)
                    raw_img_feats = self.cnn(
                        Variable(torch.FloatTensor(images)
                                 .cuda())).data.cpu().numpy().copy()
                    img_feats = np.zeros(
                        (self.actions.shape[1], raw_img_feats.shape[1]),
                        dtype=np.float32)
                    img_feats[:raw_img_feats.shape[
                        0], :] = raw_img_feats.copy()
                    if self.to_cache == True:
                        self.img_data_cache[index] = img_feats

            actions_in = actions.clone() - 1
            actions_out = actions[1:].clone() - 2

            actions_in[action_length:].fill_(0)
            mask = actions_out.clone().gt(-1)
            if len(actions_out) > action_length:
                actions_out[action_length:].fill_(0)

            # for val or test (evaluation), or
            # when target_obj_conn_map_dir is defined (reinforce),
            # load entire shortest path navigation trajectory
            # and load connectivity map for intermediate rewards
            if self.split in ['val', 'test'
                              ] or self.target_obj_conn_map_dir != False:
                target_obj_id, target_room = False, False
                bbox_obj = [
                    x for x in self.boxes[index]
                    if x['type'] == 'object' and x['target'] == True
                ][0]['box']
                for obj_id in self.env_loaded[self.env_list[index]].objects:
                    box2 = self.env_loaded[self.env_list[index]].objects[
                        obj_id]['bbox']
                    if all([bbox_obj['min'][x] == box2['min'][x] for x in range(3)]) == True and \
                        all([bbox_obj['max'][x] == box2['max'][x] for x in range(3)]) == True:
                        target_obj_id = obj_id
                        break
                bbox_room = [
                    x for x in self.boxes[index]
                    if x['type'] == 'room' and x['target'] == False
                ][0]
                for room in self.env_loaded[self.env_list[
                        index]].env.house.all_rooms:
                    if all([room['bbox']['min'][i] == bbox_room['box']['min'][i] for i in range(3)]) and \
                        all([room['bbox']['max'][i] == bbox_room['box']['max'][i] for i in range(3)]):
                        target_room = room
                        break
                assert target_obj_id != False
                assert target_room != False
                self.env_loaded[self.env_list[index]].set_target_object(
                    self.env_loaded[self.env_list[index]].objects[
                        target_obj_id], target_room)

                # [NOTE] only works for batch size = 1
                self.episode_pos_queue = self.pos_queue[index]
                self.episode_house = self.env_loaded[self.env_list[index]]
                self.target_room = target_room
                self.target_obj = self.env_loaded[self.env_list[
                    index]].objects[target_obj_id]

                return (idx, question, answer, False, actions_in, actions_out,
                        action_length, mask)

            return (idx, question, answer, img_feats, actions_in, actions_out,
                    action_length, mask)

        # [NAV] planner-controller
        elif self.input_type in ['pacman']:

            index = self.available_idx[index]

            idx = self.idx[index]
            question = self.questions[index]
            answer = self.answers[index]

            action_length = self.action_lengths[index]
            actions = self.actions[index]

            planner_actions = self.planner_actions[index]
            controller_actions = self.controller_actions[index]

            planner_action_length = self.planner_action_lengths[index]
            controller_action_length = self.controller_action_lengths[index]

            planner_hidden_idx = self.planner_hidden_idx[index]

            if self.split == 'train':
                if self.to_cache == True and index in self.img_data_cache:
                    img_feats = self.img_data_cache[index]
                else:
                    pos_queue = self.pos_queue[index]
                    images = self.get_frames(
                        self.env_loaded[self.env_list[index]],
                        pos_queue,
                        preprocess=True)
                    raw_img_feats = self.cnn(
                        Variable(torch.FloatTensor(images)
                                 .cuda())).data.cpu().numpy().copy()
                    img_feats = np.zeros(
                        (self.actions.shape[1], raw_img_feats.shape[1]),
                        dtype=np.float32)
                    img_feats[:raw_img_feats.shape[
                        0], :] = raw_img_feats.copy()
                    if self.to_cache == True:
                        self.img_data_cache[index] = img_feats

            if self.split in ['val', 'test'
                              ] or self.target_obj_conn_map_dir != False:
                target_obj_id, target_room = False, False
                bbox_obj = [
                    x for x in self.boxes[index]
                    if x['type'] == 'object' and x['target'] == True
                ][0]['box']
                for obj_id in self.env_loaded[self.env_list[index]].objects:
                    box2 = self.env_loaded[self.env_list[index]].objects[
                        obj_id]['bbox']
                    if all([bbox_obj['min'][x] == box2['min'][x] for x in range(3)]) == True and \
                        all([bbox_obj['max'][x] == box2['max'][x] for x in range(3)]) == True:
                        target_obj_id = obj_id
                        break
                bbox_room = [
                    x for x in self.boxes[index]
                    if x['type'] == 'room' and x['target'] == False
                ][0]
                for room in self.env_loaded[self.env_list[
                        index]].env.house.all_rooms:
                    if all([room['bbox']['min'][i] == bbox_room['box']['min'][i] for i in range(3)]) and \
                        all([room['bbox']['max'][i] == bbox_room['box']['max'][i] for i in range(3)]):
                        target_room = room
                        break
                assert target_obj_id != False
                assert target_room != False
                self.env_loaded[self.env_list[index]].set_target_object(
                    self.env_loaded[self.env_list[index]].objects[
                        target_obj_id], target_room)

                # [NOTE] only works for batch size = 1
                self.episode_pos_queue = self.pos_queue[index]
                self.episode_house = self.env_loaded[self.env_list[index]]
                self.target_room = target_room
                self.target_obj = self.env_loaded[self.env_list[
                    index]].objects[target_obj_id]

                return (idx, question, answer, actions, action_length)

            planner_pos_queue_idx = self.planner_pos_queue_idx[index]
            controller_pos_queue_idx = self.controller_pos_queue_idx[index]

            planner_img_feats = np.zeros(
                (self.actions.shape[1], img_feats.shape[1]), dtype=np.float32)
            planner_img_feats[:planner_action_length] = img_feats[
                planner_pos_queue_idx]

            planner_actions_in = planner_actions.clone() - 1
            planner_actions_out = planner_actions[1:].clone() - 2

            planner_actions_in[planner_action_length:].fill_(0)
            planner_mask = planner_actions_out.clone().gt(-1)
            if len(planner_actions_out) > planner_action_length:
                planner_actions_out[planner_action_length:].fill_(0)

            controller_img_feats = np.zeros(
                (self.actions.shape[1], img_feats.shape[1]), dtype=np.float32)
            controller_img_feats[:controller_action_length] = img_feats[
                controller_pos_queue_idx]

            controller_actions_in = actions[1:].clone() - 2
            if len(controller_actions_in) > controller_action_length:
                controller_actions_in[controller_action_length:].fill_(0)

            controller_out = controller_actions
            controller_mask = controller_out.clone().gt(-1)
            if len(controller_out) > controller_action_length:
                controller_out[controller_action_length:].fill_(0)

            # zero out forced controller return
            for i in range(controller_action_length):
                if i >= self.max_controller_actions - 1 and controller_out[i] == 0 and \
                        (self.max_controller_actions == 1 or
                         controller_out[i - self.max_controller_actions + 1:i].sum()
                         == self.max_controller_actions - 1):
                    controller_mask[i] = 0
                    
            return (idx, question, answer, planner_img_feats,
                    planner_actions_in, planner_actions_out,
                    planner_action_length, planner_mask, controller_img_feats,
                    controller_actions_in, planner_hidden_idx, controller_out,
                    controller_action_length, controller_mask)

    def __len__(self):
        if self.input_type == 'ques':
            return len(self.questions)
        else:
            return len(self.available_idx)


class EqaDataLoader(DataLoader):
    def __init__(self, **kwargs):
        if 'questions_h5' not in kwargs:
            raise ValueError('Must give questions_h5')
        if 'data_json' not in kwargs:
            raise ValueError('Must give data_json')
        if 'vocab' not in kwargs:
            raise ValueError('Must give vocab')
        if 'input_type' not in kwargs:
            raise ValueError('Must give input_type')
        if 'split' not in kwargs:
            raise ValueError('Must give split')
        if 'gpu_id' not in kwargs:
            raise ValueError('Must give gpu_id')

        questions_h5_path = kwargs.pop('questions_h5')
        data_json = kwargs.pop('data_json')
        input_type = kwargs.pop('input_type')

        split = kwargs.pop('split')
        vocab = kwargs.pop('vocab')

        gpu_id = kwargs.pop('gpu_id')

        if 'max_threads_per_gpu' in kwargs:
            max_threads_per_gpu = kwargs.pop('max_threads_per_gpu')
        else:
            max_threads_per_gpu = 10

        if 'to_cache' in kwargs:
            to_cache = kwargs.pop('to_cache')
        else:
            to_cache = False

        if 'target_obj_conn_map_dir' in kwargs:
            target_obj_conn_map_dir = kwargs.pop('target_obj_conn_map_dir')
        else:
            target_obj_conn_map_dir = False

        if 'map_resolution' in kwargs:
            map_resolution = kwargs.pop('map_resolution')
        else:
            map_resolution = 1000

        if 'image' in input_type or 'cnn' in input_type:
            kwargs['collate_fn'] = eqaCollateCnn
        elif 'lstm' in input_type:
            kwargs['collate_fn'] = eqaCollateSeq2seq

        if 'overfit' in kwargs:
            overfit = kwargs.pop('overfit')
        else:
            overfit = False

        if 'max_controller_actions' in kwargs:
            max_controller_actions = kwargs.pop('max_controller_actions')
        else:
            max_controller_actions = 5

        if 'max_actions' in kwargs:
            max_actions = kwargs.pop('max_actions')
        else:
            max_actions = None 

        print('Reading questions from ', questions_h5_path)
        with h5py.File(questions_h5_path, 'r') as questions_h5:
            self.dataset = EqaDataset(
                questions_h5,
                vocab,
                num_frames=kwargs.pop('num_frames'),
                data_json=data_json,
                split=split,
                gpu_id=gpu_id,
                input_type=input_type,
                max_threads_per_gpu=max_threads_per_gpu,
                to_cache=to_cache,
                target_obj_conn_map_dir=target_obj_conn_map_dir,
                map_resolution=map_resolution,
                overfit=overfit,
                max_controller_actions=max_controller_actions,
                max_actions=max_actions)

        super(EqaDataLoader, self).__init__(self.dataset, **kwargs)

    def close(self):
        pass

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-train_h5', default='data/04_22/train_v1.h5')
    parser.add_argument('-val_h5', default='data/04_22/val_v1.h5')
    parser.add_argument('-data_json', default='data/04_22/data_v1.json')
    parser.add_argument('-vocab_json', default='data/04_22/vocab_v1.json')

    parser.add_argument(
        '-input_type', default='ques', choices=['ques', 'ques,image'])
    parser.add_argument(
        '-num_frames', default=5,
        type=int)  # -1 = all frames of navigation sequence

    parser.add_argument('-batch_size', default=50, type=int)
    parser.add_argument('-max_threads_per_gpu', default=10, type=int)

    args = parser.parse_args()

    try:
        args.gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
        args.gpus = [int(x) for x in args.gpus]
    except KeyError:
        print("CPU not supported")
        exit()

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[0],
        'cache_path': False,
    }

    train_loader = EqaDataLoader(**train_loader_kwargs)
    train_loader.dataset._load_envs(start_idx=0, in_order=True)
    t = 0

    while True:
        done = False
        all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()
        while done == False:
            print('[Size:%d][t:%d][Cache:%d]' %
                  (len(train_loader.dataset), t,
                   len(train_loader.dataset.img_data_cache)))
            for batch in train_loader:
                t += 1

            if all_envs_loaded == False:
                train_loader.dataset._load_envs(in_order=True)
                if len(train_loader.dataset.pruned_env_set) == 0:
                    done = True
            else:
                done = True


================================================
FILE: training/metrics.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pdb

import copy
import json
import time
import os, sys
import argparse
import numpy as np

class Metric():
    def __init__(self, info={}, metric_names=[], log_json=None):
        self.info = info
        self.metric_names = metric_names

        self.metrics = [[None,None,None] for _ in self.metric_names]

        self.stats = []
        self.num_iters = 0

        self.log_json = log_json

    def update(self, values):
        assert isinstance(values, list)

        self.num_iters += 1
        current_stats = []

        for i in range(len(values)):
            if values[i] is None:
                continue

            if isinstance(values[i], list) == False:
                values[i] = [values[i]]

            if self.metrics[i][0] == None:
                self.metrics[i][0] = np.mean(values[i])
                self.metrics[i][1] = np.mean(values[i])
                self.metrics[i][2] = np.mean(values[i])
            else:
                self.metrics[i][0] = (self.metrics[i][0] * (self.num_iters - 1) + np.mean(values[i])) / self.num_iters
                self.metrics[i][1] = 0.95 * self.metrics[i][1] + 0.05 * np.mean(values[i])
                self.metrics[i][2] = np.mean(values[i])

            self.metrics[i][0] = float(self.metrics[i][0])
            self.metrics[i][1] = float(self.metrics[i][1])
            self.metrics[i][2] = float(self.metrics[i][2])

            current_stats.append(self.metrics[i])

        self.stats.append(copy.deepcopy(current_stats))

    def get_stat_string(self, mode=1):

        stat_string = ''

        for k, v in self.info.items():
            stat_string += '[%s:%s]' % (k, v)

        stat_string += '[iters:%d]' % self.num_iters

        for i in range(len(self.metric_names)):
            stat_string += '[%s:%.05f]' % (self.metric_names[i], self.metrics[i][mode])

        return stat_string

    def dump_log(self):

        if self.log_json == None:
            return False

        dict_to_save = {
            'metric_names': self.metric_names,
            'stats': self.stats
        }

        json.dump(dict_to_save, open(self.log_json, 'w'))

        return True

class VqaMetric(Metric):
    def __init__(self, info={}, metric_names=[], log_json=None):
        Metric.__init__(self, info, metric_names, log_json)

    def compute_ranks(self, scores, labels):
        accuracy = np.zeros(len(labels))
        ranks = np.full(len(labels), scores.shape[1])

        for i in range(scores.shape[0]):
            ranks[i] = scores[i].gt(scores[i][labels[i]]).sum() + 1
            if ranks[i] == 1:
                accuracy[i] = 1

        return accuracy, ranks

class NavMetric(Metric):
    def __init__(self, info={}, metric_names=[], log_json=None):
        Metric.__init__(self, info, metric_names, log_json)


================================================
FILE: training/models.py
================================================
# Model defs for navigation and question answering
# Navigation: CNN, LSTM, Planner-controller
# VQA: question-only, 5-frame + attention

import time
import h5py
import math
import argparse
import numpy as np
import os, sys, json

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

import pdb


def build_mlp(input_dim,
              hidden_dims,
              output_dim,
              use_batchnorm=False,
              dropout=0,
              add_sigmoid=1):
    layers = []
    D = input_dim
    if dropout > 0:
        layers.append(nn.Dropout(p=dropout))
    if use_batchnorm:
        layers.append(nn.BatchNorm1d(input_dim))
    for dim in hidden_dims:
        layers.append(nn.Linear(D, dim))
        if use_batchnorm:
            layers.append(nn.BatchNorm1d(dim))
        if dropout > 0:
            layers.append(nn.Dropout(p=dropout))
        layers.append(nn.ReLU(inplace=True))
        D = dim
    layers.append(nn.Linear(D, output_dim))

    if add_sigmoid == 1:
        layers.append(nn.Sigmoid())

    return nn.Sequential(*layers)


def get_state(m):
    if m is None:
        return None
    state = {}
    for k, v in m.state_dict().items():
        state[k] = v.clone()
    return state


def repackage_hidden(h, batch_size):
    # wraps hidden states in new Variables, to detach them from their history
    if type(h) == Variable:
        return Variable(
            h.data.resize_(h.size(0), batch_size, h.size(2)).zero_())
    else:
        return tuple(repackage_hidden(v, batch_size) for v in h)


def ensure_shared_grads(model, shared_model):
    for param, shared_param in zip(model.parameters(),
                                   shared_model.parameters()):
        if shared_param.grad is not None:
            return
        shared_param._grad = param.grad


class MaskedNLLCriterion(nn.Module):
    def __init__(self):
        super(MaskedNLLCriterion, self).__init__()

    def forward(self, input, target, mask):

        logprob_select = torch.gather(input, 1, target)

        out = torch.masked_select(logprob_select, mask)

        loss = -torch.sum(out) / mask.float().sum()
        return loss

class MultitaskCNNOutput(nn.Module):
    def __init__(
            self,
            num_classes=191,
            pretrained=True,
            checkpoint_path='models/03_13_h3d_hybrid_cnn.pt'
    ):
        super(MultitaskCNNOutput, self).__init__()

        self.num_classes = num_classes
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 8, 5),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(8, 16, 5),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(16, 32, 5),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        self.conv_block4 = nn.Sequential(
            nn.Conv2d(32, 32, 5),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        self.classifier = nn.Sequential(
            nn.Conv2d(32, 512, 5),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(512, 512, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d())

        self.encoder_seg = nn.Conv2d(512, self.num_classes, 1)
        self.encoder_depth = nn.Conv2d(512, 1, 1)
        self.encoder_ae = nn.Conv2d(512, 3, 1)

        self.score_pool2_seg = nn.Conv2d(16, self.num_classes, 1)
        self.score_pool3_seg = nn.Conv2d(32, self.num_classes, 1)

        self.score_pool2_depth = nn.Conv2d(16, 1, 1)
        self.score_pool3_depth = nn.Conv2d(32, 1, 1)

        self.score_pool2_ae = nn.Conv2d(16, 3, 1)
        self.score_pool3_ae = nn.Conv2d(32, 3, 1)

        self.pretrained = pretrained
        if self.pretrained == True:
            print('Loading CNN weights from %s' % checkpoint_path)
            checkpoint = torch.load(
                checkpoint_path, map_location={'cuda:0': 'cpu'})
            self.load_state_dict(checkpoint['model_state'])
            for param in self.parameters():
                param.requires_grad = False
        else:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * (
                        m.out_channels + m.in_channels)
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()

    def forward(self, x):

        conv1 = self.conv_block1(x)
        conv2 = self.conv_block2(conv1)
        conv3 = self.conv_block3(conv2)
        conv4 = self.conv_block4(conv3)

        encoder_output = self.classifier(conv4)

        encoder_output_seg = self.encoder_seg(encoder_output)
        encoder_output_depth = self.encoder_depth(encoder_output)
        encoder_output_ae = self.encoder_ae(encoder_output)

        score_pool2_seg = self.score_pool2_seg(conv2)
        score_pool3_seg = self.score_pool3_seg(conv3)

        score_pool2_depth = self.score_pool2_depth(conv2)
        score_pool3_depth = self.score_pool3_depth(conv3)

        score_pool2_ae = self.score_pool2_ae(conv2)
        score_pool3_ae = self.score_pool3_ae(conv3)

        score_seg = F.upsample(encoder_output_seg, score_pool3_seg.size()[2:], mode='bilinear')
        score_seg += score_pool3_seg
        score_seg = F.upsample(score_seg, score_pool2_seg.size()[2:], mode='bilinear')
        score_seg += score_pool2_seg
        out_seg = F.upsample(score_seg, x.size()[2:], mode='bilinear')

        score_depth = F.upsample(encoder_output_depth, score_pool3_depth.size()[2:], mode='bilinear')
        score_depth += score_pool3_depth
        score_depth = F.upsample(score_depth, score_pool2_depth.size()[2:], mode='bilinear')
        score_depth += score_pool2_depth
        out_depth = F.sigmoid(F.upsample(score_depth, x.size()[2:], mode='bilinear'))

        score_ae = F.upsample(encoder_output_ae, score_pool3_ae.size()[2:], mode='bilinear')
        score_ae += score_pool3_ae
        score_ae = F.upsample(score_ae, score_pool2_ae.size()[2:], mode='bilinear')
        score_ae += score_pool2_ae
        out_ae = F.sigmoid(F.upsample(score_ae, x.size()[2:], mode='bilinear'))

        return out_seg, out_depth, out_ae

class MultitaskCNN(nn.Module):
    def __init__(
            self,
            num_classes=191,
            pretrained=True,
            checkpoint_path='models/03_13_h3d_hybrid_cnn.pt'
    ):
        super(MultitaskCNN, self).__init__()

        self.num_classes = num_classes
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 8, 5),
            nn.BatchNorm2d(8),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        self.conv_block2 = nn.Sequential(
            nn.Conv2d(8, 16, 5),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        self.conv_block3 = nn.Sequential(
            nn.Conv2d(16, 32, 5),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        self.conv_block4 = nn.Sequential(
            nn.Conv2d(32, 32, 5),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2))
        self.classifier = nn.Sequential(
            nn.Conv2d(32, 512, 5),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(),
            nn.Conv2d(512, 512, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d())

        self.encoder_seg = nn.Conv2d(512, self.num_classes, 1)
        self.encoder_depth = nn.Conv2d(512, 1, 1)
        self.encoder_ae = nn.Conv2d(512, 3, 1)

        self.score_pool2_seg = nn.Conv2d(16, self.num_classes, 1)
        self.score_pool3_seg = nn.Conv2d(32, self.num_classes, 1)

        self.score_pool2_depth = nn.Conv2d(16, 1, 1)
        self.score_pool3_depth = nn.Conv2d(32, 1, 1)

        self.score_pool2_ae = nn.Conv2d(16, 3, 1)
        self.score_pool3_ae = nn.Conv2d(32, 3, 1)

        self.pretrained = pretrained
        if self.pretrained == True:
            print('Loading CNN weights from %s' % checkpoint_path)
            checkpoint = torch.load(
                checkpoint_path, map_location={'cuda:0': 'cpu'})
            self.load_state_dict(checkpoint['model_state'])
            for param in self.parameters():
                param.requires_grad = False
        else:
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    n = m.kernel_size[0] * m.kernel_size[1] * (
                        m.out_channels + m.in_channels)
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()

    def forward(self, x):

        assert self.training == False

        conv1 = self.conv_block1(x)
        conv2 = self.conv_block2(conv1)
        conv3 = self.conv_block3(conv2)
        conv4 = self.conv_block4(conv3)

        return conv4.view(-1, 32 * 10 * 10)

        # encoder_output = self.classifier(conv4)

        # encoder_output_seg = self.encoder_seg(encoder_output)
        # encoder_output_depth = self.encoder_depth(encoder_output)
        # encoder_output_ae = self.encoder_ae(encoder_output)

        # score_pool2_seg = self.score_pool2_seg(conv2)
        # score_pool3_seg = self.score_pool3_seg(conv3)

        # score_pool2_depth = self.score_pool2_depth(conv2)
        # score_pool3_depth = self.score_pool3_depth(conv3)

        # score_pool2_ae = self.score_pool2_ae(conv2)
        # score_pool3_ae = self.score_pool3_ae(conv3)

        # score_seg = F.upsample(encoder_output_seg, score_pool3_seg.size()[2:], mode='bilinear')
        # score_seg += score_pool3_seg
        # score_seg = F.upsample(score_seg, score_pool2_seg.size()[2:], mode='bilinear')
        # score_seg += score_pool2_seg
        # out_seg = F.upsample(score_seg, x.size()[2:], mode='bilinear')

        # score_depth = F.upsample(encoder_output_depth, score_pool3_depth.size()[2:], mode='bilinear')
        # score_depth += score_pool3_depth
        # score_depth = F.upsample(score_depth, score_pool2_depth.size()[2:], mode='bilinear')
        # score_depth += score_pool2_depth
        # out_depth = F.sigmoid(F.upsample(score_depth, x.size()[2:], mode='bilinear'))

        # score_ae = F.upsample(encoder_output_ae, score_pool3_ae.size()[2:], mode='bilinear')
        # score_ae += score_pool3_ae
        # score_ae = F.upsample(score_ae, score_pool2_ae.size()[2:], mode='bilinear')
        # score_ae += score_pool2_ae
        # out_ae = F.sigmoid(F.upsample(score_ae, x.size()[2:], mode='bilinear'))

        # return out_seg, out_depth, out_ae


class QuestionLstmEncoder(nn.Module):
    def __init__(self,
                 token_to_idx,
                 wordvec_dim=64,
                 rnn_dim=64,
                 rnn_num_layers=2,
                 rnn_dropout=0):
        super(QuestionLstmEncoder, self).__init__()
        self.token_to_idx = token_to_idx
        self.NULL = token_to_idx['<NULL>']
        self.START = token_to_idx['<START>']
        self.END = token_to_idx['<END>']

        self.embed = nn.Embedding(len(token_to_idx), wordvec_dim)
        self.rnn = nn.LSTM(
            wordvec_dim,
            rnn_dim,
            rnn_num_layers,
            dropout=rnn_dropout,
            batch_first=True)

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.embed.weight.data.uniform_(-initrange, initrange)

    def forward(self, x):
        N, T = x.size()
        idx = torch.LongTensor(N).fill_(T - 1)

        # Find the last non-null element in each sequence
        x_cpu = x.data.cpu()
        for i in range(N):
            for t in range(T - 1):
                if x_cpu[i, t] != self.NULL and x_cpu[i, t + 1] == self.NULL:
                    idx[i] = t
                    break
        idx = idx.type_as(x.data).long()
        idx = Variable(idx, requires_grad=False)

        hs, _ = self.rnn(self.embed(x))

        idx = idx.view(N, 1, 1).expand(N, 1, hs.size(2))
        H = hs.size(2)
        return hs.gather(1, idx).view(N, H)


# ----------- VQA -----------


class VqaLstmModel(nn.Module):
    def __init__(self,
                 vocab,
                 rnn_wordvec_dim=64,
                 rnn_dim=64,
                 rnn_num_layers=2,
                 rnn_dropout=0.5,
                 fc_use_batchnorm=False,
                 fc_dropout=0.5,
                 fc_dims=(64, )):
        super(VqaLstmModel, self).__init__()
        rnn_kwargs = {
            'token_to_idx': vocab['questionTokenToIdx'],
            'wordvec_dim': rnn_wordvec_dim,
            'rnn_dim': rnn_dim,
            'rnn_num_layers': rnn_num_layers,
            'rnn_dropout': rnn_dropout,
        }
        self.rnn = QuestionLstmEncoder(**rnn_kwargs)

        classifier_kwargs = {
            'input_dim': rnn_dim,
            'hidden_dims': fc_dims,
            'output_dim': len(vocab['answerTokenToIdx']),
            'use_batchnorm': fc_use_batchnorm,
            'dropout': fc_dropout,
            'add_sigmoid': 0
        }
        self.classifier = build_mlp(**classifier_kwargs)

    def forward(self, questions):
        q_feats = self.rnn(questions)
        scores = self.classifier(q_feats)
        return scores


class VqaLstmCnnAttentionModel(nn.Module):
    def __init__(self,
                 vocab,
                 image_feat_dim=64,
                 question_wordvec_dim=64,
                 question_hidden_dim=64,
                 question_num_layers=2,
                 question_dropout=0.5,
                 fc_use_batchnorm=False,
                 fc_dropout=0.5,
                 fc_dims=(64, )):
        super(VqaLstmCnnAttentionModel, self).__init__()

        cnn_kwargs = {'num_classes': 191, 'pretrained': True}
        self.cnn = MultitaskCNN(**cnn_kwargs)
        self.cnn_fc_layer = nn.Sequential(
            nn.Linear(32 * 10 * 10, 64), nn.ReLU(), nn.Dropout(p=0.5))

        q_rnn_kwargs = {
            'token_to_idx': vocab['questionTokenToIdx'],
            'wordvec_dim': question_wordvec_dim,
            'rnn_dim': question_hidden_dim,
            'rnn_num_layers': question_num_layers,
            'rnn_dropout': question_dropout,
        }
        self.q_rnn = QuestionLstmEncoder(**q_rnn_kwargs)

        self.img_tr = nn.Sequential(nn.Linear(64, 64), nn.Dropout(p=0.5))

        self.ques_tr = nn.Sequential(nn.Linear(64, 64), nn.Dropout(p=0.5))

        classifier_kwargs = {
            'input_dim': 64,
            'hidden_dims': fc_dims,
            'output_dim': len(vocab['answerTokenToIdx']),
            'use_batchnorm': fc_use_batchnorm,
            'dropout': fc_dropout,
            'add_sigmoid': 0
        }
        self.classifier = build_mlp(**classifier_kwargs)

        self.att = nn.Sequential(
            nn.Tanh(), nn.Dropout(p=0.5), nn.Linear(128, 1))

    def forward(self, images, questions):

        N, T, _, _, _ = images.size()

        # bs x 5 x 3 x 224 x 224
        img_feats = self.cnn(images.contiguous().view(
            -1, images.size(2), images.size(3), images.size(4)))
        img_feats = self.cnn_fc_layer(img_feats)

        img_feats_tr = self.img_tr(img_feats)

        ques_feats = self.q_rnn(questions)
        ques_feats_repl = ques_feats.view(N, 1, -1).repeat(1, T, 1)
        ques_feats_repl = ques_feats_repl.view(N * T, -1)

        ques_feats_tr = self.ques_tr(ques_feats_repl)

        ques_img_feats = torch.cat([ques_feats_tr, img_feats_tr], 1)

        att_feats = self.att(ques_img_feats)
        att_probs = F.softmax(att_feats.view(N, T), dim=1)
        att_probs2 = att_probs.view(N, T, 1).repeat(1, 1, 64)

        att_img_feats = torch.mul(att_probs2, img_feats.view(N, T, 64))
        att_img_feats = torch.sum(att_img_feats, dim=1)

        mul_feats = torch.mul(ques_feats, att_img_feats)

        scores = self.classifier(mul_feats)

        return scores, att_probs


# ----------- Nav -----------


class NavCnnModel(nn.Module):
    def __init__(self,
                 num_frames=5,
                 num_actions=4,
                 question_input=False,
                 question_vocab=False,
                 question_wordvec_dim=64,
                 question_hidden_dim=64,
                 question_num_layers=2,
                 question_dropout=0.5,
                 fc_use_batchnorm=False,
                 fc_dropout=0.5,
                 fc_dims=(64, )):
        super(NavCnnModel, self).__init__()

        # cnn_kwargs = {'num_classes': 191, 'pretrained': True}
        # self.cnn = MultitaskCNN(**cnn_kwargs)
        self.cnn_fc_layer = nn.Sequential(
            nn.Linear(32 * 10 * 10, 64), nn.ReLU(), nn.Dropout(p=0.5))

        self.question_input = question_input
        if self.question_input == True:
            q_rnn_kwargs = {
                'token_to_idx': question_vocab['questionTokenToIdx'],
                'wordvec_dim': question_wordvec_dim,
                'rnn_dim': question_hidden_dim,
                'rnn_num_layers': question_num_layers,
                'rnn_dropout': question_dropout,
            }
            self.q_rnn = QuestionLstmEncoder(**q_rnn_kwargs)
            self.ques_tr = nn.Sequential(
                nn.Linear(64, 64), nn.ReLU(), nn.Dropout(p=0.5))

        classifier_kwargs = {
            'input_dim': 64 * num_frames + self.question_input * 64,
            'hidden_dims': fc_dims,
            'output_dim': num_actions,
            'use_batchnorm': fc_use_batchnorm,
            'dropout': fc_dropout,
            'add_sigmoid': 0
        }
        self.classifier = build_mlp(**classifier_kwargs)

    # batch forward, for supervised learning
    def forward(self, img_feats, questions=None):

        # bs x 5 x 3200
        N, T, _ = img_feats.size()

        img_feats = self.cnn_fc_layer(img_feats)

        img_feats = img_feats.view(N, T, -1)
        img_feats = img_feats.view(N, -1)

        if self.question_input == True:
            ques_feats = self.q_rnn(questions)
            ques_feats = self.ques_tr(ques_feats)

            img_feats = torch.cat([ques_feats, img_feats], 1)

        scores = self.classifier(img_feats)

        return scores

class NavRnnMult(nn.Module):
    def __init__(self,
                 image_input=False,
                 image_feat_dim=128,
                 question_input=False,
                 question_embed_dim=128,
                 action_input=False,
                 action_embed_dim=32,
                 num_actions=4,
                 mode='sl',
                 rnn_type='LSTM',
                 rnn_hidden_dim=128,
                 rnn_num_layers=2,
                 rnn_dropout=0,
                 return_states=False):
        super(NavRnnMult, self).__init__()

        self.image_input = image_input
        self.image_feat_dim = image_feat_dim

        self.question_input = question_input
        self.question_embed_dim = question_embed_dim

        self.action_input = action_input
        self.action_embed_dim = action_embed_dim

        self.num_actions = num_actions

        self.rnn_type = rnn_type
        self.rnn_hidden_dim = rnn_hidden_dim
        self.rnn_num_layers = rnn_num_layers

        self.return_states = return_states

        rnn_input_dim = 0
        if self.image_input == True:
            rnn_input_dim += image_feat_dim
            print('Adding input to %s: image, rnn dim: %d' % (self.rnn_type,
                                                              rnn_input_dim))

        if self.question_input == True:
            #rnn_input_dim += question_embed_dim
            print('Adding input to %s: question, rnn dim: %d' %
                  (self.rnn_type, rnn_input_dim))

        if self.action_input == True:
            self.action_embed = nn.Embedding(num_actions, action_embed_dim)
            rnn_input_dim += action_embed_dim
            print('Adding input to %s: action, rnn dim: %d' % (self.rnn_type,
                                                               rnn_input_dim))

        self.rnn = getattr(nn, self.rnn_type)(
            rnn_input_dim,
            self.rnn_hidden_dim,
            self.rnn_num_layers,
            dropout=rnn_dropout,
            batch_first=True)
        print('Building %s with hidden dim: %d' % (self.rnn_type,
                                                   rnn_hidden_dim))

        self.decoder = nn.Linear(self.rnn_hidden_dim, self.num_actions)

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (Variable(
                weight.new(self.rnn_num_layers, bsz, self.rnn_hidden_dim)
                .zero_()), Variable(
                    weight.new(self.rnn_num_layers, bsz, self.rnn_hidden_dim)
                    .zero_()))
        elif self.rnn_type == 'GRU':
            return Variable(
                weight.new(self.rnn_num_layers, bsz, self.rnn_hidden_dim)
                .zero_())

    def forward(self,
                img_feats,
                question_feats,
                actions_in,
                action_lengths,
                hidden=False):
        input_feats = Variable()

        T = False
        if self.image_input == True:
            N, T, _ = img_feats.size()
            input_feats = img_feats

        if self.question_input == True:
            N, D = question_feats.size()
            question_feats = question_feats.view(N, 1, D)
            if T == False:
                T = actions_in.size(1)
            question_feats = question_feats.repeat(1, T, 1)
            if len(input_feats) == 0:
                input_feats = question_feats
            else:
                #input_feats = torch.cat([input_feats, question_feats], 2)
                input_feats = torch.mul(input_feats, question_feats)

        if self.action_input == True:
            if len(input_feats) == 0:
                input_feats = self.action_embed(actions_in)
            else:
                input_feats = torch.cat(
                    [input_feats, self.action_embed(actions_in)], 2)

        packed_input_feats = pack_padded_sequence(
            input_feats, action_lengths, batch_first=True)
        packed_output, hidden = self.rnn(packed_input_feats)
        rnn_output, _ = pad_packed_sequence(packed_output, batch_first=True)
        output = self.decoder(rnn_output.contiguous().view(
            rnn_output.size(0) * rnn_output.size(1), rnn_output.size(2)))

        if self.return_states == True:
            return rnn_output, output, hidden
        else:
            return output, hidden

    def step_forward(self, img_feats, question_feats, actions_in, hidden):
        input_feats = Variable()

        T = False
        if self.image_input == True:
            N, T, _ = img_feats.size()
            input_feats = img_feats

        if self.question_input == True:
            N, D = question_feats.size()
            question_feats = question_feats.view(N, 1, D)
            if T == False:
                T = actions_in.size(1)
            question_feats = question_feats.repeat(1, T, 1)
            if len(input_feats) == 0:
                input_feats = question_feats
            else:
                #input_feats = torch.cat([input_feats, question_feats], 2)
                input_feats = torch.mul(input_feats, question_feats)

        if self.action_input == True:
            if len(input_feats) == 0:
                input_feats = self.action_embed(actions_in)
            else:
                input_feats = torch.cat(
                    [input_feats, self.action_embed(actions_in)], 2)

        output, hidden = self.rnn(input_feats, hidden)

        output = self.decoder(output.contiguous().view(
            output.size(0) * output.size(1), output.size(2)))

        return output, hidden


class NavRnn(nn.Module):
    def __init__(self,
                 image_input=False,
                 image_feat_dim=128,
                 question_input=False,
                 question_embed_dim=128,
                 action_input=False,
                 action_embed_dim=32,
                 num_actions=4,
                 mode='sl',
                 rnn_type='LSTM',
                 rnn_hidden_dim=128,
                 rnn_num_layers=2,
                 rnn_dropout=0,
                 return_states=False):
        super(NavRnn, self).__init__()

        self.image_input = image_input
        self.image_feat_dim = image_feat_dim

        self.question_input = question_input
        self.question_embed_dim = question_embed_dim

        self.action_input = action_input
        self.action_embed_dim = action_embed_dim

        self.num_actions = num_actions

        self.rnn_type = rnn_type
        self.rnn_hidden_dim = rnn_hidden_dim
        self.rnn_num_layers = rnn_num_layers

        self.return_states = return_states

        rnn_input_dim = 0
        if self.image_input == True:
            rnn_input_dim += image_feat_dim
            print('Adding input to %s: image, rnn dim: %d' % (self.rnn_type,
                                                              rnn_input_dim))

        if self.question_input == True:
            rnn_input_dim += question_embed_dim
            print('Adding input to %s: question, rnn dim: %d' %
                  (self.rnn_type, rnn_input_dim))

        if self.action_input == True:
            self.action_embed = nn.Embedding(num_actions, action_embed_dim)
            rnn_input_dim += action_embed_dim
            print('Adding input to %s: action, rnn dim: %d' % (self.rnn_type,
                                                               rnn_input_dim))

        self.rnn = getattr(nn, self.rnn_type)(
            rnn_input_dim,
            self.rnn_hidden_dim,
            self.rnn_num_layers,
            dropout=rnn_dropout,
            batch_first=True)
        print('Building %s with hidden dim: %d' % (self.rnn_type,
                                                   rnn_hidden_dim))

        self.decoder = nn.Linear(self.rnn_hidden_dim, self.num_actions)

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (Variable(
                weight.new(self.rnn_num_layers, bsz, self.rnn_hidden_dim)
                .zero_()), Variable(
                    weight.new(self.rnn_num_layers, bsz, self.rnn_hidden_dim)
                    .zero_()))
        elif self.rnn_type == 'GRU':
            return Variable(
                weight.new(self.rnn_num_layers, bsz, self.rnn_hidden_dim)
                .zero_())

    def forward(self,
                img_feats,
                question_feats,
                actions_in,
                action_lengths,
                hidden=False):
        input_feats = Variable()

        T = False
        if self.image_input == True:
            N, T, _ = img_feats.size()
            input_feats = img_feats

        if self.question_input == True:
            N, D = question_feats.size()
            question_feats = question_feats.view(N, 1, D)
            if T == False:
                T = actions_in.size(1)
            question_feats = question_feats.repeat(1, T, 1)
            if len(input_feats) == 0:
                input_feats = question_feats
            else:
                input_feats = torch.cat([input_feats, question_feats], 2)

        if self.action_input == True:
            if len(input_feats) == 0:
                input_feats = self.action_embed(actions_in)
            else:
                input_feats = torch.cat(
                    [input_feats, self.action_embed(actions_in)], 2)

        packed_input_feats = pack_padded_sequence(
            input_feats, action_lengths, batch_first=True)
        packed_output, hidden = self.rnn(packed_input_feats)
        rnn_output, _ = pad_packed_sequence(packed_output, batch_first=True)
        output = self.decoder(rnn_output.contiguous().view(
            rnn_output.size(0) * rnn_output.size(1), rnn_output.size(2)))

        if self.return_states == True:
            return rnn_output, output, hidden
        else:
            return output, hidden

    def step_forward(self, img_feats, question_feats, actions_in, hidden):
        input_feats = Variable()

        T = False
        if self.image_input == True:
            N, T, _ = img_feats.size()
            input_feats = img_feats

        if self.question_input == True:
            N, D = question_feats.size()
            question_feats = question_feats.view(N, 1, D)
            if T == False:
                T = actions_in.size(1)
            question_feats = question_feats.repeat(1, T, 1)
            if len(input_feats) == 0:
                input_feats = question_feats
            else:
                input_feats = torch.cat([input_feats, question_feats], 2)

        if self.action_input == True:
            if len(input_feats) == 0:
                input_feats = self.action_embed(actions_in)
            else:
                input_feats = torch.cat(
                    [input_feats, self.action_embed(actions_in)], 2)

        output, hidden = self.rnn(input_feats, hidden)

        output = self.decoder(output.contiguous().view(
            output.size(0) * output.size(1), output.size(2)))

        return output, hidden

class NavCnnRnnMultModel(nn.Module):
    def __init__(
            self,
            num_output=4,  # forward, left, right, stop
            rnn_image_input=True,
            rnn_image_feat_dim=128,
            question_input=False,
            question_vocab=False,
            question_wordvec_dim=64,
            question_hidden_dim=64,
            question_num_layers=2,
            question_dropout=0.5,
            rnn_question_embed_dim=128,
            rnn_action_input=True,
            rnn_action_embed_dim=32,
            rnn_type='LSTM',
            rnn_hidden_dim=1024,
            rnn_num_layers=1,
            rnn_dropout=0):
        super(NavCnnRnnMultModel, self).__init__()

        self.cnn_fc_layer = nn.Sequential(
            nn.Linear(32 * 10 * 10, rnn_image_feat_dim),
            nn.ReLU(),
            nn.Dropout(p=0.5))

        self.rnn_hidden_dim = rnn_hidden_dim

        self.question_input = question_input
        if self.question_input == True:
            q_rnn_kwargs = {
                'token_to_idx': question_vocab['questionTokenToIdx'],
                'wordvec_dim': question_wordvec_dim,
                'rnn_dim': question_hidden_dim,
                'rnn_num_layers': question_num_layers,
                'rnn_dropout': question_dropout,
            }
            self.q_rnn = QuestionLstmEncoder(**q_rnn_kwargs)
            self.ques_tr = nn.Sequential(
                nn.Linear(64, rnn_image_feat_dim), nn.ReLU(), nn.Dropout(p=0.5))

        self.nav_rnn = NavRnnMult(
            image_input=rnn_image_input,
            image_feat_dim=rnn_image_feat_dim,
            question_input=question_input,
            question_embed_dim=question_hidden_dim,
            action_input=rnn_action_input,
            action_embed_dim=rnn_action_embed_dim,
            num_actions=num_output,
            rnn_type=rnn_type,
            rnn_hidden_dim=rnn_hidden_dim,
            rnn_num_layers=rnn_num_layers,
            rnn_dropout=rnn_dropout)

    def forward(self,
                img_feats,
                questions,
                actions_in,
                action_lengths,
                hidden=False,
                step=False):
        N, T, _ = img_feats.size()

        # B x T x 128
        img_feats = self.cnn_fc_layer(img_feats)

        if self.question_input == True:
            ques_feats = self.q_rnn(questions)
            ques_feats = self.ques_tr(ques_feats)

            if step == True:
                output, hidden = self.nav_rnn.step_forward(
                    img_feats, ques_feats, actions_in, hidden)
            else:
                output, hidden = self.nav_rnn(img_feats, ques_feats,
                                              actions_in, action_lengths)
        else:
            if step == True:
                output, hidden = self.nav_rnn.step_forward(
                    img_feats, False, actions_in, hidden)
            else:
                output, hidden = self.nav_rnn(img_feats, False, actions_in,
                                              action_lengths)

        return output, hidden


class NavCnnRnnModel(nn.Module):
    def __init__(
            self,
            num_output=4,  # forward, left, right, stop
            rnn_image_input=True,
            rnn_image_feat_dim=128,
            question_input=False,
            question_vocab=False,
            question_wordvec_dim=64,
            question_hidden_dim=64,
            question_num_layers=2,
            question_dropout=0.5,
            rnn_question_embed_dim=128,
            rnn_action_input=True,
            rnn_action_embed_dim=32,
            rnn_type='LSTM',
            rnn_hidden_dim=1024,
            rnn_num_layers=1,
            rnn_dropout=0):
        super(NavCnnRnnModel, self).__init__()

        self.cnn_fc_layer = nn.Sequential(
            nn.Linear(32 * 10 * 10, rnn_image_feat_dim),
            nn.ReLU(),
            nn.Dropout(p=0.5))

        self.rnn_hidden_dim = rnn_hidden_dim

        self.question_input = question_input
        if self.question_input == True:
            q_rnn_kwargs = {
                'token_to_idx': question_vocab['questionTokenToIdx'],
                'wordvec_dim': question_wordvec_dim,
                'rnn_dim': question_hidden_dim,
                'rnn_num_layers': question_num_layers,
                'rnn_dropout': question_dropout,
            }
            self.q_rnn = QuestionLstmEncoder(**q_rnn_kwargs)
            self.ques_tr = nn.Sequential(
                nn.Linear(64, 64), nn.ReLU(), nn.Dropout(p=0.5))

        self.nav_rnn = NavRnn(
            image_input=rnn_image_input,
            image_feat_dim=rnn_image_feat_dim,
            question_input=question_input,
            question_embed_dim=question_hidden_dim,
            action_input=rnn_action_input,
            action_embed_dim=rnn_action_embed_dim,
            num_actions=num_output,
            rnn_type=rnn_type,
            rnn_hidden_dim=rnn_hidden_dim,
            rnn_num_layers=rnn_num_layers,
            rnn_dropout=rnn_dropout)

    def forward(self,
                img_feats,
                questions,
                actions_in,
                action_lengths,
                hidden=False,
                step=False):
        N, T, _ = img_feats.size()

        # B x T x 128
        img_feats = self.cnn_fc_layer(img_feats)

        if self.question_input == True:
            ques_feats = self.q_rnn(questions)
            ques_feats = self.ques_tr(ques_feats)

            if step == True:
                output, hidden = self.nav_rnn.step_forward(
                    img_feats, ques_feats, actions_in, hidden)
            else:
                output, hidden = self.nav_rnn(img_feats, ques_feats,
                                              actions_in, action_lengths)
        else:
            if step == True:
                output, hidden = self.nav_rnn.step_forward(
                    img_feats, False, actions_in, hidden)
            else:
                output, hidden = self.nav_rnn(img_feats, False, actions_in,
                                              action_lengths)

        return output, hidden


class NavPlannerControllerModel(nn.Module):
    def __init__(self,
                 question_vocab,
                 num_output=4,
                 question_wordvec_dim=64,
                 question_hidden_dim=64,
                 question_num_layers=2,
                 question_dropout=0.5,
                 planner_rnn_image_feat_dim=128,
                 planner_rnn_action_embed_dim=32,
                 planner_rnn_type='GRU',
                 planner_rnn_hidden_dim=1024,
                 planner_rnn_num_layers=1,
                 planner_rnn_dropout=0,
                 controller_fc_dims=(256, )):
        super(NavPlannerControllerModel, self).__init__()

        self.cnn_fc_layer = nn.Sequential(
            nn.Linear(32 * 10 * 10, planner_rnn_image_feat_dim),
            nn.ReLU(),
            nn.Dropout(p=0.5))

        q_rnn_kwargs = {
            'token_to_idx': question_vocab['questionTokenToIdx'],
            'wordvec_dim': question_wordvec_dim,
            'rnn_dim': question_hidden_dim,
            'rnn_num_layers': question_num_layers,
            'rnn_dropout': question_dropout,
        }
        self.q_rnn = QuestionLstmEncoder(**q_rnn_kwargs)
        self.ques_tr = nn.Sequential(
            nn.Linear(question_hidden_dim, question_hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=0.5))

        self.planner_nav_rnn = NavRnn(
            image_input=True,
            image_feat_dim=planner_rnn_image_feat_dim,
            question_input=True,
            question_embed_dim=question_hidden_dim,
            action_input=True,
            action_embed_dim=planner_rnn_action_embed_dim,
            num_actions=num_output,
            rnn_type=planner_rnn_type,
            rnn_hidden_dim=planner_rnn_hidden_dim,
            rnn_num_layers=planner_rnn_num_layers,
            rnn_dropout=planner_rnn_dropout,
            return_states=True)

        controller_kwargs = {
            'input_dim':
            planner_rnn_image_feat_dim + planner_rnn_action_embed_dim +
            planner_rnn_hidden_dim,
            'hidden_dims':
            controller_fc_dims,
            'output_dim':
            2,
            'add_sigmoid':
            0
        }
        self.controller = build_mlp(**controller_kwargs)

    def forward(self,
                questions,
                planner_img_feats,
                planner_actions_in,
                planner_action_lengths,
                planner_hidden_index,
                controller_img_feats,
                controller_actions_in,
                controller_action_lengths,
                planner_hidden=False):

        # ts = time.time()
        N_p, T_p, _ = planner_img_feats.size()

        planner_img_feats = self.cnn_fc_layer(planner_img_feats)
        controller_img_feats = self.cnn_fc_layer(controller_img_feats)

        ques_feats = self.q_rnn(questions)
        ques_feats = self.ques_tr(ques_feats)

        planner_states, planner_scores, planner_hidden = self.planner_nav_rnn(
            planner_img_feats, ques_feats, planner_actions_in,
            planner_action_lengths)

        planner_hidden_index = planner_hidden_index[:, :
                                                    controller_action_lengths.
                                                    max()]
        controller_img_feats = controller_img_feats[:, :
                                                    controller_action_lengths.
                                                    max()]
        controller_actions_in = controller_actions_in[:, :
                                                      controller_action_lengths.
                                                      max()]

        N_c, T_c, _ = controller_img_feats.size()

        assert planner_hidden_index.max().data[0] < planner_states.size(1)

        planner_hidden_index = planner_hidden_index.contiguous().view(
            N_p, planner_hidden_index.size(1), 1).repeat(
                1, 1, planner_states.size(2))

        controller_hidden_in = planner_states.gather(1, planner_hidden_index)
        controller_hidden_in = controller_hidden_in.view(
            N_c * T_c, controller_hidden_in.size(2))

        controller_img_feats = controller_img_feats.contiguous().view(
            N_c * T_c, -1)
        controller_actions_embed = self.planner_nav_rnn.action_embed(
            controller_actions_in).view(N_c * T_c, -1)

        controller_in = torch.cat([
            controller_img_feats, controller_actions_embed,
            controller_hidden_in
        ], 1)
        controller_scores = self.controller(controller_in)

        return planner_scores, controller_scores, planner_hidden

    def planner_step(self, questions, img_feats, actions_in, planner_hidden):

        img_feats = self.cnn_fc_layer(img_feats)
        ques_feats = self.q_rnn(questions)
        ques_feats = self.ques_tr(ques_feats)
        planner_scores, planner_hidden = self.planner_nav_rnn.step_forward(
            img_feats, ques_feats, actions_in, planner_hidden)

        return planner_scores, planner_hidden

    def controller_step(self, img_feats, actions_in, hidden_in):

        img_feats = self.cnn_fc_layer(img_feats)
        actions_embed = self.planner_nav_rnn.action_embed(actions_in)

        img_feats = img_feats.view(1, -1)
        actions_embed = actions_embed.view(1, -1)
        hidden_in = hidden_in.view(1, -1)

        controller_in = torch.cat([img_feats, actions_embed, hidden_in], 1)
        controller_scores = self.controller(controller_in)

        return controller_scores


================================================
FILE: training/train_eqa.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import h5py
import time
import argparse
import numpy as np
import os, sys, json
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.autograd import Variable
torch.backends.cudnn.enabled = False
import torch.multiprocessing as mp

from models import NavCnnModel, NavCnnRnnModel, NavPlannerControllerModel, VqaLstmCnnAttentionModel
from data import EqaDataset, EqaDataLoader
from metrics import NavMetric, VqaMetric

from models import MaskedNLLCriterion

from models import get_state, repackage_hidden, ensure_shared_grads
from data import load_vocab, flat_to_hierarchical_actions

def eval(rank, args, shared_nav_model, shared_ans_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        nav_model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    model_kwargs = {'vocab': load_vocab(args.vocab_json)}
    ans_model = VqaLstmCnnAttentionModel(**model_kwargs)

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': args.map_resolution,
        'batch_size': 1,
        'input_type': args.model_type,
        'num_frames': 5,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': False,
        'max_controller_actions': args.max_controller_actions,
        'max_actions': args.max_actions
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    args.output_nav_log_path = os.path.join(args.log_dir,
                                            'nav_eval_' + str(rank) + '.json')
    args.output_ans_log_path = os.path.join(args.log_dir,
                                            'ans_eval_' + str(rank) + '.json')

    t, epoch, best_eval_acc = 0, 0, 0.0

    while epoch < int(args.max_epochs):

        start_time = time.time()
        invalids = []

        nav_model.load_state_dict(shared_nav_model.state_dict())
        nav_model.eval()

        ans_model.load_state_dict(shared_ans_model.state_dict())
        ans_model.eval()
        ans_model.cuda()

        # that's a lot of numbers
        nav_metrics = NavMetric(
            info={'split': args.eval_split,
                  'thread': rank},
            metric_names=[
                'd_0_10', 'd_0_30', 'd_0_50', 'd_T_10', 'd_T_30', 'd_T_50',
                'd_D_10', 'd_D_30', 'd_D_50', 'd_min_10', 'd_min_30',
                'd_min_50', 'r_T_10', 'r_T_30', 'r_T_50', 'r_e_10', 'r_e_30',
                'r_e_50', 'stop_10', 'stop_30', 'stop_50', 'ep_len_10',
                'ep_len_30', 'ep_len_50'
            ],
            log_json=args.output_nav_log_path)

        vqa_metrics = VqaMetric(
            info={'split': args.eval_split,
                  'thread': rank},
            metric_names=[
                'accuracy_10', 'accuracy_30', 'accuracy_50', 'mean_rank_10',
                'mean_rank_30', 'mean_rank_50', 'mean_reciprocal_rank_10',
                'mean_reciprocal_rank_30', 'mean_reciprocal_rank_50'
            ],
            log_json=args.output_ans_log_path)

        if 'pacman' in args.model_type:

            done = False

            while done == False:

                for batch in tqdm(eval_loader):

                    nav_model.load_state_dict(shared_nav_model.state_dict())
                    nav_model.eval()
                    nav_model.cuda()

                    idx, question, answer, actions, action_length = batch
                    metrics_slug = {}

                    h3d = eval_loader.dataset.episode_house

                    # evaluate at multiple initializations
                    for i in [10, 30, 50]:

                        t += 1

                        if i > action_length[0]:
                            invalids.append([idx[0], i])
                            continue

                        question_var = Variable(question.cuda())

                        controller_step = False
                        planner_hidden = nav_model.planner_nav_rnn.init_hidden(
                            1)

                        # forward through planner till spawn
                        (
                            planner_actions_in, planner_img_feats,
                            controller_step, controller_action_in,
                            controller_img_feat, init_pos,
                            controller_action_counter
                        ) = eval_loader.dataset.get_hierarchical_features_till_spawn(
                            actions[0, :action_length[0] + 1].numpy(), i, args.max_controller_actions
                        )

                        planner_actions_in_var = Variable(
                            planner_actions_in.cuda())
                        planner_img_feats_var = Variable(
                            planner_img_feats.cuda())

                        for step in range(planner_actions_in.size(0)):

                            planner_scores, planner_hidden = nav_model.planner_step(
                                question_var, planner_img_feats_var[step].view(
                                    1, 1,
                                    3200), planner_actions_in_var[step].view(
                                        1, 1), planner_hidden)

                        if controller_step == True:

                            controller_img_feat_var = Variable(
                                controller_img_feat.cuda())
                            controller_action_in_var = Variable(
                                torch.LongTensor(1, 1).fill_(
                                    int(controller_action_in)).cuda())

                            controller_scores = nav_model.controller_step(
                                controller_img_feat_var.view(1, 1, 3200),
                                controller_action_in_var.view(1, 1),
                                planner_hidden[0])

                            prob = F.softmax(controller_scores, dim=1)
                            controller_action = int(
                                prob.max(1)[1].data.cpu().numpy()[0])

                            if controller_action == 1:
                                controller_step = True
                            else:
                                controller_step = False

                            action = int(controller_action_in)
                            action_in = torch.LongTensor(
                                1, 1).fill_(action + 1).cuda()

                        else:

                            prob = F.softmax(planner_scores, dim=1)
                            action = int(prob.max(1)[1].data.cpu().numpy()[0])

                            action_in = torch.LongTensor(
                                1, 1).fill_(action + 1).cuda()

                        h3d.env.reset(
                            x=init_pos[0], y=init_pos[2], yaw=init_pos[3])

                        init_dist_to_target = h3d.get_dist_to_target(
                            h3d.env.cam.pos)
                        if init_dist_to_target < 0:  # unreachable
                            invalids.append([idx[0], i])
                            continue

                        episode_length = 0
                        episode_done = True
                        controller_action_counter = 0

                        dists_to_target, pos_queue, pred_actions = [
                            init_dist_to_target
                        ], [init_pos], []
                        planner_actions, controller_actions = [], []

                        if action != 3:

                            # take the first step
                            img, _, _ = h3d.step(action)
                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = eval_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224,
                                                  224).cuda())).view(
                                                      1, 1, 3200)

                            for step in range(args.max_episode_length):

                                episode_length += 1

                                if controller_step == False:
                                    planner_scores, planner_hidden = nav_model.planner_step(
                                        question_var, img_feat_var,
                                        Variable(action_in), planner_hidden)

                                    prob = F.softmax(planner_scores, dim=1)
                                    action = int(
                                        prob.max(1)[1].data.cpu().numpy()[0])
                                    planner_actions.append(action)

                                pred_actions.append(action)
                                img, _, episode_done = h3d.step(action)

                                episode_done = episode_done or episode_length >= args.max_episode_length

                                img = torch.from_numpy(img.transpose(
                                    2, 0, 1)).float() / 255.0
                                img_feat_var = eval_loader.dataset.cnn(
                                    Variable(img.view(1, 3, 224, 224)
                                             .cuda())).view(1, 1, 3200)

                                dists_to_target.append(
                                    h3d.get_dist_to_target(h3d.env.cam.pos))
                                pos_queue.append([
                                    h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                    h3d.env.cam.pos.z, h3d.env.cam.yaw
                                ])

                                if episode_done == True:
                                    break

                                # query controller to continue or not
                                controller_action_in = Variable(
                                    torch.LongTensor(1,
                                                     1).fill_(action).cuda())
                                controller_scores = nav_model.controller_step(
                                    img_feat_var, controller_action_in,
                                    planner_hidden[0])

                                prob = F.softmax(controller_scores, dim=1)
                                controller_action = int(
                                    prob.max(1)[1].data.cpu().numpy()[0])

                                if controller_action == 1 and controller_action_counter < 4:
                                    controller_action_counter += 1
                                    controller_step = True
                                else:
                                    controller_action_counter = 0
                                    controller_step = False
                                    controller_action = 0

                                controller_actions.append(controller_action)

                                action_in = torch.LongTensor(
                                    1, 1).fill_(action + 1).cuda()

                        # run answerer here
                        if len(pos_queue) < 5:
                            pos_queue = eval_loader.dataset.episode_pos_queue[len(
                                pos_queue) - 5:] + pos_queue
                        images = eval_loader.dataset.get_frames(
                            h3d, pos_queue[-5:], preprocess=True)
                        images_var = Variable(
                            torch.from_numpy(images).cuda()).view(
                                1, 5, 3, 224, 224)
                        scores, att_probs = ans_model(images_var, question_var)
                        ans_acc, ans_rank = vqa_metrics.compute_ranks(
                            scores.data.cpu(), answer)

                        pred_answer = scores.max(1)[1].data[0]

                        print('[Q_GT]', ' '.join([
                            eval_loader.dataset.vocab['questionIdxToToken'][x]
                            for x in question[0] if x != 0
                        ]))
                        print('[A_GT]', eval_loader.dataset.vocab[
                            'answerIdxToToken'][answer[0]])
                        print('[A_PRED]', eval_loader.dataset.vocab[
                            'answerIdxToToken'][pred_answer])

                        # compute stats
                        metrics_slug['accuracy_' + str(i)] = ans_acc[0]
                        metrics_slug['mean_rank_' + str(i)] = ans_rank[0]
                        metrics_slug['mean_reciprocal_rank_'
                                     + str(i)] = 1.0 / ans_rank[0]

                        metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug['d_D_' + str(
                            i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' + str(i)] = np.array(
                            dists_to_target).min()
                        metrics_slug['ep_len_' + str(i)] = episode_length
                        if action == 3:
                            metrics_slug['stop_' + str(i)] = 1
                        else:
                            metrics_slug['stop_' + str(i)] = 0
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                    # navigation metrics
                    metrics_list = []
                    for i in nav_metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(nav_metrics.metrics[
                                nav_metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    nav_metrics.update(metrics_list)

                    # vqa metrics
                    metrics_list = []
                    for i in vqa_metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(vqa_metrics.metrics[
                                vqa_metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    vqa_metrics.update(metrics_list)

                try:
                    print(nav_metrics.get_stat_string(mode=0))
                    print(vqa_metrics.get_stat_string(mode=0))
                except:
                    pass

                print('epoch', epoch)
                print('invalids', len(invalids))

                eval_loader.dataset._load_envs()
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True

        epoch += 1

        # checkpoint if best val accuracy
        if vqa_metrics.metrics[2][0] > best_eval_acc:  # ans_acc_50
            best_eval_acc = vqa_metrics.metrics[2][0]
            if epoch % args.eval_every == 0 and args.log == True:
                vqa_metrics.dump_log()
                nav_metrics.dump_log()

                model_state = get_state(nav_model)

                aad = dict(args.__dict__)
                ad = {}
                for i in aad:
                    if i[0] != '_':
                        ad[i] = aad[i]

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_ans_50_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_ans_acc_50:%.04f]' % best_eval_acc)

        eval_loader.dataset._load_envs(start_idx=0, in_order=True)


def train(rank, args, shared_nav_model, shared_ans_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        nav_model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    model_kwargs = {'vocab': load_vocab(args.vocab_json)}
    ans_model = VqaLstmCnnAttentionModel(**model_kwargs)

    optim = torch.optim.SGD(
        filter(lambda p: p.requires_grad, shared_nav_model.parameters()),
        lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': args.map_resolution,
        'batch_size': 1,
        'input_type': args.model_type,
        'num_frames': 5,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.cache,
        'max_controller_actions': args.max_controller_actions,
        'max_actions': args.max_actions
    }

    args.output_nav_log_path = os.path.join(args.log_dir,
                                            'nav_train_' + str(rank) + '.json')
    args.output_ans_log_path = os.path.join(args.log_dir,
                                            'ans_train_' + str(rank) + '.json')

    nav_model.load_state_dict(shared_nav_model.state_dict())
    nav_model.cuda()

    ans_model.load_state_dict(shared_ans_model.state_dict())
    ans_model.eval()
    ans_model.cuda()

    nav_metrics = NavMetric(
        info={'split': 'train',
              'thread': rank},
        metric_names=[
            'planner_loss', 'controller_loss', 'reward', 'episode_length'
        ],
        log_json=args.output_nav_log_path)

    vqa_metrics = VqaMetric(
        info={'split': 'train',
              'thread': rank},
        metric_names=['accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_ans_log_path)

    train_loader = EqaDataLoader(**train_loader_kwargs)

    print('train_loader has %d samples' % len(train_loader.dataset))

    t, epoch = 0, 0
    p_losses, c_losses, reward_list, episode_length_list = [], [], [], []

    nav_metrics.update([10.0, 10.0, 0, 100])

    mult = 0.1

    while epoch < int(args.max_epochs):

        if 'pacman' in args.model_type:

            planner_lossFn = MaskedNLLCriterion().cuda()
            controller_lossFn = MaskedNLLCriterion().cuda()

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    nav_model.load_state_dict(shared_nav_model.state_dict())
                    nav_model.eval()
                    nav_model.cuda()

                    idx, question, answer, actions, action_length = batch
                    metrics_slug = {}

                    h3d = train_loader.dataset.episode_house

                    # evaluate at multiple initializations
                    # for i in [10, 30, 50]:

                    t += 1

                    question_var = Variable(question.cuda())

                    controller_step = False
                    planner_hidden = nav_model.planner_nav_rnn.init_hidden(1)

                    # forward through planner till spawn
                    (
                        planner_actions_in, planner_img_feats,
                        controller_step, controller_action_in,
                        controller_img_feat, init_pos,
                        controller_action_counter
                    ) = train_loader.dataset.get_hierarchical_features_till_spawn(
                        actions[0, :action_length[0] + 1].numpy(), max(3, int(mult * action_length[0])), args.max_controller_actions
                    )

                    planner_actions_in_var = Variable(
                        planner_actions_in.cuda())
                    planner_img_feats_var = Variable(planner_img_feats.cuda())

                    for step in range(planner_actions_in.size(0)):

                        planner_scores, planner_hidden = nav_model.planner_step(
                            question_var, planner_img_feats_var[step].view(
                                1, 1, 3200), planner_actions_in_var[step].view(
                                    1, 1), planner_hidden)

                    if controller_step == True:

                        controller_img_feat_var = Variable(
                            controller_img_feat.cuda())
                        controller_action_in_var = Variable(
                            torch.LongTensor(1, 1).fill_(
                                int(controller_action_in)).cuda())

                        controller_scores = nav_model.controller_step(
                            controller_img_feat_var.view(1, 1, 3200),
                            controller_action_in_var.view(1, 1),
                            planner_hidden[0])

                        prob = F.softmax(controller_scores, dim=1)
                        controller_action = int(
                            prob.max(1)[1].data.cpu().numpy()[0])

                        if controller_action == 1:
                            controller_step = True
                        else:
                            controller_step = False

                        action = int(controller_action_in)
                        action_in = torch.LongTensor(
                            1, 1).fill_(action + 1).cuda()

                    else:

                        prob = F.softmax(planner_scores, dim=1)
                        action = int(prob.max(1)[1].data.cpu().numpy()[0])

                        action_in = torch.LongTensor(
                            1, 1).fill_(action + 1).cuda()

                    h3d.env.reset(
                        x=init_pos[0], y=init_pos[2], yaw=init_pos[3])

                    init_dist_to_target = h3d.get_dist_to_target(
                        h3d.env.cam.pos)
                    if init_dist_to_target < 0:  # unreachable
                        # invalids.append([idx[0], i])
                        continue

                    episode_length = 0
                    episode_done = True
                    controller_action_counter = 0

                    dists_to_target, pos_queue = [init_dist_to_target], [
                        init_pos
                    ]

                    rewards, planner_actions, planner_log_probs, controller_actions, controller_log_probs = [], [], [], [], []

                    if action != 3:

                        # take the first step
                        img, rwd, episode_done = h3d.step(action, step_reward=True)
                        img = torch.from_numpy(img.transpose(
                            2, 0, 1)).float() / 255.0
                        img_feat_var = train_loader.dataset.cnn(
                            Variable(img.view(1, 3, 224, 224).cuda())).view(
                                1, 1, 3200)

                        for step in range(args.max_episode_length):

                            episode_length += 1

                            if controller_step == False:
                                planner_scores, planner_hidden = nav_model.planner_step(
                                    question_var, img_feat_var,
                                    Variable(action_in), planner_hidden)

                                planner_prob = F.softmax(planner_scores, dim=1)
                                planner_log_prob = F.log_softmax(
                                    planner_scores, dim=1)

                                action = planner_prob.multinomial().data
                                planner_log_prob = planner_log_prob.gather(
                                    1, Variable(action))

                                planner_log_probs.append(
                                    planner_log_prob.cpu())

                                action = int(action.cpu().numpy()[0, 0])
                                planner_actions.append(action)

                            img, rwd, episode_done = h3d.step(action, step_reward=True)

                            episode_done = episode_done or episode_length >= args.max_episode_length

                            rewards.append(rwd)

                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = train_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224, 224)
                                         .cuda())).view(1, 1, 3200)

                            dists_to_target.append(
                                h3d.get_dist_to_target(h3d.env.cam.pos))
                            pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                h3d.env.cam.pos.z, h3d.env.cam.yaw
                            ])

                            if episode_done == True:
                                break

                            # query controller to continue or not
                            controller_action_in = Variable(
                                torch.LongTensor(1, 1).fill_(action).cuda())
                            controller_scores = nav_model.controller_step(
                                img_feat_var, controller_action_in,
                                planner_hidden[0])

                            controller_prob = F.softmax(
                                controller_scores, dim=1)
                            controller_log_prob = F.log_softmax(
                                controller_scores, dim=1)

                            controller_action = controller_prob.multinomial(
                            ).data

                            if int(controller_action[0]
                                   ) == 1 and controller_action_counter < 4:
                                controller_action_counter += 1
                                controller_step = True
                            else:
                                controller_action_counter = 0
                                controller_step = False
                                controller_action.fill_(0)

                            controller_log_prob = controller_log_prob.gather(
                                1, Variable(controller_action))
                            controller_log_probs.append(
                                controller_log_prob.cpu())

                            controller_action = int(
                                controller_action.cpu().numpy()[0, 0])
                            controller_actions.append(controller_action)
                            action_in = torch.LongTensor(
                                1, 1).fill_(action + 1).cuda()

                    # run answerer here
                    ans_acc = [0]
                    if action == 3:
                        if len(pos_queue) < 5:
                            pos_queue = train_loader.dataset.episode_pos_queue[len(
                                pos_queue) - 5:] + pos_queue
                        images = train_loader.dataset.get_frames(
                            h3d, pos_queue[-5:], preprocess=True)
                        images_var = Variable(
                            torch.from_numpy(images).cuda()).view(
                                1, 5, 3, 224, 224)
                        scores, att_probs = ans_model(images_var, question_var)
                        ans_acc, ans_rank = vqa_metrics.compute_ranks(
                            scores.data.cpu(), answer)
                        vqa_metrics.update([ans_acc, ans_rank, 1.0 / ans_rank])

                    rewards.append(h3d.success_reward * ans_acc[0])

                    R = torch.zeros(1, 1)

                    planner_loss = 0
                    controller_loss = 0

                    planner_rev_idx = -1
                    for i in reversed(range(len(rewards))):
                        R = 0.99 * R + rewards[i]
                        advantage = R - nav_metrics.metrics[2][1]

                        if i < len(controller_actions):
                            controller_loss = controller_loss - controller_log_probs[i] * Variable(
                                advantage)

                            if controller_actions[i] == 0 and planner_rev_idx + len(planner_log_probs) >= 0:
                                planner_loss = planner_loss - planner_log_probs[planner_rev_idx] * Variable(
                                    advantage)
                                planner_rev_idx -= 1

                        elif planner_rev_idx + len(planner_log_probs) >= 0:

                            planner_loss = planner_loss - planner_log_probs[planner_rev_idx] * Variable(
                                advantage)
                            planner_rev_idx -= 1

                    controller_loss /= max(1, len(controller_log_probs))
                    planner_loss /= max(1, len(planner_log_probs))

                    optim.zero_grad()

                    if isinstance(planner_loss, float) == False and isinstance(
                            controller_loss, float) == False:
                        p_losses.append(planner_loss.data[0, 0])
                        c_losses.append(controller_loss.data[0, 0])
                        reward_list.append(np.sum(rewards))
                        episode_length_list.append(episode_length)

                        (planner_loss + controller_loss).backward()

                        ensure_shared_grads(nav_model.cpu(), shared_nav_model)
                        optim.step()

                    if len(reward_list) > 50:

                        nav_metrics.update([
                            p_losses, c_losses, reward_list,
                            episode_length_list
                        ])

                        print(nav_metrics.get_stat_string())
                        if args.log == True:
                            nav_metrics.dump_log()

                        if nav_metrics.metrics[2][1] > 0.35:
                            mult = min(mult + 0.1, 1.0)

                        p_losses, c_losses, reward_list, episode_length_list = [], [], [], []

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                        if args.cache == False:
                            train_loader.dataset._load_envs(
                                start_idx=0, in_order=True)
                else:
                    done = True

        epoch += 1


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # data params
    parser.add_argument('-train_h5', default='data/train.h5')
    parser.add_argument('-val_h5', default='data/val.h5')
    parser.add_argument('-test_h5', default='data/test.h5')
    parser.add_argument('-data_json', default='data/data.json')
    parser.add_argument('-vocab_json', default='data/vocab.json')

    parser.add_argument(
        '-target_obj_conn_map_dir',
        default='/path/to/target-obj-conn-maps/500')
    parser.add_argument('-map_resolution', default=500, type=int)

    parser.add_argument(
        '-mode',
        default='train+eval',
        type=str,
        choices=['train', 'eval', 'train+eval'])
    parser.add_argument('-eval_split', default='val', type=str)

    # model details
    parser.add_argument(
        '-model_type',
        default='pacman',
        choices=['cnn', 'cnn+q', 'lstm', 'lstm+q', 'pacman'])
    parser.add_argument('-max_episode_length', default=100, type=int)

    # optim params
    parser.add_argument('-batch_size', default=20, type=int)
    parser.add_argument('-learning_rate', default=1e-5, type=float)
    parser.add_argument('-max_epochs', default=1000, type=int)

    # bookkeeping
    parser.add_argument('-print_every', default=5, type=int)
    parser.add_argument('-eval_every', default=1, type=int)
    parser.add_argument('-identifier', default='cnn')
    parser.add_argument('-num_processes', default=1, type=int)
    parser.add_argument('-max_threads_per_gpu', default=10, type=int)

    # checkpointing
    parser.add_argument('-nav_checkpoint_path', default=False)
    parser.add_argument('-ans_checkpoint_path', default=False)

    parser.add_argument('-checkpoint_dir', default='checkpoints/eqa/')
    parser.add_argument('-log_dir', default='logs/eqa/')
    parser.add_argument('-log', default=False, action='store_true')
    parser.add_argument('-cache', default=False, action='store_true')
    parser.add_argument('-max_controller_actions', type=int, default=5)
    parser.add_argument('-max_actions', type=int)
    args = parser.parse_args()

    args.time_id = time.strftime("%m_%d_%H:%M")

    try:
        args.gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
        args.gpus = [int(x) for x in args.gpus]
    except KeyError:
        print("CPU not supported")
        exit()

    # Load navigation model
    if args.nav_checkpoint_path != False:
        print('Loading navigation checkpoint from %s' % args.nav_checkpoint_path)
        checkpoint = torch.load(
            args.nav_checkpoint_path, map_location={
                'cuda:0': 'cpu'
            })

        args_to_keep = ['model_type']

        for i in args.__dict__:
            if i not in args_to_keep:
                checkpoint['args'][i] = args.__dict__[i]

        args = type('new_dict', (object, ), checkpoint['args'])

    args.checkpoint_dir = os.path.join(args.checkpoint_dir,
                                       args.time_id + '_' + args.identifier)
    args.log_dir = os.path.join(args.log_dir,
                                args.time_id + '_' + args.identifier)
    print(args.__dict__)

    if not os.path.exists(args.checkpoint_dir) and args.log == True:
        os.makedirs(args.checkpoint_dir)
        os.makedirs(args.log_dir)

    if args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        shared_nav_model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    shared_nav_model.share_memory()

    if args.nav_checkpoint_path != False:
        print('Loading navigation params from checkpoint: %s' %
            args.nav_checkpoint_path)
        shared_nav_model.load_state_dict(checkpoint['state'])

    # Load answering model
    if args.ans_checkpoint_path != False:
        print('Loading answering checkpoint from %s' % args.ans_checkpoint_path)
        ans_checkpoint = torch.load(
            args.ans_checkpoint_path, map_location={
                'cuda:0': 'cpu'
            })

    ans_model_kwargs = {'vocab': load_vocab(args.vocab_json)}
    shared_ans_model = VqaLstmCnnAttentionModel(**ans_model_kwargs)

    shared_ans_model.share_memory()

    if args.ans_checkpoint_path != False:
        print('Loading params from checkpoint: %s' % args.ans_checkpoint_path)
        shared_ans_model.load_state_dict(ans_checkpoint['state'])

    if args.mode == 'eval':

        eval(0, args, shared_nav_model, shared_ans_model)

    elif args.mode == 'train':

        train(0, args, shared_nav_model, shared_ans_model)

    else:

        processes = []

        p = mp.Process(
            target=eval, args=(0, args, shared_nav_model, shared_ans_model))
        p.start()
        processes.append(p)

        for rank in range(1, args.num_processes + 1):
            p = mp.Process(
                target=train,
                args=(rank, args, shared_nav_model, shared_ans_model))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

================================================
FILE: training/train_nav.py
================================================
import time
import argparse
from datetime import datetime
import logging
import numpy as np
import os
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
from models import NavCnnModel, NavCnnRnnModel, NavCnnRnnMultModel, NavPlannerControllerModel
from data import EqaDataLoader
from metrics import NavMetric
from models import MaskedNLLCriterion
from models import get_state, ensure_shared_grads
from data import load_vocab
from torch.autograd import Variable
from tqdm import tqdm
import time

torch.backends.cudnn.enabled = False

################################################################################################
#make models trained in pytorch 4 compatible with earlier pytorch versions
import torch._utils
try:
    torch._utils._rebuild_tensor_v2
except AttributeError:
    def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
        tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
        tensor.requires_grad = requires_grad
        tensor._backward_hooks = backward_hooks
        return tensor
    torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2

################################################################################################

def eval(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'cnn':

        model_kwargs = {}
        model = NavCnnModel(**model_kwargs)

    elif args.model_type == 'cnn+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnModel(**model_kwargs)

    elif args.model_type == 'lstm':

        model_kwargs = {}
        model = NavCnnRnnModel(**model_kwargs)

    elif args.model_type == 'lstm+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnRnnModel(**model_kwargs)

    elif args.model_type == 'lstm-mult+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnRnnMultModel(**model_kwargs)

    elif args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'target_obj_conn_map_dir': args.target_obj_conn_map_dir,
        'map_resolution': args.map_resolution,
        'batch_size': 1,
        'input_type': args.model_type,
        'num_frames': 5,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': False,
        'overfit': args.overfit,
        'max_controller_actions': args.max_controller_actions,
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))
    logging.info("EVAL: eval_loader has {} samples".format(len(eval_loader.dataset)))

    args.output_log_path = os.path.join(args.log_dir,
                                        'eval_' + str(rank) + '.json')

    t, epoch, best_eval_acc = 0, 0, 0.0

    max_epochs = args.max_epochs
    if args.mode == 'eval':
        max_epochs = 1
    while epoch < int(max_epochs):

        invalids = []

        model.load_state_dict(shared_model.state_dict())
        model.eval()

        # that's a lot of numbers
        metrics = NavMetric(
            info={'split': args.eval_split,
                  'thread': rank},
            metric_names=[
                'd_0_10', 'd_0_30', 'd_0_50', 'd_T_10', 'd_T_30', 'd_T_50',
                'd_D_10', 'd_D_30', 'd_D_50', 'd_min_10', 'd_min_30',
                'd_min_50', 'r_T_10', 'r_T_30', 'r_T_50', 'r_e_10', 'r_e_30',
                'r_e_50', 'stop_10', 'stop_30', 'stop_50', 'ep_len_10',
                'ep_len_30', 'ep_len_50'
            ],
            log_json=args.output_log_path)

        if 'cnn' in args.model_type:

            done = False

            while done == False:

                for batch in tqdm(eval_loader):

                    model.load_state_dict(shared_model.state_dict())
                    model.cuda()

                    idx, questions, _, img_feats, actions_in, actions_out, action_length = batch
                    metrics_slug = {}

                    # evaluate at multiple initializations
                    for i in [10, 30, 50]:

                        t += 1

                        if action_length[0] + 1 - i - 5 < 0:
                            invalids.append(idx[0])
                            continue

                        ep_inds = [
                            x for x in range(action_length[0] + 1 - i - 5,
                                             action_length[0] + 1 - i)
                        ]

                        sub_img_feats = torch.index_select(
                            img_feats, 1, torch.LongTensor(ep_inds))

                        init_pos = eval_loader.dataset.episode_pos_queue[
                            ep_inds[-1]]

                        h3d = eval_loader.dataset.episode_house

                        h3d.env.reset(
                            x=init_pos[0], y=init_pos[2], yaw=init_pos[3])

                        init_dist_to_target = h3d.get_dist_to_target(
                            h3d.env.cam.pos)
                        if init_dist_to_target < 0:  # unreachable
                            invalids.append(idx[0])
                            continue

                        sub_img_feats_var = Variable(sub_img_feats.cuda())
                        if '+q' in args.model_type:
                            questions_var = Variable(questions.cuda())

                        # sample actions till max steps or <stop>
                        # max no. of actions = 100

                        episode_length = 0
                        episode_done = True

                        dists_to_target, pos_queue, actions = [
                            init_dist_to_target
                        ], [init_pos], []

                        for step in range(args.max_episode_length):

                            episode_length += 1

                            if '+q' in args.model_type:
                                scores = model(sub_img_feats_var,
                                               questions_var)
                            else:
                                scores = model(sub_img_feats_var)

                            prob = F.softmax(scores, dim=1)

                            action = int(prob.max(1)[1].data.cpu().numpy()[0])

                            actions.append(action)

                            img, _, episode_done = h3d.step(action)

                            episode_done = episode_done or episode_length >= args.max_episode_length

                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = eval_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224, 224)
                                         .cuda())).view(1, 1, 3200)
                            sub_img_feats_var = torch.cat(
                                [sub_img_feats_var, img_feat_var], dim=1)
                            sub_img_feats_var = sub_img_feats_var[:, -5:, :]

                            dists_to_target.append(
                                h3d.get_dist_to_target(h3d.env.cam.pos))
                            pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                h3d.env.cam.pos.z, h3d.env.cam.yaw
                            ])

                            if episode_done == True:
                                break

                        # compute stats
                        metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug['d_D_' + str(
                            i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' + str(i)] = np.array(
                            dists_to_target).min()
                        metrics_slug['ep_len_' + str(i)] = episode_length
                        if action == 3:
                            metrics_slug['stop_' + str(i)] = 1
                        else:
                            metrics_slug['stop_' + str(i)] = 0
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                    # collate and update metrics
                    metrics_list = []
                    for i in metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(metrics.metrics[
                                metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    # update metrics
                    metrics.update(metrics_list)

                print(metrics.get_stat_string(mode=0))
                print('invalids', len(invalids))
                logging.info("EVAL: metrics: {}".format(metrics.get_stat_string(mode=0)))
                logging.info("EVAL: invalids: {}".format(len(invalids)))

               # del h3d
                eval_loader.dataset._load_envs()
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True

        elif 'lstm' in args.model_type:

            done = False

            while done == False:

                if args.overfit:
                    metrics = NavMetric(
                        info={'split': args.eval_split,
                              'thread': rank},
                        metric_names=[
                            'd_0_10', 'd_0_30', 'd_0_50', 'd_T_10', 'd_T_30', 'd_T_50',
                            'd_D_10', 'd_D_30', 'd_D_50', 'd_min_10', 'd_min_30',
                            'd_min_50', 'r_T_10', 'r_T_30', 'r_T_50', 'r_e_10', 'r_e_30',
                            'r_e_50', 'stop_10', 'stop_30', 'stop_50', 'ep_len_10',
                            'ep_len_30', 'ep_len_50'
                        ],
                        log_json=args.output_log_path)

                for batch in tqdm(eval_loader):

                    model.load_state_dict(shared_model.state_dict())
                    model.cuda()

                    idx, questions, answer, _, actions_in, actions_out, action_lengths, _ = batch
                    question_var = Variable(questions.cuda())
                    metrics_slug = {}

                    # evaluate at multiple initializations
                    for i in [10, 30, 50]:

                        t += 1

                        if action_lengths[0] - 1 - i < 0:
                            invalids.append([idx[0], i])
                            continue

                        h3d = eval_loader.dataset.episode_house

                        # forward through lstm till spawn
                        if len(eval_loader.dataset.episode_pos_queue[:-i]
                               ) > 0:
                            images = eval_loader.dataset.get_frames(
                                h3d,
                                eval_loader.dataset.episode_pos_queue[:-i],
                                preprocess=True)
                            raw_img_feats = eval_loader.dataset.cnn(
                                Variable(torch.FloatTensor(images).cuda()))

                            actions_in_pruned = actions_in[:, :
                                                           action_lengths[0] -
                                                           i]
                            actions_in_var = Variable(actions_in_pruned.cuda())
                            action_lengths_pruned = action_lengths.clone(
                            ).fill_(action_lengths[0] - i)
                            img_feats_var = raw_img_feats.view(1, -1, 3200)

                            if '+q' in args.model_type:
                                scores, hidden = model(
                                    img_feats_var, question_var,
                                    actions_in_var,
                                    action_lengths_pruned.cpu().numpy())
                            else:
                                scores, hidden = model(
                                    img_feats_var, False, actions_in_var,
                                    action_lengths_pruned.cpu().numpy())
                            try:
                                init_pos = eval_loader.dataset.episode_pos_queue[
                                    -i]
                            except:
                                invalids.append([idx[0], i])
                                continue

                            action_in = torch.LongTensor(1, 1).fill_(
                                actions_in[0,
                                           action_lengths[0] - i]).cuda()
                        else:
                            init_pos = eval_loader.dataset.episode_pos_queue[
                                -i]
                            hidden = model.nav_rnn.init_hidden(1)
                            action_in = torch.LongTensor(1, 1).fill_(0).cuda()

                        h3d.env.reset(
                            x=init_pos[0], y=init_pos[2], yaw=init_pos[3])

                        init_dist_to_target = h3d.get_dist_to_target(
                            h3d.env.cam.pos)
                        if init_dist_to_target < 0:  # unreachable
                            invalids.append([idx[0], i])
                            continue

                        img = h3d.env.render()
                        img = torch.from_numpy(img.transpose(
                            2, 0, 1)).float() / 255.0
                        img_feat_var = eval_loader.dataset.cnn(
                            Variable(img.view(1, 3, 224, 224).cuda())).view(
                                1, 1, 3200)

                        episode_length = 0
                        episode_done = True

                        dists_to_target, pos_queue, actions = [
                            init_dist_to_target
                        ], [init_pos], []
                        actual_pos_queue = [(h3d.env.cam.pos.x, h3d.env.cam.pos.z, h3d.env.cam.yaw)]

                        for step in range(args.max_episode_length):

                            episode_length += 1

                            if '+q' in args.model_type:
                                scores, hidden = model(
                                    img_feat_var,
                                    question_var,
                                    Variable(action_in),
                                    False,
                                    hidden=hidden,
                                    step=True)
                            else:
                                scores, hidden = model(
                                    img_feat_var,
                                    False,
                                    Variable(action_in),
                                    False,
                                    hidden=hidden,
                                    step=True)

                            prob = F.softmax(scores, dim=1)

                            action = int(prob.max(1)[1].data.cpu().numpy()[0])

                            actions.append(action)

                            img, _, episode_done = h3d.step(action)

                            episode_done = episode_done or episode_length >= args.max_episode_length

                            img = torch.from_numpy(img.transpose(
                                2, 0, 1)).float() / 255.0
                            img_feat_var = eval_loader.dataset.cnn(
                                Variable(img.view(1, 3, 224, 224)
                                         .cuda())).view(1, 1, 3200)

                            action_in = torch.LongTensor(
                                1, 1).fill_(action + 1).cuda()

                            dists_to_target.append(
                                h3d.get_dist_to_target(h3d.env.cam.pos))
                            pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                h3d.env.cam.pos.z, h3d.env.cam.yaw
                            ])

                            if episode_done == True:
                                break

                            actual_pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.z, h3d.env.cam.yaw])

                        # compute stats
                        metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug['d_D_' + str(
                            i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' + str(i)] = np.array(
                            dists_to_target).min()
                        metrics_slug['ep_len_' + str(i)] = episode_length
                        if action == 3:
                            metrics_slug['stop_' + str(i)] = 1
                        else:
                            metrics_slug['stop_' + str(i)] = 0
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                    # collate and update metrics
                    metrics_list = []
                    for i in metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(metrics.metrics[
                                metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    # update metrics
                    metrics.update(metrics_list)

                print(metrics.get_stat_string(mode=0))
                print('invalids', len(invalids))
                logging.info("EVAL: init_steps: {} metrics: {}".format(i, metrics.get_stat_string(mode=0)))
                logging.info("EVAL: init_steps: {} invalids: {}".format(i, len(invalids)))

                # del h3d
                eval_loader.dataset._load_envs()
                print("eval_loader pruned_env_set len: {}".format(len(eval_loader.dataset.pruned_env_set)))
                logging.info("eval_loader pruned_env_set len: {}".format(len(eval_loader.dataset.pruned_env_set)))
                assert len(eval_loader.dataset.pruned_env_set) > 0
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True

        elif 'pacman' in args.model_type:

            done = False

            while done == False:
                if args.overfit:
                    metrics = NavMetric(
                        info={'split': args.eval_split,
                              'thread': rank},
                        metric_names=[
                            'd_0_10', 'd_0_30', 'd_0_50', 'd_T_10', 'd_T_30', 'd_T_50',
                            'd_D_10', 'd_D_30', 'd_D_50', 'd_min_10', 'd_min_30',
                            'd_min_50', 'r_T_10', 'r_T_30', 'r_T_50', 'r_e_10', 'r_e_30',
                            'r_e_50', 'stop_10', 'stop_30', 'stop_50', 'ep_len_10',
                            'ep_len_30', 'ep_len_50'
                        ],
                        log_json=args.output_log_path)

                for batch in tqdm(eval_loader):

                    model.load_state_dict(shared_model.state_dict())
                    model.cuda()

                    idx, question, answer, actions, action_length = batch
                    metrics_slug = {}

                    h3d = eval_loader.dataset.episode_house

                    # evaluate at multiple initializations
                    for i in [10, 30, 50]:

                        t += 1

                        if i > action_length[0]:
                            invalids.append([idx[0], i])
                            continue

                        question_var = Variable(question.cuda())

                        controller_step = False
                        planner_hidden = model.planner_nav_rnn.init_hidden(1)

                        # get hierarchical action history
                        (
                            planner_actions_in, planner_img_feats,
                            controller_step, controller_action_in,
                            controller_img_feats, init_pos,
                            controller_action_counter
                        ) = eval_loader.dataset.get_hierarchical_features_till_spawn(
                            actions[0, :action_length[0] + 1].numpy(), i, args.max_controller_actions
                        )

                        planner_actions_in_var = Variable(
                            planner_actions_in.cuda())
                        planner_img_feats_var = Variable(
                            planner_img_feats.cuda())

                        # forward planner till spawn to update hidden state
                        for step in range(planner_actions_in.size(0)):

                            planner_scores, planner_hidden = model.planner_step(
                                question_var, planner_img_feats_var[step]
                                .unsqueeze(0).unsqueeze(0),
                                planner_actions_in_var[step].view(1, 1),
                                planner_hidden
                            )

                        h3d.env.reset(
                            x=init_pos[0], y=init_pos[2], yaw=init_pos[3])

                        init_dist_to_target = h3d.get_dist_to_target(
                            h3d.env.cam.pos)
                        if init_dist_to_target < 0:  # unreachable
                            invalids.append([idx[0], i])
                            continue

                        dists_to_target, pos_queue, pred_actions = [
                            init_dist_to_target
                        ], [init_pos], []
                        planner_actions, controller_actions = [], []

                        episode_length = 0
                        if args.max_controller_actions > 1:
                            controller_action_counter = controller_action_counter % args.max_controller_actions
                            controller_action_counter = max(controller_action_counter - 1, 0)
                        else:
                            controller_action_counter = 0

                        first_step = True
                        first_step_is_controller = controller_step
                        planner_step = True
                        action = int(controller_action_in)

                        for step in range(args.max_episode_length):
                            if not first_step:
                                img = torch.from_numpy(img.transpose(
                                    2, 0, 1)).float() / 255.0
                                img_feat_var = eval_loader.dataset.cnn(
                                    Variable(img.view(1, 3, 224,
                                                      224).cuda())).view(
                                                          1, 1, 3200)
                            else:
                                img_feat_var = Variable(controller_img_feats.cuda()).view(1, 1, 3200)

                            if not first_step or first_step_is_controller:
                                # query controller to continue or not
                                controller_action_in = Variable(
                                    torch.LongTensor(1, 1).fill_(action).cuda())
                                controller_scores = model.controller_step(
                                    img_feat_var, controller_action_in,
                                    planner_hidden[0])

                                prob = F.softmax(controller_scores, dim=1)
                                controller_action = int(
                                    prob.max(1)[1].data.cpu().numpy()[0])

                                if controller_action == 1 and controller_action_counter < args.max_controller_actions - 1:
                                    controller_action_counter += 1
                                    planner_step = False
                                else:
                                    controller_action_counter = 0
                                    planner_step = True
                                    controller_action = 0

                                controller_actions.append(controller_action)
                                first_step = False

                            if planner_step:
                                if not first_step:
                                    action_in = torch.LongTensor(
                                        1, 1).fill_(action + 1).cuda()
                                    planner_scores, planner_hidden = model.planner_step(
                                        question_var, img_feat_var,
                                        Variable(action_in), planner_hidden)

                                prob = F.softmax(planner_scores, dim=1)
                                action = int(
                                    prob.max(1)[1].data.cpu().numpy()[0])
                                planner_actions.append(action)

                            episode_done = action == 3 or episode_length >= args.max_episode_length

                            episode_length += 1
                            dists_to_target.append(
                                h3d.get_dist_to_target(h3d.env.cam.pos))
                            pos_queue.append([
                                h3d.env.cam.pos.x, h3d.env.cam.pos.y,
                                h3d.env.cam.pos.z, h3d.env.cam.yaw
                            ])

                            if episode_done:
                                break

                            img, _, _ = h3d.step(action)
                            first_step = False

                        # compute stats
                        metrics_slug['d_0_' + str(i)] = dists_to_target[0]
                        metrics_slug['d_T_' + str(i)] = dists_to_target[-1]
                        metrics_slug['d_D_' + str(
                            i)] = dists_to_target[0] - dists_to_target[-1]
                        metrics_slug['d_min_' + str(i)] = np.array(
                            dists_to_target).min()
                        metrics_slug['ep_len_' + str(i)] = episode_length
                        if action == 3:
                            metrics_slug['stop_' + str(i)] = 1
                        else:
                            metrics_slug['stop_' + str(i)] = 0
                        inside_room = []
                        for p in pos_queue:
                            inside_room.append(
                                h3d.is_inside_room(
                                    p, eval_loader.dataset.target_room))
                        if inside_room[-1] == True:
                            metrics_slug['r_T_' + str(i)] = 1
                        else:
                            metrics_slug['r_T_' + str(i)] = 0
                        if any([x == True for x in inside_room]) == True:
                            metrics_slug['r_e_' + str(i)] = 1
                        else:
                            metrics_slug['r_e_' + str(i)] = 0

                    # collate and update metrics
                    metrics_list = []
                    for i in metrics.metric_names:
                        if i not in metrics_slug:
                            metrics_list.append(metrics.metrics[
                                metrics.metric_names.index(i)][0])
                        else:
                            metrics_list.append(metrics_slug[i])

                    # update metrics
                    metrics.update(metrics_list)

                try:
                    print(metrics.get_stat_string(mode=0))
                    logging.info("EVAL: metrics: {}".format(metrics.get_stat_string(mode=0)))
                except:
                    pass

                print('epoch', epoch)
                print('invalids', len(invalids))
                logging.info("EVAL: epoch {}".format(epoch))
                logging.info("EVAL: invalids {}".format(invalids))

                # del h3d
                eval_loader.dataset._load_envs()
                if len(eval_loader.dataset.pruned_env_set) == 0:
                    done = True

        epoch += 1

        # checkpoint if best val loss
        if metrics.metrics[8][0] > best_eval_acc:  # d_D_50
            best_eval_acc = metrics.metrics[8][0]
            if epoch % args.eval_every == 0 and args.log == True:
                metrics.dump_log()

                model_state = get_state(model)

                aad = dict(args.__dict__)
                ad = {}
                for i in aad:
                    if i[0] != '_':
                        ad[i] = aad[i]

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_d_D_50_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                logging.info("EVAL: Saving checkpoint to {}".format(checkpoint_path))
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_d_D_50:%.04f]' % best_eval_acc)
        logging.info("EVAL: [best_eval_d_D_50:{:.04f}]".format(best_eval_acc))

        eval_loader.dataset._load_envs(start_idx=0, in_order=True)


def train(rank, args, shared_model):
    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.model_type == 'cnn':

        model_kwargs = {}
        model = NavCnnModel(**model_kwargs)

    elif args.model_type == 'cnn+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnModel(**model_kwargs)

    elif args.model_type == 'lstm':

        model_kwargs = {}
        model = NavCnnRnnModel(**model_kwargs)

    elif args.model_type == 'lstm-mult+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnRnnMultModel(**model_kwargs)

    elif args.model_type == 'lstm+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        model = NavCnnRnnModel(**model_kwargs)

    elif args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    optim = torch.optim.Adamax(
        filter(lambda p: p.requires_grad, shared_model.parameters()),
        lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.model_type,
        'num_frames': 5,
        'map_resolution': args.map_resolution,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank % len(args.gpus)],
        'to_cache': args.cache,
        'overfit': args.overfit,
        'max_controller_actions': args.max_controller_actions,
        'max_actions': args.max_actions
    }

    args.output_log_path = os.path.join(args.log_dir,
                                        'train_' + str(rank) + '.json')

    if 'pacman' in args.model_type:

        metrics = NavMetric(
            info={'split': 'train',
                  'thread': rank},
            metric_names=['planner_loss', 'controller_loss'],
            log_json=args.output_log_path)

    else:

        metrics = NavMetric(
            info={'split': 'train',
                  'thread': rank},
            metric_names=['loss'],
            log_json=args.output_log_path)

    train_loader = EqaDataLoader(**train_loader_kwargs)

    print('train_loader has %d samples' % len(train_loader.dataset))
    logging.info('TRAIN: train loader has {} samples'.format(len(train_loader.dataset)))

    t, epoch = 0, 0

    while epoch < int(args.max_epochs):

        if 'cnn' in args.model_type:

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cuda()

                    idx, questions, _, img_feats, _, actions_out, _ = batch

                    img_feats_var = Variable(img_feats.cuda())
                    if '+q' in args.model_type:
                        questions_var = Variable(questions.cuda())
                    actions_out_var = Variable(actions_out.cuda())

                    if '+q' in args.model_type:
                        scores = model(img_feats_var, questions_var)
                    else:
                        scores = model(img_feats_var)

                    loss = lossFn(scores, actions_out_var)

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    metrics.update([loss.data[0]])

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        logging.info("TRAIN: metrics: {}".format(metrics.get_stat_string()))
                        if args.log == True:
                            metrics.dump_log()

                    print('[CHECK][Cache:%d][Total:%d]' %
                          (len(train_loader.dataset.img_data_cache),
                           len(train_loader.dataset.env_list)))
                    logging.info('TRAIN: [CHECK][Cache:{}][Total:{}]'.format(
                        len(train_loader.dataset.img_data_cache), len(train_loader.dataset.env_list)))

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                        if args.cache == False:
                            train_loader.dataset._load_envs(
                                start_idx=0, in_order=True)
                else:
                    done = True

        elif 'lstm' in args.model_type:

            lossFn = MaskedNLLCriterion().cuda()

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()
            total_times = []
            while done == False:

                start_time = time.time()
                for batch in train_loader:

                    t += 1


                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cuda()

                    idx, questions, _, img_feats, actions_in, actions_out, action_lengths, masks = batch

                    img_feats_var = Variable(img_feats.cuda())
                    if '+q' in args.model_type:
                        questions_var = Variable(questions.cuda())
                    actions_in_var = Variable(actions_in.cuda())
                    actions_out_var = Variable(actions_out.cuda())
                    action_lengths = action_lengths.cuda()
                    masks_var = Variable(masks.cuda())

                    action_lengths, perm_idx = action_lengths.sort(
                        0, descending=True)

                    img_feats_var = img_feats_var[perm_idx]
                    if '+q' in args.model_type:
                        questions_var = questions_var[perm_idx]
                    actions_in_var = actions_in_var[perm_idx]
                    actions_out_var = actions_out_var[perm_idx]
                    masks_var = masks_var[perm_idx]

                    if '+q' in args.model_type:
                        scores, hidden = model(img_feats_var, questions_var,
                                               actions_in_var,
                                               action_lengths.cpu().numpy())
                    else:
                        scores, hidden = model(img_feats_var, False,
                                               actions_in_var,
                                               action_lengths.cpu().numpy())

                    #block out masks
                    if args.curriculum:
                        curriculum_length = (epoch+1)*5
                        for i, action_length in enumerate(action_lengths):
                            if action_length - curriculum_length > 0:
                                masks_var[i, :action_length-curriculum_length] = 0

                    logprob = F.log_softmax(scores, dim=1)
                    loss = lossFn(
                        logprob, actions_out_var[:, :action_lengths.max()]
                        .contiguous().view(-1, 1),
                        masks_var[:, :action_lengths.max()].contiguous().view(
                            -1, 1))

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    metrics.update([loss.data[0]])
                    logging.info("TRAIN LSTM loss: {:.6f}".format(loss.data[0]))

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        logging.info("TRAIN: metrics: {}".format(metrics.get_stat_string()))
                        if args.log == True:
                            metrics.dump_log()

                    print('[CHECK][Cache:%d][Total:%d]' %
                          (len(train_loader.dataset.img_data_cache),
                           len(train_loader.dataset.env_list)))
                    logging.info('TRAIN: [CHECK][Cache:{}][Total:{}]'.format(
                        len(train_loader.dataset.img_data_cache), len(train_loader.dataset.env_list)))


                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                        if args.cache == False:
                            train_loader.dataset._load_envs(
                                start_idx=0, in_order=True)
                else:
                    done = True

        elif 'pacman' in args.model_type:

            planner_lossFn = MaskedNLLCriterion().cuda()
            controller_lossFn = MaskedNLLCriterion().cuda()

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cuda()

                    idx, questions, _, planner_img_feats, planner_actions_in, \
                        planner_actions_out, planner_action_lengths, planner_masks, \
                        controller_img_feats, controller_actions_in, planner_hidden_idx, \
                        controller_outs, controller_action_lengths, controller_masks = batch

                    questions_var = Variable(questions.cuda())

                    planner_img_feats_var = Variable(planner_img_feats.cuda())
                    planner_actions_in_var = Variable(
                        planner_actions_in.cuda())
                    planner_actions_out_var = Variable(
                        planner_actions_out.cuda())
                    planner_action_lengths = planner_action_lengths.cuda()
                    planner_masks_var = Variable(planner_masks.cuda())

                    controller_img_feats_var = Variable(
                        controller_img_feats.cuda())
                    controller_actions_in_var = Variable(
                        controller_actions_in.cuda())
                    planner_hidden_idx_var = Variable(
                        planner_hidden_idx.cuda())
                    controller_outs_var = Variable(controller_outs.cuda())
                    controller_action_lengths = controller_action_lengths.cuda(
                    )
                    controller_masks_var = Variable(controller_masks.cuda())

                    planner_action_lengths, perm_idx = planner_action_lengths.sort(
                        0, descending=True)

                    questions_var = questions_var[perm_idx]

                    planner_img_feats_var = planner_img_feats_var[perm_idx]
                    planner_actions_in_var = planner_actions_in_var[perm_idx]
                    planner_actions_out_var = planner_actions_out_var[perm_idx]
                    planner_masks_var = planner_masks_var[perm_idx]

                    controller_img_feats_var = controller_img_feats_var[
                        perm_idx]
                    controller_actions_in_var = controller_actions_in_var[
                        perm_idx]
                    controller_outs_var = controller_outs_var[perm_idx]
                    planner_hidden_idx_var = planner_hidden_idx_var[perm_idx]
                    controller_action_lengths = controller_action_lengths[
                        perm_idx]
                    controller_masks_var = controller_masks_var[perm_idx]

                    planner_scores, controller_scores, planner_hidden = model(
                        questions_var, planner_img_feats_var,
                        planner_actions_in_var,
                        planner_action_lengths.cpu().numpy(),
                        planner_hidden_idx_var, controller_img_feats_var,
                        controller_actions_in_var, controller_action_lengths)

                    planner_logprob = F.log_softmax(planner_scores, dim=1)
                    controller_logprob = F.log_softmax(
                        controller_scores, dim=1)

                    planner_loss = planner_lossFn(
                        planner_logprob,
                        planner_actions_out_var[:, :planner_action_lengths.max(
                        )].contiguous().view(-1, 1),
                        planner_masks_var[:, :planner_action_lengths.max()]
                        .contiguous().view(-1, 1))

                    controller_loss = controller_lossFn(
                        controller_logprob,
                        controller_outs_var[:, :controller_action_lengths.max(
                        )].contiguous().view(-1, 1),
                        controller_masks_var[:, :controller_action_lengths.max(
                        )].contiguous().view(-1, 1))

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    metrics.update(
                        [planner_loss.data[0], controller_loss.data[0]])
                    logging.info("TRAINING PACMAN planner-loss: {:.6f} controller-loss: {:.6f}".format(
                        planner_loss.data[0], controller_loss.data[0]))

                    # backprop and update
                    if args.max_controller_actions == 1:
                        (planner_loss).backward()
                    else:
                        (planner_loss + controller_loss).backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        logging.info("TRAIN: metrics: {}".format(metrics.get_stat_string()))
                        if args.log == True:
                            metrics.dump_log()

                    print('[CHECK][Cache:%d][Total:%d]' %
                          (len(train_loader.dataset.img_data_cache),
                           len(train_loader.dataset.env_list)))
                    logging.info('TRAIN: [CHECK][Cache:{}][Total:{}]'.format(
                        len(train_loader.dataset.img_data_cache), len(train_loader.dataset.env_list)))

                if all_envs_loaded == False:
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                        if args.cache == False:
                            train_loader.dataset._load_envs(
                                start_idx=0, in_order=True)
                else:
                    done = True

        epoch += 1

        if epoch % args.save_every == 0:

            model_state = get_state(model)
            optimizer_state = optim.state_dict()

            aad = dict(args.__dict__)
            ad = {}
            for i in aad:
                if i[0] != '_':
                    ad[i] = aad[i]

            checkpoint = {'args': ad,
                          'state': model_state,
                          'epoch': epoch,
                          'optimizer': optimizer_state}

            checkpoint_path = '%s/epoch_%d_thread_%d.pt' % (
                args.checkpoint_dir, epoch, rank)
            print('Saving checkpoint to %s' % checkpoint_path)
            logging.info("TRAIN: Saving checkpoint to {}".format(checkpoint_path))
            torch.save(checkpoint, checkpoint_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # data params
    parser.add_argument('-train_h5', default='data/train.h5')
    parser.add_argument('-val_h5', default='data/val.h5')
    parser.add_argument('-test_h5', default='data/test.h5')
    parser.add_argument('-data_json', default='data/data.json')
    parser.add_argument('-vocab_json', default='data/vocab.json')

    parser.add_argument(
        '-target_obj_conn_map_dir',
        default='data/target-obj-conn-maps/500')
    parser.add_argument('-map_resolution', default=500, type=int)

    parser.add_argument(
        '-mode',
        default='train+eval',
        type=str,
        choices=['train', 'eval', 'train+eval'])
    parser.add_argument('-eval_split', default='val', type=str)

    # model details
    parser.add_argument(
        '-model_type',
        default='cnn',
        choices=['cnn', 'cnn+q', 'lstm', 'lstm+q', 'lstm-mult+q', 'pacman'])
    parser.add_argument('-max_episode_length', default=100, type=int)
    parser.add_argument('-curriculum', default=0, type=int)

    # optim params
    parser.add_argument('-batch_size', default=20, type=int)
    parser.add_argument('-learning_rate', default=1e-3, type=float)
    parser.add_argument('-max_epochs', default=1000, type=int)
    parser.add_argument('-overfit', default=False, action='store_true')

    # bookkeeping
    parser.add_argument('-print_every', default=5, type=int)
    parser.add_argument('-eval_every', default=1, type=int)
    parser.add_argument('-save_every', default=1000, type=int) #optional if you would like to save specific epochs as opposed to relying on the eval thread
    parser.add_argument('-identifier', default='cnn')
    parser.add_argument('-num_processes', default=1, type=int)
    parser.add_argument('-max_threads_per_gpu', default=10, type=int)

    # checkpointing
    parser.add_argument('-checkpoint_path', default=False)
    parser.add_argument('-checkpoint_dir', default='checkpoints/nav/')
    parser.add_argument('-log_dir', default='logs/nav/')
    parser.add_argument('-log', default=False, action='store_true')
    parser.add_argument('-cache', default=False, action='store_true')
    parser.add_argument('-max_controller_actions', type=int, default=5)
    parser.add_argument('-max_actions', type=int)
    args = parser.parse_args()

    args.time_id = time.strftime("%m_%d_%H:%M")

    #MAX_CONTROLLER_ACTIONS = args.max_controller_actions

    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)

    if args.curriculum:
        assert 'lstm' in args.model_type #TODO: Finish implementing curriculum for other model types

    logging.basicConfig(filename=os.path.join(args.log_dir, "run_{}.log".format(
                                                str(datetime.now()).replace(' ', '_'))),
                        level=logging.INFO,
                        format='%(asctime)-15s %(message)s')

    try:
        args.gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
        args.gpus = [int(x) for x in args.gpus]
    except KeyError:
        print("CPU not supported")
        logging.info("CPU not supported")
        exit()

    if args.checkpoint_path != False:

        print('Loading checkpoint from %s' % args.checkpoint_path)
        logging.info("Loading checkpoint from {}".format(args.checkpoint_path))

        args_to_keep = ['model_type']

        checkpoint = torch.load(args.checkpoint_path, map_location={
            'cuda:0': 'cpu'
        })

        for i in args.__dict__:
            if i not in args_to_keep:
                checkpoint['args'][i] = args.__dict__[i]

        args = type('new_dict', (object, ), checkpoint['args'])

    args.checkpoint_dir = os.path.join(args.checkpoint_dir,
                                       args.time_id + '_' + args.identifier)
    args.log_dir = os.path.join(args.log_dir,
                                args.time_id + '_' + args.identifier)


    # if set to overfit; set eval_split to train
    if args.overfit == True:
        args.eval_split = 'train'

    print(args.__dict__)
    logging.info(args.__dict__)

    if not os.path.exists(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)
        os.makedirs(args.log_dir)

    if args.model_type == 'cnn':

        model_kwargs = {}
        shared_model = NavCnnModel(**model_kwargs)

    elif args.model_type == 'cnn+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        shared_model = NavCnnModel(**model_kwargs)

    elif args.model_type == 'lstm':

        model_kwargs = {}
        shared_model = NavCnnRnnModel(**model_kwargs)

    elif args.model_type == 'lstm+q':

        model_kwargs = {
            'question_input': True,
            'question_vocab': load_vocab(args.vocab_json)
        }
        shared_model = NavCnnRnnModel(**model_kwargs)

    elif args.model_type == 'pacman':

        model_kwargs = {'question_vocab': load_vocab(args.vocab_json)}
        shared_model = NavPlannerControllerModel(**model_kwargs)

    else:

        exit()

    shared_model.share_memory()

    if args.checkpoint_path != False:
        print('Loading params from checkpoint: %s' % args.checkpoint_path)
        logging.info("Loading params from checkpoint: {}".format(args.checkpoint_path))
        shared_model.load_state_dict(checkpoint['state'])

    if args.mode == 'eval':

        eval(0, args, shared_model)

    elif args.mode == 'train':

        if args.num_processes > 1:
            processes = []
            for rank in range(0, args.num_processes):
                # for rank in range(0, args.num_processes):
                p = mp.Process(target=train, args=(rank, args, shared_model))
                p.start()
                processes.append(p)

            for p in processes:
                p.join()

        else:
            train(0, args, shared_model)

    else:
        processes = []

        # Start the eval thread
        p = mp.Process(target=eval, args=(0, args, shared_model))
        p.start()
        processes.append(p)

        # Start the training thread(s)
        for rank in range(1, args.num_processes + 1):
            # for rank in range(0, args.num_processes):
            p = mp.Process(target=train, args=(rank, args, shared_model))
            p.start()
            processes.append(p)

        for p in processes:
            p.join()


================================================
FILE: training/train_vqa.py
================================================
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import h5py
import time
import argparse
import numpy as np
import os, sys, json

import torch
from torch.autograd import Variable
torch.backends.cudnn.enabled = False
import torch.multiprocessing as mp

from models import VqaLstmModel, VqaLstmCnnAttentionModel
from data import EqaDataset, EqaDataLoader
from metrics import VqaMetric

from models import get_state, repackage_hidden, ensure_shared_grads
from data import load_vocab

import pdb


def eval(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    eval_loader_kwargs = {
        'questions_h5': getattr(args, args.eval_split + '_h5'),
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': 1,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': args.eval_split,
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank%len(args.gpus)],
        'to_cache': args.cache
    }

    eval_loader = EqaDataLoader(**eval_loader_kwargs)
    print('eval_loader has %d samples' % len(eval_loader.dataset))

    args.output_log_path = os.path.join(args.log_dir,
                                        'eval_' + str(rank) + '.json')

    t, epoch, best_eval_acc = 0, 0, 0

    while epoch < int(args.max_epochs):

        model.load_state_dict(shared_model.state_dict())
        model.eval()

        metrics = VqaMetric(
            info={'split': args.eval_split},
            metric_names=[
                'loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'
            ],
            log_json=args.output_log_path)

        if args.input_type == 'ques':
            for batch in eval_loader:
                t += 1

                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # update metrics
                accuracy, ranks = metrics.compute_ranks(
                    scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

            print(metrics.get_stat_string(mode=0))

        elif args.input_type == 'ques,image':
            done = False
            all_envs_loaded = eval_loader.dataset._check_if_all_envs_loaded()

            while done == False:
                for batch in eval_loader:
                    t += 1

                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(
                        scores.data.cpu(), answers)
                    metrics.update(
                        [loss.data[0], accuracy, ranks, 1.0 / ranks])

                print(metrics.get_stat_string(mode=0))

                if all_envs_loaded == False:
                    eval_loader.dataset._load_envs()
                    if len(eval_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        epoch += 1

        # checkpoint if best val accuracy
        if metrics.metrics[1][0] > best_eval_acc:
            best_eval_acc = metrics.metrics[1][0]
            if epoch % args.eval_every == 0 and args.log == True:
                metrics.dump_log()

                model_state = get_state(model)

                if args.checkpoint_path != False:
                    ad = checkpoint['args']
                else:
                    ad = args.__dict__

                checkpoint = {'args': ad, 'state': model_state, 'epoch': epoch}

                checkpoint_path = '%s/epoch_%d_accuracy_%.04f.pt' % (
                    args.checkpoint_dir, epoch, best_eval_acc)
                print('Saving checkpoint to %s' % checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

        print('[best_eval_accuracy:%.04f]' % best_eval_acc)


def train(rank, args, shared_model):

    torch.cuda.set_device(args.gpus.index(args.gpus[rank % len(args.gpus)]))

    if args.input_type == 'ques':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmModel(**model_kwargs)

    elif args.input_type == 'ques,image':

        model_kwargs = {'vocab': load_vocab(args.vocab_json)}
        model = VqaLstmCnnAttentionModel(**model_kwargs)

    lossFn = torch.nn.CrossEntropyLoss().cuda()

    optim = torch.optim.Adam(
        filter(lambda p: p.requires_grad, shared_model.parameters()),
        lr=args.learning_rate)

    train_loader_kwargs = {
        'questions_h5': args.train_h5,
        'data_json': args.data_json,
        'vocab': args.vocab_json,
        'batch_size': args.batch_size,
        'input_type': args.input_type,
        'num_frames': args.num_frames,
        'split': 'train',
        'max_threads_per_gpu': args.max_threads_per_gpu,
        'gpu_id': args.gpus[rank%len(args.gpus)],
        'to_cache': args.cache
    }

    args.output_log_path = os.path.join(args.log_dir,
                                        'train_' + str(rank) + '.json')

    metrics = VqaMetric(
        info={'split': 'train',
              'thread': rank},
        metric_names=['loss', 'accuracy', 'mean_rank', 'mean_reciprocal_rank'],
        log_json=args.output_log_path)

    train_loader = EqaDataLoader(**train_loader_kwargs)
    if args.input_type == 'ques,image':
        train_loader.dataset._load_envs(start_idx=0, in_order=True)

    print('train_loader has %d samples' % len(train_loader.dataset))

    t, epoch = 0, 0

    while epoch < int(args.max_epochs):

        if args.input_type == 'ques':

            for batch in train_loader:

                t += 1

                model.load_state_dict(shared_model.state_dict())
                model.train()
                model.cuda()

                idx, questions, answers = batch

                questions_var = Variable(questions.cuda())
                answers_var = Variable(answers.cuda())

                scores = model(questions_var)
                loss = lossFn(scores, answers_var)

                # zero grad
                optim.zero_grad()

                # update metrics
                accuracy, ranks = metrics.compute_ranks(scores.data.cpu(), answers)
                metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                # backprop and update
                loss.backward()

                ensure_shared_grads(model.cpu(), shared_model)
                optim.step()

                if t % args.print_every == 0:
                    print(metrics.get_stat_string())
                    if args.log == True:
                        metrics.dump_log()

        elif args.input_type == 'ques,image':

            done = False
            all_envs_loaded = train_loader.dataset._check_if_all_envs_loaded()

            while done == False:

                for batch in train_loader:

                    t += 1

                    model.load_state_dict(shared_model.state_dict())
                    model.train()
                    model.cnn.eval()
                    model.cuda()

                    idx, questions, answers, images, _, _, _ = batch

                    questions_var = Variable(questions.cuda())
                    answers_var = Variable(answers.cuda())
                    images_var = Variable(images.cuda())

                    scores, att_probs = model(images_var, questions_var)
                    loss = lossFn(scores, answers_var)

                    # zero grad
                    optim.zero_grad()

                    # update metrics
                    accuracy, ranks = metrics.compute_ranks(scores.data.cpu(), answers)
                    metrics.update([loss.data[0], accuracy, ranks, 1.0 / ranks])

                    # backprop and update
                    loss.backward()

                    ensure_shared_grads(model.cpu(), shared_model)
                    optim.step()

                    if t % args.print_every == 0:
                        print(metrics.get_stat_string())
                        if args.log == True:
                            metrics.dump_log()

                if all_envs_loaded == False:
                    print('[CHECK][Cache:%d][Total:%d]' % (len(train_loader.dataset.img_data_cache),
                        len(train_loader.dataset.env_list)))
                    train_loader.dataset._load_envs(in_order=True)
                    if len(train_loader.dataset.pruned_env_set) == 0:
                        done = True
                else:
                    done = True

        epoch += 1


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # data params
    parser.add_argument('-train_h5', default='data/train.h5')
    parser.add_argument('-val_h5', default='data/val.h5')
    parser.add_argument('-test_h5', default='data/test.h5')
    parser.add_argument('-data_json', default='data/data.json')
    parser.add_argument('-vocab_json', default='data/vocab.json')

    parser.add_argument('-train_cache_path', default=False)
    parser.add_argument('-val_cache_path', default=False)

    parser.add_argument('-mode', default='train', type=str, choices=['train','eval'])
    parser.add_argument('-eval_split', default='val', type=str)

    # model details
    parser.add_argument(
        '-input_type', default='ques,image', choices=['ques', 'ques,image'])
    parser.add_argument(
        '-num_frames', default=5,
        type=int)  # -1 = all frames of navigation sequence

    # optim params
    parser.add_argument('-batch_size', default=20, type=int)
    parser.add_argument('-learning_rate', default=3e-4, type=float)
    parser.add_argument('-max_epochs', default=1000, type=int)

    # bookkeeping
    parser.add_argument('-print_every', default=50, type=int)
    parser.add_argument('-eval_every', default=1, type=int)
    parser.add_argument('-identifier', default='q-only')
    parser.add_argument('-num_processes', default=1, type=int)
    parser.add_argument('-max_threads_per_gpu', default=10, type=int)

    # checkpointing
    parser.add_argument('-checkpoint_path'
Download .txt
gitextract_mc95tebh/

├── .gitignore
├── .gitmodules
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── requirements.txt
├── training/
│   ├── data.py
│   ├── metrics.py
│   ├── models.py
│   ├── train_eqa.py
│   ├── train_nav.py
│   ├── train_vqa.py
│   └── utils/
│       ├── preprocess_questions.py
│       └── preprocess_questions_pkl.py
└── utils/
    ├── house3d.py
    └── make_houses.py
Download .txt
SYMBOL INDEX (122 symbols across 10 files)

FILE: training/data.py
  function load_vocab (line 28) | def load_vocab(path):
  function invert_dict (line 40) | def invert_dict(d):
  function flat_to_hierarchical_actions (line 53) | def flat_to_hierarchical_actions(actions, controller_action_lim):
  function _dataset_to_tensor (line 90) | def _dataset_to_tensor(dset, mask=None, dtype=np.int64):
  function eqaCollateCnn (line 101) | def eqaCollateCnn(batch):
  function eqaCollateSeq2seq (line 116) | def eqaCollateSeq2seq(batch):
  class EqaDataset (line 133) | class EqaDataset(Dataset):
    method __init__ (line 134) | def __init__(self,
    method _pick_envs_to_load (line 269) | def _pick_envs_to_load(self,
    method _load_envs (line 286) | def _load_envs(self, start_idx=-1, in_order=False):
    method _clear_api_threads (line 365) | def _clear_api_threads(self):
    method _clear_memory (line 370) | def _clear_memory(self):
    method _check_if_all_envs_loaded (line 379) | def _check_if_all_envs_loaded(self):
    method set_camera (line 388) | def set_camera(self, e, pos, robot_height=1.0):
    method render (line 398) | def render(self, e):
    method get_frames (line 401) | def get_frames(self, e, pos_queue, preprocess=True):
    method get_hierarchical_features_till_spawn (line 418) | def get_hierarchical_features_till_spawn(self, actions, backtrack_step...
    method __getitem__ (line 459) | def __getitem__(self, index):
    method __len__ (line 798) | def __len__(self):
  class EqaDataLoader (line 805) | class EqaDataLoader(DataLoader):
    method __init__ (line 806) | def __init__(self, **kwargs):
    method close (line 889) | def close(self):
    method __enter__ (line 892) | def __enter__(self):
    method __exit__ (line 895) | def __exit__(self, exc_type, exc_value, traceback):

FILE: training/metrics.py
  class Metric (line 16) | class Metric():
    method __init__ (line 17) | def __init__(self, info={}, metric_names=[], log_json=None):
    method update (line 28) | def update(self, values):
    method get_stat_string (line 58) | def get_stat_string(self, mode=1):
    method dump_log (line 72) | def dump_log(self):
  class VqaMetric (line 86) | class VqaMetric(Metric):
    method __init__ (line 87) | def __init__(self, info={}, metric_names=[], log_json=None):
    method compute_ranks (line 90) | def compute_ranks(self, scores, labels):
  class NavMetric (line 101) | class NavMetric(Metric):
    method __init__ (line 102) | def __init__(self, info={}, metric_names=[], log_json=None):

FILE: training/models.py
  function build_mlp (line 22) | def build_mlp(input_dim,
  function get_state (line 50) | def get_state(m):
  function repackage_hidden (line 59) | def repackage_hidden(h, batch_size):
  function ensure_shared_grads (line 68) | def ensure_shared_grads(model, shared_model):
  class MaskedNLLCriterion (line 76) | class MaskedNLLCriterion(nn.Module):
    method __init__ (line 77) | def __init__(self):
    method forward (line 80) | def forward(self, input, target, mask):
  class MultitaskCNNOutput (line 89) | class MultitaskCNNOutput(nn.Module):
    method __init__ (line 90) | def __init__(
    method forward (line 160) | def forward(self, x):
  class MultitaskCNN (line 202) | class MultitaskCNN(nn.Module):
    method __init__ (line 203) | def __init__(
    method forward (line 273) | def forward(self, x):
  class QuestionLstmEncoder (line 320) | class QuestionLstmEncoder(nn.Module):
    method __init__ (line 321) | def __init__(self,
    method init_weights (line 343) | def init_weights(self):
    method forward (line 347) | def forward(self, x):
  class VqaLstmModel (line 371) | class VqaLstmModel(nn.Module):
    method __init__ (line 372) | def __init__(self,
    method forward (line 401) | def forward(self, questions):
  class VqaLstmCnnAttentionModel (line 407) | class VqaLstmCnnAttentionModel(nn.Module):
    method __init__ (line 408) | def __init__(self,
    method forward (line 451) | def forward(self, images, questions):
  class NavCnnModel (line 487) | class NavCnnModel(nn.Module):
    method __init__ (line 488) | def __init__(self,
    method forward (line 531) | def forward(self, img_feats, questions=None):
  class NavRnnMult (line 551) | class NavRnnMult(nn.Module):
    method __init__ (line 552) | def __init__(self,
    method init_hidden (line 613) | def init_hidden(self, bsz):
    method forward (line 626) | def forward(self,
    method step_forward (line 670) | def step_forward(self, img_feats, question_feats, actions_in, hidden):
  class NavRnn (line 705) | class NavRnn(nn.Module):
    method __init__ (line 706) | def __init__(self,
    method init_hidden (line 767) | def init_hidden(self, bsz):
    method forward (line 780) | def forward(self,
    method step_forward (line 823) | def step_forward(self, img_feats, question_feats, actions_in, hidden):
  class NavCnnRnnMultModel (line 856) | class NavCnnRnnMultModel(nn.Module):
    method __init__ (line 857) | def __init__(
    method forward (line 910) | def forward(self,
  class NavCnnRnnModel (line 943) | class NavCnnRnnModel(nn.Module):
    method __init__ (line 944) | def __init__(
    method forward (line 997) | def forward(self,
  class NavPlannerControllerModel (line 1030) | class NavPlannerControllerModel(nn.Module):
    method __init__ (line 1031) | def __init__(self,
    method forward (line 1092) | def forward(self,
    method planner_step (line 1151) | def planner_step(self, questions, img_feats, actions_in, planner_hidden):
    method controller_step (line 1161) | def controller_step(self, img_feats, actions_in, hidden_in):

FILE: training/train_eqa.py
  function eval (line 29) | def eval(rank, args, shared_nav_model, shared_ans_model):
  function train (line 404) | def train(rank, args, shared_nav_model, shared_ans_model):

FILE: training/train_nav.py
  function _rebuild_tensor_v2 (line 28) | def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_g...
  function eval (line 37) | def eval(rank, args, shared_model):
  function train (line 729) | def train(rank, args, shared_model):

FILE: training/train_vqa.py
  function eval (line 28) | def eval(rank, args, shared_model):
  function train (line 157) | def train(rank, args, shared_model):

FILE: training/utils/preprocess_questions.py
  function tokenize (line 21) | def tokenize(seq,
  function buildVocab (line 41) | def buildVocab(sequences,
  function encode (line 77) | def encode(seqTokens, tokenToIdx, allowUnk=False):
  function decode (line 89) | def decode(seqIdx, idxToToken, delim=None, stopAtEnd=True):
  function preprocessImages (line 101) | def preprocessImages(obj, render_dir=False):
  function processActions (line 123) | def processActions(actions):

FILE: training/utils/preprocess_questions_pkl.py
  function tokenize (line 17) | def tokenize(seq,
  function buildVocab (line 37) | def buildVocab(sequences,
  function encode (line 73) | def encode(seqTokens, tokenToIdx, allowUnk=False):
  function decode (line 85) | def decode(seqIdx, idxToToken, delim=None, stopAtEnd=True):
  function preprocessImages (line 97) | def preprocessImages(obj, render_dir=False):
  function processActions (line 119) | def processActions(actions):

FILE: utils/house3d.py
  class House3DUtils (line 20) | class House3DUtils():
    method __init__ (line 21) | def __init__(
    method calibrate_steps (line 83) | def calibrate_steps(self, reset=True):
    method step (line 118) | def step(self, action, step_reward=False):
    method get_dist_to_target (line 159) | def get_dist_to_target(self, pos):
    method is_inside_room (line 167) | def is_inside_room(self, pos, room):
    method build_graph (line 180) | def build_graph(self, save_path=None):
    method load_graph (line 266) | def load_graph(self, path):
    method compute_shortest_path (line 280) | def compute_shortest_path(self, source, target, graph=None):
    method fit_grid_path_to_suncg (line 302) | def fit_grid_path_to_suncg(self, nodes, init_yaw=None, back_skip=2):
    method get_rotate_steps (line 395) | def get_rotate_steps(self, pos, target_yaw):
    method _vec_to_array (line 419) | def _vec_to_array(self, pos, yaw):
    method render_images_from_pos_queue (line 423) | def render_images_from_pos_queue(self,
    method render_video_from_pos_queue (line 478) | def render_video_from_pos_queue(self,
    method _parse (line 514) | def _parse(self, levelsToExplore=[0]):
    method spawn_room (line 592) | def spawn_room(self, room=None):
    method spawn_object (line 620) | def spawn_object(self, obj=None, room=None):
    method set_target_object (line 684) | def set_target_object(self, obj, room):
    method _load_semantic_classes (line 817) | def _load_semantic_classes(self, color_file=None):
    method _get_best_yaw_obj_from_pos (line 832) | def _get_best_yaw_obj_from_pos(self, obj_id, grid_pos, height=1.0):
    method _get_best_view_obj (line 863) | def _get_best_view_obj(self,

FILE: utils/make_houses.py
  function extract_threaded (line 32) | def extract_threaded(house):
Condensed preview — 17 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (268K chars).
[
  {
    "path": ".gitignore",
    "chars": 2098,
    "preview": "tmp\ndata\nlogs\ncheckpoints\n*.pem\n*.sh\n*autoenv*\n\n# PYTHON\n# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n"
  },
  {
    "path": ".gitmodules",
    "chars": 81,
    "preview": "[submodule \"House3D\"]\n\tpath = House3D\n\turl = git@github.com:abhshkdz/House3D.git\n"
  },
  {
    "path": "CODE_OF_CONDUCT.md",
    "chars": 284,
    "preview": "# Code of Conduct\n\nFacebook has adopted a Code of Conduct that we expect project participants to adhere to.\nPlease read "
  },
  {
    "path": "CONTRIBUTING.md",
    "chars": 1545,
    "preview": "# Contributing to EmbodiedQA\nWe want to make contributing to this project as easy and transparent as\npossible.\n\n## Our D"
  },
  {
    "path": "LICENSE",
    "chars": 1537,
    "preview": "BSD License\n\nFor EmbodiedQA software\n\nCopyright (c) Facebook, Inc. and its affiliates. All rights reserved.\n\nRedistribut"
  },
  {
    "path": "README.md",
    "chars": 7715,
    "preview": "# EmbodiedQA\n\nCode for the paper\n\n**[Embodied Question Answering][1]**  \nAbhishek Das, Samyak Datta, Georgia Gkioxari, S"
  },
  {
    "path": "requirements.txt",
    "chars": 250,
    "preview": "certifi==2018.4.16\nchardet==3.0.4\nfuture==0.16.0\ngym==0.10.5\nh5py==2.8.0\nidna==2.6\nnumpy==1.14.4\nopencv-python==3.4.1.15"
  },
  {
    "path": "training/data.py",
    "chars": 37697,
    "preview": "import math\nimport time\nimport h5py\nimport logging\nimport argparse\nimport numpy as np\nimport os, sys, json\nfrom tqdm imp"
  },
  {
    "path": "training/metrics.py",
    "chars": 3007,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD"
  },
  {
    "path": "training/models.py",
    "chars": 41613,
    "preview": "# Model defs for navigation and question answering\n# Navigation: CNN, LSTM, Planner-controller\n# VQA: question-only, 5-f"
  },
  {
    "path": "training/train_eqa.py",
    "chars": 37090,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD"
  },
  {
    "path": "training/train_nav.py",
    "chars": 54328,
    "preview": "import time\nimport argparse\nfrom datetime import datetime\nimport logging\nimport numpy as np\nimport os\nimport torch\nimpor"
  },
  {
    "path": "training/train_vqa.py",
    "chars": 13629,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD"
  },
  {
    "path": "training/utils/preprocess_questions.py",
    "chars": 12048,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD"
  },
  {
    "path": "training/utils/preprocess_questions_pkl.py",
    "chars": 11969,
    "preview": "# adapted from https://github.com/facebookresearch/clevr-iep/blob/master/iep/preprocess.py\n\nimport h5py\nimport argparse\n"
  },
  {
    "path": "utils/house3d.py",
    "chars": 33771,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD"
  },
  {
    "path": "utils/make_houses.py",
    "chars": 1301,
    "preview": "# Copyright (c) Facebook, Inc. and its affiliates.\n# All rights reserved.\n#\n# This source code is licensed under the BSD"
  }
]

About this extraction

This page contains the full source code of the facebookresearch/EmbodiedQA GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 17 files (253.9 KB), approximately 56.4k tokens, and a symbol index with 122 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!