[
  {
    "path": "README.md",
    "content": "# Memory In Memory Networks\n\nMIM is a neural network for video prediction and spatiotemporal modeling. It is based on the paper [Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics](https://arxiv.org/pdf/1811.07490.pdf) to be presented at CVPR 2019.\n\n## Abstract\n\nNatural spatiotemporal processes can be highly non-stationary in many ways, e.g. the low-level non-stationarity such as spatial correlations or temporal dependencies of local pixel values; and the high-level non-stationarity such as the accumulation, deformation or dissipation of radar echoes in precipitation forecasting.\n\nWe try to stationalize and approximate the non-stationary processes by modeling the differential signals with the MIM recurrent blocks. By stacking multiple MIM blocks, we could potentially handle higher-order non-stationarity. Our model achieves the state-of-the-art results on three spatiotemporal prediction tasks across both synthetic and real-world data.\n\n![model](https://github.com/ZJianjin/mim_images/blob/master/readme_structure.png)\n\n## Pre-trained Models and Datasets\n\nAll pre-trained MIM models have been uploaded to [DROPBOX](https://www.dropbox.com/s/7kd82ijezk4lkmp/mim-lib.zip?dl=0) and [BAIDU YUN](https://pan.baidu.com/s/1O07H7l1NTWmAkx3UCDVMLA) (password: srhv).\n\nIt also includes our pre-processed training/testing data for Moving MNIST, Color-Changing Moving MNIST, and TaxiBJ. \n\nFor Human3.6M, you may  download it using data/human36m.sh.\n\n## Generation Results\n\n#### Moving MNIST\n\n![mnist1](https://github.com/ZJianjin/mim_images/blob/master/mnist1.gif)\n\n![mnist2](https://github.com/ZJianjin/mim_images/blob/master/mnist4.gif)\n\n![mnist2](https://github.com/ZJianjin/mim_images/blob/master/mnist5.gif)\n\n#### Color-Changing Moving MNIST\n\n![mnistc1](https://github.com/ZJianjin/mim_images/blob/master/mnistc2.gif)\n\n![mnistc2](https://github.com/ZJianjin/mim_images/blob/master/mnistc3.gif)\n\n![mnistc2](https://github.com/ZJianjin/mim_images/blob/master/mnistc4.gif)\n\n#### Radar Echos\n\n![radar1](https://github.com/ZJianjin/mim_images/blob/master/radar9.gif)\n\n![radar2](https://github.com/ZJianjin/mim_images/blob/master/radar3.gif)\n\n![radar3](https://github.com/ZJianjin/mim_images/blob/master/radar7.gif)\n\n#### Human3.6M\n\n![human1](https://github.com/ZJianjin/mim_images/blob/master/human3.gif)\n\n![human2](https://github.com/ZJianjin/mim_images/blob/master/human5.gif)\n\n![human3](https://github.com/ZJianjin/mim_images/blob/master/human10.gif)\n\n## BibTeX\n```\n@article{wang2018memory,\n  title={Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics},\n  author={Wang, Yunbo and Zhang, Jianjin and Zhu, Hongyu and Long, Mingsheng and Wang, Jianmin and Yu, Philip S},\n  journal={arXiv preprint arXiv:1811.07490},\n  year={2019}\n}\n```\n"
  },
  {
    "path": "data/human36m.sh",
    "content": "# Download H36M images\nmkdir human\ncd human\nwget http://visiondata.cis.upenn.edu/volumetric/h36m/S1.tar\ntar -xf S1.tar\nrm S1.tar\nwget http://visiondata.cis.upenn.edu/volumetric/h36m/S5.tar\ntar -xf S5.tar\nrm S5.tar\nwget http://visiondata.cis.upenn.edu/volumetric/h36m/S6.tar\ntar -xf S6.tar\nrm S6.tar\nwget http://visiondata.cis.upenn.edu/volumetric/h36m/S7.tar\ntar -xf S7.tar\nrm S7.tar\nwget http://visiondata.cis.upenn.edu/volumetric/h36m/S8.tar\ntar -xf S8.tar\nrm S8.tar\nwget http://visiondata.cis.upenn.edu/volumetric/h36m/S9.tar\ntar -xf S9.tar\nrm S9.tar\nwget http://visiondata.cis.upenn.edu/volumetric/h36m/S11.tar\ntar -xf S11.tar\nrm S11.tar\ncd ..\n"
  },
  {
    "path": "run.py",
    "content": "__author__ = 'yunbo'\n\nimport os\n\nimport tensorflow as tf\nimport numpy as np\nfrom time import time\n\nfrom src.data_provider import datasets_factory\nfrom src.models.model_factory import Model\nfrom src.utils import preprocess\nimport src.trainer as trainer\n\n# -----------------------------------------------------------------------------\nFLAGS = tf.app.flags.FLAGS\n\n# os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n\n# mode\ntf.app.flags.DEFINE_boolean('is_training', True, 'training or testing')\n\n# data I/O\ntf.app.flags.DEFINE_string('dataset_name', 'mnist',\n                           'The name of dataset.')\ntf.app.flags.DEFINE_string('train_data_paths',\n                           'data/moving-mnist-example/moving-mnist-train.npz',\n                           'train data paths.')\ntf.app.flags.DEFINE_string('valid_data_paths',\n                           'data/moving-mnist-example/moving-mnist-valid.npz',\n                           'validation data paths.')\ntf.app.flags.DEFINE_string('save_dir', 'checkpoints/mnist_predrnn_pp',\n                           'dir to store trained net.')\ntf.app.flags.DEFINE_string('gen_frm_dir', 'results/mnist_predrnn_pp',\n                           'dir to store result.')\ntf.app.flags.DEFINE_integer('input_length', 10,\n                            'encoder hidden states.')\ntf.app.flags.DEFINE_integer('total_length', 20,\n                            'total input and output length.')\ntf.app.flags.DEFINE_integer('img_width', 64,\n                            'input image width.')\ntf.app.flags.DEFINE_integer('img_channel', 1,\n                            'number of image channel.')\n# model[convlstm, predcnn, predrnn, predrnn_pp]\ntf.app.flags.DEFINE_string('model_name', 'convlstm_net',\n                           'The name of the architecture.')\ntf.app.flags.DEFINE_string('pretrained_model', '',\n                           'file of a pretrained model to initialize from.')\ntf.app.flags.DEFINE_string('num_hidden', '64,64,64,64',\n                           'COMMA separated number of units in a convlstm layer.')\ntf.app.flags.DEFINE_integer('filter_size', 5,\n                            'filter of a convlstm layer.')\ntf.app.flags.DEFINE_integer('stride', 1,\n                            'stride of a convlstm layer.')\ntf.app.flags.DEFINE_integer('patch_size', 1,\n                            'patch size on one dimension.')\ntf.app.flags.DEFINE_boolean('layer_norm', True,\n                            'whether to apply tensor layer norm.')\n# scheduled sampling\ntf.app.flags.DEFINE_boolean('scheduled_sampling', True, 'for scheduled sampling')\ntf.app.flags.DEFINE_integer('sampling_stop_iter', 50000, 'for scheduled sampling.')\ntf.app.flags.DEFINE_float('sampling_start_value', 1.0, 'for scheduled sampling.')\ntf.app.flags.DEFINE_float('sampling_changing_rate', 0.00002, 'for scheduled sampling.')\n# optimization\ntf.app.flags.DEFINE_float('lr', 0.001,\n                          'base learning rate.')\ntf.app.flags.DEFINE_boolean('reverse_input', True,\n                            'whether to reverse the input frames while training.')\ntf.app.flags.DEFINE_boolean('reverse_img', False,\n                            'whether to reverse the input images while training.')\ntf.app.flags.DEFINE_integer('batch_size', 8,\n                            'batch size for training.')\ntf.app.flags.DEFINE_integer('max_iterations', 80000,\n                            'max num of steps.')\ntf.app.flags.DEFINE_integer('display_interval', 1,\n                            'number of iters showing training loss.')\ntf.app.flags.DEFINE_integer('test_interval', 1000,\n                            'number of iters for test.')\ntf.app.flags.DEFINE_integer('snapshot_interval', 1000,\n                            'number of iters saving models.')\ntf.app.flags.DEFINE_integer('num_save_samples', 10,\n                            'number of sequences to be saved.')\ntf.app.flags.DEFINE_integer('n_gpu', 1,\n                            'how many GPUs to distribute the training across.')\n# gpu \ntf.app.flags.DEFINE_boolean('allow_gpu_growth', False,\n                            'allow gpu growth')\n\ntf.app.flags.DEFINE_integer('img_height', 0,\n                            'input image height.')\n\n\ndef main(argv=None):\n    if tf.gfile.Exists(FLAGS.save_dir):\n        tf.gfile.DeleteRecursively(FLAGS.save_dir)\n    tf.gfile.MakeDirs(FLAGS.save_dir)\n    if tf.gfile.Exists(FLAGS.gen_frm_dir):\n        tf.gfile.DeleteRecursively(FLAGS.gen_frm_dir)\n    tf.gfile.MakeDirs(FLAGS.gen_frm_dir)\n\n    gpu_list = np.asarray(os.environ.get('CUDA_VISIBLE_DEVICES', '-1').split(',') ,dtype=np.int32)\n    FLAGS.n_gpu = len(gpu_list)\n    print('Initializing models')\n\n    model = Model(FLAGS)\n\n    if FLAGS.is_training:\n        train_wrapper(model)\n    else:\n        start = time()\n        test_wrapper(model)\n        stop = time()\n        print(\"Time used: \" + str(stop - start) + \"s\")\n\n\ndef schedule_sampling(eta, itr):\n    if FLAGS.img_height > 0:\n        height = FLAGS.img_height\n    else:\n        height = FLAGS.img_width\n    zeros = np.zeros((FLAGS.batch_size,\n                      FLAGS.total_length - FLAGS.input_length - 1,\n                      FLAGS.img_width // FLAGS.patch_size,\n                      height // FLAGS.patch_size,\n                      FLAGS.patch_size ** 2 * FLAGS.img_channel))\n    if not FLAGS.scheduled_sampling:\n        return 0.0, zeros\n\n    if itr < FLAGS.sampling_stop_iter:\n        eta -= FLAGS.sampling_changing_rate\n    else:\n        eta = 0.0\n    random_flip = np.random.random_sample(\n        (FLAGS.batch_size, FLAGS.total_length - FLAGS.input_length - 1))\n    true_token = (random_flip < eta)\n    ones = np.ones((FLAGS.img_width // FLAGS.patch_size,\n                    height // FLAGS.patch_size,\n                    FLAGS.patch_size ** 2 * FLAGS.img_channel))\n    zeros = np.zeros((FLAGS.img_width // FLAGS.patch_size,\n                      height // FLAGS.patch_size,\n                      FLAGS.patch_size ** 2 * FLAGS.img_channel))\n    real_input_flag = []\n    for i in range(FLAGS.batch_size):\n        for j in range(FLAGS.total_length - FLAGS.input_length - 1):\n            if true_token[i, j]:\n                real_input_flag.append(ones)\n            else:\n                real_input_flag.append(zeros)\n    real_input_flag = np.array(real_input_flag)\n    real_input_flag = np.reshape(real_input_flag,\n                           (FLAGS.batch_size,\n                            FLAGS.total_length - FLAGS.input_length - 1,\n                            FLAGS.img_width // FLAGS.patch_size,\n                            height // FLAGS.patch_size,\n                            FLAGS.patch_size ** 2 * FLAGS.img_channel))\n    return eta, real_input_flag\n\n\ndef train_wrapper(model):\n    if FLAGS.pretrained_model:\n        model.load(FLAGS.pretrained_model)\n    # load data\n    train_input_handle, test_input_handle = datasets_factory.data_provider(\n        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,\n        FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=True)\n\n    eta = FLAGS.sampling_start_value\n\n    for itr in range(1, FLAGS.max_iterations + 1):\n        if train_input_handle.no_batch_left():\n            train_input_handle.begin(do_shuffle=True)\n        ims = train_input_handle.get_batch()\n        ims_reverse = None\n        if FLAGS.reverse_img:\n            ims_reverse = ims[:, :, :, ::-1]\n            ims_reverse = preprocess.reshape_patch(ims_reverse, FLAGS.patch_size)\n        ims = preprocess.reshape_patch(ims, FLAGS.patch_size)\n\n        eta, real_input_flag = schedule_sampling(eta, itr)\n\n        trainer.train(model, ims, real_input_flag, FLAGS, itr, ims_reverse)\n\n        if itr % FLAGS.snapshot_interval == 0:\n            model.save(itr)\n\n        if itr % FLAGS.test_interval == 0:\n            trainer.test(model, test_input_handle, FLAGS, itr)\n\n        train_input_handle.next()\n\n\ndef test_wrapper(model):\n    model.load(FLAGS.pretrained_model)\n    test_input_handle = datasets_factory.data_provider(\n        FLAGS.dataset_name, FLAGS.train_data_paths, FLAGS.valid_data_paths,\n        FLAGS.batch_size * FLAGS.n_gpu, FLAGS.img_width, seq_length=FLAGS.total_length, is_training=False)\n    trainer.test(model, test_input_handle, FLAGS, 'test_result')\n\n\nif __name__ == '__main__':\n    tf.app.run()\n\n"
  },
  {
    "path": "src/__init__.py",
    "content": ""
  },
  {
    "path": "src/data_provider/__init__.py",
    "content": ""
  },
  {
    "path": "src/data_provider/datasets_factory.py",
    "content": "from src.data_provider import mnist,  human, taxibj\n\ndatasets_map = {\n    'mnist': mnist,\n    'taxibj': taxibj,\n    'human': human\n}\n\n\ndef data_provider(dataset_name, train_data_paths, valid_data_paths, batch_size,\n                  img_width, seq_length=20, is_training=True):\n    '''Given a dataset name and returns a Dataset.\n    Args:\n        dataset_name: String, the name of the dataset.\n        train_data_paths: List, [train_data_path1, train_data_path2...]\n        valid_data_paths: List, [val_data_path1, val_data_path2...]\n        batch_size: Int\n        img_width: Int\n        is_training: Bool\n    Returns:\n        if is_training:\n            Two dataset instances for both training and evaluation.\n        else:\n            One dataset instance for evaluation.\n    Raises:\n        ValueError: If `dataset_name` is unknown.\n    '''\n    if dataset_name not in datasets_map:\n        raise ValueError('Name of dataset unknown %s' % dataset_name)\n    train_data_list = train_data_paths.split(',')\n    valid_data_list = valid_data_paths.split(',')\n    if dataset_name == 'mnist':\n        test_input_param = {'paths': valid_data_list,\n                            'minibatch_size': batch_size,\n                            'input_data_type': 'float32',\n                            'is_output_sequence': True,\n                            'name': dataset_name + 'test iterator'}\n        test_input_handle = datasets_map[dataset_name].InputHandle(test_input_param)\n        test_input_handle.begin(do_shuffle=False)\n        if is_training:\n            train_input_param = {'paths': train_data_list,\n                                 'minibatch_size': batch_size,\n                                 'input_data_type': 'float32',\n                                 'is_output_sequence': True,\n                                 'name': dataset_name + ' train iterator'}\n            train_input_handle = datasets_map[dataset_name].InputHandle(train_input_param)\n            train_input_handle.begin(do_shuffle=True)\n            return train_input_handle, test_input_handle\n        else:\n            return test_input_handle\n\n    if dataset_name == 'human':\n        input_param = {'paths': valid_data_list,\n                       'image_width': img_width,\n                       'minibatch_size': batch_size,\n                       'seq_length': seq_length,\n                       'channel': 3,\n                       'input_data_type': 'float32',\n                       'name': 'human'}\n        input_handle = datasets_map[dataset_name].DataProcess(input_param)\n        test_input_handle = input_handle.get_test_input_handle()\n        test_input_handle.begin(do_shuffle=False)\n        if is_training:\n            train_input_handle = input_handle.get_train_input_handle()\n            train_input_handle.begin(do_shuffle=True)\n            return train_input_handle, test_input_handle\n        else:\n            return test_input_handle\n\n    if dataset_name == 'taxibj':\n        input_param = {'paths': valid_data_list,\n                       'image_width': img_width,\n                       'minibatch_size': batch_size,\n                       'seq_length': seq_length,\n                       'input_data_type': 'float32',\n                       'name': dataset_name + ' iterator'}\n        input_handle = datasets_map[dataset_name].DataProcess(input_param)\n        if is_training:\n            train_input_handle = input_handle.get_train_input_handle()\n            train_input_handle.begin(do_shuffle=True)\n            test_input_handle = input_handle.get_test_input_handle()\n            test_input_handle.begin(do_shuffle=False)\n            return train_input_handle, test_input_handle\n        else:\n            test_input_handle = input_handle.get_test_input_handle()\n            test_input_handle.begin(do_shuffle=False)\n            return test_input_handle\n"
  },
  {
    "path": "src/data_provider/human.py",
    "content": "__author__ = 'jianjin'\nimport numpy as np\nimport os\nimport cv2\nfrom PIL import Image\nimport logging\nimport random\nimport tensorflow as tf\n\nlogger = logging.getLogger(__name__)\n\nclass InputHandle:\n    def __init__(self, datas, indices, input_param):\n        self.name = input_param['name']\n        self.input_data_type = input_param.get('input_data_type', 'float32')\n        self.minibatch_size = input_param['minibatch_size']\n        self.image_width = input_param['image_width']\n        self.channel = input_param['channel']\n        self.datas = datas\n        self.indices = indices\n        self.current_position = 0\n        self.current_batch_indices = []\n        self.current_input_length = input_param['seq_length']\n        self.interval = 2\n\n    def total(self):\n        return len(self.indices)\n\n    def begin(self, do_shuffle=True):\n        logger.info(\"Initialization for read data \")\n        if do_shuffle:\n            random.shuffle(self.indices)\n        self.current_position = 0\n        self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]\n\n    def next(self):\n        self.current_position += self.minibatch_size\n        if self.no_batch_left():\n            return None\n        self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]\n\n    def no_batch_left(self):\n        if self.current_position + self.minibatch_size > self.total():\n            return True\n        else:\n            return False\n\n    def get_batch(self):\n        if self.no_batch_left():\n            logger.error(\n                \"There is no batch left in \" + self.name + \". Consider to user iterators.begin() to rescan from the beginning of the iterators\")\n            return None\n        input_batch = np.zeros(\n            (self.minibatch_size, self.current_input_length, self.image_width, self.image_width, self.channel)).astype(\n            self.input_data_type)\n        for i in range(self.minibatch_size):\n            batch_ind = self.current_batch_indices[i]\n            begin = batch_ind\n            end = begin + self.current_input_length * self.interval\n            data_slice = self.datas[begin:end:self.interval]\n            input_batch[i, :self.current_input_length, :, :, :] = data_slice\n            # logger.info('data_slice shape')\n            # logger.info(data_slice.shape)\n            # logger.info(input_batch.shape)\n        input_batch = input_batch.astype(self.input_data_type)\n        return input_batch\n\n    def print_stat(self):\n        logger.info(\"Iterator Name: \" + self.name)\n        logger.info(\"    current_position: \" + str(self.current_position))\n        logger.info(\"    Minibatch Size: \" + str(self.minibatch_size))\n        logger.info(\"    total Size: \" + str(self.total()))\n        logger.info(\"    current_input_length: \" + str(self.current_input_length))\n        logger.info(\"    Input Data Type: \" + str(self.input_data_type))\n\nclass DataProcess:\n    def __init__(self, input_param):\n        self.input_param = input_param\n        self.paths = input_param['paths']\n        self.image_width = input_param['image_width']\n        self.seq_len = input_param['seq_length']\n\n    def load_data(self, paths, mode='train'):\n        data_dir = paths[0]\n        intervel = 2\n\n        frames_np = []\n        scenarios = ['Walking']\n        if mode == 'train':\n            subjects = ['S1', 'S5', 'S6', 'S7', 'S8']\n        elif mode == 'test':\n            subjects = ['S9', 'S11']\n        else:\n            print (\"MODE ERROR\")\n        _path = data_dir\n        print ('load data...', _path)\n        filenames = os.listdir(_path)\n        filenames.sort()\n        print ('data size ', len(filenames))\n        frames_file_name = []\n        for filename in filenames:\n            fix = filename.split('.')\n            fix = fix[0]\n            subject = fix.split('_')\n            scenario = subject[1]\n            subject = subject[0]\n            if subject not in subjects or scenario not in scenarios:\n                continue\n            file_path = os.path.join(_path, filename)\n            image = cv2.cvtColor(cv2.imread(file_path), cv2.COLOR_BGR2RGB)\n            #[1000,1000,3]\n            image = image[image.shape[0]//4:-image.shape[0]//4, image.shape[1]//4:-image.shape[1]//4, :]\n            if self.image_width != image.shape[0]:\n                image = cv2.resize(image, (self.image_width, self.image_width))\n            #image = cv2.resize(image[100:-100,100:-100,:], (self.image_width, self.image_width),\n            #                   interpolation=cv2.INTER_LINEAR)\n            frames_np.append(np.array(image, dtype=np.float32) / 255.0)\n            frames_file_name.append(filename)\n#             if len(frames_np) % 100 == 0: print len(frames_np)\n            #if len(frames_np) % 1000 == 0: break\n        # is it a begin index of sequence\n        indices = []\n        index = 0\n        print ('gen index')\n        while index + intervel * self.seq_len - 1 < len(frames_file_name):\n            # 'S11_Discussion_1.54138969_000471.jpg'\n            # ['S11_Discussion_1', '54138969_000471', 'jpg']\n            start_infos = frames_file_name[index].split('.')\n            end_infos = frames_file_name[index+intervel*(self.seq_len-1)].split('.')\n            if start_infos[0] != end_infos[0]:\n                index += 1\n                continue\n            start_video_id, start_frame_id = start_infos[1].split('_')\n            end_video_id, end_frame_id = end_infos[1].split('_')\n            if start_video_id != end_video_id:\n                index += 1\n                continue\n            if int(end_frame_id) - int(start_frame_id) == 5 * (self.seq_len - 1) * intervel:\n                indices.append(index)\n            if mode == 'train':\n                index += 10\n            elif mode == 'test':\n                index += 5\n        print(\"there are \" + str(len(indices)) + \" sequences\")\n        # data = np.asarray(frames_np)\n        data = frames_np\n        print(\"there are \" + str(len(data)) + \" pictures\")\n        return data, indices\n\n    def get_train_input_handle(self):\n        train_data, train_indices = self.load_data(self.paths, mode='train')\n        return InputHandle(train_data, train_indices, self.input_param)\n\n    def get_test_input_handle(self):\n        test_data, test_indices = self.load_data(self.paths, mode='test')\n        return InputHandle(test_data, test_indices, self.input_param)\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"input_dir\", type=str)\n    parser.add_argument(\"output_dir\", type=str)\n    args = parser.parse_args()\n\n    partition_names = ['train', 'test']\n    partition_fnames = partition_data(args.input_dir)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "src/data_provider/mnist.py",
    "content": "import numpy as np\nimport random\n\nclass InputHandle:\n    def __init__(self, input_param):\n        self.paths = input_param['paths']\n        self.num_paths = len(input_param['paths'])\n        self.name = input_param['name']\n        self.input_data_type = input_param.get('input_data_type', 'float32')\n        self.output_data_type = input_param.get('output_data_type', 'float32')\n        self.minibatch_size = input_param['minibatch_size']\n        self.is_output_sequence = input_param['is_output_sequence']\n        self.data = {}\n        self.indices = {}\n        self.current_position = 0\n        self.current_batch_size = 0\n        self.current_batch_indices = []\n        self.current_input_length = 0\n        self.current_output_length = 0\n        self.load()\n\n    def load(self):\n        dat_1 = np.load(self.paths[0])\n        for key in dat_1.keys():\n            self.data[key] = dat_1[key]\n        if self.num_paths == 2:\n            dat_2 = np.load(self.paths[1])\n            num_clips_1 = dat_1['clips'].shape[1]\n            dat_2['clips'][:,:,0] += num_clips_1\n            self.data['clips'] = np.concatenate(\n                (dat_1['clips'], dat_2['clips']), axis=1)\n            self.data['input_raw_data'] = np.concatenate(\n                (dat_1['input_raw_data'], dat_2['input_raw_data']), axis=0)\n            self.data['output_raw_data'] = np.concatenate(\n                (dat_1['output_raw_data'], dat_2['output_raw_data']), axis=0)\n        for key in self.data.keys():\n            print(key)\n            print(self.data[key].shape)\n\n    def total(self):\n        return self.data['clips'].shape[1]\n\n    def begin(self, do_shuffle = True):\n        self.indices = np.arange(self.total(),dtype=\"int32\")\n        if do_shuffle:\n            random.shuffle(self.indices)\n        self.current_position = 0\n        if self.current_position + self.minibatch_size <= self.total():\n            self.current_batch_size = self.minibatch_size\n        else:\n            self.current_batch_size = self.total() - self.current_position\n        self.current_batch_indices = self.indices[\n            self.current_position:self.current_position + self.current_batch_size]\n        self.current_input_length = max(self.data['clips'][0, ind, 1] for ind\n                                        in self.current_batch_indices)\n        self.current_output_length = max(self.data['clips'][1, ind, 1] for ind\n                                         in self.current_batch_indices)\n\n    def next(self):\n        self.current_position += self.current_batch_size\n        if self.no_batch_left():\n            return None\n        if self.current_position + self.minibatch_size <= self.total():\n            self.current_batch_size = self.minibatch_size\n        else:\n            self.current_batch_size = self.total() - self.current_position\n        self.current_batch_indices = self.indices[\n            self.current_position:self.current_position + self.current_batch_size]\n        self.current_input_length = max(self.data['clips'][0, ind, 1] for ind\n                                        in self.current_batch_indices)\n        self.current_output_length = max(self.data['clips'][1, ind, 1] for ind\n                                         in self.current_batch_indices)\n\n    def no_batch_left(self):\n        if self.current_position >= self.total() - self.current_batch_size:\n            return True\n        else:\n            return False\n\n    def input_batch(self):\n        if self.no_batch_left():\n            return None\n        input_batch = np.zeros(\n            (self.current_batch_size, self.current_input_length) +\n            tuple(self.data['dims'][0])).astype(self.input_data_type)\n        input_batch = np.transpose(input_batch,(0,1,3,4,2))\n        for i in range(self.current_batch_size):\n            batch_ind = self.current_batch_indices[i]\n            begin = self.data['clips'][0, batch_ind, 0]\n            end = self.data['clips'][0, batch_ind, 0] + \\\n                    self.data['clips'][0, batch_ind, 1]\n            data_slice = self.data['input_raw_data'][begin:end, :, :, :]\n            data_slice = np.transpose(data_slice,(0,2,3,1))\n            input_batch[i, :self.current_input_length, :, :, :] = data_slice\n        input_batch = input_batch.astype(self.input_data_type)\n        return input_batch\n\n    def output_batch(self):\n        if self.no_batch_left():\n            return None\n        if(2 ,3) == self.data['dims'].shape:\n            raw_dat = self.data['output_raw_data']\n        else:\n            raw_dat = self.data['input_raw_data']\n        if self.is_output_sequence:\n            if (1, 3) == self.data['dims'].shape:\n                output_dim = self.data['dims'][0]\n            else:\n                output_dim = self.data['dims'][1]\n            output_batch = np.zeros(\n                (self.current_batch_size,self.current_output_length) +\n                tuple(output_dim))\n        else:\n            output_batch = np.zeros((self.current_batch_size, ) +\n                                    tuple(self.data['dims'][1]))\n        for i in range(self.current_batch_size):\n            batch_ind = self.current_batch_indices[i]\n            begin = self.data['clips'][1, batch_ind, 0]\n            end = self.data['clips'][1, batch_ind, 0] + \\\n                    self.data['clips'][1, batch_ind, 1]\n            if self.is_output_sequence:\n                data_slice = raw_dat[begin:end, :, :, :]\n                output_batch[i, : data_slice.shape[0], :, :, :] = data_slice\n            else:\n                data_slice = raw_dat[begin, :, :, :]\n                output_batch[i,:, :, :] = data_slice\n        output_batch = output_batch.astype(self.output_data_type)\n        output_batch = np.transpose(output_batch, [0,1,3,4,2])\n        return output_batch\n\n    def get_batch(self):\n        input_seq = self.input_batch()\n        output_seq = self.output_batch()\n        batch = np.concatenate((input_seq, output_seq), axis=1)\n        return batch\n"
  },
  {
    "path": "src/data_provider/taxibj.py",
    "content": "__author__ = 'jianjin'\n\nimport random\nimport os.path\nimport logging\nimport os\nfrom copy import copy\nimport numpy as np\nimport h5py\nimport pandas as pd\nfrom datetime import datetime\nimport time\n\nlogger = logging.getLogger(__name__)\n\n\ndef string2timestamp(strings, T=48):\n    timestamps = []\n\n    time_per_slot = 24.0 / T\n    num_per_T = T // 24\n    for t in strings:\n        year, month, day, slot = int(t[:4]), int(t[4:6]), int(t[6:8]), int(t[8:])-1\n        timestamps.append(pd.Timestamp(datetime(year, month, day, hour=int(slot * time_per_slot),\n                                                minute=(slot % num_per_T) * int(60.0 * time_per_slot))))\n\n    return timestamps\n\n\nclass STMatrix(object):\n    \"\"\"docstring for STMatrix\"\"\"\n\n    def __init__(self, data, timestamps, T=48, CheckComplete=True):\n        super(STMatrix, self).__init__()\n        assert len(data) == len(timestamps)\n        self.data = data\n        self.timestamps = timestamps\n        self.T = T\n        self.pd_timestamps = string2timestamp(timestamps, T=self.T)\n        if CheckComplete:\n            self.check_complete()\n        # index\n        self.make_index()\n\n    def make_index(self):\n        self.get_index = dict()\n        for i, ts in enumerate(self.pd_timestamps):\n            self.get_index[ts] = i\n\n    def check_complete(self):\n        missing_timestamps = []\n        offset = pd.DateOffset(minutes=24 * 60 // self.T)\n        pd_timestamps = self.pd_timestamps\n        i = 1\n        while i < len(pd_timestamps):\n            if pd_timestamps[i-1] + offset != pd_timestamps[i]:\n                missing_timestamps.append(\"(%s -- %s)\" % (pd_timestamps[i-1], pd_timestamps[i]))\n            i += 1\n        for v in missing_timestamps:\n            print(v)\n        assert len(missing_timestamps) == 0\n\n    def get_matrix(self, timestamp):\n        return self.data[self.get_index[timestamp]]\n\n    def save(self, fname):\n        pass\n\n    def check_it(self, depends):\n        for d in depends:\n            if d not in self.get_index.keys():\n                return False\n        return True\n\n    def create_dataset(self, len_closeness=20):\n        \"\"\"current version\n        \"\"\"\n        # offset_week = pd.DateOffset(days=7)\n        offset_frame = pd.DateOffset(minutes=24 * 60 // self.T)\n        XC = []\n        timestamps_Y = []\n        depends = [range(1, len_closeness+1)]\n\n        i = len_closeness\n        while i < len(self.pd_timestamps):\n            Flag = True\n            for depend in depends:\n                if Flag is False:\n                    break\n                Flag = self.check_it([self.pd_timestamps[i] - j * offset_frame for j in depend])\n\n            if Flag is False:\n                i += 1\n                continue\n            x_c = [np.transpose(self.get_matrix(self.pd_timestamps[i] - j * offset_frame), [1, 2, 0]) for j in depends[0]]\n            if len_closeness > 0:\n                XC.append(np.stack(x_c, axis=0))\n            timestamps_Y.append(self.timestamps[i])\n            i += 1\n        XC = np.stack(XC, axis=0)\n        return XC, timestamps_Y\n\n\ndef load_stdata(fname):\n    f = h5py.File(fname, 'r')\n    data = f['data'].value\n    timestamps = f['date'].value\n    f.close()\n    return data, timestamps\n\n\ndef stat(fname):\n    def get_nb_timeslot(f):\n        s = f['date'][0]\n        e = f['date'][-1]\n        year, month, day = map(int, [s[:4], s[4:6], s[6:8]])\n        ts = time.strptime(\"%04i-%02i-%02i\" % (year, month, day), \"%Y-%m-%d\")\n        year, month, day = map(int, [e[:4], e[4:6], e[6:8]])\n        te = time.strptime(\"%04i-%02i-%02i\" % (year, month, day), \"%Y-%m-%d\")\n        nb_timeslot = (time.mktime(te) - time.mktime(ts)) / (0.5 * 3600) + 48\n        ts_str, te_str = time.strftime(\"%Y-%m-%d\", ts), time.strftime(\"%Y-%m-%d\", te)\n        return nb_timeslot, ts_str, te_str\n\n    with h5py.File(fname, 'r') as f:\n        nb_timeslot, ts_str, te_str = get_nb_timeslot(f)\n        nb_day = int(nb_timeslot / 48)\n        mmax = f['data'].value.max()\n        mmin = f['data'].value.min()\n        stat = '=' * 5 + 'stat' + '=' * 5 + '\\n' + \\\n               'data shape: %s\\n' % str(f['data'].shape) + \\\n               '# of days: %i, from %s to %s\\n' % (nb_day, ts_str, te_str) + \\\n               '# of timeslots: %i\\n' % int(nb_timeslot) + \\\n               '# of timeslots (available): %i\\n' % f['date'].shape[0] + \\\n               'missing ratio of timeslots: %.1f%%\\n' % ((1. - float(f['date'].shape[0] / nb_timeslot)) * 100) + \\\n               'max: %.3f, min: %.3f\\n' % (mmax, mmin) + \\\n               '=' * 5 + 'stat' + '=' * 5\n        print(stat)\n\n\nclass MinMaxNormalization(object):\n    '''MinMax Normalization --> [-1, 1]\n       x = (x - min) / (max - min).\n       x = x * 2 - 1\n    '''\n\n    def __init__(self):\n        pass\n\n    def fit(self, X):\n        self._min = X.min()\n        self._max = X.max()\n        print(\"min:\", self._min, \"max:\", self._max)\n\n    def transform(self, X):\n        X = 1. * (X - self._min) / (self._max - self._min)\n        # X = X * 2. - 1.\n        return X\n\n    def fit_transform(self, X):\n        self.fit(X)\n        return self.transform(X)\n\n    def inverse_transform(self, X):\n        X = (X + 1.) / 2.\n        X = 1. * X * (self._max - self._min) + self._min\n        return X\n\n\ndef timestamp2vec(timestamps):\n    # tm_wday range [0, 6], Monday is 0\n    # vec = [time.strptime(str(t[:8], encoding='utf-8'), '%Y%m%d').tm_wday for t in timestamps]  # python3\n    vec = [time.strptime(t[:8], '%Y%m%d').tm_wday for t in timestamps]  # python2\n    ret = []\n    for i in vec:\n        v = [0 for _ in range(7)]\n        v[i] = 1\n        if i >= 5:\n            v.append(0)  # weekend\n        else:\n            v.append(1)  # weekday\n        ret.append(v)\n    return np.asarray(ret)\n\n\ndef remove_incomplete_days(data, timestamps, T=48):\n    # remove a certain day which has not 48 timestamps\n    days = []  # available days: some day only contain some seqs\n    days_incomplete = []\n    i = 0\n    while i < len(timestamps):\n        if int(timestamps[i][8:]) != 1:\n            i += 1\n        elif i+T-1 < len(timestamps) and int(timestamps[i+T-1][8:]) == T:\n            days.append(timestamps[i][:8])\n            i += T\n        else:\n            days_incomplete.append(timestamps[i][:8])\n            i += 1\n    print(\"incomplete days: \", days_incomplete)\n    days = set(days)\n    idx = []\n    for i, t in enumerate(timestamps):\n        if t[:8] in days:\n            idx.append(i)\n\n    data = data[idx]\n    timestamps = [timestamps[i] for i in idx]\n    return data, timestamps\n\n\nclass InputHandle:\n    def __init__(self, datas, indices, input_param):\n        self.name = input_param['name']\n        self.input_data_type = input_param.get('input_data_type', 'float32')\n        self.minibatch_size = input_param['minibatch_size']\n        self.image_width = input_param['image_width']\n        self.datas = datas\n        self.indices = indices\n        self.current_position = 0\n        self.current_batch_indices = []\n        self.current_input_length = input_param['seq_length']\n\n    def total(self):\n        return len(self.indices)\n\n    def begin(self, do_shuffle=True):\n        logger.info(\"Initialization for read data \")\n        if do_shuffle:\n            random.shuffle(self.indices)\n        self.current_position = 0\n        self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]\n\n    def next(self):\n        self.current_position += self.minibatch_size\n        if self.no_batch_left():\n            return None\n        self.current_batch_indices = self.indices[self.current_position:self.current_position + self.minibatch_size]\n\n    def no_batch_left(self):\n        if self.current_position + self.minibatch_size >= self.total():\n            return True\n        else:\n            return False\n\n    def get_batch(self):\n        if self.no_batch_left():\n            logger.error(\n                \"There is no batch left in \" + self.name + \". Consider to user iterators.begin() to rescan from the beginning of the iterators\")\n            return None\n        input_batch = self.datas[self.current_batch_indices, :, :, :]\n        input_batch = input_batch.astype(self.input_data_type)\n        return input_batch\n\n    def print_stat(self):\n        logger.info(\"Iterator Name: \" + self.name)\n        logger.info(\"    current_position: \" + str(self.current_position))\n        logger.info(\"    Minibatch Size: \" + str(self.minibatch_size))\n        logger.info(\"    total Size: \" + str(self.total()))\n        logger.info(\"    current_input_length: \" + str(self.current_input_length))\n        logger.info(\"    Input Data Type: \" + str(self.input_data_type))\n\n\nclass DataProcess:\n    def __init__(self, input_param):\n        self.paths = input_param['paths']\n        self.image_width = input_param['image_width']\n\n        self.input_param = input_param\n        self.seq_len = input_param['seq_length']\n        self.train_data, self.test_data, _, _, _ = self.load_data(self.paths, len_closeness=input_param['seq_length'])\n        self.train_indices = list(range(self.train_data.shape[0]))\n        self.test_indices = list(range(self.test_data.shape[0]))\n\n    def load_data(self, datapath, T=48, nb_flow=2, len_closeness=None, len_test=48 * 7 * 4):\n        \"\"\"\n        \"\"\"\n        assert (len_closeness > 0)\n        # load data\n        # 13 - 16\n        data_all = []\n        timestamps_all = list()\n        for year in range(13, 17):\n            fname = os.path.join(\n                datapath[0], 'BJ{}_M32x32_T30_InOut.h5'.format(year))\n            print(\"file name: \", fname)\n            stat(fname)\n            data, timestamps = load_stdata(fname)\n            # print(timestamps)\n            # remove a certain day which does not have 48 timestamps\n            data, timestamps = remove_incomplete_days(data, timestamps, T)\n            data = data[:, :nb_flow]\n            data[data < 0] = 0.\n            data_all.append(data)\n            timestamps_all.append(timestamps)\n            print(\"\\n\")\n\n        # minmax_scale\n        data_train = np.vstack(copy(data_all))[:-len_test]\n        print('train_data shape: ', data_train.shape)\n        mmn = MinMaxNormalization()\n        mmn.fit(data_train)\n        data_all_mmn = [mmn.transform(d) for d in data_all]\n\n        XC = []\n        timestamps_Y = []\n        for data, timestamps in zip(data_all_mmn, timestamps_all):\n            # instance-based dataset --> sequences with format as (X, Y) where X is\n            # a sequence of images and Y is an image.\n            st = STMatrix(data, timestamps, T, CheckComplete=False)\n            _XC, _timestamps_Y = st.create_dataset(len_closeness=len_closeness)\n            XC.append(_XC)\n            timestamps_Y += _timestamps_Y\n        XC = np.concatenate(XC, axis=0)\n        print(\"XC shape: \", XC.shape)\n\n        XC_train = XC[:-len_test]\n        XC_test = XC[-len_test:]\n        timestamp_train, timestamp_test = timestamps_Y[:-len_test], timestamps_Y[-len_test:]\n\n        X_train = XC_train\n        X_test = XC_test\n        print('train shape:', XC_train.shape,\n              'test shape: ', XC_test.shape)\n\n        return X_train, X_test, mmn, timestamp_train, timestamp_test\n\n    def get_train_input_handle(self):\n        return InputHandle(self.train_data, self.train_indices, self.input_param)\n\n    def get_test_input_handle(self):\n        return InputHandle(self.test_data, self.test_indices, self.input_param)\n"
  },
  {
    "path": "src/layers/MIMBlock.py",
    "content": "import tensorflow as tf\nfrom src.layers.TensorLayerNorm import tensor_layer_norm\nimport math\n\n\nclass MIMBlock():\n    def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden,\n                 seq_shape, tln=False, initializer=None):\n        \"\"\"Initialize the basic Conv LSTM cell.\n        Args:\n            layer_name: layer names for different convlstm layers.\n            filter_size: int tuple thats the height and width of the filter.\n            num_hidden: number of units in output tensor.\n            forget_bias: float, The bias added to forget gates (see above).\n            tln: whether to apply tensor layer normalization\n        \"\"\"\n        self.layer_name = layer_name\n        self.filter_size = filter_size\n        self.num_hidden_in = num_hidden_in\n        self.num_hidden = num_hidden\n        self.convlstm_c = None\n        self.batch = seq_shape[0]\n        self.height = seq_shape[2]\n        self.width = seq_shape[3]\n        self.layer_norm = tln\n        self._forget_bias = 1.0\n\n        def w_initializer(dim_in, dim_out):\n            random_range = math.sqrt(6.0 / (dim_in + dim_out))\n            return tf.random_uniform_initializer(-random_range, random_range)\n        if initializer is None or initializer == -1:\n            self.initializer = w_initializer\n        else:\n            self.initializer = tf.random_uniform_initializer(-initializer, initializer)\n\n    def init_state(self):\n        return tf.zeros([self.batch, self.height, self.width, self.num_hidden],\n                        dtype=tf.float32)\n\n    def MIMS(self, x, h_t, c_t):\n        if h_t is None:\n            h_t = self.init_state()\n        if c_t is None:\n            c_t = self.init_state()\n        with tf.variable_scope(self.layer_name):\n            h_concat = tf.layers.conv2d(h_t, self.num_hidden * 4,\n                                        self.filter_size, 1, padding='same',\n                                        kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4),\n                                        name='state_to_state')\n            if self.layer_norm:\n                h_concat = tensor_layer_norm(h_concat, 'state_to_state')\n            i_h, g_h, f_h, o_h = tf.split(h_concat, 4, 3)\n\n            ct_weight = tf.get_variable(\n                'c_t_weight', [self.height,self.width,self.num_hidden*2])\n            ct_activation = tf.multiply(tf.tile(c_t, [1,1,1,2]), ct_weight)\n            i_c, f_c = tf.split(ct_activation, 2, 3)\n\n            i_ = i_h + i_c\n            f_ = f_h + f_c\n            g_ = g_h\n            o_ = o_h\n\n            if x != None:\n                x_concat = tf.layers.conv2d(x, self.num_hidden * 4,\n                                            self.filter_size, 1,\n                                            padding='same',\n                                            kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4),\n                                            name='input_to_state')\n                if self.layer_norm:\n                    x_concat = tensor_layer_norm(x_concat, 'input_to_state')\n                i_x, g_x, f_x, o_x = tf.split(x_concat, 4, 3)\n\n                i_ += i_x\n                f_ += f_x\n                g_ += g_x\n                o_ += o_x\n\n            i_ = tf.nn.sigmoid(i_)\n            f_ = tf.nn.sigmoid(f_ + self._forget_bias)\n            c_new = f_ * c_t + i_ * tf.nn.tanh(g_)\n\n            oc_weight = tf.get_variable(\n                'oc_weight', [self.height,self.width,self.num_hidden])\n            o_c = tf.multiply(c_new, oc_weight)\n\n            h_new = tf.nn.sigmoid(o_ + o_c) * tf.nn.tanh(c_new)\n\n            return h_new, c_new\n\n    def __call__(self, x, diff_h, h, c, m):\n        if h is None:\n            h = self.init_state()\n        if c is None:\n            c = self.init_state()\n        if m is None:\n            m = self.init_state()\n        if diff_h is None:\n            diff_h = tf.zeros_like(h)\n\n        with tf.variable_scope(self.layer_name):\n            t_cc = tf.layers.conv2d(\n                h, self.num_hidden * 3,\n                self.filter_size, 1, padding='same',\n                kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 3),\n                name='time_state_to_state')\n            s_cc = tf.layers.conv2d(\n                m, self.num_hidden * 4,\n                self.filter_size, 1, padding='same',\n                kernel_initializer=self.initializer(self.num_hidden, self.num_hidden * 4),\n                name='spatio_state_to_state')\n            x_shape_in = x.get_shape().as_list()[-1]\n            x_cc = tf.layers.conv2d(\n                x, self.num_hidden * 4,\n                self.filter_size, 1, padding='same',\n                kernel_initializer=self.initializer(x_shape_in, self.num_hidden * 4),\n                name='input_to_state')\n            if self.layer_norm:\n                t_cc = tensor_layer_norm(t_cc, 'time_state_to_state')\n                s_cc = tensor_layer_norm(s_cc, 'spatio_state_to_state')\n                x_cc = tensor_layer_norm(x_cc, 'input_to_state')\n\n            i_s, g_s, f_s, o_s = tf.split(s_cc, 4, 3)\n            i_t, g_t, o_t = tf.split(t_cc, 3, 3)\n            i_x, g_x, f_x, o_x = tf.split(x_cc, 4, 3)\n\n            i = tf.nn.sigmoid(i_x + i_t)\n            i_ = tf.nn.sigmoid(i_x + i_s)\n            g = tf.nn.tanh(g_x + g_t)\n            g_ = tf.nn.tanh(g_x + g_s)\n            f_ = tf.nn.sigmoid(f_x + f_s + self._forget_bias)\n            o = tf.nn.sigmoid(o_x + o_t + o_s)\n            new_m = f_ * m + i_ * g_\n            c, self.convlstm_c = self.MIMS(diff_h, c, self.convlstm_c)\n            new_c = c + i * g\n            cell = tf.concat([new_c, new_m], 3)\n            cell = tf.layers.conv2d(cell, self.num_hidden, 1, 1,\n                                    padding='same', name='cell_reduce')\n            new_h = o * tf.nn.tanh(cell)\n\n            return new_h, new_c, new_m\n"
  },
  {
    "path": "src/layers/MIMN.py",
    "content": "import tensorflow as tf\nfrom src.layers.TensorLayerNorm import tensor_layer_norm\n\nclass MIMN():\n    def __init__(self, layer_name, filter_size, num_hidden, seq_shape, tln=True, initializer=0.001):\n        \"\"\"Initialize the basic Conv LSTM cell.\n        Args:\n            layer_name: layer names for different convlstm layers.\n            filter_size: int tuple thats the height and width of the filter.\n            num_hidden: number of units in output tensor.\n            tln: whether to apply tensor layer normalization.\n        \"\"\"\n        self.layer_name = layer_name\n        self.filter_size = filter_size\n        self.num_hidden = num_hidden\n        self.layer_norm = tln\n        self.batch = seq_shape[0]\n        self.height = seq_shape[2]\n        self.width = seq_shape[3]\n        self._forget_bias = 1.0\n        if initializer == -1:\n            self.initializer = None\n        else:\n            self.initializer = tf.random_uniform_initializer(-initializer,initializer)\n\n    def init_state(self):\n        shape = [self.batch, self.height, self.width, self.num_hidden]\n        return tf.zeros(shape, dtype=tf.float32)\n\n    def __call__(self, x, h_t, c_t):\n        if h_t is None:\n            h_t = self.init_state()\n        if c_t is None:\n            c_t = self.init_state()\n        with tf.variable_scope(self.layer_name):\n            h_concat = tf.layers.conv2d(h_t, self.num_hidden * 4,\n                                        self.filter_size, 1, padding='same',\n                                        kernel_initializer=self.initializer,\n                                        name='state_to_state')\n            if self.layer_norm:\n                h_concat = tensor_layer_norm(h_concat, 'state_to_state')\n            i_h, g_h, f_h, o_h = tf.split(h_concat, 4, 3)\n\n            ct_weight = tf.get_variable(\n                'c_t_weight', [self.height,self.width,self.num_hidden*2])\n            ct_activation = tf.multiply(tf.tile(c_t, [1,1,1,2]), ct_weight)\n            i_c, f_c = tf.split(ct_activation, 2, 3)\n\n            i_ = i_h + i_c\n            f_ = f_h + f_c\n            g_ = g_h\n            o_ = o_h\n\n            if x != None:\n                x_concat = tf.layers.conv2d(x, self.num_hidden * 4,\n                                            self.filter_size, 1,\n                                            padding='same',\n                                            kernel_initializer=self.initializer,\n                                            name='input_to_state')\n                if self.layer_norm:\n                    x_concat = tensor_layer_norm(x_concat, 'input_to_state')\n                i_x, g_x, f_x, o_x = tf.split(x_concat, 4, 3)\n\n                i_ += i_x\n                f_ += f_x\n                g_ += g_x\n                o_ += o_x\n\n            i_ = tf.nn.sigmoid(i_)\n            f_ = tf.nn.sigmoid(f_ + self._forget_bias)\n            c_new = f_ * c_t + i_ * tf.nn.tanh(g_)\n\n            oc_weight = tf.get_variable(\n                'oc_weight', [self.height,self.width,self.num_hidden])\n            o_c = tf.multiply(c_new, oc_weight)\n\n            h_new = tf.nn.sigmoid(o_ + o_c) * tf.nn.tanh(c_new)\n\n            return h_new, c_new\n\n"
  },
  {
    "path": "src/layers/SpatioTemporalLSTMCellv2.py",
    "content": "import math\n\nimport tensorflow as tf\nfrom src.layers.TensorLayerNorm import tensor_layer_norm\n\nclass SpatioTemporalLSTMCell():\n    def __init__(self, layer_name, filter_size, num_hidden_in, num_hidden,\n                 seq_shape, tln=False, initializer=None):\n        \"\"\"Initialize the basic Conv LSTM cell.\n        Args:\n            layer_name: layer names for different convlstm layers.\n            filter_size: int tuple thats the height and width of the filter.\n            num_hidden: number of units in output tensor.\n            forget_bias: float, The bias added to forget gates (see above).\n            tln: whether to apply tensor layer normalization\n        \"\"\"\n        self.layer_name = layer_name\n        self.filter_size = filter_size\n        self.num_hidden_in = num_hidden_in\n        self.num_hidden = num_hidden\n        self.batch = seq_shape[0]\n        self.height = seq_shape[2]\n        self.width = seq_shape[3]\n        self.layer_norm = tln\n        self._forget_bias = 1.0\n\n        def w_initializer(dim_in, dim_out):\n            random_range = math.sqrt(6.0 / (dim_in + dim_out))\n            return tf.random_uniform_initializer(-random_range, random_range)\n        if initializer is None or initializer == -1:\n            self.initializer = w_initializer\n        else:\n            self.initializer = tf.random_uniform_initializer(-initializer, initializer)\n\n    def init_state(self):\n        return tf.zeros([self.batch, self.height, self.width, self.num_hidden],\n                        dtype=tf.float32)\n\n    def __call__(self, x, h, c, m):\n        if h is None:\n            h = self.init_state()\n        if c is None:\n            c = self.init_state()\n        if m is None:\n            m = self.init_state()\n\n        with tf.variable_scope(self.layer_name):\n            t_cc = tf.layers.conv2d(\n                h, self.num_hidden*4,\n                self.filter_size, 1, padding='same',\n                kernel_initializer=self.initializer(self.num_hidden_in, self.num_hidden*4),\n                name='time_state_to_state')\n            s_cc = tf.layers.conv2d(\n                m, self.num_hidden*4,\n                self.filter_size, 1, padding='same',\n                kernel_initializer=self.initializer(self.num_hidden_in, self.num_hidden*4),\n                name='spatio_state_to_state')\n            x_shape_in = x.get_shape().as_list()[-1]\n            x_cc = tf.layers.conv2d(\n                x, self.num_hidden*4,\n                self.filter_size, 1, padding='same',\n                kernel_initializer=self.initializer(x_shape_in, self.num_hidden*4),\n                name='input_to_state')\n            if self.layer_norm:\n                t_cc = tensor_layer_norm(t_cc, 'time_state_to_state')\n                s_cc = tensor_layer_norm(s_cc, 'spatio_state_to_state')\n                x_cc = tensor_layer_norm(x_cc, 'input_to_state')\n\n            i_s, g_s, f_s, o_s = tf.split(s_cc, 4, 3)\n            i_t, g_t, f_t, o_t = tf.split(t_cc, 4, 3)\n            i_x, g_x, f_x, o_x = tf.split(x_cc, 4, 3)\n\n            i = tf.nn.sigmoid(i_x + i_t)\n            i_ = tf.nn.sigmoid(i_x + i_s)\n            g = tf.nn.tanh(g_x + g_t)\n            g_ = tf.nn.tanh(g_x + g_s)\n            f = tf.nn.sigmoid(f_x + f_t + self._forget_bias)\n            f_ = tf.nn.sigmoid(f_x + f_s + self._forget_bias)\n            o = tf.nn.sigmoid(o_x + o_t + o_s)\n            new_m = f_ * m + i_ * g_\n            new_c = f * c + i * g\n            cell = tf.concat([new_c, new_m],3)\n            cell = tf.layers.conv2d(cell, self.num_hidden, 1, 1, padding='same',\n                                    kernel_initializer=self.initializer(self.num_hidden*2, self.num_hidden),\n                                    name='cell_reduce')\n            new_h = o * tf.nn.tanh(cell)\n\n            return new_h, new_c, new_m\n\n\n\n\n\n\n\n"
  },
  {
    "path": "src/layers/TensorLayerNorm.py",
    "content": "import tensorflow as tf\n\nEPSILON = 0.00001\n\n\ndef tensor_layer_norm(x, state_name):\n    x_shape = x.get_shape()\n    dims = x_shape.ndims\n    params_shape = x_shape[-1:]\n    if dims == 4:\n        m, v = tf.nn.moments(x, [1,2,3], keep_dims=True)\n    elif dims == 5:\n        m, v = tf.nn.moments(x, [1,2,3,4], keep_dims=True)\n    elif dims == 2:\n        m, v = tf.nn.moments(x, [1], keep_dims=True)\n    else:\n        raise ValueError('input tensor for layer normalization must be rank 4 or 5.')\n    b = tf.get_variable(state_name+'b',initializer=tf.zeros(params_shape))\n    s = tf.get_variable(state_name+'s',initializer=tf.ones(params_shape))\n    x_tln = tf.nn.batch_normalization(x, m, v, b, s, EPSILON)\n    return x_tln\n"
  },
  {
    "path": "src/layers/__init__.py",
    "content": ""
  },
  {
    "path": "src/models/__init__.py",
    "content": ""
  },
  {
    "path": "src/models/mim.py",
    "content": "__author__ = 'jianjin'\n\nimport tensorflow as tf\nfrom src.layers.SpatioTemporalLSTMCellv2 import SpatioTemporalLSTMCell as stlstm\nfrom src.layers.MIMBlock import MIMBlock as mimblock\nfrom src.layers.MIMN import MIMN as mimn\nimport math\n\n\ndef w_initializer(dim_in, dim_out):\n    random_range = math.sqrt(6.0 / (dim_in + dim_out))\n    return tf.random_uniform_initializer(-random_range, random_range)\n\n\ndef mim(images, params, schedual_sampling_bool, num_layers, num_hidden, filter_size,\n        stride=1, total_length=20, input_length=10, tln=True):\n    gen_images = []\n    stlstm_layer = []\n    stlstm_layer_diff = []\n    cell_state = []\n    hidden_state = []\n    cell_state_diff = []\n    hidden_state_diff = []\n    shape = images.get_shape().as_list()\n    output_channels = shape[-1]\n\n    for i in range(num_layers):\n        if i == 0:\n            num_hidden_in = num_hidden[num_layers - 1]\n        else:\n            num_hidden_in = num_hidden[i - 1]\n        if i < 1:\n            new_stlstm_layer = stlstm('stlstm_' + str(i + 1),\n                                      filter_size,\n                                      num_hidden_in,\n                                      num_hidden[i],\n                                      shape,\n                                      tln=tln)\n        else:\n            new_stlstm_layer = mimblock('stlstm_' + str(i + 1),\n                                        filter_size,\n                                        num_hidden_in,\n                                        num_hidden[i],\n                                        shape,\n                                        tln=tln)\n        stlstm_layer.append(new_stlstm_layer)\n        cell_state.append(None)\n        hidden_state.append(None)\n\n    for i in range(num_layers - 1):\n        new_stlstm_layer = mimn('stlstm_diff' + str(i + 1),\n                                filter_size,\n                                num_hidden[i + 1],\n                                shape,\n                                tln=tln)\n        stlstm_layer_diff.append(new_stlstm_layer)\n        cell_state_diff.append(None)\n        hidden_state_diff.append(None)\n\n    st_memory = None\n\n    for time_step in range(total_length - 1):\n        reuse = bool(gen_images)\n        with tf.variable_scope('predrnn', reuse=reuse):\n            if time_step < input_length:\n                x_gen = images[:,time_step]\n            else:\n                x_gen = schedual_sampling_bool[:,time_step-input_length]*images[:,time_step] + \\\n                        (1-schedual_sampling_bool[:,time_step-input_length])*x_gen\n            preh = hidden_state[0]\n            hidden_state[0], cell_state[0], st_memory = stlstm_layer[0](\n                x_gen, hidden_state[0], cell_state[0], st_memory)\n            for i in range(1, num_layers):\n                if time_step > 0:\n                    if i == 1:\n                        hidden_state_diff[i - 1], cell_state_diff[i - 1] = stlstm_layer_diff[i - 1](\n                            hidden_state[i - 1] - preh, hidden_state_diff[i - 1], cell_state_diff[i - 1])\n                    else:\n                        hidden_state_diff[i - 1], cell_state_diff[i - 1] = stlstm_layer_diff[i - 1](\n                            hidden_state_diff[i - 2], hidden_state_diff[i - 1], cell_state_diff[i - 1])\n                else:\n                    stlstm_layer_diff[i - 1](tf.zeros_like(hidden_state[i - 1]), None, None)\n                preh = hidden_state[i]\n                hidden_state[i], cell_state[i], st_memory = stlstm_layer[i](\n                    hidden_state[i - 1], hidden_state_diff[i - 1], hidden_state[i], cell_state[i], st_memory)\n            x_gen = tf.layers.conv2d(hidden_state[num_layers - 1],\n                                     filters=output_channels,\n                                     kernel_size=1,\n                                     strides=1,\n                                     padding='same',\n                                     kernel_initializer=w_initializer(num_hidden[num_layers - 1], output_channels),\n                                     name=\"back_to_pixel\")\n            gen_images.append(x_gen)\n\n    gen_images = tf.stack(gen_images, axis=1)\n    loss = tf.nn.l2_loss(gen_images - images[:, 1:])\n\n    return [gen_images, loss]\n"
  },
  {
    "path": "src/models/model_factory.py",
    "content": "import os\n\nimport tensorflow as tf\n\nfrom src.utils import optimizer\nfrom src.models import mim\n\n\nclass Model(object):\n    def __init__(self, configs):\n        self.configs = configs\n        # inputs\n        if configs.img_height > 0:\n            height = configs.img_height\n        else:\n            height = configs.img_width\n        self.x = [tf.placeholder(tf.float32,\n                                 [self.configs.batch_size,\n                                  self.configs.total_length,\n                                  self.configs.img_width // self.configs.patch_size,\n                                  height // self.configs.patch_size,\n                                  self.configs.patch_size * self.configs.patch_size * self.configs.img_channel])\n                  for i in range(self.configs.n_gpu)]\n\n        self.real_input_flag = tf.placeholder(tf.float32,\n                                        [self.configs.batch_size,\n                                         self.configs.total_length - self.configs.input_length - 1,\n                                         self.configs.img_width // self.configs.patch_size,\n                                         height // self.configs.patch_size,\n                                         self.configs.patch_size * self.configs.patch_size * self.configs.img_channel])\n\n        grads = []\n        loss_train = []\n        self.pred_seq = []\n        self.tf_lr = tf.placeholder(tf.float32, shape=[])\n        self.params = dict()\n        self.params.update(self.configs.__dict__['__flags'])\n        num_hidden = [int(x) for x in self.configs.num_hidden.split(',')]\n        num_layers = len(num_hidden)\n        for i in range(self.configs.n_gpu):\n            with tf.device('/gpu:%d' % i):\n                with tf.variable_scope(tf.get_variable_scope(),\n                                       reuse=True if i > 0 else None):\n                    # define a model\n                    output_list = self.construct_model(\n                        self.configs.model_name,\n                        self.x[i],\n                        self.params,\n                        self.real_input_flag,\n                        num_layers,\n                        num_hidden,\n                        self.configs.filter_size,\n                        self.configs.stride,\n                        self.configs.total_length,\n                        self.configs.input_length,\n                        self.configs.layer_norm)\n\n                    gen_ims = output_list[0]\n                    loss = output_list[1]\n                    if len(output_list) > 2:\n                        self.debug = output_list[2]\n                    else:\n                        self.debug = []\n                    pred_ims = gen_ims[:, self.configs.input_length - self.configs.total_length:]\n                    loss_train.append(loss / self.configs.batch_size)\n                    # gradients\n                    all_params = tf.trainable_variables()\n                    grads.append(tf.gradients(loss, all_params))\n                    self.pred_seq.append(pred_ims)\n\n        # add losses and gradients together and get training updates\n        with tf.device('/gpu:0'):\n            for i in range(1, self.configs.n_gpu):\n                loss_train[0] += loss_train[i]\n                for j in range(len(grads[0])):\n                    grads[0][j] += grads[i][j]\n        # keep track of moving average\n        ema = tf.train.ExponentialMovingAverage(decay=0.9995)\n        maintain_averages_op = tf.group(ema.apply(all_params))\n        self.train_op = tf.group(optimizer.adam_updates(\n            all_params, grads[0], lr=self.tf_lr, mom1=0.95, mom2=0.9995),\n            maintain_averages_op)\n\n        self.loss_train = loss_train[0] / self.configs.n_gpu\n\n        # session\n        variables = tf.global_variables()\n        self.saver = tf.train.Saver(variables)\n        init = tf.global_variables_initializer()\n        configProt = tf.ConfigProto()\n        configProt.gpu_options.allow_growth = configs.allow_gpu_growth\n        configProt.allow_soft_placement = True\n        self.sess = tf.Session(config=configProt)\n        self.sess.run(init)\n        if self.configs.pretrained_model:\n            self.saver.restore(self.sess, self.configs.pretrained_model)\n\n    def train(self, inputs, lr, real_input_flag):\n        feed_dict = {self.x[i]: inputs[i] for i in range(self.configs.n_gpu)}\n        feed_dict.update({self.tf_lr: lr})\n        feed_dict.update({self.real_input_flag: real_input_flag})\n        loss, _, debug = self.sess.run((self.loss_train, self.train_op, self.debug), feed_dict)\n        return loss\n\n    def test(self, inputs, real_input_flag):\n        feed_dict = {self.x[i]: inputs[i] for i in range(self.configs.n_gpu)}\n        feed_dict.update({self.real_input_flag: real_input_flag})\n        gen_ims, debug = self.sess.run((self.pred_seq, self.debug), feed_dict)\n        return gen_ims, debug\n\n    def save(self, itr):\n        checkpoint_path = os.path.join(self.configs.save_dir, 'model.ckpt')\n        self.saver.save(self.sess, checkpoint_path, global_step=itr)\n        print('saved to ' + self.configs.save_dir)\n\n    def load(self, checkpoint_path):\n        print('load model:', checkpoint_path)\n        self.saver.restore(self.sess, checkpoint_path)\n\n    def construct_model(self, name, images, model_params, real_input_flag, num_layers, num_hidden,\n                        filter_size, stride, total_length, input_length, tln):\n        '''Returns a sequence of generated frames\n        Args:\n            name: [predrnn_pp]\n            params: dict for extra parameters of some models\n            real_input_flag: for schedualed sampling.\n            num_hidden: number of units in a lstm layer.\n            filter_size: for convolutions inside lstm.\n            stride: for convolutions inside lstm.\n            total_length: including ins and outs.\n            input_length: for inputs.\n            tln: whether to apply tensor layer normalization.\n        Returns:\n            gen_images: a seq of frames.\n            loss: [l2 / l1+l2].\n        Raises:\n            ValueError: If network `name` is not recognized.\n        '''\n\n        networks_map = {\n            'mim': mim.mim,\n        }\n\n        params = dict(mask=real_input_flag, num_layers=num_layers, num_hidden=num_hidden, filter_size=filter_size,\n                      stride=stride, total_length=total_length, input_length=input_length, is_training=True)\n        params.update(model_params)\n        if name in networks_map:\n            func = networks_map[name]\n            return func(images, params, real_input_flag, num_layers, num_hidden, filter_size,\n                        stride, total_length, input_length, tln)\n        else:\n            raise ValueError('Name of network unknown %s' % name)\n"
  },
  {
    "path": "src/trainer.py",
    "content": "import os.path\nimport datetime\nimport cv2\nimport numpy as np\nfrom skimage.measure import compare_ssim\nfrom src.utils import metrics\nfrom src.utils import preprocess\n\n\ndef train(model, ims, real_input_flag, configs, itr, ims_reverse=None):\n    ims = ims[:, :configs.total_length]\n    ims_list = np.split(ims, configs.n_gpu)\n    cost = model.train(ims_list, configs.lr, real_input_flag)\n\n    flag = 1\n\n    if configs.reverse_img:\n        ims_rev = np.split(ims_reverse, configs.n_gpu)\n        cost += model.train(ims_rev, configs.lr, real_input_flag)\n        flag += 1\n\n    if configs.reverse_input:\n        ims_rev = np.split(ims[:, ::-1], configs.n_gpu)\n        cost += model.train(ims_rev, configs.lr, real_input_flag)\n        flag += 1\n        if configs.reverse_img:\n            ims_rev = np.split(ims_reverse[:, ::-1], configs.n_gpu)\n            cost += model.train(ims_rev, configs.lr, real_input_flag)\n            flag += 1\n\n    cost = cost / flag\n\n    if itr % configs.display_interval == 0:\n        print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'itr: ' + str(itr))\n        print('training loss: ' + str(cost))\n\n\ndef test(model, test_input_handle, configs, save_name):\n    print(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'test...')\n    test_input_handle.begin(do_shuffle=False)\n    res_path = os.path.join(configs.gen_frm_dir, str(save_name))\n    os.mkdir(res_path)\n    avg_mse = 0\n    batch_id = 0\n    img_mse, ssim, psnr, fmae, sharp = [], [], [], [], []\n\n    for i in range(configs.total_length - configs.input_length):\n        img_mse.append(0)\n        ssim.append(0)\n        psnr.append(0)\n        fmae.append(0)\n        sharp.append(0)\n\n    if configs.img_height > 0:\n        height = configs.img_height\n    else:\n        height = configs.img_width\n\n    real_input_flag = np.zeros(\n        (configs.batch_size,\n         configs.total_length - configs.input_length - 1,\n         configs.img_width // configs.patch_size,\n         height // configs.patch_size,\n         configs.patch_size ** 2 * configs.img_channel))\n\n    while not test_input_handle.no_batch_left():\n        batch_id = batch_id + 1\n        if save_name != 'test_result':\n            if batch_id > 100: break\n        test_ims = test_input_handle.get_batch()\n        test_ims = test_ims[:, :configs.total_length]\n        if len(test_ims.shape) > 3:\n            test_dat = preprocess.reshape_patch(test_ims, configs.patch_size)\n        else:\n            test_dat = test_ims\n        test_dat = np.split(test_dat, configs.n_gpu)\n        img_gen, debug = model.test(test_dat, real_input_flag)\n\n        # concat outputs of different gpus along batch\n        img_gen = np.concatenate(img_gen)\n        if len(img_gen.shape) > 3:\n            img_gen = preprocess.reshape_patch_back(img_gen, configs.patch_size)\n        # MSE per frame\n        for i in range(configs.total_length - configs.input_length):\n            x = test_ims[:, i + configs.input_length, :, :, :]\n            x = x[:configs.batch_size * configs.n_gpu]\n            x = x - np.where(x > 10000, np.floor_divide(x, 10000) * 10000, np.zeros_like(x))\n            gx = img_gen[:, i, :, :, :]\n            fmae[i] += metrics.batch_mae_frame_float(gx, x)\n            gx = np.maximum(gx, 0)\n            gx = np.minimum(gx, 1)\n            mse = np.square(x - gx).sum()\n            img_mse[i] += mse\n            avg_mse += mse\n            real_frm = np.uint8(x * 255)\n            pred_frm = np.uint8(gx * 255)\n            psnr[i] += metrics.batch_psnr(pred_frm, real_frm)\n            for b in range(configs.batch_size):\n                sharp[i] += np.max(\n                    cv2.convertScaleAbs(cv2.Laplacian(pred_frm[b], 3)))\n\n                score, _ = compare_ssim(gx[b], x[b], full=True, multichannel=True)\n                ssim[i] += score\n\n        # save prediction examples\n        if batch_id <= configs.num_save_samples:\n            path = os.path.join(res_path, str(batch_id))\n            os.mkdir(path)\n            if len(debug) != 0:\n                np.save(os.path.join(path, \"f.npy\"), debug)\n            for i in range(configs.total_length):\n                name = 'gt' + str(i + 1) + '.png'\n                file_name = os.path.join(path, name)\n                img_gt = np.uint8(test_ims[0, i, :, :, :] * 255)\n                if configs.img_channel == 2:\n                    img_gt = img_gt[:, :, :1]\n                cv2.imwrite(file_name, img_gt)\n            for i in range(configs.total_length - configs.input_length):\n                name = 'pd' + str(i + 1 + configs.input_length) + '.png'\n                file_name = os.path.join(path, name)\n                img_pd = img_gen[0, i, :, :, :]\n                if configs.img_channel == 2:\n                    img_pd = img_pd[:, :, :1]\n                img_pd = np.maximum(img_pd, 0)\n                img_pd = np.minimum(img_pd, 1)\n                img_pd = np.uint8(img_pd * 255)\n                cv2.imwrite(file_name, img_pd)\n        test_input_handle.next()\n\n    avg_mse = avg_mse / (batch_id * configs.batch_size * configs.n_gpu)\n    print('mse per seq: ' + str(avg_mse))\n    for i in range(configs.total_length - configs.input_length):\n        print(img_mse[i] / (batch_id * configs.batch_size * configs.n_gpu))\n\n    psnr = np.asarray(psnr, dtype=np.float32) / batch_id\n    fmae = np.asarray(fmae, dtype=np.float32) / batch_id\n    ssim = np.asarray(ssim, dtype=np.float32) / (configs.batch_size * batch_id)\n    sharp = np.asarray(sharp, dtype=np.float32) / (configs.batch_size * batch_id)\n\n    print('psnr per frame: ' + str(np.mean(psnr)))\n    for i in range(configs.total_length - configs.input_length):\n        print(psnr[i])\n    print('fmae per frame: ' + str(np.mean(fmae)))\n    for i in range(configs.total_length - configs.input_length):\n        print(fmae[i])\n    print('ssim per frame: ' + str(np.mean(ssim)))\n    for i in range(configs.total_length - configs.input_length):\n        print(ssim[i])\n    print('sharpness per frame: ' + str(np.mean(sharp)))\n    for i in range(configs.total_length - configs.input_length):\n        print(sharp[i])\n"
  },
  {
    "path": "src/utils/__init__.py",
    "content": ""
  },
  {
    "path": "src/utils/metrics.py",
    "content": "__author__ = 'yunbo'\n\nimport numpy as np\nfrom scipy.signal import convolve2d\n\n\ndef batch_mae_frame_float(gen_frames, gt_frames):\n    # [batch, width, height] or [batch, width, height, channel]\n    if gen_frames.ndim == 3:\n        axis = (1, 2)\n    elif gen_frames.ndim == 4:\n        axis = (1, 2, 3)\n    x = np.float32(gen_frames)\n    y = np.float32(gt_frames)\n    mae = np.sum(np.absolute(x - y), axis=axis, dtype=np.float32)\n    return np.mean(mae)\n\n\ndef batch_psnr(gen_frames, gt_frames):\n    # [batch, width, height] or [batch, width, height, channel]\n    if gen_frames.ndim == 3:\n        axis = (1, 2)\n    elif gen_frames.ndim == 4:\n        axis = (1, 2, 3)\n    x = np.int32(gen_frames)\n    y = np.int32(gt_frames)\n    num_pixels = float(np.size(gen_frames[0]))\n    mse = np.sum((x - y) ** 2, axis=axis, dtype=np.float32) / num_pixels\n    psnr = 20 * np.log10(255) - 10 * np.log10(mse)\n    return np.mean(psnr)"
  },
  {
    "path": "src/utils/optimizer.py",
    "content": "import tensorflow as tf\n\n\ndef adam_updates(params, cost_or_grads, lr=0.001, mom1=0.9, mom2=0.999):\n    updates = []\n    if type(cost_or_grads) is not list:\n        grads = tf.gradients(cost_or_grads, params)\n    else:\n        grads = cost_or_grads\n    t = tf.Variable(1., 'adam_t')\n    for p, g in zip(params, grads):\n        mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg')\n        if mom1 > 0:\n            v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v')\n            v_t = mom1 * v + (1. - mom1) * g\n            v_hat = v_t / (1. - tf.pow(mom1, t))\n            updates.append(v.assign(v_t))\n        else:\n            v_hat = g\n        mg_t = mom2 * mg + (1. - mom2) * tf.square(g)\n        mg_hat = mg_t / (1. - tf.pow(mom2, t))\n        g_t = v_hat / tf.sqrt(mg_hat + 1e-8)\n        p_t = p - lr * g_t\n        updates.append(mg.assign(mg_t))\n        updates.append(p.assign(p_t))\n    updates.append(t.assign_add(1))\n    return tf.group(*updates)\n\n"
  },
  {
    "path": "src/utils/preprocess.py",
    "content": "__author__ = 'yunbo'\n\nimport numpy as np\n\n\ndef reshape_patch(img_tensor, patch_size):\n    assert 5 == img_tensor.ndim\n    batch_size = np.shape(img_tensor)[0]\n    seq_length = np.shape(img_tensor)[1]\n    img_height = np.shape(img_tensor)[2]\n    img_width = np.shape(img_tensor)[3]\n    num_channels = np.shape(img_tensor)[4]\n    a = np.reshape(img_tensor, [batch_size, seq_length,\n                                img_height//patch_size, patch_size,\n                                img_width//patch_size, patch_size,\n                                num_channels])\n    b = np.transpose(a, [0,1,2,4,3,5,6])\n    patch_tensor = np.reshape(b, [batch_size, seq_length,\n                                  img_height//patch_size,\n                                  img_width//patch_size,\n                                  patch_size*patch_size*num_channels])\n    return patch_tensor\n\n\ndef reshape_patch_back(patch_tensor, patch_size):\n    assert 5 == patch_tensor.ndim\n    batch_size = np.shape(patch_tensor)[0]\n    seq_length = np.shape(patch_tensor)[1]\n    patch_height = np.shape(patch_tensor)[2]\n    patch_width = np.shape(patch_tensor)[3]\n    channels = np.shape(patch_tensor)[4]\n    img_channels = channels // (patch_size*patch_size)\n    a = np.reshape(patch_tensor, [batch_size, seq_length,\n                                  patch_height, patch_width,\n                                  patch_size, patch_size,\n                                  img_channels])\n    b = np.transpose(a, [0,1,2,4,3,5,6])\n    img_tensor = np.reshape(b, [batch_size, seq_length,\n                                patch_height * patch_size,\n                                patch_width * patch_size,\n                                img_channels])\n    return img_tensor\n\n"
  }
]