Repository: Yunbo426/MIM
Branch: master
Commit: a61eb6a1be20
Files: 22
Total size: 72.4 KB
Directory structure:
gitextract_lnbn37h0/
├── README.md
├── data/
│ └── human36m.sh
├── run.py
└── src/
├── __init__.py
├── data_provider/
│ ├── __init__.py
│ ├── datasets_factory.py
│ ├── human.py
│ ├── mnist.py
│ └── taxibj.py
├── layers/
│ ├── MIMBlock.py
│ ├── MIMN.py
│ ├── SpatioTemporalLSTMCellv2.py
│ ├── TensorLayerNorm.py
│ └── __init__.py
├── models/
│ ├── __init__.py
│ ├── mim.py
│ └── model_factory.py
├── trainer.py
└── utils/
├── __init__.py
├── metrics.py
├── optimizer.py
└── preprocess.py
================================================
FILE CONTENTS
================================================
================================================
FILE: README.md
================================================
# Memory In Memory Networks
MIM is a neural network for video prediction and spatiotemporal modeling. It is based on the paper [Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics](https://arxiv.org/pdf/1811.07490.pdf) to be presented at CVPR 2019.
## Abstract
Natural spatiotemporal processes can be highly non-stationary in many ways, e.g. the low-level non-stationarity such as spatial correlations or temporal dependencies of local pixel values; and the high-level non-stationarity such as the accumulation, deformation or dissipation of radar echoes in precipitation forecasting.
We try to stationalize and approximate the non-stationary processes by modeling the differential signals with the MIM recurrent blocks. By stacking multiple MIM blocks, we could potentially handle higher-order non-stationarity. Our model achieves the state-of-the-art results on three spatiotemporal prediction tasks across both synthetic and real-world data.

## Pre-trained Models and Datasets
All pre-trained MIM models have been uploaded to [DROPBOX](https://www.dropbox.com/s/7kd82ijezk4lkmp/mim-lib.zip?dl=0) and [BAIDU YUN](https://pan.baidu.com/s/1O07H7l1NTWmAkx3UCDVMLA) (password: srhv).
It also includes our pre-processed training/testing data for Moving MNIST, Color-Changing Moving MNIST, and TaxiBJ.
For Human3.6M, you may download it using data/human36m.sh.
## Generation Results
#### Moving MNIST



#### Color-Changing Moving MNIST



#### Radar Echos



#### Human3.6M



## BibTeX
```
@article{wang2018memory,
title={Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics},
author={Wang, Yunbo and Zhang, Jianjin and Zhu, Hongyu and Long, Mingsheng and Wang, Jianmin and Yu, Philip S},
journal={arXiv preprint arXiv:1811.07490},
year={2019}
}
```
================================================
FILE: data/human36m.sh
================================================
# Download H36M images
mkdir human
cd human
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S1.tar
tar -xf S1.tar
rm S1.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S5.tar
tar -xf S5.tar
rm S5.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S6.tar
tar -xf S6.tar
rm S6.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S7.tar
tar -xf S7.tar
rm S7.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S8.tar
tar -xf S8.tar
rm S8.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S9.tar
tar -xf S9.tar
rm S9.tar
wget http://visiondata.cis.upenn.edu/volumetric/h36m/S11.tar
tar -xf S11.tar
rm S11.tar
cd ..
================================================
FILE: run.py
================================================
__author__ = 'yunbo'
import os
import tensorflow as tf
import numpy as np
from time import time
from src.data_provider import datasets_factory
from src.models.model_factory import Model
from src.utils import preprocess
import src.trainer as trainer
# -----------------------------------------------------------------------------
FLAGS = tf.app.flags.FLAGS
# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
# mode
tf.app.flags.DEFINE_boolean('is_training', True, 'training or testing')
# data I/O
tf.app.flags.DEFINE_string('dataset_name', 'mnist',
'The name of dataset.')
tf.app.flags.DEFINE_string('train_data_paths',
'data/moving-mnist-example/moving-mnist-train.npz',
'train data paths.')
tf.app.flags.DEFINE_string('valid_data_paths',
'data/moving-mnist-example/moving-mnist-valid.npz',
'validation data paths.')
tf.app.flags.DEFINE_string('save_dir', 'checkpoints/mnist_predrnn_pp',
'dir to store trained net.')
tf.app.flags.DEFINE_string('gen_frm_dir', 'results/mnist_predrnn_pp',
'dir to store result.')
tf.app.flags.DEFINE_integer('input_length', 10,
'encoder hidden states.')
tf.app.flags.DEFINE_integer('total_length', 20,
'total input and output length.')
tf.app.flags.DEFINE_integer('img_width', 64,
'input image width.')
tf.app.flags.DEFINE_integer('img_channel', 1,
'number of image channel.')
# model[convlstm, predcnn, predrnn, predrnn_pp]
tf.app.flags.DEFINE_string('model_name', 'convlstm_net',
'The name of the architecture.')
tf.app.flags.DEFINE_string('pretrained_model', '',
'file of a pretrained model to initialize from.')
tf.app.flags.DEFINE_string('num_hidden', '64,64,64,64',
'COMMA separated number of units in a convlstm layer.')
tf.app.flags.DEFINE_integer('filter_size', 5,
'filter of a convlstm layer.')
tf.app.flags.DEFINE_integer('stride', 1,
'stride of a convlstm layer.')
tf.app.flags.DEFINE_integer('patch_size', 1,
'patch size on one dimension.')
tf.app.flags.DEFINE_boolean('layer_norm', True,
'whether to apply tensor layer norm.')
# scheduled sampling
tf.app.flags.DEFINE_boolean('scheduled_sampling', True, 'for scheduled sampling')
tf.app.flags.DEFINE_integer('sampling_stop_iter', 50000, 'for scheduled sampling.')
tf.app.flags.DEFINE_float('sampling_start_value', 1.0, 'for scheduled sampling.')
tf.app.flags.DEFINE_float('sampling_changing_rate', 0.00002, 'for scheduled sampling.')
# optimization
tf.app.flags.DEFINE_float('lr', 0.001,
'base learning rate.')
tf.app.flags.DEFINE_boolean('reverse_input', True,
'whether to reverse the input frames while training.')
tf.app.flags.DEFINE_boolean('reverse_img', False,
'whether to reverse the input images while training.')
tf.app.flags.DEFINE_integer('batch_size', 8,
'batch size for training.')
tf.app.flags.DEFINE_integer('max_iterations', 80000,
'max num of steps.')
tf.app.flags.DEFINE_integer('display_interval', 1,
'number of iters showing training loss.')
tf.app.flags.DEFINE_integer('test_interval', 1000,
'number of iters for test.')
tf.app.flags.DEFINE_integer('snapshot_interval', 1000,
'number of iters saving models.')
tf.app.flags.DEFINE_integer('num_save_samples', 10,
'number of sequences to be saved.')
tf.app.flags.DEFINE_integer('n_gpu', 1,
'how many GPUs to distribute the training across.')
# gpu
tf.app.flags.DEFINE_boolean('allow_gpu_growth', False,
'allow gpu growth')
tf.app.flags.DEFINE_integer('img_height', 0,
'input image height.')
def main(argv=None):
if tf.gfile.Exists(FLAGS.save_dir):
tf.gfile.DeleteRecursively(FLAGS.save_dir)
tf.gfile.MakeDirs(FLAGS.save_dir)
if tf.gfile.Exists(FLAGS.gen_frm_dir):
tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)
tf.gfile.MakeDirs(FLAGS.gen_frm_dir)
gpu_list = np.asarray(os.environ.get('CUDA_VISIBLE_DEVICES', '-1').split(',') ,dtype=np.int32)
FLAGS.n_gpu = len(gpu_list)
print('Initializing models')
model = Model(FLAGS)
if FLAGS.is_training:
train_wrapper(model)
else:
start = time()
test_wrapper(model)
stop = time()
print("Time used: " + str(stop - start) + "s")
def schedule_sampling(eta, itr):
if FLAGS.img_height > 0:
height = FLAGS.img_height
else:
height = FLAGS.img_width
zeros = np.zeros((FLAGS.batch_size,
FLAGS.total_length - FLAGS.input_length - 1,
FLAGS.img_width // FLAGS.patch_size,
height // FLAGS.patch_size,
FLAGS.patch_size ** 2 * FLAGS.img_channel))
if not FLAGS.scheduled_sampling:
return 0.0, zeros
if itr < FLAGS.sampling_stop_iter:
eta -= FLAGS.sampling_changing_rate
else:
eta = 0.0
random_flip = np.random.random_sample(
(FLAGS.batch_size, FLAGS.total_length - FLAGS.input_length - 1))
true_token = (random_flip < eta)
ones = np.ones((FLAGS.img_width // FLAGS.patch_size,
height // FLAGS.patch_size,
FLAGS.patch_size ** 2 * FLAGS.img_channel))
zeros = np.zeros((FLAGS.img_width // FLAGS.patch_size,
height // FLAGS.patch_size,
FLAGS.patch_size ** 2 * FLAGS.img_channel))
real_input_flag = []
for i in range(FLAGS.batch_size):
for j in range(FLAGS.total_length - FLAGS.input_length - 1):
if true_token[i, j]:
real_input_flag.append(ones)
else:
real_input_flag.append(zeros)
real_input_flag = np.array(real_input_flag)
real_input_flag = np.reshape(real_input_flag,
(FLAGS.batch_size,
FLAGS.total_length - FLAGS.input_length - 1,
FLAGS.img_width // FLAGS.patch_size,
height // FLAGS.patch_size,
FLAGS.patch_size ** 2 * FLAGS.img_channel))
return eta, real_input_flag
def train_wrapper(model):
if FLAGS.pretrained_model:
model.load(FLAGS.pretrained_model)
# load data
train_input_handle, test_input_handle = datasets_factory.data_provider(
FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=True)
eta = FLAGS.sampling_start_value
for itr in range(1, FLAGS.max_iterations + 1):
if train_input_handle.no_batch_left():
train_input_handle.begin(do_shuffle=True)
ims = train_input_handle.get_batch()
ims_reverse = None
if FLAGS.reverse_img:
ims_reverse = ims[:, :, :, ::-1]
ims_reverse = preprocess.reshape_patch(ims_reverse, FLAGS.patch_size)
ims = preprocess.reshape_patch(ims, FLAGS.patch_size)
eta, real_input_flag = schedule_sampling(eta, itr)
trainer.train(model, ims, real_input_flag, FLAGS, itr, ims_reverse)
if itr % FLAGS.snapshot_interval == 0:
model.save(itr)
if itr % FLAGS.test_interval == 0:
trainer.test(model, test_input_handle, FLAGS, itr)
train_input_handle.next()
def test_wrapper(model):
model.load(FLAGS.pretrained_model)
test_input_handle = datasets_factory.data_provider(
FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,
FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=False)
trainer.test(model, test_input_handle, FLAGS, 'test_result')
if __name__ == '__main__':
tf.app.run()
================================================
FILE: src/__init__.py
================================================
================================================
FILE: src/data_provider/__init__.py
================================================
================================================
FILE: src/data_provider/datasets_factory.py
================================================
from src.data_provider import mnist, human, taxibj
datasets_map = {
'mnist': mnist,
'taxibj': taxibj,
'human': human
}
def data_provider(dataset_name, train_data_paths, valid_data_paths, batch_size,
img_width, seq_length=20, is_training=True):
'''Given a dataset name and returns a Dataset.
Args:
dataset_name: String, the name of the dataset.
train_data_paths: List, [train_data_path1, train_data_path2...]
valid_data_paths: List, [val_data_path1, val_data_path2...]
batch_size: Int
img_width: Int
is_training: Bool
Returns:
if is_training:
Two dataset instances for both training and evaluation.
else:
One dataset instance for evaluation.
Raises:
ValueError: If `dataset_name` is unknown.
'''
if dataset_name not in datasets_map:
raise ValueError('Name of dataset unknown %s' % dataset_name)
train_data_list = train_data_paths.split(',')
valid_data_list = valid_data_paths.split(',')
if dataset_name == 'mnist':
test_input_param = {'paths': valid_data_list,
'minibatch_size': batch_size,
'input_data_type': 'float32',
'is_output_sequence': True,
'name': dataset_name + 'test iterator'}
test_input_handle = datasets_map[dataset_name].InputHandle(test_input_param)
test_input_handle.begin(do_shuffle=False)
if is_training:
train_input_param = {'paths': train_data_list,
'minibatch_size': batch_size,
'input_data_type': 'float32',
'is_output_sequence': True,
'name': dataset_name + ' train iterator'}
train_input_handle = datasets_map[dataset_name].InputHandle(train_input_param)
train_input_handle.begin(do_shuffle=True)
return train_input_handle, test_input_handle
else:
return test_input_handle
if dataset_name == 'human':
input_param = {'paths': valid_data_list,
'image_width': img_width,
'minibatch_size': batch_size,
'seq_length': seq_length,
'channel': 3,
'input_data_type': 'float32',
'name': 'human'}
input_handle = datasets_map[dataset_name].DataProcess(input_param)
test_input_handle = input_handle.get_test_input_handle()
test_input_handle.begin(do_shuffle=False)
if is_training:
train_input_handle = input_handle.get_train_input_handle()
train_input_handle.begin(do_shuffle=True)
return train_input_handle, test_input_handle
else:
return test_input_handle
if dataset_name == 'taxibj':
input_param = {'paths': valid_data_list,
'image_width': img_width,
'minibatch_size': batch_size,
'seq_length': seq_length,
'input_data_type': 'float32',
'name': dataset_name + ' iterator'}
input_handle = datasets_map[dataset_name].DataProcess(input_param)
if is_training:
train_input_handle = input_handle.get_train_input_handle()
train_input_handle.begin(do_shuffle=True)
test_input_handle = input_handle.get_test_input_handle()
test_input_handle.begin(do_shuffle=False)
return train_input_handle, test_input_handle
else:
test_input_handle = input_handle.get_test_input_handle()
test_input_handle.begin(do_shuffle=False)
return test_input_handle
================================================
FILE: src/data_provider/human.py
================================================
__author__ = 'jianjin'
import numpy as np
import os
import cv2
from PIL import Image
import logging
import random
import tensorflow as tf
logger = logging.getLogger(__name__)
class InputHandle:
def __init__(self, datas, indices, input_param):
self.name = input_param['name']
self.input_data_type = input_param.get('input_data_type', 'float32')
self.minibatch_size = input_param['minibatch_size']
self.image_width = input_param['image_width']
self.channel = input_param['channel']
self.datas = datas
self.indices = indices
self.current_position = 0
self.current_batch_indices = []
self.current_input_length = input_param['seq_length']
self.interval = 2
def total(self):
return len(self.indices)
def begin(self, do_shuffle=True):
logger.info("Initialization for read data ")
if do_shuffle:
random.shuffle(self.indices)
self.current_position = 0
self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]
def next(self):
self.current_position += self.minibatch_size
if self.no_batch_left():
return None
self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]
def no_batch_left(self):
if self.current_position + self.minibatch_size > self.total():
return True
else:
return False
def get_batch(self):
if self.no_batch_left():
logger.error(
"There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators")
return None
input_batch = np.zeros(
(self.minibatch_size, self.current_input_length, self.image_width, self.image_width, self.channel)).astype(
self.input_data_type)
for i in range(self.minibatch_size):
batch_ind = self.current_batch_indices[i]
begin = batch_ind
end = begin + self.current_input_length * self.interval
data_slice = self.datas[begin:end:self.interval]
input_batch[i, :self.current_input_length, :, :, :] = data_slice
# logger.info('data_slice shape')
# logger.info(data_slice.shape)
# logger.info(input_batch.shape)
input_batch = input_batch.astype(self.input_data_type)
return input_batch
def print_stat(self):
logger.info("Iterator Name: " + self.name)
logger.info(" current_position: " + str(self.current_position))
logger.info(" Minibatch Size: " + str(self.minibatch_size))
logger.info(" total Size: " + str(self.total()))
logger.info(" current_input_length: " + str(self.current_input_length))
logger.info(" Input Data Type: " + str(self.input_data_type))
class DataProcess:
def __init__(self, input_param):
self.input_param = input_param
self.paths = input_param['paths']
self.image_width = input_param['image_width']
self.seq_len = input_param['seq_length']
def load_data(self, paths, mode='train'):
data_dir = paths[0]
intervel = 2
frames_np = []
scenarios = ['Walking']
if mode == 'train':
subjects = ['S1', 'S5', 'S6', 'S7', 'S8']
elif mode == 'test':
subjects = ['S9', 'S11']
else:
print ("MODE ERROR")
_path = data_dir
print ('load data...', _path)
filenames = os.listdir(_path)
filenames.sort()
print ('data size ', len(filenames))
frames_file_name = []
for filename in filenames:
fix = filename.split('.')
fix = fix[0]
subject = fix.split('_')
scenario = subject[1]
subject = subject[0]
if subject not in subjects or scenario not in scenarios:
continue
file_path = os.path.join(_path, filename)
image = cv2.cvtColor(cv2.imread(file_path), cv2.COLOR_BGR2RGB)
#[1000,1000,3]
image = image[image.shape[0]//4:-image.shape[0]//4, image.shape[1]//4:-image.shape[1]//4, :]
if self.image_width != image.shape[0]:
image = cv2.resize(image, (self.image_width, self.image_width))
#image = cv2.resize(image[100:-100,100:-100,:], (self.image_width, self.image_width),
# interpolation=cv2.INTER_LINEAR)
frames_np.append(np.array(image, dtype=np.float32) / 255.0)
frames_file_name.append(filename)
# if len(frames_np) % 100 == 0: print len(frames_np)
#if len(frames_np) % 1000 == 0: break
# is it a begin index of sequence
indices = []
index = 0
print ('gen index')
while index + intervel * self.seq_len - 1 < len(frames_file_name):
# 'S11_Discussion_1.54138969_000471.jpg'
# ['S11_Discussion_1', '54138969_000471', 'jpg']
start_infos = frames_file_name[index].split('.')
end_infos = frames_file_name[index+intervel*(self.seq_len-1)].split('.')
if start_infos[0] != end_infos[0]:
index += 1
continue
start_video_id, start_frame_id = start_infos[1].split('_')
end_video_id, end_frame_id = end_infos[1].split('_')
if start_video_id != end_video_id:
index += 1
continue
if int(end_frame_id) - int(start_frame_id) == 5 * (self.seq_len - 1) * intervel:
indices.append(index)
if mode == 'train':
index += 10
elif mode == 'test':
index += 5
print("there are " + str(len(indices)) + " sequences")
# data = np.asarray(frames_np)
data = frames_np
print("there are " + str(len(data)) + " pictures")
return data, indices
def get_train_input_handle(self):
train_data, train_indices = self.load_data(self.paths, mode='train')
return InputHandle(train_data, train_indices, self.input_param)
def get_test_input_handle(self):
test_data, test_indices = self.load_data(self.paths, mode='test')
return InputHandle(test_data, test_indices, self.input_param)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("input_dir", type=str)
parser.add_argument("output_dir", type=str)
args = parser.parse_args()
partition_names = ['train', 'test']
partition_fnames = partition_data(args.input_dir)
if __name__ == '__main__':
main()
================================================
FILE: src/data_provider/mnist.py
================================================
import numpy as np
import random
class InputHandle:
def __init__(self, input_param):
self.paths = input_param['paths']
self.num_paths = len(input_param['paths'])
self.name = input_param['name']
self.input_data_type = input_param.get('input_data_type', 'float32')
self.output_data_type = input_param.get('output_data_type', 'float32')
self.minibatch_size = input_param['minibatch_size']
self.is_output_sequence = input_param['is_output_sequence']
self.data = {}
self.indices = {}
self.current_position = 0
self.current_batch_size = 0
self.current_batch_indices = []
self.current_input_length = 0
self.current_output_length = 0
self.load()
def load(self):
dat_1 = np.load(self.paths[0])
for key in dat_1.keys():
self.data[key] = dat_1[key]
if self.num_paths == 2:
dat_2 = np.load(self.paths[1])
num_clips_1 = dat_1['clips'].shape[1]
dat_2['clips'][:,:,0] += num_clips_1
self.data['clips'] = np.concatenate(
(dat_1['clips'], dat_2['clips']), axis=1)
self.data['input_raw_data'] = np.concatenate(
(dat_1['input_raw_data'], dat_2['input_raw_data']), axis=0)
self.data['output_raw_data'] = np.concatenate(
(dat_1['output_raw_data'], dat_2['output_raw_data']), axis=0)
for key in self.data.keys():
print(key)
print(self.data[key].shape)
def total(self):
return self.data['clips'].shape[1]
def begin(self, do_shuffle = True):
self.indices = np.arange(self.total(),dtype="int32")
if do_shuffle:
random.shuffle(self.indices)
self.current_position = 0
if self.current_position + self.minibatch_size <= self.total():
self.current_batch_size = self.minibatch_size
else:
self.current_batch_size = self.total() - self.current_position
self.current_batch_indices = self.indices[
self.current_position:self.current_position + self.current_batch_size]
self.current_input_length = max(self.data['clips'][0, ind, 1] for ind
in self.current_batch_indices)
self.current_output_length = max(self.data['clips'][1, ind, 1] for ind
in self.current_batch_indices)
def next(self):
self.current_position += self.current_batch_size
if self.no_batch_left():
return None
if self.current_position + self.minibatch_size <= self.total():
self.current_batch_size = self.minibatch_size
else:
self.current_batch_size = self.total() - self.current_position
self.current_batch_indices = self.indices[
self.current_position:self.current_position + self.current_batch_size]
self.current_input_length = max(self.data['clips'][0, ind, 1] for ind
in self.current_batch_indices)
self.current_output_length = max(self.data['clips'][1, ind, 1] for ind
in self.current_batch_indices)
def no_batch_left(self):
if self.current_position >= self.total() - self.current_batch_size:
return True
else:
return False
def input_batch(self):
if self.no_batch_left():
return None
input_batch = np.zeros(
(self.current_batch_size, self.current_input_length) +
tuple(self.data['dims'][0])).astype(self.input_data_type)
input_batch = np.transpose(input_batch,(0,1,3,4,2))
for i in range(self.current_batch_size):
batch_ind = self.current_batch_indices[i]
begin = self.data['clips'][0, batch_ind, 0]
end = self.data['clips'][0, batch_ind, 0] + \
self.data['clips'][0, batch_ind, 1]
data_slice = self.data['input_raw_data'][begin:end, :, :, :]
data_slice = np.transpose(data_slice,(0,2,3,1))
input_batch[i, :self.current_input_length, :, :, :] = data_slice
input_batch = input_batch.astype(self.input_data_type)
return input_batch
def output_batch(self):
if self.no_batch_left():
return None
if(2 ,3) == self.data['dims'].shape:
raw_dat = self.data['output_raw_data']
else:
raw_dat = self.data['input_raw_data']
if self.is_output_sequence:
if (1, 3) == self.data['dims'].shape:
output_dim = self.data['dims'][0]
else:
output_dim = self.data['dims'][1]
output_batch = np.zeros(
(self.current_batch_size,self.current_output_length) +
tuple(output_dim))
else:
output_batch = np.zeros((self.current_batch_size, ) +
tuple(self.data['dims'][1]))
for i in range(self.current_batch_size):
batch_ind = self.current_batch_indices[i]
begin = self.data['clips'][1, batch_ind, 0]
end = self.data['clips'][1, batch_ind, 0] + \
self.data['clips'][1, batch_ind, 1]
if self.is_output_sequence:
data_slice = raw_dat[begin:end, :, :, :]
output_batch[i, : data_slice.shape[0], :, :, :] = data_slice
else:
data_slice = raw_dat[begin, :, :, :]
output_batch[i,:, :, :] = data_slice
output_batch = output_batch.astype(self.output_data_type)
output_batch = np.transpose(output_batch, [0,1,3,4,2])
return output_batch
def get_batch(self):
input_seq = self.input_batch()
output_seq = self.output_batch()
batch = np.concatenate((input_seq, output_seq), axis=1)
return batch
================================================
FILE: src/data_provider/taxibj.py
================================================
__author__ = 'jianjin'
import random
import os.path
import logging
import os
from copy import copy
import numpy as np
import h5py
import pandas as pd
from datetime import datetime
import time
logger = logging.getLogger(__name__)
def string2timestamp(strings, T=48):
timestamps = []
time_per_slot = 24.0 / T
num_per_T = T // 24
for t in strings:
year, month, day, slot = int(t[:4]), int(t[4:6]), int(t[6:8]), int(t[8:])-1
timestamps.append(pd.Timestamp(datetime(year, month, day, hour=int(slot * time_per_slot),
minute=(slot % num_per_T) * int(60.0 * time_per_slot))))
return timestamps
class STMatrix(object):
"""docstring for STMatrix"""
def __init__(self, data, timestamps, T=48, CheckComplete=True):
super(STMatrix, self).__init__()
assert len(data) == len(timestamps)
self.data = data
self.timestamps = timestamps
self.T = T
self.pd_timestamps = string2timestamp(timestamps, T=self.T)
if CheckComplete:
self.check_complete()
# index
self.make_index()
def make_index(self):
self.get_index = dict()
for i, ts in enumerate(self.pd_timestamps):
self.get_index[ts] = i
def check_complete(self):
missing_timestamps = []
offset = pd.DateOffset(minutes=24 * 60 // self.T)
pd_timestamps = self.pd_timestamps
i = 1
while i < len(pd_timestamps):
if pd_timestamps[i-1] + offset != pd_timestamps[i]:
missing_timestamps.append("(%s -- %s)" % (pd_timestamps[i-1], pd_timestamps[i]))
i += 1
for v in missing_timestamps:
print(v)
assert len(missing_timestamps) == 0
def get_matrix(self, timestamp):
return self.data[self.get_index[timestamp]]
def save(self, fname):
pass
def check_it(self, depends):
for d in depends:
if d not in self.get_index.keys():
return False
return True
def create_dataset(self, len_closeness=20):
"""current version
"""
# offset_week = pd.DateOffset(days=7)
offset_frame = pd.DateOffset(minutes=24 * 60 // self.T)
XC = []
timestamps_Y = []
depends = [range(1, len_closeness+1)]
i = len_closeness
while i < len(self.pd_timestamps):
Flag = True
for depend in depends:
if Flag is False:
break
Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend])
if Flag is False:
i += 1
continue
x_c = [np.transpose(self.get_matrix(self.pd_timestamps[i] - j * offset_frame), [1, 2, 0]) for j in depends[0]]
if len_closeness > 0:
XC.append(np.stack(x_c, axis=0))
timestamps_Y.append(self.timestamps[i])
i += 1
XC = np.stack(XC, axis=0)
return XC, timestamps_Y
def load_stdata(fname):
f = h5py.File(fname, 'r')
data = f['data'].value
timestamps = f['date'].value
f.close()
return data, timestamps
def stat(fname):
def get_nb_timeslot(f):
s = f['date'][0]
e = f['date'][-1]
year, month, day = map(int, [s[:4], s[4:6], s[6:8]])
ts = time.strptime("%04i-%02i-%02i" % (year, month, day), "%Y-%m-%d")
year, month, day = map(int, [e[:4], e[4:6], e[6:8]])
te = time.strptime("%04i-%02i-%02i" % (year, month, day), "%Y-%m-%d")
nb_timeslot = (time.mktime(te) - time.mktime(ts)) / (0.5 * 3600) + 48
ts_str, te_str = time.strftime("%Y-%m-%d", ts), time.strftime("%Y-%m-%d", te)
return nb_timeslot, ts_str, te_str
with h5py.File(fname, 'r') as f:
nb_timeslot, ts_str, te_str = get_nb_timeslot(f)
nb_day = int(nb_timeslot / 48)
mmax = f['data'].value.max()
mmin = f['data'].value.min()
stat = '=' * 5 + 'stat' + '=' * 5 + '\n' + \
'data shape: %s\n' % str(f['data'].shape) + \
'# of days: %i, from %s to %s\n' % (nb_day, ts_str, te_str) + \
'# of timeslots: %i\n' % int(nb_timeslot) + \
'# of timeslots (available): %i\n' % f['date'].shape[0] + \
'missing ratio of timeslots: %.1f%%\n' % ((1. - float(f['date'].shape[0] / nb_timeslot)) * 100) + \
'max: %.3f, min: %.3f\n' % (mmax, mmin) + \
'=' * 5 + 'stat' + '=' * 5
print(stat)
class MinMaxNormalization(object):
'''MinMax Normalization --> [-1, 1]
x = (x - min) / (max - min).
x = x * 2 - 1
'''
def __init__(self):
pass
def fit(self, X):
self._min = X.min()
self._max = X.max()
print("min:", self._min, "max:", self._max)
def transform(self, X):
X = 1. * (X - self._min) / (self._max - self._min)
# X = X * 2. - 1.
return X
def fit_transform(self, X):
self.fit(X)
return self.transform(X)
def inverse_transform(self, X):
X = (X + 1.) / 2.
X = 1. * X * (self._max - self._min) + self._min
return X
def timestamp2vec(timestamps):
# tm_wday range [0, 6], Monday is 0
# vec = [time.strptime(str(t[:8], encoding='utf-8'), '%Y%m%d').tm_wday for t in timestamps] # python3
vec = [time.strptime(t[:8], '%Y%m%d').tm_wday for t in timestamps] # python2
ret = []
for i in vec:
v = [0 for _ in range(7)]
v[i] = 1
if i >= 5:
v.append(0) # weekend
else:
v.append(1) # weekday
ret.append(v)
return np.asarray(ret)
def remove_incomplete_days(data, timestamps, T=48):
# remove a certain day which has not 48 timestamps
days = [] # available days: some day only contain some seqs
days_incomplete = []
i = 0
while i < len(timestamps):
if int(timestamps[i][8:]) != 1:
i += 1
elif i+T-1 < len(timestamps) and int(timestamps[i+T-1][8:]) == T:
days.append(timestamps[i][:8])
i += T
else:
days_incomplete.append(timestamps[i][:8])
i += 1
print("incomplete days: ", days_incomplete)
days = set(days)
idx = []
for i, t in enumerate(timestamps):
if t[:8] in days:
idx.append(i)
data = data[idx]
timestamps = [timestamps[i] for i in idx]
return data, timestamps
class InputHandle:
def __init__(self, datas, indices, input_param):
self.name = input_param['name']
self.input_data_type = input_param.get('input_data_type', 'float32')
self.minibatch_size = input_param['minibatch_size']
self.image_width = input_param['image_width']
self.datas = datas
self.indices = indices
self.current_position = 0
self.current_batch_indices = []
self.current_input_length = input_param['seq_length']
def total(self):
return len(self.indices)
def begin(self, do_shuffle=True):
logger.info("Initialization for read data ")
if do_shuffle:
random.shuffle(self.indices)
self.current_position = 0
self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]
def next(self):
self.current_position += self.minibatch_size
if self.no_batch_left():
return None
self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]
def no_batch_left(self):
if self.current_position + self.minibatch_size >= self.total():
return True
else:
return False
def get_batch(self):
if self.no_batch_left():
logger.error(
"There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators")
return None
input_batch = self.datas[self.current_batch_indices, :, :, :]
input_batch = input_batch.astype(self.input_data_type)
return input_batch
def print_stat(self):
logger.info("Iterator Name: " + self.name)
logger.info(" current_position: " + str(self.current_position))
logger.info(" Minibatch Size: " + str(self.minibatch_size))
logger.info(" total Size: " + str(self.total()))
logger.info(" current_input_length: " + str(self.current_input_length))
logger.info(" Input Data Type: " + str(self.input_data_type))
class DataProcess:
def __init__(self, input_param):
self.paths = input_param['paths']
self.image_width = input_param['image_width']
self.input_param = input_param
self.seq_len = input_param['seq_length']
self.train_data, self.test_data, _, _, _ = self.load_data(self.paths, len_closeness=input_param['seq_length'])
self.train_indices = list(range(self.train_data.shape[0]))
self.test_indices = list(range(self.test_data.shape[0]))
def load_data(self, datapath, T=48, nb_flow=2, len_closeness=None, len_test=48 * 7 * 4):
"""
"""
assert (len_closeness > 0)
# load data
# 13 - 16
data_all = []
timestamps_all = list()
for year in range(13, 17):
fname = os.path.join(
datapath[0], 'BJ{}_M32x32_T30_InOut.h5'.format(year))
print("file name: ", fname)
stat(fname)
data, timestamps = load_stdata(fname)
# print(timestamps)
# remove a certain day which does not have 48 timestamps
data, timestamps = remove_incomplete_days(data, timestamps, T)
data = data[:, :nb_flow]
data[data < 0] = 0.
data_all.append(data)
timestamps_all.append(timestamps)
print("\n")
# minmax_scale
data_train = np.vstack(copy(data_all))[:-len_test]
print('train_data shape: ', data_train.shape)
mmn = MinMaxNormalization()
mmn.fit(data_train)
data_all_mmn = [mmn.transform(d) for d in data_all]
XC = []
timestamps_Y = []
for data, timestamps in zip(data_all_mmn, timestamps_all):
# instance-based dataset --> sequences with format as (X, Y) where X is
# a sequence of images and Y is an image.
st = STMatrix(data, timestamps, T, CheckComplete=False)
_XC, _timestamps_Y = st.create_dataset(len_closeness=len_closeness)
XC.append(_XC)
timestamps_Y += _timestamps_Y
XC = np.concatenate(XC, axis=0)
print("XC shape: ", XC.shape)
XC_train = XC[:-len_test]
XC_test = XC[-len_test:]
timestamp_train, timestamp_test = timestamps_Y[:-len_test], timestamps_Y[-len_test:]
X_train = XC_train
X_test = XC_test
print('train shape:', XC_train.shape,
'test shape: ', XC_test.shape)
return X_train, X_test, mmn, timestamp_train, timestamp_test
def get_train_input_handle(self):
return InputHandle(self.train_data, self.train_indices, self.input_param)
def get_test_input_handle(self):
return InputHandle(self.test_data, self.test_indices, self.input_param)
================================================
FILE: src/layers/MIMBlock.py
================================================
import tensorflow as tf
from src.layers.TensorLayerNorm import tensor_layer_norm
import math
class MIMBlock():
def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden,
seq_shape, tln=False, initializer=None):
"""Initialize the basic Conv LSTM cell.
Args:
layer_name: layer names for different convlstm layers.
filter_size: int tuple thats the height and width of the filter.
num_hidden: number of units in output tensor.
forget_bias: float, The bias added to forget gates (see above).
tln: whether to apply tensor layer normalization
"""
self.layer_name = layer_name
self.filter_size = filter_size
self.num_hidden_in = num_hidden_in
self.num_hidden = num_hidden
self.convlstm_c = None
self.batch = seq_shape[0]
self.height = seq_shape[2]
self.width = seq_shape[3]
self.layer_norm = tln
self._forget_bias = 1.0
def w_initializer(dim_in, dim_out):
random_range = math.sqrt(6.0 / (dim_in + dim_out))
return tf.random_uniform_initializer(-random_range, random_range)
if initializer is None or initializer == -1:
self.initializer = w_initializer
else:
self.initializer = tf.random_uniform_initializer(-initializer, initializer)
def init_state(self):
return tf.zeros([self.batch, self.height, self.width, self.num_hidden],
dtype=tf.float32)
def MIMS(self, x, h_t, c_t):
if h_t is None:
h_t = self.init_state()
if c_t is None:
c_t = self.init_state()
with tf.variable_scope(self.layer_name):
h_concat = tf.layers.conv2d(h_t, self.num_hidden * 4,
self.filter_size, 1, padding='same',
kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4),
name='state_to_state')
if self.layer_norm:
h_concat = tensor_layer_norm(h_concat, 'state_to_state')
i_h, g_h, f_h, o_h = tf.split(h_concat, 4, 3)
ct_weight = tf.get_variable(
'c_t_weight', [self.height,self.width,self.num_hidden*2])
ct_activation = tf.multiply(tf.tile(c_t, [1,1,1,2]), ct_weight)
i_c, f_c = tf.split(ct_activation, 2, 3)
i_ = i_h + i_c
f_ = f_h + f_c
g_ = g_h
o_ = o_h
if x != None:
x_concat = tf.layers.conv2d(x, self.num_hidden * 4,
self.filter_size, 1,
padding='same',
kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4),
name='input_to_state')
if self.layer_norm:
x_concat = tensor_layer_norm(x_concat, 'input_to_state')
i_x, g_x, f_x, o_x = tf.split(x_concat, 4, 3)
i_ += i_x
f_ += f_x
g_ += g_x
o_ += o_x
i_ = tf.nn.sigmoid(i_)
f_ = tf.nn.sigmoid(f_ + self._forget_bias)
c_new = f_ * c_t + i_ * tf.nn.tanh(g_)
oc_weight = tf.get_variable(
'oc_weight', [self.height,self.width,self.num_hidden])
o_c = tf.multiply(c_new, oc_weight)
h_new = tf.nn.sigmoid(o_ + o_c) * tf.nn.tanh(c_new)
return h_new, c_new
def __call__(self, x, diff_h, h, c, m):
if h is None:
h = self.init_state()
if c is None:
c = self.init_state()
if m is None:
m = self.init_state()
if diff_h is None:
diff_h = tf.zeros_like(h)
with tf.variable_scope(self.layer_name):
t_cc = tf.layers.conv2d(
h, self.num_hidden * 3,
self.filter_size, 1, padding='same',
kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 3),
name='time_state_to_state')
s_cc = tf.layers.conv2d(
m, self.num_hidden * 4,
self.filter_size, 1, padding='same',
kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4),
name='spatio_state_to_state')
x_shape_in = x.get_shape().as_list()[-1]
x_cc = tf.layers.conv2d(
x, self.num_hidden * 4,
self.filter_size, 1, padding='same',
kernel_initializer=self.initializer(x_shape_in, self.num_hidden * 4),
name='input_to_state')
if self.layer_norm:
t_cc = tensor_layer_norm(t_cc, 'time_state_to_state')
s_cc = tensor_layer_norm(s_cc, 'spatio_state_to_state')
x_cc = tensor_layer_norm(x_cc, 'input_to_state')
i_s, g_s, f_s, o_s = tf.split(s_cc, 4, 3)
i_t, g_t, o_t = tf.split(t_cc, 3, 3)
i_x, g_x, f_x, o_x = tf.split(x_cc, 4, 3)
i = tf.nn.sigmoid(i_x + i_t)
i_ = tf.nn.sigmoid(i_x + i_s)
g = tf.nn.tanh(g_x + g_t)
g_ = tf.nn.tanh(g_x + g_s)
f_ = tf.nn.sigmoid(f_x + f_s + self._forget_bias)
o = tf.nn.sigmoid(o_x + o_t + o_s)
new_m = f_ * m + i_ * g_
c, self.convlstm_c = self.MIMS(diff_h, c, self.convlstm_c)
new_c = c + i * g
cell = tf.concat([new_c, new_m], 3)
cell = tf.layers.conv2d(cell, self.num_hidden, 1, 1,
padding='same', name='cell_reduce')
new_h = o * tf.nn.tanh(cell)
return new_h, new_c, new_m
================================================
FILE: src/layers/MIMN.py
================================================
import tensorflow as tf
from src.layers.TensorLayerNorm import tensor_layer_norm
class MIMN():
def __init__(self, layer_name, filter_size, num_hidden, seq_shape, tln=True, initializer=0.001):
"""Initialize the basic Conv LSTM cell.
Args:
layer_name: layer names for different convlstm layers.
filter_size: int tuple thats the height and width of the filter.
num_hidden: number of units in output tensor.
tln: whether to apply tensor layer normalization.
"""
self.layer_name = layer_name
self.filter_size = filter_size
self.num_hidden = num_hidden
self.layer_norm = tln
self.batch = seq_shape[0]
self.height = seq_shape[2]
self.width = seq_shape[3]
self._forget_bias = 1.0
if initializer == -1:
self.initializer = None
else:
self.initializer = tf.random_uniform_initializer(-initializer,initializer)
def init_state(self):
shape = [self.batch, self.height, self.width, self.num_hidden]
return tf.zeros(shape, dtype=tf.float32)
def __call__(self, x, h_t, c_t):
if h_t is None:
h_t = self.init_state()
if c_t is None:
c_t = self.init_state()
with tf.variable_scope(self.layer_name):
h_concat = tf.layers.conv2d(h_t, self.num_hidden * 4,
self.filter_size, 1, padding='same',
kernel_initializer=self.initializer,
name='state_to_state')
if self.layer_norm:
h_concat = tensor_layer_norm(h_concat, 'state_to_state')
i_h, g_h, f_h, o_h = tf.split(h_concat, 4, 3)
ct_weight = tf.get_variable(
'c_t_weight', [self.height,self.width,self.num_hidden*2])
ct_activation = tf.multiply(tf.tile(c_t, [1,1,1,2]), ct_weight)
i_c, f_c = tf.split(ct_activation, 2, 3)
i_ = i_h + i_c
f_ = f_h + f_c
g_ = g_h
o_ = o_h
if x != None:
x_concat = tf.layers.conv2d(x, self.num_hidden * 4,
self.filter_size, 1,
padding='same',
kernel_initializer=self.initializer,
name='input_to_state')
if self.layer_norm:
x_concat = tensor_layer_norm(x_concat, 'input_to_state')
i_x, g_x, f_x, o_x = tf.split(x_concat, 4, 3)
i_ += i_x
f_ += f_x
g_ += g_x
o_ += o_x
i_ = tf.nn.sigmoid(i_)
f_ = tf.nn.sigmoid(f_ + self._forget_bias)
c_new = f_ * c_t + i_ * tf.nn.tanh(g_)
oc_weight = tf.get_variable(
'oc_weight', [self.height,self.width,self.num_hidden])
o_c = tf.multiply(c_new, oc_weight)
h_new = tf.nn.sigmoid(o_ + o_c) * tf.nn.tanh(c_new)
return h_new, c_new
================================================
FILE: src/layers/SpatioTemporalLSTMCellv2.py
================================================
import math
import tensorflow as tf
from src.layers.TensorLayerNorm import tensor_layer_norm
class SpatioTemporalLSTMCell():
def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden,
seq_shape, tln=False, initializer=None):
"""Initialize the basic Conv LSTM cell.
Args:
layer_name: layer names for different convlstm layers.
filter_size: int tuple thats the height and width of the filter.
num_hidden: number of units in output tensor.
forget_bias: float, The bias added to forget gates (see above).
tln: whether to apply tensor layer normalization
"""
self.layer_name = layer_name
self.filter_size = filter_size
self.num_hidden_in = num_hidden_in
self.num_hidden = num_hidden
self.batch = seq_shape[0]
self.height = seq_shape[2]
self.width = seq_shape[3]
self.layer_norm = tln
self._forget_bias = 1.0
def w_initializer(dim_in, dim_out):
random_range = math.sqrt(6.0 / (dim_in + dim_out))
return tf.random_uniform_initializer(-random_range, random_range)
if initializer is None or initializer == -1:
self.initializer = w_initializer
else:
self.initializer = tf.random_uniform_initializer(-initializer, initializer)
def init_state(self):
return tf.zeros([self.batch, self.height, self.width, self.num_hidden],
dtype=tf.float32)
def __call__(self, x, h, c, m):
if h is None:
h = self.init_state()
if c is None:
c = self.init_state()
if m is None:
m = self.init_state()
with tf.variable_scope(self.layer_name):
t_cc = tf.layers.conv2d(
h, self.num_hidden*4,
self.filter_size, 1, padding='same',
kernel_initializer=self.initializer(self.num_hidden_in, self.num_hidden*4),
name='time_state_to_state')
s_cc = tf.layers.conv2d(
m, self.num_hidden*4,
self.filter_size, 1, padding='same',
kernel_initializer=self.initializer(self.num_hidden_in, self.num_hidden*4),
name='spatio_state_to_state')
x_shape_in = x.get_shape().as_list()[-1]
x_cc = tf.layers.conv2d(
x, self.num_hidden*4,
self.filter_size, 1, padding='same',
kernel_initializer=self.initializer(x_shape_in, self.num_hidden*4),
name='input_to_state')
if self.layer_norm:
t_cc = tensor_layer_norm(t_cc, 'time_state_to_state')
s_cc = tensor_layer_norm(s_cc, 'spatio_state_to_state')
x_cc = tensor_layer_norm(x_cc, 'input_to_state')
i_s, g_s, f_s, o_s = tf.split(s_cc, 4, 3)
i_t, g_t, f_t, o_t = tf.split(t_cc, 4, 3)
i_x, g_x, f_x, o_x = tf.split(x_cc, 4, 3)
i = tf.nn.sigmoid(i_x + i_t)
i_ = tf.nn.sigmoid(i_x + i_s)
g = tf.nn.tanh(g_x + g_t)
g_ = tf.nn.tanh(g_x + g_s)
f = tf.nn.sigmoid(f_x + f_t + self._forget_bias)
f_ = tf.nn.sigmoid(f_x + f_s + self._forget_bias)
o = tf.nn.sigmoid(o_x + o_t + o_s)
new_m = f_ * m + i_ * g_
new_c = f * c + i * g
cell = tf.concat([new_c, new_m],3)
cell = tf.layers.conv2d(cell, self.num_hidden, 1, 1, padding='same',
kernel_initializer=self.initializer(self.num_hidden*2, self.num_hidden),
name='cell_reduce')
new_h = o * tf.nn.tanh(cell)
return new_h, new_c, new_m
================================================
FILE: src/layers/TensorLayerNorm.py
================================================
import tensorflow as tf
EPSILON = 0.00001
def tensor_layer_norm(x, state_name):
x_shape = x.get_shape()
dims = x_shape.ndims
params_shape = x_shape[-1:]
if dims == 4:
m, v = tf.nn.moments(x, [1,2,3], keep_dims=True)
elif dims == 5:
m, v = tf.nn.moments(x, [1,2,3,4], keep_dims=True)
elif dims == 2:
m, v = tf.nn.moments(x, [1], keep_dims=True)
else:
raise ValueError('input tensor for layer normalization must be rank 4 or 5.')
b = tf.get_variable(state_name+'b',initializer=tf.zeros(params_shape))
s = tf.get_variable(state_name+'s',initializer=tf.ones(params_shape))
x_tln = tf.nn.batch_normalization(x, m, v, b, s, EPSILON)
return x_tln
================================================
FILE: src/layers/__init__.py
================================================
================================================
FILE: src/models/__init__.py
================================================
================================================
FILE: src/models/mim.py
================================================
__author__ = 'jianjin'
import tensorflow as tf
from src.layers.SpatioTemporalLSTMCellv2 import SpatioTemporalLSTMCell as stlstm
from src.layers.MIMBlock import MIMBlock as mimblock
from src.layers.MIMN import MIMN as mimn
import math
def w_initializer(dim_in, dim_out):
random_range = math.sqrt(6.0 / (dim_in + dim_out))
return tf.random_uniform_initializer(-random_range, random_range)
def mim(images, params, schedual_sampling_bool, num_layers, num_hidden, filter_size,
stride=1, total_length=20, input_length=10, tln=True):
gen_images = []
stlstm_layer = []
stlstm_layer_diff = []
cell_state = []
hidden_state = []
cell_state_diff = []
hidden_state_diff = []
shape = images.get_shape().as_list()
output_channels = shape[-1]
for i in range(num_layers):
if i == 0:
num_hidden_in = num_hidden[num_layers - 1]
else:
num_hidden_in = num_hidden[i - 1]
if i < 1:
new_stlstm_layer = stlstm('stlstm_' + str(i + 1),
filter_size,
num_hidden_in,
num_hidden[i],
shape,
tln=tln)
else:
new_stlstm_layer = mimblock('stlstm_' + str(i + 1),
filter_size,
num_hidden_in,
num_hidden[i],
shape,
tln=tln)
stlstm_layer.append(new_stlstm_layer)
cell_state.append(None)
hidden_state.append(None)
for i in range(num_layers - 1):
new_stlstm_layer = mimn('stlstm_diff' + str(i + 1),
filter_size,
num_hidden[i + 1],
shape,
tln=tln)
stlstm_layer_diff.append(new_stlstm_layer)
cell_state_diff.append(None)
hidden_state_diff.append(None)
st_memory = None
for time_step in range(total_length - 1):
reuse = bool(gen_images)
with tf.variable_scope('predrnn', reuse=reuse):
if time_step < input_length:
x_gen = images[:,time_step]
else:
x_gen = schedual_sampling_bool[:,time_step-input_length]*images[:,time_step] + \
(1-schedual_sampling_bool[:,time_step-input_length])*x_gen
preh = hidden_state[0]
hidden_state[0], cell_state[0], st_memory = stlstm_layer[0](
x_gen, hidden_state[0], cell_state[0], st_memory)
for i in range(1, num_layers):
if time_step > 0:
if i == 1:
hidden_state_diff[i - 1], cell_state_diff[i - 1] = stlstm_layer_diff[i - 1](
hidden_state[i - 1] - preh, hidden_state_diff[i - 1], cell_state_diff[i - 1])
else:
hidden_state_diff[i - 1], cell_state_diff[i - 1] = stlstm_layer_diff[i - 1](
hidden_state_diff[i - 2], hidden_state_diff[i - 1], cell_state_diff[i - 1])
else:
stlstm_layer_diff[i - 1](tf.zeros_like(hidden_state[i - 1]), None, None)
preh = hidden_state[i]
hidden_state[i], cell_state[i], st_memory = stlstm_layer[i](
hidden_state[i - 1], hidden_state_diff[i - 1], hidden_state[i], cell_state[i], st_memory)
x_gen = tf.layers.conv2d(hidden_state[num_layers - 1],
filters=output_channels,
kernel_size=1,
strides=1,
padding='same',
kernel_initializer=w_initializer(num_hidden[num_layers - 1], output_channels),
name="back_to_pixel")
gen_images.append(x_gen)
gen_images = tf.stack(gen_images, axis=1)
loss = tf.nn.l2_loss(gen_images - images[:, 1:])
return [gen_images, loss]
================================================
FILE: src/models/model_factory.py
================================================
import os
import tensorflow as tf
from src.utils import optimizer
from src.models import mim
class Model(object):
def __init__(self, configs):
self.configs = configs
# inputs
if configs.img_height > 0:
height = configs.img_height
else:
height = configs.img_width
self.x = [tf.placeholder(tf.float32,
[self.configs.batch_size,
self.configs.total_length,
self.configs.img_width // self.configs.patch_size,
height // self.configs.patch_size,
self.configs.patch_size * self.configs.patch_size * self.configs.img_channel])
for i in range(self.configs.n_gpu)]
self.real_input_flag = tf.placeholder(tf.float32,
[self.configs.batch_size,
self.configs.total_length - self.configs.input_length - 1,
self.configs.img_width // self.configs.patch_size,
height // self.configs.patch_size,
self.configs.patch_size * self.configs.patch_size * self.configs.img_channel])
grads = []
loss_train = []
self.pred_seq = []
self.tf_lr = tf.placeholder(tf.float32, shape=[])
self.params = dict()
self.params.update(self.configs.__dict__['__flags'])
num_hidden = [int(x) for x in self.configs.num_hidden.split(',')]
num_layers = len(num_hidden)
for i in range(self.configs.n_gpu):
with tf.device('/gpu:%d' % i):
with tf.variable_scope(tf.get_variable_scope(),
reuse=True if i > 0 else None):
# define a model
output_list = self.construct_model(
self.configs.model_name,
self.x[i],
self.params,
self.real_input_flag,
num_layers,
num_hidden,
self.configs.filter_size,
self.configs.stride,
self.configs.total_length,
self.configs.input_length,
self.configs.layer_norm)
gen_ims = output_list[0]
loss = output_list[1]
if len(output_list) > 2:
self.debug = output_list[2]
else:
self.debug = []
pred_ims = gen_ims[:, self.configs.input_length - self.configs.total_length:]
loss_train.append(loss / self.configs.batch_size)
# gradients
all_params = tf.trainable_variables()
grads.append(tf.gradients(loss, all_params))
self.pred_seq.append(pred_ims)
# add losses and gradients together and get training updates
with tf.device('/gpu:0'):
for i in range(1, self.configs.n_gpu):
loss_train[0] += loss_train[i]
for j in range(len(grads[0])):
grads[0][j] += grads[i][j]
# keep track of moving average
ema = tf.train.ExponentialMovingAverage(decay=0.9995)
maintain_averages_op = tf.group(ema.apply(all_params))
self.train_op = tf.group(optimizer.adam_updates(
all_params, grads[0], lr=self.tf_lr, mom1=0.95, mom2=0.9995),
maintain_averages_op)
self.loss_train = loss_train[0] / self.configs.n_gpu
# session
variables = tf.global_variables()
self.saver = tf.train.Saver(variables)
init = tf.global_variables_initializer()
configProt = tf.ConfigProto()
configProt.gpu_options.allow_growth = configs.allow_gpu_growth
configProt.allow_soft_placement = True
self.sess = tf.Session(config=configProt)
self.sess.run(init)
if self.configs.pretrained_model:
self.saver.restore(self.sess, self.configs.pretrained_model)
def train(self, inputs, lr, real_input_flag):
feed_dict = {self.x[i]: inputs[i] for i in range(self.configs.n_gpu)}
feed_dict.update({self.tf_lr: lr})
feed_dict.update({self.real_input_flag: real_input_flag})
loss, _, debug = self.sess.run((self.loss_train, self.train_op, self.debug), feed_dict)
return loss
def test(self, inputs, real_input_flag):
feed_dict = {self.x[i]: inputs[i] for i in range(self.configs.n_gpu)}
feed_dict.update({self.real_input_flag: real_input_flag})
gen_ims, debug = self.sess.run((self.pred_seq, self.debug), feed_dict)
return gen_ims, debug
def save(self, itr):
checkpoint_path = os.path.join(self.configs.save_dir, 'model.ckpt')
self.saver.save(self.sess, checkpoint_path, global_step=itr)
print('saved to ' + self.configs.save_dir)
def load(self, checkpoint_path):
print('load model:', checkpoint_path)
self.saver.restore(self.sess, checkpoint_path)
def construct_model(self, name, images, model_params, real_input_flag, num_layers, num_hidden,
filter_size, stride, total_length, input_length, tln):
'''Returns a sequence of generated frames
Args:
name: [predrnn_pp]
params: dict for extra parameters of some models
real_input_flag: for schedualed sampling.
num_hidden: number of units in a lstm layer.
filter_size: for convolutions inside lstm.
stride: for convolutions inside lstm.
total_length: including ins and outs.
input_length: for inputs.
tln: whether to apply tensor layer normalization.
Returns:
gen_images: a seq of frames.
loss: [l2 / l1+l2].
Raises:
ValueError: If network `name` is not recognized.
'''
networks_map = {
'mim': mim.mim,
}
params = dict(mask=real_input_flag, num_layers=num_layers, num_hidden=num_hidden, filter_size=filter_size,
stride=stride, total_length=total_length, input_length=input_length, is_training=True)
params.update(model_params)
if name in networks_map:
func = networks_map[name]
return func(images, params, real_input_flag, num_layers, num_hidden, filter_size,
stride, total_length, input_length, tln)
else:
raise ValueError('Name of network unknown %s' % name)
================================================
FILE: src/trainer.py
================================================
import os.path
import datetime
import cv2
import numpy as np
from skimage.measure import compare_ssim
from src.utils import metrics
from src.utils import preprocess
def train(model, ims, real_input_flag, configs, itr, ims_reverse=None):
ims = ims[:, :configs.total_length]
ims_list = np.split(ims, configs.n_gpu)
cost = model.train(ims_list, configs.lr, real_input_flag)
flag = 1
if configs.reverse_img:
ims_rev = np.split(ims_reverse, configs.n_gpu)
cost += model.train(ims_rev, configs.lr, real_input_flag)
flag += 1
if configs.reverse_input:
ims_rev = np.split(ims[:, ::-1], configs.n_gpu)
cost += model.train(ims_rev, configs.lr, real_input_flag)
flag += 1
if configs.reverse_img:
ims_rev = np.split(ims_reverse[:, ::-1], configs.n_gpu)
cost += model.train(ims_rev, configs.lr, real_input_flag)
flag += 1
cost = cost / flag
if itr % configs.display_interval == 0:
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr))
print('training loss: ' + str(cost))
def test(model, test_input_handle, configs, save_name):
print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...')
test_input_handle.begin(do_shuffle=False)
res_path = os.path.join(configs.gen_frm_dir, str(save_name))
os.mkdir(res_path)
avg_mse = 0
batch_id = 0
img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []
for i in range(configs.total_length - configs.input_length):
img_mse.append(0)
ssim.append(0)
psnr.append(0)
fmae.append(0)
sharp.append(0)
if configs.img_height > 0:
height = configs.img_height
else:
height = configs.img_width
real_input_flag = np.zeros(
(configs.batch_size,
configs.total_length - configs.input_length - 1,
configs.img_width // configs.patch_size,
height // configs.patch_size,
configs.patch_size ** 2 * configs.img_channel))
while not test_input_handle.no_batch_left():
batch_id = batch_id + 1
if save_name != 'test_result':
if batch_id > 100: break
test_ims = test_input_handle.get_batch()
test_ims = test_ims[:, :configs.total_length]
if len(test_ims.shape) > 3:
test_dat = preprocess.reshape_patch(test_ims, configs.patch_size)
else:
test_dat = test_ims
test_dat = np.split(test_dat, configs.n_gpu)
img_gen, debug = model.test(test_dat, real_input_flag)
# concat outputs of different gpus along batch
img_gen = np.concatenate(img_gen)
if len(img_gen.shape) > 3:
img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size)
# MSE per frame
for i in range(configs.total_length - configs.input_length):
x = test_ims[:, i + configs.input_length, :, :, :]
x = x[:configs.batch_size * configs.n_gpu]
x = x - np.where(x > 10000, np.floor_divide(x, 10000) * 10000, np.zeros_like(x))
gx = img_gen[:, i, :, :, :]
fmae[i] += metrics.batch_mae_frame_float(gx, x)
gx = np.maximum(gx, 0)
gx = np.minimum(gx, 1)
mse = np.square(x - gx).sum()
img_mse[i] += mse
avg_mse += mse
real_frm = np.uint8(x * 255)
pred_frm = np.uint8(gx * 255)
psnr[i] += metrics.batch_psnr(pred_frm, real_frm)
for b in range(configs.batch_size):
sharp[i] += np.max(
cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))
score, _ = compare_ssim(gx[b], x[b], full=True, multichannel=True)
ssim[i] += score
# save prediction examples
if batch_id <= configs.num_save_samples:
path = os.path.join(res_path, str(batch_id))
os.mkdir(path)
if len(debug) != 0:
np.save(os.path.join(path, "f.npy"), debug)
for i in range(configs.total_length):
name = 'gt' + str(i + 1) + '.png'
file_name = os.path.join(path, name)
img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)
if configs.img_channel == 2:
img_gt = img_gt[:, :, :1]
cv2.imwrite(file_name, img_gt)
for i in range(configs.total_length - configs.input_length):
name = 'pd' + str(i + 1 + configs.input_length) + '.png'
file_name = os.path.join(path, name)
img_pd = img_gen[0, i, :, :, :]
if configs.img_channel == 2:
img_pd = img_pd[:, :, :1]
img_pd = np.maximum(img_pd, 0)
img_pd = np.minimum(img_pd, 1)
img_pd = np.uint8(img_pd * 255)
cv2.imwrite(file_name, img_pd)
test_input_handle.next()
avg_mse = avg_mse / (batch_id * configs.batch_size * configs.n_gpu)
print('mse per seq: ' + str(avg_mse))
for i in range(configs.total_length - configs.input_length):
print(img_mse[i] / (batch_id * configs.batch_size * configs.n_gpu))
psnr = np.asarray(psnr, dtype=np.float32) / batch_id
fmae = np.asarray(fmae, dtype=np.float32) / batch_id
ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id)
sharp = np.asarray(sharp, dtype=np.float32) / (configs.batch_size * batch_id)
print('psnr per frame: ' + str(np.mean(psnr)))
for i in range(configs.total_length - configs.input_length):
print(psnr[i])
print('fmae per frame: ' + str(np.mean(fmae)))
for i in range(configs.total_length - configs.input_length):
print(fmae[i])
print('ssim per frame: ' + str(np.mean(ssim)))
for i in range(configs.total_length - configs.input_length):
print(ssim[i])
print('sharpness per frame: ' + str(np.mean(sharp)))
for i in range(configs.total_length - configs.input_length):
print(sharp[i])
================================================
FILE: src/utils/__init__.py
================================================
================================================
FILE: src/utils/metrics.py
================================================
__author__ = 'yunbo'
import numpy as np
from scipy.signal import convolve2d
def batch_mae_frame_float(gen_frames, gt_frames):
# [batch, width, height] or [batch, width, height, channel]
if gen_frames.ndim == 3:
axis = (1, 2)
elif gen_frames.ndim == 4:
axis = (1, 2, 3)
x = np.float32(gen_frames)
y = np.float32(gt_frames)
mae = np.sum(np.absolute(x - y), axis=axis, dtype=np.float32)
return np.mean(mae)
def batch_psnr(gen_frames, gt_frames):
# [batch, width, height] or [batch, width, height, channel]
if gen_frames.ndim == 3:
axis = (1, 2)
elif gen_frames.ndim == 4:
axis = (1, 2, 3)
x = np.int32(gen_frames)
y = np.int32(gt_frames)
num_pixels = float(np.size(gen_frames[0]))
mse = np.sum((x - y) ** 2, axis=axis, dtype=np.float32) / num_pixels
psnr = 20 * np.log10(255) - 10 * np.log10(mse)
return np.mean(psnr)
================================================
FILE: src/utils/optimizer.py
================================================
import tensorflow as tf
def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999):
updates = []
if type(cost_or_grads) is not list:
grads = tf.gradients(cost_or_grads, params)
else:
grads = cost_or_grads
t = tf.Variable(1., 'adam_t')
for p, g in zip(params, grads):
mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg')
if mom1 > 0:
v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v')
v_t = mom1 * v + (1. - mom1) * g
v_hat = v_t / (1. - tf.pow(mom1, t))
updates.append(v.assign(v_t))
else:
v_hat = g
mg_t = mom2 * mg + (1. - mom2) * tf.square(g)
mg_hat = mg_t / (1. - tf.pow(mom2, t))
g_t = v_hat / tf.sqrt(mg_hat + 1e-8)
p_t = p - lr * g_t
updates.append(mg.assign(mg_t))
updates.append(p.assign(p_t))
updates.append(t.assign_add(1))
return tf.group(*updates)
================================================
FILE: src/utils/preprocess.py
================================================
__author__ = 'yunbo'
import numpy as np
def reshape_patch(img_tensor, patch_size):
assert 5 == img_tensor.ndim
batch_size = np.shape(img_tensor)[0]
seq_length = np.shape(img_tensor)[1]
img_height = np.shape(img_tensor)[2]
img_width = np.shape(img_tensor)[3]
num_channels = np.shape(img_tensor)[4]
a = np.reshape(img_tensor, [batch_size, seq_length,
img_height//patch_size, patch_size,
img_width//patch_size, patch_size,
num_channels])
b = np.transpose(a, [0,1,2,4,3,5,6])
patch_tensor = np.reshape(b, [batch_size, seq_length,
img_height//patch_size,
img_width//patch_size,
patch_size*patch_size*num_channels])
return patch_tensor
def reshape_patch_back(patch_tensor, patch_size):
assert 5 == patch_tensor.ndim
batch_size = np.shape(patch_tensor)[0]
seq_length = np.shape(patch_tensor)[1]
patch_height = np.shape(patch_tensor)[2]
patch_width = np.shape(patch_tensor)[3]
channels = np.shape(patch_tensor)[4]
img_channels = channels // (patch_size*patch_size)
a = np.reshape(patch_tensor, [batch_size, seq_length,
patch_height, patch_width,
patch_size, patch_size,
img_channels])
b = np.transpose(a, [0,1,2,4,3,5,6])
img_tensor = np.reshape(b, [batch_size, seq_length,
patch_height * patch_size,
patch_width * patch_size,
img_channels])
return img_tensor
gitextract_lnbn37h0/
├── README.md
├── data/
│ └── human36m.sh
├── run.py
└── src/
├── __init__.py
├── data_provider/
│ ├── __init__.py
│ ├── datasets_factory.py
│ ├── human.py
│ ├── mnist.py
│ └── taxibj.py
├── layers/
│ ├── MIMBlock.py
│ ├── MIMN.py
│ ├── SpatioTemporalLSTMCellv2.py
│ ├── TensorLayerNorm.py
│ └── __init__.py
├── models/
│ ├── __init__.py
│ ├── mim.py
│ └── model_factory.py
├── trainer.py
└── utils/
├── __init__.py
├── metrics.py
├── optimizer.py
└── preprocess.py
SYMBOL INDEX (91 symbols across 15 files)
FILE: run.py
function main (line 92) | def main(argv=None):
function schedule_sampling (line 115) | def schedule_sampling(eta, itr):
function train_wrapper (line 158) | def train_wrapper(model):
function test_wrapper (line 191) | def test_wrapper(model):
FILE: src/data_provider/datasets_factory.py
function data_provider (line 10) | def data_provider(dataset_name, train_data_paths, valid_data_paths, batc...
FILE: src/data_provider/human.py
class InputHandle (line 12) | class InputHandle:
method __init__ (line 13) | def __init__(self, datas, indices, input_param):
method total (line 26) | def total(self):
method begin (line 29) | def begin(self, do_shuffle=True):
method next (line 36) | def next(self):
method no_batch_left (line 42) | def no_batch_left(self):
method get_batch (line 48) | def get_batch(self):
method print_stat (line 68) | def print_stat(self):
class DataProcess (line 76) | class DataProcess:
method __init__ (line 77) | def __init__(self, input_param):
method load_data (line 83) | def load_data(self, paths, mode='train'):
method get_train_input_handle (line 150) | def get_train_input_handle(self):
method get_test_input_handle (line 154) | def get_test_input_handle(self):
function main (line 159) | def main():
FILE: src/data_provider/mnist.py
class InputHandle (line 4) | class InputHandle:
method __init__ (line 5) | def __init__(self, input_param):
method load (line 22) | def load(self):
method total (line 40) | def total(self):
method begin (line 43) | def begin(self, do_shuffle = True):
method next (line 59) | def next(self):
method no_batch_left (line 74) | def no_batch_left(self):
method input_batch (line 80) | def input_batch(self):
method output_batch (line 98) | def output_batch(self):
method get_batch (line 131) | def get_batch(self):
FILE: src/data_provider/taxibj.py
function string2timestamp (line 17) | def string2timestamp(strings, T=48):
class STMatrix (line 30) | class STMatrix(object):
method __init__ (line 33) | def __init__(self, data, timestamps, T=48, CheckComplete=True):
method make_index (line 45) | def make_index(self):
method check_complete (line 50) | def check_complete(self):
method get_matrix (line 63) | def get_matrix(self, timestamp):
method save (line 66) | def save(self, fname):
method check_it (line 69) | def check_it(self, depends):
method create_dataset (line 75) | def create_dataset(self, len_closeness=20):
function load_stdata (line 104) | def load_stdata(fname):
function stat (line 112) | def stat(fname):
class MinMaxNormalization (line 140) | class MinMaxNormalization(object):
method __init__ (line 146) | def __init__(self):
method fit (line 149) | def fit(self, X):
method transform (line 154) | def transform(self, X):
method fit_transform (line 159) | def fit_transform(self, X):
method inverse_transform (line 163) | def inverse_transform(self, X):
function timestamp2vec (line 169) | def timestamp2vec(timestamps):
function remove_incomplete_days (line 185) | def remove_incomplete_days(data, timestamps, T=48):
class InputHandle (line 211) | class InputHandle:
method __init__ (line 212) | def __init__(self, datas, indices, input_param):
method total (line 223) | def total(self):
method begin (line 226) | def begin(self, do_shuffle=True):
method next (line 233) | def next(self):
method no_batch_left (line 239) | def no_batch_left(self):
method get_batch (line 245) | def get_batch(self):
method print_stat (line 254) | def print_stat(self):
class DataProcess (line 263) | class DataProcess:
method __init__ (line 264) | def __init__(self, input_param):
method load_data (line 274) | def load_data(self, datapath, T=48, nb_flow=2, len_closeness=None, len...
method get_train_input_handle (line 327) | def get_train_input_handle(self):
method get_test_input_handle (line 330) | def get_test_input_handle(self):
FILE: src/layers/MIMBlock.py
class MIMBlock (line 6) | class MIMBlock():
method __init__ (line 7) | def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden,
method init_state (line 36) | def init_state(self):
method MIMS (line 40) | def MIMS(self, x, h_t, c_t):
method __call__ (line 91) | def __call__(self, x, diff_h, h, c, m):
FILE: src/layers/MIMN.py
class MIMN (line 4) | class MIMN():
method __init__ (line 5) | def __init__(self, layer_name, filter_size, num_hidden, seq_shape, tln...
method init_state (line 26) | def init_state(self):
method __call__ (line 30) | def __call__(self, x, h_t, c_t):
FILE: src/layers/SpatioTemporalLSTMCellv2.py
class SpatioTemporalLSTMCell (line 6) | class SpatioTemporalLSTMCell():
method __init__ (line 7) | def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden,
method init_state (line 35) | def init_state(self):
method __call__ (line 39) | def __call__(self, x, h, c, m):
FILE: src/layers/TensorLayerNorm.py
function tensor_layer_norm (line 6) | def tensor_layer_norm(x, state_name):
FILE: src/models/mim.py
function w_initializer (line 10) | def w_initializer(dim_in, dim_out):
function mim (line 15) | def mim(images, params, schedual_sampling_bool, num_layers, num_hidden, ...
FILE: src/models/model_factory.py
class Model (line 9) | class Model(object):
method __init__ (line 10) | def __init__(self, configs):
method train (line 98) | def train(self, inputs, lr, real_input_flag):
method test (line 105) | def test(self, inputs, real_input_flag):
method save (line 111) | def save(self, itr):
method load (line 116) | def load(self, checkpoint_path):
method construct_model (line 120) | def construct_model(self, name, images, model_params, real_input_flag,...
FILE: src/trainer.py
function train (line 10) | def train(model, ims, real_input_flag, configs, itr, ims_reverse=None):
function test (line 38) | def test(model, test_input_handle, configs, save_name):
FILE: src/utils/metrics.py
function batch_mae_frame_float (line 7) | def batch_mae_frame_float(gen_frames, gt_frames):
function batch_psnr (line 19) | def batch_psnr(gen_frames, gt_frames):
FILE: src/utils/optimizer.py
function adam_updates (line 4) | def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999):
FILE: src/utils/preprocess.py
function reshape_patch (line 6) | def reshape_patch(img_tensor, patch_size):
function reshape_patch_back (line 25) | def reshape_patch_back(patch_tensor, patch_size):
Condensed preview — 22 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (77K chars).
[
{
"path": "README.md",
"chars": 2869,
"preview": "# Memory In Memory Networks\n\nMIM is a neural network for video prediction and spatiotemporal modeling. It is based on th"
},
{
"path": "data/human36m.sh",
"chars": 648,
"preview": "# Download H36M images\nmkdir human\ncd human\nwget http://visiondata.cis.upenn.edu/volumetric/h36m/S1.tar\ntar -xf S1.tar\nr"
},
{
"path": "run.py",
"chars": 8301,
"preview": "__author__ = 'yunbo'\n\nimport os\n\nimport tensorflow as tf\nimport numpy as np\nfrom time import time\n\nfrom src.data_provide"
},
{
"path": "src/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/data_provider/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/data_provider/datasets_factory.py",
"chars": 3849,
"preview": "from src.data_provider import mnist, human, taxibj\n\ndatasets_map = {\n 'mnist': mnist,\n 'taxibj': taxibj,\n 'hum"
},
{
"path": "src/data_provider/human.py",
"chars": 6760,
"preview": "__author__ = 'jianjin'\nimport numpy as np\nimport os\nimport cv2\nfrom PIL import Image\nimport logging\nimport random\nimport"
},
{
"path": "src/data_provider/mnist.py",
"chars": 5967,
"preview": "import numpy as np\nimport random\n\nclass InputHandle:\n def __init__(self, input_param):\n self.paths = input_par"
},
{
"path": "src/data_provider/taxibj.py",
"chars": 11418,
"preview": "__author__ = 'jianjin'\n\nimport random\nimport os.path\nimport logging\nimport os\nfrom copy import copy\nimport numpy as np\ni"
},
{
"path": "src/layers/MIMBlock.py",
"chars": 5901,
"preview": "import tensorflow as tf\nfrom src.layers.TensorLayerNorm import tensor_layer_norm\nimport math\n\n\nclass MIMBlock():\n def"
},
{
"path": "src/layers/MIMN.py",
"chars": 3172,
"preview": "import tensorflow as tf\nfrom src.layers.TensorLayerNorm import tensor_layer_norm\n\nclass MIMN():\n def __init__(self, l"
},
{
"path": "src/layers/SpatioTemporalLSTMCellv2.py",
"chars": 3803,
"preview": "import math\n\nimport tensorflow as tf\nfrom src.layers.TensorLayerNorm import tensor_layer_norm\n\nclass SpatioTemporalLSTMC"
},
{
"path": "src/layers/TensorLayerNorm.py",
"chars": 719,
"preview": "import tensorflow as tf\n\nEPSILON = 0.00001\n\n\ndef tensor_layer_norm(x, state_name):\n x_shape = x.get_shape()\n dims "
},
{
"path": "src/layers/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/models/mim.py",
"chars": 4256,
"preview": "__author__ = 'jianjin'\n\nimport tensorflow as tf\nfrom src.layers.SpatioTemporalLSTMCellv2 import SpatioTemporalLSTMCell a"
},
{
"path": "src/models/model_factory.py",
"chars": 6803,
"preview": "import os\n\nimport tensorflow as tf\n\nfrom src.utils import optimizer\nfrom src.models import mim\n\n\nclass Model(object):\n "
},
{
"path": "src/trainer.py",
"chars": 6072,
"preview": "import os.path\nimport datetime\nimport cv2\nimport numpy as np\nfrom skimage.measure import compare_ssim\nfrom src.utils imp"
},
{
"path": "src/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "src/utils/metrics.py",
"chars": 915,
"preview": "__author__ = 'yunbo'\n\nimport numpy as np\nfrom scipy.signal import convolve2d\n\n\ndef batch_mae_frame_float(gen_frames, gt_"
},
{
"path": "src/utils/optimizer.py",
"chars": 973,
"preview": "import tensorflow as tf\n\n\ndef adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999):\n updates = []\n "
},
{
"path": "src/utils/preprocess.py",
"chars": 1738,
"preview": "__author__ = 'yunbo'\n\nimport numpy as np\n\n\ndef reshape_patch(img_tensor, patch_size):\n assert 5 == img_tensor.ndim\n "
}
]
About this extraction
This page contains the full source code of the Yunbo426/MIM GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 22 files (72.4 KB), approximately 18.1k tokens, and a symbol index with 91 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.