Full Code of Yunbo426/MIM for AI

master a61eb6a1be20 cached
22 files
72.4 KB
18.1k tokens
91 symbols
1 requests
Download .txt
Repository: Yunbo426/MIM
Branch: master
Commit: a61eb6a1be20
Files: 22
Total size: 72.4 KB

Directory structure:
gitextract_lnbn37h0/

├── README.md
├── data/
│   └── human36m.sh
├── run.py
└── src/
    ├── __init__.py
    ├── data_provider/
    │   ├── __init__.py
    │   ├── datasets_factory.py
    │   ├── human.py
    │   ├── mnist.py
    │   └── taxibj.py
    ├── layers/
    │   ├── MIMBlock.py
    │   ├── MIMN.py
    │   ├── SpatioTemporalLSTMCellv2.py
    │   ├── TensorLayerNorm.py
    │   └── __init__.py
    ├── models/
    │   ├── __init__.py
    │   ├── mim.py
    │   └── model_factory.py
    ├── trainer.py
    └── utils/
        ├── __init__.py
        ├── metrics.py
        ├── optimizer.py
        └── preprocess.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# Memory In Memory Networks

MIM is a neural network for video prediction and spatiotemporal modeling. It is based on the paper [Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics](https://arxiv.org/pdf/1811.07490.pdf) to be presented at CVPR 2019.

## Abstract

Natural spatiotemporal processes can be highly non-stationary in many ways, e.g. the low-level non-stationarity such as spatial correlations or temporal dependencies of local pixel values; and the high-level non-stationarity such as the accumulation, deformation or dissipation of radar echoes in precipitation forecasting.

We try to stationalize and approximate the non-stationary processes by modeling the differential signals with the MIM recurrent blocks. By stacking multiple MIM blocks, we could potentially handle higher-order non-stationarity. Our model achieves the state-of-the-art results on three spatiotemporal prediction tasks across both synthetic and real-world data.

![model](https://github.com/ZJianjin/mim_images/blob/master/readme_structure.png)

## Pre-trained Models and Datasets

All pre-trained MIM models have been uploaded to [DROPBOX](https://www.dropbox.com/s/7kd82ijezk4lkmp/mim-lib.zip?dl=0) and [BAIDU YUN](https://pan.baidu.com/s/1O07H7l1NTWmAkx3UCDVMLA) (password: srhv).

It also includes our pre-processed training/testing data for Moving MNIST, Color-Changing Moving MNIST, and TaxiBJ. 

For Human3.6M, you may  download it using data/human36m.sh.

## Generation Results

#### Moving MNIST

![mnist1](https://github.com/ZJianjin/mim_images/blob/master/mnist1.gif)

![mnist2](https://github.com/ZJianjin/mim_images/blob/master/mnist4.gif)

![mnist2](https://github.com/ZJianjin/mim_images/blob/master/mnist5.gif)

#### Color-Changing Moving MNIST

![mnistc1](https://github.com/ZJianjin/mim_images/blob/master/mnistc2.gif)

![mnistc2](https://github.com/ZJianjin/mim_images/blob/master/mnistc3.gif)

![mnistc2](https://github.com/ZJianjin/mim_images/blob/master/mnistc4.gif)

#### Radar Echos

![radar1](https://github.com/ZJianjin/mim_images/blob/master/radar9.gif)

![radar2](https://github.com/ZJianjin/mim_images/blob/master/radar3.gif)

![radar3](https://github.com/ZJianjin/mim_images/blob/master/radar7.gif)

#### Human3.6M

![human1](https://github.com/ZJianjin/mim_images/blob/master/human3.gif)

![human2](https://github.com/ZJianjin/mim_images/blob/master/human5.gif)

![human3](https://github.com/ZJianjin/mim_images/blob/master/human10.gif)

## BibTeX
```
@article{wang2018memory,
  title={Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics},
  author={Wang, Yunbo and Zhang, Jianjin and Zhu, Hongyu and Long, Mingsheng and Wang, Jianmin and Yu, Philip S},
  journal={arXiv preprint arXiv:1811.07490},
  year={2019}
}
```


================================================
FILE: data/human36m.sh
================================================
# Download H36M images
mkdir human
cd human
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S1.tar
tar -xf S1.tar
rm S1.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S5.tar
tar -xf S5.tar
rm S5.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S6.tar
tar -xf S6.tar
rm S6.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S7.tar
tar -xf S7.tar
rm S7.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S8.tar
tar -xf S8.tar
rm S8.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S9.tar
tar -xf S9.tar
rm S9.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S11.tar
tar -xf S11.tar
rm S11.tar
cd ..


================================================
FILE: run.py
================================================
__author__ = 'yunbo'

import os

import tensorflow as tf
import numpy as np
from time import time

from src.data_provider import datasets_factory
from src.models.model_factory import Model
from src.utils import preprocess
import src.trainer as trainer

# -----------------------------------------------------------------------------
FLAGS = tf.app.flags.FLAGS

# os.environ["CUDA_VISIBLE_DEVICES"] = "2"

# mode
tf.app.flags.DEFINE_boolean('is_training', True, 'training or testing')

# data I/O
tf.app.flags.DEFINE_string('dataset_name', 'mnist',
                           'The name of dataset.')
tf.app.flags.DEFINE_string('train_data_paths',
                           'data/moving-mnist-example/moving-mnist-train.npz',
                           'train data paths.')
tf.app.flags.DEFINE_string('valid_data_paths',
                           'data/moving-mnist-example/moving-mnist-valid.npz',
                           'validation data paths.')
tf.app.flags.DEFINE_string('save_dir', 'checkpoints/mnist_predrnn_pp',
                           'dir to store trained net.')
tf.app.flags.DEFINE_string('gen_frm_dir', 'results/mnist_predrnn_pp',
                           'dir to store result.')
tf.app.flags.DEFINE_integer('input_length', 10,
                            'encoder hidden states.')
tf.app.flags.DEFINE_integer('total_length', 20,
                            'total input and output length.')
tf.app.flags.DEFINE_integer('img_width', 64,
                            'input image width.')
tf.app.flags.DEFINE_integer('img_channel', 1,
                            'number of image channel.')
# model[convlstm, predcnn, predrnn, predrnn_pp]
tf.app.flags.DEFINE_string('model_name', 'convlstm_net',
                           'The name of the architecture.')
tf.app.flags.DEFINE_string('pretrained_model', '',
                           'file of a pretrained model to initialize from.')
tf.app.flags.DEFINE_string('num_hidden', '64,64,64,64',
                           'COMMA separated number of units in a convlstm layer.')
tf.app.flags.DEFINE_integer('filter_size', 5,
                            'filter of a convlstm layer.')
tf.app.flags.DEFINE_integer('stride', 1,
                            'stride of a convlstm layer.')
tf.app.flags.DEFINE_integer('patch_size', 1,
                            'patch size on one dimension.')
tf.app.flags.DEFINE_boolean('layer_norm', True,
                            'whether to apply tensor layer norm.')
# scheduled sampling
tf.app.flags.DEFINE_boolean('scheduled_sampling', True, 'for scheduled sampling')
tf.app.flags.DEFINE_integer('sampling_stop_iter', 50000, 'for scheduled sampling.')
tf.app.flags.DEFINE_float('sampling_start_value', 1.0, 'for scheduled sampling.')
tf.app.flags.DEFINE_float('sampling_changing_rate', 0.00002, 'for scheduled sampling.')
# optimization
tf.app.flags.DEFINE_float('lr', 0.001,
                          'base learning rate.')
tf.app.flags.DEFINE_boolean('reverse_input', True,
                            'whether to reverse the input frames while training.')
tf.app.flags.DEFINE_boolean('reverse_img', False,
                            'whether to reverse the input images while training.')
tf.app.flags.DEFINE_integer('batch_size', 8,
                            'batch size for training.')
tf.app.flags.DEFINE_integer('max_iterations', 80000,
                            'max num of steps.')
tf.app.flags.DEFINE_integer('display_interval', 1,
                            'number of iters showing training loss.')
tf.app.flags.DEFINE_integer('test_interval', 1000,
                            'number of iters for test.')
tf.app.flags.DEFINE_integer('snapshot_interval', 1000,
                            'number of iters saving models.')
tf.app.flags.DEFINE_integer('num_save_samples', 10,
                            'number of sequences to be saved.')
tf.app.flags.DEFINE_integer('n_gpu', 1,
                            'how many GPUs to distribute the training across.')
# gpu 
tf.app.flags.DEFINE_boolean('allow_gpu_growth', False,
                            'allow gpu growth')

tf.app.flags.DEFINE_integer('img_height', 0,
                            'input image height.')


def main(argv=None):
    if tf.gfile.Exists(FLAGS.save_dir):
        tf.gfile.DeleteRecursively(FLAGS.save_dir)
    tf.gfile.MakeDirs(FLAGS.save_dir)
    if tf.gfile.Exists(FLAGS.gen_frm_dir):
        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)

    gpu_list = np.asarray(os.environ.get('CUDA_VISIBLE_DEVICES', '-1').split(',') ,dtype=np.int32)
    FLAGS.n_gpu = len(gpu_list)
    print('Initializing models')

    model = Model(FLAGS)

    if FLAGS.is_training:
        train_wrapper(model)
    else:
        start = time()
        test_wrapper(model)
        stop = time()
        print("Time used: " + str(stop - start) + "s")


def schedule_sampling(eta, itr):
    if FLAGS.img_height > 0:
        height = FLAGS.img_height
    else:
        height = FLAGS.img_width
    zeros = np.zeros((FLAGS.batch_size,
                      FLAGS.total_length - FLAGS.input_length - 1,
                      FLAGS.img_width // FLAGS.patch_size,
                      height // FLAGS.patch_size,
                      FLAGS.patch_size ** 2 * FLAGS.img_channel))
    if not FLAGS.scheduled_sampling:
        return 0.0, zeros

    if itr < FLAGS.sampling_stop_iter:
        eta -= FLAGS.sampling_changing_rate
    else:
        eta = 0.0
    random_flip = np.random.random_sample(
        (FLAGS.batch_size, FLAGS.total_length - FLAGS.input_length - 1))
    true_token = (random_flip < eta)
    ones = np.ones((FLAGS.img_width // FLAGS.patch_size,
                    height // FLAGS.patch_size,
                    FLAGS.patch_size ** 2 * FLAGS.img_channel))
    zeros = np.zeros((FLAGS.img_width // FLAGS.patch_size,
                      height // FLAGS.patch_size,
                      FLAGS.patch_size ** 2 * FLAGS.img_channel))
    real_input_flag = []
    for i in range(FLAGS.batch_size):
        for j in range(FLAGS.total_length - FLAGS.input_length - 1):
            if true_token[i, j]:
                real_input_flag.append(ones)
            else:
                real_input_flag.append(zeros)
    real_input_flag = np.array(real_input_flag)
    real_input_flag = np.reshape(real_input_flag,
                           (FLAGS.batch_size,
                            FLAGS.total_length - FLAGS.input_length - 1,
                            FLAGS.img_width // FLAGS.patch_size,
                            height // FLAGS.patch_size,
                            FLAGS.patch_size ** 2 * FLAGS.img_channel))
    return eta, real_input_flag


def train_wrapper(model):
    if FLAGS.pretrained_model:
        model.load(FLAGS.pretrained_model)
    # load data
    train_input_handle, test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=True)

    eta = FLAGS.sampling_start_value

    for itr in range(1, FLAGS.max_iterations + 1):
        if train_input_handle.no_batch_left():
            train_input_handle.begin(do_shuffle=True)
        ims = train_input_handle.get_batch()
        ims_reverse = None
        if FLAGS.reverse_img:
            ims_reverse = ims[:, :, :, ::-1]
            ims_reverse = preprocess.reshape_patch(ims_reverse, FLAGS.patch_size)
        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)

        eta, real_input_flag = schedule_sampling(eta, itr)

        trainer.train(model, ims, real_input_flag, FLAGS, itr, ims_reverse)

        if itr % FLAGS.snapshot_interval == 0:
            model.save(itr)

        if itr % FLAGS.test_interval == 0:
            trainer.test(model, test_input_handle, FLAGS, itr)

        train_input_handle.next()


def test_wrapper(model):
    model.load(FLAGS.pretrained_model)
    test_input_handle = datasets_factory.data_provider(
        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
        FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=False)
    trainer.test(model, test_input_handle, FLAGS, 'test_result')


if __name__ == '__main__':
    tf.app.run()



================================================
FILE: src/__init__.py
================================================


================================================
FILE: src/data_provider/__init__.py
================================================


================================================
FILE: src/data_provider/datasets_factory.py
================================================
from src.data_provider import mnist,  human, taxibj

datasets_map = {
    'mnist': mnist,
    'taxibj': taxibj,
    'human': human
}


def data_provider(dataset_name, train_data_paths, valid_data_paths, batch_size,
                  img_width, seq_length=20, is_training=True):
    '''Given a dataset name and returns a Dataset.
    Args:
        dataset_name: String, the name of the dataset.
        train_data_paths: List, [train_data_path1, train_data_path2...]
        valid_data_paths: List, [val_data_path1, val_data_path2...]
        batch_size: Int
        img_width: Int
        is_training: Bool
    Returns:
        if is_training:
            Two dataset instances for both training and evaluation.
        else:
            One dataset instance for evaluation.
    Raises:
        ValueError: If `dataset_name` is unknown.
    '''
    if dataset_name not in datasets_map:
        raise ValueError('Name of dataset unknown %s' % dataset_name)
    train_data_list = train_data_paths.split(',')
    valid_data_list = valid_data_paths.split(',')
    if dataset_name == 'mnist':
        test_input_param = {'paths': valid_data_list,
                            'minibatch_size': batch_size,
                            'input_data_type': 'float32',
                            'is_output_sequence': True,
                            'name': dataset_name + 'test iterator'}
        test_input_handle = datasets_map[dataset_name].InputHandle(test_input_param)
        test_input_handle.begin(do_shuffle=False)
        if is_training:
            train_input_param = {'paths': train_data_list,
                                 'minibatch_size': batch_size,
                                 'input_data_type': 'float32',
                                 'is_output_sequence': True,
                                 'name': dataset_name + ' train iterator'}
            train_input_handle = datasets_map[dataset_name].InputHandle(train_input_param)
            train_input_handle.begin(do_shuffle=True)
            return train_input_handle, test_input_handle
        else:
            return test_input_handle

    if dataset_name == 'human':
        input_param = {'paths': valid_data_list,
                       'image_width': img_width,
                       'minibatch_size': batch_size,
                       'seq_length': seq_length,
                       'channel': 3,
                       'input_data_type': 'float32',
                       'name': 'human'}
        input_handle = datasets_map[dataset_name].DataProcess(input_param)
        test_input_handle = input_handle.get_test_input_handle()
        test_input_handle.begin(do_shuffle=False)
        if is_training:
            train_input_handle = input_handle.get_train_input_handle()
            train_input_handle.begin(do_shuffle=True)
            return train_input_handle, test_input_handle
        else:
            return test_input_handle

    if dataset_name == 'taxibj':
        input_param = {'paths': valid_data_list,
                       'image_width': img_width,
                       'minibatch_size': batch_size,
                       'seq_length': seq_length,
                       'input_data_type': 'float32',
                       'name': dataset_name + ' iterator'}
        input_handle = datasets_map[dataset_name].DataProcess(input_param)
        if is_training:
            train_input_handle = input_handle.get_train_input_handle()
            train_input_handle.begin(do_shuffle=True)
            test_input_handle = input_handle.get_test_input_handle()
            test_input_handle.begin(do_shuffle=False)
            return train_input_handle, test_input_handle
        else:
            test_input_handle = input_handle.get_test_input_handle()
            test_input_handle.begin(do_shuffle=False)
            return test_input_handle


================================================
FILE: src/data_provider/human.py
================================================
__author__ = 'jianjin'
import numpy as np
import os
import cv2
from PIL import Image
import logging
import random
import tensorflow as tf

logger = logging.getLogger(__name__)

class InputHandle:
    def __init__(self, datas, indices, input_param):
        self.name = input_param['name']
        self.input_data_type = input_param.get('input_data_type', 'float32')
        self.minibatch_size = input_param['minibatch_size']
        self.image_width = input_param['image_width']
        self.channel = input_param['channel']
        self.datas = datas
        self.indices = indices
        self.current_position = 0
        self.current_batch_indices = []
        self.current_input_length = input_param['seq_length']
        self.interval = 2

    def total(self):
        return len(self.indices)

    def begin(self, do_shuffle=True):
        logger.info("Initialization for read data ")
        if do_shuffle:
            random.shuffle(self.indices)
        self.current_position = 0
        self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]

    def next(self):
        self.current_position += self.minibatch_size
        if self.no_batch_left():
            return None
        self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]

    def no_batch_left(self):
        if self.current_position + self.minibatch_size > self.total():
            return True
        else:
            return False

    def get_batch(self):
        if self.no_batch_left():
            logger.error(
                "There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators")
            return None
        input_batch = np.zeros(
            (self.minibatch_size, self.current_input_length, self.image_width, self.image_width, self.channel)).astype(
            self.input_data_type)
        for i in range(self.minibatch_size):
            batch_ind = self.current_batch_indices[i]
            begin = batch_ind
            end = begin + self.current_input_length * self.interval
            data_slice = self.datas[begin:end:self.interval]
            input_batch[i, :self.current_input_length, :, :, :] = data_slice
            # logger.info('data_slice shape')
            # logger.info(data_slice.shape)
            # logger.info(input_batch.shape)
        input_batch = input_batch.astype(self.input_data_type)
        return input_batch

    def print_stat(self):
        logger.info("Iterator Name: " + self.name)
        logger.info("    current_position: " + str(self.current_position))
        logger.info("    Minibatch Size: " + str(self.minibatch_size))
        logger.info("    total Size: " + str(self.total()))
        logger.info("    current_input_length: " + str(self.current_input_length))
        logger.info("    Input Data Type: " + str(self.input_data_type))

class DataProcess:
    def __init__(self, input_param):
        self.input_param = input_param
        self.paths = input_param['paths']
        self.image_width = input_param['image_width']
        self.seq_len = input_param['seq_length']

    def load_data(self, paths, mode='train'):
        data_dir = paths[0]
        intervel = 2

        frames_np = []
        scenarios = ['Walking']
        if mode == 'train':
            subjects = ['S1', 'S5', 'S6', 'S7', 'S8']
        elif mode == 'test':
            subjects = ['S9', 'S11']
        else:
            print ("MODE ERROR")
        _path = data_dir
        print ('load data...', _path)
        filenames = os.listdir(_path)
        filenames.sort()
        print ('data size ', len(filenames))
        frames_file_name = []
        for filename in filenames:
            fix = filename.split('.')
            fix = fix[0]
            subject = fix.split('_')
            scenario = subject[1]
            subject = subject[0]
            if subject not in subjects or scenario not in scenarios:
                continue
            file_path = os.path.join(_path, filename)
            image = cv2.cvtColor(cv2.imread(file_path), cv2.COLOR_BGR2RGB)
            #[1000,1000,3]
            image = image[image.shape[0]//4:-image.shape[0]//4, image.shape[1]//4:-image.shape[1]//4, :]
            if self.image_width != image.shape[0]:
                image = cv2.resize(image, (self.image_width, self.image_width))
            #image = cv2.resize(image[100:-100,100:-100,:], (self.image_width, self.image_width),
            #                   interpolation=cv2.INTER_LINEAR)
            frames_np.append(np.array(image, dtype=np.float32) / 255.0)
            frames_file_name.append(filename)
#             if len(frames_np) % 100 == 0: print len(frames_np)
            #if len(frames_np) % 1000 == 0: break
        # is it a begin index of sequence
        indices = []
        index = 0
        print ('gen index')
        while index + intervel * self.seq_len - 1 < len(frames_file_name):
            # 'S11_Discussion_1.54138969_000471.jpg'
            # ['S11_Discussion_1', '54138969_000471', 'jpg']
            start_infos = frames_file_name[index].split('.')
            end_infos = frames_file_name[index+intervel*(self.seq_len-1)].split('.')
            if start_infos[0] != end_infos[0]:
                index += 1
                continue
            start_video_id, start_frame_id = start_infos[1].split('_')
            end_video_id, end_frame_id = end_infos[1].split('_')
            if start_video_id != end_video_id:
                index += 1
                continue
            if int(end_frame_id) - int(start_frame_id) == 5 * (self.seq_len - 1) * intervel:
                indices.append(index)
            if mode == 'train':
                index += 10
            elif mode == 'test':
                index += 5
        print("there are " + str(len(indices)) + " sequences")
        # data = np.asarray(frames_np)
        data = frames_np
        print("there are " + str(len(data)) + " pictures")
        return data, indices

    def get_train_input_handle(self):
        train_data, train_indices = self.load_data(self.paths, mode='train')
        return InputHandle(train_data, train_indices, self.input_param)

    def get_test_input_handle(self):
        test_data, test_indices = self.load_data(self.paths, mode='test')
        return InputHandle(test_data, test_indices, self.input_param)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("input_dir", type=str)
    parser.add_argument("output_dir", type=str)
    args = parser.parse_args()

    partition_names = ['train', 'test']
    partition_fnames = partition_data(args.input_dir)


if __name__ == '__main__':
    main()


================================================
FILE: src/data_provider/mnist.py
================================================
import numpy as np
import random

class InputHandle:
    def __init__(self, input_param):
        self.paths = input_param['paths']
        self.num_paths = len(input_param['paths'])
        self.name = input_param['name']
        self.input_data_type = input_param.get('input_data_type', 'float32')
        self.output_data_type = input_param.get('output_data_type', 'float32')
        self.minibatch_size = input_param['minibatch_size']
        self.is_output_sequence = input_param['is_output_sequence']
        self.data = {}
        self.indices = {}
        self.current_position = 0
        self.current_batch_size = 0
        self.current_batch_indices = []
        self.current_input_length = 0
        self.current_output_length = 0
        self.load()

    def load(self):
        dat_1 = np.load(self.paths[0])
        for key in dat_1.keys():
            self.data[key] = dat_1[key]
        if self.num_paths == 2:
            dat_2 = np.load(self.paths[1])
            num_clips_1 = dat_1['clips'].shape[1]
            dat_2['clips'][:,:,0] += num_clips_1
            self.data['clips'] = np.concatenate(
                (dat_1['clips'], dat_2['clips']), axis=1)
            self.data['input_raw_data'] = np.concatenate(
                (dat_1['input_raw_data'], dat_2['input_raw_data']), axis=0)
            self.data['output_raw_data'] = np.concatenate(
                (dat_1['output_raw_data'], dat_2['output_raw_data']), axis=0)
        for key in self.data.keys():
            print(key)
            print(self.data[key].shape)

    def total(self):
        return self.data['clips'].shape[1]

    def begin(self, do_shuffle = True):
        self.indices = np.arange(self.total(),dtype="int32")
        if do_shuffle:
            random.shuffle(self.indices)
        self.current_position = 0
        if self.current_position + self.minibatch_size <= self.total():
            self.current_batch_size = self.minibatch_size
        else:
            self.current_batch_size = self.total() - self.current_position
        self.current_batch_indices = self.indices[
            self.current_position:self.current_position + self.current_batch_size]
        self.current_input_length = max(self.data['clips'][0, ind, 1] for ind
                                        in self.current_batch_indices)
        self.current_output_length = max(self.data['clips'][1, ind, 1] for ind
                                         in self.current_batch_indices)

    def next(self):
        self.current_position += self.current_batch_size
        if self.no_batch_left():
            return None
        if self.current_position + self.minibatch_size <= self.total():
            self.current_batch_size = self.minibatch_size
        else:
            self.current_batch_size = self.total() - self.current_position
        self.current_batch_indices = self.indices[
            self.current_position:self.current_position + self.current_batch_size]
        self.current_input_length = max(self.data['clips'][0, ind, 1] for ind
                                        in self.current_batch_indices)
        self.current_output_length = max(self.data['clips'][1, ind, 1] for ind
                                         in self.current_batch_indices)

    def no_batch_left(self):
        if self.current_position >= self.total() - self.current_batch_size:
            return True
        else:
            return False

    def input_batch(self):
        if self.no_batch_left():
            return None
        input_batch = np.zeros(
            (self.current_batch_size, self.current_input_length) +
            tuple(self.data['dims'][0])).astype(self.input_data_type)
        input_batch = np.transpose(input_batch,(0,1,3,4,2))
        for i in range(self.current_batch_size):
            batch_ind = self.current_batch_indices[i]
            begin = self.data['clips'][0, batch_ind, 0]
            end = self.data['clips'][0, batch_ind, 0] + \
                    self.data['clips'][0, batch_ind, 1]
            data_slice = self.data['input_raw_data'][begin:end, :, :, :]
            data_slice = np.transpose(data_slice,(0,2,3,1))
            input_batch[i, :self.current_input_length, :, :, :] = data_slice
        input_batch = input_batch.astype(self.input_data_type)
        return input_batch

    def output_batch(self):
        if self.no_batch_left():
            return None
        if(2 ,3) == self.data['dims'].shape:
            raw_dat = self.data['output_raw_data']
        else:
            raw_dat = self.data['input_raw_data']
        if self.is_output_sequence:
            if (1, 3) == self.data['dims'].shape:
                output_dim = self.data['dims'][0]
            else:
                output_dim = self.data['dims'][1]
            output_batch = np.zeros(
                (self.current_batch_size,self.current_output_length) +
                tuple(output_dim))
        else:
            output_batch = np.zeros((self.current_batch_size, ) +
                                    tuple(self.data['dims'][1]))
        for i in range(self.current_batch_size):
            batch_ind = self.current_batch_indices[i]
            begin = self.data['clips'][1, batch_ind, 0]
            end = self.data['clips'][1, batch_ind, 0] + \
                    self.data['clips'][1, batch_ind, 1]
            if self.is_output_sequence:
                data_slice = raw_dat[begin:end, :, :, :]
                output_batch[i, : data_slice.shape[0], :, :, :] = data_slice
            else:
                data_slice = raw_dat[begin, :, :, :]
                output_batch[i,:, :, :] = data_slice
        output_batch = output_batch.astype(self.output_data_type)
        output_batch = np.transpose(output_batch, [0,1,3,4,2])
        return output_batch

    def get_batch(self):
        input_seq = self.input_batch()
        output_seq = self.output_batch()
        batch = np.concatenate((input_seq, output_seq), axis=1)
        return batch


================================================
FILE: src/data_provider/taxibj.py
================================================
__author__ = 'jianjin'

import random
import os.path
import logging
import os
from copy import copy
import numpy as np
import h5py
import pandas as pd
from datetime import datetime
import time

logger = logging.getLogger(__name__)


def string2timestamp(strings, T=48):
    timestamps = []

    time_per_slot = 24.0 / T
    num_per_T = T // 24
    for t in strings:
        year, month, day, slot = int(t[:4]), int(t[4:6]), int(t[6:8]), int(t[8:])-1
        timestamps.append(pd.Timestamp(datetime(year, month, day, hour=int(slot * time_per_slot),
                                                minute=(slot % num_per_T) * int(60.0 * time_per_slot))))

    return timestamps


class STMatrix(object):
    """docstring for STMatrix"""

    def __init__(self, data, timestamps, T=48, CheckComplete=True):
        super(STMatrix, self).__init__()
        assert len(data) == len(timestamps)
        self.data = data
        self.timestamps = timestamps
        self.T = T
        self.pd_timestamps = string2timestamp(timestamps, T=self.T)
        if CheckComplete:
            self.check_complete()
        # index
        self.make_index()

    def make_index(self):
        self.get_index = dict()
        for i, ts in enumerate(self.pd_timestamps):
            self.get_index[ts] = i

    def check_complete(self):
        missing_timestamps = []
        offset = pd.DateOffset(minutes=24 * 60 // self.T)
        pd_timestamps = self.pd_timestamps
        i = 1
        while i < len(pd_timestamps):
            if pd_timestamps[i-1] + offset != pd_timestamps[i]:
                missing_timestamps.append("(%s -- %s)" % (pd_timestamps[i-1], pd_timestamps[i]))
            i += 1
        for v in missing_timestamps:
            print(v)
        assert len(missing_timestamps) == 0

    def get_matrix(self, timestamp):
        return self.data[self.get_index[timestamp]]

    def save(self, fname):
        pass

    def check_it(self, depends):
        for d in depends:
            if d not in self.get_index.keys():
                return False
        return True

    def create_dataset(self, len_closeness=20):
        """current version
        """
        # offset_week = pd.DateOffset(days=7)
        offset_frame = pd.DateOffset(minutes=24 * 60 // self.T)
        XC = []
        timestamps_Y = []
        depends = [range(1, len_closeness+1)]

        i = len_closeness
        while i < len(self.pd_timestamps):
            Flag = True
            for depend in depends:
                if Flag is False:
                    break
                Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend])

            if Flag is False:
                i += 1
                continue
            x_c = [np.transpose(self.get_matrix(self.pd_timestamps[i] - j * offset_frame), [1, 2, 0]) for j in depends[0]]
            if len_closeness > 0:
                XC.append(np.stack(x_c, axis=0))
            timestamps_Y.append(self.timestamps[i])
            i += 1
        XC = np.stack(XC, axis=0)
        return XC, timestamps_Y


def load_stdata(fname):
    f = h5py.File(fname, 'r')
    data = f['data'].value
    timestamps = f['date'].value
    f.close()
    return data, timestamps


def stat(fname):
    def get_nb_timeslot(f):
        s = f['date'][0]
        e = f['date'][-1]
        year, month, day = map(int, [s[:4], s[4:6], s[6:8]])
        ts = time.strptime("%04i-%02i-%02i" % (year, month, day), "%Y-%m-%d")
        year, month, day = map(int, [e[:4], e[4:6], e[6:8]])
        te = time.strptime("%04i-%02i-%02i" % (year, month, day), "%Y-%m-%d")
        nb_timeslot = (time.mktime(te) - time.mktime(ts)) / (0.5 * 3600) + 48
        ts_str, te_str = time.strftime("%Y-%m-%d", ts), time.strftime("%Y-%m-%d", te)
        return nb_timeslot, ts_str, te_str

    with h5py.File(fname, 'r') as f:
        nb_timeslot, ts_str, te_str = get_nb_timeslot(f)
        nb_day = int(nb_timeslot / 48)
        mmax = f['data'].value.max()
        mmin = f['data'].value.min()
        stat = '=' * 5 + 'stat' + '=' * 5 + '\n' + \
               'data shape: %s\n' % str(f['data'].shape) + \
               '# of days: %i, from %s to %s\n' % (nb_day, ts_str, te_str) + \
               '# of timeslots: %i\n' % int(nb_timeslot) + \
               '# of timeslots (available): %i\n' % f['date'].shape[0] + \
               'missing ratio of timeslots: %.1f%%\n' % ((1. - float(f['date'].shape[0] / nb_timeslot)) * 100) + \
               'max: %.3f, min: %.3f\n' % (mmax, mmin) + \
               '=' * 5 + 'stat' + '=' * 5
        print(stat)


class MinMaxNormalization(object):
    '''MinMax Normalization --> [-1, 1]
       x = (x - min) / (max - min).
       x = x * 2 - 1
    '''

    def __init__(self):
        pass

    def fit(self, X):
        self._min = X.min()
        self._max = X.max()
        print("min:", self._min, "max:", self._max)

    def transform(self, X):
        X = 1. * (X - self._min) / (self._max - self._min)
        # X = X * 2. - 1.
        return X

    def fit_transform(self, X):
        self.fit(X)
        return self.transform(X)

    def inverse_transform(self, X):
        X = (X + 1.) / 2.
        X = 1. * X * (self._max - self._min) + self._min
        return X


def timestamp2vec(timestamps):
    # tm_wday range [0, 6], Monday is 0
    # vec = [time.strptime(str(t[:8], encoding='utf-8'), '%Y%m%d').tm_wday for t in timestamps]  # python3
    vec = [time.strptime(t[:8], '%Y%m%d').tm_wday for t in timestamps]  # python2
    ret = []
    for i in vec:
        v = [0 for _ in range(7)]
        v[i] = 1
        if i >= 5:
            v.append(0)  # weekend
        else:
            v.append(1)  # weekday
        ret.append(v)
    return np.asarray(ret)


def remove_incomplete_days(data, timestamps, T=48):
    # remove a certain day which has not 48 timestamps
    days = []  # available days: some day only contain some seqs
    days_incomplete = []
    i = 0
    while i < len(timestamps):
        if int(timestamps[i][8:]) != 1:
            i += 1
        elif i+T-1 < len(timestamps) and int(timestamps[i+T-1][8:]) == T:
            days.append(timestamps[i][:8])
            i += T
        else:
            days_incomplete.append(timestamps[i][:8])
            i += 1
    print("incomplete days: ", days_incomplete)
    days = set(days)
    idx = []
    for i, t in enumerate(timestamps):
        if t[:8] in days:
            idx.append(i)

    data = data[idx]
    timestamps = [timestamps[i] for i in idx]
    return data, timestamps


class InputHandle:
    def __init__(self, datas, indices, input_param):
        self.name = input_param['name']
        self.input_data_type = input_param.get('input_data_type', 'float32')
        self.minibatch_size = input_param['minibatch_size']
        self.image_width = input_param['image_width']
        self.datas = datas
        self.indices = indices
        self.current_position = 0
        self.current_batch_indices = []
        self.current_input_length = input_param['seq_length']

    def total(self):
        return len(self.indices)

    def begin(self, do_shuffle=True):
        logger.info("Initialization for read data ")
        if do_shuffle:
            random.shuffle(self.indices)
        self.current_position = 0
        self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]

    def next(self):
        self.current_position += self.minibatch_size
        if self.no_batch_left():
            return None
        self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]

    def no_batch_left(self):
        if self.current_position + self.minibatch_size >= self.total():
            return True
        else:
            return False

    def get_batch(self):
        if self.no_batch_left():
            logger.error(
                "There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators")
            return None
        input_batch = self.datas[self.current_batch_indices, :, :, :]
        input_batch = input_batch.astype(self.input_data_type)
        return input_batch

    def print_stat(self):
        logger.info("Iterator Name: " + self.name)
        logger.info("    current_position: " + str(self.current_position))
        logger.info("    Minibatch Size: " + str(self.minibatch_size))
        logger.info("    total Size: " + str(self.total()))
        logger.info("    current_input_length: " + str(self.current_input_length))
        logger.info("    Input Data Type: " + str(self.input_data_type))


class DataProcess:
    def __init__(self, input_param):
        self.paths = input_param['paths']
        self.image_width = input_param['image_width']

        self.input_param = input_param
        self.seq_len = input_param['seq_length']
        self.train_data, self.test_data, _, _, _ = self.load_data(self.paths, len_closeness=input_param['seq_length'])
        self.train_indices = list(range(self.train_data.shape[0]))
        self.test_indices = list(range(self.test_data.shape[0]))

    def load_data(self, datapath, T=48, nb_flow=2, len_closeness=None, len_test=48 * 7 * 4):
        """
        """
        assert (len_closeness > 0)
        # load data
        # 13 - 16
        data_all = []
        timestamps_all = list()
        for year in range(13, 17):
            fname = os.path.join(
                datapath[0], 'BJ{}_M32x32_T30_InOut.h5'.format(year))
            print("file name: ", fname)
            stat(fname)
            data, timestamps = load_stdata(fname)
            # print(timestamps)
            # remove a certain day which does not have 48 timestamps
            data, timestamps = remove_incomplete_days(data, timestamps, T)
            data = data[:, :nb_flow]
            data[data < 0] = 0.
            data_all.append(data)
            timestamps_all.append(timestamps)
            print("\n")

        # minmax_scale
        data_train = np.vstack(copy(data_all))[:-len_test]
        print('train_data shape: ', data_train.shape)
        mmn = MinMaxNormalization()
        mmn.fit(data_train)
        data_all_mmn = [mmn.transform(d) for d in data_all]

        XC = []
        timestamps_Y = []
        for data, timestamps in zip(data_all_mmn, timestamps_all):
            # instance-based dataset --> sequences with format as (X, Y) where X is
            # a sequence of images and Y is an image.
            st = STMatrix(data, timestamps, T, CheckComplete=False)
            _XC, _timestamps_Y = st.create_dataset(len_closeness=len_closeness)
            XC.append(_XC)
            timestamps_Y += _timestamps_Y
        XC = np.concatenate(XC, axis=0)
        print("XC shape: ", XC.shape)

        XC_train = XC[:-len_test]
        XC_test = XC[-len_test:]
        timestamp_train, timestamp_test = timestamps_Y[:-len_test], timestamps_Y[-len_test:]

        X_train = XC_train
        X_test = XC_test
        print('train shape:', XC_train.shape,
              'test shape: ', XC_test.shape)

        return X_train, X_test, mmn, timestamp_train, timestamp_test

    def get_train_input_handle(self):
        return InputHandle(self.train_data, self.train_indices, self.input_param)

    def get_test_input_handle(self):
        return InputHandle(self.test_data, self.test_indices, self.input_param)


================================================
FILE: src/layers/MIMBlock.py
================================================
import tensorflow as tf
from src.layers.TensorLayerNorm import tensor_layer_norm
import math


class MIMBlock():
    def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden,
                 seq_shape, tln=False, initializer=None):
        """Initialize the basic Conv LSTM cell.
        Args:
            layer_name: layer names for different convlstm layers.
            filter_size: int tuple thats the height and width of the filter.
            num_hidden: number of units in output tensor.
            forget_bias: float, The bias added to forget gates (see above).
            tln: whether to apply tensor layer normalization
        """
        self.layer_name = layer_name
        self.filter_size = filter_size
        self.num_hidden_in = num_hidden_in
        self.num_hidden = num_hidden
        self.convlstm_c = None
        self.batch = seq_shape[0]
        self.height = seq_shape[2]
        self.width = seq_shape[3]
        self.layer_norm = tln
        self._forget_bias = 1.0

        def w_initializer(dim_in, dim_out):
            random_range = math.sqrt(6.0 / (dim_in + dim_out))
            return tf.random_uniform_initializer(-random_range, random_range)
        if initializer is None or initializer == -1:
            self.initializer = w_initializer
        else:
            self.initializer = tf.random_uniform_initializer(-initializer, initializer)

    def init_state(self):
        return tf.zeros([self.batch, self.height, self.width, self.num_hidden],
                        dtype=tf.float32)

    def MIMS(self, x, h_t, c_t):
        if h_t is None:
            h_t = self.init_state()
        if c_t is None:
            c_t = self.init_state()
        with tf.variable_scope(self.layer_name):
            h_concat = tf.layers.conv2d(h_t, self.num_hidden * 4,
                                        self.filter_size, 1, padding='same',
                                        kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4),
                                        name='state_to_state')
            if self.layer_norm:
                h_concat = tensor_layer_norm(h_concat, 'state_to_state')
            i_h, g_h, f_h, o_h = tf.split(h_concat, 4, 3)

            ct_weight = tf.get_variable(
                'c_t_weight', [self.height,self.width,self.num_hidden*2])
            ct_activation = tf.multiply(tf.tile(c_t, [1,1,1,2]), ct_weight)
            i_c, f_c = tf.split(ct_activation, 2, 3)

            i_ = i_h + i_c
            f_ = f_h + f_c
            g_ = g_h
            o_ = o_h

            if x != None:
                x_concat = tf.layers.conv2d(x, self.num_hidden * 4,
                                            self.filter_size, 1,
                                            padding='same',
                                            kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4),
                                            name='input_to_state')
                if self.layer_norm:
                    x_concat = tensor_layer_norm(x_concat, 'input_to_state')
                i_x, g_x, f_x, o_x = tf.split(x_concat, 4, 3)

                i_ += i_x
                f_ += f_x
                g_ += g_x
                o_ += o_x

            i_ = tf.nn.sigmoid(i_)
            f_ = tf.nn.sigmoid(f_ + self._forget_bias)
            c_new = f_ * c_t + i_ * tf.nn.tanh(g_)

            oc_weight = tf.get_variable(
                'oc_weight', [self.height,self.width,self.num_hidden])
            o_c = tf.multiply(c_new, oc_weight)

            h_new = tf.nn.sigmoid(o_ + o_c) * tf.nn.tanh(c_new)

            return h_new, c_new

    def __call__(self, x, diff_h, h, c, m):
        if h is None:
            h = self.init_state()
        if c is None:
            c = self.init_state()
        if m is None:
            m = self.init_state()
        if diff_h is None:
            diff_h = tf.zeros_like(h)

        with tf.variable_scope(self.layer_name):
            t_cc = tf.layers.conv2d(
                h, self.num_hidden * 3,
                self.filter_size, 1, padding='same',
                kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 3),
                name='time_state_to_state')
            s_cc = tf.layers.conv2d(
                m, self.num_hidden * 4,
                self.filter_size, 1, padding='same',
                kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4),
                name='spatio_state_to_state')
            x_shape_in = x.get_shape().as_list()[-1]
            x_cc = tf.layers.conv2d(
                x, self.num_hidden * 4,
                self.filter_size, 1, padding='same',
                kernel_initializer=self.initializer(x_shape_in, self.num_hidden * 4),
                name='input_to_state')
            if self.layer_norm:
                t_cc = tensor_layer_norm(t_cc, 'time_state_to_state')
                s_cc = tensor_layer_norm(s_cc, 'spatio_state_to_state')
                x_cc = tensor_layer_norm(x_cc, 'input_to_state')

            i_s, g_s, f_s, o_s = tf.split(s_cc, 4, 3)
            i_t, g_t, o_t = tf.split(t_cc, 3, 3)
            i_x, g_x, f_x, o_x = tf.split(x_cc, 4, 3)

            i = tf.nn.sigmoid(i_x + i_t)
            i_ = tf.nn.sigmoid(i_x + i_s)
            g = tf.nn.tanh(g_x + g_t)
            g_ = tf.nn.tanh(g_x + g_s)
            f_ = tf.nn.sigmoid(f_x + f_s + self._forget_bias)
            o = tf.nn.sigmoid(o_x + o_t + o_s)
            new_m = f_ * m + i_ * g_
            c, self.convlstm_c = self.MIMS(diff_h, c, self.convlstm_c)
            new_c = c + i * g
            cell = tf.concat([new_c, new_m], 3)
            cell = tf.layers.conv2d(cell, self.num_hidden, 1, 1,
                                    padding='same', name='cell_reduce')
            new_h = o * tf.nn.tanh(cell)

            return new_h, new_c, new_m


================================================
FILE: src/layers/MIMN.py
================================================
import tensorflow as tf
from src.layers.TensorLayerNorm import tensor_layer_norm

class MIMN():
    def __init__(self, layer_name, filter_size, num_hidden, seq_shape, tln=True, initializer=0.001):
        """Initialize the basic Conv LSTM cell.
        Args:
            layer_name: layer names for different convlstm layers.
            filter_size: int tuple thats the height and width of the filter.
            num_hidden: number of units in output tensor.
            tln: whether to apply tensor layer normalization.
        """
        self.layer_name = layer_name
        self.filter_size = filter_size
        self.num_hidden = num_hidden
        self.layer_norm = tln
        self.batch = seq_shape[0]
        self.height = seq_shape[2]
        self.width = seq_shape[3]
        self._forget_bias = 1.0
        if initializer == -1:
            self.initializer = None
        else:
            self.initializer = tf.random_uniform_initializer(-initializer,initializer)

    def init_state(self):
        shape = [self.batch, self.height, self.width, self.num_hidden]
        return tf.zeros(shape, dtype=tf.float32)

    def __call__(self, x, h_t, c_t):
        if h_t is None:
            h_t = self.init_state()
        if c_t is None:
            c_t = self.init_state()
        with tf.variable_scope(self.layer_name):
            h_concat = tf.layers.conv2d(h_t, self.num_hidden * 4,
                                        self.filter_size, 1, padding='same',
                                        kernel_initializer=self.initializer,
                                        name='state_to_state')
            if self.layer_norm:
                h_concat = tensor_layer_norm(h_concat, 'state_to_state')
            i_h, g_h, f_h, o_h = tf.split(h_concat, 4, 3)

            ct_weight = tf.get_variable(
                'c_t_weight', [self.height,self.width,self.num_hidden*2])
            ct_activation = tf.multiply(tf.tile(c_t, [1,1,1,2]), ct_weight)
            i_c, f_c = tf.split(ct_activation, 2, 3)

            i_ = i_h + i_c
            f_ = f_h + f_c
            g_ = g_h
            o_ = o_h

            if x != None:
                x_concat = tf.layers.conv2d(x, self.num_hidden * 4,
                                            self.filter_size, 1,
                                            padding='same',
                                            kernel_initializer=self.initializer,
                                            name='input_to_state')
                if self.layer_norm:
                    x_concat = tensor_layer_norm(x_concat, 'input_to_state')
                i_x, g_x, f_x, o_x = tf.split(x_concat, 4, 3)

                i_ += i_x
                f_ += f_x
                g_ += g_x
                o_ += o_x

            i_ = tf.nn.sigmoid(i_)
            f_ = tf.nn.sigmoid(f_ + self._forget_bias)
            c_new = f_ * c_t + i_ * tf.nn.tanh(g_)

            oc_weight = tf.get_variable(
                'oc_weight', [self.height,self.width,self.num_hidden])
            o_c = tf.multiply(c_new, oc_weight)

            h_new = tf.nn.sigmoid(o_ + o_c) * tf.nn.tanh(c_new)

            return h_new, c_new



================================================
FILE: src/layers/SpatioTemporalLSTMCellv2.py
================================================
import math

import tensorflow as tf
from src.layers.TensorLayerNorm import tensor_layer_norm

class SpatioTemporalLSTMCell():
    def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden,
                 seq_shape, tln=False, initializer=None):
        """Initialize the basic Conv LSTM cell.
        Args:
            layer_name: layer names for different convlstm layers.
            filter_size: int tuple thats the height and width of the filter.
            num_hidden: number of units in output tensor.
            forget_bias: float, The bias added to forget gates (see above).
            tln: whether to apply tensor layer normalization
        """
        self.layer_name = layer_name
        self.filter_size = filter_size
        self.num_hidden_in = num_hidden_in
        self.num_hidden = num_hidden
        self.batch = seq_shape[0]
        self.height = seq_shape[2]
        self.width = seq_shape[3]
        self.layer_norm = tln
        self._forget_bias = 1.0

        def w_initializer(dim_in, dim_out):
            random_range = math.sqrt(6.0 / (dim_in + dim_out))
            return tf.random_uniform_initializer(-random_range, random_range)
        if initializer is None or initializer == -1:
            self.initializer = w_initializer
        else:
            self.initializer = tf.random_uniform_initializer(-initializer, initializer)

    def init_state(self):
        return tf.zeros([self.batch, self.height, self.width, self.num_hidden],
                        dtype=tf.float32)

    def __call__(self, x, h, c, m):
        if h is None:
            h = self.init_state()
        if c is None:
            c = self.init_state()
        if m is None:
            m = self.init_state()

        with tf.variable_scope(self.layer_name):
            t_cc = tf.layers.conv2d(
                h, self.num_hidden*4,
                self.filter_size, 1, padding='same',
                kernel_initializer=self.initializer(self.num_hidden_in, self.num_hidden*4),
                name='time_state_to_state')
            s_cc = tf.layers.conv2d(
                m, self.num_hidden*4,
                self.filter_size, 1, padding='same',
                kernel_initializer=self.initializer(self.num_hidden_in, self.num_hidden*4),
                name='spatio_state_to_state')
            x_shape_in = x.get_shape().as_list()[-1]
            x_cc = tf.layers.conv2d(
                x, self.num_hidden*4,
                self.filter_size, 1, padding='same',
                kernel_initializer=self.initializer(x_shape_in, self.num_hidden*4),
                name='input_to_state')
            if self.layer_norm:
                t_cc = tensor_layer_norm(t_cc, 'time_state_to_state')
                s_cc = tensor_layer_norm(s_cc, 'spatio_state_to_state')
                x_cc = tensor_layer_norm(x_cc, 'input_to_state')

            i_s, g_s, f_s, o_s = tf.split(s_cc, 4, 3)
            i_t, g_t, f_t, o_t = tf.split(t_cc, 4, 3)
            i_x, g_x, f_x, o_x = tf.split(x_cc, 4, 3)

            i = tf.nn.sigmoid(i_x + i_t)
            i_ = tf.nn.sigmoid(i_x + i_s)
            g = tf.nn.tanh(g_x + g_t)
            g_ = tf.nn.tanh(g_x + g_s)
            f = tf.nn.sigmoid(f_x + f_t + self._forget_bias)
            f_ = tf.nn.sigmoid(f_x + f_s + self._forget_bias)
            o = tf.nn.sigmoid(o_x + o_t + o_s)
            new_m = f_ * m + i_ * g_
            new_c = f * c + i * g
            cell = tf.concat([new_c, new_m],3)
            cell = tf.layers.conv2d(cell, self.num_hidden, 1, 1, padding='same',
                                    kernel_initializer=self.initializer(self.num_hidden*2, self.num_hidden),
                                    name='cell_reduce')
            new_h = o * tf.nn.tanh(cell)

            return new_h, new_c, new_m









================================================
FILE: src/layers/TensorLayerNorm.py
================================================
import tensorflow as tf

EPSILON = 0.00001


def tensor_layer_norm(x, state_name):
    x_shape = x.get_shape()
    dims = x_shape.ndims
    params_shape = x_shape[-1:]
    if dims == 4:
        m, v = tf.nn.moments(x, [1,2,3], keep_dims=True)
    elif dims == 5:
        m, v = tf.nn.moments(x, [1,2,3,4], keep_dims=True)
    elif dims == 2:
        m, v = tf.nn.moments(x, [1], keep_dims=True)
    else:
        raise ValueError('input tensor for layer normalization must be rank 4 or 5.')
    b = tf.get_variable(state_name+'b',initializer=tf.zeros(params_shape))
    s = tf.get_variable(state_name+'s',initializer=tf.ones(params_shape))
    x_tln = tf.nn.batch_normalization(x, m, v, b, s, EPSILON)
    return x_tln


================================================
FILE: src/layers/__init__.py
================================================


================================================
FILE: src/models/__init__.py
================================================


================================================
FILE: src/models/mim.py
================================================
__author__ = 'jianjin'

import tensorflow as tf
from src.layers.SpatioTemporalLSTMCellv2 import SpatioTemporalLSTMCell as stlstm
from src.layers.MIMBlock import MIMBlock as mimblock
from src.layers.MIMN import MIMN as mimn
import math


def w_initializer(dim_in, dim_out):
    random_range = math.sqrt(6.0 / (dim_in + dim_out))
    return tf.random_uniform_initializer(-random_range, random_range)


def mim(images, params, schedual_sampling_bool, num_layers, num_hidden, filter_size,
        stride=1, total_length=20, input_length=10, tln=True):
    gen_images = []
    stlstm_layer = []
    stlstm_layer_diff = []
    cell_state = []
    hidden_state = []
    cell_state_diff = []
    hidden_state_diff = []
    shape = images.get_shape().as_list()
    output_channels = shape[-1]

    for i in range(num_layers):
        if i == 0:
            num_hidden_in = num_hidden[num_layers - 1]
        else:
            num_hidden_in = num_hidden[i - 1]
        if i < 1:
            new_stlstm_layer = stlstm('stlstm_' + str(i + 1),
                                      filter_size,
                                      num_hidden_in,
                                      num_hidden[i],
                                      shape,
                                      tln=tln)
        else:
            new_stlstm_layer = mimblock('stlstm_' + str(i + 1),
                                        filter_size,
                                        num_hidden_in,
                                        num_hidden[i],
                                        shape,
                                        tln=tln)
        stlstm_layer.append(new_stlstm_layer)
        cell_state.append(None)
        hidden_state.append(None)

    for i in range(num_layers - 1):
        new_stlstm_layer = mimn('stlstm_diff' + str(i + 1),
                                filter_size,
                                num_hidden[i + 1],
                                shape,
                                tln=tln)
        stlstm_layer_diff.append(new_stlstm_layer)
        cell_state_diff.append(None)
        hidden_state_diff.append(None)

    st_memory = None

    for time_step in range(total_length - 1):
        reuse = bool(gen_images)
        with tf.variable_scope('predrnn', reuse=reuse):
            if time_step < input_length:
                x_gen = images[:,time_step]
            else:
                x_gen = schedual_sampling_bool[:,time_step-input_length]*images[:,time_step] + \
                        (1-schedual_sampling_bool[:,time_step-input_length])*x_gen
            preh = hidden_state[0]
            hidden_state[0], cell_state[0], st_memory = stlstm_layer[0](
                x_gen, hidden_state[0], cell_state[0], st_memory)
            for i in range(1, num_layers):
                if time_step > 0:
                    if i == 1:
                        hidden_state_diff[i - 1], cell_state_diff[i - 1] = stlstm_layer_diff[i - 1](
                            hidden_state[i - 1] - preh, hidden_state_diff[i - 1], cell_state_diff[i - 1])
                    else:
                        hidden_state_diff[i - 1], cell_state_diff[i - 1] = stlstm_layer_diff[i - 1](
                            hidden_state_diff[i - 2], hidden_state_diff[i - 1], cell_state_diff[i - 1])
                else:
                    stlstm_layer_diff[i - 1](tf.zeros_like(hidden_state[i - 1]), None, None)
                preh = hidden_state[i]
                hidden_state[i], cell_state[i], st_memory = stlstm_layer[i](
                    hidden_state[i - 1], hidden_state_diff[i - 1], hidden_state[i], cell_state[i], st_memory)
            x_gen = tf.layers.conv2d(hidden_state[num_layers - 1],
                                     filters=output_channels,
                                     kernel_size=1,
                                     strides=1,
                                     padding='same',
                                     kernel_initializer=w_initializer(num_hidden[num_layers - 1], output_channels),
                                     name="back_to_pixel")
            gen_images.append(x_gen)

    gen_images = tf.stack(gen_images, axis=1)
    loss = tf.nn.l2_loss(gen_images - images[:, 1:])

    return [gen_images, loss]


================================================
FILE: src/models/model_factory.py
================================================
import os

import tensorflow as tf

from src.utils import optimizer
from src.models import mim


class Model(object):
    def __init__(self, configs):
        self.configs = configs
        # inputs
        if configs.img_height > 0:
            height = configs.img_height
        else:
            height = configs.img_width
        self.x = [tf.placeholder(tf.float32,
                                 [self.configs.batch_size,
                                  self.configs.total_length,
                                  self.configs.img_width // self.configs.patch_size,
                                  height // self.configs.patch_size,
                                  self.configs.patch_size * self.configs.patch_size * self.configs.img_channel])
                  for i in range(self.configs.n_gpu)]

        self.real_input_flag = tf.placeholder(tf.float32,
                                        [self.configs.batch_size,
                                         self.configs.total_length - self.configs.input_length - 1,
                                         self.configs.img_width // self.configs.patch_size,
                                         height // self.configs.patch_size,
                                         self.configs.patch_size * self.configs.patch_size * self.configs.img_channel])

        grads = []
        loss_train = []
        self.pred_seq = []
        self.tf_lr = tf.placeholder(tf.float32, shape=[])
        self.params = dict()
        self.params.update(self.configs.__dict__['__flags'])
        num_hidden = [int(x) for x in self.configs.num_hidden.split(',')]
        num_layers = len(num_hidden)
        for i in range(self.configs.n_gpu):
            with tf.device('/gpu:%d' % i):
                with tf.variable_scope(tf.get_variable_scope(),
                                       reuse=True if i > 0 else None):
                    # define a model
                    output_list = self.construct_model(
                        self.configs.model_name,
                        self.x[i],
                        self.params,
                        self.real_input_flag,
                        num_layers,
                        num_hidden,
                        self.configs.filter_size,
                        self.configs.stride,
                        self.configs.total_length,
                        self.configs.input_length,
                        self.configs.layer_norm)

                    gen_ims = output_list[0]
                    loss = output_list[1]
                    if len(output_list) > 2:
                        self.debug = output_list[2]
                    else:
                        self.debug = []
                    pred_ims = gen_ims[:, self.configs.input_length - self.configs.total_length:]
                    loss_train.append(loss / self.configs.batch_size)
                    # gradients
                    all_params = tf.trainable_variables()
                    grads.append(tf.gradients(loss, all_params))
                    self.pred_seq.append(pred_ims)

        # add losses and gradients together and get training updates
        with tf.device('/gpu:0'):
            for i in range(1, self.configs.n_gpu):
                loss_train[0] += loss_train[i]
                for j in range(len(grads[0])):
                    grads[0][j] += grads[i][j]
        # keep track of moving average
        ema = tf.train.ExponentialMovingAverage(decay=0.9995)
        maintain_averages_op = tf.group(ema.apply(all_params))
        self.train_op = tf.group(optimizer.adam_updates(
            all_params, grads[0], lr=self.tf_lr, mom1=0.95, mom2=0.9995),
            maintain_averages_op)

        self.loss_train = loss_train[0] / self.configs.n_gpu

        # session
        variables = tf.global_variables()
        self.saver = tf.train.Saver(variables)
        init = tf.global_variables_initializer()
        configProt = tf.ConfigProto()
        configProt.gpu_options.allow_growth = configs.allow_gpu_growth
        configProt.allow_soft_placement = True
        self.sess = tf.Session(config=configProt)
        self.sess.run(init)
        if self.configs.pretrained_model:
            self.saver.restore(self.sess, self.configs.pretrained_model)

    def train(self, inputs, lr, real_input_flag):
        feed_dict = {self.x[i]: inputs[i] for i in range(self.configs.n_gpu)}
        feed_dict.update({self.tf_lr: lr})
        feed_dict.update({self.real_input_flag: real_input_flag})
        loss, _, debug = self.sess.run((self.loss_train, self.train_op, self.debug), feed_dict)
        return loss

    def test(self, inputs, real_input_flag):
        feed_dict = {self.x[i]: inputs[i] for i in range(self.configs.n_gpu)}
        feed_dict.update({self.real_input_flag: real_input_flag})
        gen_ims, debug = self.sess.run((self.pred_seq, self.debug), feed_dict)
        return gen_ims, debug

    def save(self, itr):
        checkpoint_path = os.path.join(self.configs.save_dir, 'model.ckpt')
        self.saver.save(self.sess, checkpoint_path, global_step=itr)
        print('saved to ' + self.configs.save_dir)

    def load(self, checkpoint_path):
        print('load model:', checkpoint_path)
        self.saver.restore(self.sess, checkpoint_path)

    def construct_model(self, name, images, model_params, real_input_flag, num_layers, num_hidden,
                        filter_size, stride, total_length, input_length, tln):
        '''Returns a sequence of generated frames
        Args:
            name: [predrnn_pp]
            params: dict for extra parameters of some models
            real_input_flag: for schedualed sampling.
            num_hidden: number of units in a lstm layer.
            filter_size: for convolutions inside lstm.
            stride: for convolutions inside lstm.
            total_length: including ins and outs.
            input_length: for inputs.
            tln: whether to apply tensor layer normalization.
        Returns:
            gen_images: a seq of frames.
            loss: [l2 / l1+l2].
        Raises:
            ValueError: If network `name` is not recognized.
        '''

        networks_map = {
            'mim': mim.mim,
        }

        params = dict(mask=real_input_flag, num_layers=num_layers, num_hidden=num_hidden, filter_size=filter_size,
                      stride=stride, total_length=total_length, input_length=input_length, is_training=True)
        params.update(model_params)
        if name in networks_map:
            func = networks_map[name]
            return func(images, params, real_input_flag, num_layers, num_hidden, filter_size,
                        stride, total_length, input_length, tln)
        else:
            raise ValueError('Name of network unknown %s' % name)


================================================
FILE: src/trainer.py
================================================
import os.path
import datetime
import cv2
import numpy as np
from skimage.measure import compare_ssim
from src.utils import metrics
from src.utils import preprocess


def train(model, ims, real_input_flag, configs, itr, ims_reverse=None):
    ims = ims[:, :configs.total_length]
    ims_list = np.split(ims, configs.n_gpu)
    cost = model.train(ims_list, configs.lr, real_input_flag)

    flag = 1

    if configs.reverse_img:
        ims_rev = np.split(ims_reverse, configs.n_gpu)
        cost += model.train(ims_rev, configs.lr, real_input_flag)
        flag += 1

    if configs.reverse_input:
        ims_rev = np.split(ims[:, ::-1], configs.n_gpu)
        cost += model.train(ims_rev, configs.lr, real_input_flag)
        flag += 1
        if configs.reverse_img:
            ims_rev = np.split(ims_reverse[:, ::-1], configs.n_gpu)
            cost += model.train(ims_rev, configs.lr, real_input_flag)
            flag += 1

    cost = cost / flag

    if itr % configs.display_interval == 0:
        print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr))
        print('training loss: ' + str(cost))


def test(model, test_input_handle, configs, save_name):
    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...')
    test_input_handle.begin(do_shuffle=False)
    res_path = os.path.join(configs.gen_frm_dir, str(save_name))
    os.mkdir(res_path)
    avg_mse = 0
    batch_id = 0
    img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []

    for i in range(configs.total_length - configs.input_length):
        img_mse.append(0)
        ssim.append(0)
        psnr.append(0)
        fmae.append(0)
        sharp.append(0)

    if configs.img_height > 0:
        height = configs.img_height
    else:
        height = configs.img_width

    real_input_flag = np.zeros(
        (configs.batch_size,
         configs.total_length - configs.input_length - 1,
         configs.img_width // configs.patch_size,
         height // configs.patch_size,
         configs.patch_size ** 2 * configs.img_channel))

    while not test_input_handle.no_batch_left():
        batch_id = batch_id + 1
        if save_name != 'test_result':
            if batch_id > 100: break
        test_ims = test_input_handle.get_batch()
        test_ims = test_ims[:, :configs.total_length]
        if len(test_ims.shape) > 3:
            test_dat = preprocess.reshape_patch(test_ims, configs.patch_size)
        else:
            test_dat = test_ims
        test_dat = np.split(test_dat, configs.n_gpu)
        img_gen, debug = model.test(test_dat, real_input_flag)

        # concat outputs of different gpus along batch
        img_gen = np.concatenate(img_gen)
        if len(img_gen.shape) > 3:
            img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size)
        # MSE per frame
        for i in range(configs.total_length - configs.input_length):
            x = test_ims[:, i + configs.input_length, :, :, :]
            x = x[:configs.batch_size * configs.n_gpu]
            x = x - np.where(x > 10000, np.floor_divide(x, 10000) * 10000, np.zeros_like(x))
            gx = img_gen[:, i, :, :, :]
            fmae[i] += metrics.batch_mae_frame_float(gx, x)
            gx = np.maximum(gx, 0)
            gx = np.minimum(gx, 1)
            mse = np.square(x - gx).sum()
            img_mse[i] += mse
            avg_mse += mse
            real_frm = np.uint8(x * 255)
            pred_frm = np.uint8(gx * 255)
            psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
            for b in range(configs.batch_size):
                sharp[i] += np.max(
                    cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))

                score, _ = compare_ssim(gx[b], x[b], full=True, multichannel=True)
                ssim[i] += score

        # save prediction examples
        if batch_id <= configs.num_save_samples:
            path = os.path.join(res_path, str(batch_id))
            os.mkdir(path)
            if len(debug) != 0:
                np.save(os.path.join(path, "f.npy"), debug)
            for i in range(configs.total_length):
                name = 'gt' + str(i + 1) + '.png'
                file_name = os.path.join(path, name)
                img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
                if configs.img_channel == 2:
                    img_gt = img_gt[:, :, :1]
                cv2.imwrite(file_name, img_gt)
            for i in range(configs.total_length - configs.input_length):
                name = 'pd' + str(i + 1 + configs.input_length) + '.png'
                file_name = os.path.join(path, name)
                img_pd = img_gen[0, i, :, :, :]
                if configs.img_channel == 2:
                    img_pd = img_pd[:, :, :1]
                img_pd = np.maximum(img_pd, 0)
                img_pd = np.minimum(img_pd, 1)
                img_pd = np.uint8(img_pd * 255)
                cv2.imwrite(file_name, img_pd)
        test_input_handle.next()

    avg_mse = avg_mse / (batch_id * configs.batch_size * configs.n_gpu)
    print('mse per seq: ' + str(avg_mse))
    for i in range(configs.total_length - configs.input_length):
        print(img_mse[i] / (batch_id * configs.batch_size * configs.n_gpu))

    psnr = np.asarray(psnr, dtype=np.float32) / batch_id
    fmae = np.asarray(fmae, dtype=np.float32) / batch_id
    ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id)
    sharp = np.asarray(sharp, dtype=np.float32) / (configs.batch_size * batch_id)

    print('psnr per frame: ' + str(np.mean(psnr)))
    for i in range(configs.total_length - configs.input_length):
        print(psnr[i])
    print('fmae per frame: ' + str(np.mean(fmae)))
    for i in range(configs.total_length - configs.input_length):
        print(fmae[i])
    print('ssim per frame: ' + str(np.mean(ssim)))
    for i in range(configs.total_length - configs.input_length):
        print(ssim[i])
    print('sharpness per frame: ' + str(np.mean(sharp)))
    for i in range(configs.total_length - configs.input_length):
        print(sharp[i])


================================================
FILE: src/utils/__init__.py
================================================


================================================
FILE: src/utils/metrics.py
================================================
__author__ = 'yunbo'

import numpy as np
from scipy.signal import convolve2d


def batch_mae_frame_float(gen_frames, gt_frames):
    # [batch, width, height] or [batch, width, height, channel]
    if gen_frames.ndim == 3:
        axis = (1, 2)
    elif gen_frames.ndim == 4:
        axis = (1, 2, 3)
    x = np.float32(gen_frames)
    y = np.float32(gt_frames)
    mae = np.sum(np.absolute(x - y), axis=axis, dtype=np.float32)
    return np.mean(mae)


def batch_psnr(gen_frames, gt_frames):
    # [batch, width, height] or [batch, width, height, channel]
    if gen_frames.ndim == 3:
        axis = (1, 2)
    elif gen_frames.ndim == 4:
        axis = (1, 2, 3)
    x = np.int32(gen_frames)
    y = np.int32(gt_frames)
    num_pixels = float(np.size(gen_frames[0]))
    mse = np.sum((x - y) ** 2, axis=axis, dtype=np.float32) / num_pixels
    psnr = 20 * np.log10(255) - 10 * np.log10(mse)
    return np.mean(psnr)

================================================
FILE: src/utils/optimizer.py
================================================
import tensorflow as tf


def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999):
    updates = []
    if type(cost_or_grads) is not list:
        grads = tf.gradients(cost_or_grads, params)
    else:
        grads = cost_or_grads
    t = tf.Variable(1., 'adam_t')
    for p, g in zip(params, grads):
        mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg')
        if mom1 > 0:
            v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v')
            v_t = mom1 * v + (1. - mom1) * g
            v_hat = v_t / (1. - tf.pow(mom1, t))
            updates.append(v.assign(v_t))
        else:
            v_hat = g
        mg_t = mom2 * mg + (1. - mom2) * tf.square(g)
        mg_hat = mg_t / (1. - tf.pow(mom2, t))
        g_t = v_hat / tf.sqrt(mg_hat + 1e-8)
        p_t = p - lr * g_t
        updates.append(mg.assign(mg_t))
        updates.append(p.assign(p_t))
    updates.append(t.assign_add(1))
    return tf.group(*updates)



================================================
FILE: src/utils/preprocess.py
================================================
__author__ = 'yunbo'

import numpy as np


def reshape_patch(img_tensor, patch_size):
    assert 5 == img_tensor.ndim
    batch_size = np.shape(img_tensor)[0]
    seq_length = np.shape(img_tensor)[1]
    img_height = np.shape(img_tensor)[2]
    img_width = np.shape(img_tensor)[3]
    num_channels = np.shape(img_tensor)[4]
    a = np.reshape(img_tensor, [batch_size, seq_length,
                                img_height//patch_size, patch_size,
                                img_width//patch_size, patch_size,
                                num_channels])
    b = np.transpose(a, [0,1,2,4,3,5,6])
    patch_tensor = np.reshape(b, [batch_size, seq_length,
                                  img_height//patch_size,
                                  img_width//patch_size,
                                  patch_size*patch_size*num_channels])
    return patch_tensor


def reshape_patch_back(patch_tensor, patch_size):
    assert 5 == patch_tensor.ndim
    batch_size = np.shape(patch_tensor)[0]
    seq_length = np.shape(patch_tensor)[1]
    patch_height = np.shape(patch_tensor)[2]
    patch_width = np.shape(patch_tensor)[3]
    channels = np.shape(patch_tensor)[4]
    img_channels = channels // (patch_size*patch_size)
    a = np.reshape(patch_tensor, [batch_size, seq_length,
                                  patch_height, patch_width,
                                  patch_size, patch_size,
                                  img_channels])
    b = np.transpose(a, [0,1,2,4,3,5,6])
    img_tensor = np.reshape(b, [batch_size, seq_length,
                                patch_height * patch_size,
                                patch_width * patch_size,
                                img_channels])
    return img_tensor

Download .txt
gitextract_lnbn37h0/

├── README.md
├── data/
│   └── human36m.sh
├── run.py
└── src/
    ├── __init__.py
    ├── data_provider/
    │   ├── __init__.py
    │   ├── datasets_factory.py
    │   ├── human.py
    │   ├── mnist.py
    │   └── taxibj.py
    ├── layers/
    │   ├── MIMBlock.py
    │   ├── MIMN.py
    │   ├── SpatioTemporalLSTMCellv2.py
    │   ├── TensorLayerNorm.py
    │   └── __init__.py
    ├── models/
    │   ├── __init__.py
    │   ├── mim.py
    │   └── model_factory.py
    ├── trainer.py
    └── utils/
        ├── __init__.py
        ├── metrics.py
        ├── optimizer.py
        └── preprocess.py
Download .txt
SYMBOL INDEX (91 symbols across 15 files)

FILE: run.py
  function main (line 92) | def main(argv=None):
  function schedule_sampling (line 115) | def schedule_sampling(eta, itr):
  function train_wrapper (line 158) | def train_wrapper(model):
  function test_wrapper (line 191) | def test_wrapper(model):

FILE: src/data_provider/datasets_factory.py
  function data_provider (line 10) | def data_provider(dataset_name, train_data_paths, valid_data_paths, batc...

FILE: src/data_provider/human.py
  class InputHandle (line 12) | class InputHandle:
    method __init__ (line 13) | def __init__(self, datas, indices, input_param):
    method total (line 26) | def total(self):
    method begin (line 29) | def begin(self, do_shuffle=True):
    method next (line 36) | def next(self):
    method no_batch_left (line 42) | def no_batch_left(self):
    method get_batch (line 48) | def get_batch(self):
    method print_stat (line 68) | def print_stat(self):
  class DataProcess (line 76) | class DataProcess:
    method __init__ (line 77) | def __init__(self, input_param):
    method load_data (line 83) | def load_data(self, paths, mode='train'):
    method get_train_input_handle (line 150) | def get_train_input_handle(self):
    method get_test_input_handle (line 154) | def get_test_input_handle(self):
  function main (line 159) | def main():

FILE: src/data_provider/mnist.py
  class InputHandle (line 4) | class InputHandle:
    method __init__ (line 5) | def __init__(self, input_param):
    method load (line 22) | def load(self):
    method total (line 40) | def total(self):
    method begin (line 43) | def begin(self, do_shuffle = True):
    method next (line 59) | def next(self):
    method no_batch_left (line 74) | def no_batch_left(self):
    method input_batch (line 80) | def input_batch(self):
    method output_batch (line 98) | def output_batch(self):
    method get_batch (line 131) | def get_batch(self):

FILE: src/data_provider/taxibj.py
  function string2timestamp (line 17) | def string2timestamp(strings, T=48):
  class STMatrix (line 30) | class STMatrix(object):
    method __init__ (line 33) | def __init__(self, data, timestamps, T=48, CheckComplete=True):
    method make_index (line 45) | def make_index(self):
    method check_complete (line 50) | def check_complete(self):
    method get_matrix (line 63) | def get_matrix(self, timestamp):
    method save (line 66) | def save(self, fname):
    method check_it (line 69) | def check_it(self, depends):
    method create_dataset (line 75) | def create_dataset(self, len_closeness=20):
  function load_stdata (line 104) | def load_stdata(fname):
  function stat (line 112) | def stat(fname):
  class MinMaxNormalization (line 140) | class MinMaxNormalization(object):
    method __init__ (line 146) | def __init__(self):
    method fit (line 149) | def fit(self, X):
    method transform (line 154) | def transform(self, X):
    method fit_transform (line 159) | def fit_transform(self, X):
    method inverse_transform (line 163) | def inverse_transform(self, X):
  function timestamp2vec (line 169) | def timestamp2vec(timestamps):
  function remove_incomplete_days (line 185) | def remove_incomplete_days(data, timestamps, T=48):
  class InputHandle (line 211) | class InputHandle:
    method __init__ (line 212) | def __init__(self, datas, indices, input_param):
    method total (line 223) | def total(self):
    method begin (line 226) | def begin(self, do_shuffle=True):
    method next (line 233) | def next(self):
    method no_batch_left (line 239) | def no_batch_left(self):
    method get_batch (line 245) | def get_batch(self):
    method print_stat (line 254) | def print_stat(self):
  class DataProcess (line 263) | class DataProcess:
    method __init__ (line 264) | def __init__(self, input_param):
    method load_data (line 274) | def load_data(self, datapath, T=48, nb_flow=2, len_closeness=None, len...
    method get_train_input_handle (line 327) | def get_train_input_handle(self):
    method get_test_input_handle (line 330) | def get_test_input_handle(self):

FILE: src/layers/MIMBlock.py
  class MIMBlock (line 6) | class MIMBlock():
    method __init__ (line 7) | def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden,
    method init_state (line 36) | def init_state(self):
    method MIMS (line 40) | def MIMS(self, x, h_t, c_t):
    method __call__ (line 91) | def __call__(self, x, diff_h, h, c, m):

FILE: src/layers/MIMN.py
  class MIMN (line 4) | class MIMN():
    method __init__ (line 5) | def __init__(self, layer_name, filter_size, num_hidden, seq_shape, tln...
    method init_state (line 26) | def init_state(self):
    method __call__ (line 30) | def __call__(self, x, h_t, c_t):

FILE: src/layers/SpatioTemporalLSTMCellv2.py
  class SpatioTemporalLSTMCell (line 6) | class SpatioTemporalLSTMCell():
    method __init__ (line 7) | def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden,
    method init_state (line 35) | def init_state(self):
    method __call__ (line 39) | def __call__(self, x, h, c, m):

FILE: src/layers/TensorLayerNorm.py
  function tensor_layer_norm (line 6) | def tensor_layer_norm(x, state_name):

FILE: src/models/mim.py
  function w_initializer (line 10) | def w_initializer(dim_in, dim_out):
  function mim (line 15) | def mim(images, params, schedual_sampling_bool, num_layers, num_hidden, ...

FILE: src/models/model_factory.py
  class Model (line 9) | class Model(object):
    method __init__ (line 10) | def __init__(self, configs):
    method train (line 98) | def train(self, inputs, lr, real_input_flag):
    method test (line 105) | def test(self, inputs, real_input_flag):
    method save (line 111) | def save(self, itr):
    method load (line 116) | def load(self, checkpoint_path):
    method construct_model (line 120) | def construct_model(self, name, images, model_params, real_input_flag,...

FILE: src/trainer.py
  function train (line 10) | def train(model, ims, real_input_flag, configs, itr, ims_reverse=None):
  function test (line 38) | def test(model, test_input_handle, configs, save_name):

FILE: src/utils/metrics.py
  function batch_mae_frame_float (line 7) | def batch_mae_frame_float(gen_frames, gt_frames):
  function batch_psnr (line 19) | def batch_psnr(gen_frames, gt_frames):

FILE: src/utils/optimizer.py
  function adam_updates (line 4) | def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999):

FILE: src/utils/preprocess.py
  function reshape_patch (line 6) | def reshape_patch(img_tensor, patch_size):
  function reshape_patch_back (line 25) | def reshape_patch_back(patch_tensor, patch_size):
Condensed preview — 22 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (77K chars).
[
  {
    "path": "README.md",
    "chars": 2869,
    "preview": "# Memory In Memory Networks\n\nMIM is a neural network for video prediction and spatiotemporal modeling. It is based on th"
  },
  {
    "path": "data/human36m.sh",
    "chars": 648,
    "preview": "# Download H36M images\nmkdir human\ncd human\nwget http://visiondata.cis.upenn.edu/volumetric/h36m/S1.tar\ntar -xf S1.tar\nr"
  },
  {
    "path": "run.py",
    "chars": 8301,
    "preview": "__author__ = 'yunbo'\n\nimport os\n\nimport tensorflow as tf\nimport numpy as np\nfrom time import time\n\nfrom src.data_provide"
  },
  {
    "path": "src/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/data_provider/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/data_provider/datasets_factory.py",
    "chars": 3849,
    "preview": "from src.data_provider import mnist,  human, taxibj\n\ndatasets_map = {\n    'mnist': mnist,\n    'taxibj': taxibj,\n    'hum"
  },
  {
    "path": "src/data_provider/human.py",
    "chars": 6760,
    "preview": "__author__ = 'jianjin'\nimport numpy as np\nimport os\nimport cv2\nfrom PIL import Image\nimport logging\nimport random\nimport"
  },
  {
    "path": "src/data_provider/mnist.py",
    "chars": 5967,
    "preview": "import numpy as np\nimport random\n\nclass InputHandle:\n    def __init__(self, input_param):\n        self.paths = input_par"
  },
  {
    "path": "src/data_provider/taxibj.py",
    "chars": 11418,
    "preview": "__author__ = 'jianjin'\n\nimport random\nimport os.path\nimport logging\nimport os\nfrom copy import copy\nimport numpy as np\ni"
  },
  {
    "path": "src/layers/MIMBlock.py",
    "chars": 5901,
    "preview": "import tensorflow as tf\nfrom src.layers.TensorLayerNorm import tensor_layer_norm\nimport math\n\n\nclass MIMBlock():\n    def"
  },
  {
    "path": "src/layers/MIMN.py",
    "chars": 3172,
    "preview": "import tensorflow as tf\nfrom src.layers.TensorLayerNorm import tensor_layer_norm\n\nclass MIMN():\n    def __init__(self, l"
  },
  {
    "path": "src/layers/SpatioTemporalLSTMCellv2.py",
    "chars": 3803,
    "preview": "import math\n\nimport tensorflow as tf\nfrom src.layers.TensorLayerNorm import tensor_layer_norm\n\nclass SpatioTemporalLSTMC"
  },
  {
    "path": "src/layers/TensorLayerNorm.py",
    "chars": 719,
    "preview": "import tensorflow as tf\n\nEPSILON = 0.00001\n\n\ndef tensor_layer_norm(x, state_name):\n    x_shape = x.get_shape()\n    dims "
  },
  {
    "path": "src/layers/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/models/mim.py",
    "chars": 4256,
    "preview": "__author__ = 'jianjin'\n\nimport tensorflow as tf\nfrom src.layers.SpatioTemporalLSTMCellv2 import SpatioTemporalLSTMCell a"
  },
  {
    "path": "src/models/model_factory.py",
    "chars": 6803,
    "preview": "import os\n\nimport tensorflow as tf\n\nfrom src.utils import optimizer\nfrom src.models import mim\n\n\nclass Model(object):\n  "
  },
  {
    "path": "src/trainer.py",
    "chars": 6072,
    "preview": "import os.path\nimport datetime\nimport cv2\nimport numpy as np\nfrom skimage.measure import compare_ssim\nfrom src.utils imp"
  },
  {
    "path": "src/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "src/utils/metrics.py",
    "chars": 915,
    "preview": "__author__ = 'yunbo'\n\nimport numpy as np\nfrom scipy.signal import convolve2d\n\n\ndef batch_mae_frame_float(gen_frames, gt_"
  },
  {
    "path": "src/utils/optimizer.py",
    "chars": 973,
    "preview": "import tensorflow as tf\n\n\ndef adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999):\n    updates = []\n    "
  },
  {
    "path": "src/utils/preprocess.py",
    "chars": 1738,
    "preview": "__author__ = 'yunbo'\n\nimport numpy as np\n\n\ndef reshape_patch(img_tensor, patch_size):\n    assert 5 == img_tensor.ndim\n  "
  }
]

About this extraction

This page contains the full source code of the Yunbo426/MIM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 22 files (72.4 KB), approximately 18.1k tokens, and a symbol index with 91 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!