[
  {
    "path": ".gitignore",
    "content": "outputs/\ncheckpoints/\ndebug_figs/\n*__pycache__"
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2022 Jeff Wang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# MVSTER\nMVSTER: Epipolar Transformer for Efficient Multi-View Stereo, ECCV 2022. [arXiv](https://arxiv.org/abs/2204.07346)\n\nThis repository contains the official implementation of the paper: \"MVSTER: Epipolar Transformer for Efficient Multi-View Stereo\".\n\n\n## Introduction\nMVSTER is a learning-based MVS method which achieves competitive reconstruction performance with significantly higher efficiency. MVSTER leverages the proposed epipolar Transformer to learn both 2D semantics and 3D spatial associations efficiently. Specifically, the epipolar Transformer utilizes a detachable monocular depth estimator to enhance 2D semantics and uses cross-attention to construct data-dependent 3D associations along epipolar line. Additionally, MVSTER is built in a cascade structure, where entropy-regularized optimal transport is leveraged to propagate finer depth estimations in each stage.\n![](img/arch.png)\n\n\n\n## Installation\nMVSTER is tested on:\n* python 3.7\n* CUDA 11.1\n### Requirements\n```\npip install -r requirements.txt\n```\n\n## Training\n* Dowload [DTU dataset](https://roboimagedata.compute.dtu.dk/). For convenience, can download the preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view)\n and [Depths_raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip) \n (both from [Original MVSNet](https://github.com/YoYo000/MVSNet)), and upzip it as the $DTU_TRAINING folder. For training and testing with raw image size, you can download [Rectified_raw](http://roboimagedata2.compute.dtu.dk/data/MVS/Rectified.zip), and unzip it.\n\n```                \n├── Cameras    \n├── Depths\n├── Depths_raw   \n├── Rectified\n├── Rectified_raw (Optional)                                      \n```\nIn ``scripts/train_dtu.sh``, set ``DTU_TRAINING`` as $DTU_TRAINING\n\nTrain MVSTER (Multi-GPU training): \n* Train with middle size (512x640):\n```\nbash ./scripts/train_dtu.sh mid exp_name\n```\n* Train with raw size (1200x1600):\n```\nbash ./scripts/train_dtu.sh raw exp_name\n```\nAfter training, you will get model checkpoints in ./checkpoints/dtu/exp_name.\n\n## Testing\n* Download the preprocessed test data [DTU testing data](https://drive.google.com/open?id=135oKPefcPTsdtLRzoDAQtPpHuoIrpRI_) (from [Original MVSNet](https://github.com/YoYo000/MVSNet)) and unzip it as the $DTU_TESTPATH folder, which should contain one ``cams`` folder, one ``images`` folder and one ``pair.txt`` file.\n* In ``scripts/test_dtu.sh``, set ``DTU_TESTPATH`` as $DTU_TESTPATH.\n* The ``DTU_CKPT_FILE`` is automatically set as your pretrained checkpoint file, you also can download my [pretrained model](https://github.com/JeffWang987/MVSTER/releases/tag/dtu_ckpt).\n* Test with middle size:\n```\nbash ./scripts/test_dtu.sh mid exp_name\n```\n* Test with raw size:\n```\nbash ./scripts/test_dtu.sh raw exp_name\n```\n* Test with provided pretrained model:\n```\nbash scripts/test_dtu.sh mid benchmark --loadckpt PATH_TO_CKPT_FILE\n```\nAfter testing, you will get reconstructed point clouds of DTU test set in ./outputs/dtu/exp_name.\n\n## Metric\n* For quantitative evaluation, download [SampleSet](http://roboimagedata.compute.dtu.dk/?page_id=36) and [Points](http://roboimagedata.compute.dtu.dk/?page_id=36) from DTU's website. Unzip them and place `Points` folder in `SampleSet/MVS Data/`. The structure looks like:\n```\nSampleSet\n├──MVS Data\n      └──Points\n```\n* For convinience evaluation, please install matlab (tested on Ubuntu 18.04) and uncomment **mrun_rst** function at the end of **./test_mvs4.py**, and you also need to change the path of matlab excutable file (for me, it is /mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/misc/matlab/bin/matlab). Then you can evaluate point cloud reconstruction results when testing is finished.\n\n* You can also evaluate the metrics with the traditional steps:\nIn ``evaluations/dtu/BaseEvalMain_web.m``, set `dataPath` as the path to `SampleSet/MVS Data/`, `plyPath` as directory that stores the reconstructed point clouds and `resultsPath` as directory to store the evaluation results. Then run ``evaluations/dtu/BaseEvalMain_web.m`` in matlab.\n\n## Results on DTU (single RTX 3090)\n|                       | Acc.   | Comp.  | Overall. | Inf. Time |\n|-----------------------|--------|--------|----------|-----------|\n| MVSTER (mid size)     | 0.350  | 0.276  | 0.313    |    0.09s  |\n| MVSTER (raw size)     | 0.340  | 0.266  | 0.303    |    0.17s  |\n\nPoint cloud results on [DTU](https://github.com/JeffWang987/MVSTER/releases/tag/DTU_ply), [Tanks and Temples](https://github.com/JeffWang987/MVSTER/releases/tag/T%26T_ply), [ETH3D](https://github.com/JeffWang987/MVSTER/releases/tag/ETH3D_ply)\n\n![](img/vegetables.gif) ![](img/house.gif) \n\n![](img/sculpture.gif) ![](img/rabit.gif)\n\n\nIf you find this project useful for your research, please cite: \n```\n@misc{wang2022mvster,\n      title={MVSTER: Epipolar Transformer for Efficient Multi-View Stereo}, \n      author={Xiaofeng Wang, Zheng Zhu, Fangbo Qin, Yun Ye, Guan Huang, Xu Chi, Yijia He and Xingang Wang},\n      journal={arXiv preprint arXiv:2204.07346},\n      year={2022}\n}\n```\n\n\n## Acknowledgements\nOur work is partially baed on these opening source work: [MVSNet](https://github.com/YoYo000/MVSNet), [MVSNet-pytorch](https://github.com/xy-guo/MVSNet_pytorch), [cascade-stereo](https://github.com/alibaba/cascade-stereo), [PatchmatchNet](https://github.com/FangjinhuaWang/PatchmatchNet).\n\nWe appreciate their contributions to the MVS community.\n"
  },
  {
    "path": "datasets/__init__.py",
    "content": "import importlib\n\n\n# find the dataset definition by name, for example dtu_yao (dtu_yao.py)\ndef find_dataset_def(dataset_name):\n    module_name = 'datasets.{}'.format(dataset_name)\n    module = importlib.import_module(module_name)\n    return getattr(module, \"MVSDataset\")\n"
  },
  {
    "path": "datasets/blendedmvs.py",
    "content": "from torch.utils.data import Dataset\nfrom datasets.data_io import *\nimport os\nimport numpy as np\nimport cv2\nfrom PIL import Image\nfrom torchvision import transforms as T\nimport random\nimport copy\n\ndef check_invalid_input(imgs, depths, masks, depth_mins, depth_maxs):\n    for img in imgs:\n        assert np.isnan(img).sum() == 0\n        assert np.isinf(img).sum() == 0\n    for depth in depths.values():\n        assert np.isnan(depth).sum() == 0\n        assert np.isinf(depth).sum() == 0\n    for mask in masks.values():\n        assert np.isnan(mask).sum() == 0\n        assert np.isinf(mask).sum() == 0\n\n    assert (depth_mins<=0) == 0\n    assert (depth_maxs<=depth_mins) == 0\n\n\nclass MVSDataset(Dataset):\n    def __init__(self, datapath, listfile, split, nviews, img_wh=(768, 576), robust_train=True):\n        \n        super(MVSDataset, self).__init__()\n        self.levels = 4 \n        self.datapath = datapath\n        self.split = split\n        self.listfile = listfile\n        self.robust_train = robust_train\n        assert self.split in ['train', 'val', 'all'], \\\n            'split must be either \"train\", \"val\" or \"all\"!'\n\n        self.img_wh = img_wh\n        if img_wh is not None:\n            assert img_wh[0]%32==0 and img_wh[1]%32==0, \\\n                'img_wh must both be multiples of 32!'\n        self.nviews = nviews\n        self.scale_factors = {} # depth scale factors for each scan\n        self.scale_factor = 0 # depth scale factors for each scan\n        self.build_metas()\n\n        self.color_augment = T.ColorJitter(brightness=0.5, contrast=0.5)\n\n    def build_metas(self):\n        self.metas = []\n        with open(self.listfile) as f:\n            self.scans = [line.rstrip() for line in f.readlines()]\n        for scan in self.scans:\n            with open(os.path.join(self.datapath, scan, \"cams/pair.txt\")) as f:\n                num_viewpoint = int(f.readline())\n                for _ in range(num_viewpoint):\n                    ref_view = int(f.readline().rstrip())\n                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n                    if len(src_views) >= self.nviews-1:\n                        self.metas += [(scan, ref_view, src_views)]\n\n    def read_cam_file(self, scan, filename):\n        with open(filename) as f:\n            lines = f.readlines()\n            lines = [line.rstrip() for line in lines]\n        # extrinsics: line [1,5), 4x4 matrix\n        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))\n        # intrinsics: line [7-10), 3x3 matrix\n        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))\n        depth_min = float(lines[11].split()[0])\n        depth_max = float(lines[11].split()[-1])\n\n        if scan not in self.scale_factors:\n            self.scale_factors[scan] = 100.0 / depth_min\n        depth_min *= self.scale_factors[scan]\n        depth_max *= self.scale_factors[scan]\n        extrinsics[:3, 3] *= self.scale_factors[scan]\n\n        return intrinsics, extrinsics, depth_min, depth_max\n\n    def read_depth_mask(self, scan, filename, depth_min, depth_max, scale):\n        depth = np.array(read_pfm(filename)[0], dtype=np.float32)\n        # depth = (depth * self.scale_factor) * scale\n        depth = (depth * self.scale_factors[scan]) * scale\n        # depth = depth * scale\n        # depth = np.squeeze(depth,2)\n\n        mask = (depth>=depth_min) & (depth<=depth_max)\n        assert mask.sum() > 0\n        mask = mask.astype(np.float32)\n        if self.img_wh is not None:\n            depth = cv2.resize(depth, self.img_wh,\n                                 interpolation=cv2.INTER_NEAREST)\n        h, w = depth.shape\n        depth_ms = {}\n        mask_ms = {}\n\n        for i in range(4):\n            depth_cur = cv2.resize(depth, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)\n            mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)\n\n            depth_ms[f\"stage{4-i}\"] = depth_cur\n            mask_ms[f\"stage{4-i}\"] = mask_cur\n\n        return depth_ms, mask_ms\n\n\n    def read_img(self, filename):\n        img = Image.open(filename)\n        # img = self.color_augment(img)\n        # scale 0~255 to 0~1\n        np_img = np.array(img, dtype=np.float32) / 255.\n        return np_img\n\n    def __len__(self):\n        return len(self.metas)\n\n    def __getitem__(self, idx):\n        meta = self.metas[idx]\n        scan, ref_view, src_views = meta\n        \n        if self.robust_train:\n            num_src_views = len(src_views)\n            index = random.sample(range(num_src_views), self.nviews - 1)\n            view_ids = [ref_view] + [src_views[i] for i in index]\n            scale = random.uniform(0.8, 1.25)\n\n        else:\n            view_ids = [ref_view] + src_views[:self.nviews - 1]\n            scale = 1\n\n        imgs = []\n        mask = None\n        depth = None\n        depth_min = None\n        depth_max = None\n\n        proj={}\n        proj_matrices_0 = []\n        proj_matrices_1 = []\n        proj_matrices_2 = []\n        proj_matrices_3 = []\n\n\n        for i, vid in enumerate(view_ids):\n            img_filename = os.path.join(self.datapath, '{}/blended_images/{:0>8}.jpg'.format(scan, vid))\n            depth_filename = os.path.join(self.datapath, '{}/rendered_depth_maps/{:0>8}.pfm'.format(scan, vid))\n            proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))\n\n            img = self.read_img(img_filename)\n            imgs.append(img.transpose(2,0,1))\n\n            intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(scan, proj_mat_filename)\n            # proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid)\n\n\n            proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            extrinsics[:3, 3] *= scale\n            intrinsics[:2,:] *= 0.125\n            proj_mat_0[0,:4,:4] = extrinsics.copy()\n            proj_mat_0[1,:3,:3] = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_1[0,:4,:4] = extrinsics.copy()\n            proj_mat_1[1,:3,:3] = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_2[0,:4,:4] = extrinsics.copy()\n            proj_mat_2[1,:3,:3] = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_3[0,:4,:4] = extrinsics.copy()\n            proj_mat_3[1,:3,:3] = intrinsics.copy()  \n\n            proj_matrices_0.append(proj_mat_0)\n            proj_matrices_1.append(proj_mat_1)\n            proj_matrices_2.append(proj_mat_2)\n            proj_matrices_3.append(proj_mat_3)\n\n            if i == 0:  # reference view\n                depth_min = depth_min_ * scale\n                depth_max = depth_max_ * scale\n                depth, mask = self.read_depth_mask(scan, depth_filename, depth_min, depth_max, scale)\n                for l in range(self.levels):\n                    mask[f'stage{l+1}'] = mask[f'stage{l+1}'] # np.expand_dims(mask[f'stage{l+1}'],2)\n                    depth[f'stage{l+1}'] = depth[f'stage{l+1}']\n\n        proj['stage1'] = np.stack(proj_matrices_0)\n        proj['stage2'] = np.stack(proj_matrices_1)\n        proj['stage3'] = np.stack(proj_matrices_2)\n        proj['stage4'] = np.stack(proj_matrices_3)\n\n        # check_invalid_input(imgs, depth, mask, depth_min, depth_max)\n        # data is numpy array\n        return {\"imgs\": imgs,                   # [Nv, 3, H, W]\n                \"proj_matrices\": proj,          # [N,2,4,4]\n                \"depth\": depth,                 # [1, H, W]\n                \"depth_values\": np.array([depth_min, depth_max], dtype=np.float32),\n                \"mask\": mask}                   # [1, H, W]\n        "
  },
  {
    "path": "datasets/data_io.py",
    "content": "import numpy as np\nimport re\nimport sys\n\n\ndef read_pfm(filename):\n    file = open(filename, 'rb')\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().decode('utf-8').rstrip()\n    if header == 'PF':\n        color = True\n    elif header == 'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(r'^(\\d+)\\s(\\d+)\\s$', file.readline().decode('utf-8'))\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().rstrip())\n    if scale < 0:  # little-endian\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>'  # big-endian\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    file.close()\n    return data, scale\n\n\ndef save_pfm(filename, image, scale=1):\n    file = open(filename, \"wb\")\n    color = None\n\n    image = np.flipud(image)\n\n    if image.dtype.name != 'float32':\n        raise Exception('Image dtype must be float32.')\n\n    if len(image.shape) == 3 and image.shape[2] == 3:  # color image\n        color = True\n    elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:  # greyscale\n        color = False\n    else:\n        raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')\n\n    file.write('PF\\n'.encode('utf-8') if color else 'Pf\\n'.encode('utf-8'))\n    file.write('{} {}\\n'.format(image.shape[1], image.shape[0]).encode('utf-8'))\n\n    endian = image.dtype.byteorder\n\n    if endian == '<' or endian == '=' and sys.byteorder == 'little':\n        scale = -scale\n\n    file.write(('%f\\n' % scale).encode('utf-8'))\n\n    image.tofile(file)\n    file.close()\n\n\nimport random, cv2\nclass RandomCrop(object):\n    def __init__(self, CropSize=0.1):\n        self.CropSize = CropSize\n\n    def __call__(self, image, normal):\n        h, w = normal.shape[:2]\n        img_h, img_w = image.shape[:2]\n        CropSize_w, CropSize_h = max(1, int(w * self.CropSize)), max(1, int(h * self.CropSize))\n        x1, y1 = random.randint(0, CropSize_w), random.randint(0, CropSize_h)\n        x2, y2 = random.randint(w - CropSize_w, w), random.randint(h - CropSize_h, h)\n\n        normal_crop = normal[y1:y2, x1:x2]\n        normal_resize = cv2.resize(normal_crop, (w, h), interpolation=cv2.INTER_NEAREST)\n\n        image_crop = image[4*y1:4*y2, 4*x1:4*x2]\n        image_resize = cv2.resize(image_crop, (img_w, img_h), interpolation=cv2.INTER_LINEAR)\n\n        # import matplotlib.pyplot as plt\n        # plt.subplot(2, 3, 1)\n        # plt.imshow(image)\n        # plt.subplot(2, 3, 2)\n        # plt.imshow(image_crop)\n        # plt.subplot(2, 3, 3)\n        # plt.imshow(image_resize)\n        #\n        # plt.subplot(2, 3, 4)\n        # plt.imshow((normal + 1.0) / 2, cmap=\"rainbow\")\n        # plt.subplot(2, 3, 5)\n        # plt.imshow((normal_crop + 1.0) / 2, cmap=\"rainbow\")\n        # plt.subplot(2, 3, 6)\n        # plt.imshow((normal_resize + 1.0) / 2, cmap=\"rainbow\")\n        # plt.show()\n        # plt.pause(1)\n        # plt.close()\n\n        return image_resize, normal_resize"
  },
  {
    "path": "datasets/dtu_yao4.py",
    "content": "from torch.utils.data import Dataset\nimport numpy as np\nimport os, cv2, time, math\nfrom PIL import Image\nfrom datasets.data_io import *\nfrom torchvision import transforms\n\n# the DTU dataset preprocessed by Yao Yao (only for training)\nclass MVSDataset(Dataset):\n    def __init__(self, datapath, listfile, mode, nviews, interval_scale=1.06, **kwargs):\n        super(MVSDataset, self).__init__()\n        self.datapath = datapath\n        self.listfile = listfile\n        self.mode = mode\n        self.nviews = nviews\n        self.ndepths = 192  # Hardcode\n        self.interval_scale = interval_scale\n        self.kwargs = kwargs\n        self.rt = kwargs.get(\"rt\", False)\n        self.use_raw_train = kwargs.get(\"use_raw_train\", False)\n        self.color_augment = transforms.ColorJitter(brightness=0.5, contrast=0.5)\n\n        assert self.mode in [\"train\", \"val\", \"test\"]\n        self.metas = self.build_list()\n\n    def build_list(self):\n        metas = []\n        with open(self.listfile) as f:\n            scans = f.readlines()\n            scans = [line.rstrip() for line in scans]\n\n        # scans\n        for scan in scans:\n            pair_file = \"Cameras/pair.txt\"\n            # read the pair file\n            with open(os.path.join(self.datapath, pair_file)) as f:\n                num_viewpoint = int(f.readline())\n                # viewpoints (49)\n                for view_idx in range(num_viewpoint):\n                    ref_view = int(f.readline().rstrip())\n                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n                    # light conditions 0-6\n                    for light_idx in range(7):\n                        metas.append((scan, light_idx, ref_view, src_views))\n        # print(\"dataset\", self.mode, \"metas:\", len(metas))\n        return metas\n\n    def __len__(self):\n        return len(self.metas)\n\n    def read_cam_file(self, filename):\n        with open(filename) as f:\n            lines = f.readlines()\n            lines = [line.rstrip() for line in lines]\n        # extrinsics: line [1,5), 4x4 matrix\n        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))\n        # intrinsics: line [7-10), 3x3 matrix\n        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))\n        # depth_min & depth_interval: line 11\n        depth_min = float(lines[11].split()[0])\n        depth_interval = float(lines[11].split()[1]) * self.interval_scale\n        return intrinsics, extrinsics, depth_min, depth_interval\n\n    def read_img(self, filename):\n        img = Image.open(filename)\n        if self.mode == 'train':\n            img = self.color_augment(img)\n        # scale 0~255 to 0~1\n        np_img = np.array(img, dtype=np.float32) / 255.\n        return np_img\n\n    def crop_img(self, img):\n        raw_h, raw_w = img.shape[:2]\n        start_h = (raw_h-1024)//2\n        start_w = (raw_w-1280)//2\n        return img[start_h:start_h+1024, start_w:start_w+1280, :]  # 1024, 1280, C\n\n    def prepare_img(self, hr_img):\n        h, w = hr_img.shape\n        if not self.use_raw_train:\n            #w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128\n            #downsample\n            hr_img_ds = cv2.resize(hr_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST)\n            h, w = hr_img_ds.shape\n            target_h, target_w = 512, 640\n            start_h, start_w = (h - target_h)//2, (w - target_w)//2\n            hr_img_crop = hr_img_ds[start_h: start_h + target_h, start_w: start_w + target_w]\n        elif self.use_raw_train:\n            hr_img_crop = hr_img[h//2-1024//2:h//2+1024//2, w//2-1280//2:w//2+1280//2]  # 1024, 1280, c\n        return hr_img_crop\n\n    def read_mask_hr(self, filename):\n        img = Image.open(filename)\n        np_img = np.array(img, dtype=np.float32)\n        np_img = (np_img > 10).astype(np.float32)\n        np_img = self.prepare_img(np_img)\n\n        h, w = np_img.shape\n        np_img_ms = {\n            \"stage1\": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_NEAREST),\n            \"stage2\": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_NEAREST),\n            \"stage3\": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST),\n            \"stage4\": np_img,\n        }\n        return np_img_ms\n\n\n    def read_depth_hr(self, filename, scale):\n        # read pfm depth file\n        #w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128\n        depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) * scale\n        depth_lr = self.prepare_img(depth_hr)\n\n        h, w = depth_lr.shape\n        depth_lr_ms = {\n            \"stage1\": cv2.resize(depth_lr, (w//8, h//8), interpolation=cv2.INTER_NEAREST),\n            \"stage2\": cv2.resize(depth_lr, (w//4, h//4), interpolation=cv2.INTER_NEAREST),\n            \"stage3\": cv2.resize(depth_lr, (w//2, h//2), interpolation=cv2.INTER_NEAREST),\n            \"stage4\": depth_lr,\n        }\n        return depth_lr_ms\n\n    def __getitem__(self, idx):\n        meta = self.metas[idx]\n        scan, light_idx, ref_view, src_views = meta\n        # use only the reference view and first nviews-1 source views\n\n        if self.mode == 'train' and self.rt:\n            num_src_views = len(src_views)\n            index = random.sample(range(num_src_views), self.nviews - 1)\n            view_ids = [ref_view] + [src_views[i] for i in index]\n            scale = random.uniform(0.8, 1.25)\n        else:\n            view_ids = [ref_view] + src_views[:self.nviews - 1]\n            scale = 1\n        imgs = []\n        mask = None\n        depth_values = None\n        proj_matrices = []\n        for i, vid in enumerate(view_ids):\n            # NOTE that the id in image file names is from 1 to 49 (not 0~48)\n            if not self.use_raw_train:\n                img_filename = os.path.join(self.datapath, 'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx))\n            else:\n                img_filename = os.path.join(self.datapath, 'Rectified_raw/{}/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx))\n            mask_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid))\n            depth_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid))\n            proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid)\n            img = self.read_img(img_filename)\n            if self.use_raw_train:\n                img = self.crop_img(img)\n            intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename)\n            if self.rt:\n                extrinsics[:3,3] *= scale\n            if self.use_raw_train:\n                intrinsics[:2, :] *= 2.0\n\n            if i == 0:\n\n                mask_read_ms = self.read_mask_hr(mask_filename_hr)\n                depth_ms = self.read_depth_hr(depth_filename_hr, scale)\n                #get depth values\n                depth_max = depth_interval * self.ndepths + depth_min\n                depth_values = np.array([depth_min * scale, depth_max * scale], dtype=np.float32)\n                mask = mask_read_ms\n\n            proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32)  #\n            proj_mat[0, :4, :4] = extrinsics\n            proj_mat[1, :3, :3] = intrinsics\n            proj_matrices.append(proj_mat)\n            imgs.append(img.transpose(2,0,1))\n\n        #all\n        # imgs = np.stack(imgs).transpose([0, 3, 1, 2])\n        #ms proj_mats\n        proj_matrices = np.stack(proj_matrices)\n        stage1_pjmats = proj_matrices.copy()\n        stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] / 2.0\n        stage3_pjmats = proj_matrices.copy()\n        stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2\n        stage4_pjmats = proj_matrices.copy()\n        stage4_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4\n\n        proj_matrices_ms = {\n            \"stage1\": stage1_pjmats,\n            \"stage2\": proj_matrices,\n            \"stage3\": stage3_pjmats,\n            \"stage4\": stage4_pjmats\n        }\n\n        return {\"imgs\": imgs,  # Nv C H W\n                \"proj_matrices\": proj_matrices_ms,  # 4 stage of Nv 2 4 4\n                \"depth\": depth_ms,\n                \"depth_values\": depth_values,\n                \"mask\": mask }"
  },
  {
    "path": "datasets/eth3d.py",
    "content": "from torch.utils.data import Dataset\nfrom datasets.data_io import *\nimport os\nimport numpy as np\nimport cv2\nfrom PIL import Image\n\nclass MVSDataset(Dataset):\n    def __init__(self, datapath, split='test', n_views=7, img_wh=(1920,1280)):\n        self.levels = 4\n        self.datapath = datapath\n        self.img_wh = img_wh\n        self.split = split\n        self.build_metas()\n        self.n_views = n_views\n\n    def build_metas(self):\n        self.metas = []\n        if self.split == \"test\":\n            self.scans = ['botanical_garden', 'boulders', 'bridge', 'door',\n                'exhibition_hall', 'lecture_room', 'living_room', 'lounge',\n                'observatory', 'old_computer', 'statue', 'terrace_2']\n\n        elif self.split == \"train\":\n            self.scans = ['courtyard', 'delivery_area', 'electro', 'facade',\n                    'kicker', 'meadow', 'office', 'pipes', 'playground',\n                    'relief', 'relief_2', 'terrace', 'terrains']\n        \n\n        for scan in self.scans:\n            with open(os.path.join(self.datapath, scan, 'pair.txt')) as f:\n                num_viewpoint = int(f.readline())\n                for view_idx in range(num_viewpoint):\n                    ref_view = int(f.readline().rstrip())\n                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n                    if len(src_views) != 0:\n                        self.metas += [(scan, -1, ref_view, src_views)]\n                    \n\n    def read_cam_file(self, filename):\n        with open(filename) as f:\n            lines = [line.rstrip() for line in f.readlines()]\n        # extrinsics: line [1,5), 4x4 matrix\n        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')\n        extrinsics = extrinsics.reshape((4, 4))\n        # intrinsics: line [7-10), 3x3 matrix\n        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')\n        intrinsics = intrinsics.reshape((3, 3))\n        \n        depth_min = float(lines[11].split()[0])\n        if depth_min < 0:\n            depth_min = 1\n        depth_max = float(lines[11].split()[-1])\n\n        return intrinsics, extrinsics, depth_min, depth_max\n\n    def read_img(self, filename):\n        img = Image.open(filename)\n        np_img = np.array(img, dtype=np.float32) / 255.\n        original_h, original_w, _ = np_img.shape\n        np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)\n        return np_img, original_h, original_w\n\n    def __len__(self):\n        return len(self.metas)\n\n    def __getitem__(self, idx):\n        scan, _, ref_view, src_views = self.metas[idx]\n        # use only the reference view and first nviews-1 source views\n        view_ids = [ref_view] + src_views[:self.n_views-1]\n        imgs = []\n\n        # depth = None\n        depth_min = None\n        depth_max = None\n\n        proj_matrices_0 = []\n        proj_matrices_1 = []\n        proj_matrices_2 = []\n        proj_matrices_3 = []\n\n        for i, vid in enumerate(view_ids):\n            img_filename = os.path.join(self.datapath,  scan, f'images/{vid:08d}.jpg')\n            proj_mat_filename = os.path.join(self.datapath, scan, f'cams_1/{vid:08d}_cam.txt')\n\n            img, original_h, original_w = self.read_img(img_filename)\n\n            intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)\n            intrinsics[0] *= self.img_wh[0]/original_w\n            intrinsics[1] *= self.img_wh[1]/original_h\n            imgs.append(img.transpose(2,0,1))\n\n            proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n\n            intrinsics[:2,:] *= 0.125\n            proj_mat_0[0,:4,:4] = extrinsics.copy()\n            proj_mat_0[1,:3,:3] = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_1[0,:4,:4] = extrinsics.copy()\n            proj_mat_1[1,:3,:3] = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_2[0,:4,:4] = extrinsics.copy()\n            proj_mat_2[1,:3,:3] = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_3[0,:4,:4] = extrinsics.copy()\n            proj_mat_3[1,:3,:3] = intrinsics.copy()  \n\n            proj_matrices_0.append(proj_mat_0)\n            proj_matrices_1.append(proj_mat_1)\n            proj_matrices_2.append(proj_mat_2)\n            proj_matrices_3.append(proj_mat_3)\n\n            if i == 0:  # reference view\n                depth_min =  depth_min_\n                depth_max = depth_max_\n\n        # proj_matrices: N*4*4\n        proj={}\n        proj['stage1'] = np.stack(proj_matrices_0)\n        proj['stage2'] = np.stack(proj_matrices_1)\n        proj['stage3'] = np.stack(proj_matrices_2)\n        proj['stage4'] = np.stack(proj_matrices_3)\n\n\n        return {\"imgs\": imgs,                   # N*3*H0*W0\n                \"proj_matrices\": proj, # N*4*4\n                \"depth_values\": np.array([depth_min, depth_max], dtype=np.float32),\n                \"filename\": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + \"{}\"\n                }  \n"
  },
  {
    "path": "datasets/general_eval4.py",
    "content": "from torch.utils.data import Dataset\nimport numpy as np\nimport os, cv2, time\nfrom PIL import Image\nfrom datasets.data_io import *\n\ns_h, s_w = 0, 0\nclass MVSDataset(Dataset):\n    def __init__(self, datapath, listfile, mode, nviews, interval_scale=1.06, **kwargs):\n        super(MVSDataset, self).__init__()\n        self.datapath = datapath\n        self.listfile = listfile\n        self.mode = mode\n        self.nviews = nviews\n        self.ndepths = 192  # Hardcode\n        self.interval_scale = interval_scale\n        self.max_h, self.max_w = kwargs[\"max_h\"], kwargs[\"max_w\"]\n        self.fix_res = kwargs.get(\"fix_res\", False)  #whether to fix the resolution of input image.\n        self.fix_wh = False\n\n        assert self.mode == \"test\"\n        self.metas = self.build_list()\n\n    def build_list(self):\n        metas = []\n        scans = self.listfile\n\n        interval_scale_dict = {}\n        # scans\n        for scan in scans:\n            # determine the interval scale of each scene. default is 1.06\n            if isinstance(self.interval_scale, float):\n                interval_scale_dict[scan] = self.interval_scale\n            else:\n                interval_scale_dict[scan] = self.interval_scale[scan]\n\n            pair_file = \"{}/pair.txt\".format(scan)\n            # read the pair file\n            with open(os.path.join(self.datapath, pair_file)) as f:\n                num_viewpoint = int(f.readline())\n                # viewpoints\n                for view_idx in range(num_viewpoint):\n                    ref_view = int(f.readline().rstrip())\n                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n                    # filter by no src view and fill to nviews\n                    if len(src_views) > 0:\n                        if len(src_views) < self.nviews:\n                            print(\"{}< num_views:{}\".format(len(src_views), self.nviews))\n                            src_views += [src_views[0]] * (self.nviews - len(src_views))\n                        metas.append((scan, ref_view, src_views, scan))\n\n        self.interval_scale = interval_scale_dict\n        print(\"dataset\", self.mode, \"metas:\", len(metas), \"interval_scale:{}\".format(self.interval_scale))\n        return metas\n\n    def __len__(self):\n        return len(self.metas)\n\n    def read_cam_file(self, filename, interval_scale):\n        with open(filename) as f:\n            lines = f.readlines()\n            lines = [line.rstrip() for line in lines]\n        # extrinsics: line [1,5), 4x4 matrix\n        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))\n        # intrinsics: line [7-10), 3x3 matrix\n        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))\n        intrinsics[:2, :] /= 4.0\n        # depth_min & depth_interval: line 11\n        depth_min = float(lines[11].split()[0])\n        depth_interval = float(lines[11].split()[1])\n\n        if len(lines[11].split()) >= 3:\n            num_depth = lines[11].split()[2]\n            depth_max = depth_min + int(float(num_depth)) * depth_interval\n            depth_interval = (depth_max - depth_min) / self.ndepths\n\n        depth_interval *= interval_scale\n\n        return intrinsics, extrinsics, depth_min, depth_interval\n\n    def read_img(self, filename):\n        img = Image.open(filename)\n        # scale 0~255 to 0~1\n        np_img = np.array(img, dtype=np.float32) / 255.\n\n        return np_img\n\n    def read_depth(self, filename):\n        # read pfm depth file\n        return np.array(read_pfm(filename)[0], dtype=np.float32)\n\n    def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=64):\n        h, w = img.shape[:2]\n        if h > max_h or w > max_w:\n            scale = 1.0 * max_h / h\n            if scale * w > max_w:\n                scale = 1.0 * max_w / w\n            new_w, new_h = scale * w // base * base, scale * h // base * base\n        else:\n            new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base\n\n        scale_w = 1.0 * new_w / w\n        scale_h = 1.0 * new_h / h\n        intrinsics[0, :] *= scale_w\n        intrinsics[1, :] *= scale_h\n\n        img = cv2.resize(img, (int(new_w), int(new_h)))\n\n        return img, intrinsics\n\n    def __getitem__(self, idx):\n        global s_h, s_w\n        meta = self.metas[idx]\n        scan, ref_view, src_views, scene_name = meta\n        # use only the reference view and first nviews-1 source views\n        view_ids = [ref_view] + src_views[:self.nviews - 1]\n\n        imgs = []\n        depth_values = None\n        proj_matrices = []\n\n        for i, vid in enumerate(view_ids):\n            img_filename = os.path.join(self.datapath, '{}/images_post/{:0>8}.jpg'.format(scan, vid))\n            if not os.path.exists(img_filename):\n                img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid))\n\n            proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))\n\n            img = self.read_img(img_filename)\n            intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename, interval_scale=\n                                                                                   self.interval_scale[scene_name])\n            # scale input\n            img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_w, self.max_h)\n\n            if self.fix_res:\n                # using the same standard height or width in entire scene.\n                s_h, s_w = img.shape[:2]\n                self.fix_res = False\n                self.fix_wh = True\n\n            if i == 0:\n                if not self.fix_wh:\n                    # using the same standard height or width in each nviews.\n                    s_h, s_w = img.shape[:2]\n\n            # resize to standard height or width\n            c_h, c_w = img.shape[:2]\n            if (c_h != s_h) or (c_w != s_w):\n                scale_h = 1.0 * s_h / c_h\n                scale_w = 1.0 * s_w / c_w\n                img = cv2.resize(img, (s_w, s_h))\n                intrinsics[0, :] *= scale_w\n                intrinsics[1, :] *= scale_h\n\n\n            imgs.append(img.transpose(2,0,1))\n            # extrinsics, intrinsics\n            proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32)  #\n            proj_mat[0, :4, :4] = extrinsics\n            proj_mat[1, :3, :3] = intrinsics\n            proj_matrices.append(proj_mat)\n\n            if i == 0:  # reference view\n                depth_values = np.arange(depth_min, depth_interval * (self.ndepths - 0.5) + depth_min, depth_interval,\n                                         dtype=np.float32)\n\n        #all\n        # imgs = np.stack(imgs).transpose([0, 3, 1, 2])\n        #ms proj_mats\n        proj_matrices = np.stack(proj_matrices)\n        stage1_pjmats = proj_matrices.copy()\n        stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] / 2.0\n        stage3_pjmats = proj_matrices.copy()\n        stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2\n        stage4_pjmats = proj_matrices.copy()\n        stage4_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4\n\n        proj_matrices_ms = {\n            \"stage1\": stage1_pjmats,\n            \"stage2\": proj_matrices,\n            \"stage3\": stage3_pjmats,\n            \"stage4\": stage4_pjmats\n        }\n\n        return {\"imgs\": imgs,\n                \"proj_matrices\": proj_matrices_ms,\n                \"depth_values\": depth_values,\n                \"filename\": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + \"{}\"}\n"
  },
  {
    "path": "datasets/tanks.py",
    "content": "from torch.utils.data import Dataset\nfrom datasets.data_io import *\nimport os\nimport numpy as np\nimport cv2\nfrom PIL import Image\n\nclass MVSDataset(Dataset):\n    def __init__(self, datapath, n_views=7, split='intermediate'):\n        self.levels = 4\n        self.datapath = datapath\n        self.split = split\n        self.build_metas()\n        self.n_views = n_views\n\n    def build_metas(self):\n        self.metas = []\n        if self.split == 'intermediate':\n            self.scans = ['Family', 'Francis', 'Horse', 'Playground', 'Train', 'Lighthouse', 'M60', 'Panther']\n        elif self.split == 'advanced':\n            self.scans = ['Auditorium', 'Ballroom', 'Courtroom',\n                          'Museum', 'Palace', 'Temple']\n\n        for scan in self.scans:\n            with open(os.path.join(self.datapath, self.split, scan, 'pair.txt')) as f:\n                num_viewpoint = int(f.readline())\n                for view_idx in range(num_viewpoint):\n                    ref_view = int(f.readline().rstrip())\n                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n                    if len(src_views) != 0:\n                        self.metas += [(scan, -1, ref_view, src_views)]\n   \n    def read_cam_file(self, filename):\n        with open(filename) as f:\n            lines = [line.rstrip() for line in f.readlines()]\n        # extrinsics: line [1,5), 4x4 matrix\n        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')\n        extrinsics = extrinsics.reshape((4, 4))\n        # intrinsics: line [7-10), 3x3 matrix\n        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')\n        intrinsics = intrinsics.reshape((3, 3))\n        \n        depth_min = float(lines[11].split()[0])\n        depth_max = float(lines[11].split()[-1])\n\n        return intrinsics, extrinsics, depth_min, depth_max\n\n    def read_img(self, filename):\n        img = Image.open(filename)\n        np_img = np.array(img, dtype=np.float32) / 255.\n        return np_img\n\n    def scale_input(self, intrinsics, img):\n        \"\"\"\n        intrinsics: 3x3\n        img: W H C\n        \"\"\"\n        intrinsics[1,2] =  intrinsics[1,2] - 28  # 1080 -> 1024\n        img = img[28:1080-28, :, :]\n        return intrinsics, img\n\n    def __len__(self):\n        return len(self.metas)\n\n    def __getitem__(self, idx):\n        scan, _, ref_view, src_views = self.metas[idx]\n        # use only the reference view and first nviews-1 source views\n        view_ids = [ref_view] + src_views[:self.n_views-1]\n\n        imgs = []\n\n        # depth = None\n        depth_min = None\n        depth_max = None\n\n        proj_matrices_0 = []\n        proj_matrices_1 = []\n        proj_matrices_2 = []\n        proj_matrices_3 = []\n\n        for i, vid in enumerate(view_ids):\n            img_filename = os.path.join(self.datapath, self.split, scan, f'images/{vid:08d}.jpg')\n            proj_mat_filename = os.path.join(self.datapath, self.split, scan, f'cams/{vid:08d}_cam.txt')\n\n            img = self.read_img(img_filename)\n\n            intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)\n            intrinsics, img = self.scale_input(intrinsics, img)\n            imgs.append(img.transpose(2,0,1))\n\n            proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n            proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)\n\n            intrinsics[:2,:] *= 0.125\n            proj_mat_0[0,:4,:4] = extrinsics.copy()\n            proj_mat_0[1,:3,:3] = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_1[0,:4,:4] = extrinsics.copy()\n            proj_mat_1[1,:3,:3] = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_2[0,:4,:4] = extrinsics.copy()\n            proj_mat_2[1,:3,:3] = intrinsics.copy()\n\n            intrinsics[:2,:] *= 2\n            proj_mat_3[0,:4,:4] = extrinsics.copy()\n            proj_mat_3[1,:3,:3] = intrinsics.copy()  \n\n            proj_matrices_0.append(proj_mat_0)\n            proj_matrices_1.append(proj_mat_1)\n            proj_matrices_2.append(proj_mat_2)\n            proj_matrices_3.append(proj_mat_3)\n\n            if i == 0:  # reference view\n                depth_min =  depth_min_\n                depth_max = depth_max_\n\n\n        # proj_matrices: N*4*4\n        proj={}\n        proj['stage1'] = np.stack(proj_matrices_0)\n        proj['stage2'] = np.stack(proj_matrices_1)\n        proj['stage3'] = np.stack(proj_matrices_2)\n        proj['stage4'] = np.stack(proj_matrices_3)\n\n        return {\"imgs\": imgs,                   # N*3*H0*W0\n                \"proj_matrices\": proj, # N*4*4\n                \"depth_values\": np.array([depth_min, depth_max], dtype=np.float32),\n                \"filename\": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + \"{}\"\n                }  \n"
  },
  {
    "path": "evaluations/dtu/BaseEval2Obj_web.m",
    "content": "function BaseEval2Obj_web(BaseEval,method_string,outputPath)\r\n\r\nif(nargin<3)\r\n    outputPath='./';\r\nend\r\n\r\n% tresshold for coloring alpha channel in the range of 0-10 mm\r\ndist_tresshold=10;\r\n\r\ncSet=BaseEval.cSet;\r\n\r\nQdata=BaseEval.Qdata;\r\nalpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold;\r\n\r\nfid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+');\r\n\r\nfor cP=1:size(Qdata,2)\r\n    if(BaseEval.DataInMask(cP))\r\n        C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)\r\n    else\r\n        C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis)\r\n    end\r\n    fprintf(fid,'v %f %f %f %f %f %f\\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]);\r\nend\r\nfclose(fid);\r\n\r\ndisp('Data2Stl saved as obj')\r\n\r\nQstl=BaseEval.Qstl;\r\nfid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+');\r\n\r\nalpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold;\r\n\r\nfor cP=1:size(Qstl,2)\r\n    if(BaseEval.StlAbovePlane(cP))\r\n        C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)\r\n    else\r\n        C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis)\r\n    end\r\n    fprintf(fid,'v %f %f %f %f %f %f\\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]);\r\nend\r\nfclose(fid);\r\n\r\ndisp('Stl2Data saved as obj')"
  },
  {
    "path": "evaluations/dtu/BaseEvalMain_func.m",
    "content": "function None = BaseEvalMain_func(plyPath)\r\n\r\n% clear all\r\n% close all\r\nformat compact\r\n\r\n% script to calculate distances have been measured for all included scans (UsedSets)\r\n\r\ndataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';\r\n% pred_results='cascade_hr/48-32-8_4-2-1_dlossw-0.5-1.0-2.0_chs888/gipuma_4_0.9/';\r\n% plyPath=['../../outputs/1101/dtu/' pred_results];\r\n\r\n% plyPath = '/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/repo_model_aligncorners_ITGT'\r\nresultsPath=[plyPath '/eval_out/'];\r\ndisp(resultsPath);\r\nmkdir(resultsPath);\r\n\r\nmethod_string='mvsnet';\r\nlight_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)\r\nrepresentation_string='Points'; %mvs representation 'Points' or 'Surfaces'\r\n\r\nswitch representation_string\r\n    case 'Points'\r\n        eval_string='_Eval_'; %results naming\r\n        settings_string='';\r\nend\r\n\r\n% get sets used in evaluation\r\nUsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];\r\n% UsedSets=[15];\r\n\r\ndst=0.2;    %Min dist between points when reducing\r\n\r\nparfor cIdx=1:length(UsedSets)\r\n    %Data set number\r\n    cSet = UsedSets(cIdx)\r\n    %input data name\r\n    DataInName=[plyPath sprintf('/%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)]\r\n\r\n    %results name\r\n    EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat']\r\n\r\n    %check if file is already computed\r\n    if(~exist(EvalName,'file'))\r\n        disp(DataInName);\r\n\r\n        time=clock;time(4:5), drawnow\r\n\r\n        tic\r\n        Mesh = plyread(DataInName);\r\n        Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]';\r\n        toc\r\n\r\n        BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath);\r\n\r\n        disp('Saving results'), drawnow\r\n        toc\r\n        mySave(EvalName, BaseEval);\r\n        toc\r\n\r\n        % write obj-file of evaluation\r\n        % BaseEval2Obj_web(BaseEval,method_string, resultsPath)\r\n        % toc\r\n        time=clock;time(4:5), drawnow\r\n\r\n        BaseEval.MaxDist=20; %outlier threshold of 20 mm\r\n\r\n        BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane\r\n        BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers\r\n\r\n        BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask\r\n        BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers\r\n\r\n        fprintf(\"mean/median Data (acc.) %f/%f\\n\", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));\r\n        fprintf(\"mean/median Stl (comp.) %f/%f\\n\", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));\r\n    end\r\nend\r\n\r\nend\r\n\r\nfunction mySave(filenm, data)\r\n    save(filenm, 'data');\r\nend"
  },
  {
    "path": "evaluations/dtu/BaseEvalMain_web.m",
    "content": "clear all\r\nclose all\r\nformat compact\r\nclc\r\n\r\n% script to calculate distances have been measured for all included scans (UsedSets)\r\n\r\ndataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';\r\n% pred_results='cascade_hr/48-32-8_4-2-1_dlossw-0.5-1.0-2.0_chs888/gipuma_4_0.9/';\r\n% plyPath=['../../outputs/1101/dtu/' pred_results];\r\n\r\nplyPath = '/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/ccc_4x2_scedule_aligncorners'\r\nresultsPath=[plyPath '/eval_out/'];\r\ndisp(resultsPath);\r\nmkdir(resultsPath);\r\n\r\nmethod_string='mvsnet';\r\nlight_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)\r\nrepresentation_string='Points'; %mvs representation 'Points' or 'Surfaces'\r\n\r\nswitch representation_string\r\n    case 'Points'\r\n        eval_string='_Eval_'; %results naming\r\n        settings_string='';\r\nend\r\n\r\n% get sets used in evaluation\r\nUsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];\r\n% UsedSets=[15];\r\n\r\ndst=0.2;    %Min dist between points when reducing\r\n\r\nparfor cIdx=1:length(UsedSets)\r\n    %Data set number\r\n    cSet = UsedSets(cIdx)\r\n    %input data name\r\n    DataInName=[plyPath sprintf('/%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)]\r\n\r\n    %results name\r\n    EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat']\r\n\r\n    %check if file is already computed\r\n    if(~exist(EvalName,'file'))\r\n        disp(DataInName);\r\n\r\n        time=clock;time(4:5), drawnow\r\n\r\n        tic\r\n        Mesh = plyread(DataInName);\r\n        Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]';\r\n        toc\r\n\r\n        BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath);\r\n\r\n        disp('Saving results'), drawnow\r\n        toc\r\n        mySave(EvalName, BaseEval);\r\n        toc\r\n\r\n        % write obj-file of evaluation\r\n        % BaseEval2Obj_web(BaseEval,method_string, resultsPath)\r\n        % toc\r\n        time=clock;time(4:5), drawnow\r\n\r\n        BaseEval.MaxDist=20; %outlier threshold of 20 mm\r\n\r\n        BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane\r\n        BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers\r\n\r\n        BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask\r\n        BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers\r\n\r\n        fprintf(\"mean/median Data (acc.) %f/%f\\n\", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));\r\n        fprintf(\"mean/median Stl (comp.) %f/%f\\n\", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));\r\n    end\r\nend\r\n\r\n\r\nfunction mySave(filenm, data)\r\n    save(filenm, 'data');\r\nend"
  },
  {
    "path": "evaluations/dtu/ComputeStat_func.m",
    "content": "function None = ComputeStat_func(plyPath)\r\nformat compact\r\n\r\n% script to calculate the statistics for each scan given this will currently only run if distances have been measured\r\n% for all included scans (UsedSets)\r\n\r\n% modify the path to evaluate your models\r\ndataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';\r\n% resultsPath=['/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/repo_model_aligncorners_ITGT/eval_out/'];\r\nresultsPath=[plyPath '/eval_out/'];\r\n\r\nMaxDist=20; %outlier thresshold of 20 mm\r\n\r\ntime=clock;\r\n\r\nmethod_string='mvsnet';\r\nlight_string='l3'; %'l7'; l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)\r\nrepresentation_string='Points'; %mvs representation 'Points' or 'Surfaces'\r\n\r\nswitch representation_string\r\n    case 'Points'\r\n        eval_string='_Eval_'; %results naming\r\n        settings_string='';\r\nend\r\n\r\n% get sets used in evaluation\r\nUsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];\r\n\r\nnStat=length(UsedSets);\r\n\r\nBaseStat.nStl=zeros(1,nStat);\r\nBaseStat.nData=zeros(1,nStat);\r\nBaseStat.MeanStl=zeros(1,nStat);\r\nBaseStat.MeanData=zeros(1,nStat);\r\nBaseStat.VarStl=zeros(1,nStat);\r\nBaseStat.VarData=zeros(1,nStat);\r\nBaseStat.MedStl=zeros(1,nStat);\r\nBaseStat.MedData=zeros(1,nStat);\r\n\r\nfor cStat=1:length(UsedSets) %Data set number\r\n\r\n    currentSet=UsedSets(cStat);\r\n\r\n    %input results name\r\n    EvalName=[resultsPath method_string eval_string num2str(currentSet) '.mat'];\r\n\r\n    disp(EvalName);\r\n    load(EvalName);\r\n\r\n    Dstl=data.Dstl(data.StlAbovePlane); %use only points that are above the plane\r\n    Dstl=Dstl(Dstl<MaxDist); % discard outliers\r\n\r\n    Ddata=data.Ddata(data.DataInMask); %use only points that within mask\r\n    Ddata=Ddata(Ddata<MaxDist); % discard outliers\r\n\r\n    BaseStat.nStl(cStat)=length(Dstl);\r\n    BaseStat.nData(cStat)=length(Ddata);\r\n\r\n    BaseStat.MeanStl(cStat)=mean(Dstl);\r\n    BaseStat.MeanData(cStat)=mean(Ddata);\r\n\r\n    BaseStat.VarStl(cStat)=var(Dstl);\r\n    BaseStat.VarData(cStat)=var(Ddata);\r\n\r\n    BaseStat.MedStl(cStat)=median(Dstl);\r\n    BaseStat.MedData(cStat)=median(Ddata);\r\n\r\n    disp(\"acc\");\r\n    disp(mean(Ddata));\r\n    disp(\"comp\");\r\n    disp(mean(Dstl));\r\n    time=clock;\r\nend\r\n\r\ndisp(BaseStat);\r\ndisp(\"mean acc\")\r\ndisp(mean(BaseStat.MeanData));\r\ndisp(\"mean comp\")\r\ndisp(mean(BaseStat.MeanStl));\r\ndisp(\"mean overall\")\r\ndisp((mean(BaseStat.MeanStl)+mean(BaseStat.MeanData))/2.0);\r\n\r\ntotalStatName=[resultsPath 'TotalStat_' method_string eval_string '.mat']\r\nsave(totalStatName,'BaseStat','time','MaxDist');\r\n\r\ntotalStatName=[resultsPath 'TotalStat_' method_string eval_string '.txt']\r\nfp=fopen(totalStatName,'a');\r\nfprintf(fp,'%f\\n',mean(BaseStat.MeanData));\r\nfprintf(fp,'%f\\n',mean(BaseStat.MeanStl));\r\nend"
  },
  {
    "path": "evaluations/dtu/ComputeStat_web.m",
    "content": "clear all\r\nclose all\r\nformat compact\r\nclc\r\n\r\n% script to calculate the statistics for each scan given this will currently only run if distances have been measured\r\n% for all included scans (UsedSets)\r\n\r\n% modify the path to evaluate your models\r\ndataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';\r\nresultsPath=['/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/repo_model_aligncorners_ITGT/eval_out/'];\r\n\r\nMaxDist=20; %outlier thresshold of 20 mm\r\n\r\ntime=clock;\r\n\r\nmethod_string='mvsnet';\r\nlight_string='l3'; %'l7'; l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)\r\nrepresentation_string='Points'; %mvs representation 'Points' or 'Surfaces'\r\n\r\nswitch representation_string\r\n    case 'Points'\r\n        eval_string='_Eval_'; %results naming\r\n        settings_string='';\r\nend\r\n\r\n% get sets used in evaluation\r\nUsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];\r\n\r\nnStat=length(UsedSets);\r\n\r\nBaseStat.nStl=zeros(1,nStat);\r\nBaseStat.nData=zeros(1,nStat);\r\nBaseStat.MeanStl=zeros(1,nStat);\r\nBaseStat.MeanData=zeros(1,nStat);\r\nBaseStat.VarStl=zeros(1,nStat);\r\nBaseStat.VarData=zeros(1,nStat);\r\nBaseStat.MedStl=zeros(1,nStat);\r\nBaseStat.MedData=zeros(1,nStat);\r\n\r\nfor cStat=1:length(UsedSets) %Data set number\r\n\r\n    currentSet=UsedSets(cStat);\r\n\r\n    %input results name\r\n    EvalName=[resultsPath method_string eval_string num2str(currentSet) '.mat'];\r\n\r\n    disp(EvalName);\r\n    load(EvalName);\r\n\r\n    Dstl=data.Dstl(data.StlAbovePlane); %use only points that are above the plane\r\n    Dstl=Dstl(Dstl<MaxDist); % discard outliers\r\n\r\n    Ddata=data.Ddata(data.DataInMask); %use only points that within mask\r\n    Ddata=Ddata(Ddata<MaxDist); % discard outliers\r\n\r\n    BaseStat.nStl(cStat)=length(Dstl);\r\n    BaseStat.nData(cStat)=length(Ddata);\r\n\r\n    BaseStat.MeanStl(cStat)=mean(Dstl);\r\n    BaseStat.MeanData(cStat)=mean(Ddata);\r\n\r\n    BaseStat.VarStl(cStat)=var(Dstl);\r\n    BaseStat.VarData(cStat)=var(Ddata);\r\n\r\n    BaseStat.MedStl(cStat)=median(Dstl);\r\n    BaseStat.MedData(cStat)=median(Ddata);\r\n\r\n    disp(\"acc\");\r\n    disp(mean(Ddata));\r\n    disp(\"comp\");\r\n    disp(mean(Dstl));\r\n    time=clock;\r\nend\r\n\r\ndisp(BaseStat);\r\ndisp(\"mean acc\")\r\ndisp(mean(BaseStat.MeanData));\r\ndisp(\"mean comp\")\r\ndisp(mean(BaseStat.MeanStl));\r\ndisp(\"mean overall\")\r\ndisp((mean(BaseStat.MeanStl)+mean(BaseStat.MeanData))/2.0);\r\n\r\ntotalStatName=[resultsPath 'TotalStat_' method_string eval_string '.mat']\r\nsave(totalStatName,'BaseStat','time','MaxDist');\r\n\r\ntotalStatName=[resultsPath 'TotalStat_' method_string eval_string '.txt']\r\nfp=fopen(totalStatName,'a');\r\nfprintf(fp,'%f\\n',mean(BaseStat.MeanData));\r\nfprintf(fp,'%f\\n',mean(BaseStat.MeanStl));\r\n"
  },
  {
    "path": "evaluations/dtu/MaxDistCP.m",
    "content": "function Dist = MaxDistCP(Qto,Qfrom,BB,MaxDist)\r\n\r\nDist=ones(1,size(Qfrom,2))*MaxDist;\r\n\r\nRange=floor((BB(2,:)-BB(1,:))/MaxDist);\r\n\r\ntic\r\nDone=0;\r\nLookAt=zeros(1,size(Qfrom,2));\r\nfor x=0:Range(1),\r\n    for y=0:Range(2),\r\n        for z=0:Range(3),\r\n            \r\n            Low=BB(1,:)+[x y z]*MaxDist;\r\n            High=Low+MaxDist;\r\n            \r\n            idxF=find(Qfrom(1,:)>=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &...\r\n                Qfrom(1,:)<High(1) & Qfrom(2,:)<High(2) & Qfrom(3,:)<High(3));\r\n            SQfrom=Qfrom(:,idxF);\r\n            LookAt(idxF)=LookAt(idxF)+1; %Debug\r\n            \r\n            Low=Low-MaxDist;\r\n            High=High+MaxDist;\r\n            idxT=find(Qto(1,:)>=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &...\r\n                Qto(1,:)<High(1) & Qto(2,:)<High(2) & Qto(3,:)<High(3));\r\n            SQto=Qto(:,idxT);\r\n            \r\n            if(isempty(SQto))\r\n                Dist(idxF)=MaxDist;\r\n            else\r\n                KDstl=KDTreeSearcher(SQto');\r\n                [~,SDist] = knnsearch(KDstl,SQfrom');\r\n                Dist(idxF)=SDist;\r\n                \r\n            end\r\n            \r\n            Done=Done+length(idxF); %Debug\r\n            \r\n        end\r\n    end\r\n    %Complete=Done/size(Qfrom,2);\r\n    %EstTime=(toc/Complete)/60\r\n    %toc\r\n    %LA=[sum(LookAt==0),...\r\n    %\tsum(LookAt==1),...\r\n   % \tsum(LookAt==2),...\r\n   % \tsum(LookAt==3),...\r\n   % \tsum(LookAt>3)]\r\nend\r\n\r\n"
  },
  {
    "path": "evaluations/dtu/PointCompareMain.m",
    "content": "function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath)\r\n% evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the\r\n% distances from the evaluation points to the reference\r\n\r\ntic\r\n% reduce points 0.2 mm neighbourhood density\r\nQdata=reducePts_haa(Qdata,dst);\r\ntoc\r\n\r\nStlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply'];\r\n\r\nStlMesh = plyread(StlInName);  %STL points already reduced 0.2 mm neighbourhood density\r\nQstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]';\r\n\r\n%Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res)\r\nMargin=10;\r\nMaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat'];\r\nload(MaskName)\r\n\r\nMaxDist=60;\r\ndisp('Computing Data 2 Stl distances')\r\nDdata = MaxDistCP(Qstl,Qdata,BB,MaxDist);\r\ntoc\r\n\r\ndisp('Computing Stl 2 Data distances')\r\nDstl=MaxDistCP(Qdata,Qstl,BB,MaxDist);\r\ndisp('Distances computed')\r\ntoc\r\n\r\n%use mask\r\n%From Get mask - inverted & modified.\r\nOne=ones(1,size(Qdata,2));\r\nQv=(Qdata-BB(1,:)'*One)/Res+1;\r\nQv=round(Qv);\r\n\r\nMidx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3));\r\nMidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1));\r\nMidx2=find(ObsMask(MidxA));\r\n\r\nBaseEval.DataInMask(1:size(Qv,2))=false;\r\nBaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask\r\n\r\nBaseEval.cSet=cSet;\r\nBaseEval.Margin=Margin;         %Margin of masks\r\nBaseEval.dst=dst;               %Min dist between points when reducing\r\nBaseEval.Qdata=Qdata;           %Input data points\r\nBaseEval.Ddata=Ddata;           %distance from data to stl\r\nBaseEval.Qstl=Qstl;             %Input stl points\r\nBaseEval.Dstl=Dstl;             %Distance from the stl to data\r\n\r\nload([dataPath '/ObsMask/Plane' num2str(cSet)],'P')\r\nBaseEval.GroundPlane=P;         % Plane used to destinguise which Stl points are 'used'\r\nBaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane'\r\nBaseEval.Time=clock;            %Time when computation is finished\r\n\r\n\r\n\r\n\r\n"
  },
  {
    "path": "evaluations/dtu/plyread.m",
    "content": "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\r\nfunction [Elements,varargout] = plyread(Path,Str)\r\n%PLYREAD   Read a PLY 3D data file.\r\n%   [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file\r\n%   FILENAME and returns a structure DATA.  The fields in this structure\r\n%   are defined by the PLY header; each element type is a field and each\r\n%   element property is a subfield.  If the file contains any comments,\r\n%   they are returned in a cell string array COMMENTS.\r\n%\r\n%   [TRI,PTS] = PLYREAD(FILENAME,'tri') or\r\n%   [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex\r\n%   and face data into triangular connectivity and vertex arrays.  The\r\n%   mesh can then be displayed using the TRISURF command.\r\n%\r\n%   Note: This function is slow for large mesh files (+50K faces),\r\n%   especially when reading data with list type properties.\r\n%\r\n%   Example:\r\n%   [Tri,Pts] = PLYREAD('cow.ply','tri');\r\n%   trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); \r\n%   colormap(gray); axis equal;\r\n%\r\n%   See also: PLYWRITE\r\n\r\n% Pascal Getreuer 2004\r\n\r\n[fid,Msg] = fopen(Path,'rt');\t% open file in read text mode\r\n\r\nif fid == -1, error(Msg); end\r\n\r\nBuf = fscanf(fid,'%s',1);\r\nif ~strcmp(Buf,'ply')\r\n   fclose(fid);\r\n   error('Not a PLY file.'); \r\nend\r\n\r\n\r\n%%% read header %%%\r\n\r\nPosition = ftell(fid);\r\nFormat = '';\r\nNumComments = 0;\r\nComments = {};\t\t\t\t% for storing any file comments\r\nNumElements = 0;\r\nNumProperties = 0;\r\nElements = [];\t\t\t\t% structure for holding the element data\r\nElementCount = [];\t\t% number of each type of element in file\r\nPropertyTypes = [];\t\t% corresponding structure recording property types\r\nElementNames = {};\t\t% list of element names in the order they are stored in the file\r\nPropertyNames = [];\t\t% structure of lists of property names\r\n\r\nwhile 1\r\n   Buf = fgetl(fid);   \t\t\t\t\t\t\t\t% read one line from file\r\n   BufRem = Buf;\r\n   Token = {};\r\n   Count = 0;\r\n   \r\n   while ~isempty(BufRem)\t\t\t\t\t\t\t\t% split line into tokens\r\n      [tmp,BufRem] = strtok(BufRem);\r\n      \r\n      if ~isempty(tmp)\r\n         Count = Count + 1;\t\t\t\t\t\t\t% count tokens\r\n         Token{Count} = tmp;\r\n      end\r\n   end\r\n   \r\n   if Count \t\t% parse line\r\n      switch lower(Token{1})\r\n      case 'format'\t\t% read data format\r\n         if Count >= 2\r\n            Format = lower(Token{2});\r\n            \r\n            if Count == 3 & ~strcmp(Token{3},'1.0')\r\n               fclose(fid);\r\n               error('Only PLY format version 1.0 supported.');\r\n            end\r\n         end\r\n      case 'comment'\t\t% read file comment\r\n         NumComments = NumComments + 1;\r\n         Comments{NumComments} = '';\r\n         for i = 2:Count\r\n            Comments{NumComments} = [Comments{NumComments},Token{i},' '];\r\n         end\r\n      case 'element'\t\t% element name\r\n         if Count >= 3\r\n            if isfield(Elements,Token{2})\r\n               fclose(fid);\r\n               error(['Duplicate element name, ''',Token{2},'''.']);\r\n            end\r\n            \r\n            NumElements = NumElements + 1;\r\n            NumProperties = 0;\r\n   \t      Elements = setfield(Elements,Token{2},[]);\r\n            PropertyTypes = setfield(PropertyTypes,Token{2},[]);\r\n            ElementNames{NumElements} = Token{2};\r\n            PropertyNames = setfield(PropertyNames,Token{2},{});\r\n            CurElement = Token{2};\r\n            ElementCount(NumElements) = str2double(Token{3});\r\n            \r\n            if isnan(ElementCount(NumElements))\r\n               fclose(fid);\r\n               error(['Bad element definition: ',Buf]); \r\n            end            \r\n         else\r\n            error(['Bad element definition: ',Buf]);\r\n         end         \r\n      case 'property'\t% element property\r\n         if ~isempty(CurElement) & Count >= 3            \r\n            NumProperties = NumProperties + 1;\r\n            eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],...\r\n               'fclose(fid);error([''Error reading property: '',Buf])');\r\n            \r\n            if tmp\r\n               error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']);\r\n            end            \r\n            \r\n            % add property subfield to Elements\r\n            eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ...\r\n               'fclose(fid);error([''Error reading property: '',Buf])');            \r\n            % add property subfield to PropertyTypes and save type\r\n            eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ...\r\n               'fclose(fid);error([''Error reading property: '',Buf])');            \r\n            % record property name order \r\n            eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ...\r\n               'fclose(fid);error([''Error reading property: '',Buf])');\r\n         else\r\n            fclose(fid);\r\n            \r\n            if isempty(CurElement)            \r\n               error(['Property definition without element definition: ',Buf]);\r\n            else               \r\n               error(['Bad property definition: ',Buf]);\r\n            end            \r\n         end         \r\n      case 'end_header'\t% end of header, break from while loop\r\n         break;\t\t\r\n      end\r\n   end\r\nend\r\n\r\n%%% set reading for specified data format %%%\r\n\r\nif isempty(Format)\r\n\twarning('Data format unspecified, assuming ASCII.');\r\n   Format = 'ascii';\r\nend\r\n\r\nswitch Format\r\ncase 'ascii'\r\n   Format = 0;\r\ncase 'binary_little_endian'\r\n   Format = 1;\r\ncase 'binary_big_endian'\r\n   Format = 2;\r\notherwise\r\n   fclose(fid);\r\n   error(['Data format ''',Format,''' not supported.']);\r\nend\r\n\r\nif ~Format   \r\n   Buf = fscanf(fid,'%f');\t\t% read the rest of the file as ASCII data\r\n   BufOff = 1;\r\nelse\r\n   % reopen the file in read binary mode\r\n   fclose(fid);\r\n   \r\n   if Format == 1\r\n      fid = fopen(Path,'r','ieee-le.l64');\t\t% little endian\r\n   else\r\n      fid = fopen(Path,'r','ieee-be.l64');\t\t% big endian\r\n   end\r\n   \r\n   % find the end of the header again (using ftell on the old handle doesn't give the correct position)   \r\n   BufSize = 8192;\r\n   Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')];\r\n   i = [];\r\n   tmp = -11;\r\n   \r\n   while isempty(i)\r\n   \ti = findstr(Buf,['end_header',13,10]);\t\t\t% look for end_header + CR/LF\r\n   \ti = [i,findstr(Buf,['end_header',10])];\t\t% look for end_header + LF\r\n      \r\n      if isempty(i)\r\n         tmp = tmp + BufSize;\r\n         Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')];\r\n      end\r\n   end\r\n   \r\n   % seek to just after the line feed\r\n   fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1);\r\nend\r\n\r\n\r\n%%% read element data %%%\r\n\r\n% PLY and MATLAB data types (for fread)\r\nPlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ...\r\n   'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'};\r\nMatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'};\r\nSizeOf = [1,1,2,2,4,4,4,8];\t% size in bytes of each type\r\n\r\nfor i = 1:NumElements\r\n   % get current element property information\r\n   eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']);\r\n   eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']);\r\n   NumProperties = size(CurPropertyNames,2);\r\n   \r\n%   fprintf('Reading %s...\\n',ElementNames{i});\r\n      \r\n   if ~Format\t%%% read ASCII data %%%\r\n      for j = 1:NumProperties\r\n         Token = getfield(CurPropertyTypes,CurPropertyNames{j});\r\n         \r\n         if strcmpi(Token{1},'list')\r\n            Type(j) = 1;\r\n         else\r\n            Type(j) = 0;\r\n\t\t\tend\r\n      end\r\n      \r\n      % parse buffer\r\n      if ~any(Type)\r\n         % no list types\r\n         Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))';\r\n         BufOff = BufOff + ElementCount(i)*NumProperties;\r\n      else\r\n         ListData = cell(NumProperties,1);\r\n         \r\n         for k = 1:NumProperties\r\n            ListData{k} = cell(ElementCount(i),1);\r\n         end\r\n         \r\n         % list type\r\n\t\t   for j = 1:ElementCount(i)\r\n   \t      for k = 1:NumProperties\r\n      \t      if ~Type(k)\r\n         \t      Data(j,k) = Buf(BufOff);\r\n            \t   BufOff = BufOff + 1;\r\n\t            else\r\n   \t            tmp = Buf(BufOff);\r\n      \t         ListData{k}{j} = Buf(BufOff+(1:tmp))';\r\n         \t      BufOff = BufOff + tmp + 1;\r\n            \tend\r\n            end\r\n         end\r\n      end\r\n   else\t\t%%% read binary data %%%\r\n      % translate PLY data type names to MATLAB data type names\r\n      ListFlag = 0;\t\t% = 1 if there is a list type \r\n      SameFlag = 1;     % = 1 if all types are the same\r\n      \r\n      for j = 1:NumProperties\r\n         Token = getfield(CurPropertyTypes,CurPropertyNames{j});\r\n         \r\n         if ~strcmp(Token{1},'list')\t\t\t% non-list type\r\n\t         tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1;\r\n         \r\n            if ~isempty(tmp)\r\n               TypeSize(j) = SizeOf(tmp);\r\n               Type{j} = MatlabTypeNames{tmp};\r\n               TypeSize2(j) = 0;\r\n               Type2{j} = '';\r\n               \r\n               SameFlag = SameFlag & strcmp(Type{1},Type{j});\r\n\t         else\r\n   \t         fclose(fid);\r\n               error(['Unknown property data type, ''',Token{1},''', in ', ...\r\n                     ElementNames{i},'.',CurPropertyNames{j},'.']);\r\n         \tend\r\n         else\t\t\t\t\t\t\t\t\t\t\t% list type\r\n            if length(Token) == 3\r\n               ListFlag = 1;\r\n               SameFlag = 0;\r\n               tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1;\r\n               tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1;\r\n         \r\n               if ~isempty(tmp) & ~isempty(tmp2)\r\n                  TypeSize(j) = SizeOf(tmp);\r\n                  Type{j} = MatlabTypeNames{tmp};\r\n                  TypeSize2(j) = SizeOf(tmp2);\r\n                  Type2{j} = MatlabTypeNames{tmp2};\r\n\t   \t      else\r\n   \t   \t      fclose(fid);\r\n               \terror(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ...\r\n                        ElementNames{i},'.',CurPropertyNames{j},'.']);\r\n               end\r\n            else\r\n               fclose(fid);\r\n               error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']);\r\n            end\r\n         end\r\n      end\r\n      \r\n      % read file\r\n      if ~ListFlag\r\n         if SameFlag\r\n            % no list types, all the same type (fast)\r\n            Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})';\r\n         else\r\n            % no list types, mixed type\r\n            Data = zeros(ElementCount(i),NumProperties);\r\n            \r\n         \tfor j = 1:ElementCount(i)\r\n        \t\t\tfor k = 1:NumProperties\r\n               \tData(j,k) = fread(fid,1,Type{k});\r\n              \tend\r\n         \tend\r\n         end\r\n      else\r\n         ListData = cell(NumProperties,1);\r\n         \r\n         for k = 1:NumProperties\r\n            ListData{k} = cell(ElementCount(i),1);\r\n         end\r\n         \r\n         if NumProperties == 1\r\n            BufSize = 512;\r\n            SkipNum = 4;\r\n            j = 0;\r\n            \r\n            % list type, one property (fast if lists are usually the same length)\r\n            while j < ElementCount(i)\r\n               Position = ftell(fid);\r\n               % read in BufSize count values, assuming all counts = SkipNum\r\n               [Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1));\r\n               Miss = find(Buf ~= SkipNum);\t\t\t\t\t% find first count that is not SkipNum\r\n               fseek(fid,Position + TypeSize(1),-1); \t\t% seek back to after first count                              \r\n               \r\n               if isempty(Miss)\t\t\t\t\t\t\t\t\t% all counts are SkipNum\r\n                  Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';\r\n                  fseek(fid,-TypeSize(1),0); \t\t\t\t% undo last skip\r\n                  \r\n                  for k = 1:BufSize\r\n                     ListData{1}{j+k} = Buf(k,:);\r\n                  end\r\n                  \r\n                  j = j + BufSize;\r\n                  BufSize = floor(1.5*BufSize);\r\n               else\r\n                  if Miss(1) > 1\t\t\t\t\t\t\t\t\t% some counts are SkipNum\r\n                     Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';                     \r\n                     \r\n                     for k = 1:Miss(1)-1\r\n                        ListData{1}{j+k} = Buf2(k,:);\r\n                     end\r\n                     \r\n                     j = j + k;\r\n                  end\r\n                  \r\n                  % read in the list with the missed count\r\n                  SkipNum = Buf(Miss(1));\r\n                  j = j + 1;\r\n                  ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1});\r\n                  BufSize = ceil(0.6*BufSize);\r\n               end\r\n            end\r\n         else\r\n            % list type(s), multiple properties (slow)\r\n            Data = zeros(ElementCount(i),NumProperties);\r\n            \r\n            for j = 1:ElementCount(i)\r\n         \t\tfor k = 1:NumProperties\r\n            \t\tif isempty(Type2{k})\r\n               \t\tData(j,k) = fread(fid,1,Type{k});\r\n            \t\telse\r\n               \t\ttmp = fread(fid,1,Type{k});\r\n               \t\tListData{k}{j} = fread(fid,[1,tmp],Type2{k});\r\n\t\t            end\r\n      \t\t   end\r\n      \t\tend\r\n         end\r\n      end\r\n   end\r\n   \r\n   % put data into Elements structure\r\n   for k = 1:NumProperties\r\n   \tif (~Format & ~Type(k)) | (Format & isempty(Type2{k}))\r\n      \teval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']);\r\n      else\r\n      \teval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']);\r\n\t\tend\r\n   end\r\nend\r\n\r\nclear Data ListData;\r\nfclose(fid);\r\n\r\nif (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2   \r\n   % find vertex element field\r\n   Name = {'vertex','Vertex','point','Point','pts','Pts'};\r\n   Names = [];\r\n   \r\n   for i = 1:length(Name)\r\n      if any(strcmp(ElementNames,Name{i}))\r\n         Names = getfield(PropertyNames,Name{i});\r\n         Name = Name{i};         \r\n         break;\r\n      end\r\n   end\r\n   \r\n   if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z'))\r\n      eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']);\r\n   else\r\n      varargout{1} = zeros(1,3);\r\n\tend\r\n           \r\n   varargout{2} = Elements;\r\n   varargout{3} = Comments;\r\n   Elements = [];\r\n   \r\n   % find face element field\r\n   Name = {'face','Face','poly','Poly','tri','Tri'};\r\n   Names = [];\r\n   \r\n   for i = 1:length(Name)\r\n      if any(strcmp(ElementNames,Name{i}))\r\n         Names = getfield(PropertyNames,Name{i});\r\n         Name = Name{i};\r\n         break;\r\n      end\r\n   end\r\n   \r\n   if ~isempty(Names)\r\n      % find vertex indices property subfield\r\n\t   PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'};           \r\n      \r\n   \tfor i = 1:length(PropertyName)\r\n      \tif any(strcmp(Names,PropertyName{i}))\r\n         \tPropertyName = PropertyName{i};\r\n\t         break;\r\n   \t   end\r\n      end\r\n      \r\n      if ~iscell(PropertyName)\r\n         % convert face index lists to triangular connectivity\r\n         eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']);\r\n  \t\t\tN = length(FaceIndices);\r\n   \t\tElements = zeros(N*2,3);\r\n   \t\tExtra = 0;   \r\n\r\n\t\t\tfor k = 1:N\r\n   \t\t\tElements(k,:) = FaceIndices{k}(1:3);\r\n   \r\n   \t\t\tfor j = 4:length(FaceIndices{k})\r\n      \t\t\tExtra = Extra + 1;      \r\n\t      \t\tElements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)];\r\n   \t\t\tend\r\n         end\r\n         Elements = Elements(1:N+Extra,:) + 1;\r\n      end\r\n   end\r\nelse\r\n   varargout{1} = Comments;\r\nend"
  },
  {
    "path": "evaluations/dtu/reducePts_haa.m",
    "content": "function [ptsOut,indexSet] = reducePts_haa(pts, dst)\n\n%Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance\n% between points is 'dst'. Writen by abd, edited by haa, then by raje\n\nnPoints=size(pts,2);\n\nindexSet=true(nPoints,1);\nRandOrd=randperm(nPoints);\n\n%tic\nNS = KDTreeSearcher(pts');\n%toc\n\n% search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big\nChunks=1:min(4e6,nPoints-1):nPoints;\nChunks(end)=nPoints;\n\nfor cChunk=1:(length(Chunks)-1)\n    Range=Chunks(cChunk):Chunks(cChunk+1);\n    \n    idx = rangesearch(NS,pts(:,RandOrd(Range))',dst);\n    \n    for i = 1:size(idx,1)\n        id =RandOrd(i-1+Chunks(cChunk));\n        if (indexSet(id))\n            indexSet(idx{i}) = 0;\n            indexSet(id) = 1;\n        end\n    end\nend\n\nptsOut = pts(:,indexSet);\n\ndisp(['downsample factor: ' num2str(nPoints/sum(indexSet))]);\n"
  },
  {
    "path": "lists/blendedmvs/train.txt",
    "content": "5c1f33f1d33e1f2e4aa6dda4\n5bfe5ae0fe0ea555e6a969ca\n5bff3c5cfe0ea555e6bcbf3a\n58eaf1513353456af3a1682a\n5bfc9d5aec61ca1dd69132a2\n5bf18642c50e6f7f8bdbd492\n5bf26cbbd43923194854b270\n5bf17c0fd439231948355385\n5be3ae47f44e235bdbbc9771\n5be3a5fb8cfdd56947f6b67c\n5bbb6eb2ea1cfa39f1af7e0c\n5ba75d79d76ffa2c86cf2f05\n5bb7a08aea1cfa39f1a947ab\n5b864d850d072a699b32f4ae\n5b6eff8b67b396324c5b2672\n5b6e716d67b396324c2d77cb\n5b69cc0cb44b61786eb959bf\n5b62647143840965efc0dbde\n5b60fa0c764f146feef84df0\n5b558a928bbfb62204e77ba2\n5b271079e0878c3816dacca4\n5b08286b2775267d5b0634ba\n5afacb69ab00705d0cefdd5b\n5af28cea59bc705737003253\n5af02e904c8216544b4ab5a2\n5aa515e613d42d091d29d300\n5c34529873a8df509ae57b58\n5c34300a73a8df509add216d\n5c1af2e2bee9a723c963d019\n5c1892f726173c3a09ea9aeb\n5c0d13b795da9479e12e2ee9\n5c062d84a96e33018ff6f0a6\n5bfd0f32ec61ca1dd69dc77b\n5bf21799d43923194842c001\n5bf3a82cd439231948877aed\n5bf03590d4392319481971dc\n5beb6e66abd34c35e18e66b9\n5be883a4f98cee15019d5b83\n5be47bf9b18881428d8fbc1d\n5bcf979a6d5f586b95c258cd\n5bce7ac9ca24970bce4934b6\n5bb8a49aea1cfa39f1aa7f75\n5b78e57afc8fcf6781d0c3ba\n5b21e18c58e2823a67a10dd8\n5b22269758e2823a67a3bd03\n5b192eb2170cf166458ff886\n5ae2e9c5fe405c5076abc6b2\n5adc6bd52430a05ecb2ffb85\n5ab8b8e029f5351f7f2ccf59\n5abc2506b53b042ead637d86\n5ab85f1dac4291329b17cb50\n5a969eea91dfc339a9a3ad2c\n5a8aa0fab18050187cbe060e\n5a7d3db14989e929563eb153\n5a69c47d0d5d0a7f3b2e9752\n5a618c72784780334bc1972d\n5a6464143d809f1d8208c43c\n5a588a8193ac3d233f77fbca\n5a57542f333d180827dfc132\n5a572fd9fc597b0478a81d14\n5a563183425d0f5186314855\n5a4a38dad38c8a075495b5d2\n5a48d4b2c7dab83a7d7b9851\n5a489fb1c7dab83a7d7b1070\n5a48ba95c7dab83a7d7b44ed\n5a3ca9cb270f0e3f14d0eddb\n5a3cb4e4270f0e3f14d12f43\n5a3f4aba5889373fbbc5d3b5\n5a0271884e62597cdee0d0eb\n59e864b2a9e91f2c5529325f\n599aa591d5b41f366fed0d58\n59350ca084b7f26bf5ce6eb8\n59338e76772c3e6384afbb15\n5c20ca3a0843bc542d94e3e2\n5c1dbf200843bc542d8ef8c4\n5c1b1500bee9a723c96c3e78\n5bea87f4abd34c35e1860ab5\n5c2b3ed5e611832e8aed46bf\n57f8d9bbe73f6760f10e916a\n5bf7d63575c26f32dbf7413b\n5be4ab93870d330ff2dce134\n5bd43b4ba6b28b1ee86b92dd\n5bccd6beca24970bce448134\n5bc5f0e896b66a2cd8f9bd36\n5b908d3dc6ab78485f3d24a9\n5b2c67b5e0878c381608b8d8\n5b4933abf2b5f44e95de482a\n5b3b353d8d46a939f93524b9\n5acf8ca0f3d8a750097e4b15\n5ab8713ba3799a1d138bd69a\n5aa235f64a17b335eeaf9609\n5aa0f9d7a9efce63548c69a1\n5a8315f624b8e938486e0bd8\n5a48c4e9c7dab83a7d7b5cc7\n59ecfd02e225f6492d20fcc9\n59f87d0bfa6280566fb38c9a\n59f363a8b45be22330016cad\n59f70ab1e5c5d366af29bf3e\n59e75a2ca9e91f2c5526005d\n5947719bf1b45630bd096665\n5947b62af1b45630bd0c2a02\n59056e6760bb961de55f3501\n58f7f7299f5b5647873cb110\n58cf4771d0f5fb221defe6da\n58d36897f387231e6c929903\n58c4bb4f4a69c55606122be4\n"
  },
  {
    "path": "lists/blendedmvs/val.txt",
    "content": "5b7a3890fc8fcf6781e2593a\r\n5c189f2326173c3a09ed7ef3\r\n5b950c71608de421b1e7318f\r\n5a6400933d809f1d8200af15\r\n59d2657f82ca7774b1ec081d\r\n5ba19a8a360c7c30c1c169df\r\n59817e4a1bd4b175e7038d19\r\n"
  },
  {
    "path": "lists/dtu/test.txt",
    "content": "scan1\nscan4\nscan9\nscan10\nscan11\nscan12\nscan13\nscan15\nscan23\nscan24\nscan29\nscan32\nscan33\nscan34\nscan48\nscan49\nscan62\nscan75\nscan77\nscan110\nscan114\nscan118"
  },
  {
    "path": "lists/dtu/train.txt",
    "content": "scan2\nscan6\nscan7\nscan8\nscan14\nscan16\nscan18\nscan19\nscan20\nscan22\nscan30\nscan31\nscan36\nscan39\nscan41\nscan42\nscan44\nscan45\nscan46\nscan47\nscan50\nscan51\nscan52\nscan53\nscan55\nscan57\nscan58\nscan60\nscan61\nscan63\nscan64\nscan65\nscan68\nscan69\nscan70\nscan71\nscan72\nscan74\nscan76\nscan83\nscan84\nscan85\nscan87\nscan88\nscan89\nscan90\nscan91\nscan92\nscan93\nscan94\nscan95\nscan96\nscan97\nscan98\nscan99\nscan100\nscan101\nscan102\nscan103\nscan104\nscan105\nscan107\nscan108\nscan109\nscan111\nscan112\nscan113\nscan115\nscan116\nscan119\nscan120\nscan121\nscan122\nscan123\nscan124\nscan125\nscan126\nscan127\nscan128"
  },
  {
    "path": "lists/dtu/trainval.txt",
    "content": "scan2\nscan6\nscan7\nscan8\nscan14\nscan16\nscan18\nscan19\nscan20\nscan22\nscan30\nscan31\nscan36\nscan39\nscan41\nscan42\nscan44\nscan45\nscan46\nscan47\nscan50\nscan51\nscan52\nscan53\nscan55\nscan57\nscan58\nscan60\nscan61\nscan63\nscan64\nscan65\nscan68\nscan69\nscan70\nscan71\nscan72\nscan74\nscan76\nscan83\nscan84\nscan85\nscan87\nscan88\nscan89\nscan90\nscan91\nscan92\nscan93\nscan94\nscan95\nscan96\nscan97\nscan98\nscan99\nscan100\nscan101\nscan102\nscan103\nscan104\nscan105\nscan107\nscan108\nscan109\nscan111\nscan112\nscan113\nscan115\nscan116\nscan119\nscan120\nscan121\nscan122\nscan123\nscan124\nscan125\nscan126\nscan127\nscan128\nscan3\nscan5\nscan17\nscan21\nscan28\nscan35\nscan37\nscan38\nscan40\nscan43\nscan56\nscan59\nscan66\nscan67\nscan82\nscan86\nscan106\nscan117"
  },
  {
    "path": "lists/dtu/val.txt",
    "content": "scan3\nscan5\nscan17\nscan21\nscan28\nscan35\nscan37\nscan38\nscan40\nscan43\nscan56\nscan59\nscan66\nscan67\nscan82\nscan86\nscan106\nscan117"
  },
  {
    "path": "models/MVS4Net.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom models.mvs4net_utils import stagenet, reg2d, reg3d, FPN4, FPN4_convnext, FPN4_convnext4, PosEncSine, PosEncLearned, \\\n        init_range, schedule_range, init_inverse_range, schedule_inverse_range, sinkhorn, mono_depth_decoder, ASFF\n\n\nclass MVS4net(nn.Module):\n    def __init__(self, arch_mode=\"fpn\", reg_net='reg2d', num_stage=4, fpn_base_channel=8, \n                reg_channel=8, stage_splits=[8,8,4,4], depth_interals_ratio=[0.5,0.5,0.5,1],\n                group_cor=False, group_cor_dim=[8,8,8,8],\n                inverse_depth=False,\n                agg_type='ConvBnReLU3D',\n                dcn=False,\n                pos_enc=0,\n                mono=False,\n                asff=False,\n                attn_temp=2,\n                attn_fuse_d=True,\n                vis_ETA=False,\n                vis_mono=False\n                ):\n        # pos_enc: 0 no pos enc; 1 depth sine; 2 learnable pos enc\n        super(MVS4net, self).__init__()\n        self.arch_mode = arch_mode\n        self.num_stage = num_stage\n        self.depth_interals_ratio = depth_interals_ratio\n        self.group_cor = group_cor\n        self.group_cor_dim = group_cor_dim\n        self.inverse_depth = inverse_depth\n        self.asff = asff\n        if self.asff:\n            self.asff = nn.ModuleList([ASFF(i) for i in range(num_stage)])\n        self.attn_ob = nn.ModuleList()\n        if arch_mode == \"fpn\":\n            self.feature = FPN4(base_channels=fpn_base_channel, gn=False, dcn=dcn)\n        self.vis_mono = vis_mono\n        self.stagenet = stagenet(inverse_depth, mono, attn_fuse_d, vis_ETA, attn_temp)\n        self.stage_splits = stage_splits\n        self.reg = nn.ModuleList()\n        self.pos_enc = pos_enc\n        self.pos_enc_func = nn.ModuleList()\n        self.mono = mono\n        if self.mono:\n            self.mono_depth_decoder = mono_depth_decoder()\n        if reg_net == 'reg3d':\n            self.down_size = [3,3,2,2]\n        for idx in range(num_stage):\n            if self.group_cor:\n                in_dim = group_cor_dim[idx]\n            else:\n                in_dim = self.feature.out_channels[idx]\n            if reg_net == 'reg2d':\n                self.reg.append(reg2d(input_channel=in_dim, base_channel=reg_channel, conv_name=agg_type))\n            elif reg_net == 'reg3d':\n                self.reg.append(reg3d(in_channels=in_dim, base_channels=reg_channel, down_size=self.down_size[idx]))\n\n\n    def forward(self, imgs, proj_matrices, depth_values, filename=None):\n        depth_min = depth_values[:, 0].cpu().numpy()\n        depth_max = depth_values[:, -1].cpu().numpy()\n        depth_interval = (depth_max - depth_min) / depth_values.size(1)\n\n        # step 1. feature extraction\n        features = []\n        for nview_idx in range(len(imgs)):  #imgs shape (B, N, C, H, W)\n            img = imgs[nview_idx]\n            features.append(self.feature(img))\n        if self.vis_mono:\n            scan_name = filename[0].split('/')[0]\n            image_name = filename[0].split('/')[2][:-2]\n            save_fn = './debug_figs/vis_mono/feat_{}'.format(scan_name+'_'+image_name)\n            feat_ = features[-1]['stage4'].detach().cpu().numpy()\n            np.save(save_fn, feat_)\n        # step 2. iter (multi-scale)\n        outputs = {}\n        for stage_idx in range(self.num_stage):\n            if not self.asff:\n                features_stage = [feat[\"stage{}\".format(stage_idx+1)] for feat in features]\n            else:\n                features_stage = [self.asff[stage_idx](feat['stage1'],feat['stage2'],feat['stage3'],feat['stage4']) for feat in features]\n\n            proj_matrices_stage = proj_matrices[\"stage{}\".format(stage_idx + 1)]\n            B,C,H,W = features[0]['stage{}'.format(stage_idx+1)].shape\n\n            # init range\n            if stage_idx == 0:\n                if self.inverse_depth:\n                    depth_hypo = init_inverse_range(depth_values, self.stage_splits[stage_idx], img[0].device, img[0].dtype, H, W)\n                else:\n                    depth_hypo = init_range(depth_values, self.stage_splits[stage_idx], img[0].device, img[0].dtype, H, W)\n            else:\n                if self.inverse_depth:\n                    depth_hypo = schedule_inverse_range(outputs_stage['inverse_min_depth'].detach(), outputs_stage['inverse_max_depth'].detach(), self.stage_splits[stage_idx], H, W)  # B D H W\n                else:\n                    depth_hypo = schedule_range(outputs_stage['depth'].detach(), self.stage_splits[stage_idx], self.depth_interals_ratio[stage_idx] * depth_interval, H, W)\n\n            outputs_stage = self.stagenet(features_stage, proj_matrices_stage, depth_hypo=depth_hypo, regnet=self.reg[stage_idx], stage_idx=stage_idx,\n                                        group_cor=self.group_cor, group_cor_dim=self.group_cor_dim[stage_idx],\n                                        split_itv=self.depth_interals_ratio[stage_idx],\n                                        fn=filename)\n\n            outputs[\"stage{}\".format(stage_idx + 1)] = outputs_stage\n            outputs.update(outputs_stage)\n        \n        if self.mono and self.training:\n        # if self.mono:\n            outputs = self.mono_depth_decoder(outputs, depth_values[:,0], depth_values[:,1])\n\n        return outputs\n\ndef MVS4net_loss(inputs, depth_gt_ms, mask_ms, **kwargs):\n    stage_lw = kwargs.get(\"stage_lw\", [1,1,1,1])\n    l1ot_lw = kwargs.get(\"l1ot_lw\", [0,1])\n    inverse = kwargs.get(\"inverse_depth\", False)\n    ot_iter = kwargs.get(\"ot_iter\", 3)\n    ot_eps = kwargs.get(\"ot_eps\", 1)\n    ot_continous = kwargs.get(\"ot_continous\", False)\n    mono = kwargs.get(\"mono\", False)\n    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False)\n    stage_ot_loss = []\n    stage_l1_loss = []\n    range_err_ratio = []\n    for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if \"stage\" in k]):\n        depth_pred = stage_inputs['depth']\n        hypo_depth = stage_inputs['hypo_depth']\n        attn_weight = stage_inputs['attn_weight']\n        B,H,W = depth_pred.shape\n        D = hypo_depth.shape[1]\n        mask = mask_ms[stage_key]\n        mask = mask > 0.5\n        depth_gt = depth_gt_ms[stage_key]\n\n        if mono and stage_idx!=0:\n            this_stage_l1_loss = F.l1_loss(stage_inputs['mono_depth'][mask], depth_gt[mask], reduction='mean')\n        else:\n            this_stage_l1_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False)\n\n        # mask range\n        if inverse:\n            depth_itv = (1/hypo_depth[:,2,:,:]-1/hypo_depth[:,1,:,:]).abs()  # B H W\n            mask_out_of_range = ((1/hypo_depth - 1/depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W\n        else:\n            depth_itv = (hypo_depth[:,2,:,:]-hypo_depth[:,1,:,:]).abs()  # B H W\n            mask_out_of_range = ((hypo_depth - depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W\n        range_err_ratio.append(mask_out_of_range[mask].float().mean())\n\n        this_stage_ot_loss = sinkhorn(depth_gt, hypo_depth, attn_weight, mask, iters=ot_iter, eps=ot_eps, continuous=ot_continous)[1]\n\n        stage_l1_loss.append(this_stage_l1_loss)\n        stage_ot_loss.append(this_stage_ot_loss)\n        total_loss = total_loss + stage_lw[stage_idx] * (l1ot_lw[0] * this_stage_l1_loss + l1ot_lw[1] * this_stage_ot_loss)\n\n    return total_loss, stage_l1_loss, stage_ot_loss, range_err_ratio\n\n\ndef Blend_loss(inputs, depth_gt_ms, mask_ms, **kwargs):\n    stage_lw = kwargs.get(\"stage_lw\", [1,1,1,1])\n    l1ot_lw = kwargs.get(\"l1ot_lw\", [0,1])\n    inverse = kwargs.get(\"inverse_depth\", False)\n    ot_iter = kwargs.get(\"ot_iter\", 3)\n    ot_eps = kwargs.get(\"ot_eps\", 1)\n    ot_continous = kwargs.get(\"ot_continous\", False)\n    depth_max = kwargs.get(\"depth_max\", 100)\n    depth_min = kwargs.get(\"depth_min\", 1)\n    mono = kwargs.get(\"mono\", False)\n    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False)\n    stage_ot_loss = []\n    stage_l1_loss = []\n    range_err_ratio = []\n    for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if \"stage\" in k]):\n        depth_pred = stage_inputs['depth']\n        hypo_depth = stage_inputs['hypo_depth']\n        attn_weight = stage_inputs['attn_weight']\n        B,H,W = depth_pred.shape\n        mask = mask_ms[stage_key]\n        mask = mask > 0.5\n        depth_gt = depth_gt_ms[stage_key]\n        depth_pred_norm = depth_pred * 128 / (depth_max - depth_min)[:,None,None]  # B H W\n        depth_gt_norm = depth_gt * 128 / (depth_max - depth_min)[:,None,None]  # B H W\n\n        if mono and stage_idx!=0:\n            this_stage_l1_loss = F.l1_loss(stage_inputs['mono_depth'][mask], depth_gt[mask], reduction='mean')\n        else:\n            this_stage_l1_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False)\n\n        if inverse:\n            depth_itv = (1/hypo_depth[:,2,:,:]-1/hypo_depth[:,1,:,:]).abs()  # B H W\n            mask_out_of_range = ((1/hypo_depth - 1/depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W\n        else:\n            depth_itv = (hypo_depth[:,2,:,:]-hypo_depth[:,1,:,:]).abs()  # B H W\n            mask_out_of_range = ((hypo_depth - depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W\n        range_err_ratio.append(mask_out_of_range[mask].float().mean())\n\n        this_stage_ot_loss = sinkhorn(depth_gt, hypo_depth, attn_weight, mask, iters=ot_iter, eps=ot_eps, continuous=ot_continous)[1]\n\n        stage_l1_loss.append(this_stage_l1_loss)\n        stage_ot_loss.append(this_stage_ot_loss)\n        total_loss = total_loss + stage_lw[stage_idx] * (l1ot_lw[0] * this_stage_l1_loss + l1ot_lw[1] * this_stage_ot_loss)\n\n    abs_err = torch.abs(depth_pred_norm[mask] - depth_gt_norm[mask])\n    epe = abs_err.mean()\n    err3 = (abs_err<=3).float().mean()*100\n    err1= (abs_err<=1).float().mean()*100\n    return total_loss, stage_l1_loss, stage_ot_loss, range_err_ratio, epe, err3, err1"
  },
  {
    "path": "models/__init__.py",
    "content": "\nfrom models.MVS4Net import MVS4net, MVS4net_loss, Blend_loss"
  },
  {
    "path": "models/module.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport time\nimport sys\nimport seaborn as sns\nimport numpy as np\nimport matplotlib.pyplot as plt\nsys.path.append(\"..\")\nfrom utils import local_pcd\nfrom modules.deform_conv import DeformConvPack\n\n\ndef init_bn(module):\n    if module.weight is not None:\n        nn.init.ones_(module.weight)\n    if module.bias is not None:\n        nn.init.zeros_(module.bias)\n    return\n\n\ndef init_uniform(module, init_method):\n    if module.weight is not None:\n        if init_method == \"kaiming\":\n            nn.init.kaiming_uniform_(module.weight)\n        elif init_method == \"xavier\":\n            nn.init.xavier_uniform_(module.weight)\n    return\n\nclass Conv2d(nn.Module):\n    \"\"\"Applies a 2D convolution (optionally with batch normalization and relu activation)\n    over an input signal composed of several input planes.\n\n    Attributes:\n        conv (nn.Module): convolution module\n        bn (nn.Module): batch normalization module\n        relu (bool): whether to activate by relu\n\n    Notes:\n        Default momentum for batch normalization is set to be 0.01,\n\n    \"\"\"\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 relu=True, bn=True, bn_momentum=0.1, init_method=\"xavier\", **kwargs):\n        super(Conv2d, self).__init__()\n\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,\n                              bias=(not bn), **kwargs)\n        self.kernel_size = kernel_size\n        self.stride = stride\n        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None\n        self.relu = relu\n\n        # assert init_method in [\"kaiming\", \"xavier\"]\n        # self.init_weights(init_method)\n\n    def forward(self, x):\n        x = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(x)\n        if self.relu:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def init_weights(self, init_method):\n        \"\"\"default initialization\"\"\"\n        init_uniform(self.conv, init_method)\n        if self.bn is not None:\n            init_bn(self.bn)\n\nclass DCNConv2d(nn.Module):\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 relu=True, bn=True, bn_momentum=0.1, init_method=\"xavier\", **kwargs):\n        super(DCNConv2d, self).__init__()\n\n        self.conv = DeformConvPack(in_channels, out_channels, kernel_size, stride=stride, padding=1, bias=(not bn), im2col_step=16)\n        self.kernel_size = kernel_size\n        self.stride = stride\n        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None\n        self.relu = relu\n\n        # assert init_method in [\"kaiming\", \"xavier\"]\n        # self.init_weights(init_method)\n\n    def forward(self, x):\n        x = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(x)\n        if self.relu:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def init_weights(self, init_method):\n        \"\"\"default initialization\"\"\"\n        init_uniform(self.conv, init_method)\n        if self.bn is not None:\n            init_bn(self.bn)\n\nclass Deconv2d(nn.Module):\n    \"\"\"Applies a 2D deconvolution (optionally with batch normalization and relu activation)\n       over an input signal composed of several input planes.\n\n       Attributes:\n           conv (nn.Module): convolution module\n           bn (nn.Module): batch normalization module\n           relu (bool): whether to activate by relu\n\n       Notes:\n           Default momentum for batch normalization is set to be 0.01,\n\n       \"\"\"\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 relu=True, bn=True, bn_momentum=0.1, init_method=\"xavier\", **kwargs):\n        super(Deconv2d, self).__init__()\n        self.out_channels = out_channels\n        assert stride in [1, 2]\n        self.stride = stride\n\n        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride,\n                                       bias=(not bn), **kwargs)\n        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None\n        self.relu = relu\n\n        # assert init_method in [\"kaiming\", \"xavier\"]\n        # self.init_weights(init_method)\n\n    def forward(self, x):\n        y = self.conv(x)\n        if self.stride == 2:\n            h, w = list(x.size())[2:]\n            y = y[:, :, :2 * h, :2 * w].contiguous()\n        if self.bn is not None:\n            x = self.bn(y)\n        if self.relu:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def init_weights(self, init_method):\n        \"\"\"default initialization\"\"\"\n        init_uniform(self.conv, init_method)\n        if self.bn is not None:\n            init_bn(self.bn)\n\nclass Conv3d(nn.Module):\n    \"\"\"Applies a 3D convolution (optionally with batch normalization and relu activation)\n    over an input signal composed of several input planes.\n\n    Attributes:\n        conv (nn.Module): convolution module\n        bn (nn.Module): batch normalization module\n        relu (bool): whether to activate by relu\n\n    Notes:\n        Default momentum for batch normalization is set to be 0.01,\n\n    \"\"\"\n\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,\n                 relu=True, bn=True, bn_momentum=0.1, init_method=\"xavier\", **kwargs):\n        super(Conv3d, self).__init__()\n        self.out_channels = out_channels\n        self.kernel_size = kernel_size\n        assert stride in [1, 2]\n        self.stride = stride\n\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride,\n                              bias=(not bn), **kwargs)\n        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None\n        self.relu = relu\n\n        # assert init_method in [\"kaiming\", \"xavier\"]\n        # self.init_weights(init_method)\n\n    def forward(self, x):\n        x = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(x)\n        if self.relu:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def init_weights(self, init_method):\n        \"\"\"default initialization\"\"\"\n        init_uniform(self.conv, init_method)\n        if self.bn is not None:\n            init_bn(self.bn)\n\nclass PConv3d(nn.Module):\n\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,\n                 relu=True, bn=True, bn_momentum=0.1, padding=1, init_method=\"xavier\", **kwargs):\n        super(PConv3d, self).__init__()\n        self.out_channels = out_channels\n        self.kernel_size_xy = (1, kernel_size, kernel_size)\n        self.kernel_size_d = (kernel_size, 1, 1)\n        assert stride in [1, 2]\n        self.stride_xy = (1, stride, stride)\n        self.stride_d = (stride, 1, 1)\n        self.padding_xy = (0, padding, padding)\n        self.padding_d = (padding, 0, 0)\n\n        self.convxy = nn.Conv3d(in_channels, in_channels, self.kernel_size_xy, stride=self.stride_xy, padding=self.padding_xy, bias=(not bn), **kwargs)\n        self.convd = nn.Conv3d(in_channels, out_channels, self.kernel_size_d, stride=self.stride_d, padding=self.padding_d, bias=(not bn), **kwargs)\n        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None\n        self.relu = relu\n\n        # assert init_method in [\"kaiming\", \"xavier\"]\n        # self.init_weights(init_method)\n\n    def forward(self, x):\n        x = self.convxy(x)\n        x = self.convd(x)\n        if self.bn is not None:\n            x = self.bn(x)\n        if self.relu:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def init_weights(self, init_method):\n        \"\"\"default initialization\"\"\"\n        init_uniform(self.convxy, init_method)\n        init_uniform(self.convd, init_method)\n        if self.bn is not None:\n            init_bn(self.bn)\n\n\nclass Deconv3d(nn.Module):\n    \"\"\"Applies a 3D deconvolution (optionally with batch normalization and relu activation)\n       over an input signal composed of several input planes.\n\n       Attributes:\n           conv (nn.Module): convolution module\n           bn (nn.Module): batch normalization module\n           relu (bool): whether to activate by relu\n\n       Notes:\n           Default momentum for batch normalization is set to be 0.01,\n\n       \"\"\"\n\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,\n                 relu=True, bn=True, bn_momentum=0.1, init_method=\"xavier\", **kwargs):\n        super(Deconv3d, self).__init__()\n        self.out_channels = out_channels\n        assert stride in [1, 2]\n        self.stride = stride\n\n        self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,\n                                       bias=(not bn), **kwargs)\n        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None\n        self.relu = relu\n\n        # assert init_method in [\"kaiming\", \"xavier\"]\n        # self.init_weights(init_method)\n\n    def forward(self, x):\n        y = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(y)\n        if self.relu:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def init_weights(self, init_method):\n        \"\"\"default initialization\"\"\"\n        init_uniform(self.conv, init_method)\n        if self.bn is not None:\n            init_bn(self.bn)\n\n\nclass PDeconv3d(nn.Module):\n\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,output_padding=1,\n                 relu=True, bn=True, bn_momentum=0.1, init_method=\"xavier\", **kwargs):\n        super(PDeconv3d, self).__init__()\n        self.out_channels = out_channels\n        assert stride in [1, 2]\n        self.stride = stride\n        self.kernel_size_xy = (1, kernel_size,kernel_size)\n        self.kernel_size_d = (kernel_size, 1,1)\n        self.stride_xy = (1, stride, stride)\n        self.stride_d = (stride, 1, 1)\n        self.padding_xy = (0, padding, padding)\n        self.padding_d = (padding, 0, 0)\n        self.outpadding_xy = (0, output_padding, output_padding)\n        self.outpadding_d = (output_padding, 0, 0)\n        self.convxy = nn.ConvTranspose3d(in_channels, in_channels, self.kernel_size_xy, stride=self.stride_xy, padding=self.padding_xy, output_padding=self.outpadding_xy, bias=(not bn))\n        self.convd = nn.ConvTranspose3d(in_channels, out_channels, self.kernel_size_d, stride=self.stride_d, padding=self.padding_d, output_padding=self.outpadding_d, bias=(not bn))\n        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None\n        self.relu = relu\n\n        # assert init_method in [\"kaiming\", \"xavier\"]\n        # self.init_weights(init_method)\n\n    def forward(self, x):\n        x = self.convxy(x)\n        y = self.convd(x)\n        if self.bn is not None:\n            x = self.bn(y)\n        if self.relu:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def init_weights(self, init_method):\n        \"\"\"default initialization\"\"\"\n        init_uniform(self.convxy, init_method)\n        init_uniform(self.convd, init_method)\n        if self.bn is not None:\n            init_bn(self.bn)\n\nclass ConvBnReLU(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):\n        super(ConvBnReLU, self).__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)\n        self.bn = nn.BatchNorm2d(out_channels)\n\n    def forward(self, x):\n        return F.relu(self.bn(self.conv(x)), inplace=True)\n\nclass ConvBn(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):\n        super(ConvBn, self).__init__()\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)\n        self.bn = nn.BatchNorm2d(out_channels)\n\n    def forward(self, x):\n        return self.bn(self.conv(x))\n\nclass ConvBnReLU3D(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):\n        super(ConvBnReLU3D, self).__init__()\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)\n        self.bn = nn.BatchNorm3d(out_channels)\n\n    def forward(self, x):\n        return F.relu(self.bn(self.conv(x)), inplace=True)\n\n\nclass ConvBn3D(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):\n        super(ConvBn3D, self).__init__()\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)\n        self.bn = nn.BatchNorm3d(out_channels)\n\n    def forward(self, x):\n        return self.bn(self.conv(x))\n\n\nclass BasicBlock(nn.Module):\n    def __init__(self, in_channels, out_channels, stride, downsample=None):\n        super(BasicBlock, self).__init__()\n\n        self.conv1 = ConvBnReLU(in_channels, out_channels, kernel_size=3, stride=stride, pad=1)\n        self.conv2 = ConvBn(out_channels, out_channels, kernel_size=3, stride=1, pad=1)\n\n        self.downsample = downsample\n        self.stride = stride\n\n    def forward(self, x):\n        out = self.conv1(x)\n        out = self.conv2(out)\n        if self.downsample is not None:\n            x = self.downsample(x)\n        out += x\n        return out\n\n\nclass Hourglass3d(nn.Module):\n    def __init__(self, channels):\n        super(Hourglass3d, self).__init__()\n\n        self.conv1a = ConvBnReLU3D(channels, channels * 2, kernel_size=3, stride=2, pad=1)\n        self.conv1b = ConvBnReLU3D(channels * 2, channels * 2, kernel_size=3, stride=1, pad=1)\n\n        self.conv2a = ConvBnReLU3D(channels * 2, channels * 4, kernel_size=3, stride=2, pad=1)\n        self.conv2b = ConvBnReLU3D(channels * 4, channels * 4, kernel_size=3, stride=1, pad=1)\n\n        self.dconv2 = nn.Sequential(\n            nn.ConvTranspose3d(channels * 4, channels * 2, kernel_size=3, padding=1, output_padding=1, stride=2,\n                               bias=False),\n            nn.BatchNorm3d(channels * 2))\n\n        self.dconv1 = nn.Sequential(\n            nn.ConvTranspose3d(channels * 2, channels, kernel_size=3, padding=1, output_padding=1, stride=2,\n                               bias=False),\n            nn.BatchNorm3d(channels))\n\n        self.redir1 = ConvBn3D(channels, channels, kernel_size=1, stride=1, pad=0)\n        self.redir2 = ConvBn3D(channels * 2, channels * 2, kernel_size=1, stride=1, pad=0)\n\n    def forward(self, x):\n        conv1 = self.conv1b(self.conv1a(x))\n        conv2 = self.conv2b(self.conv2a(conv1))\n        dconv2 = F.relu(self.dconv2(conv2) + self.redir2(conv1), inplace=True)\n        dconv1 = F.relu(self.dconv1(dconv2) + self.redir1(x), inplace=True)\n        return dconv1\n\n\ndef homo_warping(src_fea, src_proj, ref_proj, depth_values, align_corners=False):\n    # src_fea: [B, C, H, W]\n    # src_proj: [B, 4, 4]\n    # ref_proj: [B, 4, 4]\n    # depth_values: [B, Ndepth] o [B, Ndepth, H, W]\n    # out: [B, C, Ndepth, H, W]\n    batch, channels = src_fea.shape[0], src_fea.shape[1]\n    num_depth = depth_values.shape[1]\n    height, width = src_fea.shape[2], src_fea.shape[3]\n\n    with torch.no_grad():\n        proj = torch.matmul(src_proj, torch.inverse(ref_proj))\n        rot = proj[:, :3, :3]  # [B,3,3]\n        trans = proj[:, :3, 3:4]  # [B,3,1]\n\n        y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device),\n                               torch.arange(0, width, dtype=torch.float32, device=src_fea.device)])\n        y, x = y.contiguous(), x.contiguous()\n        y, x = y.view(height * width), x.view(height * width)\n        xyz = torch.stack((x, y, torch.ones_like(x)))  # [3, H*W]\n        xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1)  # [B, 3, H*W]\n        rot_xyz = torch.matmul(rot, xyz)  # [B, 3, H*W]\n        rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.view(batch, 1, num_depth, -1)  # [B, 3, Ndepth, H*W]\n        proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1)  # [B, 3, Ndepth, H*W]\n        proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :]  # [B, 2, Ndepth, H*W]\n        proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1\n        proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1\n        proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3)  # [B, Ndepth, H*W, 2]\n        grid = proj_xy\n\n    warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear', padding_mode='zeros', align_corners=align_corners)\n    warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width)\n\n    return warped_src_fea\n\nclass DeConv2dFuse(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size, relu=True, bn=True,\n                 bn_momentum=0.1):\n        super(DeConv2dFuse, self).__init__()\n\n        self.deconv = Deconv2d(in_channels, out_channels, kernel_size, stride=2, padding=1, output_padding=1,\n                               bn=True, relu=relu, bn_momentum=bn_momentum)\n\n        self.conv = Conv2d(2*out_channels, out_channels, kernel_size, stride=1, padding=1,\n                           bn=bn, relu=relu, bn_momentum=bn_momentum)\n\n        # assert init_method in [\"kaiming\", \"xavier\"]\n        # self.init_weights(init_method)\n\n    def forward(self, x_pre, x):\n        x = self.deconv(x)\n        x = torch.cat((x, x_pre), dim=1)\n        x = self.conv(x)\n        return x\n\n\nclass FeatureNet(nn.Module):\n    def __init__(self, base_channels, num_stage=3, stride=4, arch_mode=\"unet\"):\n        super(FeatureNet, self).__init__()\n        assert arch_mode in [\"unet\", \"fpn\"], print(\"mode must be in 'unet' or 'fpn', but get:{}\".format(arch_mode))\n        print(\"*************feature extraction arch mode:{}****************\".format(arch_mode))\n        self.arch_mode = arch_mode\n        self.stride = stride\n        self.base_channels = base_channels\n        self.num_stage = num_stage\n\n        self.conv0 = nn.Sequential(\n            Conv2d(3, base_channels, 3, 1, padding=1),\n            Conv2d(base_channels, base_channels, 3, 1, padding=1),\n        )\n\n        self.conv1 = nn.Sequential(\n            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),\n        )\n\n        self.conv2 = nn.Sequential(\n            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),\n        )\n\n        self.out1 = nn.Conv2d(base_channels * 4, base_channels * 4, 1, bias=False)\n        self.out_channels = [4 * base_channels]\n\n        if self.arch_mode == 'unet':\n            if num_stage == 3:\n                self.deconv1 = DeConv2dFuse(base_channels * 4, base_channels * 2, 3)\n                self.deconv2 = DeConv2dFuse(base_channels * 2, base_channels, 3)\n\n                self.out2 = nn.Conv2d(base_channels * 2, base_channels * 2, 1, bias=False)\n                self.out3 = nn.Conv2d(base_channels, base_channels, 1, bias=False)\n                self.out_channels.append(2 * base_channels)\n                self.out_channels.append(base_channels)\n\n            elif num_stage == 2:\n                self.deconv1 = DeConv2dFuse(base_channels * 4, base_channels * 2, 3)\n\n                self.out2 = nn.Conv2d(base_channels * 2, base_channels * 2, 1, bias=False)\n                self.out_channels.append(2 * base_channels)\n        elif self.arch_mode == \"fpn\":\n            final_chs = base_channels * 4\n            if num_stage == 3:\n                self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)\n                self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)\n\n                self.out2 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)\n                self.out3 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)\n                self.out_channels.append(base_channels * 2)\n                self.out_channels.append(base_channels)\n\n            elif num_stage == 2:\n                self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)\n\n                self.out2 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)\n                self.out_channels.append(base_channels)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv1 = self.conv1(conv0)\n        conv2 = self.conv2(conv1)\n\n        intra_feat = conv2\n        outputs = {}\n        out = self.out1(intra_feat)\n        outputs[\"stage1\"] = out\n        if self.arch_mode == \"unet\":\n            if self.num_stage == 3:\n                intra_feat = self.deconv1(conv1, intra_feat)\n                out = self.out2(intra_feat)\n                outputs[\"stage2\"] = out\n\n                intra_feat = self.deconv2(conv0, intra_feat)\n                out = self.out3(intra_feat)\n                outputs[\"stage3\"] = out\n\n            elif self.num_stage == 2:\n                intra_feat = self.deconv1(conv1, intra_feat)\n                out = self.out2(intra_feat)\n                outputs[\"stage2\"] = out\n\n        elif self.arch_mode == \"fpn\":\n            if self.num_stage == 3:\n                intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"nearest\") + self.inner1(conv1)\n                out = self.out2(intra_feat)\n                outputs[\"stage2\"] = out\n\n                intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"nearest\") + self.inner2(conv0)\n                out = self.out3(intra_feat)\n                outputs[\"stage3\"] = out\n\n            elif self.num_stage == 2:\n                intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"nearest\") + self.inner1(conv1)\n                out = self.out2(intra_feat)\n                outputs[\"stage2\"] = out\n\n        return outputs\n\nclass FPNDCNpath(nn.Module):\n    \"\"\"\n    FPN+DCN pathway\"\"\"\n    def __init__(self, base_channels, stride=4):\n        super(FPNDCNpath, self).__init__()\n        self.stride = stride\n        self.base_channels = base_channels\n\n        self.conv0 = nn.Sequential(\n            Conv2d(3, base_channels, 3, 1, padding=1),\n            Conv2d(base_channels, base_channels, 3, 1, padding=1),\n        )\n\n        self.conv1 = nn.Sequential(\n            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),\n        )\n\n        self.conv2 = nn.Sequential(\n            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),\n        )\n\n        self.out1 = nn.Sequential(\n            DCNConv2d(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1),\n            DCNConv2d(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1),\n            DeformConvPack(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1, bias=False, im2col_step=16)\n        )\n        self.out_channels = [4 * base_channels]\n\n        final_chs = base_channels * 4\n        self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)\n        self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)\n\n        self.out2 = nn.Sequential(\n            DCNConv2d(base_channels * 4, base_channels * 2, 3,  stride=1, padding=1),\n            DCNConv2d(base_channels * 2, base_channels * 2, 3,  stride=1, padding=1),\n            DeformConvPack(base_channels * 2, base_channels * 2, 3,  stride=1, padding=1, bias=False, im2col_step=16)\n        )\n        self.out2pathconv = nn.Conv2d(base_channels * 4, base_channels * 2, 3,  stride=1, padding=1)\n        self.out3 = nn.Sequential(\n            DCNConv2d(base_channels * 4, base_channels * 1, 3,  stride=1, padding=1),\n            DCNConv2d(base_channels * 1, base_channels * 1, 3,  stride=1, padding=1),\n            DeformConvPack(base_channels * 1, base_channels * 1, 3,  stride=1, padding=1, bias=False, im2col_step=16)\n        )\n        self.out3pathconv = nn.Conv2d(base_channels * 2, base_channels * 1, 3,  stride=1, padding=1)\n        self.out_channels.append(base_channels * 2)\n        self.out_channels.append(base_channels)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv1 = self.conv1(conv0)\n        conv2 = self.conv2(conv1)\n\n        intra_feat = conv2\n        outputs = {}\n        out1 = self.out1(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner1(conv1)\n        out2 = self.out2(intra_feat)\n        out2 = out2 + self.out2pathconv(F.interpolate(out1, scale_factor=2, mode=\"bilinear\", align_corners=True))\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner2(conv0)\n        out3 = self.out3(intra_feat)\n        out3 = out3 + self.out3pathconv(F.interpolate(out2, scale_factor=2, mode=\"bilinear\", align_corners=True))\n\n        outputs[\"stage1\"] = out1\n        outputs[\"stage2\"] = out2\n        outputs[\"stage3\"] = out3\n\n        return outputs\n\nclass FPNDCN(nn.Module):\n    \"\"\"\n    FPN+DCN\"\"\"\n    def __init__(self, base_channels, stride=4):\n        super(FPNDCN, self).__init__()\n        self.stride = stride\n        self.base_channels = base_channels\n\n        self.conv0 = nn.Sequential(\n            Conv2d(3, base_channels, 3, 1, padding=1),\n            Conv2d(base_channels, base_channels, 3, 1, padding=1),\n        )\n\n        self.conv1 = nn.Sequential(\n            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),\n        )\n\n        self.conv2 = nn.Sequential(\n            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),\n        )\n\n        self.out1 = nn.Sequential(\n            DCNConv2d(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1),\n            DCNConv2d(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1),\n            DeformConvPack(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1, bias=False, im2col_step=16)\n        )\n        self.out_channels = [4 * base_channels]\n\n        final_chs = base_channels * 4\n        self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)\n        self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)\n\n        self.out2 = nn.Sequential(\n            DCNConv2d(base_channels * 4, base_channels * 2, 3,  stride=1, padding=1),\n            DCNConv2d(base_channels * 2, base_channels * 2, 3,  stride=1, padding=1),\n            DeformConvPack(base_channels * 2, base_channels * 2, 3,  stride=1, padding=1, bias=False, im2col_step=16)\n        )\n        self.out3 = nn.Sequential(\n            DCNConv2d(base_channels * 4, base_channels * 1, 3,  stride=1, padding=1),\n            DCNConv2d(base_channels * 1, base_channels * 1, 3,  stride=1, padding=1),\n            DeformConvPack(base_channels * 1, base_channels * 1, 3,  stride=1, padding=1, bias=False, im2col_step=16)\n        )\n        self.out_channels.append(base_channels * 2)\n        self.out_channels.append(base_channels)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv1 = self.conv1(conv0)\n        conv2 = self.conv2(conv1)\n\n        intra_feat = conv2\n        outputs = {}\n        out1 = self.out1(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner1(conv1)\n        out2 = self.out2(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner2(conv0)\n        out3 = self.out3(intra_feat)\n\n        outputs[\"stage1\"] = out1\n        outputs[\"stage2\"] = out2\n        outputs[\"stage3\"] = out3\n\n        return outputs\n\nclass FPNA(nn.Module):\n    \"\"\"\n    FPN aligncorners\"\"\"\n    def __init__(self, base_channels, stride=4):\n        super(FPNA, self).__init__()\n        self.stride = stride\n        self.base_channels = base_channels\n\n        self.conv0 = nn.Sequential(\n            Conv2d(3, base_channels, 3, 1, padding=1),\n            Conv2d(base_channels, base_channels, 3, 1, padding=1),\n        )\n\n        self.conv1 = nn.Sequential(\n            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),\n        )\n\n        self.conv2 = nn.Sequential(\n            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),\n        )\n\n        self.out1 = nn.Conv2d(base_channels * 4, base_channels * 4, 1, bias=False)\n        self.out_channels = [4 * base_channels]\n\n        final_chs = base_channels * 4\n        self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)\n        self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)\n\n        self.out2 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)\n        self.out3 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)\n\n        self.out_channels.append(base_channels * 2)\n        self.out_channels.append(base_channels)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv1 = self.conv1(conv0)\n        conv2 = self.conv2(conv1)\n\n        intra_feat = conv2\n        outputs = {}\n        out1 = self.out1(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner1(conv1)\n        out2 = self.out2(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner2(conv0)\n        out3 = self.out3(intra_feat)\n\n        outputs[\"stage1\"] = out1\n        outputs[\"stage2\"] = out2\n        outputs[\"stage3\"] = out3\n\n        return outputs\n\nclass FPNA4(nn.Module):\n    \"\"\"\n    FPN aligncorners downsample 4x\"\"\"\n    def __init__(self, base_channels):\n        super(FPNA4, self).__init__()\n        self.base_channels = base_channels\n\n        self.conv0 = nn.Sequential(\n            Conv2d(3, base_channels, 3, 1, padding=1),\n            Conv2d(base_channels, base_channels, 3, 1, padding=1),\n        )\n\n        self.conv1 = nn.Sequential(\n            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),\n        )\n\n        self.conv2 = nn.Sequential(\n            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),\n        )\n\n        self.conv3 = nn.Sequential(\n            Conv2d(base_channels * 4, base_channels * 8, 5, stride=2, padding=2),\n            Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1),\n            Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1),\n        )\n\n        self.out_channels = [8 * base_channels]\n        final_chs = base_channels * 8\n\n        self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)\n        self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)\n        self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)\n\n        self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)\n        self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)\n        self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)\n        self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)\n\n        self.out_channels.append(base_channels * 4)\n        self.out_channels.append(base_channels * 2)\n        self.out_channels.append(base_channels)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv1 = self.conv1(conv0)\n        conv2 = self.conv2(conv1)\n        conv3 = self.conv3(conv2)\n\n        intra_feat = conv3\n        outputs = {}\n        out1 = self.out1(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner1(conv2)\n        out2 = self.out2(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner2(conv1)\n        out3 = self.out3(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner3(conv0)\n        out4 = self.out4(intra_feat)\n\n        outputs[\"stage1\"] = out1\n        outputs[\"stage2\"] = out2\n        outputs[\"stage3\"] = out3\n        outputs[\"stage4\"] = out4\n\n        return outputs\n\nclass CostRegNet(nn.Module):\n    def __init__(self, in_channels, base_channels, down_size=3):\n        super(CostRegNet, self).__init__()\n        self.down_size = down_size\n        self.conv0 = Conv3d(in_channels, base_channels, padding=1)\n\n        self.conv1 = Conv3d(base_channels, base_channels * 2, stride=2, padding=1)\n        self.conv2 = Conv3d(base_channels * 2, base_channels * 2, padding=1)\n\n        if down_size >= 2:\n            self.conv3 = Conv3d(base_channels * 2, base_channels * 4, stride=2, padding=1)\n            self.conv4 = Conv3d(base_channels * 4, base_channels * 4, padding=1)\n\n        if down_size >= 3:\n            self.conv5 = Conv3d(base_channels * 4, base_channels * 8, stride=2, padding=1)\n            self.conv6 = Conv3d(base_channels * 8, base_channels * 8, padding=1)\n            self.conv7 = Deconv3d(base_channels * 8, base_channels * 4, stride=2, padding=1, output_padding=1)\n\n        if down_size >= 2:\n            self.conv9 = Deconv3d(base_channels * 4, base_channels * 2, stride=2, padding=1, output_padding=1)\n            \n        self.conv11 = Deconv3d(base_channels * 2, base_channels * 1, stride=2, padding=1, output_padding=1)\n        self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False)\n\n    def forward(self, x):\n        if self.down_size==3:\n            conv0 = self.conv0(x)\n            conv2 = self.conv2(self.conv1(conv0))\n            conv4 = self.conv4(self.conv3(conv2))\n            x = self.conv6(self.conv5(conv4))\n            x = conv4 + self.conv7(x)\n            x = conv2 + self.conv9(x)\n            x = conv0 + self.conv11(x)\n            x = self.prob(x)\n        elif self.down_size==2:\n            conv0 = self.conv0(x)\n            conv2 = self.conv2(self.conv1(conv0))\n            x = self.conv4(self.conv3(conv2))\n            x = conv2 + self.conv9(x)\n            x = conv0 + self.conv11(x)\n            x = self.prob(x)\n        else:\n            conv0 = self.conv0(x)\n            x = self.conv2(self.conv1(conv0))\n            x = conv0 + self.conv11(x)\n            x = self.prob(x)\n        return x\n\nclass P3DConv(nn.Module):\n    \"\"\"\n    Pseudo 3D conv: 3x3x1 + 1x3x3\n    \"\"\"\n    def __init__(self, in_channels, base_channels):\n        super(P3DConv, self).__init__()\n        self.conv0 = PConv3d(in_channels, base_channels, padding=1)\n\n        self.conv1 = PConv3d(base_channels, base_channels * 2, stride=2, padding=1)\n        self.conv2 = PConv3d(base_channels * 2, base_channels * 2, padding=1)\n\n        self.conv3 = PConv3d(base_channels * 2, base_channels * 4, stride=2, padding=1)\n        self.conv4 = PConv3d(base_channels * 4, base_channels * 4, padding=1)\n\n        self.conv5 = PConv3d(base_channels * 4, base_channels * 8, stride=2, padding=1)\n        self.conv6 = PConv3d(base_channels * 8, base_channels * 8, padding=1)\n\n        self.conv7 = PDeconv3d(base_channels * 8, base_channels * 4, stride=2, padding=1, output_padding=1)\n\n        self.conv9 = PDeconv3d(base_channels * 4, base_channels * 2, stride=2, padding=1, output_padding=1)\n\n        self.conv11 = PDeconv3d(base_channels * 2, base_channels * 1, stride=2, padding=1, output_padding=1)\n\n        self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv2 = self.conv2(self.conv1(conv0))\n        conv4 = self.conv4(self.conv3(conv2))\n        x = self.conv6(self.conv5(conv4))\n        x = conv4 + self.conv7(x)\n        x = conv2 + self.conv9(x)\n        x = conv0 + self.conv11(x)\n        x = self.prob(x)\n        return x\n\nclass RefineNet(nn.Module):\n    def __init__(self):\n        super(RefineNet, self).__init__()\n        self.conv1 = ConvBnReLU(4, 32)\n        self.conv2 = ConvBnReLU(32, 32)\n        self.conv3 = ConvBnReLU(32, 32)\n        self.res = ConvBnReLU(32, 1)\n\n    def forward(self, img, depth_init):\n        concat = F.cat((img, depth_init), dim=1)\n        depth_residual = self.res(self.conv3(self.conv2(self.conv1(concat))))\n        depth_refined = depth_init + depth_residual\n        return depth_refined\n\n\ndef depth_regression(p, depth_values):\n    if depth_values.dim() <= 2:\n        # print(\"regression dim <= 2\")\n        depth_values = depth_values.view(*depth_values.shape, 1, 1)\n    depth = torch.sum(p * depth_values, 1)\n\n    return depth\n\ndef cas_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs):\n    depth_loss_weights = kwargs.get(\"dlossw\", None)\n\n    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False)\n\n    for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if \"stage\" in k]:\n        depth_est = stage_inputs[\"depth\"]\n        depth_gt = depth_gt_ms[stage_key]\n        mask = mask_ms[stage_key]\n        mask = mask > 0.5\n\n        depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean')\n\n        if depth_loss_weights is not None:\n            stage_idx = int(stage_key.replace(\"stage\", \"\")) - 1\n            total_loss += depth_loss_weights[stage_idx] * depth_loss\n        else:\n            total_loss += 1.0 * depth_loss\n\n    return total_loss, depth_loss\n\ndef cas_mvsnet_T_loss(inputs, depth_gt_ms, mask_ms, **kwargs):\n    depth_loss_weights = kwargs.get(\"dlossw\", None)\n    l1ce_lw = kwargs.get(\"l1ce_lw\", [0.1, 1])\n    range_thres = kwargs.get(\"range_thres\", [84.8, 10.6])\n    cas_method = kwargs.get(\"cascade_method\", None)\n    last_conv3d = kwargs.get(\"last_conv3d\", False)\n    visual = kwargs.get(\"visual\", False)\n    wt = kwargs.get(\"wt\", False)\n    fl = kwargs.get(\"fl\", False)\n    shrink_method = kwargs.get(\"shrink_method\", 'schedule')\n    upsampled_loss = kwargs.get(\"upsampled_loss\", False)\n    selected_loss = kwargs.get(\"selected_loss\", False)\n    mask_range_loss = kwargs.get(\"mask_range_loss\", False)\n    det = kwargs.get(\"det\", False)\n    if visual:\n        f, axs = plt.subplots(figsize=(30, 10),ncols=3)  # depth offset\n        f2, axs2 = plt.subplots(figsize=(30, 10),ncols=3)  # attn weight max\n        f3, axs3 = plt.subplots(figsize=(30, 10),ncols=3)  # attn weight gt val\n        f4, axs4 = plt.subplots(figsize=(30, 10),ncols=3)  # max gt offset\n        err_848_str = ''\n        err_106_str = ''\n        err_002_str = ''\n\n    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False)\n    stage_depth_loss = []\n    stage_ce_loss = []\n    range_err_ratio = []\n    upsampled_depth_losses = []\n    det_offset_losses = []\n    for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if \"stage\" in k]):\n        depth_est = stage_inputs[\"depth\"]\n        B,H,W = depth_est.shape\n        mask = mask_ms[stage_key]\n        mask = mask > 0.5\n        depth_gt = depth_gt_ms[stage_key]\n\n        if upsampled_loss:\n            if stage_idx!=0 :\n                upsampled_depth = stage_inputs[\"upsampled_depth\"]\n                upsampled_depth_loss = F.smooth_l1_loss(upsampled_depth[mask], depth_gt[mask], reduction='mean')\n                upsampled_depth_losses.append(upsampled_depth_loss)\n        else:\n            if stage_idx!=0 :\n                upsampled_depth_losses.append(torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False))\n        \n        if mask_range_loss:\n            if stage_idx != 0:\n                depth_offset = next_stage_depth_hypo - depth_gt  # B H W\n                this_stage_mask_range = torch.abs(depth_offset)<range_thres[stage_idx-1]\n                mask = mask & this_stage_mask_range  # B H W\n            next_stage_depth_hypo = F.interpolate(depth_est.unsqueeze(1), scale_factor=2, mode='bilinear', align_corners=True).squeeze(1)\n            \n\n        if stage_idx != len(range_thres):\n            depth_offset = depth_est - depth_gt\n            depth_offset[~mask] = 0\n            depth_offset = depth_offset # B H W\n            range_err_ratio.append((torch.abs(depth_offset)>range_thres[stage_idx]).float().mean())\n\n\n        if visual:\n            depth_offset = depth_est - depth_gt\n            depth_offset[~mask] = 0\n            depth_offset = depth_offset.detach().cpu().numpy()[0] # H W  \n            err_848_str += str((np.abs(depth_offset)>84.8).sum()) + ','\n            err_106_str += str((np.abs(depth_offset)>10.6).sum()) + ','\n            err_002_str += str((np.abs(depth_offset)>2).sum()) + ','\n            sns.heatmap(depth_offset, annot=False, ax=axs[stage_idx])\n\n            attn_weights = stage_inputs[\"attn_weights\"][0]  # D H W\n            attn_weights_max, ind_max = torch.max(attn_weights, 0)\n            attn_weights_max = attn_weights_max.detach().cpu().numpy()  # H W\n            sns.heatmap(attn_weights_max, annot=False, ax=axs2[stage_idx])\n\n            this_stage_depth_val = stage_inputs['depth_values']  # B D H W\n            depth_offsets = torch.abs(this_stage_depth_val- depth_gt[:,None,:,:])[0]  # D,H,W\n            _, indices = torch.min(depth_offsets, dim=0, keepdim=True)  # [1, H, W]\n            attn_gt = torch.gather(attn_weights, 0, indices)[0]  # [H W]\n            attn_gt = attn_gt.detach().cpu().numpy()\n            sns.heatmap(attn_gt, annot=False, ax=axs3[stage_idx])\n\n            max_gt_offset = ind_max - indices[0]  # H W\n            max_gt_offset = max_gt_offset.detach().cpu().numpy()\n            sns.heatmap(max_gt_offset, annot=False, ax=axs4[stage_idx])\n\n        if cas_method[stage_idx] == 't' or cas_method[stage_idx] == 'r' or cas_method[stage_idx] == 'p':\n            # Loss for transformer \n            depth_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False)\n            if last_conv3d:\n                depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean')\n            attn_weights = stage_inputs[\"attn_weights\"].permute(0,2,3,1).reshape(B*H*W, -1)  # BHW D\n            this_stage_depth_val = stage_inputs['depth_values']  # B D H W\n            depth_offsets = torch.abs(this_stage_depth_val- depth_gt[:,None,:,:])  # B,D,H,W\n            _, indices = torch.min(depth_offsets, dim=1)  # [B, H, W]\n            indices = indices.reshape(-1)  # [BHW]\n            mask = mask.reshape(-1)  # BHW\n            if fl:  # -p(1-q)^a log(q)\n                this_stage_ce_loss = F.nll_loss((1-attn_weights[mask])**2 * torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')\n            else:  # -plog(q)\n                this_stage_ce_loss = F.nll_loss(torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')\n            stage_depth_loss.append(depth_loss)\n            stage_ce_loss.append(this_stage_ce_loss)\n\n            this_stage_loss = l1ce_lw[0]*depth_loss + l1ce_lw[1]*this_stage_ce_loss\n        \n        # Loss for 3D conv\n        else: \n            if wt:\n                depth_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False)\n                stage_depth_loss.append(depth_loss)\n                attn_weights = stage_inputs[\"attn_weights\"].permute(0,2,3,1).reshape(B*H*W, -1)  # BHW D\n                depth_offsets = torch.abs(stage_inputs['depth_values']- depth_gt[:,None,:,:])  # B,D,H,W\n                indices = torch.min(depth_offsets, dim=1)[1].reshape(-1)  # [BHW]\n                mask = mask.reshape(-1)  # BHW\n                if fl:  # -p(1-q)^a log(q)\n                    this_stage_ce_loss = F.nll_loss((1-attn_weights[mask])**2 * torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')\n                else:  # -plog(q)\n                    this_stage_ce_loss = F.nll_loss(torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')\n                stage_ce_loss.append(this_stage_ce_loss)\n            else:\n                depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean')\n                stage_depth_loss.append(depth_loss)\n                this_stage_ce_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False)\n                stage_ce_loss.append(this_stage_ce_loss)\n\n            this_stage_loss = l1ce_lw[0]*depth_loss + l1ce_lw[1]*this_stage_ce_loss\n        \n        if upsampled_loss:\n            if stage_idx!=0:\n                this_stage_loss = this_stage_loss + upsampled_depth_loss * l1ce_lw[0]\n        if shrink_method == 'DPF':\n            if stage_idx!=0:\n                depth_offsets = stage_inputs['depth_values'] - depth_gt[:,None,:,:]  # B,D,H,W\n                depth_offset_clamp = torch.clamp(depth_offsets, -1, 1)\n                this_stage_loss = this_stage_loss + torch.abs(depth_offset_clamp).permute(0,2,3,1).reshape(B*H*W, -1)[mask.reshape(-1)].mean()\n        if selected_loss:\n            select_weight = stage_inputs[\"select_weight\"].permute(0,2,3,1).reshape(B*H*W, -1)  # BHW D\n            depth_offsets = torch.abs(stage_inputs['depth_values']- depth_gt[:,None,:,:]) \n            indices = torch.min(depth_offsets, dim=1)[1]  # [B, H, W]\n            indices = indices.reshape(-1)  # [BHW]\n            mask = mask.reshape(-1)  # BHW\n            this_stage_selected_loss = F.nll_loss(torch.log(select_weight[mask]+1e-12), indices[mask], reduce='mean')\n            this_stage_loss = this_stage_loss + this_stage_selected_loss * 0.01*l1ce_lw[1]\n        if det:\n            assert wt\n            depth_itv = stage_inputs['depth_values'][:,1,:,:] - stage_inputs['depth_values'][:,0,:,:]   # B H W\n            pred_offset = stage_inputs['offset_reg'].reshape(-1)  # BHW\n            offset_gt = (depth_gt - (depth_est - stage_inputs['offset_reg'])).reshape(-1) / depth_itv.reshape(-1) # BHW\n            det_offset_loss = F.smooth_l1_loss(pred_offset[mask], offset_gt[mask], reduction='mean')\n            det_offset_losses.append(det_offset_loss)\n            this_stage_loss += det_offset_loss\n        else:\n            det_offset_losses.append(torch.tensor(0.0, dtype=torch.float32, device=mask_ms[\"stage1\"].device, requires_grad=False))\n\n        if depth_loss_weights is not None:\n            stage_idx = int(stage_key.replace(\"stage\", \"\")) - 1\n            total_loss += depth_loss_weights[stage_idx] * this_stage_loss\n        else:\n            total_loss += 1.0 * this_stage_loss\n\n    if visual:\n        axs[1].set_title('err848:{}'.format(err_848_str) + 'err_106:{}'.format(err_106_str) + 'err_002:{}'.format(err_002_str))\n        f.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/offset_heatmap.png')\n        f.clf()\n        f2.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/attn_max_heatmap.png')\n        f2.clf()\n        f3.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/attn_gt_heatmap.png')\n        f3.clf()\n        f4.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/max_gt_offset_heatmap.png')\n        f4.clf()\n\n    return total_loss, depth_loss, stage_depth_loss, stage_ce_loss, range_err_ratio, upsampled_depth_losses, det_offset_losses\n\n\ndef get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, shape, max_depth=192.0, min_depth=0.0):\n    #shape, (B, H, W)\n    #cur_depth: (B, H, W)\n    #return depth_range_values: (B, D, H, W)\n    cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel)  # (B, H, W)\n    cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel)\n    # cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel).clamp(min=0.0)   #(B, H, W)\n    # cur_depth_max = (cur_depth_min + (ndepth - 1) * depth_inteval_pixel).clamp(max=max_depth)\n\n    assert cur_depth.shape == torch.Size(shape), \"cur_depth:{}, input shape:{}\".format(cur_depth.shape, shape)\n    new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1)  # (B, H, W)\n\n    depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=cur_depth.device,\n                                                                  dtype=cur_depth.dtype,\n                                                                  requires_grad=False).reshape(1, -1, 1, 1) * new_interval.unsqueeze(1))\n\n    return depth_range_samples\n\n\ndef get_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, device, dtype, shape,\n                           max_depth=192.0, min_depth=0.0):\n    #shape: (B, H, W)\n    #cur_depth: (B, H, W) or (B, D)\n    #return depth_range_samples: (B, D, H, W)\n    if cur_depth.dim() == 2:\n        cur_depth_min = cur_depth[:, 0]  # (B,)\n        cur_depth_max = cur_depth[:, -1]\n        new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1)  # (B, )\n\n        depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=device, dtype=dtype,\n                                                                       requires_grad=False).reshape(1, -1) * new_interval.unsqueeze(1)) #(B, D)\n\n        depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, shape[1], shape[2]) #(B, D, H, W)\n\n    else:\n\n        depth_range_samples = get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, shape, max_depth, min_depth)\n\n    return depth_range_samples\n\n\n\nif __name__ == \"__main__\":\n    # some testing code, just IGNORE it\n    import sys\n    sys.path.append(\"../\")\n    from datasets import find_dataset_def\n    from torch.utils.data import DataLoader\n    import numpy as np\n    import cv2\n    import matplotlib as mpl\n    mpl.use('Agg')\n    import matplotlib.pyplot as plt\n\n    # MVSDataset = find_dataset_def(\"colmap\")\n    # dataset = MVSDataset(\"../data/results/ford/num10_1/\", 3, 'test',\n    #                      128, interval_scale=1.06, max_h=1250, max_w=1024)\n\n    MVSDataset = find_dataset_def(\"dtu_yao\")\n    num_depth = 48\n    dataset = MVSDataset(\"../data/DTU/mvs_training/dtu/\", '../lists/dtu/train.txt', 'train',\n                         3, num_depth, interval_scale=1.06 * 192 / num_depth)\n\n    dataloader = DataLoader(dataset, batch_size=1)\n    item = next(iter(dataloader))\n\n    imgs = item[\"imgs\"][:, :, :, ::4, ::4]  #(B, N, 3, H, W)\n    # imgs = item[\"imgs\"][:, :, :, :, :]\n    proj_matrices = item[\"proj_matrices\"]   #(B, N, 2, 4, 4) dim=N: N view; dim=2: index 0 for extr, 1 for intric\n    proj_matrices[:, :, 1, :2, :] = proj_matrices[:, :, 1, :2, :]\n    # proj_matrices[:, :, 1, :2, :] = proj_matrices[:, :, 1, :2, :] * 4\n    depth_values = item[\"depth_values\"]     #(B, D)\n\n    imgs = torch.unbind(imgs, 1)\n    proj_matrices = torch.unbind(proj_matrices, 1)\n    ref_img, src_imgs = imgs[0], imgs[1:]\n    ref_proj, src_proj = proj_matrices[0], proj_matrices[1:][0]  #only vis first view\n\n    src_proj_new = src_proj[:, 0].clone()\n    src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4])\n    ref_proj_new = ref_proj[:, 0].clone()\n    ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4])\n\n    warped_imgs = homo_warping(src_imgs[0], src_proj_new, ref_proj_new, depth_values)\n\n    ref_img_np = ref_img.permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255\n    cv2.imwrite('../tmp/ref.png', ref_img_np)\n    cv2.imwrite('../tmp/src.png', src_imgs[0].permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255)\n\n    for i in range(warped_imgs.shape[2]):\n        warped_img = warped_imgs[:, :, i, :, :].permute([0, 2, 3, 1]).contiguous()\n        img_np = warped_img[0].detach().cpu().numpy()\n        img_np = img_np[:, :, ::-1] * 255\n\n        alpha = 0.5\n        beta = 1 - alpha\n        gamma = 0\n        img_add = cv2.addWeighted(ref_img_np, alpha, img_np, beta, gamma)\n        cv2.imwrite('../tmp/tmp{}.png'.format(i), np.hstack([ref_img_np, img_np, img_add])) #* ratio + img_np*(1-ratio)]))"
  },
  {
    "path": "models/mvs4net_utils.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport importlib\ntry:\n    from modules.deform_conv import DeformConvPack\nexcept:\n    print('DeformConvPack not found, please install it from: https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch')\n    pass\nimport math\nimport numpy as np\n\ndef homo_warping(src_fea, src_proj, ref_proj, depth_values, vis_ETA=False, fn=None):\n    # src_fea: [B, C, H, W]\n    # src_proj: [B, 4, 4]\n    # ref_proj: [B, 4, 4]\n    # depth_values: [B, Ndepth] o [B, Ndepth, H, W]\n    # out: [B, C, Ndepth, H, W]\n    C = src_fea.shape[1]\n    Hs,Ws = src_fea.shape[-2:]\n    B,num_depth,Hr,Wr = depth_values.shape\n\n    with torch.no_grad():\n        proj = torch.matmul(src_proj, torch.inverse(ref_proj))\n        rot = proj[:, :3, :3]  # [B,3,3]\n        trans = proj[:, :3, 3:4]  # [B,3,1]\n\n        y, x = torch.meshgrid([torch.arange(0, Hr, dtype=torch.float32, device=src_fea.device),\n                               torch.arange(0, Wr, dtype=torch.float32, device=src_fea.device)])\n        y = y.reshape(Hr*Wr)\n        x = x.reshape(Hr*Wr)\n        xyz = torch.stack((x, y, torch.ones_like(x)))  # [3, H*W]\n        xyz = torch.unsqueeze(xyz, 0).repeat(B, 1, 1)  # [B, 3, H*W]\n        rot_xyz = torch.matmul(rot, xyz)  # [B, 3, H*W]\n        rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.reshape(B, 1, num_depth, -1)  # [B, 3, Ndepth, H*W]\n        proj_xyz = rot_depth_xyz + trans.reshape(B, 3, 1, 1)  # [B, 3, Ndepth, H*W]\n        # FIXME divide 0\n        temp = proj_xyz[:, 2:3, :, :]\n        temp[temp==0] = 1e-9\n        proj_xy = proj_xyz[:, :2, :, :] / temp  # [B, 2, Ndepth, H*W]\n        # proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :]  # [B, 2, Ndepth, H*W]\n\n        proj_x_normalized = proj_xy[:, 0, :, :] / ((Ws - 1) / 2) - 1\n        proj_y_normalized = proj_xy[:, 1, :, :] / ((Hs - 1) / 2) - 1\n        proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3)  # [B, Ndepth, H*W, 2]\n        if vis_ETA:\n            tensor_saved = proj_xy.reshape(B,num_depth,Hs,Ws,2).cpu().numpy()\n            np.save(fn+'_grid', tensor_saved)\n        grid = proj_xy\n    if len(src_fea.shape)==4:\n        warped_src_fea = F.grid_sample(src_fea, grid.reshape(B, num_depth * Hr, Wr, 2), mode='bilinear', padding_mode='zeros', align_corners=True)\n        warped_src_fea = warped_src_fea.reshape(B, C, num_depth, Hr, Wr)\n    elif len(src_fea.shape)==5:\n        warped_src_fea = []\n        for d in range(src_fea.shape[2]):\n            warped_src_fea.append(F.grid_sample(src_fea[:,:,d], grid.reshape(B, num_depth, Hr, Wr, 2)[:,d], mode='bilinear', padding_mode='zeros', align_corners=True))\n        warped_src_fea = torch.stack(warped_src_fea, dim=2)\n\n    return warped_src_fea\n\ndef init_range(cur_depth, ndepths, device, dtype, H, W):\n    cur_depth_min = cur_depth[:, 0]  # (B,)\n    cur_depth_max = cur_depth[:, -1]\n    new_interval = (cur_depth_max - cur_depth_min) / (ndepths - 1)  # (B, )\n    new_interval = new_interval[:, None, None]  # B H W\n    depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepths, device=device, dtype=dtype,\n                                                                requires_grad=False).reshape(1, -1) * new_interval.squeeze(1)) #(B, D)\n    depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, H, W) #(B, D, H, W)\n    return depth_range_samples\n\ndef init_inverse_range(cur_depth, ndepths, device, dtype, H, W):\n    inverse_depth_min = 1. / cur_depth[:, 0]  # (B,)\n    inverse_depth_max = 1. / cur_depth[:, -1]\n    itv = torch.arange(0, ndepths, device=device, dtype=dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H, W)  / (ndepths - 1)  # 1 D H W\n    inverse_depth_hypo = inverse_depth_max[:,None, None, None] + (inverse_depth_min - inverse_depth_max)[:,None, None, None] * itv\n\n    return 1./inverse_depth_hypo\n\ndef schedule_inverse_range(inverse_min_depth, inverse_max_depth, ndepths, H, W):\n    #cur_depth_min, (B, H, W)\n    #cur_depth_max: (B, H, W)\n    itv = torch.arange(0, ndepths, device=inverse_min_depth.device, dtype=inverse_min_depth.dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H//2, W//2)  / (ndepths - 1)  # 1 D H W\n\n    inverse_depth_hypo = inverse_max_depth[:,None, :, :] + (inverse_min_depth - inverse_max_depth)[:,None, :, :] * itv  # B D H W\n    inverse_depth_hypo = F.interpolate(inverse_depth_hypo.unsqueeze(1), [ndepths, H, W], mode='trilinear', align_corners=True).squeeze(1)\n    return 1./inverse_depth_hypo\n\ndef schedule_range(cur_depth, ndepth, depth_inteval_pixel, H, W):\n    #shape, (B, H, W)\n    #cur_depth: (B, H, W)\n    #return depth_range_values: (B, D, H, W)\n    cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel[:,None,None])  # (B, H, W)\n    cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel[:,None,None])\n    new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1)  # (B, H, W)\n\n    depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=cur_depth.device, dtype=cur_depth.dtype,\n                                                                  requires_grad=False).reshape(1, -1, 1, 1) * new_interval.unsqueeze(1))\n    depth_range_samples = F.interpolate(depth_range_samples.unsqueeze(1), [ndepth, H, W], mode='trilinear', align_corners=True).squeeze(1)\n    return depth_range_samples\n\ndef init_bn(module):\n    if module.weight is not None:\n        nn.init.ones_(module.weight)\n    if module.bias is not None:\n        nn.init.zeros_(module.bias)\n    return\n\ndef init_uniform(module, init_method):\n    if module.weight is not None:\n        if init_method == \"kaiming\":\n            nn.init.kaiming_uniform_(module.weight)\n        elif init_method == \"xavier\":\n            nn.init.xavier_uniform_(module.weight)\n    return\n\nclass ConvBnReLU3D(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):\n        super(ConvBnReLU3D, self).__init__()\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)\n        self.bn = nn.BatchNorm3d(out_channels)\n\n    def forward(self, x):\n        return F.relu(self.bn(self.conv(x)), inplace=True)\n\nclass ConvBnReLU3D_CAM(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):\n        super(ConvBnReLU3D_CAM, self).__init__()\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)\n        self.bn = nn.BatchNorm3d(out_channels)\n        self.linear_agg = nn.Sequential(\n            nn.Linear(out_channels, out_channels//2),\n            nn.ReLU(),\n            nn.Linear(out_channels//2, out_channels)\n        )\n\n    def forward(self, input):\n        x = self.conv(input)\n        B,C,D,H,W = x.shape\n        avg_attn = self.linear_agg(x.reshape(B,C,D*H*W).mean(2))\n        max_attn = self.linear_agg(x.reshape(B,C,D*H*W).max(2)[0])  # B C\n        attn = F.sigmoid(max_attn+avg_attn)[:,:,None,None,None]  # B C,1,1,1\n        x = x * attn\n        return F.relu(self.bn(x+input), inplace=True)\n\nclass ConvBnReLU3D_DCAM(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):\n        super(ConvBnReLU3D_DCAM, self).__init__()\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)\n        self.bn = nn.BatchNorm3d(out_channels)\n        self.linear_agg = nn.Sequential(\n            nn.Linear(out_channels, out_channels//2),\n            nn.ReLU(),\n            nn.Linear(out_channels//2, out_channels)\n        )\n\n    def forward(self, input):\n        x = self.conv(input)\n        B,C,D,H,W = x.shape\n        avg_attn = self.linear_agg(x.reshape(B,C,D,H*W).mean(3).permute(0,2,1).reshape(B*D,C)).reshape(B,D,C).permute(0,2,1)\n        max_attn = self.linear_agg(x.reshape(B,C,D,H*W).max(3)[0].permute(0,2,1).reshape(B*D,C)).reshape(B,D,C).permute(0,2,1)  # B C D\n        attn = F.sigmoid(max_attn+avg_attn)[:,:,:,None,None]  # B C,D,1,1\n        x = x * attn\n        return F.relu(self.bn(x+input), inplace=True)\n\nclass ConvBnReLU3D_PAM(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):\n        super(ConvBnReLU3D_PAM, self).__init__()\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)\n        self.bn = nn.BatchNorm3d(out_channels)\n        self.pixel_conv = nn.Conv2d(2,1,7,stride=1,padding='same')\n\n    def forward(self, input):\n        x = self.conv(input)\n        B,C,D,H,W = x.shape\n        max_attn = x.reshape(B,C*D,H,W).max(1, keepdim=True)[0]\n        avg_attn = x.reshape(B,C*D,H,W).mean(1, keepdim=True)  # B 1 H W\n        attn = F.sigmoid(self.pixel_conv(torch.cat([max_attn, avg_attn], dim=1)))[:,:,None,:,:]  # B 1,1,H,W\n        x = x * attn\n        return F.relu(self.bn(x+input), inplace=True)\n\nclass ConvBnReLU3D_PDAM(nn.Module):\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):\n        super(ConvBnReLU3D_PDAM, self).__init__()\n        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)\n        self.bn = nn.BatchNorm3d(out_channels)\n        self.spatial_conv = nn.Conv3d(2,1,7,stride=1,padding='same')\n\n    def forward(self, input):\n        x = self.conv(input)\n        B,C,D,H,W = x.shape\n        max_attn = x.max(1, keepdim=True)[0]\n        avg_attn = x.mean(1, keepdim=True)  # B 1 D H W\n        attn = F.sigmoid(self.spatial_conv(torch.cat([max_attn, avg_attn], dim=1)))  # B 1,D,H,W\n        x = x * attn\n        return F.relu(self.bn(x+input), inplace=True)\n\nclass Deconv3d(nn.Module):\n\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,\n                 relu=True, bn=True, bn_momentum=0.1, init_method=\"xavier\", **kwargs):\n        super(Deconv3d, self).__init__()\n        self.out_channels = out_channels\n        assert stride in [1, 2]\n        self.stride = stride\n\n        self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,\n                                       bias=(not bn), **kwargs)\n        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None\n        self.relu = relu\n\n    def forward(self, x):\n        y = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(y)\n        if self.relu:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def init_weights(self, init_method):\n        init_uniform(self.conv, init_method)\n        if self.bn is not None:\n            init_bn(self.bn)\n\nclass Conv2d(nn.Module):\n\n    def __init__(self, in_channels, out_channels, kernel_size, stride=1,\n                 relu=True, bn_momentum=0.1, init_method=\"xavier\", gn=False, group_channel=8, **kwargs):\n        super(Conv2d, self).__init__()\n        bn = not gn\n        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,\n                              bias=(not bn), **kwargs)\n        self.kernel_size = kernel_size\n        self.stride = stride\n        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None\n        self.gn = nn.GroupNorm(int(max(1, out_channels / group_channel)), out_channels) if gn else None\n        self.relu = relu\n\n    def forward(self, x):\n        x = self.conv(x)\n        if self.bn is not None:\n            x = self.bn(x)\n        else:\n            x = self.gn(x)\n        if self.relu:\n            x = F.relu(x, inplace=True)\n        return x\n\n    def init_weights(self, init_method):\n        init_uniform(self.conv, init_method)\n        if self.bn is not None:\n            init_bn(self.bn)\n\nclass Deconv2d(nn.Module):\n\n    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,\n                 relu=True, bn=True, bn_momentum=0.1, init_method=\"xavier\", **kwargs):\n        super(Deconv2d, self).__init__()\n        self.out_channels = out_channels\n        assert stride in [1, 2]\n        self.stride = stride\n\n        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride,\n                                       bias=(not bn), **kwargs)\n        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None\n        self.relu = relu\n\nclass DeformConv2d(nn.Module):\n    def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=True):\n        super(DeformConv2d, self).__init__()\n        self.kernel_size = kernel_size\n        self.padding = padding\n        self.stride = stride\n        self.zero_padding = nn.ZeroPad2d(padding)\n        self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)\n\n        self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)\n        nn.init.constant_(self.p_conv.weight, 0)\n        self.p_conv.register_backward_hook(self._set_lr)\n\n        self.modulation = modulation\n        if modulation:\n            self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)\n            nn.init.constant_(self.m_conv.weight, 0)\n            self.m_conv.register_backward_hook(self._set_lr)\n\n    @staticmethod\n    def _set_lr(module, grad_input, grad_output):\n        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))\n        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))\n\n    def forward(self, x):\n        offset = self.p_conv(x)\n        if self.modulation:\n            m = torch.sigmoid(self.m_conv(x))\n\n        dtype = offset.data.type()\n        ks = self.kernel_size\n        N = offset.size(1) // 2\n\n        if self.padding:\n            x = self.zero_padding(x)\n\n        # (b, 2N, h, w)\n        p = self._get_p(offset, dtype)\n\n        # (b, h, w, 2N)\n        p = p.contiguous().permute(0, 2, 3, 1)\n        q_lt = p.detach().floor()\n        q_rb = q_lt + 1\n\n        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()\n        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()\n        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)\n        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)\n\n        # clip p\n        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)\n\n        # bilinear kernel (b, h, w, N)\n        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))\n        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))\n        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))\n        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))\n\n        # (b, c, h, w, N)\n        x_q_lt = self._get_x_q(x, q_lt, N)\n        x_q_rb = self._get_x_q(x, q_rb, N)\n        x_q_lb = self._get_x_q(x, q_lb, N)\n        x_q_rt = self._get_x_q(x, q_rt, N)\n\n        # (b, c, h, w, N)\n        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \\\n                   g_rb.unsqueeze(dim=1) * x_q_rb + \\\n                   g_lb.unsqueeze(dim=1) * x_q_lb + \\\n                   g_rt.unsqueeze(dim=1) * x_q_rt\n\n        # modulation\n        if self.modulation:\n            m = m.contiguous().permute(0, 2, 3, 1)\n            m = m.unsqueeze(dim=1)\n            m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)\n            x_offset *= m\n\n        x_offset = self._reshape_x_offset(x_offset, ks)\n        out = self.conv(x_offset)\n\n        return out\n\n    def _get_p_n(self, N, dtype):\n        p_n_x, p_n_y = torch.meshgrid(\n            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),\n            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))\n        # (2N, 1)\n        p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)\n        p_n = p_n.view(1, 2*N, 1, 1).type(dtype)\n\n        return p_n\n\n    def _get_p_0(self, h, w, N, dtype):\n        p_0_x, p_0_y = torch.meshgrid(\n            torch.arange(1, h*self.stride+1, self.stride),\n            torch.arange(1, w*self.stride+1, self.stride))\n        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)\n        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)\n        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)\n\n        return p_0\n\n    def _get_p(self, offset, dtype):\n        N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)\n\n        # (1, 2N, 1, 1)\n        p_n = self._get_p_n(N, dtype)\n        # (1, 2N, h, w)\n        p_0 = self._get_p_0(h, w, N, dtype)\n        p = p_0 + p_n + offset\n        return p\n\n    def _get_x_q(self, x, q, N):\n        b, h, w, _ = q.size()\n        padded_w = x.size(3)\n        c = x.size(1)\n        # (b, c, h*w)\n        x = x.contiguous().view(b, c, -1)\n\n        # (b, h, w, N)\n        index = q[..., :N]*padded_w + q[..., N:]  # offset_x*w + offset_y\n        # (b, c, h*w*N)\n        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)\n\n        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)\n\n        return x_offset\n\n    @staticmethod\n    def _reshape_x_offset(x_offset, ks):\n        b, c, h, w, N = x_offset.size()\n        x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)\n        x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)\n\n        return x_offset\n\ndef NA_DCN(in_channels, kernel_size=3, stride=1, dilation=1, bias=True, group_channel=8, gn=False):\n    if gn:\n        return nn.Sequential(\n            nn.GroupNorm(int(max(1, in_channels / group_channel)), in_channels),\n            nn.ReLU(inplace=True),\n            # DeformConv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride,  bias=bias),\n            DeformConvPack(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=1, deformable_groups=1, bias=False, im2col_step=16)\n        )\n    else:\n        return nn.Sequential(\n            nn.BatchNorm2d(in_channels, momentum=0.1),\n            nn.ReLU(inplace=True),\n            # DeformConv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride,  bias=bias),\n            DeformConvPack(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=1, deformable_groups=1, bias=False, im2col_step=16)\n        )\n\nclass FPN4(nn.Module):\n    \"\"\"\n    FPN aligncorners downsample 4x\"\"\"\n    def __init__(self, base_channels, gn=False, dcn=False):\n        super(FPN4, self).__init__()\n        self.base_channels = base_channels\n\n        self.conv0 = nn.Sequential(\n            Conv2d(3, base_channels, 3, 1, padding=1, gn=gn),\n            Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn),\n        )\n\n        self.conv1 = nn.Sequential(\n            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2, gn=gn),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn),\n            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn),\n        )\n\n        self.conv2 = nn.Sequential(\n            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2, gn=gn),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn),\n            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn),\n        )\n\n        self.conv3 = nn.Sequential(\n            Conv2d(base_channels * 4, base_channels * 8, 5, stride=2, padding=2, gn=gn),\n            Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn),\n            Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn),\n        )\n\n        self.out_channels = [8 * base_channels]\n        final_chs = base_channels * 8\n\n        self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)\n        self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)\n        self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)\n\n        self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)\n        self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)\n        self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)\n        self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)\n\n        self.dcn = dcn\n        if self.dcn:\n            self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn)\n            self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn)\n            self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn)\n            self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn)\n\n        self.out_channels.append(base_channels * 4)\n        self.out_channels.append(base_channels * 2)\n        self.out_channels.append(base_channels)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv1 = self.conv1(conv0)\n        conv2 = self.conv2(conv1)\n        conv3 = self.conv3(conv2)\n\n        intra_feat = conv3\n        outputs = {}\n        out1 = self.out1(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner1(conv2)\n        out2 = self.out2(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner2(conv1)\n        out3 = self.out3(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner3(conv0)\n        out4 = self.out4(intra_feat)\n\n        if self.dcn:\n            out1 = self.dcn1(out1)\n            out2 = self.dcn2(out2)\n            out3 = self.dcn3(out3)\n            out4 = self.dcn4(out4)\n\n        outputs[\"stage1\"] = out1\n        outputs[\"stage2\"] = out2\n        outputs[\"stage3\"] = out3\n        outputs[\"stage4\"] = out4\n\n        return outputs\n\nclass LayerNorm(nn.Module):\n\n    def __init__(self, normalized_shape, eps=1e-6, data_format=\"channels_last\"):\n        super().__init__()\n        self.weight = nn.Parameter(torch.ones(normalized_shape))\n        self.bias = nn.Parameter(torch.zeros(normalized_shape))\n        self.eps = eps\n        self.data_format = data_format\n        if self.data_format not in [\"channels_last\", \"channels_first\"]:\n            raise NotImplementedError \n        self.normalized_shape = (normalized_shape, )\n    \n    def forward(self, x):\n        if self.data_format == \"channels_last\":\n            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)\n        elif self.data_format == \"channels_first\":\n            u = x.mean(1, keepdim=True)\n            s = (x - u).pow(2).mean(1, keepdim=True)\n            x = (x - u) / torch.sqrt(s + self.eps)\n            x = self.weight[:, None, None] * x + self.bias[:, None, None]\n            return x\n\nclass convnext_block(nn.Module):\n\n    def __init__(self, dim, layer_scale_init_value=1e-6):\n        super().__init__()\n        self.dwconv = nn.Conv2d(dim, 2*dim, kernel_size=7, stride=2, padding=3, groups=dim) # depthwise conv\n        self.norm = LayerNorm(2*dim, eps=1e-6)\n        self.pwconv1 = nn.Linear(2*dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers\n        self.act = nn.GELU()\n        self.pwconv2 = nn.Linear(4 * dim, 2*dim)\n        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((2*dim)), \n                                    requires_grad=True) if layer_scale_init_value > 0 else None\n\n    def forward(self, x):\n        input = x\n        x = self.dwconv(x)\n        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)\n        x = self.norm(x)\n        x = self.pwconv1(x)\n        x = self.act(x)\n        x = self.pwconv2(x)\n        if self.gamma is not None:\n            x = self.gamma * x\n        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)\n\n        # x = input + x\n        return x\n\nclass convnext4_block(nn.Module):\n\n    def __init__(self, dim, layer_scale_init_value=1e-6):\n        super().__init__()\n        self.sconv = nn.Conv2d(dim, 2*dim, kernel_size=2, stride=2, padding=0) # stride=2 conv\n        self.dwconv = nn.Conv2d(2*dim, 2*dim, kernel_size=7, stride=1, padding=3, groups=dim) # depthwise conv\n        self.norm = LayerNorm(2*dim, eps=1e-6)\n        self.pwconv1 = nn.Linear(2*dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers\n        self.act = nn.GELU()\n        self.pwconv2 = nn.Linear(4 * dim, 2*dim)\n        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((2*dim)), \n                                    requires_grad=True) if layer_scale_init_value > 0 else None\n\n    def forward(self, x):\n        input = self.sconv(x)\n        x = self.dwconv(input)\n        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)\n        x = self.norm(x)\n        x = self.pwconv1(x)\n        x = self.act(x)\n        x = self.pwconv2(x)\n        if self.gamma is not None:\n            x = self.gamma * x\n        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)\n\n        x = input + x\n        return x\n\nclass FPN4_convnext(nn.Module):\n    \"\"\"\n    FPN aligncorners downsample 4x\"\"\"\n    def __init__(self, base_channels, gn=False, dcn=False):\n        super(FPN4_convnext, self).__init__()\n        self.base_channels = base_channels\n\n        self.conv0 = nn.Sequential(\n            Conv2d(3, base_channels, 3, 1, padding=1, gn=gn),\n            Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn),\n        )\n\n        self.conv1 = convnext_block(base_channels)\n        self.conv2 = convnext_block(2*base_channels)\n        self.conv3 = convnext_block(4*base_channels)\n\n        self.out_channels = [8 * base_channels]\n        final_chs = base_channels * 8\n\n        self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)\n        self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)\n        self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)\n\n        self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)\n        self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)\n        self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)\n        self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)\n\n        self.dcn = dcn\n        if self.dcn:\n            self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn)\n            self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn)\n            self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn)\n            self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn)\n\n        self.out_channels.append(base_channels * 4)\n        self.out_channels.append(base_channels * 2)\n        self.out_channels.append(base_channels)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv1 = self.conv1(conv0)\n        conv2 = self.conv2(conv1)\n        conv3 = self.conv3(conv2)\n\n        intra_feat = conv3\n        outputs = {}\n        out1 = self.out1(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner1(conv2)\n        out2 = self.out2(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner2(conv1)\n        out3 = self.out3(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner3(conv0)\n        out4 = self.out4(intra_feat)\n\n        if self.dcn:\n            out1 = self.dcn1(out1)\n            out2 = self.dcn2(out2)\n            out3 = self.dcn3(out3)\n            out4 = self.dcn4(out4)\n\n        outputs[\"stage1\"] = out1\n        outputs[\"stage2\"] = out2\n        outputs[\"stage3\"] = out3\n        outputs[\"stage4\"] = out4\n\n        return outputs\n\nclass FPN4_convnext4(nn.Module):\n    \"\"\"\n    FPN aligncorners downsample 4x\"\"\"\n    def __init__(self, base_channels, gn=False, dcn=False):\n        super(FPN4_convnext4, self).__init__()\n        self.base_channels = base_channels\n\n        self.conv0 = nn.Sequential(\n            Conv2d(3, base_channels, 3, 1, padding=1, gn=gn),\n            Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn),\n        )\n\n        self.conv1 = convnext4_block(base_channels)\n        self.conv2 = convnext4_block(2*base_channels)\n        self.conv3 = convnext4_block(4*base_channels)\n\n        self.out_channels = [8 * base_channels]\n        final_chs = base_channels * 8\n\n        self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)\n        self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)\n        self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)\n\n        self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)\n        self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)\n        self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)\n        self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)\n\n        self.dcn = dcn\n        if self.dcn:\n            self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn)\n            self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn)\n            self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn)\n            self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn)\n\n        self.out_channels.append(base_channels * 4)\n        self.out_channels.append(base_channels * 2)\n        self.out_channels.append(base_channels)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv1 = self.conv1(conv0)\n        conv2 = self.conv2(conv1)\n        conv3 = self.conv3(conv2)\n\n        intra_feat = conv3\n        outputs = {}\n        out1 = self.out1(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner1(conv2)\n        out2 = self.out2(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner2(conv1)\n        out3 = self.out3(intra_feat)\n\n        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode=\"bilinear\", align_corners=True) + self.inner3(conv0)\n        out4 = self.out4(intra_feat)\n\n        if self.dcn:\n            out1 = self.dcn1(out1)\n            out2 = self.dcn2(out2)\n            out3 = self.dcn3(out3)\n            out4 = self.dcn4(out4)\n\n        outputs[\"stage1\"] = out1\n        outputs[\"stage2\"] = out2\n        outputs[\"stage3\"] = out3\n        outputs[\"stage4\"] = out4\n\n        return outputs\n    \nclass ASFF(nn.Module):\n    def __init__(self, level):\n        super(ASFF, self).__init__()\n        self.level = level\n        self.dim = [64,32,16,8]\n        self.inter_dim = self.dim[self.level]\n        if level==0:\n            self.stride_level_1 = Conv2d(32, 64, 3, stride=2, padding=1)\n            self.stride_level_2 = Conv2d(16, 64, 3, stride=2, padding=1)\n            self.stride_level_3 = Conv2d(8, 64, 3, stride=2, padding=1)\n            self.expand = Conv2d(64, 64, 3, stride=1, padding=1)\n        elif level==1:\n            self.compress_level_0 =  Conv2d(64, 32, 1, stride=1, padding=0)\n            self.stride_level_2 = Conv2d(16, 32, 3, stride=2, padding=1)\n            self.stride_level_3 = Conv2d(8, 32, 3, stride=2, padding=1)\n            self.expand = Conv2d(32, 32, 3, stride=1, padding=1)\n        elif level==2:\n            self.compress_level_0 = Conv2d(64, 16, 1, stride=1, padding=0)\n            self.compress_level_1 = Conv2d(32, 16, 1, stride=1, padding=0)\n            self.stride_level_3 = Conv2d(8, 16, 3, stride=2, padding=1)\n            self.expand = Conv2d(16, 16, 3, stride=1, padding=1)\n        elif level==3:\n            self.compress_level_0 = Conv2d(64, 8, 1, stride=1, padding=0)\n            self.compress_level_1 = Conv2d(32, 8, 1, stride=1, padding=0)\n            self.compress_level_2 = Conv2d(16, 8, 1, stride=1, padding=0)\n            self.expand = Conv2d(8, 8, 3, stride=1, padding=1)\n\n        self.weight_level_0 = Conv2d(self.dim[level], 8, 1, 1, 0)\n        self.weight_level_1 = Conv2d(self.dim[level], 8, 1, 1, 0)\n        self.weight_level_2 = Conv2d(self.dim[level], 8, 1, 1, 0)\n        self.weight_level_3 = Conv2d(self.dim[level], 8, 1, 1, 0)\n\n        self.weight_levels = nn.Conv2d(32, 4, kernel_size=1, stride=1, padding=0)\n\n\n    def forward(self, x_level_0, x_level_1, x_level_2, x_level_3):\n        if self.level==0:\n            level_0_resized = x_level_0\n            level_1_resized = self.stride_level_1(x_level_1)\n            level_2_downsampled_inter = F.max_pool2d(x_level_2, 2, stride=2, padding=0)\n            level_2_resized = self.stride_level_2(level_2_downsampled_inter)\n            level_3_downsampled_inter = F.max_pool2d(x_level_3, 4, stride=4, padding=0)\n            level_3_resized = self.stride_level_3(level_3_downsampled_inter)\n\n        elif self.level==1:\n            level_0_compressed = self.compress_level_0(x_level_0)\n            level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')\n            level_1_resized = x_level_1\n            level_2_resized = self.stride_level_2(x_level_2)\n            level_3_downsampled_inter = F.max_pool2d(x_level_3, 2, stride=2, padding=0)\n            level_3_resized = self.stride_level_3(level_3_downsampled_inter)\n        elif self.level==2:\n            level_0_compressed = self.compress_level_0(x_level_0)\n            level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')\n            level_1_compressed = self.compress_level_1(x_level_1)\n            level_1_resized = F.interpolate(level_1_compressed, scale_factor=2, mode='nearest')\n            level_2_resized = x_level_2\n            level_3_resized = self.stride_level_3(x_level_3)\n        elif self.level==3:\n            level_0_compressed = self.compress_level_0(x_level_0)\n            level_0_resized = F.interpolate(level_0_compressed, scale_factor=8, mode='nearest')\n            level_1_compressed = self.compress_level_1(x_level_1)\n            level_1_resized = F.interpolate(level_1_compressed, scale_factor=4, mode='nearest')\n            level_2_compressed = self.compress_level_2(x_level_2)\n            level_2_resized = F.interpolate(level_2_compressed, scale_factor=2, mode='nearest')\n            level_3_resized = x_level_3\n\n        level_0_weight_v = self.weight_level_0(level_0_resized)\n        level_1_weight_v = self.weight_level_1(level_1_resized)\n        level_2_weight_v = self.weight_level_2(level_2_resized)\n        level_3_weight_v = self.weight_level_3(level_3_resized)\n        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v, level_3_weight_v),1)\n        levels_weight = self.weight_levels(levels_weight_v)\n        levels_weight = F.softmax(levels_weight, dim=1)\n\n        fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+\\\n                            level_1_resized * levels_weight[:,1:2,:,:]+\\\n                            level_2_resized * levels_weight[:,2:3,:,:]+\\\n                            level_3_resized * levels_weight[:,3:,:,:]\n\n        out = self.expand(fused_out_reduced)\n\n        return out\n\nclass FullImageEncoder(nn.Module):\n    def __init__(self, h, w, kernel_size):\n        super(FullImageEncoder, self).__init__()\n        self.global_pooling = nn.AvgPool2d(kernel_size, stride=kernel_size, padding=kernel_size // 2)  # KITTI 16 16\n        self.dropout = nn.Dropout2d(p=0.5)\n        self.h = h // kernel_size + 1\n        self.w = w // kernel_size + 1\n        # print(\"h=\", self.h, \" w=\", self.w, h, w)\n        self.global_fc = nn.Linear(2048 * self.h * self.w, 512)  # kitti 4x5\n        self.relu = nn.ReLU(inplace=True)\n        self.conv1 = nn.Conv2d(512, 512, 1)  # 1x1 卷积\n\n    def forward(self, x):\n        # print('x size:', x.size())\n        x1 = self.global_pooling(x)\n        # print('# x1 size:', x1.size())\n        x2 = self.dropout(x1)\n        x3 = x2.view(-1, 2048 * self.h * self.w)  # kitti 4x5\n        x4 = self.relu(self.global_fc(x3))\n        # print('# x4 size:', x4.size())\n        x4 = x4.view(-1, 512, 1, 1)\n        # print('# x4 size:', x4.size())\n        x5 = self.conv1(x4)\n        # out = self.upsample(x5)\n        return x5\n\nclass mono_depth_decoder(nn.Module):\n\n    def __init__(self):\n        super(mono_depth_decoder, self).__init__()\n        self.convblocks = nn.ModuleList(\n            [Conv2d(64, 32, 3, 1, padding=1),\n            Conv2d(32, 16, 3, 1, padding=1),\n            Conv2d(16, 8, 3, 1, padding=1)]\n        )\n        self.conv3x3 = nn.ModuleList(\n           [nn.Conv2d(64, 1, 3, 1, 1),\n            nn.Conv2d(32, 1, 3, 1, 1),\n            nn.Conv2d(16, 1, 3, 1, 1)]\n        )\n        self.sigmoid = nn.Sigmoid()\n\n    def forward(self, outputs, d_min, d_max):\n        \"\"\"\n        d_max: B\n        \"\"\"\n        for i in range(1,4):  # 1 2 3\n            mono_small_feat = outputs['stage{}'.format(i)]['mono_feat']\n            mono_large_feat = outputs['stage{}'.format(i+1)]['mono_feat']\n\n            mono_small_feat = self.convblocks[i-1](mono_small_feat)\n            mono_small_feat = F.interpolate(mono_small_feat, scale_factor=2, mode=\"nearest\")\n\n            mono_feat = self.conv3x3[i-1](torch.cat([mono_small_feat, mono_large_feat], 1))  # B C H W\n\n            disp = self.sigmoid(mono_feat)\n            min_disp = (1 / d_max)[:,None,None,None]  # B 1 1 1\n            max_disp = (1 / d_min)[:,None,None,None]\n            scaled_disp = min_disp + (max_disp - min_disp) * disp\n            depth = 1 / scaled_disp\n            outputs['stage{}'.format(i+1)]['mono_depth'] = depth.squeeze(1)\n        return outputs\n\nclass reg2d(nn.Module):\n    def __init__(self, input_channel=128, base_channel=32, conv_name='ConvBnReLU3D'):\n        super(reg2d, self).__init__()\n        module = importlib.import_module(\"models.mvs4net_utils\")\n        stride_conv_name = 'ConvBnReLU3D'\n        self.conv0 = getattr(module, stride_conv_name)(input_channel, base_channel, kernel_size=(1,3,3), pad=(0,1,1))\n        self.conv1 = getattr(module, stride_conv_name)(base_channel, base_channel*2, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))\n        self.conv2 = getattr(module, conv_name)(base_channel*2, base_channel*2)\n\n        self.conv3 = getattr(module, stride_conv_name)(base_channel*2, base_channel*4, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))\n        self.conv4 = getattr(module, conv_name)(base_channel*4, base_channel*4)\n\n        self.conv5 = getattr(module, stride_conv_name)(base_channel*4, base_channel*8, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))\n        self.conv6 = getattr(module, conv_name)(base_channel*8, base_channel*8)\n\n        self.conv7 = nn.Sequential(\n            nn.ConvTranspose3d(base_channel*8, base_channel*4, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),\n            nn.BatchNorm3d(base_channel*4),\n            nn.ReLU(inplace=True))\n\n        self.conv9 = nn.Sequential(\n            nn.ConvTranspose3d(base_channel*4, base_channel*2, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),\n            nn.BatchNorm3d(base_channel*2),\n            nn.ReLU(inplace=True))\n\n        self.conv11 = nn.Sequential(\n            nn.ConvTranspose3d(base_channel*2, base_channel, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),\n            nn.BatchNorm3d(base_channel),\n            nn.ReLU(inplace=True))\n\n        self.prob = nn.Conv3d(8, 1, 1, stride=1, padding=0)\n\n    def forward(self, x):\n        conv0 = self.conv0(x)\n        conv2 = self.conv2(self.conv1(conv0))\n        conv4 = self.conv4(self.conv3(conv2))\n        x = self.conv6(self.conv5(conv4))\n        x = conv4 + self.conv7(x)\n        x = conv2 + self.conv9(x)\n        x = conv0 + self.conv11(x)\n        x = self.prob(x)\n\n        return x.squeeze(1)\n\nclass reg3d(nn.Module):\n    def __init__(self, in_channels, base_channels, down_size=3):\n        super(reg3d, self).__init__()\n        self.down_size = down_size\n        self.conv0 = ConvBnReLU3D(in_channels, base_channels, kernel_size=3, pad=1)\n        self.conv1 = ConvBnReLU3D(base_channels, base_channels*2, kernel_size=3, stride=2, pad=1)\n        self.conv2 = ConvBnReLU3D(base_channels*2, base_channels*2)\n        if down_size >= 2:\n            self.conv3 = ConvBnReLU3D(base_channels*2, base_channels*4, kernel_size=3, stride=2, pad=1)\n            self.conv4 = ConvBnReLU3D(base_channels*4, base_channels*4)\n        if down_size >= 3:\n            self.conv5 = ConvBnReLU3D(base_channels*4, base_channels*8, kernel_size=3, stride=2, pad=1)\n            self.conv6 = ConvBnReLU3D(base_channels*8, base_channels*8)\n            self.conv7 = nn.Sequential(\n                nn.ConvTranspose3d(base_channels*8, base_channels*4, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),\n                nn.BatchNorm3d(base_channels*4),\n                nn.ReLU(inplace=True))\n        if down_size >= 2:\n            self.conv9 = nn.Sequential(\n                nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),\n                nn.BatchNorm3d(base_channels*2),\n                nn.ReLU(inplace=True))\n\n        self.conv11 = nn.Sequential(\n            nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),\n            nn.BatchNorm3d(base_channels),\n            nn.ReLU(inplace=True))\n        self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False)\n\n    def forward(self, x):\n        if self.down_size==3:\n            conv0 = self.conv0(x)\n            conv2 = self.conv2(self.conv1(conv0))\n            conv4 = self.conv4(self.conv3(conv2))\n            x = self.conv6(self.conv5(conv4))\n            x = conv4 + self.conv7(x)\n            x = conv2 + self.conv9(x)\n            x = conv0 + self.conv11(x)\n            x = self.prob(x)\n        elif self.down_size==2:\n            conv0 = self.conv0(x)\n            conv2 = self.conv2(self.conv1(conv0))\n            x = self.conv4(self.conv3(conv2))\n            x = conv2 + self.conv9(x)\n            x = conv0 + self.conv11(x)\n            x = self.prob(x)\n        else:\n            conv0 = self.conv0(x)\n            x = self.conv2(self.conv1(conv0))\n            x = conv0 + self.conv11(x)\n            x = self.prob(x)\n        return x.squeeze(1)  # B D H W\n\nclass PosEncSine(nn.Module):\n\n    def __init__(self, temperature=1000):\n        super(PosEncSine, self).__init__()\n        self.temperature = temperature\n\n    def forward(self, x, depth):\n        # depth : B D H W\n        with torch.no_grad():\n            B,C,D,H,W = x.shape\n            depth = depth.permute(0,2,3,1).reshape(B*H*W, D) / self.temperature  # BHW D\n            pos = torch.stack([torch.sin(i * math.pi * depth) for i in range(C//2)] + [torch.cos(i * math.pi * depth) for i in range(C//2)], dim=-1)  # BHW,D,C\n            pos = pos.reshape(B,H,W,D,C).permute(0,4,3,1,2)  # B C D H W\n        x = x + pos\n        return x\n\nclass PosEncLearned(nn.Module):\n    \"\"\"\n    Absolute pos embedding, learned.\n    \"\"\"\n    def __init__(self, D, C):\n        super().__init__()\n        self.D = D\n        self.C = C\n        self.depth_embed = nn.Parameter(torch.Tensor(C, self.D))\n        self.reset_parameters()\n\n    def reset_parameters(self):\n        nn.init.uniform_(self.depth_embed)\n\n    def forward(self, x, **kwargs):\n        B,C,D,H,W = x.shape\n        pos = self.depth_embed[None,:,:,None,None].repeat(B,1,1,H,W)  # B C D H W\n        x = x + pos\n        return x\n\nclass stagenet(nn.Module):\n    def __init__(self, inverse_depth=False, mono=False, attn_fuse_d=True, vis_ETA=False, attn_temp=1):\n        super(stagenet, self).__init__()\n        self.inverse_depth = inverse_depth\n        self.mono = mono\n        self.attn_fuse_d = attn_fuse_d\n        self.vis_ETA = vis_ETA\n        self.attn_temp = attn_temp\n\n    def forward(self, features, proj_matrices, depth_hypo, regnet, stage_idx, group_cor=False, group_cor_dim=8, split_itv=1, fn=None):\n\n        # step 1. feature extraction\n        proj_matrices = torch.unbind(proj_matrices, 1)\n        ref_feature, src_features = features[0], features[1:]\n        ref_proj, src_projs = proj_matrices[0], proj_matrices[1:]\n        B,D,H,W = depth_hypo.shape\n        C = ref_feature.shape[1]\n\n        ref_volume =  ref_feature.unsqueeze(2).repeat(1, 1, D, 1, 1)\n        cor_weight_sum = 1e-8\n        cor_feats = 0\n        # step 2. Epipolar Transformer Aggregation\n        for src_idx, (src_fea, src_proj) in enumerate(zip(src_features, src_projs)):\n            if self.vis_ETA:\n                scan_name = fn[0].split('/')[0]\n                image_name = fn[0].split('/')[2][:-2]\n                save_fn = './debug_figs/vis_ETA/{}_stage{}_src{}'.format(scan_name+'_'+image_name, stage_idx, src_idx)\n            else:\n                save_fn = None\n            src_proj_new = src_proj[:, 0].clone()\n            src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4])\n            ref_proj_new = ref_proj[:, 0].clone()\n            ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4])\n            warped_src = homo_warping(src_fea, src_proj_new, ref_proj_new, depth_hypo, self.vis_ETA, save_fn)  # B C D H W\n            if group_cor:\n                warped_src = warped_src.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W)\n                ref_volume = ref_volume.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W)\n                cor_feat = (warped_src * ref_volume).mean(2)  # B G D H W\n            else:\n                cor_feat = (ref_volume - warped_src)**2 # B C D H W \n            del warped_src, src_proj, src_fea\n            if self.vis_ETA:\n                vis_weight = torch.softmax(cor_feat.sum(1), 1).detach().cpu().numpy()\n                np.save(save_fn, vis_weight)\n\n            if not self.attn_fuse_d:\n                cor_weight = torch.softmax(cor_feat.sum(1), 1).max(1)[0]  # B H W\n                cor_weight_sum += cor_weight  # B H W\n                cor_feats += cor_weight.unsqueeze(1).unsqueeze(1) * cor_feat  # B C D H W\n            else:\n                cor_weight = torch.softmax(cor_feat.sum(1) / self.attn_temp, 1) / math.sqrt(C)  # B D H W\n                cor_weight_sum += cor_weight  # B D H W\n                cor_feats += cor_weight.unsqueeze(1) * cor_feat  # B C D H W\n            del cor_weight, cor_feat\n        if not self.attn_fuse_d:\n            cor_feats = cor_feats / cor_weight_sum.unsqueeze(1).unsqueeze(1)  # B C D H W\n        else:\n            cor_feats = cor_feats / cor_weight_sum.unsqueeze(1)  # B C D H W\n\n        del cor_weight_sum, src_features\n        \n    \n        # step 3. regularization\n        attn_weight = regnet(cor_feats)  # B D H W\n        del cor_feats\n        attn_weight = F.softmax(attn_weight, dim=1)  # B D H W\n\n        # step 4. depth argmax\n        attn_max_indices = attn_weight.max(1, keepdim=True)[1]  # B 1 H W\n        depth = torch.gather(depth_hypo, 1, attn_max_indices).squeeze(1)  # B H W\n\n        if not self.training:\n            with torch.no_grad():\n                photometric_confidence = attn_weight.max(1)[0]  # B H W\n                photometric_confidence = F.interpolate(photometric_confidence.unsqueeze(1), scale_factor=2**(3-stage_idx), mode='bilinear', align_corners=True).squeeze(1)\n        else:\n            photometric_confidence = torch.tensor(0.0, dtype=torch.float32, device=ref_feature.device, requires_grad=False)\n        \n        ret_dict = {\"depth\": depth,  \"photometric_confidence\": photometric_confidence, \"hypo_depth\": depth_hypo, \"attn_weight\": attn_weight}\n        \n        if self.inverse_depth:\n            last_depth_itv = 1./depth_hypo[:,2,:,:] - 1./depth_hypo[:,1,:,:]\n            inverse_min_depth = 1/depth + split_itv * last_depth_itv  # B H W\n            inverse_max_depth = 1/depth - split_itv * last_depth_itv  # B H W\n            ret_dict['inverse_min_depth'] = inverse_min_depth\n            ret_dict['inverse_max_depth'] = inverse_max_depth\n\n        # if self.mono and self.training:\n        if self.mono:\n            ret_dict['mono_feat'] = ref_feature  # B C H W\n            \n        return ret_dict\n \ndef sinkhorn(gt_depth, hypo_depth, attn_weight, mask, iters, eps=1, continuous=False):\n    \"\"\"\n    gt_depth: B H W\n    hypo_depth: B D H W\n    attn_weight: B D H W\n    mask: B H W\n    \"\"\"\n    B,D,H,W = attn_weight.shape\n    if not continuous:\n        D_map = torch.stack([torch.arange(-i,D-i,1, dtype=torch.float32, device=gt_depth.device) for i in range(D)], dim=1).abs()\n        D_map = D_map[None,None,:,:].repeat(B,H*W,1,1)  # B HW D D\n        gt_indices = torch.abs(hypo_depth - gt_depth[:,None,:,:]).min(1)[1].squeeze(1).reshape(B*H*W, 1)  # BHW, 1\n        gt_dist = torch.zeros_like(hypo_depth).permute(0,2,3,1).reshape(B*H*W, D)\n        gt_dist.scatter_add_(1,gt_indices,torch.ones([gt_dist.shape[0],1], dtype=gt_dist.dtype, device=gt_dist.device))\n        gt_dist = gt_dist.reshape(B,H*W,D)  # B HW D\n    else:\n        gt_dist = torch.zeros((B,H*W,D+1), dtype=torch.float32, device=gt_depth.device, requires_grad=False)  # B HW D+1\n        gt_dist[:,:,-1] = 1\n        D_map = torch.zeros((B,D,D+1), dtype=torch.float32, device=gt_depth.device, requires_grad=False)  # B D D+1\n        D_map[:, :D, :D] = torch.stack([torch.arange(-i,D-i,1, dtype=torch.float32, device=gt_depth.device) for i in range(D)], dim=1).abs().unsqueeze(0)  # B D D+1\n        D_map = D_map[:,None,None,:,:].repeat(1,H,W,1,1)  # B H W D D+1\n        itv = 1/hypo_depth[:,2,:,:] - 1/hypo_depth[:,1,:,:]  # B H W\n        gt_bin_distance_ = (1/gt_depth - 1/hypo_depth[:,0,:,:]) / itv  # B H W\n        #FIXME hard code 100\n        gt_bin_distance_[~mask] = 10\n\n        gt_bin_distance = torch.stack([(gt_bin_distance_ - i).abs() for i in range(D)], dim=1).permute(0,2,3,1)  # B H W D\n        D_map[:,:,:,:,-1] = gt_bin_distance\n        D_map = D_map.reshape(B,H*W,D,1+D)  # B HW D D+1\n\n    pred_dist = attn_weight.permute(0,2,3,1).reshape(B,H*W,D)  # B HW D\n\n    # map to log space for stability\n    log_mu = (gt_dist+1e-12).log()\n    log_nu = (pred_dist+1e-12).log()  # B HW D or D+1\n\n    u, v = torch.zeros_like(log_nu), torch.zeros_like(log_mu)\n    for _ in range(iters):\n        # scale v first then u to ensure row sum is 1, col sum slightly larger than 1\n        v = log_mu - torch.logsumexp(D_map/eps + u.unsqueeze(3), dim=2)  # log(sum(exp()))\n        u = log_nu - torch.logsumexp(D_map/eps + v.unsqueeze(2), dim=3)\n\n    # convert back from log space, recover probabilities by normalization 2W\n    T_map = (D_map/eps + u.unsqueeze(3) + v.unsqueeze(2)).exp()  # B HW D D\n    loss = (T_map * D_map).reshape(B*H*W,-1)[mask.reshape(-1)].sum(-1).mean()\n    \n    return T_map, loss"
  },
  {
    "path": "requirements.txt",
    "content": "torch==1.9.0\ntorchvision==0.10.0\nnumpy\npillow\ntensorboardX\nopencv-python\nplyfile"
  },
  {
    "path": "scripts/test_dtu.sh",
    "content": "#!/usr/bin/env bash\nDTU_TESTPATH=\"/mnt/cfs/algorithm/public_data/mvs/dtu_test\"\nDTU_TESTLIST=\"lists/dtu/test.txt\"\n\nDTU_size=$1\nexp=$2\nPY_ARGS=${@:3}\n\nDTU_LOG_DIR=\"./checkpoints/dtu/\"$exp \nif [ ! -d $DTU_LOG_DIR ]; then\n    mkdir -p $DTU_LOG_DIR\nfi\nDTU_CKPT_FILE=$DTU_LOG_DIR\"/finalmodel.ckpt\"\nDTU_OUT_DIR=\"./outputs/dtu/\"$exp\n\n\n\nif [ $DTU_size = \"raw\" ] ; then\npython test_mvs4.py --dataset=general_eval4 --batch_size=1 --testpath=$DTU_TESTPATH  --testlist=$DTU_TESTLIST --loadckpt $DTU_CKPT_FILE --interval_scale 1.06 --outdir $DTU_OUT_DIR\\\n             --use_raw_train --thres_view 4 --conf 0.5 --group_cor --attn_temp 2 --inverse_depth $PY_ARGS | tee -a $DTU_LOG_DIR/log_test.txt\nelse\npython test_mvs4.py --dataset=general_eval4 --batch_size=1 --testpath=$DTU_TESTPATH  --testlist=$DTU_TESTLIST --loadckpt $DTU_CKPT_FILE --interval_scale 1.06 --outdir $DTU_OUT_DIR\\\n             --thres_view 4 --conf 0.5 --group_cor --attn_temp 2 --inverse_depth $PY_ARGS | tee -a $DTU_LOG_DIR/log_test.txt\nfi\n"
  },
  {
    "path": "scripts/train_dtu.sh",
    "content": "#!/usr/bin/env bash\nDTU_TRAINING=\"/mnt/cfs/algorithm/public_data/mvs/mvs_training/dtu\"\nDTU_TRAINLIST=\"lists/dtu/train.txt\"\nDTU_TESTLIST=\"lists/dtu/test.txt\"\n\nDTU_trainsize=$1\nexp=$2\nPY_ARGS=${@:3}\n\nDTU_LOG_DIR=\"./checkpoints/dtu/\"$exp \nif [ ! -d $DTU_LOG_DIR ]; then\n    mkdir -p $DTU_LOG_DIR\nfi\n\nDTU_CKPT_FILE=$DTU_LOG_DIR\"/finalmodel.ckpt\"\nDTU_OUT_DIR=\"./outputs/dtu/\"$exp\n\n\nif [ $DTU_trainsize = \"raw\" ] ; then\npython -m torch.distributed.launch --nproc_per_node=4 train_mvs4.py --logdir $DTU_LOG_DIR --dataset=dtu_yao4 --batch_size=2 --trainpath=$DTU_TRAINING --summary_freq 100 \\\n                --group_cor --inverse_depth --rt --mono --attn_temp 2 --use_raw_train --trainlist $DTU_TRAINLIST --testlist $DTU_TESTLIST  $PY_ARGS | tee -a $DTU_LOG_DIR/log.txt\nelse\npython -m torch.distributed.launch --nproc_per_node=4 train_mvs4.py --logdir $DTU_LOG_DIR --dataset=dtu_yao4 --batch_size=2 --trainpath=$DTU_TRAINING --summary_freq 100 \\\n                --group_cor --inverse_depth --rt --mono --attn_temp 2 --trainlist $DTU_TRAINLIST --testlist $DTU_TESTLIST  $PY_ARGS | tee -a $DTU_LOG_DIR/log.txt\nfi\n\n"
  },
  {
    "path": "test_mvs4.py",
    "content": "import argparse, os, time, sys, gc, cv2\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nfrom torch.utils.data import DataLoader\nimport torch.nn.functional as F\nimport numpy as np\nfrom datasets import find_dataset_def\nfrom models import *\nfrom utils import *\nfrom datasets.data_io import read_pfm, save_pfm\nfrom plyfile import PlyData, PlyElement\nfrom PIL import Image\n\nfrom multiprocessing import Pool\nfrom functools import partial\nimport signal\n\ncudnn.benchmark = True\n\nparser = argparse.ArgumentParser(description='Predict depth, filter, and fuse')\nparser.add_argument('--model', default='mvsnet', help='select model')\n\nparser.add_argument('--dataset', default='dtu_yao_eval', help='select dataset')\nparser.add_argument('--testpath', help='testing data dir for some scenes')\nparser.add_argument('--testlist', help='testing scene list')\n\nparser.add_argument('--batch_size', type=int, default=1, help='testing batch size')\n\nparser.add_argument('--loadckpt', default=None, help='load a specific checkpoint')\nparser.add_argument('--outdir', default='./outputs', help='output dir')\n\nparser.add_argument('--share_cr', action='store_true', help='whether share the cost volume regularization')\n\nparser.add_argument('--ndepths', type=str, default=\"8,8,4,4\", help='ndepths')\nparser.add_argument('--depth_inter_r', type=str, default=\"0.5,0.5,0.5,1\", help='depth_intervals_ratio')\n\nparser.add_argument('--interval_scale', type=float, required=True, help='the depth interval scale')\nparser.add_argument('--num_view', type=int, default=5, help='num of view')\nparser.add_argument('--max_h', type=int, default=864, help='testing max h')\nparser.add_argument('--max_w', type=int, default=1152, help='testing max w')\nparser.add_argument('--fix_res', action='store_true', help='scene all using same res')\n\nparser.add_argument('--num_worker', type=int, default=4, help='depth_filer worker')\nparser.add_argument('--save_freq', type=int, default=20, help='save freq of local pcd')\n\nparser.add_argument('--filter_method', type=str, default='normal', choices=[\"gipuma\", \"normal\"], help=\"filter method\")\n\n#filter\nparser.add_argument('--conf', type=float, default=0.9, help='prob confidence')\nparser.add_argument('--thres_view', type=int, default=5, help='threshold of num view')\n\nparser.add_argument(\"--fpn_base_channel\", type=int, default=8)\nparser.add_argument(\"--reg_channel\", type=int, default=8)\nparser.add_argument('--reg_mode', type=str, default=\"reg2d\")\nparser.add_argument('--dlossw', type=str, default=\"1,1,1,1\", help='depth loss weight for different stage')\nparser.add_argument('--resume', action='store_true', help='continue to train the model')\nparser.add_argument('--group_cor', action='store_true',help='group correlation')\nparser.add_argument('--group_cor_dim', type=str, default=\"8,8,4,4\", help='group correlation dim')\nparser.add_argument('--inverse_depth', action='store_true',help='inverse depth')\nparser.add_argument('--agg_type', type=str, default=\"ConvBnReLU3D\", help='cost regularization type')\nparser.add_argument('--dcn', action='store_true',help='dcn')\nparser.add_argument('--arch_mode', type=str, default=\"fpn\")\nparser.add_argument('--ot_continous', action='store_true',help='optimal transport continous gt bin')\nparser.add_argument('--ot_eps', type=float, default=1)\nparser.add_argument('--ot_iter', type=int, default=0)\nparser.add_argument('--rt', action='store_true',help='robust training')\nparser.add_argument('--use_raw_train', action='store_true',help='using 1200x1600 training')\nparser.add_argument('--mono', action='store_true',help='query to build mono depth prediction and loss')\nparser.add_argument('--split', type=str, default='intermediate', help='intermediate or advanced')\nparser.add_argument('--save_jpg', action='store_true')\nparser.add_argument('--ASFF', action='store_true')\nparser.add_argument('--vis_ETA', action='store_true')\nparser.add_argument('--vis_mono', action='store_true')\nparser.add_argument('--attn_temp', type=float, default=2)\n\n# parse arguments and check\nargs = parser.parse_args()\nprint(\"argv:\", sys.argv[1:])\nprint_args(args)\n\nif args.use_raw_train:\n    args.max_h = 1200\n    args.max_w = 1600\n    \nnum_stage = len([int(nd) for nd in args.ndepths.split(\",\") if nd])\n\nInterval_Scale = args.interval_scale\nprint(\"***********Interval_Scale**********\\n\", Interval_Scale)\n\n\n# read intrinsics and extrinsics\ndef read_camera_parameters(filename):\n    with open(filename) as f:\n        lines = f.readlines()\n        lines = [line.rstrip() for line in lines]\n    # extrinsics: line [1,5), 4x4 matrix\n    extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))\n    # intrinsics: line [7-10), 3x3 matrix\n    intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))\n    return intrinsics, extrinsics\n\n\n# read an image\ndef read_img(filename):\n    img = Image.open(filename)\n    # scale 0~255 to 0~1\n    np_img = np.array(img, dtype=np.float32) / 255.\n    return np_img\n\n\n# read a binary mask\ndef read_mask(filename):\n    return read_img(filename) > 0.5\n\n\n# save a binary mask\ndef save_mask(filename, mask):\n    assert mask.dtype == np.bool\n    mask = mask.astype(np.uint8) * 255\n    Image.fromarray(mask).save(filename)\n\n\n# read a pair file, [(ref_view1, [src_view1-1, ...]), (ref_view2, [src_view2-1, ...]), ...]\ndef read_pair_file(filename):\n    data = []\n    with open(filename) as f:\n        num_viewpoint = int(f.readline())\n        # 49 viewpoints\n        for view_idx in range(num_viewpoint):\n            ref_view = int(f.readline().rstrip())\n            src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]\n            if len(src_views) > 0:\n                data.append((ref_view, src_views))\n    return data\n\ndef write_cam(file, cam):\n    f = open(file, \"w\")\n    f.write('extrinsic\\n')\n    for i in range(0, 4):\n        for j in range(0, 4):\n            f.write(str(cam[0][i][j]) + ' ')\n        f.write('\\n')\n    f.write('\\n')\n\n    f.write('intrinsic\\n')\n    for i in range(0, 3):\n        for j in range(0, 3):\n            f.write(str(cam[1][i][j]) + ' ')\n        f.write('\\n')\n\n    f.write('\\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' ' + str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\\n')\n\n    f.close()\n\ndef save_depth(testlist):\n    torch.cuda.reset_peak_memory_stats()\n    total_time = 0\n    total_sample = 0\n    for scene in testlist:\n        time_this_scene, sample_this_scene = save_scene_depth([scene])\n        total_time += time_this_scene\n        total_sample += sample_this_scene\n    gpu_measure = torch.cuda.max_memory_allocated() / 1024. / 1024. /1024.    \n    print('avg time: {}'.format(total_time/total_sample))\n    print('max gpu: {}'.format(gpu_measure))\n\n\ndef save_scene_depth(testlist):\n    # dataset, dataloader\n    MVSDataset = find_dataset_def(args.dataset)\n    test_dataset = MVSDataset(args.testpath, testlist, \"test\", args.num_view, Interval_Scale,\n                            max_h=args.max_h, max_w=args.max_w, fix_res=args.fix_res)\n    TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False)\n\n    # model\n    model = MVS4net(arch_mode=args.arch_mode, reg_net=args.reg_mode, num_stage=4, \n                    fpn_base_channel=args.fpn_base_channel, reg_channel=args.reg_channel, \n                    stage_splits=[int(n) for n in args.ndepths.split(\",\")], \n                    depth_interals_ratio=[float(ir) for ir in args.depth_inter_r.split(\",\")],\n                    group_cor=args.group_cor, group_cor_dim=[int(n) for n in args.group_cor_dim.split(\",\")],\n                    inverse_depth=args.inverse_depth,\n                    agg_type=args.agg_type,\n                    dcn=args.dcn,\n                    mono=args.mono,\n                    asff=args.ASFF,\n                    attn_temp=args.attn_temp,\n                    vis_ETA=args.vis_ETA,\n                    vis_mono=args.vis_mono\n                )\n    # load checkpoint file specified by args.loadckpt\n    print(\"loading model {}\".format(args.loadckpt))\n    state_dict = torch.load(args.loadckpt, map_location=torch.device(\"cpu\"))\n    model.load_state_dict(state_dict['model'], strict=True)\n    model = nn.DataParallel(model)\n    model.cuda()\n    model.eval()\n    \n    total_time = 0\n    with torch.no_grad():\n        for batch_idx, sample in enumerate(TestImgLoader):\n            sample_cuda = tocuda(sample)\n            start_time = time.time()\n            outputs = model(sample_cuda[\"imgs\"], sample_cuda[\"proj_matrices\"], sample_cuda[\"depth_values\"], sample[\"filename\"])\n            end_time = time.time()\n            total_time += end_time - start_time\n            outputs = tensor2numpy(outputs)\n            del sample_cuda\n            filenames = sample[\"filename\"]\n            cams = sample[\"proj_matrices\"][\"stage{}\".format(num_stage)].numpy()\n            imgs = sample[\"imgs\"]\n            print('Iter {}/{}, Time:{} Res:{}'.format(batch_idx, len(TestImgLoader), end_time - start_time, imgs[0].shape))\n\n            # save depth maps and confidence maps\n            for filename, cam, img, depth_est, photometric_confidence in zip(filenames, cams, imgs, \\\n                                                            outputs[\"depth\"], outputs[\"photometric_confidence\"]):\n                img = img[0].numpy()  #ref view\n                cam = cam[0]  #ref cam\n                depth_filename = os.path.join(args.outdir, filename.format('depth_est', '.pfm'))\n                confidence_filename = os.path.join(args.outdir, filename.format('confidence', '.pfm'))\n                cam_filename = os.path.join(args.outdir, filename.format('cams', '_cam.txt'))\n                img_filename = os.path.join(args.outdir, filename.format('images', '.jpg'))\n                ply_filename = os.path.join(args.outdir, filename.format('ply_local', '.ply'))\n                os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True)\n                os.makedirs(confidence_filename.rsplit('/', 1)[0], exist_ok=True)\n                os.makedirs(cam_filename.rsplit('/', 1)[0], exist_ok=True)\n                os.makedirs(img_filename.rsplit('/', 1)[0], exist_ok=True)\n                os.makedirs(ply_filename.rsplit('/', 1)[0], exist_ok=True)\n                #save depth maps\n                save_pfm(depth_filename, depth_est)\n                if args.save_jpg:\n                    for stage_idx in range(4):\n                        depth_jpg_filename = os.path.join(args.outdir, filename.format('depth_est', '{}_{}.jpg'.format('stage',str(stage_idx+1))))\n                        stage_depth = outputs['stage{}'.format(stage_idx+1)]['depth'][0]\n                        mi = np.min(stage_depth[stage_depth>0])\n                        ma = np.max(stage_depth)\n                        depth = (stage_depth-mi)/(ma-mi+1e-8)\n                        depth = (255*depth).astype(np.uint8)\n                        depth_img = cv2.applyColorMap(depth, cv2.COLORMAP_JET)\n                        print(cv2.imwrite(depth_jpg_filename, depth_img))\n                        if stage_idx == 0:\n                            continue\n                        mono_depth_jpg_filename = os.path.join(args.outdir, filename.format('depth_est', '{}_{}.jpg'.format('mono',str(stage_idx+1))))\n                        stage_mono_depth = outputs['stage{}'.format(stage_idx+1)]['mono_depth'][0]\n                        mi = np.min(stage_mono_depth[stage_mono_depth>0])\n                        ma = np.max(stage_mono_depth)\n                        depth = (stage_mono_depth-mi)/(ma-mi+1e-8)\n                        depth = (255*depth).astype(np.uint8)\n                        depth_img = cv2.applyColorMap(depth, cv2.COLORMAP_JET)\n                        print(cv2.imwrite(mono_depth_jpg_filename, depth_img))\n                #save confidence maps\n                confidence_list = [outputs['stage{}'.format(i)]['photometric_confidence'].squeeze(0) for i in range(1,5)]\n\n                photometric_confidence = confidence_list[-1]  # H W\n                save_pfm(confidence_filename, photometric_confidence) \n                #save cams, img\n                write_cam(cam_filename, cam)\n                img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype(np.uint8)\n                img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)\n                cv2.imwrite(img_filename, img_bgr)\n\n                if batch_idx % args.save_freq == 0:\n                    generate_pointcloud(img, depth_est, ply_filename, cam[1, :3, :3])\n\n    torch.cuda.empty_cache()\n    gc.collect()\n    return total_time, len(TestImgLoader)\n\n\n\n# project the reference point cloud into the source view, then project back\ndef reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src):\n    width, height = depth_ref.shape[1], depth_ref.shape[0]\n    ## step1. project reference pixels to the source view\n    # reference view x, y\n    x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))\n    x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1])\n    # reference 3D space\n    xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref),\n                        np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1]))\n    # source 3D space\n    xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)),\n                        np.vstack((xyz_ref, np.ones_like(x_ref))))[:3]\n    # source view x, y\n    K_xyz_src = np.matmul(intrinsics_src, xyz_src)\n    xy_src = K_xyz_src[:2] / K_xyz_src[2:3]\n\n    ## step2. reproject the source view points with source view depth estimation\n    # find the depth estimation of the source view\n    x_src = xy_src[0].reshape([height, width]).astype(np.float32)\n    y_src = xy_src[1].reshape([height, width]).astype(np.float32)\n    sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR)\n    # mask = sampled_depth_src > 0\n\n    # source 3D space\n    # NOTE that we should use sampled source-view depth_here to project back\n    xyz_src = np.matmul(np.linalg.inv(intrinsics_src),\n                        np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1]))\n    # reference 3D space\n    xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)),\n                                np.vstack((xyz_src, np.ones_like(x_ref))))[:3]\n    # source view x, y, depth\n    depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32)\n    K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected)\n    xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3]\n    x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32)\n    y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32)\n\n    return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src\n\n\ndef check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src):\n    width, height = depth_ref.shape[1], depth_ref.shape[0]\n    x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height))\n    depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref,\n                                                     depth_src, intrinsics_src, extrinsics_src)\n    # check |p_reproj-p_1| < 1\n    dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2)\n\n    # check |d_reproj-d_1| / d_1 < 0.01\n    depth_diff = np.abs(depth_reprojected - depth_ref)\n    relative_depth_diff = depth_diff / depth_ref\n\n    mask = np.logical_and(dist < 1, relative_depth_diff < 0.01)\n    depth_reprojected[~mask] = 0\n\n    return mask, depth_reprojected, x2d_src, y2d_src\n\n\ndef filter_depth(pair_folder, scan_folder, out_folder, plyfilename):\n    # the pair file\n    pair_file = os.path.join(pair_folder, \"pair.txt\")\n    # for the final point cloud\n    vertexs = []\n    vertex_colors = []\n\n    pair_data = read_pair_file(pair_file)\n\n    # for each reference view and the corresponding source views\n    for ref_view, src_views in pair_data:\n        # src_views = src_views[:args.num_view]\n        # load the camera parameters\n        ref_intrinsics, ref_extrinsics = read_camera_parameters(\n            os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view)))\n        # load the reference image\n        ref_img = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view)))\n        # load the estimated depth of the reference view\n        ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0]\n        # load the photometric mask of the reference view\n        confidence = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}.pfm'.format(ref_view)))[0]\n        photo_mask = confidence > args.conf\n\n        all_srcview_depth_ests = []\n        all_srcview_x = []\n        all_srcview_y = []\n        all_srcview_geomask = []\n\n        # compute the geometric mask\n        geo_mask_sum = 0\n        for src_view in src_views:\n            # camera parameters of the source view\n            src_intrinsics, src_extrinsics = read_camera_parameters(\n                os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(src_view)))\n            # the estimated depth of the source view\n            src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0]\n\n            geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics,\n                                                                      src_depth_est,\n                                                                      src_intrinsics, src_extrinsics)\n            geo_mask_sum += geo_mask.astype(np.int32)\n            all_srcview_depth_ests.append(depth_reprojected)\n            all_srcview_x.append(x2d_src)\n            all_srcview_y.append(y2d_src)\n            all_srcview_geomask.append(geo_mask)\n\n        depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1)\n        # at least 3 source views matched\n        geo_mask = geo_mask_sum >= args.thres_view\n        final_mask = np.logical_and(photo_mask, geo_mask)\n\n        os.makedirs(os.path.join(out_folder, \"mask\"), exist_ok=True)\n        save_mask(os.path.join(out_folder, \"mask/{:0>8}_photo.png\".format(ref_view)), photo_mask)\n        save_mask(os.path.join(out_folder, \"mask/{:0>8}_geo.png\".format(ref_view)), geo_mask)\n        save_mask(os.path.join(out_folder, \"mask/{:0>8}_final.png\".format(ref_view)), final_mask)\n\n        print(\"processing {}, ref-view{:0>2}, photo/geo/final-mask:{}/{}/{}\".format(scan_folder, ref_view,\n                                                                                    photo_mask.mean(),\n                                                                                    geo_mask.mean(), final_mask.mean()))\n\n        height, width = depth_est_averaged.shape[:2]\n        x, y = np.meshgrid(np.arange(0, width), np.arange(0, height))\n        # valid_points = np.logical_and(final_mask, ~used_mask[ref_view])\n        valid_points = final_mask\n        print(\"valid_points\", valid_points.mean())\n        x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points]\n        #color = ref_img[1:-16:4, 1::4, :][valid_points]  # hardcoded for DTU dataset\n        color = ref_img[valid_points]\n\n        xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics),\n                            np.vstack((x, y, np.ones_like(x))) * depth)\n        xyz_world = np.matmul(np.linalg.inv(ref_extrinsics),\n                              np.vstack((xyz_ref, np.ones_like(x))))[:3]\n        vertexs.append(xyz_world.transpose((1, 0)))\n        vertex_colors.append((color * 255).astype(np.uint8))\n\n\n    vertexs = np.concatenate(vertexs, axis=0)\n    vertex_colors = np.concatenate(vertex_colors, axis=0)\n    vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])\n    vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])\n\n    vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr)\n    for prop in vertexs.dtype.names:\n        vertex_all[prop] = vertexs[prop]\n    for prop in vertex_colors.dtype.names:\n        vertex_all[prop] = vertex_colors[prop]\n\n    el = PlyElement.describe(vertex_all, 'vertex')\n    PlyData([el]).write(plyfilename)\n    print(\"saving the final model to\", plyfilename)\n\n\ndef init_worker():\n    '''\n    Catch Ctrl+C signal to termiante workers\n    '''\n    signal.signal(signal.SIGINT, signal.SIG_IGN)\n\n\ndef pcd_filter_worker(scan):\n    if args.testlist != \"all\":\n        scan_id = int(scan[4:])\n        save_name = 'mvsnet{:0>3}_l3.ply'.format(scan_id)\n    else:\n        save_name = '{}.ply'.format(scan)\n    pair_folder = os.path.join(args.testpath, scan)\n    scan_folder = os.path.join(args.outdir, scan)\n    out_folder = os.path.join(args.outdir, scan)\n    filter_depth(pair_folder, scan_folder, out_folder, os.path.join(args.outdir, save_name))\n\n\ndef pcd_filter(testlist, number_worker):\n\n    partial_func = partial(pcd_filter_worker)\n\n    p = Pool(number_worker, init_worker)\n    try:\n        p.map(partial_func, testlist)\n    except KeyboardInterrupt:\n        print(\"....\\nCaught KeyboardInterrupt, terminating workers\")\n        p.terminate()\n    else:\n        p.close()\n    p.join()\n\ndef mrun_rst(eval_dir, plyPath):\n    print('Runing BaseEvalMain_func.m...')\n    os.chdir(eval_dir)\n    os.system('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/misc/matlab/bin/matlab -nodesktop -nosplash -r \"BaseEvalMain_func(\\'{}\\'); quit\" '.format(plyPath))\n    print('Runing ComputeStat_func.m...')\n    os.system('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/misc/matlab/bin/matlab -nodesktop -nosplash -r \"ComputeStat_func(\\'{}\\'); quit\" '.format(plyPath))\n    print('Check your results! ^-^')\n\nif __name__ == '__main__':\n\n    if args.vis_ETA:\n        os.makedirs('./debug_figs/vis_ETA', exist_ok=True)\n\n    if args.testlist != \"all\":\n        with open(args.testlist) as f:\n            content = f.readlines()\n            testlist = [line.rstrip() for line in content]\n\n    # step1. save all the depth maps and the masks in outputs directory\n    save_depth(testlist)\n\n    if args.dataset.startswith('general'):\n        # step2. filter saved depth maps with photometric confidence maps and geometric constraints\n        pcd_filter(testlist, args.num_worker)\n\n        # Make sure the matlab is installed and you can comment out the following lines\n        # And you also need to change the path of the matlab script\n\n        mrun_rst(\n            eval_dir='./evaluations/dtu/',\n            plyPath='./'+args.outdir[1:]\n        )\n\n\n "
  },
  {
    "path": "train_mvs4.py",
    "content": "import argparse, os, sys, time, gc, datetime\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backends.cudnn as cudnn\nimport torch.optim as optim\nfrom torch.utils.data import DataLoader\nfrom tensorboardX import SummaryWriter\nfrom datasets import find_dataset_def\nfrom models import *\nfrom utils import *\nimport torch.distributed as dist\n\ncudnn.benchmark = True\n\nparser = argparse.ArgumentParser(description='A PyTorch Implementation of MVSTER')\nparser.add_argument('--mode', default='train', help='train or test', choices=['train', 'test', 'profile'])\nparser.add_argument('--device', default='cuda', help='select model')\n\nparser.add_argument('--dataset', default='dtu_yao4', help='select dataset')\nparser.add_argument('--trainpath', help='train datapath')\nparser.add_argument('--testpath', help='test datapath')\nparser.add_argument('--trainlist', help='train list')\nparser.add_argument('--testlist', help='test list')\n\nparser.add_argument('--epochs', type=int, default=10, help='number of epochs to train')\nparser.add_argument('--lr', type=float, default=0.001, help='learning rate')\nparser.add_argument('--lrepochs', type=str, default=\"6,8,9:2\", help='epoch ids to downscale lr and the downscale rate')\nparser.add_argument('--wd', type=float, default=0.0, help='weight decay')\n\nparser.add_argument('--batch_size', type=int, default=1, help='train batch size')\nparser.add_argument('--interval_scale', type=float, default=1.06, help='the number of depth values')\n\nparser.add_argument('--loadckpt', default=None, help='load a specific checkpoint')\nparser.add_argument('--logdir', default='./checkpoints/debug', help='the directory to save checkpoints/logs')\nparser.add_argument('--resume', action='store_true', help='continue to train the model')\n\nparser.add_argument('--summary_freq', type=int, default=2, help='print and summary frequency')\nparser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency')\nparser.add_argument('--eval_freq', type=int, default=1, help='eval freq')\n\nparser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed')\nparser.add_argument('--pin_m', action='store_true', help='data loader pin memory')\nparser.add_argument(\"--local_rank\", type=int, default=0)\n\nparser.add_argument('--ndepths', type=str, default=\"8,8,4,4\", help='ndepths')\nparser.add_argument('--depth_inter_r', type=str, default=\"0.5,0.5,0.5,1\", help='depth_intervals_ratio')\nparser.add_argument('--dlossw', type=str, default=\"1,1,1,1\", help='depth loss weight for different stage')\n\nparser.add_argument('--l1ce_lw', type=str, default=\"0,1\", help='loss weight for l1 and ce loss')\nparser.add_argument(\"--fpn_base_channel\", type=int, default=8)\nparser.add_argument(\"--reg_channel\", type=int, default=8)\nparser.add_argument('--reg_mode', type=str, default=\"reg2d\")\n\nparser.add_argument('--group_cor', action='store_true',help='group correlation')\nparser.add_argument('--group_cor_dim', type=str, default=\"8,8,4,4\", help='group correlation dim')\n\nparser.add_argument('--inverse_depth', action='store_true',help='inverse depth')\nparser.add_argument('--agg_type', type=str, default=\"ConvBnReLU3D\", help='cost regularization type')\nparser.add_argument('--dcn', action='store_true',help='dcn')\nparser.add_argument('--pos_enc', type=int, default=0, help='pos_enc: 0 no pos enc; 1 depth sine; 2 learnable pos enc')\nparser.add_argument('--arch_mode', type=str, default=\"fpn\")\n\nparser.add_argument('--ot_continous', action='store_true',help='optimal transport continous gt bin')\nparser.add_argument('--ot_iter', type=int, default=10)\nparser.add_argument('--ot_eps', type=float, default=1)\n\nparser.add_argument('--rt', action='store_true',help='robust training')\n\nparser.add_argument('--max_h', type=int, default=864, help='testing max h')\nparser.add_argument('--max_w', type=int, default=1152, help='testing max w')\nparser.add_argument('--use_raw_train', action='store_true',help='using 1200x1600 training')\nparser.add_argument('--mono', action='store_true',help='query to build mono depth prediction and loss')\nparser.add_argument('--lr_scheduler', type=str, default='MS')\nparser.add_argument('--ASFF', action='store_true')\nparser.add_argument('--attn_temp', type=float, default=2)\n\n\nnum_gpus = int(os.environ[\"WORLD_SIZE\"]) if \"WORLD_SIZE\" in os.environ else 1\nis_distributed = num_gpus > 1\n\n# main function\ndef train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args):\n    milestones = [len(TrainImgLoader) * int(epoch_idx) for epoch_idx in args.lrepochs.split(':')[0].split(',')]\n    lr_gamma = 1 / float(args.lrepochs.split(':')[1])\n    if args.lr_scheduler == 'MS':\n        lr_scheduler = WarmupMultiStepLR(optimizer, milestones, gamma=lr_gamma, warmup_factor=1.0/3, warmup_iters=500,\n                                                            last_epoch=len(TrainImgLoader) * start_epoch - 1)\n    elif args.lr_scheduler == 'cos':\n        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(args.epochs*len(TrainImgLoader)), eta_min=0)\n    elif args.lr_scheduler == 'onecycle':\n        lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr,total_steps=int(args.epochs*len(TrainImgLoader)))\n\n    for epoch_idx in range(start_epoch, args.epochs):\n        print('Epoch {}:'.format(epoch_idx))\n        global_step = len(TrainImgLoader) * epoch_idx\n\n        # training\n        for batch_idx, sample in enumerate(TrainImgLoader):\n            start_time = time.time()\n            global_step = len(TrainImgLoader) * epoch_idx + batch_idx\n            do_summary = global_step % args.summary_freq == 0\n            loss, scalar_outputs, image_outputs = train_sample(model, model_loss, optimizer, sample, args)\n            lr_scheduler.step()\n            if (not is_distributed) or (dist.get_rank() == 0):\n                if do_summary:\n                    save_scalars(logger, 'train', scalar_outputs, global_step)\n                    save_images(logger, 'train', image_outputs, global_step)\n                    print(\n                       \"Epoch {}/{}, Iter {}/{}, lr {:.6f}, train loss = {:.3f}, d_loss = {:.3f}, {:.3f}, {:.3f}, {:.3f}, c_loss = {:.3f}, {:.3f}, {:.3f}, {:.3f}, range_err = {:.3f}, {:.3f}, {:.3f}, {:.3f}, time = {:.3f}\".format(\n                           epoch_idx, args.epochs, batch_idx, len(TrainImgLoader),\n                           optimizer.param_groups[0][\"lr\"], \n                           loss,\n                           scalar_outputs[\"s0_d_loss\"],\n                           scalar_outputs[\"s1_d_loss\"],\n                           scalar_outputs[\"s2_d_loss\"],\n                           scalar_outputs[\"s3_d_loss\"],\n                           scalar_outputs[\"s0_c_loss\"],\n                           scalar_outputs[\"s1_c_loss\"],\n                           scalar_outputs[\"s2_c_loss\"],\n                           scalar_outputs[\"s3_c_loss\"],\n                           scalar_outputs[\"s0_range_err_ratio\"],\n                           scalar_outputs[\"s1_range_err_ratio\"],\n                           scalar_outputs[\"s2_range_err_ratio\"],\n                           scalar_outputs[\"s3_range_err_ratio\"],\n                           time.time() - start_time))\n                del scalar_outputs, image_outputs\n\n        # checkpoint\n        if (not is_distributed) or (dist.get_rank() == 0):\n            if (epoch_idx + 1) % args.save_freq == 0:\n                if epoch_idx == args.epochs - 1:\n                    torch.save({\n                        'epoch': epoch_idx,\n                        'model': model.module.state_dict(),\n                        'optimizer': optimizer.state_dict()},\n                        \"{}/finalmodel.ckpt\".format(args.logdir))  \n        gc.collect()\n\n        # testing\n        if (epoch_idx % args.eval_freq == 0) or (epoch_idx == args.epochs - 1):\n            avg_test_scalars = DictAverageMeter()\n            for batch_idx, sample in enumerate(TestImgLoader):\n                start_time = time.time()\n                global_step = len(TrainImgLoader) * epoch_idx + batch_idx\n                do_summary = global_step % args.summary_freq == 0\n                loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args)\n                if (not is_distributed) or (dist.get_rank() == 0):\n                    if do_summary:\n                        save_scalars(logger, 'test', scalar_outputs, global_step)\n                        save_images(logger, 'test', image_outputs, global_step)\n                        print(\n                            \"Epoch {}/{}, Iter {}/{}, lr {:.6f}, test loss = {:.3f}, d_loss = {:.3f}, {:.3f}, {:.3f}, {:.3f}, c_loss = {:.3f}, {:.3f}, {:.3f}, {:.3f}, range_err = {:.3f}, {:.3f}, {:.3f}, {:.3f}, time = {:.3f}\".format(\n                            epoch_idx, args.epochs, batch_idx, len(TrainImgLoader),\n                               optimizer.param_groups[0][\"lr\"], \n                           loss,\n                           scalar_outputs[\"s0_d_loss\"],\n                           scalar_outputs[\"s1_d_loss\"],\n                           scalar_outputs[\"s2_d_loss\"],\n                           scalar_outputs[\"s3_d_loss\"],\n                           scalar_outputs[\"s0_c_loss\"],\n                           scalar_outputs[\"s1_c_loss\"],\n                           scalar_outputs[\"s2_c_loss\"],\n                           scalar_outputs[\"s3_c_loss\"],\n                           scalar_outputs[\"s0_range_err_ratio\"],\n                           scalar_outputs[\"s1_range_err_ratio\"],\n                           scalar_outputs[\"s2_range_err_ratio\"],\n                           scalar_outputs[\"s3_range_err_ratio\"],\n                            time.time() - start_time))\n                    avg_test_scalars.update(scalar_outputs)\n                    del scalar_outputs, image_outputs\n\n            if (not is_distributed) or (dist.get_rank() == 0):\n                save_scalars(logger, 'fulltest', avg_test_scalars.mean(), global_step)\n                print(\"avg_test_scalars:\", avg_test_scalars.mean())\n            gc.collect()\n\n\ndef test(model, model_loss, TestImgLoader, args):\n    avg_test_scalars = DictAverageMeter()\n    for batch_idx, sample in enumerate(TestImgLoader):\n        start_time = time.time()\n        loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args)\n        avg_test_scalars.update(scalar_outputs)\n        del scalar_outputs, image_outputs\n        if (not is_distributed) or (dist.get_rank() == 0):\n            print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, len(TestImgLoader), loss,\n                                                                        time.time() - start_time))\n            if batch_idx % 100 == 0:\n                print(\"Iter {}/{}, test results = {}\".format(batch_idx, len(TestImgLoader), avg_test_scalars.mean()))\n    if (not is_distributed) or (dist.get_rank() == 0):\n        print(\"final\", avg_test_scalars.mean())\n\n\ndef train_sample(model, model_loss, optimizer, sample, args):\n    model.train()\n    optimizer.zero_grad()\n\n    sample_cuda = tocuda(sample)\n    depth_gt_ms = sample_cuda[\"depth\"]\n    mask_ms = sample_cuda[\"mask\"]\n\n    num_stage = len([int(nd) for nd in args.ndepths.split(\",\") if nd])\n    depth_gt = depth_gt_ms[\"stage{}\".format(num_stage)]\n    mask = mask_ms[\"stage{}\".format(num_stage)]\n\n    outputs = model(sample_cuda[\"imgs\"], sample_cuda[\"proj_matrices\"], sample_cuda[\"depth_values\"])\n    depth_est = outputs[\"depth\"]\n\n    loss, stage_d_loss, stage_c_loss, range_err_ratio = model_loss(\n                                        outputs, depth_gt_ms, mask_ms, stage_lw=[float(e) for e in args.dlossw.split(\",\") if e], \n                                        l1ce_lw=[float(lw) for lw in args.l1ce_lw.split(\",\")],\n                                        inverse_depth=args.inverse_depth,\n                                        ot_iter=args.ot_iter, ot_continous=args.ot_continous, ot_eps=args.ot_eps,\n                                        mono=args.mono\n                                        )\n    loss.backward()\n    optimizer.step()\n\n    scalar_outputs = {\"loss\": loss,\n                      \"s0_d_loss\": stage_d_loss[0],\n                      \"s1_d_loss\": stage_d_loss[1],\n                      \"s2_d_loss\": stage_d_loss[2],\n                      \"s3_d_loss\": stage_d_loss[3],\n                      \"s0_c_loss\": stage_c_loss[0],\n                      \"s1_c_loss\": stage_c_loss[1],\n                      \"s2_c_loss\": stage_c_loss[2],\n                      \"s3_c_loss\": stage_c_loss[3],\n                      \"s0_range_err_ratio\":range_err_ratio[0],\n                      \"s1_range_err_ratio\":range_err_ratio[1],\n                      \"s2_range_err_ratio\":range_err_ratio[2],\n                      \"s3_range_err_ratio\":range_err_ratio[3],\n                      \"abs_depth_error\": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5),\n                      \"thres2mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2),\n                      \"thres4mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4),\n                      \"thres8mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8),}\n\n    image_outputs = {\"depth_est\": depth_est * mask,\n                     \"depth_est_nomask\": depth_est,\n                     \"depth_gt\": sample[\"depth\"][\"stage1\"],\n                     \"ref_img\": sample[\"imgs\"][0],\n                     \"mask\": sample[\"mask\"][\"stage1\"],\n                     \"errormap\": (depth_est - depth_gt).abs() * mask,\n                     }\n\n    if is_distributed:\n        scalar_outputs = reduce_scalar_outputs(scalar_outputs)\n\n    return tensor2float(scalar_outputs[\"loss\"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs)\n\n\n@make_nograd_func\ndef test_sample_depth(model, model_loss, sample, args):\n    if is_distributed:\n        model_eval = model.module\n    else:\n        model_eval = model\n    model_eval.eval()\n\n    sample_cuda = tocuda(sample)\n    depth_gt_ms = sample_cuda[\"depth\"]\n    mask_ms = sample_cuda[\"mask\"]\n\n    num_stage = len([int(nd) for nd in args.ndepths.split(\",\") if nd])\n    depth_gt = depth_gt_ms[\"stage{}\".format(num_stage)]\n    mask = mask_ms[\"stage{}\".format(num_stage)]\n\n    outputs = model_eval(sample_cuda[\"imgs\"], sample_cuda[\"proj_matrices\"], sample_cuda[\"depth_values\"])\n    depth_est = outputs[\"depth\"]\n\n    loss, stage_d_loss, stage_c_loss, range_err_ratio = model_loss(\n                                        outputs, depth_gt_ms, mask_ms, stage_lw=[float(e) for e in args.dlossw.split(\",\") if e], \n                                        l1ce_lw=[float(lw) for lw in args.l1ce_lw.split(\",\")],\n                                        inverse_depth=args.inverse_depth,\n                                        ot_iter=args.ot_iter, ot_continous=args.ot_continous, ot_eps=args.ot_eps,\n                                        mono=False\n                                        )\n    scalar_outputs = {\"loss\": loss,\n                      \"s0_d_loss\": stage_d_loss[0],\n                      \"s1_d_loss\": stage_d_loss[1],\n                      \"s2_d_loss\": stage_d_loss[2],\n                      \"s3_d_loss\": stage_d_loss[3],\n                      \"s0_c_loss\": stage_c_loss[0],\n                      \"s1_c_loss\": stage_c_loss[1],\n                      \"s2_c_loss\": stage_c_loss[2],\n                      \"s3_c_loss\": stage_c_loss[3],\n                      \"s0_range_err_ratio\":range_err_ratio[0],\n                      \"s1_range_err_ratio\":range_err_ratio[1],\n                      \"s2_range_err_ratio\":range_err_ratio[2],\n                      \"s3_range_err_ratio\":range_err_ratio[3],\n                      \"abs_depth_error\": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5),\n                      \"thres2mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2),\n                      \"thres4mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4),\n                      \"thres8mm_error\": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8),\n                    }\n\n    image_outputs = {\"depth_est\": depth_est * mask,\n                     \"depth_est_nomask\": depth_est,\n                     \"depth_gt\": sample[\"depth\"][\"stage1\"],\n                     \"ref_img\": sample[\"imgs\"][0],\n                     \"mask\": sample[\"mask\"][\"stage1\"],\n                     \"errormap\": (depth_est - depth_gt).abs() * mask}\n\n    if is_distributed:\n        scalar_outputs = reduce_scalar_outputs(scalar_outputs)\n\n    return tensor2float(scalar_outputs[\"loss\"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs)\n\n\nif __name__ == '__main__':\n    # parse arguments and check\n    args = parser.parse_args()\n\n    if args.resume:\n        assert args.mode == \"train\"\n        assert args.loadckpt is None\n\n    if args.testpath is None:\n        args.testpath = args.trainpath\n\n    if is_distributed:\n        torch.cuda.set_device(args.local_rank)\n        torch.distributed.init_process_group(\n            backend=\"nccl\", init_method=\"env://\"\n        )\n        synchronize()\n\n    set_random_seed(args.seed)\n    device = torch.device(args.device)\n\n    if (not is_distributed) or (dist.get_rank() == 0):\n        # create logger for mode \"train\" and \"testall\"\n        if args.mode == \"train\":\n            if not os.path.isdir(args.logdir):\n                os.makedirs(args.logdir)\n            current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S'))\n            print(\"current time\", current_time_str)\n            print(\"creating new summary file\")\n            logger = SummaryWriter(args.logdir)\n        print(\"argv:\", sys.argv[1:])\n        print_args(args)\n\n    # model, optimizer\n    model = MVS4net(arch_mode=args.arch_mode, reg_net=args.reg_mode, num_stage=4, \n                    fpn_base_channel=args.fpn_base_channel, reg_channel=args.reg_channel, \n                    stage_splits=[int(n) for n in args.ndepths.split(\",\")], \n                    depth_interals_ratio=[float(ir) for ir in args.depth_inter_r.split(\",\")],\n                    group_cor=args.group_cor, group_cor_dim=[int(n) for n in args.group_cor_dim.split(\",\")],\n                    inverse_depth=args.inverse_depth,\n                    agg_type=args.agg_type,\n                    dcn=args.dcn,\n                    pos_enc=args.pos_enc,\n                    mono=args.mono,\n                    asff=args.ASFF,\n                    attn_temp=args.attn_temp,\n                )\n\n    model.to(device)\n    model_loss = MVS4net_loss\n\n    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.wd)\n\n    # load parameters\n    start_epoch = 0\n    if args.resume:\n        saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(\".ckpt\")]\n        saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0]))\n        # use the latest checkpoint file\n        loadckpt = os.path.join(args.logdir, saved_models[-1])\n        print(\"resuming\", loadckpt)\n        state_dict = torch.load(loadckpt, map_location=torch.device(\"cpu\"))\n        model.load_state_dict(state_dict['model'])\n        optimizer.load_state_dict(state_dict['optimizer'])\n        start_epoch = state_dict['epoch'] + 1\n    elif args.loadckpt:\n        # load checkpoint file specified by args.loadckpt\n        print(\"loading model {}\".format(args.loadckpt))\n        state_dict = torch.load(args.loadckpt, map_location=torch.device(\"cpu\"))\n        model.load_state_dict(state_dict['model'])\n\n    if (not is_distributed) or (dist.get_rank() == 0):\n        print(\"start at epoch {}\".format(start_epoch))\n        print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))\n\n\n    if is_distributed:\n        if dist.get_rank() == 0:\n            print(\"Let's use\", torch.cuda.device_count(), \"GPUs!\")\n        model = torch.nn.parallel.DistributedDataParallel(\n            model, device_ids=[args.local_rank], output_device=args.local_rank,\n            # find_unused_parameters=True,\n        )\n    else:\n        if torch.cuda.is_available():\n            print(\"Let's use\", torch.cuda.device_count(), \"GPUs!\")\n            model = nn.DataParallel(model)\n\n    # dataset, dataloader\n    MVSDataset = find_dataset_def(args.dataset)\n    if args.dataset.startswith('dtu'):\n        train_dataset = MVSDataset(args.trainpath, args.trainlist, \"train\", 5, args.interval_scale, rt=args.rt,  use_raw_train=args.use_raw_train)\n        test_dataset = MVSDataset(args.testpath, args.testlist, \"val\", 5, args.interval_scale)\n    elif args.dataset.startswith('blendedmvs'):\n        train_dataset = MVSDataset(args.trainpath, args.trainlist, \"train\", 7, robust_train=args.rt)\n        test_dataset = MVSDataset(args.testpath, args.testlist, \"val\", 7)\n    if is_distributed:\n        train_sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=dist.get_world_size(),\n                                                            rank=dist.get_rank())\n        test_sampler = torch.utils.data.DistributedSampler(test_dataset, num_replicas=dist.get_world_size(),\n                                                           rank=dist.get_rank())\n\n        TrainImgLoader = DataLoader(train_dataset, args.batch_size, sampler=train_sampler, num_workers=1,\n                                    drop_last=True,\n                                    pin_memory=args.pin_m)\n        TestImgLoader = DataLoader(test_dataset, args.batch_size, sampler=test_sampler, num_workers=1, drop_last=False,\n                                   pin_memory=args.pin_m)\n    else:\n        TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=1, drop_last=True,\n                                    pin_memory=args.pin_m)\n        TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=1, drop_last=False,\n                                   pin_memory=args.pin_m)\n\n\n    if args.mode == \"train\":\n        train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args)\n    elif args.mode == \"test\":\n        test(model, model_loss, TestImgLoader, args)\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "utils.py",
    "content": "import numpy as np\nimport torchvision.utils as vutils\nimport torch, random\nimport torch.nn.functional as F\n\n\n# print arguments\ndef print_args(args):\n    print(\"################################  args  ################################\")\n    for k, v in args.__dict__.items():\n        print(\"{0: <10}\\t{1: <30}\\t{2: <20}\".format(k, str(v), str(type(v))))\n    print(\"########################################################################\")\n\n\n# torch.no_grad warpper for functions\ndef make_nograd_func(func):\n    def wrapper(*f_args, **f_kwargs):\n        with torch.no_grad():\n            ret = func(*f_args, **f_kwargs)\n        return ret\n\n    return wrapper\n\n\n# convert a function into recursive style to handle nested dict/list/tuple variables\ndef make_recursive_func(func):\n    def wrapper(vars):\n        if isinstance(vars, list):\n            return [wrapper(x) for x in vars]\n        elif isinstance(vars, tuple):\n            return tuple([wrapper(x) for x in vars])\n        elif isinstance(vars, dict):\n            return {k: wrapper(v) for k, v in vars.items()}\n        else:\n            return func(vars)\n\n    return wrapper\n\n\n@make_recursive_func\ndef tensor2float(vars):\n    if isinstance(vars, float):\n        return vars\n    elif isinstance(vars, torch.Tensor):\n        return vars.data.item()\n    else:\n        raise NotImplementedError(\"invalid input type {} for tensor2float\".format(type(vars)))\n\n\n@make_recursive_func\ndef tensor2numpy(vars):\n    if isinstance(vars, np.ndarray):\n        return vars\n    elif isinstance(vars, torch.Tensor):\n        return vars.detach().cpu().numpy().copy()\n    else:\n        raise NotImplementedError(\"invalid input type {} for tensor2numpy\".format(type(vars)))\n\n\n@make_recursive_func\ndef tocuda(vars):\n    if isinstance(vars, torch.Tensor):\n        return vars.to(torch.device(\"cuda\"))\n    elif isinstance(vars, str):\n        return vars\n    else:\n        raise NotImplementedError(\"invalid input type {} for tensor2numpy\".format(type(vars)))\n\n\ndef save_scalars(logger, mode, scalar_dict, global_step):\n    scalar_dict = tensor2float(scalar_dict)\n    for key, value in scalar_dict.items():\n        if not isinstance(value, (list, tuple)):\n            name = '{}/{}'.format(mode, key)\n            logger.add_scalar(name, value, global_step)\n        else:\n            for idx in range(len(value)):\n                name = '{}/{}_{}'.format(mode, key, idx)\n                logger.add_scalar(name, value[idx], global_step)\n\n\ndef save_images(logger, mode, images_dict, global_step):\n    images_dict = tensor2numpy(images_dict)\n\n    def preprocess(name, img):\n        if not (len(img.shape) == 3 or len(img.shape) == 4):\n            raise NotImplementedError(\"invalid img shape {}:{} in save_images\".format(name, img.shape))\n        if len(img.shape) == 3:\n            img = img[:, np.newaxis, :, :]\n        img = torch.from_numpy(img[:1])\n        return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True)\n\n    for key, value in images_dict.items():\n        if not isinstance(value, (list, tuple)):\n            name = '{}/{}'.format(mode, key)\n            logger.add_image(name, preprocess(name, value), global_step)\n        else:\n            for idx in range(len(value)):\n                name = '{}/{}_{}'.format(mode, key, idx)\n                logger.add_image(name, preprocess(name, value[idx]), global_step)\n\n\nclass DictAverageMeter(object):\n    def __init__(self):\n        self.data = {}\n        self.count = 0\n\n    def update(self, new_input):\n        self.count += 1\n        if len(self.data) == 0:\n            for k, v in new_input.items():\n                if not isinstance(v, float):\n                    raise NotImplementedError(\"invalid data {}: {}\".format(k, type(v)))\n                self.data[k] = v\n        else:\n            for k, v in new_input.items():\n                if not isinstance(v, float):\n                    raise NotImplementedError(\"invalid data {}: {}\".format(k, type(v)))\n                self.data[k] += v\n\n    def mean(self):\n        return {k: v / self.count for k, v in self.data.items()}\n\n\n# a wrapper to compute metrics for each image individually\ndef compute_metrics_for_each_image(metric_func):\n    def wrapper(depth_est, depth_gt, mask, *args):\n        batch_size = depth_gt.shape[0]\n        results = []\n        # compute result one by one\n        for idx in range(batch_size):\n            ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args)\n            results.append(ret)\n        return torch.stack(results).mean()\n\n    return wrapper\n\n\n@make_nograd_func\n@compute_metrics_for_each_image\ndef Thres_metrics(depth_est, depth_gt, mask, thres):\n    assert isinstance(thres, (int, float))\n    depth_est, depth_gt = depth_est[mask], depth_gt[mask]\n    errors = torch.abs(depth_est - depth_gt)\n    err_mask = errors > thres\n    return torch.mean(err_mask.float())\n\n\n# NOTE: please do not use this to build up training loss\n@make_nograd_func\n@compute_metrics_for_each_image\ndef AbsDepthError_metrics(depth_est, depth_gt, mask, thres=None):\n    depth_est, depth_gt = depth_est[mask], depth_gt[mask]\n    error = (depth_est - depth_gt).abs()\n    if thres is not None:\n        error = error[(error >= float(thres[0])) & (error <= float(thres[1]))]\n        if error.shape[0] == 0:\n            return torch.tensor(0, device=error.device, dtype=error.dtype)\n    return torch.mean(error)\n\nimport torch.distributed as dist\ndef synchronize():\n    \"\"\"\n    Helper function to synchronize (barrier) among all processes when\n    using distributed training\n    \"\"\"\n    if not dist.is_available():\n        return\n    if not dist.is_initialized():\n        return\n    world_size = dist.get_world_size()\n    if world_size == 1:\n        return\n    dist.barrier()\n\ndef get_world_size():\n    if not dist.is_available():\n        return 1\n    if not dist.is_initialized():\n        return 1\n    return dist.get_world_size()\n\ndef reduce_scalar_outputs(scalar_outputs):\n    world_size = get_world_size()\n    if world_size < 2:\n        return scalar_outputs\n    with torch.no_grad():\n        names = []\n        scalars = []\n        for k in sorted(scalar_outputs.keys()):\n            names.append(k)\n            scalars.append(scalar_outputs[k])\n        scalars = torch.stack(scalars, dim=0)\n        dist.reduce(scalars, dst=0)\n        if dist.get_rank() == 0:\n            # only main process gets accumulated, so only divide by\n            # world_size in this case\n            scalars /= world_size\n        reduced_scalars = {k: v for k, v in zip(names, scalars)}\n\n    return reduced_scalars\n\nimport torch\nfrom bisect import bisect_right\n# FIXME ideally this would be achieved with a CombinedLRScheduler,\n# separating MultiStepLR with WarmupLR\n# but the current LRScheduler design doesn't allow it\nclass WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):\n    def __init__(\n        self,\n        optimizer,\n        milestones,\n        gamma=0.1,\n        warmup_factor=1.0 / 3,\n        warmup_iters=500,\n        warmup_method=\"linear\",\n        last_epoch=-1,\n    ):\n        if not list(milestones) == sorted(milestones):\n            raise ValueError(\n                \"Milestones should be a list of\" \" increasing integers. Got {}\",\n                milestones,\n            )\n\n        if warmup_method not in (\"constant\", \"linear\"):\n            raise ValueError(\n                \"Only 'constant' or 'linear' warmup_method accepted\"\n                \"got {}\".format(warmup_method)\n            )\n        self.milestones = milestones\n        self.gamma = gamma\n        self.warmup_factor = warmup_factor\n        self.warmup_iters = warmup_iters\n        self.warmup_method = warmup_method\n        super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)\n\n    def get_lr(self):\n        warmup_factor = 1\n        if self.last_epoch < self.warmup_iters:\n            if self.warmup_method == \"constant\":\n                warmup_factor = self.warmup_factor\n            elif self.warmup_method == \"linear\":\n                alpha = float(self.last_epoch) / self.warmup_iters\n                warmup_factor = self.warmup_factor * (1 - alpha) + alpha\n        return [\n            base_lr\n            * warmup_factor\n            * self.gamma ** bisect_right(self.milestones, self.last_epoch)\n            for base_lr in self.base_lrs\n        ]\n\n    \ndef set_random_seed(seed):\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)\n    torch.cuda.manual_seed_all(seed)\n\n\ndef local_pcd(depth, intr):\n    nx = depth.shape[1]  # w\n    ny = depth.shape[0]  # h\n    x, y = np.meshgrid(np.arange(nx), np.arange(ny), indexing='xy')\n    x = x.reshape(nx * ny)\n    y = y.reshape(nx * ny)\n    p2d = np.array([x, y, np.ones_like(y)])\n    p3d = np.matmul(np.linalg.inv(intr), p2d)\n    depth = depth.reshape(1, nx * ny)\n    p3d *= depth\n    p3d = np.transpose(p3d, (1, 0))\n    p3d = p3d.reshape(ny, nx, 3).astype(np.float32)\n    return p3d\n\ndef generate_pointcloud(rgb, depth, ply_file, intr, scale=1.0):\n    \"\"\"\n    Generate a colored point cloud in PLY format from a color and a depth image.\n    Input:\n    rgb_file -- filename of color image\n    depth_file -- filename of depth image\n    ply_file -- filename of ply file\n    \"\"\"\n    fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2]\n    points = []\n    for v in range(rgb.shape[0]):\n        for u in range(rgb.shape[1]):\n            color = rgb[v, u] #rgb.getpixel((u, v))\n            Z = depth[v, u] / scale\n            if Z == 0: continue\n            X = (u - cx) * Z / fx\n            Y = (v - cy) * Z / fy\n            points.append(\"%f %f %f %d %d %d 0\\n\" % (X, Y, Z, color[0], color[1], color[2]))\n    file = open(ply_file, \"w\")\n    file.write('''ply\n            format ascii 1.0\n            element vertex %d\n            property float x\n            property float y\n            property float z\n            property uchar red\n            property uchar green\n            property uchar blue\n            property uchar alpha\n            end_header\n            %s\n            ''' % (len(points), \"\".join(points)))\n    file.close()\n    print(\"save ply, fx:{}, fy:{}, cx:{}, cy:{}\".format(fx, fy, cx, cy))"
  }
]