[
  {
    "path": "LICENSE",
    "content": "\nMIT License\n\nCopyright (c) 2020 Virginia Tech Vision and Learning Lab\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n\n--------------------------- LICENSE FOR EdgeConnect --------------------------------\n\nAttribution-NonCommercial 4.0 International\n"
  },
  {
    "path": "README.md",
    "content": "# Dynamic View Synthesis from Dynamic Monocular Video\n\n[![arXiv](https://img.shields.io/badge/arXiv-2108.00946-b31b1b.svg)](https://arxiv.org/abs/2105.06468)\n\n[Project Website](https://free-view-video.github.io/) | [Video](https://youtu.be/j8CUzIR0f8M) | [Paper](https://arxiv.org/abs/2105.06468)\n\n> **Dynamic View Synthesis from Dynamic Monocular Video**<br>\n> [Chen Gao](http://chengao.vision), [Ayush Saraf](#), [Johannes Kopf](https://johanneskopf.de/), [Jia-Bin Huang](https://filebox.ece.vt.edu/~jbhuang/) <br>\nin ICCV 2021 <br>\n\n## Setup\nThe code is test with\n* Linux (tested on CentOS Linux release 7.4.1708)\n* Anaconda 3\n* Python 3.7.11\n* CUDA 10.1\n* 1 V100 GPU\n\n\nTo get started, please create the conda environment `dnerf` by running\n```\nconda create --name dnerf python=3.7\nconda activate dnerf\nconda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch\npip install imageio scikit-image configargparse timm lpips\n```\nand install [COLMAP](https://colmap.github.io/install.html) manually. Then download MiDaS and RAFT weights\n```\nROOT_PATH=/path/to/the/DynamicNeRF/folder\ncd $ROOT_PATH\nwget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/weights.zip\nunzip weights.zip\nrm weights.zip\n```\n\n## Dynamic Scene Dataset\nThe [Dynamic Scene Dataset](https://www-users.cse.umn.edu/~jsyoon/dynamic_synth/) is used to\nquantitatively evaluate our method. Please download the pre-processed data by running:\n```\ncd $ROOT_PATH\nwget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/data.zip\nunzip data.zip\nrm data.zip\n```\n\n### Training\nYou can train a model from scratch by running:\n```\ncd $ROOT_PATH/\npython run_nerf.py --config configs/config_Balloon2.txt\n```\n\nEvery 100k iterations, you should get videos like the following examples\n\nThe novel view-time synthesis results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/novelviewtime`.\n![novelviewtime](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/novelviewtime_Balloon2.gif)\n<!-- <img src=\"https://filebox.ece.vt.edu/~chengao/free-view-video/gif/novelviewtime.gif\" height=\"270\" /> -->\n\nThe reconstruction results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset`.\n![testset](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/testset_Balloon2.gif)\n\nThe fix-view-change-time results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset_view000`.\n![testset_view000](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/testset_view000_Balloon2.gif)\n\nThe fix-time-change-view results will be saved in `$ROOT_PATH/logs/Balloon2_H270_DyNeRF/testset_time000`.\n![testset_time000](https://filebox.ece.vt.edu/~chengao/free-view-video/gif/testset_time000_Balloon2.gif)\n\n\n### Rendering from pre-trained models\nWe also provide pre-trained models. You can download them by running:\n```\ncd $ROOT_PATH/\nwget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/logs.zip\nunzip logs.zip\nrm logs.zip\n```\n\nThen you can render the results directly by running:\n```\npython run_nerf.py --config configs/config_Balloon2.txt --render_only --ft_path $ROOT_PATH/logs/Balloon2_H270_DyNeRF_pretrain/300000.tar\n```\n\n### Evaluating our method and others\nOur goal is to make the evaluation as simple as possible for you. We have collected the fix-view-change-time results of the following methods:\n\n`NeRF` \\\n`NeRF + t` \\\n`Yoon et al.` \\\n`Non-Rigid NeRF` \\\n`NSFF` \\\n`DynamicNeRF (ours)`\n\nPlease download the results by running:\n```\ncd $ROOT_PATH/\nwget --no-check-certificate https://filebox.ece.vt.edu/~chengao/free-view-video/results.zip\nunzip results.zip\nrm results.zip\n```\n\nThen you can calculate the PSNR/SSIM/LPIPS by running:\n```\ncd $ROOT_PATH/utils\npython evaluation.py\n```\n\n| PSNR / LPIPS |    Jumping    |    Skating    |     Truck     |    Umbrella   |    Balloon1   |    Balloon2   |   Playground  |    Average    |\n|:-------------|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|:-------------:|\n| NeRF         | 20.99 / 0.305 | 23.67 / 0.311 | 22.73 / 0.229 | 21.29 / 0.440 | 19.82 / 0.205 | 24.37 / 0.098 | 21.07 / 0.165 | 21.99 / 0.250 |\n| NeRF + t     | 18.04 / 0.455 | 20.32 / 0.512 | 18.33 / 0.382 | 17.69 / 0.728 | 18.54 / 0.275 | 20.69 / 0.216 | 14.68 / 0.421 | 18.33 / 0.427 |\n| NR NeRF      | 20.09 / 0.287 | 23.95 / 0.227 | 19.33 / 0.446 | 19.63 / 0.421 | 17.39 / 0.348 | 22.41 / 0.213 | 15.06 / 0.317 | 19.69 / 0.323 |\n| NSFF         | 24.65 / 0.151 | 29.29 / 0.129 | 25.96 / 0.167 | 22.97 / 0.295 | 21.96 / 0.215 | 24.27 / 0.222 | 21.22 / 0.212 | 24.33 / 0.199 |\n| Ours         | 24.68 / 0.090 | 32.66 / 0.035 | 28.56 / 0.082 | 23.26 / 0.137 | 22.36 / 0.104 | 27.06 / 0.049 | 24.15 / 0.080 | 26.10 / 0.082 |\n\n\nPlease note:\n1. The numbers reported in the paper are calculated using TF code. The numbers here are calculated using this improved Pytorch version.\n2. In Yoon's results, the first frame and the last frame are missing. To compare with Yoon's results, we have to omit the first frame and the last frame. To do so, please uncomment line 72 and comment line 73 in `evaluation.py`.\n3. We obtain the results of NSFF and NR NeRF using the official implementation with default parameters.\n\n\n## Train a model on your sequence\n0. Set some paths\n\n```\nROOT_PATH=/path/to/the/DynamicNeRF/folder\nDATASET_NAME=name_of_the_video_without_extension\nDATASET_PATH=$ROOT_PATH/data/$DATASET_NAME\n```\n\n1. Prepare training images and background masks from a video.\n\n```\ncd $ROOT_PATH/utils\npython generate_data.py --videopath /path/to/the/video\n```\n\n2. Use COLMAP to obtain camera poses.\n\n```\ncolmap feature_extractor \\\n--database_path $DATASET_PATH/database.db \\\n--image_path $DATASET_PATH/images_colmap \\\n--ImageReader.mask_path $DATASET_PATH/background_mask \\\n--ImageReader.single_camera 1\n\ncolmap exhaustive_matcher \\\n--database_path $DATASET_PATH/database.db\n\nmkdir $DATASET_PATH/sparse\ncolmap mapper \\\n    --database_path $DATASET_PATH/database.db \\\n    --image_path $DATASET_PATH/images_colmap \\\n    --output_path $DATASET_PATH/sparse \\\n    --Mapper.num_threads 16 \\\n    --Mapper.init_min_tri_angle 4 \\\n    --Mapper.multiple_models 0 \\\n    --Mapper.extract_colors 0\n```\n\n3. Save camera poses into the format that NeRF reads.\n\n```\ncd $ROOT_PATH/utils\npython generate_pose.py --dataset_path $DATASET_PATH\n```\n\n4. Estimate monocular depth.\n\n```\ncd $ROOT_PATH/utils\npython generate_depth.py --dataset_path $DATASET_PATH --model $ROOT_PATH/weights/midas_v21-f6b98070.pt\n```\n\n5. Predict optical flows.\n\n```\ncd $ROOT_PATH/utils\npython generate_flow.py --dataset_path $DATASET_PATH --model $ROOT_PATH/weights/raft-things.pth\n```\n\n6. Obtain motion mask (code adapted from NSFF).\n\n```\ncd $ROOT_PATH/utils\npython generate_motion_mask.py --dataset_path $DATASET_PATH\n```\n\n7. Train a model. Please change `expname` and `datadir` in `configs/config.txt`.\n\n```\ncd $ROOT_PATH/\npython run_nerf.py --config configs/config.txt\n```\n\nExplanation of each parameter:\n\n- `expname`: experiment name\n- `basedir`: where to store ckpts and logs\n- `datadir`: input data directory\n- `factor`: downsample factor for the input images\n- `N_rand`: number of random rays per gradient step\n- `N_samples`: number of samples per ray\n- `netwidth`: channels per layer\n- `use_viewdirs`: whether enable view-dependency for StaticNeRF\n- `use_viewdirsDyn`: whether enable view-dependency for DynamicNeRF\n- `raw_noise_std`: std dev of noise added to regularize sigma_a output\n- `no_ndc`: do not use normalized device coordinates\n- `lindisp`: sampling linearly in disparity rather than depth\n- `i_video`: frequency of novel view-time synthesis video saving\n- `i_testset`: frequency of testset video saving\n- `N_iters`: number of training iterations\n- `i_img`: frequency of tensorboard image logging\n- `DyNeRF_blending`: whether use DynamicNeRF to predict blending weight\n- `pretrain`: whether pre-train StaticNeRF\n\n## License\nThis work is licensed under MIT License. See [LICENSE](LICENSE) for details.\n\nIf you find this code useful for your research, please consider citing the following paper:\n\n\t@inproceedings{Gao-ICCV-DynNeRF,\n\t    author    = {Gao, Chen and Saraf, Ayush and Kopf, Johannes and Huang, Jia-Bin},\n\t    title     = {Dynamic View Synthesis from Dynamic Monocular Video},\n\t    booktitle = {Proceedings of the IEEE International Conference on Computer Vision},\n\t    year      = {2021}\n\t}\n\n## Acknowledgments\nOur training code is build upon\n[NeRF](https://github.com/bmild/nerf),\n[NeRF-pytorch](https://github.com/yenchenlin/nerf-pytorch), and\n[NSFF](https://github.com/zl548/Neural-Scene-Flow-Fields).\nOur flow prediction code is modified from [RAFT](https://github.com/princeton-vl/RAFT).\nOur depth prediction code is modified from [MiDaS](https://github.com/isl-org/MiDaS).\n"
  },
  {
    "path": "configs/config.txt",
    "content": "expname = xxxxxx_DyNeRF_pretrain_test\nbasedir = ./logs\ndatadir = ./data/xxxxxx/\n\ndataset_type = llff\n\nfactor = 4\nN_rand = 1024\nN_samples = 64\nnetwidth = 256\n\ni_video = 100000\ni_testset = 100000\nN_iters = 500001\ni_img = 500\n\nuse_viewdirs = True\nuse_viewdirsDyn = True\nraw_noise_std = 1e0\nno_ndc = False\nlindisp = False\n\ndynamic_loss_lambda = 1.0\nstatic_loss_lambda = 1.0\nfull_loss_lambda = 3.0\ndepth_loss_lambda = 0.04\norder_loss_lambda = 0.1\nflow_loss_lambda = 0.02\nslow_loss_lambda = 0.01\nsmooth_loss_lambda = 0.1\nconsistency_loss_lambda = 1.0\nmask_loss_lambda = 0.01\nsparse_loss_lambda = 0.001\nDyNeRF_blending = True\npretrain = True\n"
  },
  {
    "path": "configs/config_Balloon1.txt",
    "content": "expname = Balloon1_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Balloon1/\n\ndataset_type = llff\n\nfactor = 2\nN_rand = 1024\nN_samples = 64\nN_importance = 0\nnetwidth = 256\n\ni_video = 100000\ni_testset = 100000\nN_iters = 300001\ni_img = 500\n\nuse_viewdirs = True\nuse_viewdirsDyn = False\nraw_noise_std = 1e0\nno_ndc = False\nlindisp = False\n\ndynamic_loss_lambda = 1.0\nstatic_loss_lambda = 1.0\nfull_loss_lambda = 3.0\ndepth_loss_lambda = 0.04\norder_loss_lambda = 0.1\nflow_loss_lambda = 0.02\nslow_loss_lambda = 0.01\nsmooth_loss_lambda = 0.1\nconsistency_loss_lambda = 1.0\nmask_loss_lambda = 0.1\nsparse_loss_lambda = 0.001\nDyNeRF_blending = True\npretrain = True\n"
  },
  {
    "path": "configs/config_Balloon2.txt",
    "content": "expname = Balloon2_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Balloon2/\n\ndataset_type = llff\n\nfactor = 2\nN_rand = 1024\nN_samples = 64\nN_importance = 0\nnetwidth = 256\n\ni_video = 100000\ni_testset = 100000\nN_iters = 300001\ni_img = 500\n\nuse_viewdirs = True\nuse_viewdirsDyn = True\nraw_noise_std = 1e0\nno_ndc = False\nlindisp = False\n\ndynamic_loss_lambda = 1.0\nstatic_loss_lambda = 1.0\nfull_loss_lambda = 3.0\ndepth_loss_lambda = 0.04\norder_loss_lambda = 0.1\nflow_loss_lambda = 0.02\nslow_loss_lambda = 0.01\nsmooth_loss_lambda = 0.1\nconsistency_loss_lambda = 1.0\nmask_loss_lambda = 0.1\nsparse_loss_lambda = 0.001\nDyNeRF_blending = True\npretrain = True\n"
  },
  {
    "path": "configs/config_Jumping.txt",
    "content": "expname = Jumping_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Jumping/\n\ndataset_type = llff\n\nfactor = 2\nN_rand = 1024\nN_samples = 64\nN_importance = 0\nnetwidth = 256\n\ni_video = 100000\ni_testset = 100000\nN_iters = 300001\ni_img = 500\n\nuse_viewdirs = True\nuse_viewdirsDyn = False\nraw_noise_std = 1e0\nno_ndc = False\nlindisp = False\n\ndynamic_loss_lambda = 1.0\nstatic_loss_lambda = 1.0\nfull_loss_lambda = 3.0\ndepth_loss_lambda = 0.04\norder_loss_lambda = 0.1\nflow_loss_lambda = 0.02\nslow_loss_lambda = 0.01\nsmooth_loss_lambda = 0.1\nconsistency_loss_lambda = 1.0\nmask_loss_lambda = 0.1\nsparse_loss_lambda = 0.001\nDyNeRF_blending = True\npretrain = True\n"
  },
  {
    "path": "configs/config_Playground.txt",
    "content": "expname = Playground_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Playground/\n\ndataset_type = llff\n\nfactor = 2\nN_rand = 1024\nN_samples = 64\nN_importance = 0\nnetwidth = 256\n\ni_video = 100000\ni_testset = 100000\nN_iters = 300001\ni_img = 500\n\nuse_viewdirs = True\nuse_viewdirsDyn = True\nraw_noise_std = 1e0\nno_ndc = False\nlindisp = False\n\ndynamic_loss_lambda = 1.0\nstatic_loss_lambda = 1.0\nfull_loss_lambda = 3.0\ndepth_loss_lambda = 0.04\norder_loss_lambda = 0.1\nflow_loss_lambda = 0.02\nslow_loss_lambda = 0.01\nsmooth_loss_lambda = 0.1\nconsistency_loss_lambda = 1.0\nmask_loss_lambda = 0.1\nsparse_loss_lambda = 0.001\nDyNeRF_blending = True\npretrain = True\n"
  },
  {
    "path": "configs/config_Skating.txt",
    "content": "expname = Skating_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Skating/\n\ndataset_type = llff\n\nfactor = 2\nN_rand = 1024\nN_samples = 64\nN_importance = 0\nnetwidth = 256\n\ni_video = 100000\ni_testset = 100000\nN_iters = 300001\ni_img = 500\n\nuse_viewdirs = True\nuse_viewdirsDyn = True\nraw_noise_std = 1e0\nno_ndc = False\nlindisp = False\n\ndynamic_loss_lambda = 1.0\nstatic_loss_lambda = 1.0\nfull_loss_lambda = 3.0\ndepth_loss_lambda = 0.04\norder_loss_lambda = 0.1\nflow_loss_lambda = 0.02\nslow_loss_lambda = 0.01\nsmooth_loss_lambda = 0.1\nconsistency_loss_lambda = 1.0\nmask_loss_lambda = 0.1\nsparse_loss_lambda = 0.001\nDyNeRF_blending = True\npretrain = True\n"
  },
  {
    "path": "configs/config_Truck.txt",
    "content": "expname = Truck_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Truck/\n\ndataset_type = llff\n\nfactor = 2\nN_rand = 1024\nN_samples = 64\nN_importance = 0\nnetwidth = 256\n\ni_video = 100000\ni_testset = 100000\nN_iters = 300001\ni_img = 500\n\nuse_viewdirs = True\nuse_viewdirsDyn = True\nraw_noise_std = 1e0\nno_ndc = False\nlindisp = False\n\ndynamic_loss_lambda = 1.0\nstatic_loss_lambda = 1.0\nfull_loss_lambda = 3.0\ndepth_loss_lambda = 0.04\norder_loss_lambda = 0.1\nflow_loss_lambda = 0.02\nslow_loss_lambda = 0.01\nsmooth_loss_lambda = 0.1\nconsistency_loss_lambda = 1.0\nmask_loss_lambda = 0.1\nsparse_loss_lambda = 0.001\nDyNeRF_blending = True\npretrain = True\n"
  },
  {
    "path": "configs/config_Umbrella.txt",
    "content": "expname = Umbrella_H270_DyNeRF_pretrain\nbasedir = ./logs\ndatadir = ./data/Umbrella/\n\ndataset_type = llff\n\nfactor = 2\nN_rand = 1024\nN_samples = 64\nN_importance = 0\nnetwidth = 256\n\ni_video = 100000\ni_testset = 100000\nN_iters = 300001\ni_img = 500\n\nuse_viewdirs = True\nuse_viewdirsDyn = True\nraw_noise_std = 1e0\nno_ndc = False\nlindisp = False\n\ndynamic_loss_lambda = 1.0\nstatic_loss_lambda = 1.0\nfull_loss_lambda = 3.0\ndepth_loss_lambda = 0.04\norder_loss_lambda = 0.1\nflow_loss_lambda = 0.02\nslow_loss_lambda = 0.01\nsmooth_loss_lambda = 0.1\nconsistency_loss_lambda = 1.0\nmask_loss_lambda = 0.1\nsparse_loss_lambda = 0.001\nDyNeRF_blending = True\npretrain = True\n"
  },
  {
    "path": "load_llff.py",
    "content": "import os\nimport cv2\nimport imageio\nimport numpy as np\n\nfrom utils.flow_utils import resize_flow\nfrom run_nerf_helpers import get_grid\n\n\ndef _minify(basedir, factors=[], resolutions=[]):\n    needtoload = False\n    for r in factors:\n        imgdir = os.path.join(basedir, 'images_{}'.format(r))\n        if not os.path.exists(imgdir):\n            needtoload = True\n    for r in resolutions:\n        imgdir = os.path.join(basedir, 'images_{}x{}'.format(r[1], r[0]))\n        if not os.path.exists(imgdir):\n            needtoload = True\n    if not needtoload:\n        return\n\n    from shutil import copy\n    from subprocess import check_output\n\n    imgdir = os.path.join(basedir, 'images')\n    imgs = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir))]\n    imgs = [f for f in imgs if any([f.endswith(ex) for ex in ['JPG', 'jpg', 'png', 'jpeg', 'PNG']])]\n    imgdir_orig = imgdir\n\n    wd = os.getcwd()\n\n    for r in factors + resolutions:\n        if isinstance(r, int):\n            name = 'images_{}'.format(r)\n            resizearg = '{}%'.format(100./r)\n        else:\n            name = 'images_{}x{}'.format(r[1], r[0])\n            resizearg = '{}x{}'.format(r[1], r[0])\n        imgdir = os.path.join(basedir, name)\n        if os.path.exists(imgdir):\n            continue\n\n        print('Minifying', r, basedir)\n\n        os.makedirs(imgdir)\n        check_output('cp {}/* {}'.format(imgdir_orig, imgdir), shell=True)\n\n        ext = imgs[0].split('.')[-1]\n        args = ' '.join(['mogrify', '-resize', resizearg, '-format', 'png', '*.{}'.format(ext)])\n        print(args)\n        os.chdir(imgdir)\n        check_output(args, shell=True)\n        os.chdir(wd)\n\n        if ext != 'png':\n            check_output('rm {}/*.{}'.format(imgdir, ext), shell=True)\n            print('Removed duplicates')\n        print('Done')\n\n\ndef _load_data(basedir, factor=None, width=None, height=None, load_imgs=True):\n    print('factor ', factor)\n    poses_arr = np.load(os.path.join(basedir, 'poses_bounds.npy'))\n    poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])\n    bds = poses_arr[:, -2:].transpose([1,0])\n\n    img0 = [os.path.join(basedir, 'images', f) for f in sorted(os.listdir(os.path.join(basedir, 'images'))) \\\n            if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')][0]\n    sh = imageio.imread(img0).shape\n\n    sfx = ''\n\n    if factor is not None:\n        sfx = '_{}'.format(factor)\n        _minify(basedir, factors=[factor])\n        factor = factor\n    elif height is not None:\n        factor = sh[0] / float(height)\n        width = int(sh[1] / factor)\n        if width % 2 == 1:\n            width -= 1\n        _minify(basedir, resolutions=[[height, width]])\n        sfx = '_{}x{}'.format(width, height)\n    elif width is not None:\n        factor = sh[1] / float(width)\n        height = int(sh[0] / factor)\n        if height % 2 == 1:\n            height -= 1\n        _minify(basedir, resolutions=[[height, width]])\n        sfx = '_{}x{}'.format(width, height)\n    else:\n        factor = 1\n\n    imgdir = os.path.join(basedir, 'images' + sfx)\n    if not os.path.exists(imgdir):\n        print( imgdir, 'does not exist, returning' )\n        return\n\n    imgfiles = [os.path.join(imgdir, f) for f in sorted(os.listdir(imgdir)) \\\n                if f.endswith('JPG') or f.endswith('jpg') or f.endswith('png')]\n    if poses.shape[-1] != len(imgfiles):\n        print( 'Mismatch between imgs {} and poses {} !!!!'.format(len(imgfiles), poses.shape[-1]) )\n        return\n\n    sh = imageio.imread(imgfiles[0]).shape\n    num_img = len(imgfiles)\n    poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])\n    poses[2, 4, :] = poses[2, 4, :] * 1./factor\n\n    if not load_imgs:\n        return poses, bds\n\n    def imread(f):\n        if f.endswith('png'):\n            return imageio.imread(f, ignoregamma=True)\n        else:\n            return imageio.imread(f)\n\n    imgs = [imread(f)[..., :3] / 255. for f in imgfiles]\n    imgs = np.stack(imgs, -1)\n\n    assert imgs.shape[0] == sh[0]\n    assert imgs.shape[1] == sh[1]\n\n    disp_dir = os.path.join(basedir, 'disp')\n\n    dispfiles = [os.path.join(disp_dir, f) \\\n                for f in sorted(os.listdir(disp_dir)) if f.endswith('npy')]\n\n    disp = [cv2.resize(np.load(f),\n                    (sh[1], sh[0]),\n                    interpolation=cv2.INTER_NEAREST) for f in dispfiles]\n    disp = np.stack(disp, -1)\n\n    mask_dir = os.path.join(basedir, 'motion_masks')\n    maskfiles = [os.path.join(mask_dir, f) \\\n                for f in sorted(os.listdir(mask_dir)) if f.endswith('png')]\n\n    masks = [cv2.resize(imread(f)/255., (sh[1], sh[0]),\n                        interpolation=cv2.INTER_NEAREST) for f in maskfiles]\n    masks = np.stack(masks, -1)\n    masks = np.float32(masks > 1e-3)\n\n    flow_dir = os.path.join(basedir, 'flow')\n    flows_f = []\n    flow_masks_f = []\n    flows_b = []\n    flow_masks_b = []\n    for i in range(num_img):\n        if i == num_img - 1:\n            fwd_flow, fwd_mask = np.zeros((sh[0], sh[1], 2)), np.zeros((sh[0], sh[1]))\n        else:\n            fwd_flow_path = os.path.join(flow_dir, '%03d_fwd.npz'%i)\n            fwd_data = np.load(fwd_flow_path)\n            fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask']\n            fwd_flow = resize_flow(fwd_flow, sh[0], sh[1])\n            fwd_mask = np.float32(fwd_mask)\n            fwd_mask = cv2.resize(fwd_mask, (sh[1], sh[0]),\n                                interpolation=cv2.INTER_NEAREST)\n        flows_f.append(fwd_flow)\n        flow_masks_f.append(fwd_mask)\n\n        if i == 0:\n            bwd_flow, bwd_mask = np.zeros((sh[0], sh[1], 2)), np.zeros((sh[0], sh[1]))\n        else:\n            bwd_flow_path = os.path.join(flow_dir, '%03d_bwd.npz'%i)\n            bwd_data = np.load(bwd_flow_path)\n            bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask']\n            bwd_flow = resize_flow(bwd_flow, sh[0], sh[1])\n            bwd_mask = np.float32(bwd_mask)\n            bwd_mask = cv2.resize(bwd_mask, (sh[1], sh[0]),\n                                interpolation=cv2.INTER_NEAREST)\n        flows_b.append(bwd_flow)\n        flow_masks_b.append(bwd_mask)\n\n    flows_f = np.stack(flows_f, -1)\n    flow_masks_f = np.stack(flow_masks_f, -1)\n    flows_b = np.stack(flows_b, -1)\n    flow_masks_b = np.stack(flow_masks_b, -1)\n\n    print(imgs.shape)\n    print(disp.shape)\n    print(masks.shape)\n    print(flows_f.shape)\n    print(flow_masks_f.shape)\n\n    assert(imgs.shape[0] == disp.shape[0])\n    assert(imgs.shape[0] == masks.shape[0])\n    assert(imgs.shape[0] == flows_f.shape[0])\n    assert(imgs.shape[0] == flow_masks_f.shape[0])\n\n    assert(imgs.shape[1] == disp.shape[1])\n    assert(imgs.shape[1] == masks.shape[1])\n\n    return poses, bds, imgs, disp, masks, flows_f, flow_masks_f, flows_b, flow_masks_b\n\n\ndef normalize(x):\n    return x / np.linalg.norm(x)\n\ndef viewmatrix(z, up, pos):\n    vec2 = normalize(z)\n    vec1_avg = up\n    vec0 = normalize(np.cross(vec1_avg, vec2))\n    vec1 = normalize(np.cross(vec2, vec0))\n    m = np.stack([vec0, vec1, vec2, pos], 1)\n    return m\n\n\ndef poses_avg(poses):\n\n    hwf = poses[0, :3, -1:]\n\n    center = poses[:, :3, 3].mean(0)\n    vec2 = normalize(poses[:, :3, 2].sum(0))\n    up = poses[:, :3, 1].sum(0)\n    c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)\n\n    return c2w\n\n\n\ndef render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):\n    render_poses = []\n    rads = np.array(list(rads) + [1.])\n    hwf = c2w[:,4:5]\n\n    for theta in np.linspace(0., 2. * np.pi * rots, N+1)[:-1]:\n        c = np.dot(c2w[:3, :4],\n                    np.array([np.cos(theta),\n                             -np.sin(theta),\n                             -np.sin(theta*zrate),\n                              1.]) * rads)\n        z = normalize(c - np.dot(c2w[:3,:4], np.array([0,0,-focal, 1.])))\n        render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))\n    return render_poses\n\n\n\ndef recenter_poses(poses):\n\n    poses_ = poses+0\n    bottom = np.reshape([0,0,0,1.], [1,4])\n    c2w = poses_avg(poses)\n    c2w = np.concatenate([c2w[:3,:4], bottom], -2)\n    bottom = np.tile(np.reshape(bottom, [1,1,4]), [poses.shape[0],1,1])\n    poses = np.concatenate([poses[:,:3,:4], bottom], -2)\n\n    poses = np.linalg.inv(c2w) @ poses\n    poses_[:,:3,:4] = poses[:,:3,:4]\n    poses = poses_\n    return poses\n\n\ndef spherify_poses(poses, bds):\n\n    p34_to_44 = lambda p : np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1,:], [1,1,4]), [p.shape[0], 1,1])], 1)\n\n    rays_d = poses[:,:3,2:3]\n    rays_o = poses[:,:3,3:4]\n\n    def min_line_dist(rays_o, rays_d):\n        A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0,2,1])\n        b_i = -A_i @ rays_o\n        pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0,2,1]) @ A_i).mean(0)) @ (b_i).mean(0))\n        return pt_mindist\n\n    pt_mindist = min_line_dist(rays_o, rays_d)\n\n    center = pt_mindist\n    up = (poses[:,:3,3] - center).mean(0)\n\n    vec0 = normalize(up)\n    vec1 = normalize(np.cross([.1,.2,.3], vec0))\n    vec2 = normalize(np.cross(vec0, vec1))\n    pos = center\n    c2w = np.stack([vec1, vec2, vec0, pos], 1)\n\n    poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:,:3,:4])\n\n    rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:,:3,3]), -1)))\n\n    sc = 1./rad\n    poses_reset[:,:3,3] *= sc\n    bds *= sc\n    rad *= sc\n\n    centroid = np.mean(poses_reset[:,:3,3], 0)\n    zh = centroid[2]\n    radcircle = np.sqrt(rad**2-zh**2)\n    new_poses = []\n\n    for th in np.linspace(0.,2.*np.pi, 120):\n\n        camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])\n        up = np.array([0,0,-1.])\n\n        vec2 = normalize(camorigin)\n        vec0 = normalize(np.cross(vec2, up))\n        vec1 = normalize(np.cross(vec2, vec0))\n        pos = camorigin\n        p = np.stack([vec0, vec1, vec2, pos], 1)\n\n        new_poses.append(p)\n\n    new_poses = np.stack(new_poses, 0)\n\n    new_poses = np.concatenate([new_poses, np.broadcast_to(poses[0,:3,-1:], new_poses[:,:3,-1:].shape)], -1)\n    poses_reset = np.concatenate([poses_reset[:,:3,:4], np.broadcast_to(poses[0,:3,-1:], poses_reset[:,:3,-1:].shape)], -1)\n\n    return poses_reset, new_poses, bds\n\n\ndef load_llff_data(args, basedir,\n                   factor=2,\n                   recenter=True, bd_factor=.75,\n                   spherify=False, path_zflat=False,\n                   frame2dolly=10):\n\n    poses, bds, imgs, disp, masks, flows_f, flow_masks_f, flows_b, flow_masks_b = \\\n        _load_data(basedir, factor=factor) # factor=2 downsamples original imgs by 2x\n\n    print('Loaded', basedir, bds.min(), bds.max())\n\n    # Correct rotation matrix ordering and move variable dim to axis 0\n    poses = np.concatenate([poses[:, 1:2, :],\n                           -poses[:, 0:1, :],\n                            poses[:, 2:, :]], 1)\n    poses = np.moveaxis(poses, -1, 0).astype(np.float32)\n    images = np.moveaxis(imgs, -1, 0).astype(np.float32)\n    bds = np.moveaxis(bds, -1, 0).astype(np.float32)\n    disp = np.moveaxis(disp, -1, 0).astype(np.float32)\n    masks = np.moveaxis(masks, -1, 0).astype(np.float32)\n    flows_f = np.moveaxis(flows_f, -1, 0).astype(np.float32)\n    flow_masks_f = np.moveaxis(flow_masks_f, -1, 0).astype(np.float32)\n    flows_b = np.moveaxis(flows_b, -1, 0).astype(np.float32)\n    flow_masks_b = np.moveaxis(flow_masks_b, -1, 0).astype(np.float32)\n\n    # Rescale if bd_factor is provided\n    sc = 1. if bd_factor is None else 1./(np.percentile(bds[:, 0], 5) * bd_factor)\n\n    poses[:, :3, 3] *= sc\n    bds *= sc\n\n    if recenter:\n        poses = recenter_poses(poses)\n\n    # Only for rendering\n    if frame2dolly == -1:\n        c2w = poses_avg(poses)\n    else:\n        c2w = poses[frame2dolly, :, :]\n\n    H, W, _ = c2w[:, -1]\n\n    # Generate poses for novel views\n    render_poses, render_focals = generate_path(c2w, args)\n    render_poses = np.array(render_poses).astype(np.float32)\n\n    grids = get_grid(int(H), int(W), len(poses), flows_f, flow_masks_f, flows_b, flow_masks_b) # [N, H, W, 8]\n\n    return images, disp, masks, poses, bds,\\\n        render_poses, render_focals, grids\n\n\ndef generate_path(c2w, args):\n    hwf = c2w[:, 4:5]\n    num_novelviews = args.num_novelviews\n    max_disp = 48.0\n    H, W, focal = hwf[:, 0]\n\n    max_trans = max_disp / focal\n    output_poses = []\n    output_focals = []\n\n    # Rendering teaser. Add translation.\n    for i in range(num_novelviews):\n        x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_novelviews)) * args.x_trans_multiplier\n        y_trans = max_trans * (np.cos(2.0 * np.pi * float(i) / float(num_novelviews)) - 1.) * args.y_trans_multiplier\n        z_trans = 0.\n\n        i_pose = np.concatenate([\n            np.concatenate(\n                [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),\n            np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]\n        ],axis=0)\n\n        i_pose = np.linalg.inv(i_pose)\n\n        ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)\n\n        render_pose = np.dot(ref_pose, i_pose)\n        output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))\n        output_focals.append(focal)\n\n    # Rendering teaser. Add zooming.\n    if args.frame2dolly != -1:\n        for i in range(num_novelviews // 2 + 1):\n            x_trans = 0.\n            y_trans = 0.\n            # z_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_novelviews)) * args.z_trans_multiplier\n            z_trans = max_trans * args.z_trans_multiplier * i / float(num_novelviews // 2)\n            i_pose = np.concatenate([\n                np.concatenate(\n                    [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),\n                np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]\n            ],axis=0)\n\n            i_pose = np.linalg.inv(i_pose) #torch.tensor(np.linalg.inv(i_pose)).float()\n\n            ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)\n\n            render_pose = np.dot(ref_pose, i_pose)\n            output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))\n            output_focals.append(focal)\n            print(z_trans / max_trans / args.z_trans_multiplier)\n\n    # Rendering teaser. Add dolly zoom.\n    if args.frame2dolly != -1:\n        for i in range(num_novelviews // 2 + 1):\n            x_trans = 0.\n            y_trans = 0.\n            z_trans = max_trans * args.z_trans_multiplier * i / float(num_novelviews // 2)\n            i_pose = np.concatenate([\n                np.concatenate(\n                    [np.eye(3), np.array([x_trans, y_trans, z_trans])[:, np.newaxis]], axis=1),\n                np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]\n            ],axis=0)\n\n            i_pose = np.linalg.inv(i_pose)\n\n            ref_pose = np.concatenate([c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0)\n\n            render_pose = np.dot(ref_pose, i_pose)\n            output_poses.append(np.concatenate([render_pose[:3, :], hwf], 1))\n            new_focal = focal - args.focal_decrease * z_trans / max_trans / args.z_trans_multiplier\n            output_focals.append(new_focal)\n            print(z_trans / max_trans / args.z_trans_multiplier, new_focal)\n\n    return output_poses, output_focals\n"
  },
  {
    "path": "render_utils.py",
    "content": "import os\nimport time\nimport torch\nimport imageio\nimport numpy as np\nimport torch.nn.functional as F\n\nfrom run_nerf_helpers import *\nfrom utils.flow_utils import flow_to_image\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef batchify_rays(t, chain_5frames,\n                rays_flat, chunk=1024*16, **kwargs):\n    \"\"\"Render rays in smaller minibatches to avoid OOM.\n    \"\"\"\n    all_ret = {}\n    for i in range(0, rays_flat.shape[0], chunk):\n        ret = render_rays(t, chain_5frames, rays_flat[i:i+chunk], **kwargs)\n        for k in ret:\n            if k not in all_ret:\n                all_ret[k] = []\n            all_ret[k].append(ret[k])\n\n    all_ret = {k: torch.cat(all_ret[k], 0) for k in all_ret}\n    return all_ret\n\n\ndef render(t, chain_5frames,\n           H, W, focal, focal_render=None,\n           chunk=1024*16, rays=None, c2w=None, ndc=True,\n           near=0., far=1.,\n           use_viewdirs=False, c2w_staticcam=None,\n           **kwargs):\n    \"\"\"Render rays\n    Args:\n      H: int. Height of image in pixels.\n      W: int. Width of image in pixels.\n      focal: float. Focal length of pinhole camera.\n      chunk: int. Maximum number of rays to process simultaneously. Used to\n        control maximum memory usage. Does not affect final results.\n      rays: array of shape [2, batch_size, 3]. Ray origin and direction for\n        each example in batch.\n      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.\n      ndc: bool. If True, represent ray origin, direction in NDC coordinates.\n      near: float or array of shape [batch_size]. Nearest distance for a ray.\n      far: float or array of shape [batch_size]. Farthest distance for a ray.\n      use_viewdirs: bool. If True, use viewing direction of a point in space in model.\n      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for\n       camera while using other c2w argument for viewing directions.\n    Returns:\n      rgb_map: [batch_size, 3]. Predicted RGB values for rays.\n      disp_map: [batch_size]. Disparity map. Inverse of depth.\n      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.\n      extras: dict with everything returned by render_rays().\n    \"\"\"\n\n    if c2w is not None:\n        # special case to render full image\n        if focal_render is not None:\n            # Render full image using different focal length for dolly zoom. Inference only.\n            rays_o, rays_d = get_rays(H, W, focal_render, c2w)\n        else:\n            rays_o, rays_d = get_rays(H, W, focal, c2w)\n    else:\n        # use provided ray batch\n        rays_o, rays_d = rays\n\n    if use_viewdirs:\n        # provide ray directions as input\n        viewdirs = rays_d\n        if c2w_staticcam is not None:\n            raise NotImplementedError\n        # Make all directions unit magnitude.\n        # shape: [batch_size, 3]\n        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)\n        viewdirs = torch.reshape(viewdirs, [-1, 3]).float()\n\n    sh = rays_d.shape # [..., 3]\n    if ndc:\n        # for forward facing scenes\n        rays_o, rays_d = ndc_rays(H, W, focal, 1., rays_o, rays_d)\n\n    # Create ray batch\n    rays_o = torch.reshape(rays_o, [-1, 3]).float()\n    rays_d = torch.reshape(rays_d, [-1, 3]).float()\n    near, far = near * \\\n        torch.ones_like(rays_d[..., :1]), far * torch.ones_like(rays_d[..., :1])\n\n    # (ray origin, ray direction, min dist, max dist) for each ray\n    rays = torch.cat([rays_o, rays_d, near, far], -1)\n    if use_viewdirs:\n        rays = torch.cat([rays, viewdirs], -1)\n\n    # Render and reshape\n    all_ret = batchify_rays(t, chain_5frames,\n                        rays, chunk, **kwargs)\n    for k in all_ret:\n        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])\n        all_ret[k] = torch.reshape(all_ret[k], k_sh)\n\n    return all_ret\n\n\ndef render_path_batch(render_poses, time2render,\n                    hwf, chunk, render_kwargs, savedir=None, focal2render=None):\n    \"\"\"Render frames using batch.\n\n    Args:\n      render_poses: array of shape [num_frame, 3, 4]. Camera-to-world transformation matrix of each frame.\n      time2render: array of shape [num_frame]. Time of each frame.\n      hwf: list. [Height of image in pixels, Width of image in pixels, Focal length of pinhole camera]\n      chunk: int. Maximum number of rays to process simultaneously. Used to\n        control maximum memory usage. Does not affect final results.\n      render_kwargs: dictionary. args for the render function.\n      savedir: string. Directory to save results.\n      focal2render: list. Only used to perform dolly-zoom.\n    Returns:\n      ret_dict: dictionary. Final and intermediate results.\n    \"\"\"\n    H, W, focal = hwf\n\n    ret_dict = {}\n    rgbs = []\n    rgbs_d = []\n    rgbs_s = []\n    dynamicnesses = []\n\n    time_curr = time.time()\n    for i, c2w in enumerate(render_poses):\n\n        print(i, time.time() - time_curr)\n        time_curr = time.time()\n\n        t = time2render[i]\n\n        if focal2render is not None:\n            # Render full image using different focal length\n            rays_o, rays_d = get_rays(H, W, focal2render[i], c2w)\n        else:\n            rays_o, rays_d = get_rays(H, W, focal, c2w)\n        rays_o = torch.reshape(rays_o, (-1, 3))\n        rays_d = torch.reshape(rays_d, (-1, 3))\n        batch_rays = torch.stack([rays_o, rays_d], 0)\n        rgb = []\n        rgb_d = []\n        rgb_s = []\n        dynamicness = []\n        for j in range(0, batch_rays.shape[1], chunk):\n            # print(j, '/', batch_rays.shape[1])\n            ret = render(t, False,\n                         H, W, focal,\n                         chunk=chunk, rays=batch_rays[:, j:j+chunk, :],\n                         **render_kwargs)\n            rgb.append(ret['rgb_map_full'].cpu())\n            rgb_d.append(ret['rgb_map_d'].cpu())\n            rgb_s.append(ret['rgb_map_s'].cpu())\n            dynamicness.append(ret['dynamicness_map'].cpu())\n        rgb = torch.reshape(torch.cat(rgb, 0), (H, W, 3)).numpy()\n        rgb_d = torch.reshape(torch.cat(rgb_d, 0), (H, W, 3)).numpy()\n        rgb_s = torch.reshape(torch.cat(rgb_s, 0), (H, W, 3)).numpy()\n        dynamicness = torch.reshape(torch.cat(dynamicness, 0), (H, W)).numpy()\n\n        # Not a good solution. Should take care of this when preparing the data.\n        if W%2 == 1:\n            # rgb = cv2.resize(rgb, (W - 1, H))\n            rgb = rgb[:, :-1, :]\n            rgb_d = rgb_d[:, :-1, :]\n            rgb_s = rgb_s[:, :-1, :]\n            dynamicness = dynamicness[:, :-1]\n        rgbs.append(rgb)\n        rgbs_d.append(rgb_d)\n        rgbs_s.append(rgb_s)\n        dynamicnesses.append(dynamicness)\n\n        if savedir is not None:\n            rgb8 = to8b(rgbs[-1])\n            filename = os.path.join(savedir, '{:03d}.png'.format(i))\n            imageio.imwrite(filename, rgb8)\n\n    ret_dict['rgbs'] = np.stack(rgbs, 0)\n    ret_dict['rgbs_d'] = np.stack(rgbs_d, 0)\n    ret_dict['rgbs_s'] = np.stack(rgbs_s, 0)\n    ret_dict['dynamicnesses'] = np.stack(dynamicnesses, 0)\n\n    return ret_dict\n\n\ndef render_path(render_poses,\n                time2render,\n                hwf,\n                chunk,\n                render_kwargs,\n                savedir=None,\n                flows_gt_f=None,\n                flows_gt_b=None,\n                focal2render=None):\n    \"\"\"Render frames.\n\n    Args:\n      render_poses: array of shape [num_frame, 3, 4]. Camera-to-world transformation matrix of each frame.\n      time2render: array of shape [num_frame]. Time of each frame.\n      hwf: list. [Height of image in pixels, Width of image in pixels, Focal length of pinhole camera]\n      chunk: int. Maximum number of rays to process simultaneously. Used to\n        control maximum memory usage. Does not affect final results.\n      render_kwargs: dictionary. args for the render function.\n      savedir: string. Directory to save results.\n      focal2render: list. Only used to perform dolly-zoom.\n    Returns:\n      ret_dict: dictionary. Final and intermediate results.\n    \"\"\"\n    H, W, focal = hwf\n\n    ret_dict = {}\n    rgbs = []\n    rgbs_d = []\n    rgbs_s = []\n    depths = []\n    depths_d = []\n    depths_s = []\n    flows_f = []\n    flows_b = []\n    dynamicness = []\n    blending = []\n\n    grid = np.stack(np.meshgrid(np.arange(W, dtype=np.float32),\n                       np.arange(H, dtype=np.float32), indexing='xy'), -1)\n    grid = torch.Tensor(grid)\n    time_curr = time.time()\n    for i, c2w in enumerate(render_poses):\n        t = time2render[i]\n        pose = c2w[:3, :4]\n        print(i, time.time() - time_curr)\n        time_curr = time.time()\n\n        if focal2render is None:\n            # Normal rendering.\n            ret = render(t, False,\n                         H, W, focal,\n                         chunk=1024*32, c2w=pose,\n                         **render_kwargs)\n        else:\n            # Render image using different focal length.\n            ret = render(t, False,\n                         H, W, focal, focal_render=focal2render[i],\n                         chunk=1024*32, c2w=pose,\n                         **render_kwargs)\n\n        rgbs.append(ret['rgb_map_full'].cpu().numpy())\n        rgbs_d.append(ret['rgb_map_d'].cpu().numpy())\n        rgbs_s.append(ret['rgb_map_s'].cpu().numpy())\n\n        depths.append(ret['depth_map_full'].cpu().numpy())\n        depths_d.append(ret['depth_map_d'].cpu().numpy())\n        depths_s.append(ret['depth_map_s'].cpu().numpy())\n\n        dynamicness.append(ret['dynamicness_map'].cpu().numpy())\n\n        if flows_gt_f is not None:\n            # Reconstruction. Flow is caused by both changing camera and changing time.\n            pose_f = render_poses[min(i + 1, int(len(render_poses)) - 1), :3, :4]\n            pose_b = render_poses[max(i - 1, 0), :3, :4]\n        else:\n            # Non training view-time. Flow is caused by changing time (just for visualization).\n            pose_f = render_poses[i, :3, :4]\n            pose_b = render_poses[i, :3, :4]\n\n        # Sceneflow induced optical flow\n        induced_flow_f_ = induce_flow(H, W, focal, pose_f, ret['weights_d'], ret['raw_pts_f'], grid[..., :2])\n        induced_flow_b_ = induce_flow(H, W, focal, pose_b, ret['weights_d'], ret['raw_pts_b'], grid[..., :2])\n\n        if (i + 1) >= len(render_poses):\n            induced_flow_f = np.zeros((H, W, 2))\n        else:\n            induced_flow_f = induced_flow_f_.cpu().numpy()\n        if flows_gt_f is not None:\n            flow_gt_f = flows_gt_f[i].cpu().numpy()\n            induced_flow_f = np.concatenate((induced_flow_f, flow_gt_f), 0)\n        induced_flow_f_img = flow_to_image(induced_flow_f)\n        flows_f.append(induced_flow_f_img)\n\n        if (i - 1) < 0:\n            induced_flow_b = np.zeros((H, W, 2))\n        else:\n            induced_flow_b = induced_flow_b_.cpu().numpy()\n        if flows_gt_b is not None:\n            flow_gt_b = flows_gt_b[i].cpu().numpy()\n            induced_flow_b = np.concatenate((induced_flow_b, flow_gt_b), 0)\n        induced_flow_b_img = flow_to_image(induced_flow_b)\n        flows_b.append(induced_flow_b_img)\n\n        if i == 0:\n            ret_dict['sceneflow_f_NDC'] = ret['sceneflow_f'].cpu().numpy()\n            ret_dict['sceneflow_b_NDC'] = ret['sceneflow_b'].cpu().numpy()\n            ret_dict['blending'] = ret['blending'].cpu().numpy()\n\n            weights = np.concatenate((ret['weights_d'][..., None].cpu().numpy(),\n                                      ret['weights_s'][..., None].cpu().numpy(),\n                                      ret['blending'][..., None].cpu().numpy(),\n                                      ret['weights_full'][..., None].cpu().numpy()))\n            ret_dict['weights'] = np.moveaxis(weights, [0, 1, 2, 3], [1, 2, 0, 3])\n\n        if savedir is not None:\n            rgb8 = to8b(rgbs[-1])\n            filename = os.path.join(savedir, '{:03d}.png'.format(i))\n            imageio.imwrite(filename, rgb8)\n\n    ret_dict['rgbs'] = np.stack(rgbs, 0)\n    ret_dict['rgbs_d'] = np.stack(rgbs_d, 0)\n    ret_dict['rgbs_s'] = np.stack(rgbs_s, 0)\n    ret_dict['depths'] = np.stack(depths, 0)\n    ret_dict['depths_d'] = np.stack(depths_d, 0)\n    ret_dict['depths_s'] = np.stack(depths_s, 0)\n    ret_dict['dynamicness'] = np.stack(dynamicness, 0)\n    ret_dict['flows_f'] = np.stack(flows_f, 0)\n    ret_dict['flows_b'] = np.stack(flows_b, 0)\n\n    return ret_dict\n\n\ndef raw2outputs(raw_s,\n                raw_d,\n                blending,\n                z_vals,\n                rays_d,\n                raw_noise_std):\n    \"\"\"Transforms model's predictions to semantically meaningful values.\n\n    Args:\n      raw_d: [num_rays, num_samples along ray, 4]. Prediction from Dynamic model.\n      raw_s: [num_rays, num_samples along ray, 4]. Prediction from Static model.\n      z_vals: [num_rays, num_samples along ray]. Integration time.\n      rays_d: [num_rays, 3]. Direction of each ray.\n\n    Returns:\n      rgb_map: [num_rays, 3]. Estimated RGB color of a ray.\n      disp_map: [num_rays]. Disparity map. Inverse of depth map.\n      acc_map: [num_rays]. Sum of weights along each ray.\n      weights: [num_rays, num_samples]. Weights assigned to each sampled color.\n      depth_map: [num_rays]. Estimated distance to object.\n    \"\"\"\n    # Function for computing density from model prediction. This value is\n    # strictly between [0, 1].\n    def raw2alpha(raw, dists, act_fn=F.relu): return 1.0 - \\\n        torch.exp(-act_fn(raw) * dists)\n\n    # Compute 'distance' (in time) between each integration time along a ray.\n    dists = z_vals[..., 1:] - z_vals[..., :-1]\n\n    # The 'distance' from the last integration time is infinity.\n    dists = torch.cat(\n        [dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)],\n         -1) # [N_rays, N_samples]\n\n    # Multiply each distance by the norm of its corresponding direction ray\n    # to convert to real world distance (accounts for non-unit directions).\n    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)\n\n    # Extract RGB of each sample position along each ray.\n    rgb_d = torch.sigmoid(raw_d[..., :3])  # [N_rays, N_samples, 3]\n    rgb_s = torch.sigmoid(raw_s[..., :3])  # [N_rays, N_samples, 3]\n\n    # Add noise to model's predictions for density. Can be used to\n    # regularize network during training (prevents floater artifacts).\n    noise = 0.\n    if raw_noise_std > 0.:\n        noise = torch.randn(raw_d[..., 3].shape) * raw_noise_std\n\n    # Predict density of each sample along each ray. Higher values imply\n    # higher likelihood of being absorbed at this point.\n    alpha_d = raw2alpha(raw_d[..., 3] + noise, dists) # [N_rays, N_samples]\n    alpha_s = raw2alpha(raw_s[..., 3] + noise, dists) # [N_rays, N_samples]\n    alphas  = 1. - (1. - alpha_s) * (1. - alpha_d) # [N_rays, N_samples]\n\n    T_d    = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), 1. - alpha_d + 1e-10], -1), -1)[:, :-1]\n    T_s    = torch.cumprod(torch.cat([torch.ones((alpha_s.shape[0], 1)), 1. - alpha_s + 1e-10], -1), -1)[:, :-1]\n    T_full = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), (1. - alpha_d * blending) * (1. - alpha_s * (1. - blending)) + 1e-10], -1), -1)[:, :-1]\n    # T_full = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), torch.pow(1. - alpha_d + 1e-10, blending) * torch.pow(1. - alpha_s + 1e-10, 1. - blending)], -1), -1)[:, :-1]\n    # T_full = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), (1. - alpha_d) * (1. - alpha_s) + 1e-10], -1), -1)[:, :-1]\n\n    # Compute weight for RGB of each sample along each ray.  A cumprod() is\n    # used to express the idea of the ray not having reflected up to this\n    # sample yet.\n    weights_d = alpha_d * T_d\n    weights_s = alpha_s * T_s\n    weights_full = (alpha_d * blending + alpha_s * (1. - blending)) * T_full\n    # weights_full = alphas * T_full\n\n    # Computed weighted color of each sample along each ray.\n    rgb_map_d = torch.sum(weights_d[..., None] * rgb_d, -2)\n    rgb_map_s = torch.sum(weights_s[..., None] * rgb_s, -2)\n    rgb_map_full = torch.sum(\n        (T_full * alpha_d * blending)[..., None] * rgb_d + \\\n        (T_full * alpha_s * (1. - blending))[..., None] * rgb_s, -2)\n\n    # Estimated depth map is expected distance.\n    depth_map_d = torch.sum(weights_d * z_vals, -1)\n    depth_map_s = torch.sum(weights_s * z_vals, -1)\n    depth_map_full = torch.sum(weights_full * z_vals, -1)\n\n    # Sum of weights along each ray. This value is in [0, 1] up to numerical error.\n    acc_map_d = torch.sum(weights_d, -1)\n    acc_map_s = torch.sum(weights_s, -1)\n    acc_map_full = torch.sum(weights_full, -1)\n\n    # Computed dynamicness\n    dynamicness_map = torch.sum(weights_full * blending, -1)\n    # dynamicness_map = 1 - T_d[..., -1]\n\n    return rgb_map_full, depth_map_full, acc_map_full, weights_full, \\\n           rgb_map_s, depth_map_s, acc_map_s, weights_s, \\\n           rgb_map_d, depth_map_d, acc_map_d, weights_d, dynamicness_map\n\n\ndef raw2outputs_d(raw_d,\n                  z_vals,\n                  rays_d,\n                  raw_noise_std):\n\n    # Function for computing density from model prediction. This value is\n    # strictly between [0, 1].\n    def raw2alpha(raw, dists, act_fn=F.relu): return 1.0 - \\\n        torch.exp(-act_fn(raw) * dists)\n\n    # Compute 'distance' (in time) between each integration time along a ray.\n    dists = z_vals[..., 1:] - z_vals[..., :-1]\n\n    # The 'distance' from the last integration time is infinity.\n    dists = torch.cat(\n        [dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)],\n        -1)  # [N_rays, N_samples]\n\n    # Multiply each distance by the norm of its corresponding direction ray\n    # to convert to real world distance (accounts for non-unit directions).\n    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)\n\n    # Extract RGB of each sample position along each ray.\n    rgb_d = torch.sigmoid(raw_d[..., :3])  # [N_rays, N_samples, 3]\n\n    # Add noise to model's predictions for density. Can be used to\n    # regularize network during training (prevents floater artifacts).\n    noise = 0.\n    if raw_noise_std > 0.:\n        noise = torch.randn(raw_d[..., 3].shape) * raw_noise_std\n\n    # Predict density of each sample along each ray. Higher values imply\n    # higher likelihood of being absorbed at this point.\n    alpha_d = raw2alpha(raw_d[..., 3] + noise, dists)  # [N_rays, N_samples]\n\n    T_d = torch.cumprod(torch.cat([torch.ones((alpha_d.shape[0], 1)), 1. - alpha_d + 1e-10], -1), -1)[:, :-1]\n    # Compute weight for RGB of each sample along each ray.  A cumprod() is\n    # used to express the idea of the ray not having reflected up to this\n    # sample yet.\n    weights_d = alpha_d * T_d\n\n    # Computed weighted color of each sample along each ray.\n    rgb_map_d = torch.sum(weights_d[..., None] * rgb_d, -2)\n\n    return rgb_map_d, weights_d\n\n\ndef render_rays(t,\n                chain_5frames,\n                ray_batch,\n                network_fn_d,\n                network_fn_s,\n                network_query_fn_d,\n                network_query_fn_s,\n                N_samples,\n                num_img,\n                DyNeRF_blending,\n                pretrain=False,\n                lindisp=False,\n                perturb=0.,\n                N_importance=0,\n                raw_noise_std=0.,\n                inference=False):\n\n    \"\"\"Volumetric rendering.\n    Args:\n      ray_batch: array of shape [batch_size, ...]. All information necessary\n        for sampling along a ray, including: ray origin, ray direction, min\n        dist, max dist, and unit-magnitude viewing direction.\n      network_fn_d: function. Model for predicting RGB and density at each point\n        in space.\n      network_query_fn_d: function used for passing queries to network_fn_d.\n      N_samples: int. Number of different times to sample along each ray.\n      lindisp: bool. If True, sample linearly in inverse depth rather than in depth.\n      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified\n        random points in time.\n      N_importance: int. Number of additional times to sample along each ray.\n        These samples are only passed to network_fine.\n      network_fine: \"fine\" network with same spec as network_fn.\n      raw_noise_std: ...\n    Returns:\n      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.\n      disp_map: [num_rays]. Disparity map. 1 / depth.\n      acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.\n      raw: [num_rays, num_samples, 4]. Raw predictions from model.\n      rgb0: See rgb_map. Output for coarse model.\n      disp0: See disp_map. Output for coarse model.\n      acc0: See acc_map. Output for coarse model.\n      z_std: [num_rays]. Standard deviation of distances along ray for each\n        sample.\n    \"\"\"\n\n    # batch size\n    N_rays = ray_batch.shape[0]\n\n    # ray_batch: [N_rays, 11]\n    # rays_o:    [N_rays, 0:3]\n    # rays_d:    [N_rays, 3:6]\n    # near:      [N_rays, 6:7]\n    # far:       [N_rays, 7:8]\n    # viewdirs:  [N_rays, 8:11]\n\n    # Extract ray origin, direction.\n    rays_o, rays_d = ray_batch[:, 0:3], ray_batch[:, 3:6] # [N_rays, 3] each\n\n    # Extract unit-normalized viewing direction.\n    viewdirs = ray_batch[:, -3:] if ray_batch.shape[-1] > 8 else None\n\n    # Extract lower, upper bound for ray distance.\n    bounds = torch.reshape(ray_batch[..., 6:8], [-1, 1, 2])\n    near, far = bounds[..., 0], bounds[..., 1]\n\n    # Decide where to sample along each ray. Under the logic, all rays will be sampled at\n    # the same times.\n    t_vals = torch.linspace(0., 1., steps=N_samples)\n    if not lindisp:\n        # Space integration times linearly between 'near' and 'far'. Same\n        # integration points will be used for all rays.\n        z_vals = near * (1.-t_vals) + far * (t_vals)\n    else:\n        # Sample linearly in inverse depth (disparity).\n        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))\n    z_vals = z_vals.expand([N_rays, N_samples])\n\n    # Perturb sampling time along each ray.\n    if perturb > 0.:\n        # get intervals between samples\n        mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])\n        upper = torch.cat([mids, z_vals[..., -1:]], -1)\n        lower = torch.cat([z_vals[..., :1], mids], -1)\n        # stratified samples in those intervals\n        t_rand = torch.rand(z_vals.shape)\n        z_vals = lower + (upper - lower) * t_rand\n\n    # Points in space to evaluate model at.\n    pts = rays_o[..., None, :] + rays_d[..., None, :] * \\\n        z_vals[..., :, None] # [N_rays, N_samples, 3]\n\n    # Add the time dimension to xyz.\n    pts_ref = torch.cat([pts, torch.ones_like(pts[..., 0:1]) * t], -1)\n\n    # First pass: we have the staticNeRF results\n    raw_s = network_query_fn_s(pts_ref[..., :3], viewdirs, network_fn_s)\n    # raw_s:          [N_rays, N_samples, 5]\n    # raw_s_rgb:      [N_rays, N_samples, 0:3]\n    # raw_s_a:        [N_rays, N_samples, 3:4]\n    # raw_s_blending: [N_rays, N_samples, 4:5]\n\n    # Second pass: we have the DyanmicNeRF results and the blending weight\n    raw_d = network_query_fn_d(pts_ref, viewdirs, network_fn_d)\n    # raw_d:          [N_rays, N_samples, 11]\n    # raw_d_rgb:      [N_rays, N_samples, 0:3]\n    # raw_d_a:        [N_rays, N_samples, 3:4]\n    # sceneflow_b:    [N_rays, N_samples, 4:7]\n    # sceneflow_f:    [N_rays, N_samples, 7:10]\n    # raw_d_blending: [N_rays, N_samples, 10:11]\n\n    if pretrain:\n        rgb_map_s, _ = raw2outputs_d(raw_s[..., :4],\n                                     z_vals,\n                                     rays_d,\n                                     raw_noise_std)\n        ret = {'rgb_map_s': rgb_map_s}\n        return ret\n\n    raw_s_rgba = raw_s[..., :4]\n    raw_d_rgba = raw_d[..., :4]\n\n    # We need the sceneflow from the dynamicNeRF.\n    sceneflow_b = raw_d[..., 4:7]\n    sceneflow_f = raw_d[..., 7:10]\n\n    if DyNeRF_blending:\n        blending = raw_d[..., 10]\n    else:\n        blending = raw_s[..., 4]\n\n    # if sfmask:\n    #     sceneflow_f = sceneflow_f * blending.detach()[..., None]\n    #     sceneflow_b = sceneflow_b * blending.detach()[..., None]\n\n    # Rerndering.\n    rgb_map_full, depth_map_full, acc_map_full, weights_full, \\\n    rgb_map_s, depth_map_s, acc_map_s, weights_s, \\\n    rgb_map_d, depth_map_d, acc_map_d, weights_d, \\\n    dynamicness_map = raw2outputs(raw_s_rgba,\n                                  raw_d_rgba,\n                                  blending,\n                                  z_vals,\n                                  rays_d,\n                                  raw_noise_std)\n\n    ret = {'rgb_map_full': rgb_map_full, 'depth_map_full': depth_map_full, 'acc_map_full': acc_map_full, 'weights_full': weights_full,\n           'rgb_map_s': rgb_map_s, 'depth_map_s': depth_map_s, 'acc_map_s': acc_map_s, 'weights_s': weights_s,\n           'rgb_map_d': rgb_map_d, 'depth_map_d': depth_map_d, 'acc_map_d': acc_map_d, 'weights_d': weights_d,\n           'dynamicness_map': dynamicness_map}\n\n    t_interval = 1. / num_img * 2.\n    pts_f = torch.cat([pts + sceneflow_f, torch.ones_like(pts[..., 0:1]) * (t + t_interval)], -1)\n    pts_b = torch.cat([pts + sceneflow_b, torch.ones_like(pts[..., 0:1]) * (t - t_interval)], -1)\n\n    ret['sceneflow_b'] = sceneflow_b\n    ret['sceneflow_f'] = sceneflow_f\n    ret['raw_pts'] = pts_ref[..., :3]\n    ret['raw_pts_f'] = pts_f[..., :3]\n    ret['raw_pts_b'] = pts_b[..., :3]\n    ret['blending'] = blending\n\n    # Third pass: we have the DyanmicNeRF results at time t - 1\n    raw_d_b = network_query_fn_d(pts_b, viewdirs, network_fn_d)\n    raw_d_b_rgba = raw_d_b[..., :4]\n    sceneflow_b_b = raw_d_b[..., 4:7]\n    sceneflow_b_f = raw_d_b[..., 7:10]\n\n    # Rerndering t - 1\n    rgb_map_d_b, weights_d_b = raw2outputs_d(raw_d_b_rgba,\n                                             z_vals,\n                                             rays_d,\n                                             raw_noise_std)\n\n    ret['sceneflow_b_f'] = sceneflow_b_f\n    ret['rgb_map_d_b'] = rgb_map_d_b\n    ret['acc_map_d_b'] = torch.abs(torch.sum(weights_d_b - weights_d, -1))\n\n    # Fourth pass: we have the DyanmicNeRF results at time t + 1\n    raw_d_f = network_query_fn_d(pts_f, viewdirs, network_fn_d)\n    raw_d_f_rgba = raw_d_f[..., :4]\n    sceneflow_f_b = raw_d_f[..., 4:7]\n    sceneflow_f_f = raw_d_f[..., 7:10]\n\n    rgb_map_d_f, weights_d_f = raw2outputs_d(raw_d_f_rgba,\n                                             z_vals,\n                                             rays_d,\n                                             raw_noise_std)\n\n    ret['sceneflow_f_b'] = sceneflow_f_b\n    ret['rgb_map_d_f'] = rgb_map_d_f\n    ret['acc_map_d_f'] = torch.abs(torch.sum(weights_d_f - weights_d, -1))\n\n    if inference:\n        return ret\n\n    # Also consider time t - 2 and t + 2 (Learn from NSFF)\n\n    # Fifth pass: we have the DyanmicNeRF results at time t - 2\n    pts_b_b = torch.cat([pts_b[..., :3] + sceneflow_b_b, torch.ones_like(pts[..., 0:1]) * (t - t_interval * 2)], -1)\n    ret['raw_pts_b_b'] = pts_b_b[..., :3]\n\n    if chain_5frames:\n        raw_d_b_b = network_query_fn_d(pts_b_b, viewdirs, network_fn_d)\n        raw_d_b_b_rgba = raw_d_b_b[..., :4]\n        rgb_map_d_b_b, _ = raw2outputs_d(raw_d_b_b_rgba,\n                                      z_vals,\n                                      rays_d,\n                                      raw_noise_std)\n\n        ret['rgb_map_d_b_b'] = rgb_map_d_b_b\n\n    # Sixth pass: we have the DyanmicNeRF results at time t + 2\n    pts_f_f = torch.cat([pts_f[..., :3] + sceneflow_f_f, torch.ones_like(pts[..., 0:1]) * (t + t_interval * 2)], -1)\n    ret['raw_pts_f_f'] = pts_f_f[..., :3]\n\n    if chain_5frames:\n        raw_d_f_f = network_query_fn_d(pts_f_f, viewdirs, network_fn_d)\n        raw_d_f_f_rgba = raw_d_f_f[..., :4]\n        rgb_map_d_f_f, _ = raw2outputs_d(raw_d_f_f_rgba,\n                                      z_vals,\n                                      rays_d,\n                                      raw_noise_std)\n\n        ret['rgb_map_d_f_f'] = rgb_map_d_f_f\n\n    for k in ret:\n        if torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any():\n            print(f\"! [Numerical Error] {k} contains nan or inf.\")\n            import ipdb; ipdb.set_trace()\n\n    return ret\n"
  },
  {
    "path": "run_nerf.py",
    "content": "import os\nimport time\nimport torch\nimport imageio\nimport numpy as np\nfrom torch.utils.tensorboard import SummaryWriter\n\nfrom render_utils import *\nfrom run_nerf_helpers import *\nfrom load_llff import *\nfrom utils.flow_utils import flow_to_image\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef config_parser():\n\n    import configargparse\n    parser = configargparse.ArgumentParser()\n    parser.add_argument('--config', is_config_file=True,\n                        help='config file path')\n    parser.add_argument(\"--expname\", type=str,\n                        help='experiment name')\n    parser.add_argument(\"--basedir\", type=str, default='./logs/',\n                        help='where to store ckpts and logs')\n    parser.add_argument(\"--datadir\", type=str, default='./data/llff/fern',\n                        help='input data directory')\n\n    # training options\n    parser.add_argument(\"--netdepth\", type=int, default=8,\n                        help='layers in network')\n    parser.add_argument(\"--netwidth\", type=int, default=256,\n                        help='channels per layer')\n    parser.add_argument(\"--netdepth_fine\", type=int, default=8,\n                        help='layers in fine network')\n    parser.add_argument(\"--netwidth_fine\", type=int, default=256,\n                        help='channels per layer in fine network')\n    parser.add_argument(\"--N_rand\", type=int, default=32*32*4,\n                        help='batch size (number of random rays per gradient step)')\n    parser.add_argument(\"--lrate\", type=float, default=5e-4,\n                        help='learning rate')\n    parser.add_argument(\"--lrate_decay\", type=int, default=300000,\n                        help='exponential learning rate decay')\n    parser.add_argument(\"--chunk\", type=int, default=1024*128,\n                        help='number of rays processed in parallel, decrease if running out of memory')\n    parser.add_argument(\"--netchunk\", type=int, default=1024*128,\n                        help='number of pts sent through network in parallel, decrease if running out of memory')\n    parser.add_argument(\"--no_reload\", action='store_true',\n                        help='do not reload weights from saved ckpt')\n    parser.add_argument(\"--ft_path\", type=str, default=None,\n                        help='specific weights npy file to reload for coarse network')\n    parser.add_argument(\"--random_seed\", type=int, default=1,\n                        help='fix random seed for repeatability')\n\n    # rendering options\n    parser.add_argument(\"--N_samples\", type=int, default=64,\n                        help='number of coarse samples per ray')\n    parser.add_argument(\"--N_importance\", type=int, default=0,\n                        help='number of additional fine samples per ray')\n    parser.add_argument(\"--perturb\", type=float, default=1.,\n                        help='set to 0. for no jitter, 1. for jitter')\n    parser.add_argument(\"--use_viewdirs\", action='store_true',\n                        help='use full 5D input instead of 3D')\n    parser.add_argument(\"--use_viewdirsDyn\", action='store_true',\n                        help='use full 5D input instead of 3D for D-NeRF')\n    parser.add_argument(\"--i_embed\", type=int, default=0,\n                        help='set 0 for default positional encoding, -1 for none')\n    parser.add_argument(\"--multires\", type=int, default=10,\n                        help='log2 of max freq for positional encoding (3D location)')\n    parser.add_argument(\"--multires_views\", type=int, default=4,\n                        help='log2 of max freq for positional encoding (2D direction)')\n    parser.add_argument(\"--raw_noise_std\", type=float, default=0.,\n                        help='std dev of noise added to regularize sigma_a output, 1e0 recommended')\n    parser.add_argument(\"--render_only\", action='store_true',\n                        help='do not optimize, reload weights and render out render_poses path')\n\n    # dataset options\n    parser.add_argument(\"--dataset_type\", type=str, default='llff',\n                        help='options: llff')\n\n    # llff flags\n    parser.add_argument(\"--factor\", type=int, default=8,\n                        help='downsample factor for LLFF images')\n    parser.add_argument(\"--no_ndc\", action='store_true',\n                        help='do not use normalized device coordinates (set for non-forward facing scenes)')\n    parser.add_argument(\"--lindisp\", action='store_true',\n                        help='sampling linearly in disparity rather than depth')\n    parser.add_argument(\"--spherify\", action='store_true',\n                        help='set for spherical 360 scenes')\n\n    # logging/saving options\n    parser.add_argument(\"--i_print\",   type=int, default=500,\n                        help='frequency of console printout and metric logging')\n    parser.add_argument(\"--i_img\",     type=int, default=500,\n                        help='frequency of tensorboard image logging')\n    parser.add_argument(\"--i_weights\", type=int, default=10000,\n                        help='frequency of weight ckpt saving')\n    parser.add_argument(\"--i_testset\", type=int, default=50000,\n                        help='frequency of testset saving')\n    parser.add_argument(\"--i_video\",   type=int, default=50000,\n                        help='frequency of render_poses video saving')\n    parser.add_argument(\"--N_iters\", type=int, default=1000000,\n                        help='number of training iterations')\n\n    # Dynamic NeRF lambdas\n    parser.add_argument(\"--dynamic_loss_lambda\", type=float, default=1.,\n                        help='lambda of dynamic loss')\n    parser.add_argument(\"--static_loss_lambda\", type=float, default=1.,\n                        help='lambda of static loss')\n    parser.add_argument(\"--full_loss_lambda\", type=float, default=3.,\n                        help='lambda of full loss')\n    parser.add_argument(\"--depth_loss_lambda\", type=float, default=0.04,\n                        help='lambda of depth loss')\n    parser.add_argument(\"--order_loss_lambda\", type=float, default=0.1,\n                        help='lambda of order loss')\n    parser.add_argument(\"--flow_loss_lambda\", type=float, default=0.02,\n                        help='lambda of optical flow loss')\n    parser.add_argument(\"--slow_loss_lambda\", type=float, default=0.1,\n                        help='lambda of sf slow regularization')\n    parser.add_argument(\"--smooth_loss_lambda\", type=float, default=0.1,\n                        help='lambda of sf smooth regularization')\n    parser.add_argument(\"--consistency_loss_lambda\", type=float, default=0.1,\n                        help='lambda of sf cycle consistency regularization')\n    parser.add_argument(\"--mask_loss_lambda\", type=float, default=0.1,\n                        help='lambda of the mask loss')\n    parser.add_argument(\"--sparse_loss_lambda\", type=float, default=0.1,\n                        help='lambda of sparse loss')\n    parser.add_argument(\"--DyNeRF_blending\", action='store_true',\n                        help='use Dynamic NeRF to predict blending weight')\n    parser.add_argument(\"--pretrain\", action='store_true',\n                        help='Pretrain the StaticneRF')\n    parser.add_argument(\"--ft_path_S\", type=str, default=None,\n                        help='specific weights npy file to reload for StaticNeRF')\n\n    # For rendering teasers\n    parser.add_argument(\"--frame2dolly\", type=int, default=-1,\n                        help='choose frame to perform dolly zoom')\n    parser.add_argument(\"--x_trans_multiplier\", type=float, default=1.,\n                        help='x_trans_multiplier')\n    parser.add_argument(\"--y_trans_multiplier\", type=float, default=0.33,\n                        help='y_trans_multiplier')\n    parser.add_argument(\"--z_trans_multiplier\", type=float, default=5.,\n                        help='z_trans_multiplier')\n    parser.add_argument(\"--num_novelviews\", type=int, default=60,\n                        help='num_novelviews')\n    parser.add_argument(\"--focal_decrease\", type=float, default=200,\n                        help='focal_decrease')\n    return parser\n\n\ndef train():\n\n    parser = config_parser()\n    args = parser.parse_args()\n\n    if args.random_seed is not None:\n        print('Fixing random seed', args.random_seed)\n        np.random.seed(args.random_seed)\n\n    # Load data\n    if args.dataset_type == 'llff':\n        frame2dolly = args.frame2dolly\n        images, invdepths, masks, poses, bds, \\\n        render_poses, render_focals, grids = load_llff_data(args, args.datadir,\n                                                            args.factor,\n                                                            frame2dolly=frame2dolly,\n                                                            recenter=True, bd_factor=.9,\n                                                            spherify=args.spherify)\n\n        hwf = poses[0, :3, -1]\n        poses = poses[:, :3, :4]\n        num_img = float(poses.shape[0])\n        assert len(poses) == len(images)\n        print('Loaded llff', images.shape,\n            render_poses.shape, hwf, args.datadir)\n\n        # Use all views to train\n        i_train = np.array([i for i in np.arange(int(images.shape[0]))])\n\n        print('DEFINING BOUNDS')\n        if args.no_ndc:\n            raise NotImplementedError\n            near = np.ndarray.min(bds) * .9\n            far = np.ndarray.max(bds) * 1.\n        else:\n            near = 0.\n            far = 1.\n        print('NEAR FAR', near, far)\n    else:\n        print('Unknown dataset type', args.dataset_type, 'exiting')\n        return\n\n    # Cast intrinsics to right types\n    H, W, focal = hwf\n    H, W = int(H), int(W)\n    hwf = [H, W, focal]\n\n    # Create log dir and copy the config file\n    basedir = args.basedir\n    expname = args.expname\n    os.makedirs(os.path.join(basedir, expname), exist_ok=True)\n\n    if not args.render_only:\n        f = os.path.join(basedir, expname, 'args.txt')\n        with open(f, 'w') as file:\n            for arg in sorted(vars(args)):\n                attr = getattr(args, arg)\n                file.write('{} = {}\\n'.format(arg, attr))\n        if args.config is not None:\n            f = os.path.join(basedir, expname, 'config.txt')\n            with open(f, 'w') as file:\n                file.write(open(args.config, 'r').read())\n\n    # Create nerf model\n    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(args)\n    global_step = start\n\n    bds_dict = {\n        'near': near,\n        'far': far,\n        'num_img': num_img,\n    }\n    render_kwargs_train.update(bds_dict)\n    render_kwargs_test.update(bds_dict)\n\n    # Short circuit if only rendering out from trained model\n    if args.render_only:\n        print('RENDER ONLY')\n        i = start - 1\n\n        # Change time and change view at the same time.\n        time2render = np.concatenate((np.repeat((i_train / float(num_img) * 2. - 1.0), 4),\n                                      np.repeat((i_train / float(num_img) * 2. - 1.0)[::-1][1:-1], 4)))\n        if len(time2render) > len(render_poses):\n            pose2render = np.tile(render_poses, (int(np.ceil(len(time2render) / len(render_poses))), 1, 1))\n            pose2render = pose2render[:len(time2render)]\n            pose2render = torch.Tensor(pose2render)\n        else:\n            time2render = np.tile(time2render, int(np.ceil(len(render_poses) / len(time2render))))\n            time2render = time2render[:len(render_poses)]\n            pose2render = torch.Tensor(render_poses)\n        result_type = 'novelviewtime'\n\n        testsavedir = os.path.join(\n            basedir, expname, result_type + '_{:06d}'.format(i))\n        os.makedirs(testsavedir, exist_ok=True)\n        with torch.no_grad():\n            ret = render_path(pose2render, time2render,\n                              hwf, args.chunk, render_kwargs_test, savedir=testsavedir)\n        moviebase = os.path.join(\n            testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))\n        save_res(moviebase, ret)\n\n        # Fix view (first view) and change time.\n        pose2render = torch.Tensor(poses[0:1, ...]).expand([int(num_img), 3, 4])\n        time2render = i_train / float(num_img) * 2. - 1.0\n        result_type = 'testset_view000'\n\n        testsavedir = os.path.join(\n            basedir, expname, result_type + '_{:06d}'.format(i))\n        os.makedirs(testsavedir, exist_ok=True)\n        with torch.no_grad():\n            ret = render_path(pose2render, time2render,\n                              hwf, args.chunk, render_kwargs_test, savedir=testsavedir)\n        moviebase = os.path.join(\n            testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))\n        save_res(moviebase, ret)\n\n        return\n\n    N_rand = args.N_rand\n\n    # Move training data to GPU\n    images = torch.Tensor(images)\n    invdepths = torch.Tensor(invdepths)\n    masks = 1.0 - torch.Tensor(masks)\n    poses = torch.Tensor(poses)\n    grids = torch.Tensor(grids)\n\n    print('Begin')\n    print('TRAIN views are', i_train)\n\n    # Summary writers\n    writer = SummaryWriter(os.path.join(basedir, 'summaries', expname))\n\n    decay_iteration = max(25, num_img)\n\n    # Pre-train StaticNeRF\n    if args.pretrain:\n        render_kwargs_train.update({'pretrain': True})\n\n        # Pre-train StaticNeRF first and use DynamicNeRF to blend\n        assert args.DyNeRF_blending == True\n\n        if args.ft_path_S is not None and args.ft_path_S != 'None':\n            # Load Pre-trained StaticNeRF\n            ckpt_path = args.ft_path_S\n            print('Reloading StaticNeRF from', ckpt_path)\n            ckpt = torch.load(ckpt_path)\n            render_kwargs_train['network_fn_s'].load_state_dict(ckpt['network_fn_s_state_dict'])\n        else:\n            # Train StaticNeRF from scratch\n            for i in range(args.N_iters):\n                time0 = time.time()\n\n                # No raybatching as we need to take random rays from one image at a time\n                img_i = np.random.choice(i_train)\n                t = img_i / num_img * 2. - 1.0 # time of the current frame\n                target = images[img_i]\n                pose = poses[img_i, :3, :4]\n                mask = masks[img_i] # Static region mask\n\n                rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)\n                coords_s = torch.stack((torch.where(mask >= 0.5)), -1)\n                select_inds_s = np.random.choice(coords_s.shape[0], size=[N_rand], replace=False)\n                select_coords = coords_s[select_inds_s]\n\n                def select_batch(value, select_coords=select_coords):\n                    return value[select_coords[:, 0], select_coords[:, 1]]\n\n                rays_o = select_batch(rays_o) # (N_rand, 3)\n                rays_d = select_batch(rays_d) # (N_rand, 3)\n                target_rgb = select_batch(target)\n                batch_mask = select_batch(mask[..., None])\n                batch_rays = torch.stack([rays_o, rays_d], 0)\n\n                #####  Core optimization loop  #####\n                ret = render(t,\n                             False,\n                             H, W, focal,\n                             chunk=args.chunk,\n                             rays=batch_rays,\n                             **render_kwargs_train)\n\n                optimizer.zero_grad()\n\n                # Compute MSE loss between rgb_s and true RGB.\n                img_s_loss = img2mse(ret['rgb_map_s'], target_rgb)\n                psnr_s = mse2psnr(img_s_loss)\n                loss = args.static_loss_lambda * img_s_loss\n\n                loss.backward()\n                optimizer.step()\n\n                # Learning rate decay.\n                decay_rate = 0.1\n                decay_steps = args.lrate_decay\n                new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))\n                for param_group in optimizer.param_groups:\n                    param_group['lr'] = new_lrate\n\n                dt = time.time() - time0\n\n                print(f\"Pretraining step: {global_step}, Loss: {loss}, Time: {dt}, expname: {expname}\")\n\n                if i % args.i_print == 0:\n                    writer.add_scalar(\"loss\", loss.item(), i)\n                    writer.add_scalar(\"lr\", new_lrate, i)\n                    writer.add_scalar(\"psnr_s\", psnr_s.item(), i)\n\n                if i % args.i_img == 0:\n                    target = images[img_i]\n                    pose = poses[img_i, :3, :4]\n                    mask = masks[img_i]\n\n                    with torch.no_grad():\n                        ret = render(t,\n                                     False,\n                                     H, W, focal,\n                                     chunk=1024*16,\n                                     c2w=pose,\n                                     **render_kwargs_test)\n\n                        # Save out the validation image for Tensorboard-free monitoring\n                        writer.add_image(\"rgb_holdout\", target, global_step=i, dataformats='HWC')\n                        writer.add_image(\"mask\", mask, global_step=i, dataformats='HW')\n                        writer.add_image(\"rgb_s\", torch.clamp(ret['rgb_map_s'], 0., 1.), global_step=i, dataformats='HWC')\n                        writer.add_image(\"depth_s\", normalize_depth(ret['depth_map_s']), global_step=i, dataformats='HW')\n                        writer.add_image(\"acc_s\", ret['acc_map_s'], global_step=i, dataformats='HW')\n\n                global_step += 1\n\n        # Save the pretrained weight\n        torch.save({\n            'global_step': global_step,\n            'network_fn_s_state_dict': render_kwargs_train['network_fn_s'].state_dict(),\n            'optimizer_state_dict': optimizer.state_dict(),\n        }, os.path.join(basedir, expname, 'Pretrained_S.tar'))\n\n        # Reset\n        render_kwargs_train.update({'pretrain': False})\n        global_step = start\n\n        # Fix the StaticNeRF and only train the DynamicNeRF\n        grad_vars_d = list(render_kwargs_train['network_fn_d'].parameters())\n        optimizer = torch.optim.Adam(params=grad_vars_d, lr=args.lrate, betas=(0.9, 0.999))\n\n    for i in range(start, args.N_iters):\n        time0 = time.time()\n\n        # Use frames at t-2, t-1, t, t+1, t+2 (adapted from NSFF)\n        if i < decay_iteration * 2000:\n            chain_5frames = False\n        else:\n            chain_5frames = True\n\n        # Lambda decay.\n        Temp = 1. / (10 ** (i // (decay_iteration * 1000)))\n\n        if i % (decay_iteration * 1000) == 0:\n            torch.cuda.empty_cache()\n\n        # No raybatching as we need to take random rays from one image at a time\n        img_i = np.random.choice(i_train)\n        t = img_i / num_img * 2. - 1.0 # time of the current frame\n        target = images[img_i]\n        pose = poses[img_i, :3, :4]\n        mask = masks[img_i] # Static region mask\n        invdepth = invdepths[img_i]\n        grid = grids[img_i]\n\n        rays_o, rays_d = get_rays(H, W, focal, torch.Tensor(pose)) # (H, W, 3), (H, W, 3)\n        coords_d = torch.stack((torch.where(mask < 0.5)), -1)\n        coords_s = torch.stack((torch.where(mask >= 0.5)), -1)\n        coords = torch.stack((torch.where(mask > -1)), -1)\n\n        # Evenly sample dynamic region and static region\n        select_inds_d = np.random.choice(coords_d.shape[0], size=[min(len(coords_d), N_rand//2)], replace=False)\n        select_inds_s = np.random.choice(coords_s.shape[0], size=[N_rand//2], replace=False)\n        select_coords = torch.cat([coords_s[select_inds_s],\n                                   coords_d[select_inds_d]], 0)\n\n        def select_batch(value, select_coords=select_coords):\n            return value[select_coords[:, 0], select_coords[:, 1]]\n\n        rays_o = select_batch(rays_o) # (N_rand, 3)\n        rays_d = select_batch(rays_d) # (N_rand, 3)\n        target_rgb = select_batch(target)\n        batch_grid = select_batch(grid) # (N_rand, 8)\n        batch_mask = select_batch(mask[..., None])\n        batch_invdepth = select_batch(invdepth)\n        batch_rays = torch.stack([rays_o, rays_d], 0)\n\n        #####  Core optimization loop  #####\n        ret = render(t,\n                     chain_5frames,\n                     H, W, focal,\n                     chunk=args.chunk,\n                     rays=batch_rays,\n                     **render_kwargs_train)\n\n        optimizer.zero_grad()\n        loss = 0\n        loss_dict = {}\n\n        # Compute MSE loss between rgb_full and true RGB.\n        img_loss = img2mse(ret['rgb_map_full'], target_rgb)\n        psnr = mse2psnr(img_loss)\n        loss_dict['psnr'] = psnr\n        loss_dict['img_loss'] = img_loss\n        loss += args.full_loss_lambda * loss_dict['img_loss']\n\n        # Compute MSE loss between rgb_s and true RGB.\n        img_s_loss = img2mse(ret['rgb_map_s'], target_rgb, batch_mask)\n        psnr_s = mse2psnr(img_s_loss)\n        loss_dict['psnr_s'] = psnr_s\n        loss_dict['img_s_loss'] = img_s_loss\n        loss += args.static_loss_lambda * loss_dict['img_s_loss']\n\n        # Compute MSE loss between rgb_d and true RGB.\n        img_d_loss = img2mse(ret['rgb_map_d'], target_rgb)\n        psnr_d = mse2psnr(img_d_loss)\n        loss_dict['psnr_d'] = psnr_d\n        loss_dict['img_d_loss'] = img_d_loss\n        loss += args.dynamic_loss_lambda * loss_dict['img_d_loss']\n\n        # Compute MSE loss between rgb_d_f and true RGB.\n        img_d_f_loss = img2mse(ret['rgb_map_d_f'], target_rgb)\n        psnr_d_f = mse2psnr(img_d_f_loss)\n        loss_dict['psnr_d_f'] = psnr_d_f\n        loss_dict['img_d_f_loss'] = img_d_f_loss\n        loss += args.dynamic_loss_lambda * loss_dict['img_d_f_loss']\n\n        # Compute MSE loss between rgb_d_b and true RGB.\n        img_d_b_loss = img2mse(ret['rgb_map_d_b'], target_rgb)\n        psnr_d_b = mse2psnr(img_d_b_loss)\n        loss_dict['psnr_d_b'] = psnr_d_b\n        loss_dict['img_d_b_loss'] = img_d_b_loss\n        loss += args.dynamic_loss_lambda * loss_dict['img_d_b_loss']\n\n        # Motion loss.\n        # Compuate EPE between induced flow and true flow (forward flow).\n        # The last frame does not have forward flow.\n        if img_i < num_img - 1:\n            pts_f = ret['raw_pts_f']\n            weight = ret['weights_d']\n            pose_f = poses[img_i + 1, :3, :4]\n            induced_flow_f = induce_flow(H, W, focal, pose_f, weight, pts_f, batch_grid[..., :2])\n            flow_f_loss = img2mae(induced_flow_f, batch_grid[:, 2:4], batch_grid[:, 4:5])\n            loss_dict['flow_f_loss'] = flow_f_loss\n            loss += args.flow_loss_lambda * Temp * loss_dict['flow_f_loss']\n\n        # Compuate EPE between induced flow and true flow (backward flow).\n        # The first frame does not have backward flow.\n        if img_i > 0:\n            pts_b = ret['raw_pts_b']\n            weight = ret['weights_d']\n            pose_b = poses[img_i - 1, :3, :4]\n            induced_flow_b = induce_flow(H, W, focal, pose_b, weight, pts_b, batch_grid[..., :2])\n            flow_b_loss = img2mae(induced_flow_b, batch_grid[:, 5:7], batch_grid[:, 7:8])\n            loss_dict['flow_b_loss'] = flow_b_loss\n            loss += args.flow_loss_lambda * Temp * loss_dict['flow_b_loss']\n\n        # Slow scene flow. The forward and backward sceneflow should be small.\n        slow_loss = L1(ret['sceneflow_b']) + L1(ret['sceneflow_f'])\n        loss_dict['slow_loss'] = slow_loss\n        loss += args.slow_loss_lambda * loss_dict['slow_loss']\n\n        # Smooth scene flow. The summation of the forward and backward sceneflow should be small.\n        smooth_loss = compute_sf_smooth_loss(ret['raw_pts'],\n                                             ret['raw_pts_f'],\n                                             ret['raw_pts_b'],\n                                             H, W, focal)\n        loss_dict['smooth_loss'] = smooth_loss\n        loss += args.smooth_loss_lambda * loss_dict['smooth_loss']\n\n        # Spatial smooth scene flow. (loss adapted from NSFF)\n        sp_smooth_loss = compute_sf_smooth_s_loss(ret['raw_pts'], ret['raw_pts_f'], H, W, focal) \\\n                       + compute_sf_smooth_s_loss(ret['raw_pts'], ret['raw_pts_b'], H, W, focal)\n        loss_dict['sp_smooth_loss'] = sp_smooth_loss\n        loss += args.smooth_loss_lambda * loss_dict['sp_smooth_loss']\n\n        # Consistency loss.\n        consistency_loss = L1(ret['sceneflow_f'] + ret['sceneflow_f_b']) + \\\n                           L1(ret['sceneflow_b'] + ret['sceneflow_b_f'])\n        loss_dict['consistency_loss'] = consistency_loss\n        loss += args.consistency_loss_lambda * loss_dict['consistency_loss']\n\n        # Mask loss.\n        mask_loss = L1(ret['blending'][batch_mask[:, 0].type(torch.bool)]) + \\\n                    img2mae(ret['dynamicness_map'][..., None], 1 - batch_mask)\n        loss_dict['mask_loss'] = mask_loss\n        if i < decay_iteration * 1000:\n            loss += args.mask_loss_lambda * loss_dict['mask_loss']\n\n        # Sparsity loss.\n        sparse_loss = entropy(ret['weights_d']) + entropy(ret['blending'])\n        loss_dict['sparse_loss'] = sparse_loss\n        loss += args.sparse_loss_lambda * loss_dict['sparse_loss']\n\n        # Depth constraint\n        # Depth in NDC space equals to negative disparity in Euclidean space.\n        depth_loss = compute_depth_loss(ret['depth_map_d'], -batch_invdepth)\n        loss_dict['depth_loss'] = depth_loss\n        loss += args.depth_loss_lambda * Temp * loss_dict['depth_loss']\n\n        # Order loss\n        order_loss = torch.mean(torch.square(ret['depth_map_d'][batch_mask[:, 0].type(torch.bool)] - \\\n                                             ret['depth_map_s'].detach()[batch_mask[:, 0].type(torch.bool)]))\n        loss_dict['order_loss'] = order_loss\n        loss += args.order_loss_lambda * loss_dict['order_loss']\n\n        sf_smooth_loss = compute_sf_smooth_loss(ret['raw_pts_b'],\n                                                ret['raw_pts'],\n                                                ret['raw_pts_b_b'],\n                                                H, W, focal) + \\\n                         compute_sf_smooth_loss(ret['raw_pts_f'],\n                                                ret['raw_pts_f_f'],\n                                                ret['raw_pts'],\n                                                H, W, focal)\n        loss_dict['sf_smooth_loss'] = sf_smooth_loss\n        loss += args.smooth_loss_lambda * loss_dict['sf_smooth_loss']\n\n        if chain_5frames:\n            img_d_b_b_loss = img2mse(ret['rgb_map_d_b_b'], target_rgb)\n            loss_dict['img_d_b_b_loss'] = img_d_b_b_loss\n            loss += args.dynamic_loss_lambda * loss_dict['img_d_b_b_loss']\n\n            img_d_f_f_loss = img2mse(ret['rgb_map_d_f_f'], target_rgb)\n            loss_dict['img_d_f_f_loss'] = img_d_f_f_loss\n            loss += args.dynamic_loss_lambda * loss_dict['img_d_f_f_loss']\n\n        loss.backward()\n        optimizer.step()\n\n        # Learning rate decay.\n        decay_rate = 0.1\n        decay_steps = args.lrate_decay\n        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = new_lrate\n\n        dt = time.time() - time0\n\n        print(f\"Step: {global_step}, Loss: {loss}, Time: {dt}, chain_5frames: {chain_5frames}, expname: {expname}\")\n\n        # Rest is logging\n        if i % args.i_weights==0:\n            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))\n\n            if args.N_importance > 0:\n                raise NotImplementedError\n            else:\n                torch.save({\n                    'global_step': global_step,\n                    'network_fn_d_state_dict': render_kwargs_train['network_fn_d'].state_dict(),\n                    'network_fn_s_state_dict': render_kwargs_train['network_fn_s'].state_dict(),\n                    'optimizer_state_dict': optimizer.state_dict(),\n                }, path)\n\n            print('Saved weights at', path)\n\n        if i % args.i_video == 0 and i > 0:\n\n            # Change time and change view at the same time.\n            time2render = np.concatenate((np.repeat((i_train / float(num_img) * 2. - 1.0), 4),\n                                          np.repeat((i_train / float(num_img) * 2. - 1.0)[::-1][1:-1], 4)))\n            if len(time2render) > len(render_poses):\n                pose2render = np.tile(render_poses, (int(np.ceil(len(time2render) / len(render_poses))), 1, 1))\n                pose2render = pose2render[:len(time2render)]\n                pose2render = torch.Tensor(pose2render)\n            else:\n                time2render = np.tile(time2render, int(np.ceil(len(render_poses) / len(time2render))))\n                time2render = time2render[:len(render_poses)]\n                pose2render = torch.Tensor(render_poses)\n            result_type = 'novelviewtime'\n\n            testsavedir = os.path.join(\n                basedir, expname, result_type + '_{:06d}'.format(i))\n            os.makedirs(testsavedir, exist_ok=True)\n            with torch.no_grad():\n                ret = render_path(pose2render, time2render,\n                                  hwf, args.chunk, render_kwargs_test, savedir=testsavedir)\n            moviebase = os.path.join(\n                testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))\n            save_res(moviebase, ret)\n\n        if i % args.i_testset == 0 and i > 0:\n\n            # Change view and time.\n            pose2render = torch.Tensor(poses)\n            time2render = i_train / float(num_img) * 2. - 1.0\n            result_type = 'testset'\n\n            testsavedir = os.path.join(\n                basedir, expname, result_type + '_{:06d}'.format(i))\n            os.makedirs(testsavedir, exist_ok=True)\n            with torch.no_grad():\n                ret = render_path(pose2render, time2render,\n                                  hwf, args.chunk, render_kwargs_test, savedir=testsavedir,\n                                  flows_gt_f=grids[:, :, :, 2:4], flows_gt_b=grids[:, :, :, 5:7])\n            moviebase = os.path.join(\n                testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))\n            save_res(moviebase, ret)\n\n            # Fix view (first view) and change time.\n            pose2render = torch.Tensor(poses[0:1, ...].expand([int(num_img), 3, 4]))\n            time2render = i_train / float(num_img) * 2. - 1.0\n            result_type = 'testset_view000'\n\n            testsavedir = os.path.join(\n                basedir, expname, result_type + '_{:06d}'.format(i))\n            os.makedirs(testsavedir, exist_ok=True)\n            with torch.no_grad():\n                ret = render_path(pose2render, time2render,\n                                  hwf, args.chunk, render_kwargs_test, savedir=testsavedir)\n            moviebase = os.path.join(\n                testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))\n            save_res(moviebase, ret)\n\n            # Fix time (the first timestamp) and change view.\n            pose2render = torch.Tensor(poses)\n            time2render = np.tile(i_train[0], [int(num_img)]) / float(num_img) * 2. - 1.0\n            result_type = 'testset_time000'\n\n            testsavedir = os.path.join(\n                basedir, expname, result_type + '_{:06d}'.format(i))\n            os.makedirs(testsavedir, exist_ok=True)\n            with torch.no_grad():\n                ret = render_path(pose2render, time2render,\n                                  hwf, args.chunk, render_kwargs_test, savedir=testsavedir)\n            moviebase = os.path.join(\n                testsavedir, '{}_{}_{:06d}_'.format(expname, result_type, i))\n            save_res(moviebase, ret)\n\n        if i % args.i_print == 0:\n            writer.add_scalar(\"loss\", loss.item(), i)\n            writer.add_scalar(\"lr\", new_lrate, i)\n            writer.add_scalar(\"Temp\", Temp, i)\n            for loss_key in loss_dict:\n                writer.add_scalar(loss_key, loss_dict[loss_key].item(), i)\n\n        if i % args.i_img == 0:\n            # Log a rendered training view to Tensorboard.\n            # img_i = np.random.choice(i_train[1:-1])\n            target = images[img_i]\n            pose = poses[img_i, :3, :4]\n            mask = masks[img_i]\n            grid = grids[img_i]\n            invdepth = invdepths[img_i]\n\n            flow_f_img = flow_to_image(grid[..., 2:4].cpu().numpy())\n            flow_b_img = flow_to_image(grid[..., 5:7].cpu().numpy())\n\n            with torch.no_grad():\n                ret = render(t,\n                             False,\n                             H, W, focal,\n                             chunk=1024*16,\n                             c2w=pose,\n                             **render_kwargs_test)\n\n                # The last frame does not have forward flow.\n                pose_f = poses[min(img_i + 1, int(num_img) - 1), :3, :4]\n                induced_flow_f = induce_flow(H, W, focal, pose_f, ret['weights_d'], ret['raw_pts_f'], grid[..., :2])\n\n                # The first frame does not have backward flow.\n                pose_b = poses[max(img_i - 1, 0), :3, :4]\n                induced_flow_b = induce_flow(H, W, focal, pose_b, ret['weights_d'], ret['raw_pts_b'], grid[..., :2])\n\n                induced_flow_f_img = flow_to_image(induced_flow_f.cpu().numpy())\n                induced_flow_b_img = flow_to_image(induced_flow_b.cpu().numpy())\n\n                psnr = mse2psnr(img2mse(ret['rgb_map_full'], target))\n\n                # Save out the validation image for Tensorboard-free monitoring\n                testimgdir = os.path.join(basedir, expname, 'tboard_val_imgs')\n                if i == 0:\n                    os.makedirs(testimgdir, exist_ok=True)\n                imageio.imwrite(os.path.join(testimgdir, '{:06d}.png'.format(i)), to8b(ret['rgb_map_full'].cpu().numpy()))\n\n                writer.add_scalar(\"psnr_holdout\", psnr.item(), i)\n                writer.add_image(\"rgb_holdout\", target, global_step=i, dataformats='HWC')\n                writer.add_image(\"mask\", mask, global_step=i, dataformats='HW')\n                writer.add_image(\"disp\", torch.clamp(invdepth / percentile(invdepth, 97), 0., 1.), global_step=i, dataformats='HW')\n\n                writer.add_image(\"rgb\", torch.clamp(ret['rgb_map_full'], 0., 1.), global_step=i, dataformats='HWC')\n                writer.add_image(\"depth\", normalize_depth(ret['depth_map_full']), global_step=i, dataformats='HW')\n                writer.add_image(\"acc\", ret['acc_map_full'], global_step=i, dataformats='HW')\n\n                writer.add_image(\"rgb_s\", torch.clamp(ret['rgb_map_s'], 0., 1.), global_step=i, dataformats='HWC')\n                writer.add_image(\"depth_s\", normalize_depth(ret['depth_map_s']), global_step=i, dataformats='HW')\n                writer.add_image(\"acc_s\", ret['acc_map_s'], global_step=i, dataformats='HW')\n\n                writer.add_image(\"rgb_d\", torch.clamp(ret['rgb_map_d'], 0., 1.), global_step=i, dataformats='HWC')\n                writer.add_image(\"depth_d\", normalize_depth(ret['depth_map_d']), global_step=i, dataformats='HW')\n                writer.add_image(\"acc_d\", ret['acc_map_d'], global_step=i, dataformats='HW')\n\n                writer.add_image(\"induced_flow_f\", induced_flow_f_img, global_step=i, dataformats='HWC')\n                writer.add_image(\"induced_flow_b\", induced_flow_b_img, global_step=i, dataformats='HWC')\n                writer.add_image(\"flow_f_gt\", flow_f_img, global_step=i, dataformats='HWC')\n                writer.add_image(\"flow_b_gt\", flow_b_img, global_step=i, dataformats='HWC')\n\n                writer.add_image(\"dynamicness\", ret['dynamicness_map'], global_step=i, dataformats='HW')\n\n        global_step += 1\n\n\nif __name__ == '__main__':\n    torch.set_default_tensor_type('torch.cuda.FloatTensor')\n    train()\n"
  },
  {
    "path": "run_nerf_helpers.py",
    "content": "import os\nimport torch\nimport imageio\nimport numpy as np\nimport torch.nn as nn\nimport torch.nn.functional as F\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\n# Misc utils\ndef img2mse(x, y, M=None):\n    if M == None:\n        return torch.mean((x - y) ** 2)\n    else:\n        return torch.sum((x - y) ** 2 * M) / (torch.sum(M) + 1e-8) / x.shape[-1]\n\n\ndef img2mae(x, y, M=None):\n    if M == None:\n        return torch.mean(torch.abs(x - y))\n    else:\n        return torch.sum(torch.abs(x - y) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]\n\n\ndef L1(x, M=None):\n    if M == None:\n        return torch.mean(torch.abs(x))\n    else:\n        return torch.sum(torch.abs(x) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]\n\n\ndef L2(x, M=None):\n    if M == None:\n        return torch.mean(x ** 2)\n    else:\n        return torch.sum((x ** 2) * M) / (torch.sum(M) + 1e-8) / x.shape[-1]\n\n\ndef entropy(x):\n    return -torch.sum(x * torch.log(x + 1e-19)) / x.shape[0]\n\n\ndef mse2psnr(x): return -10. * torch.log(x) / torch.log(torch.Tensor([10.]))\n\n\ndef to8b(x): return (255 * np.clip(x, 0, 1)).astype(np.uint8)\n\n\nclass Embedder:\n\n    def __init__(self, **kwargs):\n\n        self.kwargs = kwargs\n        self.create_embedding_fn()\n\n    def create_embedding_fn(self):\n\n        embed_fns = []\n        d = self.kwargs['input_dims']\n        out_dim = 0\n        if self.kwargs['include_input']:\n            embed_fns.append(lambda x: x)\n            out_dim += d\n\n        max_freq = self.kwargs['max_freq_log2']\n        N_freqs = self.kwargs['num_freqs']\n\n        if self.kwargs['log_sampling']:\n            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)\n        else:\n            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)\n\n        for freq in freq_bands:\n            for p_fn in self.kwargs['periodic_fns']:\n                embed_fns.append(lambda x, p_fn=p_fn,\n                                 freq=freq : p_fn(x * freq))\n                out_dim += d\n\n        self.embed_fns = embed_fns\n        self.out_dim = out_dim\n\n    def embed(self, inputs):\n        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)\n\n\ndef get_embedder(multires, i=0, input_dims=3):\n\n    if i == -1:\n        return nn.Identity(), 3\n\n    embed_kwargs = {\n        'include_input': True,\n        'input_dims': input_dims,\n        'max_freq_log2': multires-1,\n        'num_freqs': multires,\n        'log_sampling': True,\n        'periodic_fns': [torch.sin, torch.cos],\n    }\n\n    embedder_obj = Embedder(**embed_kwargs)\n    def embed(x, eo=embedder_obj): return eo.embed(x)\n    return embed, embedder_obj.out_dim\n\n\n# Dynamic NeRF model architecture\nclass NeRF_d(nn.Module):\n    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirsDyn=True):\n        \"\"\"\n        \"\"\"\n        super(NeRF_d, self).__init__()\n        self.D = D\n        self.W = W\n        self.input_ch = input_ch\n        self.input_ch_views = input_ch_views\n        self.skips = skips\n        self.use_viewdirsDyn = use_viewdirsDyn\n\n        self.pts_linears = nn.ModuleList(\n            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])\n\n        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])\n\n        if self.use_viewdirsDyn:\n            self.feature_linear = nn.Linear(W, W)\n            self.alpha_linear = nn.Linear(W, 1)\n            self.rgb_linear = nn.Linear(W//2, 3)\n        else:\n            self.output_linear = nn.Linear(W, output_ch)\n\n        self.sf_linear = nn.Linear(W, 6)\n        self.weight_linear = nn.Linear(W, 1)\n\n    def forward(self, x):\n        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)\n        h = input_pts\n        for i, l in enumerate(self.pts_linears):\n            h = self.pts_linears[i](h)\n            h = F.relu(h)\n            if i in self.skips:\n                h = torch.cat([input_pts, h], -1)\n\n        # Scene flow should be unbounded. However, in NDC space the coordinate is\n        # bounded in [-1, 1].\n        sf = torch.tanh(self.sf_linear(h))\n        blending = torch.sigmoid(self.weight_linear(h))\n\n        if self.use_viewdirsDyn:\n            alpha = self.alpha_linear(h)\n            feature = self.feature_linear(h)\n            h = torch.cat([feature, input_views], -1)\n\n            for i, l in enumerate(self.views_linears):\n                h = self.views_linears[i](h)\n                h = F.relu(h)\n\n            rgb = self.rgb_linear(h)\n            outputs = torch.cat([rgb, alpha], -1)\n        else:\n            outputs = self.output_linear(h)\n\n        return torch.cat([outputs, sf, blending], dim=-1)\n\n\n# Static NeRF model architecture\nclass NeRF_s(nn.Module):\n    def __init__(self, D=8, W=256, input_ch=3, input_ch_views=3, output_ch=4, skips=[4], use_viewdirs=True):\n        \"\"\"\n        \"\"\"\n        super(NeRF_s, self).__init__()\n        self.D = D\n        self.W = W\n        self.input_ch = input_ch\n        self.input_ch_views = input_ch_views\n        self.skips = skips\n        self.use_viewdirs = use_viewdirs\n\n        self.pts_linears = nn.ModuleList(\n            [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in self.skips else nn.Linear(W + input_ch, W) for i in range(D-1)])\n\n        self.views_linears = nn.ModuleList([nn.Linear(input_ch_views + W, W//2)])\n\n        if self.use_viewdirs:\n            self.feature_linear = nn.Linear(W, W)\n            self.alpha_linear = nn.Linear(W, 1)\n            self.rgb_linear = nn.Linear(W//2, 3)\n        else:\n            self.output_linear = nn.Linear(W, output_ch)\n\n        self.weight_linear = nn.Linear(W, 1)\n\n    def forward(self, x):\n        input_pts, input_views = torch.split(x, [self.input_ch, self.input_ch_views], dim=-1)\n        h = input_pts\n        for i, l in enumerate(self.pts_linears):\n            h = self.pts_linears[i](h)\n            h = F.relu(h)\n            if i in self.skips:\n                h = torch.cat([input_pts, h], -1)\n\n        blending = torch.sigmoid(self.weight_linear(h))\n        if self.use_viewdirs:\n            alpha = self.alpha_linear(h)\n            feature = self.feature_linear(h)\n            h = torch.cat([feature, input_views], -1)\n\n            for i, l in enumerate(self.views_linears):\n                h = self.views_linears[i](h)\n                h = F.relu(h)\n\n            rgb = self.rgb_linear(h)\n            outputs = torch.cat([rgb, alpha], -1)\n        else:\n            outputs = self.output_linear(h)\n\n        return torch.cat([outputs, blending], -1)\n\n\ndef batchify(fn, chunk):\n    \"\"\"Constructs a version of 'fn' that applies to smaller batches.\n    \"\"\"\n    if chunk is None:\n        return fn\n\n    def ret(inputs):\n        return torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)\n    return ret\n\n\ndef run_network(inputs, viewdirs, fn, embed_fn, embeddirs_fn, netchunk=1024*64):\n    \"\"\"Prepares inputs and applies network 'fn'.\n    \"\"\"\n\n    inputs_flat = torch.reshape(inputs, [-1, inputs.shape[-1]])\n\n    embedded = embed_fn(inputs_flat)\n    if viewdirs is not None:\n        input_dirs = viewdirs[:, None].expand(inputs[:, :, :3].shape)\n        input_dirs_flat = torch.reshape(input_dirs, [-1, input_dirs.shape[-1]])\n        embedded_dirs = embeddirs_fn(input_dirs_flat)\n        embedded = torch.cat([embedded, embedded_dirs], -1)\n\n    outputs_flat = batchify(fn, netchunk)(embedded)\n    outputs = torch.reshape(outputs_flat, list(\n        inputs.shape[:-1]) + [outputs_flat.shape[-1]])\n    return outputs\n\n\ndef create_nerf(args):\n    \"\"\"Instantiate NeRF's MLP model.\n    \"\"\"\n\n    embed_fn_d, input_ch_d = get_embedder(args.multires, args.i_embed, 4)\n    # 10 * 2 * 4 + 4 = 84\n    # L * (sin, cos) * (x, y, z, t) + (x, y, z, t)\n\n    input_ch_views = 0\n    embeddirs_fn = None\n    if args.use_viewdirs:\n        embeddirs_fn, input_ch_views = get_embedder(\n            args.multires_views, args.i_embed, 3)\n        # 4 * 2 * 3 + 3 = 27\n        # L * (sin, cos) * (3 Cartesian viewing direction unit vector from [theta, phi]) + (3 Cartesian viewing direction unit vector from [theta, phi])\n    output_ch = 5 if args.N_importance > 0 else 4\n    skips = [4]\n    model_d = NeRF_d(D=args.netdepth, W=args.netwidth,\n                     input_ch=input_ch_d, output_ch=output_ch, skips=skips,\n                     input_ch_views=input_ch_views,\n                     use_viewdirsDyn=args.use_viewdirsDyn).to(device)\n\n    device_ids = list(range(torch.cuda.device_count()))\n    model_d = torch.nn.DataParallel(model_d, device_ids=device_ids)\n    grad_vars = list(model_d.parameters())\n\n    embed_fn_s, input_ch_s = get_embedder(args.multires, args.i_embed, 3)\n    # 10 * 2 * 3 + 3 = 63\n    # L * (sin, cos) * (x, y, z) + (x, y, z)\n\n    model_s = NeRF_s(D=args.netdepth, W=args.netwidth,\n                     input_ch=input_ch_s, output_ch=output_ch, skips=skips,\n                     input_ch_views=input_ch_views,\n                     use_viewdirs=args.use_viewdirs).to(device)\n\n    model_s = torch.nn.DataParallel(model_s, device_ids=device_ids)\n    grad_vars += list(model_s.parameters())\n\n    model_fine = None\n    if args.N_importance > 0:\n        raise NotImplementedError\n\n    def network_query_fn_d(inputs, viewdirs, network_fn): return run_network(\n        inputs, viewdirs, network_fn,\n        embed_fn=embed_fn_d,\n        embeddirs_fn=embeddirs_fn,\n        netchunk=args.netchunk)\n\n    def network_query_fn_s(inputs, viewdirs, network_fn): return run_network(\n        inputs, viewdirs, network_fn,\n        embed_fn=embed_fn_s,\n        embeddirs_fn=embeddirs_fn,\n        netchunk=args.netchunk)\n\n    render_kwargs_train = {\n        'network_query_fn_d': network_query_fn_d,\n        'network_query_fn_s': network_query_fn_s,\n        'network_fn_d': model_d,\n        'network_fn_s': model_s,\n        'perturb': args.perturb,\n        'N_importance': args.N_importance,\n        'N_samples': args.N_samples,\n        'use_viewdirs': args.use_viewdirs,\n        'raw_noise_std': args.raw_noise_std,\n        'inference': False,\n        'DyNeRF_blending': args.DyNeRF_blending,\n    }\n\n    # NDC only good for LLFF-style forward facing data\n    if args.dataset_type != 'llff' or args.no_ndc:\n        print('Not ndc!')\n        render_kwargs_train['ndc'] = False\n        render_kwargs_train['lindisp'] = args.lindisp\n    else:\n        render_kwargs_train['ndc'] = True\n\n    render_kwargs_test = {\n        k: render_kwargs_train[k] for k in render_kwargs_train}\n    render_kwargs_test['perturb'] = False\n    render_kwargs_test['raw_noise_std'] = 0.\n    render_kwargs_test['inference'] = True\n\n    # Create optimizer\n    optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))\n\n    start = 0\n    basedir = args.basedir\n    expname = args.expname\n\n    if args.ft_path is not None and args.ft_path != 'None':\n        ckpts = [args.ft_path]\n    else:\n        ckpts = [os.path.join(basedir, expname, f) for f in sorted(os.listdir(os.path.join(basedir, expname))) if 'tar' in f]\n    print('Found ckpts', ckpts)\n    if len(ckpts) > 0 and not args.no_reload:\n        ckpt_path = ckpts[-1]\n        print('Reloading from', ckpt_path)\n        ckpt = torch.load(ckpt_path)\n\n        start = ckpt['global_step'] + 1\n        # optimizer.load_state_dict(ckpt['optimizer_state_dict'])\n        model_d.load_state_dict(ckpt['network_fn_d_state_dict'])\n        model_s.load_state_dict(ckpt['network_fn_s_state_dict'])\n        print('Resetting step to', start)\n\n        if model_fine is not None:\n            raise NotImplementedError\n\n    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer\n\n\n# Ray helpers\ndef get_rays(H, W, focal, c2w):\n    \"\"\"Get ray origins, directions from a pinhole camera.\"\"\"\n    i, j = torch.meshgrid(torch.linspace(0, W-1, W), torch.linspace(0, H-1, H)) # pytorch's meshgrid has indexing='ij'\n    i = i.t()\n    j = j.t()\n    dirs = torch.stack([(i-W*.5)/focal, -(j-H*.5)/focal, -torch.ones_like(i)], -1)\n    # Rotate ray directions from camera frame to the world frame\n    rays_d = torch.sum(dirs[..., np.newaxis, :] * c2w[:3, :3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]\n    # Translate camera frame's origin to the world frame. It is the origin of all rays.\n    rays_o = c2w[:3, -1].expand(rays_d.shape)\n    return rays_o, rays_d\n\n\ndef ndc_rays(H, W, focal, near, rays_o, rays_d):\n    \"\"\"Normalized device coordinate rays.\n    Space such that the canvas is a cube with sides [-1, 1] in each axis.\n    Args:\n      H: int. Height in pixels.\n      W: int. Width in pixels.\n      focal: float. Focal length of pinhole camera.\n      near: float or array of shape[batch_size]. Near depth bound for the scene.\n      rays_o: array of shape [batch_size, 3]. Camera origin.\n      rays_d: array of shape [batch_size, 3]. Ray direction.\n    Returns:\n      rays_o: array of shape [batch_size, 3]. Camera origin in NDC.\n      rays_d: array of shape [batch_size, 3]. Ray direction in NDC.\n    \"\"\"\n    # Shift ray origins to near plane\n    t = -(near + rays_o[..., 2]) / rays_d[..., 2]\n    rays_o = rays_o + t[..., None] * rays_d\n\n    # Projection\n    o0 = -1./(W/(2.*focal)) * rays_o[..., 0] / rays_o[..., 2]\n    o1 = -1./(H/(2.*focal)) * rays_o[..., 1] / rays_o[..., 2]\n    o2 = 1. + 2. * near / rays_o[..., 2]\n\n    d0 = -1./(W/(2.*focal)) * \\\n        (rays_d[..., 0]/rays_d[..., 2] - rays_o[..., 0]/rays_o[..., 2])\n    d1 = -1./(H/(2.*focal)) * \\\n    (rays_d[..., 1]/rays_d[..., 2] - rays_o[..., 1]/rays_o[..., 2])\n    d2 = -2. * near / rays_o[..., 2]\n\n    rays_o = torch.stack([o0, o1, o2], -1)\n    rays_d = torch.stack([d0, d1, d2], -1)\n\n    return rays_o, rays_d\n\n\ndef get_grid(H, W, num_img, flows_f, flow_masks_f, flows_b, flow_masks_b):\n\n    # |--------------------|  |--------------------|\n    # |       j            |  |       v            |\n    # |   i   *            |  |   u   *            |\n    # |                    |  |                    |\n    # |--------------------|  |--------------------|\n\n    i, j = np.meshgrid(np.arange(W, dtype=np.float32),\n                       np.arange(H, dtype=np.float32), indexing='xy')\n\n    grid = np.empty((0, H, W, 8), np.float32)\n    for idx in range(num_img):\n        grid = np.concatenate((grid, np.stack([i,\n                                               j,\n                                               flows_f[idx, :, :, 0],\n                                               flows_f[idx, :, :, 1],\n                                               flow_masks_f[idx, :, :],\n                                               flows_b[idx, :, :, 0],\n                                               flows_b[idx, :, :, 1],\n                                               flow_masks_b[idx, :, :]], -1)[None, ...]))\n    return grid\n\n\ndef NDC2world(pts, H, W, f):\n\n    # NDC coordinate to world coordinate\n    pts_z = 2 / (torch.clamp(pts[..., 2:], min=-1., max=1-1e-3) - 1)\n    pts_x = - pts[..., 0:1] * pts_z * W / 2 / f\n    pts_y = - pts[..., 1:2] * pts_z * H / 2 / f\n    pts_world = torch.cat([pts_x, pts_y, pts_z], -1)\n\n    return pts_world\n\n\ndef render_3d_point(H, W, f, pose, weights, pts):\n    \"\"\"Render 3D position along each ray and project it to the image plane.\n    \"\"\"\n\n    c2w = pose\n    w2c = c2w[:3, :3].transpose(0, 1) # same as np.linalg.inv(c2w[:3, :3])\n\n    # Rendered 3D position in NDC coordinate\n    pts_map_NDC = torch.sum(weights[..., None] * pts, -2)\n\n    # NDC coordinate to world coordinate\n    pts_map_world = NDC2world(pts_map_NDC, H, W, f)\n\n    # World coordinate to camera coordinate\n    # Translate\n    pts_map_world = pts_map_world - c2w[:, 3]\n    # Rotate\n    pts_map_cam = torch.sum(pts_map_world[..., None, :] * w2c[:3, :3], -1)\n\n    # Camera coordinate to 2D image coordinate\n    pts_plane = torch.cat([pts_map_cam[..., 0:1] / (- pts_map_cam[..., 2:]) * f + W * .5,\n                         - pts_map_cam[..., 1:2] / (- pts_map_cam[..., 2:]) * f + H * .5],\n                         -1)\n\n    return pts_plane\n\n\ndef induce_flow(H, W, focal, pose_neighbor, weights, pts_3d_neighbor, pts_2d):\n\n    # Render 3D position along each ray and project it to the neighbor frame's image plane.\n    pts_2d_neighbor = render_3d_point(H, W, focal,\n                                      pose_neighbor,\n                                      weights,\n                                      pts_3d_neighbor)\n    induced_flow = pts_2d_neighbor - pts_2d\n\n    return induced_flow\n\n\ndef compute_depth_loss(dyn_depth, gt_depth):\n\n    t_d = torch.median(dyn_depth)\n    s_d = torch.mean(torch.abs(dyn_depth - t_d))\n    dyn_depth_norm = (dyn_depth - t_d) / s_d\n\n    t_gt = torch.median(gt_depth)\n    s_gt = torch.mean(torch.abs(gt_depth - t_gt))\n    gt_depth_norm = (gt_depth - t_gt) / s_gt\n\n    return torch.mean((dyn_depth_norm - gt_depth_norm) ** 2)\n\n\ndef normalize_depth(depth):\n    return torch.clamp(depth / percentile(depth, 97), 0., 1.)\n\n\ndef percentile(t, q):\n    \"\"\"\n    Return the ``q``-th percentile of the flattened input tensor's data.\n\n    CAUTION:\n     * Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.\n     * Values are not interpolated, which corresponds to\n       ``numpy.percentile(..., interpolation=\"nearest\")``.\n\n    :param t: Input tensor.\n    :param q: Percentile to compute, which must be between 0 and 100 inclusive.\n    :return: Resulting value (scalar).\n    \"\"\"\n\n    k = 1 + round(.01 * float(q) * (t.numel() - 1))\n    result = t.view(-1).kthvalue(k).values.item()\n    return result\n\n\ndef save_res(moviebase, ret, fps=None):\n\n    if fps == None:\n        if len(ret['rgbs']) < 25:\n            fps = 4\n        else:\n            fps = 24\n\n    for k in ret:\n        if 'rgbs' in k:\n            imageio.mimwrite(moviebase + k + '.mp4',\n                             to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)\n            # imageio.mimsave(moviebase + k + '.gif',\n            #                  to8b(ret[k]), format='gif', fps=fps)\n        elif 'depths' in k:\n            imageio.mimwrite(moviebase + k + '.mp4',\n                             to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)\n            # imageio.mimsave(moviebase + k + '.gif',\n            #                  to8b(ret[k]), format='gif', fps=fps)\n        elif 'disps' in k:\n            imageio.mimwrite(moviebase + k + '.mp4',\n                             to8b(ret[k] / np.max(ret[k])), fps=fps, quality=8, macro_block_size=1)\n            # imageio.mimsave(moviebase + k + '.gif',\n            #                  to8b(ret[k] / np.max(ret[k])), format='gif', fps=fps)\n        elif 'sceneflow_' in k:\n            imageio.mimwrite(moviebase + k + '.mp4',\n                             to8b(norm_sf(ret[k])), fps=fps, quality=8, macro_block_size=1)\n            # imageio.mimsave(moviebase + k + '.gif',\n            #                  to8b(norm_sf(ret[k])), format='gif', fps=fps)\n        elif 'flows' in k:\n            imageio.mimwrite(moviebase + k + '.mp4',\n                             ret[k], fps=fps, quality=8, macro_block_size=1)\n            # imageio.mimsave(moviebase + k + '.gif',\n            #                  ret[k], format='gif', fps=fps)\n        elif 'dynamicness' in k:\n            imageio.mimwrite(moviebase + k + '.mp4',\n                             to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)\n            # imageio.mimsave(moviebase + k + '.gif',\n            #                  to8b(ret[k]), format='gif', fps=fps)\n        elif 'disocclusions' in k:\n            imageio.mimwrite(moviebase + k + '.mp4',\n                             to8b(ret[k][..., 0]), fps=fps, quality=8, macro_block_size=1)\n            # imageio.mimsave(moviebase + k + '.gif',\n            #                  to8b(ret[k][..., 0]), format='gif', fps=fps)\n        elif 'blending' in k:\n            blending = ret[k][..., None]\n            blending = np.moveaxis(blending, [0, 1, 2, 3], [1, 2, 0, 3])\n            imageio.mimwrite(moviebase + k + '.mp4',\n                             to8b(blending), fps=fps, quality=8, macro_block_size=1)\n            # imageio.mimsave(moviebase + k + '.gif',\n            #                  to8b(blending), format='gif', fps=fps)\n        elif 'weights' in k:\n            imageio.mimwrite(moviebase + k + '.mp4',\n                             to8b(ret[k]), fps=fps, quality=8, macro_block_size=1)\n        else:\n            raise NotImplementedError\n\n\ndef norm_sf_channel(sf_ch):\n\n    # Make sure zero scene flow is not shifted\n    sf_ch[sf_ch >= 0] = sf_ch[sf_ch >= 0] / sf_ch.max() / 2\n    sf_ch[sf_ch < 0] = sf_ch[sf_ch < 0] / np.abs(sf_ch.min()) / 2\n    sf_ch = sf_ch + 0.5\n    return sf_ch\n\n\ndef norm_sf(sf):\n\n    sf = np.concatenate((norm_sf_channel(sf[..., 0:1]),\n                         norm_sf_channel(sf[..., 1:2]),\n                         norm_sf_channel(sf[..., 2:3])), -1)\n    sf = np.moveaxis(sf, [0, 1, 2, 3], [1, 2, 0, 3])\n    return sf\n\n\n# Spatial smoothness (adapted from NSFF)\ndef compute_sf_smooth_s_loss(pts1, pts2, H, W, f):\n\n    N_samples = pts1.shape[1]\n\n    # NDC coordinate to world coordinate\n    pts1_world = NDC2world(pts1[..., :int(N_samples * 0.95), :], H, W, f)\n    pts2_world = NDC2world(pts2[..., :int(N_samples * 0.95), :], H, W, f)\n\n    # scene flow in world coordinate\n    scene_flow_world = pts1_world - pts2_world\n\n    return L1(scene_flow_world[..., :-1, :] - scene_flow_world[..., 1:, :])\n\n\n# Temporal smoothness\ndef compute_sf_smooth_loss(pts, pts_f, pts_b, H, W, f):\n\n    N_samples = pts.shape[1]\n\n    pts_world   = NDC2world(pts[..., :int(N_samples * 0.9), :],   H, W, f)\n    pts_f_world = NDC2world(pts_f[..., :int(N_samples * 0.9), :], H, W, f)\n    pts_b_world = NDC2world(pts_b[..., :int(N_samples * 0.9), :], H, W, f)\n\n    # scene flow in world coordinate\n    sceneflow_f = pts_f_world - pts_world\n    sceneflow_b = pts_b_world - pts_world\n\n    # For a 3D point, its forward and backward sceneflow should be opposite.\n    return L2(sceneflow_f + sceneflow_b)\n"
  },
  {
    "path": "utils/RAFT/__init__.py",
    "content": "# from .demo import RAFT_infer\nfrom .raft import RAFT\n"
  },
  {
    "path": "utils/RAFT/corr.py",
    "content": "import torch\nimport torch.nn.functional as F\nfrom .utils.utils import bilinear_sampler, coords_grid\n\ntry:\n    import alt_cuda_corr\nexcept:\n    # alt_cuda_corr is not compiled\n    pass\n\n\nclass CorrBlock:\n    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):\n        self.num_levels = num_levels\n        self.radius = radius\n        self.corr_pyramid = []\n\n        # all pairs correlation\n        corr = CorrBlock.corr(fmap1, fmap2)\n\n        batch, h1, w1, dim, h2, w2 = corr.shape\n        corr = corr.reshape(batch*h1*w1, dim, h2, w2)\n\n        self.corr_pyramid.append(corr)\n        for i in range(self.num_levels-1):\n            corr = F.avg_pool2d(corr, 2, stride=2)\n            self.corr_pyramid.append(corr)\n\n    def __call__(self, coords):\n        r = self.radius\n        coords = coords.permute(0, 2, 3, 1)\n        batch, h1, w1, _ = coords.shape\n\n        out_pyramid = []\n        for i in range(self.num_levels):\n            corr = self.corr_pyramid[i]\n            dx = torch.linspace(-r, r, 2*r+1)\n            dy = torch.linspace(-r, r, 2*r+1)\n            delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)\n\n            centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i\n            delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)\n            coords_lvl = centroid_lvl + delta_lvl\n\n            corr = bilinear_sampler(corr, coords_lvl)\n            corr = corr.view(batch, h1, w1, -1)\n            out_pyramid.append(corr)\n\n        out = torch.cat(out_pyramid, dim=-1)\n        return out.permute(0, 3, 1, 2).contiguous().float()\n\n    @staticmethod\n    def corr(fmap1, fmap2):\n        batch, dim, ht, wd = fmap1.shape\n        fmap1 = fmap1.view(batch, dim, ht*wd)\n        fmap2 = fmap2.view(batch, dim, ht*wd)\n\n        corr = torch.matmul(fmap1.transpose(1,2), fmap2)\n        corr = corr.view(batch, ht, wd, 1, ht, wd)\n        return corr  / torch.sqrt(torch.tensor(dim).float())\n\n\nclass CorrLayer(torch.autograd.Function):\n    @staticmethod\n    def forward(ctx, fmap1, fmap2, coords, r):\n        fmap1 = fmap1.contiguous()\n        fmap2 = fmap2.contiguous()\n        coords = coords.contiguous()\n        ctx.save_for_backward(fmap1, fmap2, coords)\n        ctx.r = r\n        corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)\n        return corr\n\n    @staticmethod\n    def backward(ctx, grad_corr):\n        fmap1, fmap2, coords = ctx.saved_tensors\n        grad_corr = grad_corr.contiguous()\n        fmap1_grad, fmap2_grad, coords_grad = \\\n            correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)\n        return fmap1_grad, fmap2_grad, coords_grad, None\n\n\nclass AlternateCorrBlock:\n    def __init__(self, fmap1, fmap2, num_levels=4, radius=4):\n        self.num_levels = num_levels\n        self.radius = radius\n\n        self.pyramid = [(fmap1, fmap2)]\n        for i in range(self.num_levels):\n            fmap1 = F.avg_pool2d(fmap1, 2, stride=2)\n            fmap2 = F.avg_pool2d(fmap2, 2, stride=2)\n            self.pyramid.append((fmap1, fmap2))\n\n    def __call__(self, coords):\n\n        coords = coords.permute(0, 2, 3, 1)\n        B, H, W, _ = coords.shape\n\n        corr_list = []\n        for i in range(self.num_levels):\n            r = self.radius\n            fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)\n            fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)\n\n            coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()\n            corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)\n            corr_list.append(corr.squeeze(1))\n\n        corr = torch.stack(corr_list, dim=1)\n        corr = corr.reshape(B, -1, H, W)\n        return corr / 16.0\n"
  },
  {
    "path": "utils/RAFT/datasets.py",
    "content": "# Data loading based on https://github.com/NVIDIA/flownet2-pytorch\n\nimport numpy as np\nimport torch\nimport torch.utils.data as data\nimport torch.nn.functional as F\n\nimport os\nimport math\nimport random\nfrom glob import glob\nimport os.path as osp\n\nfrom utils import frame_utils\nfrom utils.augmentor import FlowAugmentor, SparseFlowAugmentor\n\n\nclass FlowDataset(data.Dataset):\n    def __init__(self, aug_params=None, sparse=False):\n        self.augmentor = None\n        self.sparse = sparse\n        if aug_params is not None:\n            if sparse:\n                self.augmentor = SparseFlowAugmentor(**aug_params)\n            else:\n                self.augmentor = FlowAugmentor(**aug_params)\n\n        self.is_test = False\n        self.init_seed = False\n        self.flow_list = []\n        self.image_list = []\n        self.extra_info = []\n\n    def __getitem__(self, index):\n\n        if self.is_test:\n            img1 = frame_utils.read_gen(self.image_list[index][0])\n            img2 = frame_utils.read_gen(self.image_list[index][1])\n            img1 = np.array(img1).astype(np.uint8)[..., :3]\n            img2 = np.array(img2).astype(np.uint8)[..., :3]\n            img1 = torch.from_numpy(img1).permute(2, 0, 1).float()\n            img2 = torch.from_numpy(img2).permute(2, 0, 1).float()\n            return img1, img2, self.extra_info[index]\n\n        if not self.init_seed:\n            worker_info = torch.utils.data.get_worker_info()\n            if worker_info is not None:\n                torch.manual_seed(worker_info.id)\n                np.random.seed(worker_info.id)\n                random.seed(worker_info.id)\n                self.init_seed = True\n\n        index = index % len(self.image_list)\n        valid = None\n        if self.sparse:\n            flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])\n        else:\n            flow = frame_utils.read_gen(self.flow_list[index])\n\n        img1 = frame_utils.read_gen(self.image_list[index][0])\n        img2 = frame_utils.read_gen(self.image_list[index][1])\n\n        flow = np.array(flow).astype(np.float32)\n        img1 = np.array(img1).astype(np.uint8)\n        img2 = np.array(img2).astype(np.uint8)\n\n        # grayscale images\n        if len(img1.shape) == 2:\n            img1 = np.tile(img1[...,None], (1, 1, 3))\n            img2 = np.tile(img2[...,None], (1, 1, 3))\n        else:\n            img1 = img1[..., :3]\n            img2 = img2[..., :3]\n\n        if self.augmentor is not None:\n            if self.sparse:\n                img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)\n            else:\n                img1, img2, flow = self.augmentor(img1, img2, flow)\n\n        img1 = torch.from_numpy(img1).permute(2, 0, 1).float()\n        img2 = torch.from_numpy(img2).permute(2, 0, 1).float()\n        flow = torch.from_numpy(flow).permute(2, 0, 1).float()\n\n        if valid is not None:\n            valid = torch.from_numpy(valid)\n        else:\n            valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)\n\n        return img1, img2, flow, valid.float()\n\n\n    def __rmul__(self, v):\n        self.flow_list = v * self.flow_list\n        self.image_list = v * self.image_list\n        return self\n        \n    def __len__(self):\n        return len(self.image_list)\n        \n\nclass MpiSintel(FlowDataset):\n    def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):\n        super(MpiSintel, self).__init__(aug_params)\n        flow_root = osp.join(root, split, 'flow')\n        image_root = osp.join(root, split, dstype)\n\n        if split == 'test':\n            self.is_test = True\n\n        for scene in os.listdir(image_root):\n            image_list = sorted(glob(osp.join(image_root, scene, '*.png')))\n            for i in range(len(image_list)-1):\n                self.image_list += [ [image_list[i], image_list[i+1]] ]\n                self.extra_info += [ (scene, i) ] # scene and frame_id\n\n            if split != 'test':\n                self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))\n\n\nclass FlyingChairs(FlowDataset):\n    def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):\n        super(FlyingChairs, self).__init__(aug_params)\n\n        images = sorted(glob(osp.join(root, '*.ppm')))\n        flows = sorted(glob(osp.join(root, '*.flo')))\n        assert (len(images)//2 == len(flows))\n\n        split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)\n        for i in range(len(flows)):\n            xid = split_list[i]\n            if (split=='training' and xid==1) or (split=='validation' and xid==2):\n                self.flow_list += [ flows[i] ]\n                self.image_list += [ [images[2*i], images[2*i+1]] ]\n\n\nclass FlyingThings3D(FlowDataset):\n    def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):\n        super(FlyingThings3D, self).__init__(aug_params)\n\n        for cam in ['left']:\n            for direction in ['into_future', 'into_past']:\n                image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))\n                image_dirs = sorted([osp.join(f, cam) for f in image_dirs])\n\n                flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))\n                flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])\n\n                for idir, fdir in zip(image_dirs, flow_dirs):\n                    images = sorted(glob(osp.join(idir, '*.png')) )\n                    flows = sorted(glob(osp.join(fdir, '*.pfm')) )\n                    for i in range(len(flows)-1):\n                        if direction == 'into_future':\n                            self.image_list += [ [images[i], images[i+1]] ]\n                            self.flow_list += [ flows[i] ]\n                        elif direction == 'into_past':\n                            self.image_list += [ [images[i+1], images[i]] ]\n                            self.flow_list += [ flows[i+1] ]\n      \n\nclass KITTI(FlowDataset):\n    def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):\n        super(KITTI, self).__init__(aug_params, sparse=True)\n        if split == 'testing':\n            self.is_test = True\n\n        root = osp.join(root, split)\n        images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))\n        images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))\n\n        for img1, img2 in zip(images1, images2):\n            frame_id = img1.split('/')[-1]\n            self.extra_info += [ [frame_id] ]\n            self.image_list += [ [img1, img2] ]\n\n        if split == 'training':\n            self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))\n\n\nclass HD1K(FlowDataset):\n    def __init__(self, aug_params=None, root='datasets/HD1k'):\n        super(HD1K, self).__init__(aug_params, sparse=True)\n\n        seq_ix = 0\n        while 1:\n            flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))\n            images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))\n\n            if len(flows) == 0:\n                break\n\n            for i in range(len(flows)-1):\n                self.flow_list += [flows[i]]\n                self.image_list += [ [images[i], images[i+1]] ]\n\n            seq_ix += 1\n\n\ndef fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):\n    \"\"\" Create the data loader for the corresponding trainign set \"\"\"\n\n    if args.stage == 'chairs':\n        aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}\n        train_dataset = FlyingChairs(aug_params, split='training')\n    \n    elif args.stage == 'things':\n        aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}\n        clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')\n        final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')\n        train_dataset = clean_dataset + final_dataset\n\n    elif args.stage == 'sintel':\n        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}\n        things = FlyingThings3D(aug_params, dstype='frames_cleanpass')\n        sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')\n        sintel_final = MpiSintel(aug_params, split='training', dstype='final')        \n\n        if TRAIN_DS == 'C+T+K+S+H':\n            kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})\n            hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})\n            train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things\n\n        elif TRAIN_DS == 'C+T+K/S':\n            train_dataset = 100*sintel_clean + 100*sintel_final + things\n\n    elif args.stage == 'kitti':\n        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}\n        train_dataset = KITTI(aug_params, split='training')\n\n    train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, \n        pin_memory=False, shuffle=True, num_workers=4, drop_last=True)\n\n    print('Training with %d image pairs' % len(train_dataset))\n    return train_loader\n\n"
  },
  {
    "path": "utils/RAFT/demo.py",
    "content": "import sys\nimport argparse\nimport os\nimport cv2\nimport glob\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom .raft import RAFT\nfrom .utils import flow_viz\nfrom .utils.utils import InputPadder\n\n\n\nDEVICE = 'cuda'\n\ndef load_image(imfile):\n    img = np.array(Image.open(imfile)).astype(np.uint8)\n    img = torch.from_numpy(img).permute(2, 0, 1).float()\n    return img\n\n\ndef load_image_list(image_files):\n    images = []\n    for imfile in sorted(image_files):\n        images.append(load_image(imfile))\n\n    images = torch.stack(images, dim=0)\n    images = images.to(DEVICE)\n\n    padder = InputPadder(images.shape)\n    return padder.pad(images)[0]\n\n\ndef viz(img, flo):\n    img = img[0].permute(1,2,0).cpu().numpy()\n    flo = flo[0].permute(1,2,0).cpu().numpy()\n\n    # map flow to rgb image\n    flo = flow_viz.flow_to_image(flo)\n    # img_flo = np.concatenate([img, flo], axis=0)\n    img_flo = flo\n\n    cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])\n    # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)\n    # cv2.waitKey()\n\n\ndef demo(args):\n    model = torch.nn.DataParallel(RAFT(args))\n    model.load_state_dict(torch.load(args.model))\n\n    model = model.module\n    model.to(DEVICE)\n    model.eval()\n\n    with torch.no_grad():\n        images = glob.glob(os.path.join(args.path, '*.png')) + \\\n                 glob.glob(os.path.join(args.path, '*.jpg'))\n\n        images = load_image_list(images)\n        for i in range(images.shape[0]-1):\n            image1 = images[i,None]\n            image2 = images[i+1,None]\n\n            flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)\n            viz(image1, flow_up)\n\n\ndef RAFT_infer(args):\n    model = torch.nn.DataParallel(RAFT(args))\n    model.load_state_dict(torch.load(args.model))\n\n    model = model.module\n    model.to(DEVICE)\n    model.eval()\n\n    return model\n"
  },
  {
    "path": "utils/RAFT/extractor.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass ResidualBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn='group', stride=1):\n        super(ResidualBlock, self).__init__()\n  \n        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)\n        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n        \n        elif norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(planes)\n            self.norm2 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.BatchNorm2d(planes)\n        \n        elif norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(planes)\n            self.norm2 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm3 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            if not stride == 1:\n                self.norm3 = nn.Sequential()\n\n        if stride == 1:\n            self.downsample = None\n        \n        else:    \n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)\n\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x+y)\n\n\n\nclass BottleneckBlock(nn.Module):\n    def __init__(self, in_planes, planes, norm_fn='group', stride=1):\n        super(BottleneckBlock, self).__init__()\n  \n        self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)\n        self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)\n        self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)\n        self.relu = nn.ReLU(inplace=True)\n\n        num_groups = planes // 8\n\n        if norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)\n            self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)\n            self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n            if not stride == 1:\n                self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)\n        \n        elif norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(planes//4)\n            self.norm2 = nn.BatchNorm2d(planes//4)\n            self.norm3 = nn.BatchNorm2d(planes)\n            if not stride == 1:\n                self.norm4 = nn.BatchNorm2d(planes)\n        \n        elif norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(planes//4)\n            self.norm2 = nn.InstanceNorm2d(planes//4)\n            self.norm3 = nn.InstanceNorm2d(planes)\n            if not stride == 1:\n                self.norm4 = nn.InstanceNorm2d(planes)\n\n        elif norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n            self.norm2 = nn.Sequential()\n            self.norm3 = nn.Sequential()\n            if not stride == 1:\n                self.norm4 = nn.Sequential()\n\n        if stride == 1:\n            self.downsample = None\n        \n        else:    \n            self.downsample = nn.Sequential(\n                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)\n\n\n    def forward(self, x):\n        y = x\n        y = self.relu(self.norm1(self.conv1(y)))\n        y = self.relu(self.norm2(self.conv2(y)))\n        y = self.relu(self.norm3(self.conv3(y)))\n\n        if self.downsample is not None:\n            x = self.downsample(x)\n\n        return self.relu(x+y)\n\nclass BasicEncoder(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):\n        super(BasicEncoder, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(64)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(64)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 64\n        self.layer1 = self._make_layer(64,  stride=1)\n        self.layer2 = self._make_layer(96, stride=2)\n        self.layer3 = self._make_layer(128, stride=2)\n\n        # output convolution\n        self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n        \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n\n        x = self.conv2(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n\n        return x\n\n\nclass SmallEncoder(nn.Module):\n    def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):\n        super(SmallEncoder, self).__init__()\n        self.norm_fn = norm_fn\n\n        if self.norm_fn == 'group':\n            self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)\n            \n        elif self.norm_fn == 'batch':\n            self.norm1 = nn.BatchNorm2d(32)\n\n        elif self.norm_fn == 'instance':\n            self.norm1 = nn.InstanceNorm2d(32)\n\n        elif self.norm_fn == 'none':\n            self.norm1 = nn.Sequential()\n\n        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)\n        self.relu1 = nn.ReLU(inplace=True)\n\n        self.in_planes = 32\n        self.layer1 = self._make_layer(32,  stride=1)\n        self.layer2 = self._make_layer(64, stride=2)\n        self.layer3 = self._make_layer(96, stride=2)\n\n        self.dropout = None\n        if dropout > 0:\n            self.dropout = nn.Dropout2d(p=dropout)\n        \n        self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)\n\n        for m in self.modules():\n            if isinstance(m, nn.Conv2d):\n                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n            elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):\n                if m.weight is not None:\n                    nn.init.constant_(m.weight, 1)\n                if m.bias is not None:\n                    nn.init.constant_(m.bias, 0)\n\n    def _make_layer(self, dim, stride=1):\n        layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)\n        layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)\n        layers = (layer1, layer2)\n    \n        self.in_planes = dim\n        return nn.Sequential(*layers)\n\n\n    def forward(self, x):\n\n        # if input is list, combine batch dimension\n        is_list = isinstance(x, tuple) or isinstance(x, list)\n        if is_list:\n            batch_dim = x[0].shape[0]\n            x = torch.cat(x, dim=0)\n\n        x = self.conv1(x)\n        x = self.norm1(x)\n        x = self.relu1(x)\n\n        x = self.layer1(x)\n        x = self.layer2(x)\n        x = self.layer3(x)\n        x = self.conv2(x)\n\n        if self.training and self.dropout is not None:\n            x = self.dropout(x)\n\n        if is_list:\n            x = torch.split(x, [batch_dim, batch_dim], dim=0)\n\n        return x\n"
  },
  {
    "path": "utils/RAFT/raft.py",
    "content": "import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom .update import BasicUpdateBlock, SmallUpdateBlock\nfrom .extractor import BasicEncoder, SmallEncoder\nfrom .corr import CorrBlock, AlternateCorrBlock\nfrom .utils.utils import bilinear_sampler, coords_grid, upflow8\n\ntry:\n    autocast = torch.cuda.amp.autocast\nexcept:\n    # dummy autocast for PyTorch < 1.6\n    class autocast:\n        def __init__(self, enabled):\n            pass\n        def __enter__(self):\n            pass\n        def __exit__(self, *args):\n            pass\n\n\nclass RAFT(nn.Module):\n    def __init__(self, args):\n        super(RAFT, self).__init__()\n        self.args = args\n\n        if args.small:\n            self.hidden_dim = hdim = 96\n            self.context_dim = cdim = 64\n            args.corr_levels = 4\n            args.corr_radius = 3\n\n        else:\n            self.hidden_dim = hdim = 128\n            self.context_dim = cdim = 128\n            args.corr_levels = 4\n            args.corr_radius = 4\n\n        if 'dropout' not in args._get_kwargs():\n            args.dropout = 0\n\n        if 'alternate_corr' not in args._get_kwargs():\n            args.alternate_corr = False\n\n        # feature network, context network, and update block\n        if args.small:\n            self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)\n            self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)\n            self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)\n\n        else:\n            self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)\n            self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)\n            self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)\n\n\n    def freeze_bn(self):\n        for m in self.modules():\n            if isinstance(m, nn.BatchNorm2d):\n                m.eval()\n\n    def initialize_flow(self, img):\n        \"\"\" Flow is represented as difference between two coordinate grids flow = coords1 - coords0\"\"\"\n        N, C, H, W = img.shape\n        coords0 = coords_grid(N, H//8, W//8).to(img.device)\n        coords1 = coords_grid(N, H//8, W//8).to(img.device)\n\n        # optical flow computed as difference: flow = coords1 - coords0\n        return coords0, coords1\n\n    def upsample_flow(self, flow, mask):\n        \"\"\" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination \"\"\"\n        N, _, H, W = flow.shape\n        mask = mask.view(N, 1, 9, 8, 8, H, W)\n        mask = torch.softmax(mask, dim=2)\n\n        up_flow = F.unfold(8 * flow, [3,3], padding=1)\n        up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)\n\n        up_flow = torch.sum(mask * up_flow, dim=2)\n        up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)\n        return up_flow.reshape(N, 2, 8*H, 8*W)\n\n\n    def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):\n        \"\"\" Estimate optical flow between pair of frames \"\"\"\n\n        image1 = 2 * (image1 / 255.0) - 1.0\n        image2 = 2 * (image2 / 255.0) - 1.0\n\n        image1 = image1.contiguous()\n        image2 = image2.contiguous()\n\n        hdim = self.hidden_dim\n        cdim = self.context_dim\n\n        # run the feature network\n        with autocast(enabled=self.args.mixed_precision):\n            fmap1, fmap2 = self.fnet([image1, image2])\n\n        fmap1 = fmap1.float()\n        fmap2 = fmap2.float()\n        if self.args.alternate_corr:\n            corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)\n        else:\n            corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)\n\n        # run the context network\n        with autocast(enabled=self.args.mixed_precision):\n            cnet = self.cnet(image1)\n            net, inp = torch.split(cnet, [hdim, cdim], dim=1)\n            net = torch.tanh(net)\n            inp = torch.relu(inp)\n\n        coords0, coords1 = self.initialize_flow(image1)\n\n        if flow_init is not None:\n            coords1 = coords1 + flow_init\n\n        flow_predictions = []\n        for itr in range(iters):\n            coords1 = coords1.detach()\n            corr = corr_fn(coords1) # index correlation volume\n\n            flow = coords1 - coords0\n            with autocast(enabled=self.args.mixed_precision):\n                net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)\n\n            # F(t+1) = F(t) + \\Delta(t)\n            coords1 = coords1 + delta_flow\n\n            # upsample predictions\n            if up_mask is None:\n                flow_up = upflow8(coords1 - coords0)\n            else:\n                flow_up = self.upsample_flow(coords1 - coords0, up_mask)\n\n            flow_predictions.append(flow_up)\n\n        if test_mode:\n            return coords1 - coords0, flow_up\n\n        return flow_predictions\n"
  },
  {
    "path": "utils/RAFT/update.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass FlowHead(nn.Module):\n    def __init__(self, input_dim=128, hidden_dim=256):\n        super(FlowHead, self).__init__()\n        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)\n        self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        return self.conv2(self.relu(self.conv1(x)))\n\nclass ConvGRU(nn.Module):\n    def __init__(self, hidden_dim=128, input_dim=192+128):\n        super(ConvGRU, self).__init__()\n        self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)\n        self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)\n        self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)\n\n    def forward(self, h, x):\n        hx = torch.cat([h, x], dim=1)\n\n        z = torch.sigmoid(self.convz(hx))\n        r = torch.sigmoid(self.convr(hx))\n        q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))\n\n        h = (1-z) * h + z * q\n        return h\n\nclass SepConvGRU(nn.Module):\n    def __init__(self, hidden_dim=128, input_dim=192+128):\n        super(SepConvGRU, self).__init__()\n        self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n        self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n        self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))\n\n        self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n        self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n        self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))\n\n\n    def forward(self, h, x):\n        # horizontal\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz1(hx))\n        r = torch.sigmoid(self.convr1(hx))\n        q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))        \n        h = (1-z) * h + z * q\n\n        # vertical\n        hx = torch.cat([h, x], dim=1)\n        z = torch.sigmoid(self.convz2(hx))\n        r = torch.sigmoid(self.convr2(hx))\n        q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))       \n        h = (1-z) * h + z * q\n\n        return h\n\nclass SmallMotionEncoder(nn.Module):\n    def __init__(self, args):\n        super(SmallMotionEncoder, self).__init__()\n        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2\n        self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)\n        self.convf1 = nn.Conv2d(2, 64, 7, padding=3)\n        self.convf2 = nn.Conv2d(64, 32, 3, padding=1)\n        self.conv = nn.Conv2d(128, 80, 3, padding=1)\n\n    def forward(self, flow, corr):\n        cor = F.relu(self.convc1(corr))\n        flo = F.relu(self.convf1(flow))\n        flo = F.relu(self.convf2(flo))\n        cor_flo = torch.cat([cor, flo], dim=1)\n        out = F.relu(self.conv(cor_flo))\n        return torch.cat([out, flow], dim=1)\n\nclass BasicMotionEncoder(nn.Module):\n    def __init__(self, args):\n        super(BasicMotionEncoder, self).__init__()\n        cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2\n        self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)\n        self.convc2 = nn.Conv2d(256, 192, 3, padding=1)\n        self.convf1 = nn.Conv2d(2, 128, 7, padding=3)\n        self.convf2 = nn.Conv2d(128, 64, 3, padding=1)\n        self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)\n\n    def forward(self, flow, corr):\n        cor = F.relu(self.convc1(corr))\n        cor = F.relu(self.convc2(cor))\n        flo = F.relu(self.convf1(flow))\n        flo = F.relu(self.convf2(flo))\n\n        cor_flo = torch.cat([cor, flo], dim=1)\n        out = F.relu(self.conv(cor_flo))\n        return torch.cat([out, flow], dim=1)\n\nclass SmallUpdateBlock(nn.Module):\n    def __init__(self, args, hidden_dim=96):\n        super(SmallUpdateBlock, self).__init__()\n        self.encoder = SmallMotionEncoder(args)\n        self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)\n        self.flow_head = FlowHead(hidden_dim, hidden_dim=128)\n\n    def forward(self, net, inp, corr, flow):\n        motion_features = self.encoder(flow, corr)\n        inp = torch.cat([inp, motion_features], dim=1)\n        net = self.gru(net, inp)\n        delta_flow = self.flow_head(net)\n\n        return net, None, delta_flow\n\nclass BasicUpdateBlock(nn.Module):\n    def __init__(self, args, hidden_dim=128, input_dim=128):\n        super(BasicUpdateBlock, self).__init__()\n        self.args = args\n        self.encoder = BasicMotionEncoder(args)\n        self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)\n        self.flow_head = FlowHead(hidden_dim, hidden_dim=256)\n\n        self.mask = nn.Sequential(\n            nn.Conv2d(128, 256, 3, padding=1),\n            nn.ReLU(inplace=True),\n            nn.Conv2d(256, 64*9, 1, padding=0))\n\n    def forward(self, net, inp, corr, flow, upsample=True):\n        motion_features = self.encoder(flow, corr)\n        inp = torch.cat([inp, motion_features], dim=1)\n\n        net = self.gru(net, inp)\n        delta_flow = self.flow_head(net)\n\n        # scale mask to balence gradients\n        mask = .25 * self.mask(net)\n        return net, mask, delta_flow\n\n\n\n"
  },
  {
    "path": "utils/RAFT/utils/__init__.py",
    "content": "from .flow_viz import flow_to_image\nfrom .frame_utils import writeFlow\n"
  },
  {
    "path": "utils/RAFT/utils/augmentor.py",
    "content": "import numpy as np\nimport random\nimport math\nfrom PIL import Image\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL(False)\n\nimport torch\nfrom torchvision.transforms import ColorJitter\nimport torch.nn.functional as F\n\n\nclass FlowAugmentor:\n    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):\n        \n        # spatial augmentation params\n        self.crop_size = crop_size\n        self.min_scale = min_scale\n        self.max_scale = max_scale\n        self.spatial_aug_prob = 0.8\n        self.stretch_prob = 0.8\n        self.max_stretch = 0.2\n\n        # flip augmentation params\n        self.do_flip = do_flip\n        self.h_flip_prob = 0.5\n        self.v_flip_prob = 0.1\n\n        # photometric augmentation params\n        self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)\n        self.asymmetric_color_aug_prob = 0.2\n        self.eraser_aug_prob = 0.5\n\n    def color_transform(self, img1, img2):\n        \"\"\" Photometric augmentation \"\"\"\n\n        # asymmetric\n        if np.random.rand() < self.asymmetric_color_aug_prob:\n            img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)\n            img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)\n\n        # symmetric\n        else:\n            image_stack = np.concatenate([img1, img2], axis=0)\n            image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)\n            img1, img2 = np.split(image_stack, 2, axis=0)\n\n        return img1, img2\n\n    def eraser_transform(self, img1, img2, bounds=[50, 100]):\n        \"\"\" Occlusion augmentation \"\"\"\n\n        ht, wd = img1.shape[:2]\n        if np.random.rand() < self.eraser_aug_prob:\n            mean_color = np.mean(img2.reshape(-1, 3), axis=0)\n            for _ in range(np.random.randint(1, 3)):\n                x0 = np.random.randint(0, wd)\n                y0 = np.random.randint(0, ht)\n                dx = np.random.randint(bounds[0], bounds[1])\n                dy = np.random.randint(bounds[0], bounds[1])\n                img2[y0:y0+dy, x0:x0+dx, :] = mean_color\n\n        return img1, img2\n\n    def spatial_transform(self, img1, img2, flow):\n        # randomly sample scale\n        ht, wd = img1.shape[:2]\n        min_scale = np.maximum(\n            (self.crop_size[0] + 8) / float(ht), \n            (self.crop_size[1] + 8) / float(wd))\n\n        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)\n        scale_x = scale\n        scale_y = scale\n        if np.random.rand() < self.stretch_prob:\n            scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)\n            scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)\n        \n        scale_x = np.clip(scale_x, min_scale, None)\n        scale_y = np.clip(scale_y, min_scale, None)\n\n        if np.random.rand() < self.spatial_aug_prob:\n            # rescale the images\n            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            flow = flow * [scale_x, scale_y]\n\n        if self.do_flip:\n            if np.random.rand() < self.h_flip_prob: # h-flip\n                img1 = img1[:, ::-1]\n                img2 = img2[:, ::-1]\n                flow = flow[:, ::-1] * [-1.0, 1.0]\n\n            if np.random.rand() < self.v_flip_prob: # v-flip\n                img1 = img1[::-1, :]\n                img2 = img2[::-1, :]\n                flow = flow[::-1, :] * [1.0, -1.0]\n\n        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])\n        x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])\n        \n        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n\n        return img1, img2, flow\n\n    def __call__(self, img1, img2, flow):\n        img1, img2 = self.color_transform(img1, img2)\n        img1, img2 = self.eraser_transform(img1, img2)\n        img1, img2, flow = self.spatial_transform(img1, img2, flow)\n\n        img1 = np.ascontiguousarray(img1)\n        img2 = np.ascontiguousarray(img2)\n        flow = np.ascontiguousarray(flow)\n\n        return img1, img2, flow\n\nclass SparseFlowAugmentor:\n    def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):\n        # spatial augmentation params\n        self.crop_size = crop_size\n        self.min_scale = min_scale\n        self.max_scale = max_scale\n        self.spatial_aug_prob = 0.8\n        self.stretch_prob = 0.8\n        self.max_stretch = 0.2\n\n        # flip augmentation params\n        self.do_flip = do_flip\n        self.h_flip_prob = 0.5\n        self.v_flip_prob = 0.1\n\n        # photometric augmentation params\n        self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)\n        self.asymmetric_color_aug_prob = 0.2\n        self.eraser_aug_prob = 0.5\n        \n    def color_transform(self, img1, img2):\n        image_stack = np.concatenate([img1, img2], axis=0)\n        image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)\n        img1, img2 = np.split(image_stack, 2, axis=0)\n        return img1, img2\n\n    def eraser_transform(self, img1, img2):\n        ht, wd = img1.shape[:2]\n        if np.random.rand() < self.eraser_aug_prob:\n            mean_color = np.mean(img2.reshape(-1, 3), axis=0)\n            for _ in range(np.random.randint(1, 3)):\n                x0 = np.random.randint(0, wd)\n                y0 = np.random.randint(0, ht)\n                dx = np.random.randint(50, 100)\n                dy = np.random.randint(50, 100)\n                img2[y0:y0+dy, x0:x0+dx, :] = mean_color\n\n        return img1, img2\n\n    def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):\n        ht, wd = flow.shape[:2]\n        coords = np.meshgrid(np.arange(wd), np.arange(ht))\n        coords = np.stack(coords, axis=-1)\n\n        coords = coords.reshape(-1, 2).astype(np.float32)\n        flow = flow.reshape(-1, 2).astype(np.float32)\n        valid = valid.reshape(-1).astype(np.float32)\n\n        coords0 = coords[valid>=1]\n        flow0 = flow[valid>=1]\n\n        ht1 = int(round(ht * fy))\n        wd1 = int(round(wd * fx))\n\n        coords1 = coords0 * [fx, fy]\n        flow1 = flow0 * [fx, fy]\n\n        xx = np.round(coords1[:,0]).astype(np.int32)\n        yy = np.round(coords1[:,1]).astype(np.int32)\n\n        v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)\n        xx = xx[v]\n        yy = yy[v]\n        flow1 = flow1[v]\n\n        flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)\n        valid_img = np.zeros([ht1, wd1], dtype=np.int32)\n\n        flow_img[yy, xx] = flow1\n        valid_img[yy, xx] = 1\n\n        return flow_img, valid_img\n\n    def spatial_transform(self, img1, img2, flow, valid):\n        # randomly sample scale\n\n        ht, wd = img1.shape[:2]\n        min_scale = np.maximum(\n            (self.crop_size[0] + 1) / float(ht), \n            (self.crop_size[1] + 1) / float(wd))\n\n        scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)\n        scale_x = np.clip(scale, min_scale, None)\n        scale_y = np.clip(scale, min_scale, None)\n\n        if np.random.rand() < self.spatial_aug_prob:\n            # rescale the images\n            img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)\n            flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)\n\n        if self.do_flip:\n            if np.random.rand() < 0.5: # h-flip\n                img1 = img1[:, ::-1]\n                img2 = img2[:, ::-1]\n                flow = flow[:, ::-1] * [-1.0, 1.0]\n                valid = valid[:, ::-1]\n\n        margin_y = 20\n        margin_x = 50\n\n        y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)\n        x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)\n\n        y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])\n        x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])\n\n        img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]\n        return img1, img2, flow, valid\n\n\n    def __call__(self, img1, img2, flow, valid):\n        img1, img2 = self.color_transform(img1, img2)\n        img1, img2 = self.eraser_transform(img1, img2)\n        img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)\n\n        img1 = np.ascontiguousarray(img1)\n        img2 = np.ascontiguousarray(img2)\n        flow = np.ascontiguousarray(flow)\n        valid = np.ascontiguousarray(valid)\n\n        return img1, img2, flow, valid\n"
  },
  {
    "path": "utils/RAFT/utils/flow_viz.py",
    "content": "# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization\n\n\n# MIT License\n#\n# Copyright (c) 2018 Tom Runia\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to conditions.\n#\n# Author: Tom Runia\n# Date Created: 2018-08-03\n\nimport numpy as np\n\ndef make_colorwheel():\n    \"\"\"\n    Generates a color wheel for optical flow visualization as presented in:\n        Baker et al. \"A Database and Evaluation Methodology for Optical Flow\" (ICCV, 2007)\n        URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf\n\n    Code follows the original C++ source code of Daniel Scharstein.\n    Code follows the the Matlab source code of Deqing Sun.\n\n    Returns:\n        np.ndarray: Color wheel\n    \"\"\"\n\n    RY = 15\n    YG = 6\n    GC = 4\n    CB = 11\n    BM = 13\n    MR = 6\n\n    ncols = RY + YG + GC + CB + BM + MR\n    colorwheel = np.zeros((ncols, 3))\n    col = 0\n\n    # RY\n    colorwheel[0:RY, 0] = 255\n    colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)\n    col = col+RY\n    # YG\n    colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)\n    colorwheel[col:col+YG, 1] = 255\n    col = col+YG\n    # GC\n    colorwheel[col:col+GC, 1] = 255\n    colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)\n    col = col+GC\n    # CB\n    colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)\n    colorwheel[col:col+CB, 2] = 255\n    col = col+CB\n    # BM\n    colorwheel[col:col+BM, 2] = 255\n    colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)\n    col = col+BM\n    # MR\n    colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)\n    colorwheel[col:col+MR, 0] = 255\n    return colorwheel\n\n\ndef flow_uv_to_colors(u, v, convert_to_bgr=False):\n    \"\"\"\n    Applies the flow color wheel to (possibly clipped) flow components u and v.\n\n    According to the C++ source code of Daniel Scharstein\n    According to the Matlab source code of Deqing Sun\n\n    Args:\n        u (np.ndarray): Input horizontal flow of shape [H,W]\n        v (np.ndarray): Input vertical flow of shape [H,W]\n        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.\n\n    Returns:\n        np.ndarray: Flow visualization image of shape [H,W,3]\n    \"\"\"\n    flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)\n    colorwheel = make_colorwheel()  # shape [55x3]\n    ncols = colorwheel.shape[0]\n    rad = np.sqrt(np.square(u) + np.square(v))\n    a = np.arctan2(-v, -u)/np.pi\n    fk = (a+1) / 2*(ncols-1)\n    k0 = np.floor(fk).astype(np.int32)\n    k1 = k0 + 1\n    k1[k1 == ncols] = 0\n    f = fk - k0\n    for i in range(colorwheel.shape[1]):\n        tmp = colorwheel[:,i]\n        col0 = tmp[k0] / 255.0\n        col1 = tmp[k1] / 255.0\n        col = (1-f)*col0 + f*col1\n        idx = (rad <= 1)\n        col[idx]  = 1 - rad[idx] * (1-col[idx])\n        col[~idx] = col[~idx] * 0.75   # out of range\n        # Note the 2-i => BGR instead of RGB\n        ch_idx = 2-i if convert_to_bgr else i\n        flow_image[:,:,ch_idx] = np.floor(255 * col)\n    return flow_image\n\n\ndef flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):\n    \"\"\"\n    Expects a two dimensional flow image of shape.\n\n    Args:\n        flow_uv (np.ndarray): Flow UV image of shape [H,W,2]\n        clip_flow (float, optional): Clip maximum of flow values. Defaults to None.\n        convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.\n\n    Returns:\n        np.ndarray: Flow visualization image of shape [H,W,3]\n    \"\"\"\n    assert flow_uv.ndim == 3, 'input flow must have three dimensions'\n    assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'\n    if clip_flow is not None:\n        flow_uv = np.clip(flow_uv, 0, clip_flow)\n    u = flow_uv[:,:,0]\n    v = flow_uv[:,:,1]\n    rad = np.sqrt(np.square(u) + np.square(v))\n    rad_max = np.max(rad)\n    epsilon = 1e-5\n    u = u / (rad_max + epsilon)\n    v = v / (rad_max + epsilon)\n    return flow_uv_to_colors(u, v, convert_to_bgr)"
  },
  {
    "path": "utils/RAFT/utils/frame_utils.py",
    "content": "import numpy as np\nfrom PIL import Image\nfrom os.path import *\nimport re\n\nimport cv2\ncv2.setNumThreads(0)\ncv2.ocl.setUseOpenCL(False)\n\nTAG_CHAR = np.array([202021.25], np.float32)\n\ndef readFlow(fn):\n    \"\"\" Read .flo file in Middlebury format\"\"\"\n    # Code adapted from:\n    # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy\n\n    # WARNING: this will work on little-endian architectures (eg Intel x86) only!\n    # print 'fn = %s'%(fn)\n    with open(fn, 'rb') as f:\n        magic = np.fromfile(f, np.float32, count=1)\n        if 202021.25 != magic:\n            print('Magic number incorrect. Invalid .flo file')\n            return None\n        else:\n            w = np.fromfile(f, np.int32, count=1)\n            h = np.fromfile(f, np.int32, count=1)\n            # print 'Reading %d x %d flo file\\n' % (w, h)\n            data = np.fromfile(f, np.float32, count=2*int(w)*int(h))\n            # Reshape data into 3D array (columns, rows, bands)\n            # The reshape here is for visualization, the original code is (w,h,2)\n            return np.resize(data, (int(h), int(w), 2))\n\ndef readPFM(file):\n    file = open(file, 'rb')\n\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().rstrip()\n    if header == b'PF':\n        color = True\n    elif header == b'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(rb'^(\\d+)\\s(\\d+)\\s$', file.readline())\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().rstrip())\n    if scale < 0: # little-endian\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>' # big-endian\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    return data\n\ndef writeFlow(filename,uv,v=None):\n    \"\"\" Write optical flow to file.\n    \n    If v is None, uv is assumed to contain both u and v channels,\n    stacked in depth.\n    Original code by Deqing Sun, adapted from Daniel Scharstein.\n    \"\"\"\n    nBands = 2\n\n    if v is None:\n        assert(uv.ndim == 3)\n        assert(uv.shape[2] == 2)\n        u = uv[:,:,0]\n        v = uv[:,:,1]\n    else:\n        u = uv\n\n    assert(u.shape == v.shape)\n    height,width = u.shape\n    f = open(filename,'wb')\n    # write the header\n    f.write(TAG_CHAR)\n    np.array(width).astype(np.int32).tofile(f)\n    np.array(height).astype(np.int32).tofile(f)\n    # arrange into matrix form\n    tmp = np.zeros((height, width*nBands))\n    tmp[:,np.arange(width)*2] = u\n    tmp[:,np.arange(width)*2 + 1] = v\n    tmp.astype(np.float32).tofile(f)\n    f.close()\n\n\ndef readFlowKITTI(filename):\n    flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)\n    flow = flow[:,:,::-1].astype(np.float32)\n    flow, valid = flow[:, :, :2], flow[:, :, 2]\n    flow = (flow - 2**15) / 64.0\n    return flow, valid\n\ndef readDispKITTI(filename):\n    disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0\n    valid = disp > 0.0\n    flow = np.stack([-disp, np.zeros_like(disp)], -1)\n    return flow, valid\n\n\ndef writeFlowKITTI(filename, uv):\n    uv = 64.0 * uv + 2**15\n    valid = np.ones([uv.shape[0], uv.shape[1], 1])\n    uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)\n    cv2.imwrite(filename, uv[..., ::-1])\n    \n\ndef read_gen(file_name, pil=False):\n    ext = splitext(file_name)[-1]\n    if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':\n        return Image.open(file_name)\n    elif ext == '.bin' or ext == '.raw':\n        return np.load(file_name)\n    elif ext == '.flo':\n        return readFlow(file_name).astype(np.float32)\n    elif ext == '.pfm':\n        flow = readPFM(file_name).astype(np.float32)\n        if len(flow.shape) == 2:\n            return flow\n        else:\n            return flow[:, :, :-1]\n    return []"
  },
  {
    "path": "utils/RAFT/utils/utils.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nfrom scipy import interpolate\n\n\nclass InputPadder:\n    \"\"\" Pads images such that dimensions are divisible by 8 \"\"\"\n    def __init__(self, dims, mode='sintel'):\n        self.ht, self.wd = dims[-2:]\n        pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8\n        pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8\n        if mode == 'sintel':\n            self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]\n        else:\n            self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]\n\n    def pad(self, *inputs):\n        return [F.pad(x, self._pad, mode='replicate') for x in inputs]\n\n    def unpad(self,x):\n        ht, wd = x.shape[-2:]\n        c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]\n        return x[..., c[0]:c[1], c[2]:c[3]]\n\ndef forward_interpolate(flow):\n    flow = flow.detach().cpu().numpy()\n    dx, dy = flow[0], flow[1]\n\n    ht, wd = dx.shape\n    x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))\n\n    x1 = x0 + dx\n    y1 = y0 + dy\n    \n    x1 = x1.reshape(-1)\n    y1 = y1.reshape(-1)\n    dx = dx.reshape(-1)\n    dy = dy.reshape(-1)\n\n    valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)\n    x1 = x1[valid]\n    y1 = y1[valid]\n    dx = dx[valid]\n    dy = dy[valid]\n\n    flow_x = interpolate.griddata(\n        (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)\n\n    flow_y = interpolate.griddata(\n        (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)\n\n    flow = np.stack([flow_x, flow_y], axis=0)\n    return torch.from_numpy(flow).float()\n\n\ndef bilinear_sampler(img, coords, mode='bilinear', mask=False):\n    \"\"\" Wrapper for grid_sample, uses pixel coordinates \"\"\"\n    H, W = img.shape[-2:]\n    xgrid, ygrid = coords.split([1,1], dim=-1)\n    xgrid = 2*xgrid/(W-1) - 1\n    ygrid = 2*ygrid/(H-1) - 1\n\n    grid = torch.cat([xgrid, ygrid], dim=-1)\n    img = F.grid_sample(img, grid, align_corners=True)\n\n    if mask:\n        mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)\n        return img, mask.float()\n\n    return img\n\n\ndef coords_grid(batch, ht, wd):\n    coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))\n    coords = torch.stack(coords[::-1], dim=0).float()\n    return coords[None].repeat(batch, 1, 1, 1)\n\n\ndef upflow8(flow, mode='bilinear'):\n    new_size = (8 * flow.shape[2], 8 * flow.shape[3])\n    return  8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)\n"
  },
  {
    "path": "utils/colmap_utils.py",
    "content": "# Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.\n# All rights reserved.\n#\n# Redistribution and use in source and binary forms, with or without\n# modification, are permitted provided that the following conditions are met:\n#\n#     * Redistributions of source code must retain the above copyright\n#       notice, this list of conditions and the following disclaimer.\n#\n#     * Redistributions in binary form must reproduce the above copyright\n#       notice, this list of conditions and the following disclaimer in the\n#       documentation and/or other materials provided with the distribution.\n#\n#     * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of\n#       its contributors may be used to endorse or promote products derived\n#       from this software without specific prior written permission.\n#\n# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE\n# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n# POSSIBILITY OF SUCH DAMAGE.\n#\n# Author: Johannes L. Schoenberger (jsch at inf.ethz.ch)\n\nimport os\nimport sys\nimport collections\nimport numpy as np\nimport struct\n\n\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"])\nCamera = collections.namedtuple(\n    \"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nBaseImage = collections.namedtuple(\n    \"Image\", [\"id\", \"qvec\", \"tvec\", \"camera_id\", \"name\", \"xys\", \"point3D_ids\"])\nPoint3D = collections.namedtuple(\n    \"Point3D\", [\"id\", \"xyz\", \"rgb\", \"error\", \"image_ids\", \"point2D_idxs\"])\n\nclass Image(BaseImage):\n    def qvec2rotmat(self):\n        return qvec2rotmat(self.qvec)\n\n\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12)\n}\nCAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model) \\\n                         for camera_model in CAMERA_MODELS])\n\n\ndef read_next_bytes(fid, num_bytes, format_char_sequence, endian_character=\"<\"):\n    \"\"\"Read and unpack the next bytes from a binary file.\n    :param fid:\n    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    :param endian_character: Any of {@, =, <, >, !}\n    :return: Tuple of read and unpacked values.\n    \"\"\"\n    data = fid.read(num_bytes)\n    return struct.unpack(endian_character + format_char_sequence, data)\n\n\ndef read_cameras_text(path):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasText(const std::string& path)\n        void Reconstruction::ReadCamerasText(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                camera_id = int(elems[0])\n                model = elems[1]\n                width = int(elems[2])\n                height = int(elems[3])\n                params = np.array(tuple(map(float, elems[4:])))\n                cameras[camera_id] = Camera(id=camera_id, model=model,\n                                            width=width, height=height,\n                                            params=params)\n    return cameras\n\n\ndef read_cameras_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_cameras = read_next_bytes(fid, 8, \"Q\")[0]\n        for camera_line_index in range(num_cameras):\n            camera_properties = read_next_bytes(\n                fid, num_bytes=24, format_char_sequence=\"iiQQ\")\n            camera_id = camera_properties[0]\n            model_id = camera_properties[1]\n            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name\n            width = camera_properties[2]\n            height = camera_properties[3]\n            num_params = CAMERA_MODEL_IDS[model_id].num_params\n            params = read_next_bytes(fid, num_bytes=8*num_params,\n                                     format_char_sequence=\"d\"*num_params)\n            cameras[camera_id] = Camera(id=camera_id,\n                                        model=model_name,\n                                        width=width,\n                                        height=height,\n                                        params=np.array(params))\n        assert len(cameras) == num_cameras\n    return cameras\n\n\ndef read_images_text(path):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesText(const std::string& path)\n        void Reconstruction::WriteImagesText(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                image_id = int(elems[0])\n                qvec = np.array(tuple(map(float, elems[1:5])))\n                tvec = np.array(tuple(map(float, elems[5:8])))\n                camera_id = int(elems[8])\n                image_name = elems[9]\n                elems = fid.readline().split()\n                xys = np.column_stack([tuple(map(float, elems[0::3])),\n                                       tuple(map(float, elems[1::3]))])\n                point3D_ids = np.array(tuple(map(int, elems[2::3])))\n                images[image_id] = Image(\n                    id=image_id, qvec=qvec, tvec=tvec,\n                    camera_id=camera_id, name=image_name,\n                    xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_images_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_reg_images = read_next_bytes(fid, 8, \"Q\")[0]\n        for image_index in range(num_reg_images):\n            binary_image_properties = read_next_bytes(\n                fid, num_bytes=64, format_char_sequence=\"idddddddi\")\n            image_id = binary_image_properties[0]\n            qvec = np.array(binary_image_properties[1:5])\n            tvec = np.array(binary_image_properties[5:8])\n            camera_id = binary_image_properties[8]\n            image_name = \"\"\n            current_char = read_next_bytes(fid, 1, \"c\")[0]\n            while current_char != b\"\\x00\":   # look for the ASCII 0 entry\n                image_name += current_char.decode(\"utf-8\")\n                current_char = read_next_bytes(fid, 1, \"c\")[0]\n            num_points2D = read_next_bytes(fid, num_bytes=8,\n                                           format_char_sequence=\"Q\")[0]\n            x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,\n                                       format_char_sequence=\"ddq\"*num_points2D)\n            xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),\n                                   tuple(map(float, x_y_id_s[1::3]))])\n            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))\n            images[image_id] = Image(\n                id=image_id, qvec=qvec, tvec=tvec,\n                camera_id=camera_id, name=image_name,\n                xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_points3D_text(path):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DText(const std::string& path)\n        void Reconstruction::WritePoints3DText(const std::string& path)\n    \"\"\"\n    points3D = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                point3D_id = int(elems[0])\n                xyz = np.array(tuple(map(float, elems[1:4])))\n                rgb = np.array(tuple(map(int, elems[4:7])))\n                error = float(elems[7])\n                image_ids = np.array(tuple(map(int, elems[8::2])))\n                point2D_idxs = np.array(tuple(map(int, elems[9::2])))\n                points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,\n                                               error=error, image_ids=image_ids,\n                                               point2D_idxs=point2D_idxs)\n    return points3D\n\n\ndef read_points3d_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n    points3D = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_points = read_next_bytes(fid, 8, \"Q\")[0]\n        for point_line_index in range(num_points):\n            binary_point_line_properties = read_next_bytes(\n                fid, num_bytes=43, format_char_sequence=\"QdddBBBd\")\n            point3D_id = binary_point_line_properties[0]\n            xyz = np.array(binary_point_line_properties[1:4])\n            rgb = np.array(binary_point_line_properties[4:7])\n            error = np.array(binary_point_line_properties[7])\n            track_length = read_next_bytes(\n                fid, num_bytes=8, format_char_sequence=\"Q\")[0]\n            track_elems = read_next_bytes(\n                fid, num_bytes=8*track_length,\n                format_char_sequence=\"ii\"*track_length)\n            image_ids = np.array(tuple(map(int, track_elems[0::2])))\n            point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))\n            points3D[point3D_id] = Point3D(\n                id=point3D_id, xyz=xyz, rgb=rgb,\n                error=error, image_ids=image_ids,\n                point2D_idxs=point2D_idxs)\n    return points3D\n\n\ndef read_model(path, ext):\n    if ext == \".txt\":\n        cameras = read_cameras_text(os.path.join(path, \"cameras\" + ext))\n        images = read_images_text(os.path.join(path, \"images\" + ext))\n        points3D = read_points3D_text(os.path.join(path, \"points3D\") + ext)\n    else:\n        cameras = read_cameras_binary(os.path.join(path, \"cameras\" + ext))\n        images = read_images_binary(os.path.join(path, \"images\" + ext))\n        points3D = read_points3d_binary(os.path.join(path, \"points3D\") + ext)\n    return cameras, images, points3D\n\n\ndef qvec2rotmat(qvec):\n    return np.array([\n        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,\n         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],\n         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],\n        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],\n         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,\n         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],\n        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],\n         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],\n         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])\n\n\ndef rotmat2qvec(R):\n    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat\n    K = np.array([\n        [Rxx - Ryy - Rzz, 0, 0, 0],\n        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],\n        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],\n        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0\n    eigvals, eigvecs = np.linalg.eigh(K)\n    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]\n    if qvec[0] < 0:\n        qvec *= -1\n    return qvec\n\n\ndef main():\n    if len(sys.argv) != 3:\n        print(\"Usage: python read_model.py path/to/model/folder [.txt,.bin]\")\n        return\n\n    cameras, images, points3D = read_model(path=sys.argv[1], ext=sys.argv[2])\n\n    print(\"num_cameras:\", len(cameras))\n    print(\"num_images:\", len(images))\n    print(\"num_points3D:\", len(points3D))\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "utils/evaluation.py",
    "content": "import os\nimport cv2\nimport lpips\nimport torch\nimport numpy as np\nfrom skimage.metrics import structural_similarity\n\n\ndef im2tensor(img):\n    return torch.Tensor(img.transpose(2, 0, 1) / 127.5 - 1.0)[None, ...]\n\n\ndef create_dir(dir):\n    if not os.path.exists(dir):\n        os.makedirs(dir)\n\n\ndef readimage(data_dir, sequence, time, method):\n    img = cv2.imread(os.path.join(data_dir, method, sequence, 'v000_t' + str(time).zfill(3) + '.png'))\n    return img\n\n\ndef calculate_metrics(data_dir, sequence, methods, lpips_loss):\n\n    PSNRs = np.zeros((len(methods)))\n    SSIMs = np.zeros((len(methods)))\n    LPIPSs = np.zeros((len(methods)))\n\n    nFrame = 0\n\n    # Yoon's results do not include v000_t000 and v000_t011. Omit these two\n    # frames if evaluating Yoon's method.\n    if 'Yoon' in methods:\n        time_start = 1\n        time_end = 11\n    else:\n        time_start = 0\n        time_end = 12\n\n    for time in range(time_start, time_end): # Fix view v0, change time\n\n        nFrame += 1\n\n        img_true = readimage(data_dir, sequence, time, 'gt')\n\n        for method_idx, method in enumerate(methods):\n\n            if 'Yoon' in methods and sequence == 'Truck' and time == 10:\n                break\n\n            img = readimage(data_dir, sequence, time, method)\n            PSNR = cv2.PSNR(img_true, img)\n            SSIM = structural_similarity(img_true, img, multichannel=True)\n            LPIPS = lpips_loss.forward(im2tensor(img_true), im2tensor(img)).item()\n\n            PSNRs[method_idx] += PSNR\n            SSIMs[method_idx] += SSIM\n            LPIPSs[method_idx] += LPIPS\n\n    PSNRs = PSNRs / nFrame\n    SSIMs = SSIMs / nFrame\n    LPIPSs = LPIPSs / nFrame\n\n    return PSNRs, SSIMs, LPIPSs\n\n\nif __name__ == '__main__':\n\n    lpips_loss = lpips.LPIPS(net='alex') # best forward scores\n    data_dir = '../results'\n    sequences = ['Balloon1', 'Balloon2', 'Jumping', 'Playground', 'Skating', 'Truck', 'Umbrella']\n    # methods = ['NeRF', 'NeRF_t', 'Yoon', 'NR', 'NSFF', 'Ours']\n    methods = ['NeRF', 'NeRF_t', 'NR', 'NSFF', 'Ours']\n\n    PSNRs_total = np.zeros((len(methods)))\n    SSIMs_total = np.zeros((len(methods)))\n    LPIPSs_total = np.zeros((len(methods)))\n    for sequence in sequences:\n        print(sequence)\n        PSNRs, SSIMs, LPIPSs = calculate_metrics(data_dir, sequence, methods, lpips_loss)\n        for method_idx, method in enumerate(methods):\n            print(method.ljust(7) + '%.2f'%(PSNRs[method_idx]) + ' / %.4f'%(SSIMs[method_idx]) + ' / %.3f'%(LPIPSs[method_idx]))\n\n        PSNRs_total += PSNRs\n        SSIMs_total += SSIMs\n        LPIPSs_total += LPIPSs\n\n    PSNRs_total = PSNRs_total / len(sequences)\n    SSIMs_total = SSIMs_total / len(sequences)\n    LPIPSs_total = LPIPSs_total / len(sequences)\n    print('Avg.')\n    for method_idx, method in enumerate(methods):\n        print(method.ljust(7) + '%.2f'%(PSNRs_total[method_idx]) + ' / %.4f'%(SSIMs_total[method_idx]) + ' / %.3f'%(LPIPSs_total[method_idx]))\n"
  },
  {
    "path": "utils/flow_utils.py",
    "content": "import os\nimport cv2\nimport numpy as np\nfrom PIL import Image\nfrom os.path import *\nUNKNOWN_FLOW_THRESH = 1e7\n\ndef flow_to_image(flow, global_max=None):\n    \"\"\"\n    Convert flow into middlebury color code image\n    :param flow: optical flow map\n    :return: optical flow image in middlebury color\n    \"\"\"\n    u = flow[:, :, 0]\n    v = flow[:, :, 1]\n\n    maxu = -999.\n    maxv = -999.\n    minu = 999.\n    minv = 999.\n\n    idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH)\n    u[idxUnknow] = 0\n    v[idxUnknow] = 0\n\n    maxu = max(maxu, np.max(u))\n    minu = min(minu, np.min(u))\n\n    maxv = max(maxv, np.max(v))\n    minv = min(minv, np.min(v))\n\n    rad = np.sqrt(u ** 2 + v ** 2)\n\n    if global_max == None:\n        maxrad = max(-1, np.max(rad))\n    else:\n        maxrad = global_max\n\n    u = u/(maxrad + np.finfo(float).eps)\n    v = v/(maxrad + np.finfo(float).eps)\n\n    img = compute_color(u, v)\n\n    idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2)\n    img[idx] = 0\n\n    return np.uint8(img)\n\n\ndef compute_color(u, v):\n    \"\"\"\n    compute optical flow color map\n    :param u: optical flow horizontal map\n    :param v: optical flow vertical map\n    :return: optical flow in color code\n    \"\"\"\n    [h, w] = u.shape\n    img = np.zeros([h, w, 3])\n    nanIdx = np.isnan(u) | np.isnan(v)\n    u[nanIdx] = 0\n    v[nanIdx] = 0\n\n    colorwheel = make_color_wheel()\n    ncols = np.size(colorwheel, 0)\n\n    rad = np.sqrt(u**2+v**2)\n\n    a = np.arctan2(-v, -u) / np.pi\n\n    fk = (a+1) / 2 * (ncols - 1) + 1\n\n    k0 = np.floor(fk).astype(int)\n\n    k1 = k0 + 1\n    k1[k1 == ncols+1] = 1\n    f = fk - k0\n\n    for i in range(0, np.size(colorwheel,1)):\n        tmp = colorwheel[:, i]\n        col0 = tmp[k0-1] / 255\n        col1 = tmp[k1-1] / 255\n        col = (1-f) * col0 + f * col1\n\n        idx = rad <= 1\n        col[idx] = 1-rad[idx]*(1-col[idx])\n        notidx = np.logical_not(idx)\n\n        col[notidx] *= 0.75\n        img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx)))\n\n    return img\n\n\ndef make_color_wheel():\n    \"\"\"\n    Generate color wheel according Middlebury color code\n    :return: Color wheel\n    \"\"\"\n    RY = 15\n    YG = 6\n    GC = 4\n    CB = 11\n    BM = 13\n    MR = 6\n\n    ncols = RY + YG + GC + CB + BM + MR\n\n    colorwheel = np.zeros([ncols, 3])\n\n    col = 0\n\n    # RY\n    colorwheel[0:RY, 0] = 255\n    colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY))\n    col += RY\n\n    # YG\n    colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG))\n    colorwheel[col:col+YG, 1] = 255\n    col += YG\n\n    # GC\n    colorwheel[col:col+GC, 1] = 255\n    colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC))\n    col += GC\n\n    # CB\n    colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB))\n    colorwheel[col:col+CB, 2] = 255\n    col += CB\n\n    # BM\n    colorwheel[col:col+BM, 2] = 255\n    colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM))\n    col += + BM\n\n    # MR\n    colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR))\n    colorwheel[col:col+MR, 0] = 255\n\n    return colorwheel\n\n\ndef resize_flow(flow, H_new, W_new):\n    H_old, W_old = flow.shape[0:2]\n    flow_resized = cv2.resize(flow, (W_new, H_new), interpolation=cv2.INTER_LINEAR)\n    flow_resized[:, :, 0] *= H_new / H_old\n    flow_resized[:, :, 1] *= W_new / W_old\n    return flow_resized\n\n\n\ndef warp_flow(img, flow):\n    h, w = flow.shape[:2]\n    flow_new = flow.copy()\n    flow_new[:,:,0] += np.arange(w)\n    flow_new[:,:,1] += np.arange(h)[:,np.newaxis]\n\n    res = cv2.remap(img, flow_new, None,\n                    cv2.INTER_CUBIC,\n                    borderMode=cv2.BORDER_CONSTANT)\n    return res\n\n\ndef consistCheck(flowB, flowF):\n\n    # |--------------------|  |--------------------|\n    # |       y            |  |       v            |\n    # |   x   *            |  |   u   *            |\n    # |                    |  |                    |\n    # |--------------------|  |--------------------|\n\n    # sub: numPix * [y x t]\n\n    imgH, imgW, _ = flowF.shape\n\n    (fy, fx) = np.mgrid[0 : imgH, 0 : imgW].astype(np.float32)\n    fxx = fx + flowB[:, :, 0]  # horizontal\n    fyy = fy + flowB[:, :, 1]  # vertical\n\n    u = (fxx + cv2.remap(flowF[:, :, 0], fxx, fyy, cv2.INTER_LINEAR) - fx)\n    v = (fyy + cv2.remap(flowF[:, :, 1], fxx, fyy, cv2.INTER_LINEAR) - fy)\n    BFdiff = (u ** 2 + v ** 2) ** 0.5\n\n    return BFdiff, np.stack((u, v), axis=2)\n\n\ndef read_optical_flow(basedir, img_i_name, read_fwd):\n    flow_dir = os.path.join(basedir, 'flow')\n\n    fwd_flow_path = os.path.join(flow_dir, '%s_fwd.npz'%img_i_name[:-4])\n    bwd_flow_path = os.path.join(flow_dir, '%s_bwd.npz'%img_i_name[:-4])\n\n    if read_fwd:\n      fwd_data = np.load(fwd_flow_path)\n      fwd_flow, fwd_mask = fwd_data['flow'], fwd_data['mask']\n      return fwd_flow, fwd_mask\n    else:\n      bwd_data = np.load(bwd_flow_path)\n      bwd_flow, bwd_mask = bwd_data['flow'], bwd_data['mask']\n      return bwd_flow, bwd_mask\n\n\ndef compute_epipolar_distance(T_21, K, p_1, p_2):\n    R_21 = T_21[:3, :3]\n    t_21 = T_21[:3, 3]\n\n    E_mat = np.dot(skew(t_21), R_21)\n    # compute bearing vector\n    inv_K = np.linalg.inv(K)\n\n    F_mat = np.dot(np.dot(inv_K.T, E_mat), inv_K)\n\n    l_2 = np.dot(F_mat, p_1)\n    algebric_e_distance = np.sum(p_2 * l_2, axis=0)\n    n_term = np.sqrt(l_2[0, :]**2 + l_2[1, :]**2) + 1e-8\n    geometric_e_distance = algebric_e_distance/n_term\n    geometric_e_distance = np.abs(geometric_e_distance)\n\n    return geometric_e_distance\n\n\ndef skew(x):\n    return np.array([[0, -x[2], x[1]],\n                     [x[2], 0, -x[0]],\n                     [-x[1], x[0], 0]])\n"
  },
  {
    "path": "utils/generate_data.py",
    "content": "import os\nimport numpy as np\nimport imageio\nimport glob\nimport torch\nimport torchvision\nimport skimage.morphology\nimport argparse\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n\ndef create_dir(dir):\n    if not os.path.exists(dir):\n        os.makedirs(dir)\n\n\ndef multi_view_multi_time(args):\n    \"\"\"\n    Generating multi view multi time data\n    \"\"\"\n\n    Maskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).cuda().eval()\n    threshold = 0.5\n\n    videoname, ext = os.path.splitext(os.path.basename(args.videopath))\n\n    imgs = []\n    reader = imageio.get_reader(args.videopath)\n    for i, im in enumerate(reader):\n        imgs.append(im)\n\n    imgs = np.array(imgs)\n    num_frames, H, W, _ = imgs.shape\n    imgs = imgs[::int(np.ceil(num_frames / 100))]\n\n    create_dir(os.path.join(args.data_dir, videoname, 'images'))\n    create_dir(os.path.join(args.data_dir, videoname, 'images_colmap'))\n    create_dir(os.path.join(args.data_dir, videoname, 'background_mask'))\n\n    for idx, img in enumerate(imgs):\n        print(idx)\n        imageio.imwrite(os.path.join(args.data_dir, videoname, 'images', str(idx).zfill(3) + '.png'), img)\n        imageio.imwrite(os.path.join(args.data_dir, videoname, 'images_colmap', str(idx).zfill(3) + '.jpg'), img)\n\n        # Get coarse background mask\n        img = torchvision.transforms.functional.to_tensor(img).to(device)\n        background_mask = torch.FloatTensor(H, W).fill_(1.0).to(device)\n        objPredictions = Maskrcnn([img])[0]\n\n        for intMask in range(len(objPredictions['masks'])):\n            if objPredictions['scores'][intMask].item() > threshold:\n                if objPredictions['labels'][intMask].item() == 1: # person\n                    background_mask[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0\n\n        background_mask_np = ((background_mask.cpu().numpy() > 0.1) * 255).astype(np.uint8)\n        imageio.imwrite(os.path.join(args.data_dir, videoname, 'background_mask', str(idx).zfill(3) + '.jpg.png'), background_mask_np)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--videopath\", type=str,\n                        help='video path')\n    parser.add_argument(\"--data_dir\", type=str, default='../data/',\n                        help='where to store data')\n\n    args = parser.parse_args()\n\n    multi_view_multi_time(args)\n"
  },
  {
    "path": "utils/generate_depth.py",
    "content": "\"\"\"Compute depth maps for images in the input folder.\n\"\"\"\nimport os\nimport cv2\nimport glob\nimport torch\nimport argparse\nimport numpy as np\n\nfrom torchvision.transforms import Compose\nfrom midas.midas_net import MidasNet\nfrom midas.transforms import Resize, NormalizeImage, PrepareForNet\n\n\ndef create_dir(dir):\n    if not os.path.exists(dir):\n        os.makedirs(dir)\n\n\ndef read_image(path):\n    \"\"\"Read image and output RGB image (0-1).\n\n    Args:\n        path (str): path to file\n\n    Returns:\n        array: RGB image (0-1)\n    \"\"\"\n    img = cv2.imread(path)\n\n    if img.ndim == 2:\n        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)\n\n    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0\n\n    return img\n\n\ndef run(input_path, output_path, output_img_path, model_path):\n    \"\"\"Run MonoDepthNN to compute depth maps.\n    Args:\n        input_path (str): path to input folder\n        output_path (str): path to output folder\n        model_path (str): path to saved model\n    \"\"\"\n    print(\"initialize\")\n\n    # select device\n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    print(\"device: %s\" % device)\n\n    # load network\n    model = MidasNet(model_path, non_negative=True)\n    sh = cv2.imread(sorted(glob.glob(os.path.join(input_path, \"*\")))[0]).shape\n    net_w, net_h = sh[1], sh[0]\n\n    resize_mode=\"upper_bound\"\n\n    transform = Compose(\n        [\n            Resize(\n                net_w,\n                net_h,\n                resize_target=None,\n                keep_aspect_ratio=True,\n                ensure_multiple_of=32,\n                resize_method=resize_mode,\n                image_interpolation_method=cv2.INTER_CUBIC,\n            ),\n            NormalizeImage(mean=[0.485, 0.456, 0.406],\n                        std=[0.229, 0.224, 0.225]),\n            PrepareForNet(),\n        ]\n    )\n\n    model.eval()\n    model.to(device)\n\n    # get input\n    img_names = sorted(glob.glob(os.path.join(input_path, \"*\")))\n    num_images = len(img_names)\n\n    # create output folder\n    os.makedirs(output_path, exist_ok=True)\n\n    print(\"start processing\")\n\n    for ind, img_name in enumerate(img_names):\n\n        print(\"  processing {} ({}/{})\".format(img_name, ind + 1, num_images))\n\n        # input\n        img = read_image(img_name)\n        img_input = transform({\"image\": img})[\"image\"]\n\n        # compute\n        with torch.no_grad():\n            sample = torch.from_numpy(img_input).to(device).unsqueeze(0)\n            prediction = model.forward(sample)\n            prediction = (\n                torch.nn.functional.interpolate(\n                    prediction.unsqueeze(1),\n                    size=[net_h, net_w],\n                    mode=\"bicubic\",\n                    align_corners=False,\n                )\n                .squeeze()\n                .cpu()\n                .numpy()\n            )\n\n        # output\n        filename = os.path.join(\n            output_path, os.path.splitext(os.path.basename(img_name))[0]\n        )\n\n        print(filename + '.npy')\n        np.save(filename + '.npy', prediction.astype(np.float32))\n\n        depth_min = prediction.min()\n        depth_max = prediction.max()\n\n        max_val = (2**(8*2))-1\n\n        if depth_max - depth_min > np.finfo(\"float\").eps:\n            out = max_val * (prediction - depth_min) / (depth_max - depth_min)\n        else:\n            out = np.zeros(prediction.shape, dtype=prediction.type)\n\n        cv2.imwrite(os.path.join(output_img_path, os.path.splitext(os.path.basename(img_name))[0] + '.png'), out.astype(\"uint16\"))\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset_path\", type=str, help='Dataset path')\n    parser.add_argument('--model', help=\"restore midas checkpoint\")\n    args = parser.parse_args()\n\n    input_path = os.path.join(args.dataset_path, 'images')\n    output_path = os.path.join(args.dataset_path, 'disp')\n    output_img_path = os.path.join(args.dataset_path, 'disp_png')\n    create_dir(output_path)\n    create_dir(output_img_path)\n\n    # set torch options\n    torch.backends.cudnn.enabled = True\n    torch.backends.cudnn.benchmark = True\n\n    # compute depth maps\n    run(input_path, output_path, output_img_path, args.model)\n"
  },
  {
    "path": "utils/generate_flow.py",
    "content": "import argparse\nimport os\nimport cv2\nimport glob\nimport numpy as np\nimport torch\nfrom PIL import Image\n\nfrom RAFT.raft import RAFT\nfrom RAFT.utils import flow_viz\nfrom RAFT.utils.utils import InputPadder\n\nfrom flow_utils import *\n\nDEVICE = 'cuda'\n\n\ndef create_dir(dir):\n    if not os.path.exists(dir):\n        os.makedirs(dir)\n\n\ndef load_image(imfile):\n    img = np.array(Image.open(imfile)).astype(np.uint8)\n    img = torch.from_numpy(img).permute(2, 0, 1).float()\n    return img[None].to(DEVICE)\n\n\ndef warp_flow(img, flow):\n    h, w = flow.shape[:2]\n    flow_new = flow.copy()\n    flow_new[:,:,0] += np.arange(w)\n    flow_new[:,:,1] += np.arange(h)[:,np.newaxis]\n\n    res = cv2.remap(img, flow_new, None, cv2.INTER_CUBIC, borderMode=cv2.BORDER_CONSTANT)\n    return res\n\n\ndef compute_fwdbwd_mask(fwd_flow, bwd_flow):\n    alpha_1 = 0.5\n    alpha_2 = 0.5\n\n    bwd2fwd_flow = warp_flow(bwd_flow, fwd_flow)\n    fwd_lr_error = np.linalg.norm(fwd_flow + bwd2fwd_flow, axis=-1)\n    fwd_mask = fwd_lr_error < alpha_1  * (np.linalg.norm(fwd_flow, axis=-1) \\\n                + np.linalg.norm(bwd2fwd_flow, axis=-1)) + alpha_2\n\n    fwd2bwd_flow = warp_flow(fwd_flow, bwd_flow)\n    bwd_lr_error = np.linalg.norm(bwd_flow + fwd2bwd_flow, axis=-1)\n\n    bwd_mask = bwd_lr_error < alpha_1  * (np.linalg.norm(bwd_flow, axis=-1) \\\n                + np.linalg.norm(fwd2bwd_flow, axis=-1)) + alpha_2\n\n    return fwd_mask, bwd_mask\n\ndef run(args, input_path, output_path, output_img_path):\n    model = torch.nn.DataParallel(RAFT(args))\n    model.load_state_dict(torch.load(args.model))\n\n    model = model.module\n    model.to(DEVICE)\n    model.eval()\n\n    with torch.no_grad():\n        images = glob.glob(os.path.join(input_path, '*.png')) + \\\n                 glob.glob(os.path.join(input_path, '*.jpg'))\n\n        images = sorted(images)\n        for i in range(len(images) - 1):\n            print(i)\n            image1 = load_image(images[i])\n            image2 = load_image(images[i + 1])\n\n            padder = InputPadder(image1.shape)\n            image1, image2 = padder.pad(image1, image2)\n\n            _, flow_fwd = model(image1, image2, iters=20, test_mode=True)\n            _, flow_bwd = model(image2, image1, iters=20, test_mode=True)\n\n            flow_fwd = padder.unpad(flow_fwd[0]).cpu().numpy().transpose(1, 2, 0)\n            flow_bwd = padder.unpad(flow_bwd[0]).cpu().numpy().transpose(1, 2, 0)\n\n            mask_fwd, mask_bwd = compute_fwdbwd_mask(flow_fwd, flow_bwd)\n\n            # Save flow\n            np.savez(os.path.join(output_path, '%03d_fwd.npz'%i), flow=flow_fwd, mask=mask_fwd)\n            np.savez(os.path.join(output_path, '%03d_bwd.npz'%(i + 1)), flow=flow_bwd, mask=mask_bwd)\n\n            # Save flow_img\n            Image.fromarray(flow_viz.flow_to_image(flow_fwd)).save(os.path.join(output_img_path, '%03d_fwd.png'%i))\n            Image.fromarray(flow_viz.flow_to_image(flow_bwd)).save(os.path.join(output_img_path, '%03d_bwd.png'%(i + 1)))\n\n            Image.fromarray(mask_fwd).save(os.path.join(output_img_path, '%03d_fwd_mask.png'%i))\n            Image.fromarray(mask_bwd).save(os.path.join(output_img_path, '%03d_bwd_mask.png'%(i + 1)))\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset_path\", type=str, help='Dataset path')\n    parser.add_argument('--model', help=\"restore RAFT checkpoint\")\n    parser.add_argument('--small', action='store_true', help='use small model')\n    parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')\n    args = parser.parse_args()\n\n    input_path = os.path.join(args.dataset_path, 'images')\n    output_path = os.path.join(args.dataset_path, 'flow')\n    output_img_path = os.path.join(args.dataset_path, 'flow_png')\n    create_dir(output_path)\n    create_dir(output_img_path)\n\n    run(args, input_path, output_path, output_img_path)\n"
  },
  {
    "path": "utils/generate_motion_mask.py",
    "content": "import os\nimport cv2\nimport PIL\nimport glob\nimport torch\nimport argparse\nimport numpy as np\n\nfrom colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary\n\nimport skimage.morphology\nimport torchvision\nfrom flow_utils import read_optical_flow, compute_epipolar_distance, skew\n\n\n\ndef create_dir(dir):\n    if not os.path.exists(dir):\n        os.makedirs(dir)\n\n\ndef extract_poses(im):\n    R = im.qvec2rotmat()\n    t = im.tvec.reshape([3,1])\n    bottom = np.array([0,0,0,1.]).reshape([1,4])\n\n    m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)\n\n    return m\n\n\ndef load_colmap_data(realdir):\n\n    camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin')\n    camdata = read_cameras_binary(camerasfile)\n\n    list_of_keys = list(camdata.keys())\n    cam = camdata[list_of_keys[0]]\n    print( 'Cameras', len(cam))\n\n    h, w, f = cam.height, cam.width, cam.params[0]\n    # w, h, f = factor * w, factor * h, factor * f\n    hwf = np.array([h,w,f]).reshape([3,1])\n\n    imagesfile = os.path.join(realdir, 'sparse/0/images.bin')\n    imdata = read_images_binary(imagesfile)\n\n    w2c_mats = []\n    # bottom = np.array([0,0,0,1.]).reshape([1,4])\n\n    names = [imdata[k].name for k in imdata]\n    img_keys = [k for k in imdata]\n\n    print( 'Images #', len(names))\n    perm = np.argsort(names)\n\n    return imdata, perm, img_keys, hwf\n\n\ndef run_maskrcnn(model, img_path, intWidth=1024, intHeight=576):\n\n    # intHeight = 576\n    # intWidth = 1024\n\n    threshold = 0.5\n\n    o_image = PIL.Image.open(img_path)\n    image = o_image.resize((intWidth, intHeight), PIL.Image.ANTIALIAS)\n\n    image_tensor = torchvision.transforms.functional.to_tensor(image).cuda()\n\n    tenHumans = torch.FloatTensor(intHeight, intWidth).fill_(1.0).cuda()\n\n    objPredictions = model([image_tensor])[0]\n\n    for intMask in range(objPredictions['masks'].size(0)):\n        if objPredictions['scores'][intMask].item() > threshold:\n            if objPredictions['labels'][intMask].item() == 1: # person\n                tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0\n\n            if objPredictions['labels'][intMask].item() == 4: # motorcycle\n                tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0\n\n            if objPredictions['labels'][intMask].item() == 2: # bicycle\n                tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0\n\n            if objPredictions['labels'][intMask].item() == 8: # truck\n                tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0\n\n            if objPredictions['labels'][intMask].item() == 28: # umbrella\n                tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0\n\n            if objPredictions['labels'][intMask].item() == 17: # cat\n                tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0\n\n            if objPredictions['labels'][intMask].item() == 18: # dog\n                tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0\n\n            if objPredictions['labels'][intMask].item() == 36: # snowboard\n                tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0\n\n            if objPredictions['labels'][intMask].item() == 41: # skateboard\n                tenHumans[objPredictions['masks'][intMask, 0, :, :] > threshold] = 0.0\n\n    npyMask = skimage.morphology.erosion(tenHumans.cpu().numpy(),\n                                         skimage.morphology.disk(1))\n    npyMask = ((npyMask < 1e-3) * 255.0).clip(0.0, 255.0).astype(np.uint8)\n    return npyMask\n\n\ndef motion_segmentation(basedir, threshold,\n                        input_semantic_w=1024,\n                        input_semantic_h=576):\n\n    points3dfile = os.path.join(basedir, 'sparse/0/points3D.bin')\n    pts3d = read_points3d_binary(points3dfile)\n\n    img_dir = glob.glob(basedir + '/images_colmap')[0]\n    img0 = glob.glob(glob.glob(img_dir)[0] + '/*jpg')[0]\n    shape_0 = cv2.imread(img0).shape\n\n    resized_height, resized_width = shape_0[0], shape_0[1]\n\n    imdata, perm, img_keys, hwf = load_colmap_data(basedir)\n    scale_x, scale_y = resized_width / float(hwf[1]), resized_height / float(hwf[0])\n\n    K = np.eye(3)\n    K[0, 0] = hwf[2]\n    K[0, 2] = hwf[1] / 2.\n    K[1, 1] = hwf[2]\n    K[1, 2] = hwf[0] / 2.\n\n    xx = range(0, resized_width)\n    yy = range(0, resized_height)\n    xv, yv = np.meshgrid(xx, yy)\n    p_ref = np.float32(np.stack((xv, yv), axis=-1))\n    p_ref_h = np.reshape(p_ref, (-1, 2))\n    p_ref_h = np.concatenate((p_ref_h, np.ones((p_ref_h.shape[0], 1))), axis=-1).T\n\n    num_frames = len(perm)\n\n    if os.path.isdir(os.path.join(basedir, 'images_colmap')):\n        num_colmap_frames = len(glob.glob(os.path.join(basedir, 'images_colmap', '*.jpg')))\n        num_data_frames = len(glob.glob(os.path.join(basedir, 'images', '*.png')))\n\n        if num_colmap_frames != num_data_frames:\n            num_frames = num_data_frames\n\n\n    save_mask_dir = os.path.join(basedir, 'motion_segmentation')\n    create_dir(save_mask_dir)\n\n    for i in range(0, num_frames):\n        im_prev = imdata[img_keys[perm[max(0, i - 1)]]]\n        im_ref = imdata[img_keys[perm[i]]]\n        im_post = imdata[img_keys[perm[min(num_frames -1, i + 1)]]]\n\n        print(im_prev.name, im_ref.name, im_post.name)\n\n        T_prev_G = extract_poses(im_prev)\n        T_ref_G = extract_poses(im_ref)\n        T_post_G = extract_poses(im_post)\n\n        T_ref2prev = np.dot(T_prev_G, np.linalg.inv(T_ref_G))\n        T_ref2post = np.dot(T_post_G, np.linalg.inv(T_ref_G))\n        # load optical flow\n\n        if i == 0:\n          fwd_flow, _ = read_optical_flow(basedir,\n                                       im_ref.name,\n                                       read_fwd=True)\n          bwd_flow = np.zeros_like(fwd_flow)\n        elif i == num_frames - 1:\n          bwd_flow, _ = read_optical_flow(basedir,\n                                       im_ref.name,\n                                       read_fwd=False)\n          fwd_flow = np.zeros_like(bwd_flow)\n        else:\n          fwd_flow, _ = read_optical_flow(basedir,\n                                       im_ref.name,\n                                       read_fwd=True)\n          bwd_flow, _ = read_optical_flow(basedir,\n                                       im_ref.name,\n                                       read_fwd=False)\n\n        p_post = p_ref + fwd_flow\n        p_post_h = np.reshape(p_post, (-1, 2))\n        p_post_h = np.concatenate((p_post_h, np.ones((p_post_h.shape[0], 1))), axis=-1).T\n\n        fwd_e_dist = compute_epipolar_distance(T_ref2post, K,\n                                               p_ref_h, p_post_h)\n        fwd_e_dist = np.reshape(fwd_e_dist, (fwd_flow.shape[0], fwd_flow.shape[1]))\n\n        p_prev = p_ref + bwd_flow\n        p_prev_h = np.reshape(p_prev, (-1, 2))\n        p_prev_h = np.concatenate((p_prev_h, np.ones((p_prev_h.shape[0], 1))), axis=-1).T\n\n        bwd_e_dist = compute_epipolar_distance(T_ref2prev, K,\n                                               p_ref_h, p_prev_h)\n        bwd_e_dist = np.reshape(bwd_e_dist, (bwd_flow.shape[0], bwd_flow.shape[1]))\n\n        e_dist = np.maximum(bwd_e_dist, fwd_e_dist)\n\n        motion_mask = skimage.morphology.binary_opening(e_dist > threshold, skimage.morphology.disk(1))\n\n        cv2.imwrite(os.path.join(save_mask_dir, im_ref.name.replace('.jpg', '.png')), np.uint8(255 * (0. + motion_mask)))\n\n    # RUN SEMANTIC SEGMENTATION\n    img_dir = os.path.join(basedir, 'images')\n    img_path_list = sorted(glob.glob(os.path.join(img_dir, '*.jpg'))) \\\n                  + sorted(glob.glob(os.path.join(img_dir, '*.png')))\n    semantic_mask_dir = os.path.join(basedir, 'semantic_mask')\n    netMaskrcnn = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True).cuda().eval()\n    create_dir(semantic_mask_dir)\n\n\n    for i in range(0, len(img_path_list)):\n        img_path = img_path_list[i]\n        img_name = img_path.split('/')[-1]\n        semantic_mask = run_maskrcnn(netMaskrcnn, img_path,\n                                     input_semantic_w,\n                                     input_semantic_h)\n        cv2.imwrite(os.path.join(semantic_mask_dir,\n                                img_name.replace('.jpg', '.png')),\n                    semantic_mask)\n\n    # combine them\n    save_mask_dir = os.path.join(basedir, 'motion_masks')\n    create_dir(save_mask_dir)\n\n    mask_dir = os.path.join(basedir, 'motion_segmentation')\n    mask_path_list = sorted(glob.glob(os.path.join(mask_dir, '*.png')))\n\n    semantic_dir = os.path.join(basedir, 'semantic_mask')\n\n    for mask_path in mask_path_list:\n        print(mask_path)\n\n        motion_mask = cv2.imread(mask_path)\n        motion_mask = cv2.resize(motion_mask, (resized_width, resized_height),\n                                interpolation=cv2.INTER_NEAREST)\n        motion_mask = motion_mask[:, :, 0] > 0.1\n\n        # combine from motion segmentation\n        semantic_mask = cv2.imread(os.path.join(semantic_dir, mask_path.split('/')[-1]))\n        semantic_mask = cv2.resize(semantic_mask, (resized_width, resized_height),\n                                interpolation=cv2.INTER_NEAREST)\n        semantic_mask = semantic_mask[:, :, 0] > 0.1\n        motion_mask = semantic_mask | motion_mask\n\n        motion_mask = skimage.morphology.dilation(motion_mask, skimage.morphology.disk(2))\n        cv2.imwrite(os.path.join(save_mask_dir, '%s'%mask_path.split('/')[-1]),\n                    np.uint8(np.clip((motion_mask), 0, 1) * 255) )\n\n    # delete old mask dir\n    os.system('rm -r %s'%mask_dir)\n    os.system('rm -r %s'%semantic_dir)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset_path\", type=str, help='Dataset path')\n    parser.add_argument(\"--epi_threshold\", type=float,\n                        default=1.0,\n                        help='epipolar distance threshold for physical motion segmentation')\n\n    parser.add_argument(\"--input_flow_w\", type=int,\n                        default=768,\n                        help='input image width for optical flow, \\\n                        the height will be computed based on original aspect ratio ')\n\n    parser.add_argument(\"--input_semantic_w\", type=int,\n                        default=1024,\n                        help='input image width for semantic segmentation')\n\n    parser.add_argument(\"--input_semantic_h\", type=int,\n                        default=576,\n                        help='input image height for semantic segmentation')\n    args = parser.parse_args()\n\n    motion_segmentation(args.dataset_path, args.epi_threshold,\n                        args.input_semantic_w,\n                        args.input_semantic_h)\n"
  },
  {
    "path": "utils/generate_pose.py",
    "content": "import os\nimport glob\nimport argparse\nimport numpy as np\nfrom colmap_utils import read_cameras_binary, read_images_binary, read_points3d_binary\n\n\ndef load_colmap_data(realdir):\n\n    camerasfile = os.path.join(realdir, 'sparse/0/cameras.bin')\n    camdata = read_cameras_binary(camerasfile)\n\n    list_of_keys = list(camdata.keys())\n    cam = camdata[list_of_keys[0]]\n    print( 'Cameras', len(cam))\n\n    h, w, f = cam.height, cam.width, cam.params[0]\n    # w, h, f = factor * w, factor * h, factor * f\n    hwf = np.array([h,w,f]).reshape([3,1])\n\n    imagesfile = os.path.join(realdir, 'sparse/0/images.bin')\n    imdata = read_images_binary(imagesfile)\n\n    w2c_mats = []\n    bottom = np.array([0,0,0,1.]).reshape([1,4])\n\n    names = [imdata[k].name for k in imdata]\n    img_keys = [k for k in imdata]\n\n    print('Images #', len(names))\n    perm = np.argsort(names)\n\n    points3dfile = os.path.join(realdir, 'sparse/0/points3D.bin')\n    pts3d = read_points3d_binary(points3dfile)\n\n    bounds_mats = []\n\n    for i in perm[0:len(img_keys)]:\n\n        im = imdata[img_keys[i]]\n        print(im.name)\n        R = im.qvec2rotmat()\n        t = im.tvec.reshape([3,1])\n        m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)\n        w2c_mats.append(m)\n\n        pts_3d_idx = im.point3D_ids\n        pts_3d_vis_idx = pts_3d_idx[pts_3d_idx >= 0]\n\n        #\n        depth_list = []\n        for k in range(len(pts_3d_vis_idx)):\n          point_info = pts3d[pts_3d_vis_idx[k]]\n\n          P_g = point_info.xyz\n          P_c = np.dot(R, P_g.reshape(3, 1)) + t.reshape(3, 1)\n          depth_list.append(P_c[2])\n\n        zs = np.array(depth_list)\n        close_depth, inf_depth = np.percentile(zs, 5), np.percentile(zs, 95)\n        bounds = np.array([close_depth, inf_depth])\n        bounds_mats.append(bounds)\n\n    w2c_mats = np.stack(w2c_mats, 0)\n    c2w_mats = np.linalg.inv(w2c_mats)\n\n    poses = c2w_mats[:, :3, :4].transpose([1,2,0])\n    poses = np.concatenate([poses, np.tile(hwf[..., np.newaxis],\n                                        [1,1,poses.shape[-1]])], 1)\n\n    # must switch to [-u, r, -t] from [r, -u, t], NOT [r, u, -t]\n    poses = np.concatenate([poses[:, 1:2, :],\n                            poses[:, 0:1, :],\n                           -poses[:, 2:3, :],\n                            poses[:, 3:4, :],\n                            poses[:, 4:5, :]], 1)\n\n    save_arr = []\n\n    for i in range((poses.shape[2])):\n        save_arr.append(np.concatenate([poses[..., i].ravel(), bounds_mats[i]], 0))\n\n    save_arr = np.array(save_arr)\n    print(save_arr.shape)\n\n    # Use all frames to calculate COLMAP camera poses.\n    if os.path.isdir(os.path.join(realdir, 'images_colmap')):\n        num_colmap_frames = len(glob.glob(os.path.join(realdir, 'images_colmap', '*.jpg')))\n        num_data_frames = len(glob.glob(os.path.join(realdir, 'images', '*.png')))\n\n        assert num_colmap_frames == save_arr.shape[0]\n        np.save(os.path.join(realdir, 'poses_bounds.npy'), save_arr[:num_data_frames, :])\n    else:\n        np.save(os.path.join(realdir, 'poses_bounds.npy'), save_arr)\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--dataset_path\", type=str,\n                        help='Dataset path')\n\n    args = parser.parse_args()\n\n    load_colmap_data(args.dataset_path)\n"
  },
  {
    "path": "utils/midas/base_model.py",
    "content": "import torch\n\n\nclass BaseModel(torch.nn.Module):\n    def load(self, path):\n        \"\"\"Load model from file.\n        Args:\n            path (str): file path\n        \"\"\"\n        parameters = torch.load(path, map_location=torch.device('cpu'))\n\n        if \"optimizer\" in parameters:\n            parameters = parameters[\"model\"]\n\n        self.load_state_dict(parameters)\n"
  },
  {
    "path": "utils/midas/blocks.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom .vit import (\n    _make_pretrained_vitb_rn50_384,\n    _make_pretrained_vitl16_384,\n    _make_pretrained_vitb16_384,\n    forward_vit,\n)\n\ndef _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout=\"ignore\",):\n    if backbone == \"vitl16_384\":\n        pretrained = _make_pretrained_vitl16_384(\n            use_pretrained, hooks=hooks, use_readout=use_readout\n        )\n        scratch = _make_scratch(\n            [256, 512, 1024, 1024], features, groups=groups, expand=expand\n        )  # ViT-L/16 - 85.0% Top1 (backbone)\n    elif backbone == \"vitb_rn50_384\":\n        pretrained = _make_pretrained_vitb_rn50_384(\n            use_pretrained,\n            hooks=hooks,\n            use_vit_only=use_vit_only,\n            use_readout=use_readout,\n        )\n        scratch = _make_scratch(\n            [256, 512, 768, 768], features, groups=groups, expand=expand\n        )  # ViT-H/16 - 85.0% Top1 (backbone)\n    elif backbone == \"vitb16_384\":\n        pretrained = _make_pretrained_vitb16_384(\n            use_pretrained, hooks=hooks, use_readout=use_readout\n        )\n        scratch = _make_scratch(\n            [96, 192, 384, 768], features, groups=groups, expand=expand\n        )  # ViT-B/16 - 84.6% Top1 (backbone)\n    elif backbone == \"resnext101_wsl\":\n        pretrained = _make_pretrained_resnext101_wsl(use_pretrained)\n        scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand)     # efficientnet_lite3\n    elif backbone == \"efficientnet_lite3\":\n        pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)\n        scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand)  # efficientnet_lite3\n    else:\n        print(f\"Backbone '{backbone}' not implemented\")\n        assert False\n\n    return pretrained, scratch\n\n\ndef _make_scratch(in_shape, out_shape, groups=1, expand=False):\n    scratch = nn.Module()\n\n    out_shape1 = out_shape\n    out_shape2 = out_shape\n    out_shape3 = out_shape\n    out_shape4 = out_shape\n    if expand==True:\n        out_shape1 = out_shape\n        out_shape2 = out_shape*2\n        out_shape3 = out_shape*4\n        out_shape4 = out_shape*8\n\n    scratch.layer1_rn = nn.Conv2d(\n        in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    scratch.layer2_rn = nn.Conv2d(\n        in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    scratch.layer3_rn = nn.Conv2d(\n        in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n    scratch.layer4_rn = nn.Conv2d(\n        in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups\n    )\n\n    return scratch\n\n\ndef _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):\n    efficientnet = torch.hub.load(\n        \"rwightman/gen-efficientnet-pytorch\",\n        \"tf_efficientnet_lite3\",\n        pretrained=use_pretrained,\n        exportable=exportable\n    )\n    return _make_efficientnet_backbone(efficientnet)\n\n\ndef _make_efficientnet_backbone(effnet):\n    pretrained = nn.Module()\n\n    pretrained.layer1 = nn.Sequential(\n        effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]\n    )\n    pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])\n    pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])\n    pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])\n\n    return pretrained\n\n\ndef _make_resnet_backbone(resnet):\n    pretrained = nn.Module()\n    pretrained.layer1 = nn.Sequential(\n        resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1\n    )\n\n    pretrained.layer2 = resnet.layer2\n    pretrained.layer3 = resnet.layer3\n    pretrained.layer4 = resnet.layer4\n\n    return pretrained\n\n\ndef _make_pretrained_resnext101_wsl(use_pretrained):\n    resnet = torch.hub.load(\"facebookresearch/WSL-Images\", \"resnext101_32x8d_wsl\")\n    return _make_resnet_backbone(resnet)\n\n\n\nclass Interpolate(nn.Module):\n    \"\"\"Interpolation module.\n    \"\"\"\n\n    def __init__(self, scale_factor, mode, align_corners=False):\n        \"\"\"Init.\n\n        Args:\n            scale_factor (float): scaling\n            mode (str): interpolation mode\n        \"\"\"\n        super(Interpolate, self).__init__()\n\n        self.interp = nn.functional.interpolate\n        self.scale_factor = scale_factor\n        self.mode = mode\n        self.align_corners = align_corners\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input\n\n        Returns:\n            tensor: interpolated data\n        \"\"\"\n\n        x = self.interp(\n            x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners\n        )\n\n        return x\n\n\nclass ResidualConvUnit(nn.Module):\n    \"\"\"Residual convolution module.\n    \"\"\"\n\n    def __init__(self, features):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super().__init__()\n\n        self.conv1 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True\n        )\n\n        self.conv2 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True\n        )\n\n        self.relu = nn.ReLU(inplace=True)\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input\n\n        Returns:\n            tensor: output\n        \"\"\"\n        out = self.relu(x)\n        out = self.conv1(out)\n        out = self.relu(out)\n        out = self.conv2(out)\n\n        return out + x\n\n\nclass FeatureFusionBlock(nn.Module):\n    \"\"\"Feature fusion block.\n    \"\"\"\n\n    def __init__(self, features):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super(FeatureFusionBlock, self).__init__()\n\n        self.resConfUnit1 = ResidualConvUnit(features)\n        self.resConfUnit2 = ResidualConvUnit(features)\n\n    def forward(self, *xs):\n        \"\"\"Forward pass.\n\n        Returns:\n            tensor: output\n        \"\"\"\n        output = xs[0]\n\n        if len(xs) == 2:\n            output += self.resConfUnit1(xs[1])\n\n        output = self.resConfUnit2(output)\n\n        output = nn.functional.interpolate(\n            output, scale_factor=2, mode=\"bilinear\", align_corners=True\n        )\n\n        return output\n\n\n\n\nclass ResidualConvUnit_custom(nn.Module):\n    \"\"\"Residual convolution module.\n    \"\"\"\n\n    def __init__(self, features, activation, bn):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super().__init__()\n\n        self.bn = bn\n\n        self.groups=1\n\n        self.conv1 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups\n        )\n\n        self.conv2 = nn.Conv2d(\n            features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups\n        )\n\n        if self.bn==True:\n            self.bn1 = nn.BatchNorm2d(features)\n            self.bn2 = nn.BatchNorm2d(features)\n\n        self.activation = activation\n\n        self.skip_add = nn.quantized.FloatFunctional()\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input\n\n        Returns:\n            tensor: output\n        \"\"\"\n\n        out = self.activation(x)\n        out = self.conv1(out)\n        if self.bn==True:\n            out = self.bn1(out)\n\n        out = self.activation(out)\n        out = self.conv2(out)\n        if self.bn==True:\n            out = self.bn2(out)\n\n        if self.groups > 1:\n            out = self.conv_merge(out)\n\n        return self.skip_add.add(out, x)\n\n        # return out + x\n\n\nclass FeatureFusionBlock_custom(nn.Module):\n    \"\"\"Feature fusion block.\n    \"\"\"\n\n    def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):\n        \"\"\"Init.\n\n        Args:\n            features (int): number of features\n        \"\"\"\n        super(FeatureFusionBlock_custom, self).__init__()\n\n        self.deconv = deconv\n        self.align_corners = align_corners\n\n        self.groups=1\n\n        self.expand = expand\n        out_features = features\n        if self.expand==True:\n            out_features = features//2\n\n        self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)\n\n        self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)\n        self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)\n\n        self.skip_add = nn.quantized.FloatFunctional()\n\n    def forward(self, *xs):\n        \"\"\"Forward pass.\n\n        Returns:\n            tensor: output\n        \"\"\"\n        output = xs[0]\n\n        if len(xs) == 2:\n            res = self.resConfUnit1(xs[1])\n            output = self.skip_add.add(output, res)\n            # output += res\n\n        output = self.resConfUnit2(output)\n\n        output = nn.functional.interpolate(\n            output, scale_factor=2, mode=\"bilinear\", align_corners=self.align_corners\n        )\n\n        output = self.out_conv(output)\n\n        return output\n"
  },
  {
    "path": "utils/midas/midas_net.py",
    "content": "\"\"\"MidashNet: Network for monocular depth estimation trained by mixing several datasets.\nThis file contains code that is adapted from\nhttps://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py\n\"\"\"\nimport torch\nimport torch.nn as nn\n\nfrom .base_model import BaseModel\nfrom .blocks import FeatureFusionBlock, Interpolate, _make_encoder\n\n\nclass MidasNet(BaseModel):\n    \"\"\"Network for monocular depth estimation.\n    \"\"\"\n\n    def __init__(self, path=None, features=256, non_negative=True):\n        \"\"\"Init.\n\n        Args:\n            path (str, optional): Path to saved model. Defaults to None.\n            features (int, optional): Number of features. Defaults to 256.\n            backbone (str, optional): Backbone network for encoder. Defaults to resnet50\n        \"\"\"\n        print(\"Loading weights: \", path)\n\n        super(MidasNet, self).__init__()\n\n        use_pretrained = False if path is None else True\n\n        self.pretrained, self.scratch = _make_encoder(backbone=\"resnext101_wsl\", features=features, use_pretrained=use_pretrained)\n\n        self.scratch.refinenet4 = FeatureFusionBlock(features)\n        self.scratch.refinenet3 = FeatureFusionBlock(features)\n        self.scratch.refinenet2 = FeatureFusionBlock(features)\n        self.scratch.refinenet1 = FeatureFusionBlock(features)\n\n        self.scratch.output_conv = nn.Sequential(\n            nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),\n            Interpolate(scale_factor=2, mode=\"bilinear\"),\n            nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),\n            nn.ReLU(True),\n            nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),\n            nn.ReLU(True) if non_negative else nn.Identity(),\n        )\n\n        if path:\n            self.load(path)\n\n    def forward(self, x):\n        \"\"\"Forward pass.\n\n        Args:\n            x (tensor): input data (image)\n\n        Returns:\n            tensor: depth\n        \"\"\"\n\n        layer_1 = self.pretrained.layer1(x)\n        layer_2 = self.pretrained.layer2(layer_1)\n        layer_3 = self.pretrained.layer3(layer_2)\n        layer_4 = self.pretrained.layer4(layer_3)\n\n        layer_1_rn = self.scratch.layer1_rn(layer_1)\n        layer_2_rn = self.scratch.layer2_rn(layer_2)\n        layer_3_rn = self.scratch.layer3_rn(layer_3)\n        layer_4_rn = self.scratch.layer4_rn(layer_4)\n\n        path_4 = self.scratch.refinenet4(layer_4_rn)\n        path_3 = self.scratch.refinenet3(path_4, layer_3_rn)\n        path_2 = self.scratch.refinenet2(path_3, layer_2_rn)\n        path_1 = self.scratch.refinenet1(path_2, layer_1_rn)\n\n        out = self.scratch.output_conv(path_1)\n\n        return torch.squeeze(out, dim=1)\n"
  },
  {
    "path": "utils/midas/transforms.py",
    "content": "import numpy as np\nimport cv2\nimport math\n\n\ndef apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):\n    \"\"\"Rezise the sample to ensure the given size. Keeps aspect ratio.\n\n    Args:\n        sample (dict): sample\n        size (tuple): image size\n\n    Returns:\n        tuple: new size\n    \"\"\"\n    shape = list(sample[\"disparity\"].shape)\n\n    if shape[0] >= size[0] and shape[1] >= size[1]:\n        return sample\n\n    scale = [0, 0]\n    scale[0] = size[0] / shape[0]\n    scale[1] = size[1] / shape[1]\n\n    scale = max(scale)\n\n    shape[0] = math.ceil(scale * shape[0])\n    shape[1] = math.ceil(scale * shape[1])\n\n    # resize\n    sample[\"image\"] = cv2.resize(\n        sample[\"image\"], tuple(shape[::-1]), interpolation=image_interpolation_method\n    )\n\n    sample[\"disparity\"] = cv2.resize(\n        sample[\"disparity\"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST\n    )\n    sample[\"mask\"] = cv2.resize(\n        sample[\"mask\"].astype(np.float32),\n        tuple(shape[::-1]),\n        interpolation=cv2.INTER_NEAREST,\n    )\n    sample[\"mask\"] = sample[\"mask\"].astype(bool)\n\n    return tuple(shape)\n\n\nclass Resize(object):\n    \"\"\"Resize sample to given size (width, height).\n    \"\"\"\n\n    def __init__(\n        self,\n        width,\n        height,\n        resize_target=True,\n        keep_aspect_ratio=False,\n        ensure_multiple_of=1,\n        resize_method=\"lower_bound\",\n        image_interpolation_method=cv2.INTER_AREA,\n    ):\n        \"\"\"Init.\n\n        Args:\n            width (int): desired output width\n            height (int): desired output height\n            resize_target (bool, optional):\n                True: Resize the full sample (image, mask, target).\n                False: Resize image only.\n                Defaults to True.\n            keep_aspect_ratio (bool, optional):\n                True: Keep the aspect ratio of the input sample.\n                Output sample might not have the given width and height, and\n                resize behaviour depends on the parameter 'resize_method'.\n                Defaults to False.\n            ensure_multiple_of (int, optional):\n                Output width and height is constrained to be multiple of this parameter.\n                Defaults to 1.\n            resize_method (str, optional):\n                \"lower_bound\": Output will be at least as large as the given size.\n                \"upper_bound\": Output will be at max as large as the given size. (Output size might be smaller than given size.)\n                \"minimal\": Scale as least as possible.  (Output size might be smaller than given size.)\n                Defaults to \"lower_bound\".\n        \"\"\"\n        self.__width = width\n        self.__height = height\n\n        self.__resize_target = resize_target\n        self.__keep_aspect_ratio = keep_aspect_ratio\n        self.__multiple_of = ensure_multiple_of\n        self.__resize_method = resize_method\n        self.__image_interpolation_method = image_interpolation_method\n\n    def constrain_to_multiple_of(self, x, min_val=0, max_val=None):\n        y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)\n\n        if max_val is not None and y > max_val:\n            y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)\n\n        if y < min_val:\n            y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)\n\n        return y\n\n    def get_size(self, width, height):\n        # determine new height and width\n        scale_height = self.__height / height\n        scale_width = self.__width / width\n\n        if self.__keep_aspect_ratio:\n            if self.__resize_method == \"lower_bound\":\n                # scale such that output size is lower bound\n                if scale_width > scale_height:\n                    # fit width\n                    scale_height = scale_width\n                else:\n                    # fit height\n                    scale_width = scale_height\n            elif self.__resize_method == \"upper_bound\":\n                # scale such that output size is upper bound\n                if scale_width < scale_height:\n                    # fit width\n                    scale_height = scale_width\n                else:\n                    # fit height\n                    scale_width = scale_height\n            elif self.__resize_method == \"minimal\":\n                # scale as least as possbile\n                if abs(1 - scale_width) < abs(1 - scale_height):\n                    # fit width\n                    scale_height = scale_width\n                else:\n                    # fit height\n                    scale_width = scale_height\n            else:\n                raise ValueError(\n                    f\"resize_method {self.__resize_method} not implemented\"\n                )\n\n        if self.__resize_method == \"lower_bound\":\n            new_height = self.constrain_to_multiple_of(\n                scale_height * height, min_val=self.__height\n            )\n            new_width = self.constrain_to_multiple_of(\n                scale_width * width, min_val=self.__width\n            )\n        elif self.__resize_method == \"upper_bound\":\n            new_height = self.constrain_to_multiple_of(\n                scale_height * height, max_val=self.__height\n            )\n            new_width = self.constrain_to_multiple_of(\n                scale_width * width, max_val=self.__width\n            )\n        elif self.__resize_method == \"minimal\":\n            new_height = self.constrain_to_multiple_of(scale_height * height)\n            new_width = self.constrain_to_multiple_of(scale_width * width)\n        else:\n            raise ValueError(f\"resize_method {self.__resize_method} not implemented\")\n\n        return (new_width, new_height)\n\n    def __call__(self, sample):\n        width, height = self.get_size(\n            sample[\"image\"].shape[1], sample[\"image\"].shape[0]\n        )\n\n        # resize sample\n        sample[\"image\"] = cv2.resize(\n            sample[\"image\"],\n            (width, height),\n            interpolation=self.__image_interpolation_method,\n        )\n\n        if self.__resize_target:\n            if \"disparity\" in sample:\n                sample[\"disparity\"] = cv2.resize(\n                    sample[\"disparity\"],\n                    (width, height),\n                    interpolation=cv2.INTER_NEAREST,\n                )\n\n            if \"depth\" in sample:\n                sample[\"depth\"] = cv2.resize(\n                    sample[\"depth\"], (width, height), interpolation=cv2.INTER_NEAREST\n                )\n\n            sample[\"mask\"] = cv2.resize(\n                sample[\"mask\"].astype(np.float32),\n                (width, height),\n                interpolation=cv2.INTER_NEAREST,\n            )\n            sample[\"mask\"] = sample[\"mask\"].astype(bool)\n\n        return sample\n\n\nclass NormalizeImage(object):\n    \"\"\"Normlize image by given mean and std.\n    \"\"\"\n\n    def __init__(self, mean, std):\n        self.__mean = mean\n        self.__std = std\n\n    def __call__(self, sample):\n        sample[\"image\"] = (sample[\"image\"] - self.__mean) / self.__std\n\n        return sample\n\n\nclass PrepareForNet(object):\n    \"\"\"Prepare sample for usage as network input.\n    \"\"\"\n\n    def __init__(self):\n        pass\n\n    def __call__(self, sample):\n        image = np.transpose(sample[\"image\"], (2, 0, 1))\n        sample[\"image\"] = np.ascontiguousarray(image).astype(np.float32)\n\n        if \"mask\" in sample:\n            sample[\"mask\"] = sample[\"mask\"].astype(np.float32)\n            sample[\"mask\"] = np.ascontiguousarray(sample[\"mask\"])\n\n        if \"disparity\" in sample:\n            disparity = sample[\"disparity\"].astype(np.float32)\n            sample[\"disparity\"] = np.ascontiguousarray(disparity)\n\n        if \"depth\" in sample:\n            depth = sample[\"depth\"].astype(np.float32)\n            sample[\"depth\"] = np.ascontiguousarray(depth)\n\n        return sample\n"
  },
  {
    "path": "utils/midas/vit.py",
    "content": "import torch\nimport torch.nn as nn\nimport timm\nimport types\nimport math\nimport torch.nn.functional as F\n\n\nclass Slice(nn.Module):\n    def __init__(self, start_index=1):\n        super(Slice, self).__init__()\n        self.start_index = start_index\n\n    def forward(self, x):\n        return x[:, self.start_index :]\n\n\nclass AddReadout(nn.Module):\n    def __init__(self, start_index=1):\n        super(AddReadout, self).__init__()\n        self.start_index = start_index\n\n    def forward(self, x):\n        if self.start_index == 2:\n            readout = (x[:, 0] + x[:, 1]) / 2\n        else:\n            readout = x[:, 0]\n        return x[:, self.start_index :] + readout.unsqueeze(1)\n\n\nclass ProjectReadout(nn.Module):\n    def __init__(self, in_features, start_index=1):\n        super(ProjectReadout, self).__init__()\n        self.start_index = start_index\n\n        self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())\n\n    def forward(self, x):\n        readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])\n        features = torch.cat((x[:, self.start_index :], readout), -1)\n\n        return self.project(features)\n\n\nclass Transpose(nn.Module):\n    def __init__(self, dim0, dim1):\n        super(Transpose, self).__init__()\n        self.dim0 = dim0\n        self.dim1 = dim1\n\n    def forward(self, x):\n        x = x.transpose(self.dim0, self.dim1)\n        return x\n\n\ndef forward_vit(pretrained, x):\n    b, c, h, w = x.shape\n\n    glob = pretrained.model.forward_flex(x)\n\n    layer_1 = pretrained.activations[\"1\"]\n    layer_2 = pretrained.activations[\"2\"]\n    layer_3 = pretrained.activations[\"3\"]\n    layer_4 = pretrained.activations[\"4\"]\n\n    layer_1 = pretrained.act_postprocess1[0:2](layer_1)\n    layer_2 = pretrained.act_postprocess2[0:2](layer_2)\n    layer_3 = pretrained.act_postprocess3[0:2](layer_3)\n    layer_4 = pretrained.act_postprocess4[0:2](layer_4)\n\n    unflatten = nn.Sequential(\n        nn.Unflatten(\n            2,\n            torch.Size(\n                [\n                    h // pretrained.model.patch_size[1],\n                    w // pretrained.model.patch_size[0],\n                ]\n            ),\n        )\n    )\n\n    if layer_1.ndim == 3:\n        layer_1 = unflatten(layer_1)\n    if layer_2.ndim == 3:\n        layer_2 = unflatten(layer_2)\n    if layer_3.ndim == 3:\n        layer_3 = unflatten(layer_3)\n    if layer_4.ndim == 3:\n        layer_4 = unflatten(layer_4)\n\n    layer_1 = pretrained.act_postprocess1[3 : len(pretrained.act_postprocess1)](layer_1)\n    layer_2 = pretrained.act_postprocess2[3 : len(pretrained.act_postprocess2)](layer_2)\n    layer_3 = pretrained.act_postprocess3[3 : len(pretrained.act_postprocess3)](layer_3)\n    layer_4 = pretrained.act_postprocess4[3 : len(pretrained.act_postprocess4)](layer_4)\n\n    return layer_1, layer_2, layer_3, layer_4\n\n\ndef _resize_pos_embed(self, posemb, gs_h, gs_w):\n    posemb_tok, posemb_grid = (\n        posemb[:, : self.start_index],\n        posemb[0, self.start_index :],\n    )\n\n    gs_old = int(math.sqrt(len(posemb_grid)))\n\n    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)\n    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode=\"bilinear\")\n    posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)\n\n    posemb = torch.cat([posemb_tok, posemb_grid], dim=1)\n\n    return posemb\n\n\ndef forward_flex(self, x):\n    b, c, h, w = x.shape\n\n    pos_embed = self._resize_pos_embed(\n        self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]\n    )\n\n    B = x.shape[0]\n\n    if hasattr(self.patch_embed, \"backbone\"):\n        x = self.patch_embed.backbone(x)\n        if isinstance(x, (list, tuple)):\n            x = x[-1]  # last feature if backbone outputs list/tuple of features\n\n    x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)\n\n    if getattr(self, \"dist_token\", None) is not None:\n        cls_tokens = self.cls_token.expand(\n            B, -1, -1\n        )  # stole cls_tokens impl from Phil Wang, thanks\n        dist_token = self.dist_token.expand(B, -1, -1)\n        x = torch.cat((cls_tokens, dist_token, x), dim=1)\n    else:\n        cls_tokens = self.cls_token.expand(\n            B, -1, -1\n        )  # stole cls_tokens impl from Phil Wang, thanks\n        x = torch.cat((cls_tokens, x), dim=1)\n\n    x = x + pos_embed\n    x = self.pos_drop(x)\n\n    for blk in self.blocks:\n        x = blk(x)\n\n    x = self.norm(x)\n\n    return x\n\n\nactivations = {}\n\n\ndef get_activation(name):\n    def hook(model, input, output):\n        activations[name] = output\n\n    return hook\n\n\ndef get_readout_oper(vit_features, features, use_readout, start_index=1):\n    if use_readout == \"ignore\":\n        readout_oper = [Slice(start_index)] * len(features)\n    elif use_readout == \"add\":\n        readout_oper = [AddReadout(start_index)] * len(features)\n    elif use_readout == \"project\":\n        readout_oper = [\n            ProjectReadout(vit_features, start_index) for out_feat in features\n        ]\n    else:\n        assert (\n            False\n        ), \"wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'\"\n\n    return readout_oper\n\n\ndef _make_vit_b16_backbone(\n    model,\n    features=[96, 192, 384, 768],\n    size=[384, 384],\n    hooks=[2, 5, 8, 11],\n    vit_features=768,\n    use_readout=\"ignore\",\n    start_index=1,\n):\n    pretrained = nn.Module()\n\n    pretrained.model = model\n    pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation(\"1\"))\n    pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation(\"2\"))\n    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation(\"3\"))\n    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation(\"4\"))\n\n    pretrained.activations = activations\n\n    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)\n\n    # 32, 48, 136, 384\n    pretrained.act_postprocess1 = nn.Sequential(\n        readout_oper[0],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[0],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.ConvTranspose2d(\n            in_channels=features[0],\n            out_channels=features[0],\n            kernel_size=4,\n            stride=4,\n            padding=0,\n            bias=True,\n            dilation=1,\n            groups=1,\n        ),\n    )\n\n    pretrained.act_postprocess2 = nn.Sequential(\n        readout_oper[1],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[1],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.ConvTranspose2d(\n            in_channels=features[1],\n            out_channels=features[1],\n            kernel_size=2,\n            stride=2,\n            padding=0,\n            bias=True,\n            dilation=1,\n            groups=1,\n        ),\n    )\n\n    pretrained.act_postprocess3 = nn.Sequential(\n        readout_oper[2],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[2],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n    )\n\n    pretrained.act_postprocess4 = nn.Sequential(\n        readout_oper[3],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[3],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.Conv2d(\n            in_channels=features[3],\n            out_channels=features[3],\n            kernel_size=3,\n            stride=2,\n            padding=1,\n        ),\n    )\n\n    pretrained.model.start_index = start_index\n    pretrained.model.patch_size = [16, 16]\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)\n    pretrained.model._resize_pos_embed = types.MethodType(\n        _resize_pos_embed, pretrained.model\n    )\n\n    return pretrained\n\n\ndef _make_pretrained_vitl16_384(pretrained, use_readout=\"ignore\", hooks=None):\n    model = timm.create_model(\"vit_large_patch16_384\", pretrained=pretrained)\n\n    hooks = [5, 11, 17, 23] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model,\n        features=[256, 512, 1024, 1024],\n        hooks=hooks,\n        vit_features=1024,\n        use_readout=use_readout,\n    )\n\n\ndef _make_pretrained_vitb16_384(pretrained, use_readout=\"ignore\", hooks=None):\n    model = timm.create_model(\"vit_base_patch16_384\", pretrained=pretrained)\n\n    hooks = [2, 5, 8, 11] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout\n    )\n\n\ndef _make_pretrained_deitb16_384(pretrained, use_readout=\"ignore\", hooks=None):\n    model = timm.create_model(\"vit_deit_base_patch16_384\", pretrained=pretrained)\n\n    hooks = [2, 5, 8, 11] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout\n    )\n\n\ndef _make_pretrained_deitb16_distil_384(pretrained, use_readout=\"ignore\", hooks=None):\n    model = timm.create_model(\n        \"vit_deit_base_distilled_patch16_384\", pretrained=pretrained\n    )\n\n    hooks = [2, 5, 8, 11] if hooks == None else hooks\n    return _make_vit_b16_backbone(\n        model,\n        features=[96, 192, 384, 768],\n        hooks=hooks,\n        use_readout=use_readout,\n        start_index=2,\n    )\n\n\ndef _make_vit_b_rn50_backbone(\n    model,\n    features=[256, 512, 768, 768],\n    size=[384, 384],\n    hooks=[0, 1, 8, 11],\n    vit_features=768,\n    use_vit_only=False,\n    use_readout=\"ignore\",\n    start_index=1,\n):\n    pretrained = nn.Module()\n\n    pretrained.model = model\n\n    if use_vit_only == True:\n        pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation(\"1\"))\n        pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation(\"2\"))\n    else:\n        pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(\n            get_activation(\"1\")\n        )\n        pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(\n            get_activation(\"2\")\n        )\n\n    pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation(\"3\"))\n    pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation(\"4\"))\n\n    pretrained.activations = activations\n\n    readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)\n\n    if use_vit_only == True:\n        pretrained.act_postprocess1 = nn.Sequential(\n            readout_oper[0],\n            Transpose(1, 2),\n            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n            nn.Conv2d(\n                in_channels=vit_features,\n                out_channels=features[0],\n                kernel_size=1,\n                stride=1,\n                padding=0,\n            ),\n            nn.ConvTranspose2d(\n                in_channels=features[0],\n                out_channels=features[0],\n                kernel_size=4,\n                stride=4,\n                padding=0,\n                bias=True,\n                dilation=1,\n                groups=1,\n            ),\n        )\n\n        pretrained.act_postprocess2 = nn.Sequential(\n            readout_oper[1],\n            Transpose(1, 2),\n            nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n            nn.Conv2d(\n                in_channels=vit_features,\n                out_channels=features[1],\n                kernel_size=1,\n                stride=1,\n                padding=0,\n            ),\n            nn.ConvTranspose2d(\n                in_channels=features[1],\n                out_channels=features[1],\n                kernel_size=2,\n                stride=2,\n                padding=0,\n                bias=True,\n                dilation=1,\n                groups=1,\n            ),\n        )\n    else:\n        pretrained.act_postprocess1 = nn.Sequential(\n            nn.Identity(), nn.Identity(), nn.Identity()\n        )\n        pretrained.act_postprocess2 = nn.Sequential(\n            nn.Identity(), nn.Identity(), nn.Identity()\n        )\n\n    pretrained.act_postprocess3 = nn.Sequential(\n        readout_oper[2],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[2],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n    )\n\n    pretrained.act_postprocess4 = nn.Sequential(\n        readout_oper[3],\n        Transpose(1, 2),\n        nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),\n        nn.Conv2d(\n            in_channels=vit_features,\n            out_channels=features[3],\n            kernel_size=1,\n            stride=1,\n            padding=0,\n        ),\n        nn.Conv2d(\n            in_channels=features[3],\n            out_channels=features[3],\n            kernel_size=3,\n            stride=2,\n            padding=1,\n        ),\n    )\n\n    pretrained.model.start_index = start_index\n    pretrained.model.patch_size = [16, 16]\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)\n\n    # We inject this function into the VisionTransformer instances so that\n    # we can use it with interpolated position embeddings without modifying the library source.\n    pretrained.model._resize_pos_embed = types.MethodType(\n        _resize_pos_embed, pretrained.model\n    )\n\n    return pretrained\n\n\ndef _make_pretrained_vitb_rn50_384(\n    pretrained, use_readout=\"ignore\", hooks=None, use_vit_only=False\n):\n    model = timm.create_model(\"vit_base_resnet50_384\", pretrained=pretrained)\n\n    hooks = [0, 1, 8, 11] if hooks == None else hooks\n    return _make_vit_b_rn50_backbone(\n        model,\n        features=[256, 512, 768, 768],\n        size=[384, 384],\n        hooks=hooks,\n        use_vit_only=use_vit_only,\n        use_readout=use_readout,\n    )\n"
  }
]