Repository: Yunbo426/MIM Branch: master Commit: a61eb6a1be20 Files: 22 Total size: 72.4 KB Directory structure: gitextract_lnbn37h0/ ├── README.md ├── data/ │ └── human36m.sh ├── run.py └── src/ ├── __init__.py ├── data_provider/ │ ├── __init__.py │ ├── datasets_factory.py │ ├── human.py │ ├── mnist.py │ └── taxibj.py ├── layers/ │ ├── MIMBlock.py │ ├── MIMN.py │ ├── SpatioTemporalLSTMCellv2.py │ ├── TensorLayerNorm.py │ └── __init__.py ├── models/ │ ├── __init__.py │ ├── mim.py │ └── model_factory.py ├── trainer.py └── utils/ ├── __init__.py ├── metrics.py ├── optimizer.py └── preprocess.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: README.md ================================================ # Memory In Memory Networks MIM is a neural network for video prediction and spatiotemporal modeling. It is based on the paper [Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics](https://arxiv.org/pdf/1811.07490.pdf) to be presented at CVPR 2019. ## Abstract Natural spatiotemporal processes can be highly non-stationary in many ways, e.g. the low-level non-stationarity such as spatial correlations or temporal dependencies of local pixel values; and the high-level non-stationarity such as the accumulation, deformation or dissipation of radar echoes in precipitation forecasting. We try to stationalize and approximate the non-stationary processes by modeling the differential signals with the MIM recurrent blocks. By stacking multiple MIM blocks, we could potentially handle higher-order non-stationarity. Our model achieves the state-of-the-art results on three spatiotemporal prediction tasks across both synthetic and real-world data. ![model](https://github.com/ZJianjin/mim_images/blob/master/readme_structure.png) ## Pre-trained Models and Datasets All pre-trained MIM models have been uploaded to [DROPBOX](https://www.dropbox.com/s/7kd82ijezk4lkmp/mim-lib.zip?dl=0) and [BAIDU YUN](https://pan.baidu.com/s/1O07H7l1NTWmAkx3UCDVMLA) (password: srhv). It also includes our pre-processed training/testing data for Moving MNIST, Color-Changing Moving MNIST, and TaxiBJ. For Human3.6M, you may download it using data/human36m.sh. ## Generation Results #### Moving MNIST ![mnist1](https://github.com/ZJianjin/mim_images/blob/master/mnist1.gif) ![mnist2](https://github.com/ZJianjin/mim_images/blob/master/mnist4.gif) ![mnist2](https://github.com/ZJianjin/mim_images/blob/master/mnist5.gif) #### Color-Changing Moving MNIST ![mnistc1](https://github.com/ZJianjin/mim_images/blob/master/mnistc2.gif) ![mnistc2](https://github.com/ZJianjin/mim_images/blob/master/mnistc3.gif) ![mnistc2](https://github.com/ZJianjin/mim_images/blob/master/mnistc4.gif) #### Radar Echos ![radar1](https://github.com/ZJianjin/mim_images/blob/master/radar9.gif) ![radar2](https://github.com/ZJianjin/mim_images/blob/master/radar3.gif) ![radar3](https://github.com/ZJianjin/mim_images/blob/master/radar7.gif) #### Human3.6M ![human1](https://github.com/ZJianjin/mim_images/blob/master/human3.gif) ![human2](https://github.com/ZJianjin/mim_images/blob/master/human5.gif) ![human3](https://github.com/ZJianjin/mim_images/blob/master/human10.gif) ## BibTeX ``` @article{wang2018memory, title={Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics}, author={Wang, Yunbo and Zhang, Jianjin and Zhu, Hongyu and Long, Mingsheng and Wang, Jianmin and Yu, Philip S}, journal={arXiv preprint arXiv:1811.07490}, year={2019} } ``` ================================================ FILE: data/human36m.sh ================================================ # Download H36M images mkdir human cd human wget http://visiondata.cis.upenn.edu/volumetric/h36m/S1.tar tar -xf S1.tar rm S1.tar wget http://visiondata.cis.upenn.edu/volumetric/h36m/S5.tar tar -xf S5.tar rm S5.tar wget http://visiondata.cis.upenn.edu/volumetric/h36m/S6.tar tar -xf S6.tar rm S6.tar wget http://visiondata.cis.upenn.edu/volumetric/h36m/S7.tar tar -xf S7.tar rm S7.tar wget http://visiondata.cis.upenn.edu/volumetric/h36m/S8.tar tar -xf S8.tar rm S8.tar wget http://visiondata.cis.upenn.edu/volumetric/h36m/S9.tar tar -xf S9.tar rm S9.tar wget http://visiondata.cis.upenn.edu/volumetric/h36m/S11.tar tar -xf S11.tar rm S11.tar cd .. ================================================ FILE: run.py ================================================ __author__ = 'yunbo' import os import tensorflow as tf import numpy as np from time import time from src.data_provider import datasets_factory from src.models.model_factory import Model from src.utils import preprocess import src.trainer as trainer # ----------------------------------------------------------------------------- FLAGS = tf.app.flags.FLAGS # os.environ["CUDA_VISIBLE_DEVICES"] = "2" # mode tf.app.flags.DEFINE_boolean('is_training', True, 'training or testing') # data I/O tf.app.flags.DEFINE_string('dataset_name', 'mnist', 'The name of dataset.') tf.app.flags.DEFINE_string('train_data_paths', 'data/moving-mnist-example/moving-mnist-train.npz', 'train data paths.') tf.app.flags.DEFINE_string('valid_data_paths', 'data/moving-mnist-example/moving-mnist-valid.npz', 'validation data paths.') tf.app.flags.DEFINE_string('save_dir', 'checkpoints/mnist_predrnn_pp', 'dir to store trained net.') tf.app.flags.DEFINE_string('gen_frm_dir', 'results/mnist_predrnn_pp', 'dir to store result.') tf.app.flags.DEFINE_integer('input_length', 10, 'encoder hidden states.') tf.app.flags.DEFINE_integer('total_length', 20, 'total input and output length.') tf.app.flags.DEFINE_integer('img_width', 64, 'input image width.') tf.app.flags.DEFINE_integer('img_channel', 1, 'number of image channel.') # model[convlstm, predcnn, predrnn, predrnn_pp] tf.app.flags.DEFINE_string('model_name', 'convlstm_net', 'The name of the architecture.') tf.app.flags.DEFINE_string('pretrained_model', '', 'file of a pretrained model to initialize from.') tf.app.flags.DEFINE_string('num_hidden', '64,64,64,64', 'COMMA separated number of units in a convlstm layer.') tf.app.flags.DEFINE_integer('filter_size', 5, 'filter of a convlstm layer.') tf.app.flags.DEFINE_integer('stride', 1, 'stride of a convlstm layer.') tf.app.flags.DEFINE_integer('patch_size', 1, 'patch size on one dimension.') tf.app.flags.DEFINE_boolean('layer_norm', True, 'whether to apply tensor layer norm.') # scheduled sampling tf.app.flags.DEFINE_boolean('scheduled_sampling', True, 'for scheduled sampling') tf.app.flags.DEFINE_integer('sampling_stop_iter', 50000, 'for scheduled sampling.') tf.app.flags.DEFINE_float('sampling_start_value', 1.0, 'for scheduled sampling.') tf.app.flags.DEFINE_float('sampling_changing_rate', 0.00002, 'for scheduled sampling.') # optimization tf.app.flags.DEFINE_float('lr', 0.001, 'base learning rate.') tf.app.flags.DEFINE_boolean('reverse_input', True, 'whether to reverse the input frames while training.') tf.app.flags.DEFINE_boolean('reverse_img', False, 'whether to reverse the input images while training.') tf.app.flags.DEFINE_integer('batch_size', 8, 'batch size for training.') tf.app.flags.DEFINE_integer('max_iterations', 80000, 'max num of steps.') tf.app.flags.DEFINE_integer('display_interval', 1, 'number of iters showing training loss.') tf.app.flags.DEFINE_integer('test_interval', 1000, 'number of iters for test.') tf.app.flags.DEFINE_integer('snapshot_interval', 1000, 'number of iters saving models.') tf.app.flags.DEFINE_integer('num_save_samples', 10, 'number of sequences to be saved.') tf.app.flags.DEFINE_integer('n_gpu', 1, 'how many GPUs to distribute the training across.') # gpu tf.app.flags.DEFINE_boolean('allow_gpu_growth', False, 'allow gpu growth') tf.app.flags.DEFINE_integer('img_height', 0, 'input image height.') def main(argv=None): if tf.gfile.Exists(FLAGS.save_dir): tf.gfile.DeleteRecursively(FLAGS.save_dir) tf.gfile.MakeDirs(FLAGS.save_dir) if tf.gfile.Exists(FLAGS.gen_frm_dir): tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir) tf.gfile.MakeDirs(FLAGS.gen_frm_dir) gpu_list = np.asarray(os.environ.get('CUDA_VISIBLE_DEVICES', '-1').split(',') ,dtype=np.int32) FLAGS.n_gpu = len(gpu_list) print('Initializing models') model = Model(FLAGS) if FLAGS.is_training: train_wrapper(model) else: start = time() test_wrapper(model) stop = time() print("Time used: " + str(stop - start) + "s") def schedule_sampling(eta, itr): if FLAGS.img_height > 0: height = FLAGS.img_height else: height = FLAGS.img_width zeros = np.zeros((FLAGS.batch_size, FLAGS.total_length - FLAGS.input_length - 1, FLAGS.img_width // FLAGS.patch_size, height // FLAGS.patch_size, FLAGS.patch_size ** 2 * FLAGS.img_channel)) if not FLAGS.scheduled_sampling: return 0.0, zeros if itr < FLAGS.sampling_stop_iter: eta -= FLAGS.sampling_changing_rate else: eta = 0.0 random_flip = np.random.random_sample( (FLAGS.batch_size, FLAGS.total_length - FLAGS.input_length - 1)) true_token = (random_flip < eta) ones = np.ones((FLAGS.img_width // FLAGS.patch_size, height // FLAGS.patch_size, FLAGS.patch_size ** 2 * FLAGS.img_channel)) zeros = np.zeros((FLAGS.img_width // FLAGS.patch_size, height // FLAGS.patch_size, FLAGS.patch_size ** 2 * FLAGS.img_channel)) real_input_flag = [] for i in range(FLAGS.batch_size): for j in range(FLAGS.total_length - FLAGS.input_length - 1): if true_token[i, j]: real_input_flag.append(ones) else: real_input_flag.append(zeros) real_input_flag = np.array(real_input_flag) real_input_flag = np.reshape(real_input_flag, (FLAGS.batch_size, FLAGS.total_length - FLAGS.input_length - 1, FLAGS.img_width // FLAGS.patch_size, height // FLAGS.patch_size, FLAGS.patch_size ** 2 * FLAGS.img_channel)) return eta, real_input_flag def train_wrapper(model): if FLAGS.pretrained_model: model.load(FLAGS.pretrained_model) # load data train_input_handle, test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=True) eta = FLAGS.sampling_start_value for itr in range(1, FLAGS.max_iterations + 1): if train_input_handle.no_batch_left(): train_input_handle.begin(do_shuffle=True) ims = train_input_handle.get_batch() ims_reverse = None if FLAGS.reverse_img: ims_reverse = ims[:, :, :, ::-1] ims_reverse = preprocess.reshape_patch(ims_reverse, FLAGS.patch_size) ims = preprocess.reshape_patch(ims, FLAGS.patch_size) eta, real_input_flag = schedule_sampling(eta, itr) trainer.train(model, ims, real_input_flag, FLAGS, itr, ims_reverse) if itr % FLAGS.snapshot_interval == 0: model.save(itr) if itr % FLAGS.test_interval == 0: trainer.test(model, test_input_handle, FLAGS, itr) train_input_handle.next() def test_wrapper(model): model.load(FLAGS.pretrained_model) test_input_handle = datasets_factory.data_provider( FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths, FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=False) trainer.test(model, test_input_handle, FLAGS, 'test_result') if __name__ == '__main__': tf.app.run() ================================================ FILE: src/__init__.py ================================================ ================================================ FILE: src/data_provider/__init__.py ================================================ ================================================ FILE: src/data_provider/datasets_factory.py ================================================ from src.data_provider import mnist, human, taxibj datasets_map = { 'mnist': mnist, 'taxibj': taxibj, 'human': human } def data_provider(dataset_name, train_data_paths, valid_data_paths, batch_size, img_width, seq_length=20, is_training=True): '''Given a dataset name and returns a Dataset. Args: dataset_name: String, the name of the dataset. train_data_paths: List, [train_data_path1, train_data_path2...] valid_data_paths: List, [val_data_path1, val_data_path2...] batch_size: Int img_width: Int is_training: Bool Returns: if is_training: Two dataset instances for both training and evaluation. else: One dataset instance for evaluation. Raises: ValueError: If `dataset_name` is unknown. ''' if dataset_name not in datasets_map: raise ValueError('Name of dataset unknown %s' % dataset_name) train_data_list = train_data_paths.split(',') valid_data_list = valid_data_paths.split(',') if dataset_name == 'mnist': test_input_param = {'paths': valid_data_list, 'minibatch_size': batch_size, 'input_data_type': 'float32', 'is_output_sequence': True, 'name': dataset_name + 'test iterator'} test_input_handle = datasets_map[dataset_name].InputHandle(test_input_param) test_input_handle.begin(do_shuffle=False) if is_training: train_input_param = {'paths': train_data_list, 'minibatch_size': batch_size, 'input_data_type': 'float32', 'is_output_sequence': True, 'name': dataset_name + ' train iterator'} train_input_handle = datasets_map[dataset_name].InputHandle(train_input_param) train_input_handle.begin(do_shuffle=True) return train_input_handle, test_input_handle else: return test_input_handle if dataset_name == 'human': input_param = {'paths': valid_data_list, 'image_width': img_width, 'minibatch_size': batch_size, 'seq_length': seq_length, 'channel': 3, 'input_data_type': 'float32', 'name': 'human'} input_handle = datasets_map[dataset_name].DataProcess(input_param) test_input_handle = input_handle.get_test_input_handle() test_input_handle.begin(do_shuffle=False) if is_training: train_input_handle = input_handle.get_train_input_handle() train_input_handle.begin(do_shuffle=True) return train_input_handle, test_input_handle else: return test_input_handle if dataset_name == 'taxibj': input_param = {'paths': valid_data_list, 'image_width': img_width, 'minibatch_size': batch_size, 'seq_length': seq_length, 'input_data_type': 'float32', 'name': dataset_name + ' iterator'} input_handle = datasets_map[dataset_name].DataProcess(input_param) if is_training: train_input_handle = input_handle.get_train_input_handle() train_input_handle.begin(do_shuffle=True) test_input_handle = input_handle.get_test_input_handle() test_input_handle.begin(do_shuffle=False) return train_input_handle, test_input_handle else: test_input_handle = input_handle.get_test_input_handle() test_input_handle.begin(do_shuffle=False) return test_input_handle ================================================ FILE: src/data_provider/human.py ================================================ __author__ = 'jianjin' import numpy as np import os import cv2 from PIL import Image import logging import random import tensorflow as tf logger = logging.getLogger(__name__) class InputHandle: def __init__(self, datas, indices, input_param): self.name = input_param['name'] self.input_data_type = input_param.get('input_data_type', 'float32') self.minibatch_size = input_param['minibatch_size'] self.image_width = input_param['image_width'] self.channel = input_param['channel'] self.datas = datas self.indices = indices self.current_position = 0 self.current_batch_indices = [] self.current_input_length = input_param['seq_length'] self.interval = 2 def total(self): return len(self.indices) def begin(self, do_shuffle=True): logger.info("Initialization for read data ") if do_shuffle: random.shuffle(self.indices) self.current_position = 0 self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] def next(self): self.current_position += self.minibatch_size if self.no_batch_left(): return None self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] def no_batch_left(self): if self.current_position + self.minibatch_size > self.total(): return True else: return False def get_batch(self): if self.no_batch_left(): logger.error( "There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators") return None input_batch = np.zeros( (self.minibatch_size, self.current_input_length, self.image_width, self.image_width, self.channel)).astype( self.input_data_type) for i in range(self.minibatch_size): batch_ind = self.current_batch_indices[i] begin = batch_ind end = begin + self.current_input_length * self.interval data_slice = self.datas[begin:end:self.interval] input_batch[i, :self.current_input_length, :, :, :] = data_slice # logger.info('data_slice shape') # logger.info(data_slice.shape) # logger.info(input_batch.shape) input_batch = input_batch.astype(self.input_data_type) return input_batch def print_stat(self): logger.info("Iterator Name: " + self.name) logger.info(" current_position: " + str(self.current_position)) logger.info(" Minibatch Size: " + str(self.minibatch_size)) logger.info(" total Size: " + str(self.total())) logger.info(" current_input_length: " + str(self.current_input_length)) logger.info(" Input Data Type: " + str(self.input_data_type)) class DataProcess: def __init__(self, input_param): self.input_param = input_param self.paths = input_param['paths'] self.image_width = input_param['image_width'] self.seq_len = input_param['seq_length'] def load_data(self, paths, mode='train'): data_dir = paths[0] intervel = 2 frames_np = [] scenarios = ['Walking'] if mode == 'train': subjects = ['S1', 'S5', 'S6', 'S7', 'S8'] elif mode == 'test': subjects = ['S9', 'S11'] else: print ("MODE ERROR") _path = data_dir print ('load data...', _path) filenames = os.listdir(_path) filenames.sort() print ('data size ', len(filenames)) frames_file_name = [] for filename in filenames: fix = filename.split('.') fix = fix[0] subject = fix.split('_') scenario = subject[1] subject = subject[0] if subject not in subjects or scenario not in scenarios: continue file_path = os.path.join(_path, filename) image = cv2.cvtColor(cv2.imread(file_path), cv2.COLOR_BGR2RGB) #[1000,1000,3] image = image[image.shape[0]//4:-image.shape[0]//4, image.shape[1]//4:-image.shape[1]//4, :] if self.image_width != image.shape[0]: image = cv2.resize(image, (self.image_width, self.image_width)) #image = cv2.resize(image[100:-100,100:-100,:], (self.image_width, self.image_width), # interpolation=cv2.INTER_LINEAR) frames_np.append(np.array(image, dtype=np.float32) / 255.0) frames_file_name.append(filename) # if len(frames_np) % 100 == 0: print len(frames_np) #if len(frames_np) % 1000 == 0: break # is it a begin index of sequence indices = [] index = 0 print ('gen index') while index + intervel * self.seq_len - 1 < len(frames_file_name): # 'S11_Discussion_1.54138969_000471.jpg' # ['S11_Discussion_1', '54138969_000471', 'jpg'] start_infos = frames_file_name[index].split('.') end_infos = frames_file_name[index+intervel*(self.seq_len-1)].split('.') if start_infos[0] != end_infos[0]: index += 1 continue start_video_id, start_frame_id = start_infos[1].split('_') end_video_id, end_frame_id = end_infos[1].split('_') if start_video_id != end_video_id: index += 1 continue if int(end_frame_id) - int(start_frame_id) == 5 * (self.seq_len - 1) * intervel: indices.append(index) if mode == 'train': index += 10 elif mode == 'test': index += 5 print("there are " + str(len(indices)) + " sequences") # data = np.asarray(frames_np) data = frames_np print("there are " + str(len(data)) + " pictures") return data, indices def get_train_input_handle(self): train_data, train_indices = self.load_data(self.paths, mode='train') return InputHandle(train_data, train_indices, self.input_param) def get_test_input_handle(self): test_data, test_indices = self.load_data(self.paths, mode='test') return InputHandle(test_data, test_indices, self.input_param) def main(): parser = argparse.ArgumentParser() parser.add_argument("input_dir", type=str) parser.add_argument("output_dir", type=str) args = parser.parse_args() partition_names = ['train', 'test'] partition_fnames = partition_data(args.input_dir) if __name__ == '__main__': main() ================================================ FILE: src/data_provider/mnist.py ================================================ import numpy as np import random class InputHandle: def __init__(self, input_param): self.paths = input_param['paths'] self.num_paths = len(input_param['paths']) self.name = input_param['name'] self.input_data_type = input_param.get('input_data_type', 'float32') self.output_data_type = input_param.get('output_data_type', 'float32') self.minibatch_size = input_param['minibatch_size'] self.is_output_sequence = input_param['is_output_sequence'] self.data = {} self.indices = {} self.current_position = 0 self.current_batch_size = 0 self.current_batch_indices = [] self.current_input_length = 0 self.current_output_length = 0 self.load() def load(self): dat_1 = np.load(self.paths[0]) for key in dat_1.keys(): self.data[key] = dat_1[key] if self.num_paths == 2: dat_2 = np.load(self.paths[1]) num_clips_1 = dat_1['clips'].shape[1] dat_2['clips'][:,:,0] += num_clips_1 self.data['clips'] = np.concatenate( (dat_1['clips'], dat_2['clips']), axis=1) self.data['input_raw_data'] = np.concatenate( (dat_1['input_raw_data'], dat_2['input_raw_data']), axis=0) self.data['output_raw_data'] = np.concatenate( (dat_1['output_raw_data'], dat_2['output_raw_data']), axis=0) for key in self.data.keys(): print(key) print(self.data[key].shape) def total(self): return self.data['clips'].shape[1] def begin(self, do_shuffle = True): self.indices = np.arange(self.total(),dtype="int32") if do_shuffle: random.shuffle(self.indices) self.current_position = 0 if self.current_position + self.minibatch_size <= self.total(): self.current_batch_size = self.minibatch_size else: self.current_batch_size = self.total() - self.current_position self.current_batch_indices = self.indices[ self.current_position:self.current_position + self.current_batch_size] self.current_input_length = max(self.data['clips'][0, ind, 1] for ind in self.current_batch_indices) self.current_output_length = max(self.data['clips'][1, ind, 1] for ind in self.current_batch_indices) def next(self): self.current_position += self.current_batch_size if self.no_batch_left(): return None if self.current_position + self.minibatch_size <= self.total(): self.current_batch_size = self.minibatch_size else: self.current_batch_size = self.total() - self.current_position self.current_batch_indices = self.indices[ self.current_position:self.current_position + self.current_batch_size] self.current_input_length = max(self.data['clips'][0, ind, 1] for ind in self.current_batch_indices) self.current_output_length = max(self.data['clips'][1, ind, 1] for ind in self.current_batch_indices) def no_batch_left(self): if self.current_position >= self.total() - self.current_batch_size: return True else: return False def input_batch(self): if self.no_batch_left(): return None input_batch = np.zeros( (self.current_batch_size, self.current_input_length) + tuple(self.data['dims'][0])).astype(self.input_data_type) input_batch = np.transpose(input_batch,(0,1,3,4,2)) for i in range(self.current_batch_size): batch_ind = self.current_batch_indices[i] begin = self.data['clips'][0, batch_ind, 0] end = self.data['clips'][0, batch_ind, 0] + \ self.data['clips'][0, batch_ind, 1] data_slice = self.data['input_raw_data'][begin:end, :, :, :] data_slice = np.transpose(data_slice,(0,2,3,1)) input_batch[i, :self.current_input_length, :, :, :] = data_slice input_batch = input_batch.astype(self.input_data_type) return input_batch def output_batch(self): if self.no_batch_left(): return None if(2 ,3) == self.data['dims'].shape: raw_dat = self.data['output_raw_data'] else: raw_dat = self.data['input_raw_data'] if self.is_output_sequence: if (1, 3) == self.data['dims'].shape: output_dim = self.data['dims'][0] else: output_dim = self.data['dims'][1] output_batch = np.zeros( (self.current_batch_size,self.current_output_length) + tuple(output_dim)) else: output_batch = np.zeros((self.current_batch_size, ) + tuple(self.data['dims'][1])) for i in range(self.current_batch_size): batch_ind = self.current_batch_indices[i] begin = self.data['clips'][1, batch_ind, 0] end = self.data['clips'][1, batch_ind, 0] + \ self.data['clips'][1, batch_ind, 1] if self.is_output_sequence: data_slice = raw_dat[begin:end, :, :, :] output_batch[i, : data_slice.shape[0], :, :, :] = data_slice else: data_slice = raw_dat[begin, :, :, :] output_batch[i,:, :, :] = data_slice output_batch = output_batch.astype(self.output_data_type) output_batch = np.transpose(output_batch, [0,1,3,4,2]) return output_batch def get_batch(self): input_seq = self.input_batch() output_seq = self.output_batch() batch = np.concatenate((input_seq, output_seq), axis=1) return batch ================================================ FILE: src/data_provider/taxibj.py ================================================ __author__ = 'jianjin' import random import os.path import logging import os from copy import copy import numpy as np import h5py import pandas as pd from datetime import datetime import time logger = logging.getLogger(__name__) def string2timestamp(strings, T=48): timestamps = [] time_per_slot = 24.0 / T num_per_T = T // 24 for t in strings: year, month, day, slot = int(t[:4]), int(t[4:6]), int(t[6:8]), int(t[8:])-1 timestamps.append(pd.Timestamp(datetime(year, month, day, hour=int(slot * time_per_slot), minute=(slot % num_per_T) * int(60.0 * time_per_slot)))) return timestamps class STMatrix(object): """docstring for STMatrix""" def __init__(self, data, timestamps, T=48, CheckComplete=True): super(STMatrix, self).__init__() assert len(data) == len(timestamps) self.data = data self.timestamps = timestamps self.T = T self.pd_timestamps = string2timestamp(timestamps, T=self.T) if CheckComplete: self.check_complete() # index self.make_index() def make_index(self): self.get_index = dict() for i, ts in enumerate(self.pd_timestamps): self.get_index[ts] = i def check_complete(self): missing_timestamps = [] offset = pd.DateOffset(minutes=24 * 60 // self.T) pd_timestamps = self.pd_timestamps i = 1 while i < len(pd_timestamps): if pd_timestamps[i-1] + offset != pd_timestamps[i]: missing_timestamps.append("(%s -- %s)" % (pd_timestamps[i-1], pd_timestamps[i])) i += 1 for v in missing_timestamps: print(v) assert len(missing_timestamps) == 0 def get_matrix(self, timestamp): return self.data[self.get_index[timestamp]] def save(self, fname): pass def check_it(self, depends): for d in depends: if d not in self.get_index.keys(): return False return True def create_dataset(self, len_closeness=20): """current version """ # offset_week = pd.DateOffset(days=7) offset_frame = pd.DateOffset(minutes=24 * 60 // self.T) XC = [] timestamps_Y = [] depends = [range(1, len_closeness+1)] i = len_closeness while i < len(self.pd_timestamps): Flag = True for depend in depends: if Flag is False: break Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend]) if Flag is False: i += 1 continue x_c = [np.transpose(self.get_matrix(self.pd_timestamps[i] - j * offset_frame), [1, 2, 0]) for j in depends[0]] if len_closeness > 0: XC.append(np.stack(x_c, axis=0)) timestamps_Y.append(self.timestamps[i]) i += 1 XC = np.stack(XC, axis=0) return XC, timestamps_Y def load_stdata(fname): f = h5py.File(fname, 'r') data = f['data'].value timestamps = f['date'].value f.close() return data, timestamps def stat(fname): def get_nb_timeslot(f): s = f['date'][0] e = f['date'][-1] year, month, day = map(int, [s[:4], s[4:6], s[6:8]]) ts = time.strptime("%04i-%02i-%02i" % (year, month, day), "%Y-%m-%d") year, month, day = map(int, [e[:4], e[4:6], e[6:8]]) te = time.strptime("%04i-%02i-%02i" % (year, month, day), "%Y-%m-%d") nb_timeslot = (time.mktime(te) - time.mktime(ts)) / (0.5 * 3600) + 48 ts_str, te_str = time.strftime("%Y-%m-%d", ts), time.strftime("%Y-%m-%d", te) return nb_timeslot, ts_str, te_str with h5py.File(fname, 'r') as f: nb_timeslot, ts_str, te_str = get_nb_timeslot(f) nb_day = int(nb_timeslot / 48) mmax = f['data'].value.max() mmin = f['data'].value.min() stat = '=' * 5 + 'stat' + '=' * 5 + '\n' + \ 'data shape: %s\n' % str(f['data'].shape) + \ '# of days: %i, from %s to %s\n' % (nb_day, ts_str, te_str) + \ '# of timeslots: %i\n' % int(nb_timeslot) + \ '# of timeslots (available): %i\n' % f['date'].shape[0] + \ 'missing ratio of timeslots: %.1f%%\n' % ((1. - float(f['date'].shape[0] / nb_timeslot)) * 100) + \ 'max: %.3f, min: %.3f\n' % (mmax, mmin) + \ '=' * 5 + 'stat' + '=' * 5 print(stat) class MinMaxNormalization(object): '''MinMax Normalization --> [-1, 1] x = (x - min) / (max - min). x = x * 2 - 1 ''' def __init__(self): pass def fit(self, X): self._min = X.min() self._max = X.max() print("min:", self._min, "max:", self._max) def transform(self, X): X = 1. * (X - self._min) / (self._max - self._min) # X = X * 2. - 1. return X def fit_transform(self, X): self.fit(X) return self.transform(X) def inverse_transform(self, X): X = (X + 1.) / 2. X = 1. * X * (self._max - self._min) + self._min return X def timestamp2vec(timestamps): # tm_wday range [0, 6], Monday is 0 # vec = [time.strptime(str(t[:8], encoding='utf-8'), '%Y%m%d').tm_wday for t in timestamps] # python3 vec = [time.strptime(t[:8], '%Y%m%d').tm_wday for t in timestamps] # python2 ret = [] for i in vec: v = [0 for _ in range(7)] v[i] = 1 if i >= 5: v.append(0) # weekend else: v.append(1) # weekday ret.append(v) return np.asarray(ret) def remove_incomplete_days(data, timestamps, T=48): # remove a certain day which has not 48 timestamps days = [] # available days: some day only contain some seqs days_incomplete = [] i = 0 while i < len(timestamps): if int(timestamps[i][8:]) != 1: i += 1 elif i+T-1 < len(timestamps) and int(timestamps[i+T-1][8:]) == T: days.append(timestamps[i][:8]) i += T else: days_incomplete.append(timestamps[i][:8]) i += 1 print("incomplete days: ", days_incomplete) days = set(days) idx = [] for i, t in enumerate(timestamps): if t[:8] in days: idx.append(i) data = data[idx] timestamps = [timestamps[i] for i in idx] return data, timestamps class InputHandle: def __init__(self, datas, indices, input_param): self.name = input_param['name'] self.input_data_type = input_param.get('input_data_type', 'float32') self.minibatch_size = input_param['minibatch_size'] self.image_width = input_param['image_width'] self.datas = datas self.indices = indices self.current_position = 0 self.current_batch_indices = [] self.current_input_length = input_param['seq_length'] def total(self): return len(self.indices) def begin(self, do_shuffle=True): logger.info("Initialization for read data ") if do_shuffle: random.shuffle(self.indices) self.current_position = 0 self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] def next(self): self.current_position += self.minibatch_size if self.no_batch_left(): return None self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size] def no_batch_left(self): if self.current_position + self.minibatch_size >= self.total(): return True else: return False def get_batch(self): if self.no_batch_left(): logger.error( "There is no batch left in " + self.name + ". Consider to user iterators.begin() to rescan from the beginning of the iterators") return None input_batch = self.datas[self.current_batch_indices, :, :, :] input_batch = input_batch.astype(self.input_data_type) return input_batch def print_stat(self): logger.info("Iterator Name: " + self.name) logger.info(" current_position: " + str(self.current_position)) logger.info(" Minibatch Size: " + str(self.minibatch_size)) logger.info(" total Size: " + str(self.total())) logger.info(" current_input_length: " + str(self.current_input_length)) logger.info(" Input Data Type: " + str(self.input_data_type)) class DataProcess: def __init__(self, input_param): self.paths = input_param['paths'] self.image_width = input_param['image_width'] self.input_param = input_param self.seq_len = input_param['seq_length'] self.train_data, self.test_data, _, _, _ = self.load_data(self.paths, len_closeness=input_param['seq_length']) self.train_indices = list(range(self.train_data.shape[0])) self.test_indices = list(range(self.test_data.shape[0])) def load_data(self, datapath, T=48, nb_flow=2, len_closeness=None, len_test=48 * 7 * 4): """ """ assert (len_closeness > 0) # load data # 13 - 16 data_all = [] timestamps_all = list() for year in range(13, 17): fname = os.path.join( datapath[0], 'BJ{}_M32x32_T30_InOut.h5'.format(year)) print("file name: ", fname) stat(fname) data, timestamps = load_stdata(fname) # print(timestamps) # remove a certain day which does not have 48 timestamps data, timestamps = remove_incomplete_days(data, timestamps, T) data = data[:, :nb_flow] data[data < 0] = 0. data_all.append(data) timestamps_all.append(timestamps) print("\n") # minmax_scale data_train = np.vstack(copy(data_all))[:-len_test] print('train_data shape: ', data_train.shape) mmn = MinMaxNormalization() mmn.fit(data_train) data_all_mmn = [mmn.transform(d) for d in data_all] XC = [] timestamps_Y = [] for data, timestamps in zip(data_all_mmn, timestamps_all): # instance-based dataset --> sequences with format as (X, Y) where X is # a sequence of images and Y is an image. st = STMatrix(data, timestamps, T, CheckComplete=False) _XC, _timestamps_Y = st.create_dataset(len_closeness=len_closeness) XC.append(_XC) timestamps_Y += _timestamps_Y XC = np.concatenate(XC, axis=0) print("XC shape: ", XC.shape) XC_train = XC[:-len_test] XC_test = XC[-len_test:] timestamp_train, timestamp_test = timestamps_Y[:-len_test], timestamps_Y[-len_test:] X_train = XC_train X_test = XC_test print('train shape:', XC_train.shape, 'test shape: ', XC_test.shape) return X_train, X_test, mmn, timestamp_train, timestamp_test def get_train_input_handle(self): return InputHandle(self.train_data, self.train_indices, self.input_param) def get_test_input_handle(self): return InputHandle(self.test_data, self.test_indices, self.input_param) ================================================ FILE: src/layers/MIMBlock.py ================================================ import tensorflow as tf from src.layers.TensorLayerNorm import tensor_layer_norm import math class MIMBlock(): def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden, seq_shape, tln=False, initializer=None): """Initialize the basic Conv LSTM cell. Args: layer_name: layer names for different convlstm layers. filter_size: int tuple thats the height and width of the filter. num_hidden: number of units in output tensor. forget_bias: float, The bias added to forget gates (see above). tln: whether to apply tensor layer normalization """ self.layer_name = layer_name self.filter_size = filter_size self.num_hidden_in = num_hidden_in self.num_hidden = num_hidden self.convlstm_c = None self.batch = seq_shape[0] self.height = seq_shape[2] self.width = seq_shape[3] self.layer_norm = tln self._forget_bias = 1.0 def w_initializer(dim_in, dim_out): random_range = math.sqrt(6.0 / (dim_in + dim_out)) return tf.random_uniform_initializer(-random_range, random_range) if initializer is None or initializer == -1: self.initializer = w_initializer else: self.initializer = tf.random_uniform_initializer(-initializer, initializer) def init_state(self): return tf.zeros([self.batch, self.height, self.width, self.num_hidden], dtype=tf.float32) def MIMS(self, x, h_t, c_t): if h_t is None: h_t = self.init_state() if c_t is None: c_t = self.init_state() with tf.variable_scope(self.layer_name): h_concat = tf.layers.conv2d(h_t, self.num_hidden * 4, self.filter_size, 1, padding='same', kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4), name='state_to_state') if self.layer_norm: h_concat = tensor_layer_norm(h_concat, 'state_to_state') i_h, g_h, f_h, o_h = tf.split(h_concat, 4, 3) ct_weight = tf.get_variable( 'c_t_weight', [self.height,self.width,self.num_hidden*2]) ct_activation = tf.multiply(tf.tile(c_t, [1,1,1,2]), ct_weight) i_c, f_c = tf.split(ct_activation, 2, 3) i_ = i_h + i_c f_ = f_h + f_c g_ = g_h o_ = o_h if x != None: x_concat = tf.layers.conv2d(x, self.num_hidden * 4, self.filter_size, 1, padding='same', kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4), name='input_to_state') if self.layer_norm: x_concat = tensor_layer_norm(x_concat, 'input_to_state') i_x, g_x, f_x, o_x = tf.split(x_concat, 4, 3) i_ += i_x f_ += f_x g_ += g_x o_ += o_x i_ = tf.nn.sigmoid(i_) f_ = tf.nn.sigmoid(f_ + self._forget_bias) c_new = f_ * c_t + i_ * tf.nn.tanh(g_) oc_weight = tf.get_variable( 'oc_weight', [self.height,self.width,self.num_hidden]) o_c = tf.multiply(c_new, oc_weight) h_new = tf.nn.sigmoid(o_ + o_c) * tf.nn.tanh(c_new) return h_new, c_new def __call__(self, x, diff_h, h, c, m): if h is None: h = self.init_state() if c is None: c = self.init_state() if m is None: m = self.init_state() if diff_h is None: diff_h = tf.zeros_like(h) with tf.variable_scope(self.layer_name): t_cc = tf.layers.conv2d( h, self.num_hidden * 3, self.filter_size, 1, padding='same', kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 3), name='time_state_to_state') s_cc = tf.layers.conv2d( m, self.num_hidden * 4, self.filter_size, 1, padding='same', kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4), name='spatio_state_to_state') x_shape_in = x.get_shape().as_list()[-1] x_cc = tf.layers.conv2d( x, self.num_hidden * 4, self.filter_size, 1, padding='same', kernel_initializer=self.initializer(x_shape_in, self.num_hidden * 4), name='input_to_state') if self.layer_norm: t_cc = tensor_layer_norm(t_cc, 'time_state_to_state') s_cc = tensor_layer_norm(s_cc, 'spatio_state_to_state') x_cc = tensor_layer_norm(x_cc, 'input_to_state') i_s, g_s, f_s, o_s = tf.split(s_cc, 4, 3) i_t, g_t, o_t = tf.split(t_cc, 3, 3) i_x, g_x, f_x, o_x = tf.split(x_cc, 4, 3) i = tf.nn.sigmoid(i_x + i_t) i_ = tf.nn.sigmoid(i_x + i_s) g = tf.nn.tanh(g_x + g_t) g_ = tf.nn.tanh(g_x + g_s) f_ = tf.nn.sigmoid(f_x + f_s + self._forget_bias) o = tf.nn.sigmoid(o_x + o_t + o_s) new_m = f_ * m + i_ * g_ c, self.convlstm_c = self.MIMS(diff_h, c, self.convlstm_c) new_c = c + i * g cell = tf.concat([new_c, new_m], 3) cell = tf.layers.conv2d(cell, self.num_hidden, 1, 1, padding='same', name='cell_reduce') new_h = o * tf.nn.tanh(cell) return new_h, new_c, new_m ================================================ FILE: src/layers/MIMN.py ================================================ import tensorflow as tf from src.layers.TensorLayerNorm import tensor_layer_norm class MIMN(): def __init__(self, layer_name, filter_size, num_hidden, seq_shape, tln=True, initializer=0.001): """Initialize the basic Conv LSTM cell. Args: layer_name: layer names for different convlstm layers. filter_size: int tuple thats the height and width of the filter. num_hidden: number of units in output tensor. tln: whether to apply tensor layer normalization. """ self.layer_name = layer_name self.filter_size = filter_size self.num_hidden = num_hidden self.layer_norm = tln self.batch = seq_shape[0] self.height = seq_shape[2] self.width = seq_shape[3] self._forget_bias = 1.0 if initializer == -1: self.initializer = None else: self.initializer = tf.random_uniform_initializer(-initializer,initializer) def init_state(self): shape = [self.batch, self.height, self.width, self.num_hidden] return tf.zeros(shape, dtype=tf.float32) def __call__(self, x, h_t, c_t): if h_t is None: h_t = self.init_state() if c_t is None: c_t = self.init_state() with tf.variable_scope(self.layer_name): h_concat = tf.layers.conv2d(h_t, self.num_hidden * 4, self.filter_size, 1, padding='same', kernel_initializer=self.initializer, name='state_to_state') if self.layer_norm: h_concat = tensor_layer_norm(h_concat, 'state_to_state') i_h, g_h, f_h, o_h = tf.split(h_concat, 4, 3) ct_weight = tf.get_variable( 'c_t_weight', [self.height,self.width,self.num_hidden*2]) ct_activation = tf.multiply(tf.tile(c_t, [1,1,1,2]), ct_weight) i_c, f_c = tf.split(ct_activation, 2, 3) i_ = i_h + i_c f_ = f_h + f_c g_ = g_h o_ = o_h if x != None: x_concat = tf.layers.conv2d(x, self.num_hidden * 4, self.filter_size, 1, padding='same', kernel_initializer=self.initializer, name='input_to_state') if self.layer_norm: x_concat = tensor_layer_norm(x_concat, 'input_to_state') i_x, g_x, f_x, o_x = tf.split(x_concat, 4, 3) i_ += i_x f_ += f_x g_ += g_x o_ += o_x i_ = tf.nn.sigmoid(i_) f_ = tf.nn.sigmoid(f_ + self._forget_bias) c_new = f_ * c_t + i_ * tf.nn.tanh(g_) oc_weight = tf.get_variable( 'oc_weight', [self.height,self.width,self.num_hidden]) o_c = tf.multiply(c_new, oc_weight) h_new = tf.nn.sigmoid(o_ + o_c) * tf.nn.tanh(c_new) return h_new, c_new ================================================ FILE: src/layers/SpatioTemporalLSTMCellv2.py ================================================ import math import tensorflow as tf from src.layers.TensorLayerNorm import tensor_layer_norm class SpatioTemporalLSTMCell(): def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden, seq_shape, tln=False, initializer=None): """Initialize the basic Conv LSTM cell. Args: layer_name: layer names for different convlstm layers. filter_size: int tuple thats the height and width of the filter. num_hidden: number of units in output tensor. forget_bias: float, The bias added to forget gates (see above). tln: whether to apply tensor layer normalization """ self.layer_name = layer_name self.filter_size = filter_size self.num_hidden_in = num_hidden_in self.num_hidden = num_hidden self.batch = seq_shape[0] self.height = seq_shape[2] self.width = seq_shape[3] self.layer_norm = tln self._forget_bias = 1.0 def w_initializer(dim_in, dim_out): random_range = math.sqrt(6.0 / (dim_in + dim_out)) return tf.random_uniform_initializer(-random_range, random_range) if initializer is None or initializer == -1: self.initializer = w_initializer else: self.initializer = tf.random_uniform_initializer(-initializer, initializer) def init_state(self): return tf.zeros([self.batch, self.height, self.width, self.num_hidden], dtype=tf.float32) def __call__(self, x, h, c, m): if h is None: h = self.init_state() if c is None: c = self.init_state() if m is None: m = self.init_state() with tf.variable_scope(self.layer_name): t_cc = tf.layers.conv2d( h, self.num_hidden*4, self.filter_size, 1, padding='same', kernel_initializer=self.initializer(self.num_hidden_in, self.num_hidden*4), name='time_state_to_state') s_cc = tf.layers.conv2d( m, self.num_hidden*4, self.filter_size, 1, padding='same', kernel_initializer=self.initializer(self.num_hidden_in, self.num_hidden*4), name='spatio_state_to_state') x_shape_in = x.get_shape().as_list()[-1] x_cc = tf.layers.conv2d( x, self.num_hidden*4, self.filter_size, 1, padding='same', kernel_initializer=self.initializer(x_shape_in, self.num_hidden*4), name='input_to_state') if self.layer_norm: t_cc = tensor_layer_norm(t_cc, 'time_state_to_state') s_cc = tensor_layer_norm(s_cc, 'spatio_state_to_state') x_cc = tensor_layer_norm(x_cc, 'input_to_state') i_s, g_s, f_s, o_s = tf.split(s_cc, 4, 3) i_t, g_t, f_t, o_t = tf.split(t_cc, 4, 3) i_x, g_x, f_x, o_x = tf.split(x_cc, 4, 3) i = tf.nn.sigmoid(i_x + i_t) i_ = tf.nn.sigmoid(i_x + i_s) g = tf.nn.tanh(g_x + g_t) g_ = tf.nn.tanh(g_x + g_s) f = tf.nn.sigmoid(f_x + f_t + self._forget_bias) f_ = tf.nn.sigmoid(f_x + f_s + self._forget_bias) o = tf.nn.sigmoid(o_x + o_t + o_s) new_m = f_ * m + i_ * g_ new_c = f * c + i * g cell = tf.concat([new_c, new_m],3) cell = tf.layers.conv2d(cell, self.num_hidden, 1, 1, padding='same', kernel_initializer=self.initializer(self.num_hidden*2, self.num_hidden), name='cell_reduce') new_h = o * tf.nn.tanh(cell) return new_h, new_c, new_m ================================================ FILE: src/layers/TensorLayerNorm.py ================================================ import tensorflow as tf EPSILON = 0.00001 def tensor_layer_norm(x, state_name): x_shape = x.get_shape() dims = x_shape.ndims params_shape = x_shape[-1:] if dims == 4: m, v = tf.nn.moments(x, [1,2,3], keep_dims=True) elif dims == 5: m, v = tf.nn.moments(x, [1,2,3,4], keep_dims=True) elif dims == 2: m, v = tf.nn.moments(x, [1], keep_dims=True) else: raise ValueError('input tensor for layer normalization must be rank 4 or 5.') b = tf.get_variable(state_name+'b',initializer=tf.zeros(params_shape)) s = tf.get_variable(state_name+'s',initializer=tf.ones(params_shape)) x_tln = tf.nn.batch_normalization(x, m, v, b, s, EPSILON) return x_tln ================================================ FILE: src/layers/__init__.py ================================================ ================================================ FILE: src/models/__init__.py ================================================ ================================================ FILE: src/models/mim.py ================================================ __author__ = 'jianjin' import tensorflow as tf from src.layers.SpatioTemporalLSTMCellv2 import SpatioTemporalLSTMCell as stlstm from src.layers.MIMBlock import MIMBlock as mimblock from src.layers.MIMN import MIMN as mimn import math def w_initializer(dim_in, dim_out): random_range = math.sqrt(6.0 / (dim_in + dim_out)) return tf.random_uniform_initializer(-random_range, random_range) def mim(images, params, schedual_sampling_bool, num_layers, num_hidden, filter_size, stride=1, total_length=20, input_length=10, tln=True): gen_images = [] stlstm_layer = [] stlstm_layer_diff = [] cell_state = [] hidden_state = [] cell_state_diff = [] hidden_state_diff = [] shape = images.get_shape().as_list() output_channels = shape[-1] for i in range(num_layers): if i == 0: num_hidden_in = num_hidden[num_layers - 1] else: num_hidden_in = num_hidden[i - 1] if i < 1: new_stlstm_layer = stlstm('stlstm_' + str(i + 1), filter_size, num_hidden_in, num_hidden[i], shape, tln=tln) else: new_stlstm_layer = mimblock('stlstm_' + str(i + 1), filter_size, num_hidden_in, num_hidden[i], shape, tln=tln) stlstm_layer.append(new_stlstm_layer) cell_state.append(None) hidden_state.append(None) for i in range(num_layers - 1): new_stlstm_layer = mimn('stlstm_diff' + str(i + 1), filter_size, num_hidden[i + 1], shape, tln=tln) stlstm_layer_diff.append(new_stlstm_layer) cell_state_diff.append(None) hidden_state_diff.append(None) st_memory = None for time_step in range(total_length - 1): reuse = bool(gen_images) with tf.variable_scope('predrnn', reuse=reuse): if time_step < input_length: x_gen = images[:,time_step] else: x_gen = schedual_sampling_bool[:,time_step-input_length]*images[:,time_step] + \ (1-schedual_sampling_bool[:,time_step-input_length])*x_gen preh = hidden_state[0] hidden_state[0], cell_state[0], st_memory = stlstm_layer[0]( x_gen, hidden_state[0], cell_state[0], st_memory) for i in range(1, num_layers): if time_step > 0: if i == 1: hidden_state_diff[i - 1], cell_state_diff[i - 1] = stlstm_layer_diff[i - 1]( hidden_state[i - 1] - preh, hidden_state_diff[i - 1], cell_state_diff[i - 1]) else: hidden_state_diff[i - 1], cell_state_diff[i - 1] = stlstm_layer_diff[i - 1]( hidden_state_diff[i - 2], hidden_state_diff[i - 1], cell_state_diff[i - 1]) else: stlstm_layer_diff[i - 1](tf.zeros_like(hidden_state[i - 1]), None, None) preh = hidden_state[i] hidden_state[i], cell_state[i], st_memory = stlstm_layer[i]( hidden_state[i - 1], hidden_state_diff[i - 1], hidden_state[i], cell_state[i], st_memory) x_gen = tf.layers.conv2d(hidden_state[num_layers - 1], filters=output_channels, kernel_size=1, strides=1, padding='same', kernel_initializer=w_initializer(num_hidden[num_layers - 1], output_channels), name="back_to_pixel") gen_images.append(x_gen) gen_images = tf.stack(gen_images, axis=1) loss = tf.nn.l2_loss(gen_images - images[:, 1:]) return [gen_images, loss] ================================================ FILE: src/models/model_factory.py ================================================ import os import tensorflow as tf from src.utils import optimizer from src.models import mim class Model(object): def __init__(self, configs): self.configs = configs # inputs if configs.img_height > 0: height = configs.img_height else: height = configs.img_width self.x = [tf.placeholder(tf.float32, [self.configs.batch_size, self.configs.total_length, self.configs.img_width // self.configs.patch_size, height // self.configs.patch_size, self.configs.patch_size * self.configs.patch_size * self.configs.img_channel]) for i in range(self.configs.n_gpu)] self.real_input_flag = tf.placeholder(tf.float32, [self.configs.batch_size, self.configs.total_length - self.configs.input_length - 1, self.configs.img_width // self.configs.patch_size, height // self.configs.patch_size, self.configs.patch_size * self.configs.patch_size * self.configs.img_channel]) grads = [] loss_train = [] self.pred_seq = [] self.tf_lr = tf.placeholder(tf.float32, shape=[]) self.params = dict() self.params.update(self.configs.__dict__['__flags']) num_hidden = [int(x) for x in self.configs.num_hidden.split(',')] num_layers = len(num_hidden) for i in range(self.configs.n_gpu): with tf.device('/gpu:%d' % i): with tf.variable_scope(tf.get_variable_scope(), reuse=True if i > 0 else None): # define a model output_list = self.construct_model( self.configs.model_name, self.x[i], self.params, self.real_input_flag, num_layers, num_hidden, self.configs.filter_size, self.configs.stride, self.configs.total_length, self.configs.input_length, self.configs.layer_norm) gen_ims = output_list[0] loss = output_list[1] if len(output_list) > 2: self.debug = output_list[2] else: self.debug = [] pred_ims = gen_ims[:, self.configs.input_length - self.configs.total_length:] loss_train.append(loss / self.configs.batch_size) # gradients all_params = tf.trainable_variables() grads.append(tf.gradients(loss, all_params)) self.pred_seq.append(pred_ims) # add losses and gradients together and get training updates with tf.device('/gpu:0'): for i in range(1, self.configs.n_gpu): loss_train[0] += loss_train[i] for j in range(len(grads[0])): grads[0][j] += grads[i][j] # keep track of moving average ema = tf.train.ExponentialMovingAverage(decay=0.9995) maintain_averages_op = tf.group(ema.apply(all_params)) self.train_op = tf.group(optimizer.adam_updates( all_params, grads[0], lr=self.tf_lr, mom1=0.95, mom2=0.9995), maintain_averages_op) self.loss_train = loss_train[0] / self.configs.n_gpu # session variables = tf.global_variables() self.saver = tf.train.Saver(variables) init = tf.global_variables_initializer() configProt = tf.ConfigProto() configProt.gpu_options.allow_growth = configs.allow_gpu_growth configProt.allow_soft_placement = True self.sess = tf.Session(config=configProt) self.sess.run(init) if self.configs.pretrained_model: self.saver.restore(self.sess, self.configs.pretrained_model) def train(self, inputs, lr, real_input_flag): feed_dict = {self.x[i]: inputs[i] for i in range(self.configs.n_gpu)} feed_dict.update({self.tf_lr: lr}) feed_dict.update({self.real_input_flag: real_input_flag}) loss, _, debug = self.sess.run((self.loss_train, self.train_op, self.debug), feed_dict) return loss def test(self, inputs, real_input_flag): feed_dict = {self.x[i]: inputs[i] for i in range(self.configs.n_gpu)} feed_dict.update({self.real_input_flag: real_input_flag}) gen_ims, debug = self.sess.run((self.pred_seq, self.debug), feed_dict) return gen_ims, debug def save(self, itr): checkpoint_path = os.path.join(self.configs.save_dir, 'model.ckpt') self.saver.save(self.sess, checkpoint_path, global_step=itr) print('saved to ' + self.configs.save_dir) def load(self, checkpoint_path): print('load model:', checkpoint_path) self.saver.restore(self.sess, checkpoint_path) def construct_model(self, name, images, model_params, real_input_flag, num_layers, num_hidden, filter_size, stride, total_length, input_length, tln): '''Returns a sequence of generated frames Args: name: [predrnn_pp] params: dict for extra parameters of some models real_input_flag: for schedualed sampling. num_hidden: number of units in a lstm layer. filter_size: for convolutions inside lstm. stride: for convolutions inside lstm. total_length: including ins and outs. input_length: for inputs. tln: whether to apply tensor layer normalization. Returns: gen_images: a seq of frames. loss: [l2 / l1+l2]. Raises: ValueError: If network `name` is not recognized. ''' networks_map = { 'mim': mim.mim, } params = dict(mask=real_input_flag, num_layers=num_layers, num_hidden=num_hidden, filter_size=filter_size, stride=stride, total_length=total_length, input_length=input_length, is_training=True) params.update(model_params) if name in networks_map: func = networks_map[name] return func(images, params, real_input_flag, num_layers, num_hidden, filter_size, stride, total_length, input_length, tln) else: raise ValueError('Name of network unknown %s' % name) ================================================ FILE: src/trainer.py ================================================ import os.path import datetime import cv2 import numpy as np from skimage.measure import compare_ssim from src.utils import metrics from src.utils import preprocess def train(model, ims, real_input_flag, configs, itr, ims_reverse=None): ims = ims[:, :configs.total_length] ims_list = np.split(ims, configs.n_gpu) cost = model.train(ims_list, configs.lr, real_input_flag) flag = 1 if configs.reverse_img: ims_rev = np.split(ims_reverse, configs.n_gpu) cost += model.train(ims_rev, configs.lr, real_input_flag) flag += 1 if configs.reverse_input: ims_rev = np.split(ims[:, ::-1], configs.n_gpu) cost += model.train(ims_rev, configs.lr, real_input_flag) flag += 1 if configs.reverse_img: ims_rev = np.split(ims_reverse[:, ::-1], configs.n_gpu) cost += model.train(ims_rev, configs.lr, real_input_flag) flag += 1 cost = cost / flag if itr % configs.display_interval == 0: print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr)) print('training loss: ' + str(cost)) def test(model, test_input_handle, configs, save_name): print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...') test_input_handle.begin(do_shuffle=False) res_path = os.path.join(configs.gen_frm_dir, str(save_name)) os.mkdir(res_path) avg_mse = 0 batch_id = 0 img_mse, ssim, psnr, fmae, sharp = [], [], [], [], [] for i in range(configs.total_length - configs.input_length): img_mse.append(0) ssim.append(0) psnr.append(0) fmae.append(0) sharp.append(0) if configs.img_height > 0: height = configs.img_height else: height = configs.img_width real_input_flag = np.zeros( (configs.batch_size, configs.total_length - configs.input_length - 1, configs.img_width // configs.patch_size, height // configs.patch_size, configs.patch_size ** 2 * configs.img_channel)) while not test_input_handle.no_batch_left(): batch_id = batch_id + 1 if save_name != 'test_result': if batch_id > 100: break test_ims = test_input_handle.get_batch() test_ims = test_ims[:, :configs.total_length] if len(test_ims.shape) > 3: test_dat = preprocess.reshape_patch(test_ims, configs.patch_size) else: test_dat = test_ims test_dat = np.split(test_dat, configs.n_gpu) img_gen, debug = model.test(test_dat, real_input_flag) # concat outputs of different gpus along batch img_gen = np.concatenate(img_gen) if len(img_gen.shape) > 3: img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size) # MSE per frame for i in range(configs.total_length - configs.input_length): x = test_ims[:, i + configs.input_length, :, :, :] x = x[:configs.batch_size * configs.n_gpu] x = x - np.where(x > 10000, np.floor_divide(x, 10000) * 10000, np.zeros_like(x)) gx = img_gen[:, i, :, :, :] fmae[i] += metrics.batch_mae_frame_float(gx, x) gx = np.maximum(gx, 0) gx = np.minimum(gx, 1) mse = np.square(x - gx).sum() img_mse[i] += mse avg_mse += mse real_frm = np.uint8(x * 255) pred_frm = np.uint8(gx * 255) psnr[i] += metrics.batch_psnr(pred_frm, real_frm) for b in range(configs.batch_size): sharp[i] += np.max( cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3))) score, _ = compare_ssim(gx[b], x[b], full=True, multichannel=True) ssim[i] += score # save prediction examples if batch_id <= configs.num_save_samples: path = os.path.join(res_path, str(batch_id)) os.mkdir(path) if len(debug) != 0: np.save(os.path.join(path, "f.npy"), debug) for i in range(configs.total_length): name = 'gt' + str(i + 1) + '.png' file_name = os.path.join(path, name) img_gt = np.uint8(test_ims[0, i, :, :, :] * 255) if configs.img_channel == 2: img_gt = img_gt[:, :, :1] cv2.imwrite(file_name, img_gt) for i in range(configs.total_length - configs.input_length): name = 'pd' + str(i + 1 + configs.input_length) + '.png' file_name = os.path.join(path, name) img_pd = img_gen[0, i, :, :, :] if configs.img_channel == 2: img_pd = img_pd[:, :, :1] img_pd = np.maximum(img_pd, 0) img_pd = np.minimum(img_pd, 1) img_pd = np.uint8(img_pd * 255) cv2.imwrite(file_name, img_pd) test_input_handle.next() avg_mse = avg_mse / (batch_id * configs.batch_size * configs.n_gpu) print('mse per seq: ' + str(avg_mse)) for i in range(configs.total_length - configs.input_length): print(img_mse[i] / (batch_id * configs.batch_size * configs.n_gpu)) psnr = np.asarray(psnr, dtype=np.float32) / batch_id fmae = np.asarray(fmae, dtype=np.float32) / batch_id ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id) sharp = np.asarray(sharp, dtype=np.float32) / (configs.batch_size * batch_id) print('psnr per frame: ' + str(np.mean(psnr))) for i in range(configs.total_length - configs.input_length): print(psnr[i]) print('fmae per frame: ' + str(np.mean(fmae))) for i in range(configs.total_length - configs.input_length): print(fmae[i]) print('ssim per frame: ' + str(np.mean(ssim))) for i in range(configs.total_length - configs.input_length): print(ssim[i]) print('sharpness per frame: ' + str(np.mean(sharp))) for i in range(configs.total_length - configs.input_length): print(sharp[i]) ================================================ FILE: src/utils/__init__.py ================================================ ================================================ FILE: src/utils/metrics.py ================================================ __author__ = 'yunbo' import numpy as np from scipy.signal import convolve2d def batch_mae_frame_float(gen_frames, gt_frames): # [batch, width, height] or [batch, width, height, channel] if gen_frames.ndim == 3: axis = (1, 2) elif gen_frames.ndim == 4: axis = (1, 2, 3) x = np.float32(gen_frames) y = np.float32(gt_frames) mae = np.sum(np.absolute(x - y), axis=axis, dtype=np.float32) return np.mean(mae) def batch_psnr(gen_frames, gt_frames): # [batch, width, height] or [batch, width, height, channel] if gen_frames.ndim == 3: axis = (1, 2) elif gen_frames.ndim == 4: axis = (1, 2, 3) x = np.int32(gen_frames) y = np.int32(gt_frames) num_pixels = float(np.size(gen_frames[0])) mse = np.sum((x - y) ** 2, axis=axis, dtype=np.float32) / num_pixels psnr = 20 * np.log10(255) - 10 * np.log10(mse) return np.mean(psnr) ================================================ FILE: src/utils/optimizer.py ================================================ import tensorflow as tf def adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999): updates = [] if type(cost_or_grads) is not list: grads = tf.gradients(cost_or_grads, params) else: grads = cost_or_grads t = tf.Variable(1., 'adam_t') for p, g in zip(params, grads): mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg') if mom1 > 0: v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v') v_t = mom1 * v + (1. - mom1) * g v_hat = v_t / (1. - tf.pow(mom1, t)) updates.append(v.assign(v_t)) else: v_hat = g mg_t = mom2 * mg + (1. - mom2) * tf.square(g) mg_hat = mg_t / (1. - tf.pow(mom2, t)) g_t = v_hat / tf.sqrt(mg_hat + 1e-8) p_t = p - lr * g_t updates.append(mg.assign(mg_t)) updates.append(p.assign(p_t)) updates.append(t.assign_add(1)) return tf.group(*updates) ================================================ FILE: src/utils/preprocess.py ================================================ __author__ = 'yunbo' import numpy as np def reshape_patch(img_tensor, patch_size): assert 5 == img_tensor.ndim batch_size = np.shape(img_tensor)[0] seq_length = np.shape(img_tensor)[1] img_height = np.shape(img_tensor)[2] img_width = np.shape(img_tensor)[3] num_channels = np.shape(img_tensor)[4] a = np.reshape(img_tensor, [batch_size, seq_length, img_height//patch_size, patch_size, img_width//patch_size, patch_size, num_channels]) b = np.transpose(a, [0,1,2,4,3,5,6]) patch_tensor = np.reshape(b, [batch_size, seq_length, img_height//patch_size, img_width//patch_size, patch_size*patch_size*num_channels]) return patch_tensor def reshape_patch_back(patch_tensor, patch_size): assert 5 == patch_tensor.ndim batch_size = np.shape(patch_tensor)[0] seq_length = np.shape(patch_tensor)[1] patch_height = np.shape(patch_tensor)[2] patch_width = np.shape(patch_tensor)[3] channels = np.shape(patch_tensor)[4] img_channels = channels // (patch_size*patch_size) a = np.reshape(patch_tensor, [batch_size, seq_length, patch_height, patch_width, patch_size, patch_size, img_channels]) b = np.transpose(a, [0,1,2,4,3,5,6]) img_tensor = np.reshape(b, [batch_size, seq_length, patch_height * patch_size, patch_width * patch_size, img_channels]) return img_tensor