Repository: JeffWang987/MVSTER Branch: main Commit: 3f8bb98bba0c Files: 35 Total size: 246.6 KB Directory structure: gitextract_j4zdg8al/ ├── .gitignore ├── LICENSE ├── README.md ├── datasets/ │ ├── __init__.py │ ├── blendedmvs.py │ ├── data_io.py │ ├── dtu_yao4.py │ ├── eth3d.py │ ├── general_eval4.py │ └── tanks.py ├── evaluations/ │ └── dtu/ │ ├── BaseEval2Obj_web.m │ ├── BaseEvalMain_func.m │ ├── BaseEvalMain_web.m │ ├── ComputeStat_func.m │ ├── ComputeStat_web.m │ ├── MaxDistCP.m │ ├── PointCompareMain.m │ ├── plyread.m │ └── reducePts_haa.m ├── lists/ │ ├── blendedmvs/ │ │ ├── train.txt │ │ └── val.txt │ └── dtu/ │ ├── test.txt │ ├── train.txt │ ├── trainval.txt │ └── val.txt ├── models/ │ ├── MVS4Net.py │ ├── __init__.py │ ├── module.py │ └── mvs4net_utils.py ├── requirements.txt ├── scripts/ │ ├── test_dtu.sh │ └── train_dtu.sh ├── test_mvs4.py ├── train_mvs4.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ outputs/ checkpoints/ debug_figs/ *__pycache__ ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2022 Jeff Wang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # MVSTER MVSTER: Epipolar Transformer for Efficient Multi-View Stereo, ECCV 2022. [arXiv](https://arxiv.org/abs/2204.07346) This repository contains the official implementation of the paper: "MVSTER: Epipolar Transformer for Efficient Multi-View Stereo". ## Introduction MVSTER is a learning-based MVS method which achieves competitive reconstruction performance with significantly higher efficiency. MVSTER leverages the proposed epipolar Transformer to learn both 2D semantics and 3D spatial associations efficiently. Specifically, the epipolar Transformer utilizes a detachable monocular depth estimator to enhance 2D semantics and uses cross-attention to construct data-dependent 3D associations along epipolar line. Additionally, MVSTER is built in a cascade structure, where entropy-regularized optimal transport is leveraged to propagate finer depth estimations in each stage. ![](img/arch.png) ## Installation MVSTER is tested on: * python 3.7 * CUDA 11.1 ### Requirements ``` pip install -r requirements.txt ``` ## Training * Dowload [DTU dataset](https://roboimagedata.compute.dtu.dk/). For convenience, can download the preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) and [Depths_raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip) (both from [Original MVSNet](https://github.com/YoYo000/MVSNet)), and upzip it as the $DTU_TRAINING folder. For training and testing with raw image size, you can download [Rectified_raw](http://roboimagedata2.compute.dtu.dk/data/MVS/Rectified.zip), and unzip it. ``` ├── Cameras ├── Depths ├── Depths_raw ├── Rectified ├── Rectified_raw (Optional) ``` In ``scripts/train_dtu.sh``, set ``DTU_TRAINING`` as $DTU_TRAINING Train MVSTER (Multi-GPU training): * Train with middle size (512x640): ``` bash ./scripts/train_dtu.sh mid exp_name ``` * Train with raw size (1200x1600): ``` bash ./scripts/train_dtu.sh raw exp_name ``` After training, you will get model checkpoints in ./checkpoints/dtu/exp_name. ## Testing * Download the preprocessed test data [DTU testing data](https://drive.google.com/open?id=135oKPefcPTsdtLRzoDAQtPpHuoIrpRI_) (from [Original MVSNet](https://github.com/YoYo000/MVSNet)) and unzip it as the $DTU_TESTPATH folder, which should contain one ``cams`` folder, one ``images`` folder and one ``pair.txt`` file. * In ``scripts/test_dtu.sh``, set ``DTU_TESTPATH`` as $DTU_TESTPATH. * The ``DTU_CKPT_FILE`` is automatically set as your pretrained checkpoint file, you also can download my [pretrained model](https://github.com/JeffWang987/MVSTER/releases/tag/dtu_ckpt). * Test with middle size: ``` bash ./scripts/test_dtu.sh mid exp_name ``` * Test with raw size: ``` bash ./scripts/test_dtu.sh raw exp_name ``` * Test with provided pretrained model: ``` bash scripts/test_dtu.sh mid benchmark --loadckpt PATH_TO_CKPT_FILE ``` After testing, you will get reconstructed point clouds of DTU test set in ./outputs/dtu/exp_name. ## Metric * For quantitative evaluation, download [SampleSet](http://roboimagedata.compute.dtu.dk/?page_id=36) and [Points](http://roboimagedata.compute.dtu.dk/?page_id=36) from DTU's website. Unzip them and place `Points` folder in `SampleSet/MVS Data/`. The structure looks like: ``` SampleSet ├──MVS Data └──Points ``` * For convinience evaluation, please install matlab (tested on Ubuntu 18.04) and uncomment **mrun_rst** function at the end of **./test_mvs4.py**, and you also need to change the path of matlab excutable file (for me, it is /mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/misc/matlab/bin/matlab). Then you can evaluate point cloud reconstruction results when testing is finished. * You can also evaluate the metrics with the traditional steps: In ``evaluations/dtu/BaseEvalMain_web.m``, set `dataPath` as the path to `SampleSet/MVS Data/`, `plyPath` as directory that stores the reconstructed point clouds and `resultsPath` as directory to store the evaluation results. Then run ``evaluations/dtu/BaseEvalMain_web.m`` in matlab. ## Results on DTU (single RTX 3090) | | Acc. | Comp. | Overall. | Inf. Time | |-----------------------|--------|--------|----------|-----------| | MVSTER (mid size) | 0.350 | 0.276 | 0.313 | 0.09s | | MVSTER (raw size) | 0.340 | 0.266 | 0.303 | 0.17s | Point cloud results on [DTU](https://github.com/JeffWang987/MVSTER/releases/tag/DTU_ply), [Tanks and Temples](https://github.com/JeffWang987/MVSTER/releases/tag/T%26T_ply), [ETH3D](https://github.com/JeffWang987/MVSTER/releases/tag/ETH3D_ply) ![](img/vegetables.gif) ![](img/house.gif) ![](img/sculpture.gif) ![](img/rabit.gif) If you find this project useful for your research, please cite: ``` @misc{wang2022mvster, title={MVSTER: Epipolar Transformer for Efficient Multi-View Stereo}, author={Xiaofeng Wang, Zheng Zhu, Fangbo Qin, Yun Ye, Guan Huang, Xu Chi, Yijia He and Xingang Wang}, journal={arXiv preprint arXiv:2204.07346}, year={2022} } ``` ## Acknowledgements Our work is partially baed on these opening source work: [MVSNet](https://github.com/YoYo000/MVSNet), [MVSNet-pytorch](https://github.com/xy-guo/MVSNet_pytorch), [cascade-stereo](https://github.com/alibaba/cascade-stereo), [PatchmatchNet](https://github.com/FangjinhuaWang/PatchmatchNet). We appreciate their contributions to the MVS community. ================================================ FILE: datasets/__init__.py ================================================ import importlib # find the dataset definition by name, for example dtu_yao (dtu_yao.py) def find_dataset_def(dataset_name): module_name = 'datasets.{}'.format(dataset_name) module = importlib.import_module(module_name) return getattr(module, "MVSDataset") ================================================ FILE: datasets/blendedmvs.py ================================================ from torch.utils.data import Dataset from datasets.data_io import * import os import numpy as np import cv2 from PIL import Image from torchvision import transforms as T import random import copy def check_invalid_input(imgs, depths, masks, depth_mins, depth_maxs): for img in imgs: assert np.isnan(img).sum() == 0 assert np.isinf(img).sum() == 0 for depth in depths.values(): assert np.isnan(depth).sum() == 0 assert np.isinf(depth).sum() == 0 for mask in masks.values(): assert np.isnan(mask).sum() == 0 assert np.isinf(mask).sum() == 0 assert (depth_mins<=0) == 0 assert (depth_maxs<=depth_mins) == 0 class MVSDataset(Dataset): def __init__(self, datapath, listfile, split, nviews, img_wh=(768, 576), robust_train=True): super(MVSDataset, self).__init__() self.levels = 4 self.datapath = datapath self.split = split self.listfile = listfile self.robust_train = robust_train assert self.split in ['train', 'val', 'all'], \ 'split must be either "train", "val" or "all"!' self.img_wh = img_wh if img_wh is not None: assert img_wh[0]%32==0 and img_wh[1]%32==0, \ 'img_wh must both be multiples of 32!' self.nviews = nviews self.scale_factors = {} # depth scale factors for each scan self.scale_factor = 0 # depth scale factors for each scan self.build_metas() self.color_augment = T.ColorJitter(brightness=0.5, contrast=0.5) def build_metas(self): self.metas = [] with open(self.listfile) as f: self.scans = [line.rstrip() for line in f.readlines()] for scan in self.scans: with open(os.path.join(self.datapath, scan, "cams/pair.txt")) as f: num_viewpoint = int(f.readline()) for _ in range(num_viewpoint): ref_view = int(f.readline().rstrip()) src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] if len(src_views) >= self.nviews-1: self.metas += [(scan, ref_view, src_views)] def read_cam_file(self, scan, filename): with open(filename) as f: lines = f.readlines() lines = [line.rstrip() for line in lines] # extrinsics: line [1,5), 4x4 matrix extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) # intrinsics: line [7-10), 3x3 matrix intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) depth_min = float(lines[11].split()[0]) depth_max = float(lines[11].split()[-1]) if scan not in self.scale_factors: self.scale_factors[scan] = 100.0 / depth_min depth_min *= self.scale_factors[scan] depth_max *= self.scale_factors[scan] extrinsics[:3, 3] *= self.scale_factors[scan] return intrinsics, extrinsics, depth_min, depth_max def read_depth_mask(self, scan, filename, depth_min, depth_max, scale): depth = np.array(read_pfm(filename)[0], dtype=np.float32) # depth = (depth * self.scale_factor) * scale depth = (depth * self.scale_factors[scan]) * scale # depth = depth * scale # depth = np.squeeze(depth,2) mask = (depth>=depth_min) & (depth<=depth_max) assert mask.sum() > 0 mask = mask.astype(np.float32) if self.img_wh is not None: depth = cv2.resize(depth, self.img_wh, interpolation=cv2.INTER_NEAREST) h, w = depth.shape depth_ms = {} mask_ms = {} for i in range(4): depth_cur = cv2.resize(depth, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST) mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST) depth_ms[f"stage{4-i}"] = depth_cur mask_ms[f"stage{4-i}"] = mask_cur return depth_ms, mask_ms def read_img(self, filename): img = Image.open(filename) # img = self.color_augment(img) # scale 0~255 to 0~1 np_img = np.array(img, dtype=np.float32) / 255. return np_img def __len__(self): return len(self.metas) def __getitem__(self, idx): meta = self.metas[idx] scan, ref_view, src_views = meta if self.robust_train: num_src_views = len(src_views) index = random.sample(range(num_src_views), self.nviews - 1) view_ids = [ref_view] + [src_views[i] for i in index] scale = random.uniform(0.8, 1.25) else: view_ids = [ref_view] + src_views[:self.nviews - 1] scale = 1 imgs = [] mask = None depth = None depth_min = None depth_max = None proj={} proj_matrices_0 = [] proj_matrices_1 = [] proj_matrices_2 = [] proj_matrices_3 = [] for i, vid in enumerate(view_ids): img_filename = os.path.join(self.datapath, '{}/blended_images/{:0>8}.jpg'.format(scan, vid)) depth_filename = os.path.join(self.datapath, '{}/rendered_depth_maps/{:0>8}.pfm'.format(scan, vid)) proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid)) img = self.read_img(img_filename) imgs.append(img.transpose(2,0,1)) intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(scan, proj_mat_filename) # proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid) proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32) proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32) proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32) proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32) extrinsics[:3, 3] *= scale intrinsics[:2,:] *= 0.125 proj_mat_0[0,:4,:4] = extrinsics.copy() proj_mat_0[1,:3,:3] = intrinsics.copy() intrinsics[:2,:] *= 2 proj_mat_1[0,:4,:4] = extrinsics.copy() proj_mat_1[1,:3,:3] = intrinsics.copy() intrinsics[:2,:] *= 2 proj_mat_2[0,:4,:4] = extrinsics.copy() proj_mat_2[1,:3,:3] = intrinsics.copy() intrinsics[:2,:] *= 2 proj_mat_3[0,:4,:4] = extrinsics.copy() proj_mat_3[1,:3,:3] = intrinsics.copy() proj_matrices_0.append(proj_mat_0) proj_matrices_1.append(proj_mat_1) proj_matrices_2.append(proj_mat_2) proj_matrices_3.append(proj_mat_3) if i == 0: # reference view depth_min = depth_min_ * scale depth_max = depth_max_ * scale depth, mask = self.read_depth_mask(scan, depth_filename, depth_min, depth_max, scale) for l in range(self.levels): mask[f'stage{l+1}'] = mask[f'stage{l+1}'] # np.expand_dims(mask[f'stage{l+1}'],2) depth[f'stage{l+1}'] = depth[f'stage{l+1}'] proj['stage1'] = np.stack(proj_matrices_0) proj['stage2'] = np.stack(proj_matrices_1) proj['stage3'] = np.stack(proj_matrices_2) proj['stage4'] = np.stack(proj_matrices_3) # check_invalid_input(imgs, depth, mask, depth_min, depth_max) # data is numpy array return {"imgs": imgs, # [Nv, 3, H, W] "proj_matrices": proj, # [N,2,4,4] "depth": depth, # [1, H, W] "depth_values": np.array([depth_min, depth_max], dtype=np.float32), "mask": mask} # [1, H, W] ================================================ FILE: datasets/data_io.py ================================================ import numpy as np import re import sys def read_pfm(filename): file = open(filename, 'rb') color = None width = None height = None scale = None endian = None header = file.readline().decode('utf-8').rstrip() if header == 'PF': color = True elif header == 'Pf': color = False else: raise Exception('Not a PFM file.') dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) if dim_match: width, height = map(int, dim_match.groups()) else: raise Exception('Malformed PFM header.') scale = float(file.readline().rstrip()) if scale < 0: # little-endian endian = '<' scale = -scale else: endian = '>' # big-endian data = np.fromfile(file, endian + 'f') shape = (height, width, 3) if color else (height, width) data = np.reshape(data, shape) data = np.flipud(data) file.close() return data, scale def save_pfm(filename, image, scale=1): file = open(filename, "wb") color = None image = np.flipud(image) if image.dtype.name != 'float32': raise Exception('Image dtype must be float32.') if len(image.shape) == 3 and image.shape[2] == 3: # color image color = True elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale color = False else: raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8')) endian = image.dtype.byteorder if endian == '<' or endian == '=' and sys.byteorder == 'little': scale = -scale file.write(('%f\n' % scale).encode('utf-8')) image.tofile(file) file.close() import random, cv2 class RandomCrop(object): def __init__(self, CropSize=0.1): self.CropSize = CropSize def __call__(self, image, normal): h, w = normal.shape[:2] img_h, img_w = image.shape[:2] CropSize_w, CropSize_h = max(1, int(w * self.CropSize)), max(1, int(h * self.CropSize)) x1, y1 = random.randint(0, CropSize_w), random.randint(0, CropSize_h) x2, y2 = random.randint(w - CropSize_w, w), random.randint(h - CropSize_h, h) normal_crop = normal[y1:y2, x1:x2] normal_resize = cv2.resize(normal_crop, (w, h), interpolation=cv2.INTER_NEAREST) image_crop = image[4*y1:4*y2, 4*x1:4*x2] image_resize = cv2.resize(image_crop, (img_w, img_h), interpolation=cv2.INTER_LINEAR) # import matplotlib.pyplot as plt # plt.subplot(2, 3, 1) # plt.imshow(image) # plt.subplot(2, 3, 2) # plt.imshow(image_crop) # plt.subplot(2, 3, 3) # plt.imshow(image_resize) # # plt.subplot(2, 3, 4) # plt.imshow((normal + 1.0) / 2, cmap="rainbow") # plt.subplot(2, 3, 5) # plt.imshow((normal_crop + 1.0) / 2, cmap="rainbow") # plt.subplot(2, 3, 6) # plt.imshow((normal_resize + 1.0) / 2, cmap="rainbow") # plt.show() # plt.pause(1) # plt.close() return image_resize, normal_resize ================================================ FILE: datasets/dtu_yao4.py ================================================ from torch.utils.data import Dataset import numpy as np import os, cv2, time, math from PIL import Image from datasets.data_io import * from torchvision import transforms # the DTU dataset preprocessed by Yao Yao (only for training) class MVSDataset(Dataset): def __init__(self, datapath, listfile, mode, nviews, interval_scale=1.06, **kwargs): super(MVSDataset, self).__init__() self.datapath = datapath self.listfile = listfile self.mode = mode self.nviews = nviews self.ndepths = 192 # Hardcode self.interval_scale = interval_scale self.kwargs = kwargs self.rt = kwargs.get("rt", False) self.use_raw_train = kwargs.get("use_raw_train", False) self.color_augment = transforms.ColorJitter(brightness=0.5, contrast=0.5) assert self.mode in ["train", "val", "test"] self.metas = self.build_list() def build_list(self): metas = [] with open(self.listfile) as f: scans = f.readlines() scans = [line.rstrip() for line in scans] # scans for scan in scans: pair_file = "Cameras/pair.txt" # read the pair file with open(os.path.join(self.datapath, pair_file)) as f: num_viewpoint = int(f.readline()) # viewpoints (49) for view_idx in range(num_viewpoint): ref_view = int(f.readline().rstrip()) src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] # light conditions 0-6 for light_idx in range(7): metas.append((scan, light_idx, ref_view, src_views)) # print("dataset", self.mode, "metas:", len(metas)) return metas def __len__(self): return len(self.metas) def read_cam_file(self, filename): with open(filename) as f: lines = f.readlines() lines = [line.rstrip() for line in lines] # extrinsics: line [1,5), 4x4 matrix extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) # intrinsics: line [7-10), 3x3 matrix intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) # depth_min & depth_interval: line 11 depth_min = float(lines[11].split()[0]) depth_interval = float(lines[11].split()[1]) * self.interval_scale return intrinsics, extrinsics, depth_min, depth_interval def read_img(self, filename): img = Image.open(filename) if self.mode == 'train': img = self.color_augment(img) # scale 0~255 to 0~1 np_img = np.array(img, dtype=np.float32) / 255. return np_img def crop_img(self, img): raw_h, raw_w = img.shape[:2] start_h = (raw_h-1024)//2 start_w = (raw_w-1280)//2 return img[start_h:start_h+1024, start_w:start_w+1280, :] # 1024, 1280, C def prepare_img(self, hr_img): h, w = hr_img.shape if not self.use_raw_train: #w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128 #downsample hr_img_ds = cv2.resize(hr_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST) h, w = hr_img_ds.shape target_h, target_w = 512, 640 start_h, start_w = (h - target_h)//2, (w - target_w)//2 hr_img_crop = hr_img_ds[start_h: start_h + target_h, start_w: start_w + target_w] elif self.use_raw_train: hr_img_crop = hr_img[h//2-1024//2:h//2+1024//2, w//2-1280//2:w//2+1280//2] # 1024, 1280, c return hr_img_crop def read_mask_hr(self, filename): img = Image.open(filename) np_img = np.array(img, dtype=np.float32) np_img = (np_img > 10).astype(np.float32) np_img = self.prepare_img(np_img) h, w = np_img.shape np_img_ms = { "stage1": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_NEAREST), "stage2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_NEAREST), "stage3": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST), "stage4": np_img, } return np_img_ms def read_depth_hr(self, filename, scale): # read pfm depth file #w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128 depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) * scale depth_lr = self.prepare_img(depth_hr) h, w = depth_lr.shape depth_lr_ms = { "stage1": cv2.resize(depth_lr, (w//8, h//8), interpolation=cv2.INTER_NEAREST), "stage2": cv2.resize(depth_lr, (w//4, h//4), interpolation=cv2.INTER_NEAREST), "stage3": cv2.resize(depth_lr, (w//2, h//2), interpolation=cv2.INTER_NEAREST), "stage4": depth_lr, } return depth_lr_ms def __getitem__(self, idx): meta = self.metas[idx] scan, light_idx, ref_view, src_views = meta # use only the reference view and first nviews-1 source views if self.mode == 'train' and self.rt: num_src_views = len(src_views) index = random.sample(range(num_src_views), self.nviews - 1) view_ids = [ref_view] + [src_views[i] for i in index] scale = random.uniform(0.8, 1.25) else: view_ids = [ref_view] + src_views[:self.nviews - 1] scale = 1 imgs = [] mask = None depth_values = None proj_matrices = [] for i, vid in enumerate(view_ids): # NOTE that the id in image file names is from 1 to 49 (not 0~48) if not self.use_raw_train: img_filename = os.path.join(self.datapath, 'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx)) else: img_filename = os.path.join(self.datapath, 'Rectified_raw/{}/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx)) mask_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid)) depth_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid)) proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid) img = self.read_img(img_filename) if self.use_raw_train: img = self.crop_img(img) intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename) if self.rt: extrinsics[:3,3] *= scale if self.use_raw_train: intrinsics[:2, :] *= 2.0 if i == 0: mask_read_ms = self.read_mask_hr(mask_filename_hr) depth_ms = self.read_depth_hr(depth_filename_hr, scale) #get depth values depth_max = depth_interval * self.ndepths + depth_min depth_values = np.array([depth_min * scale, depth_max * scale], dtype=np.float32) mask = mask_read_ms proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) # proj_mat[0, :4, :4] = extrinsics proj_mat[1, :3, :3] = intrinsics proj_matrices.append(proj_mat) imgs.append(img.transpose(2,0,1)) #all # imgs = np.stack(imgs).transpose([0, 3, 1, 2]) #ms proj_mats proj_matrices = np.stack(proj_matrices) stage1_pjmats = proj_matrices.copy() stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] / 2.0 stage3_pjmats = proj_matrices.copy() stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2 stage4_pjmats = proj_matrices.copy() stage4_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4 proj_matrices_ms = { "stage1": stage1_pjmats, "stage2": proj_matrices, "stage3": stage3_pjmats, "stage4": stage4_pjmats } return {"imgs": imgs, # Nv C H W "proj_matrices": proj_matrices_ms, # 4 stage of Nv 2 4 4 "depth": depth_ms, "depth_values": depth_values, "mask": mask } ================================================ FILE: datasets/eth3d.py ================================================ from torch.utils.data import Dataset from datasets.data_io import * import os import numpy as np import cv2 from PIL import Image class MVSDataset(Dataset): def __init__(self, datapath, split='test', n_views=7, img_wh=(1920,1280)): self.levels = 4 self.datapath = datapath self.img_wh = img_wh self.split = split self.build_metas() self.n_views = n_views def build_metas(self): self.metas = [] if self.split == "test": self.scans = ['botanical_garden', 'boulders', 'bridge', 'door', 'exhibition_hall', 'lecture_room', 'living_room', 'lounge', 'observatory', 'old_computer', 'statue', 'terrace_2'] elif self.split == "train": self.scans = ['courtyard', 'delivery_area', 'electro', 'facade', 'kicker', 'meadow', 'office', 'pipes', 'playground', 'relief', 'relief_2', 'terrace', 'terrains'] for scan in self.scans: with open(os.path.join(self.datapath, scan, 'pair.txt')) as f: num_viewpoint = int(f.readline()) for view_idx in range(num_viewpoint): ref_view = int(f.readline().rstrip()) src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] if len(src_views) != 0: self.metas += [(scan, -1, ref_view, src_views)] def read_cam_file(self, filename): with open(filename) as f: lines = [line.rstrip() for line in f.readlines()] # extrinsics: line [1,5), 4x4 matrix extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ') extrinsics = extrinsics.reshape((4, 4)) # intrinsics: line [7-10), 3x3 matrix intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ') intrinsics = intrinsics.reshape((3, 3)) depth_min = float(lines[11].split()[0]) if depth_min < 0: depth_min = 1 depth_max = float(lines[11].split()[-1]) return intrinsics, extrinsics, depth_min, depth_max def read_img(self, filename): img = Image.open(filename) np_img = np.array(img, dtype=np.float32) / 255. original_h, original_w, _ = np_img.shape np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR) return np_img, original_h, original_w def __len__(self): return len(self.metas) def __getitem__(self, idx): scan, _, ref_view, src_views = self.metas[idx] # use only the reference view and first nviews-1 source views view_ids = [ref_view] + src_views[:self.n_views-1] imgs = [] # depth = None depth_min = None depth_max = None proj_matrices_0 = [] proj_matrices_1 = [] proj_matrices_2 = [] proj_matrices_3 = [] for i, vid in enumerate(view_ids): img_filename = os.path.join(self.datapath, scan, f'images/{vid:08d}.jpg') proj_mat_filename = os.path.join(self.datapath, scan, f'cams_1/{vid:08d}_cam.txt') img, original_h, original_w = self.read_img(img_filename) intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename) intrinsics[0] *= self.img_wh[0]/original_w intrinsics[1] *= self.img_wh[1]/original_h imgs.append(img.transpose(2,0,1)) proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32) proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32) proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32) proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32) intrinsics[:2,:] *= 0.125 proj_mat_0[0,:4,:4] = extrinsics.copy() proj_mat_0[1,:3,:3] = intrinsics.copy() intrinsics[:2,:] *= 2 proj_mat_1[0,:4,:4] = extrinsics.copy() proj_mat_1[1,:3,:3] = intrinsics.copy() intrinsics[:2,:] *= 2 proj_mat_2[0,:4,:4] = extrinsics.copy() proj_mat_2[1,:3,:3] = intrinsics.copy() intrinsics[:2,:] *= 2 proj_mat_3[0,:4,:4] = extrinsics.copy() proj_mat_3[1,:3,:3] = intrinsics.copy() proj_matrices_0.append(proj_mat_0) proj_matrices_1.append(proj_mat_1) proj_matrices_2.append(proj_mat_2) proj_matrices_3.append(proj_mat_3) if i == 0: # reference view depth_min = depth_min_ depth_max = depth_max_ # proj_matrices: N*4*4 proj={} proj['stage1'] = np.stack(proj_matrices_0) proj['stage2'] = np.stack(proj_matrices_1) proj['stage3'] = np.stack(proj_matrices_2) proj['stage4'] = np.stack(proj_matrices_3) return {"imgs": imgs, # N*3*H0*W0 "proj_matrices": proj, # N*4*4 "depth_values": np.array([depth_min, depth_max], dtype=np.float32), "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}" } ================================================ FILE: datasets/general_eval4.py ================================================ from torch.utils.data import Dataset import numpy as np import os, cv2, time from PIL import Image from datasets.data_io import * s_h, s_w = 0, 0 class MVSDataset(Dataset): def __init__(self, datapath, listfile, mode, nviews, interval_scale=1.06, **kwargs): super(MVSDataset, self).__init__() self.datapath = datapath self.listfile = listfile self.mode = mode self.nviews = nviews self.ndepths = 192 # Hardcode self.interval_scale = interval_scale self.max_h, self.max_w = kwargs["max_h"], kwargs["max_w"] self.fix_res = kwargs.get("fix_res", False) #whether to fix the resolution of input image. self.fix_wh = False assert self.mode == "test" self.metas = self.build_list() def build_list(self): metas = [] scans = self.listfile interval_scale_dict = {} # scans for scan in scans: # determine the interval scale of each scene. default is 1.06 if isinstance(self.interval_scale, float): interval_scale_dict[scan] = self.interval_scale else: interval_scale_dict[scan] = self.interval_scale[scan] pair_file = "{}/pair.txt".format(scan) # read the pair file with open(os.path.join(self.datapath, pair_file)) as f: num_viewpoint = int(f.readline()) # viewpoints for view_idx in range(num_viewpoint): ref_view = int(f.readline().rstrip()) src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] # filter by no src view and fill to nviews if len(src_views) > 0: if len(src_views) < self.nviews: print("{}< num_views:{}".format(len(src_views), self.nviews)) src_views += [src_views[0]] * (self.nviews - len(src_views)) metas.append((scan, ref_view, src_views, scan)) self.interval_scale = interval_scale_dict print("dataset", self.mode, "metas:", len(metas), "interval_scale:{}".format(self.interval_scale)) return metas def __len__(self): return len(self.metas) def read_cam_file(self, filename, interval_scale): with open(filename) as f: lines = f.readlines() lines = [line.rstrip() for line in lines] # extrinsics: line [1,5), 4x4 matrix extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) # intrinsics: line [7-10), 3x3 matrix intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) intrinsics[:2, :] /= 4.0 # depth_min & depth_interval: line 11 depth_min = float(lines[11].split()[0]) depth_interval = float(lines[11].split()[1]) if len(lines[11].split()) >= 3: num_depth = lines[11].split()[2] depth_max = depth_min + int(float(num_depth)) * depth_interval depth_interval = (depth_max - depth_min) / self.ndepths depth_interval *= interval_scale return intrinsics, extrinsics, depth_min, depth_interval def read_img(self, filename): img = Image.open(filename) # scale 0~255 to 0~1 np_img = np.array(img, dtype=np.float32) / 255. return np_img def read_depth(self, filename): # read pfm depth file return np.array(read_pfm(filename)[0], dtype=np.float32) def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=64): h, w = img.shape[:2] if h > max_h or w > max_w: scale = 1.0 * max_h / h if scale * w > max_w: scale = 1.0 * max_w / w new_w, new_h = scale * w // base * base, scale * h // base * base else: new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base scale_w = 1.0 * new_w / w scale_h = 1.0 * new_h / h intrinsics[0, :] *= scale_w intrinsics[1, :] *= scale_h img = cv2.resize(img, (int(new_w), int(new_h))) return img, intrinsics def __getitem__(self, idx): global s_h, s_w meta = self.metas[idx] scan, ref_view, src_views, scene_name = meta # use only the reference view and first nviews-1 source views view_ids = [ref_view] + src_views[:self.nviews - 1] imgs = [] depth_values = None proj_matrices = [] for i, vid in enumerate(view_ids): img_filename = os.path.join(self.datapath, '{}/images_post/{:0>8}.jpg'.format(scan, vid)) if not os.path.exists(img_filename): img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid)) proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid)) img = self.read_img(img_filename) intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename, interval_scale= self.interval_scale[scene_name]) # scale input img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_w, self.max_h) if self.fix_res: # using the same standard height or width in entire scene. s_h, s_w = img.shape[:2] self.fix_res = False self.fix_wh = True if i == 0: if not self.fix_wh: # using the same standard height or width in each nviews. s_h, s_w = img.shape[:2] # resize to standard height or width c_h, c_w = img.shape[:2] if (c_h != s_h) or (c_w != s_w): scale_h = 1.0 * s_h / c_h scale_w = 1.0 * s_w / c_w img = cv2.resize(img, (s_w, s_h)) intrinsics[0, :] *= scale_w intrinsics[1, :] *= scale_h imgs.append(img.transpose(2,0,1)) # extrinsics, intrinsics proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) # proj_mat[0, :4, :4] = extrinsics proj_mat[1, :3, :3] = intrinsics proj_matrices.append(proj_mat) if i == 0: # reference view depth_values = np.arange(depth_min, depth_interval * (self.ndepths - 0.5) + depth_min, depth_interval, dtype=np.float32) #all # imgs = np.stack(imgs).transpose([0, 3, 1, 2]) #ms proj_mats proj_matrices = np.stack(proj_matrices) stage1_pjmats = proj_matrices.copy() stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] / 2.0 stage3_pjmats = proj_matrices.copy() stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2 stage4_pjmats = proj_matrices.copy() stage4_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4 proj_matrices_ms = { "stage1": stage1_pjmats, "stage2": proj_matrices, "stage3": stage3_pjmats, "stage4": stage4_pjmats } return {"imgs": imgs, "proj_matrices": proj_matrices_ms, "depth_values": depth_values, "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"} ================================================ FILE: datasets/tanks.py ================================================ from torch.utils.data import Dataset from datasets.data_io import * import os import numpy as np import cv2 from PIL import Image class MVSDataset(Dataset): def __init__(self, datapath, n_views=7, split='intermediate'): self.levels = 4 self.datapath = datapath self.split = split self.build_metas() self.n_views = n_views def build_metas(self): self.metas = [] if self.split == 'intermediate': self.scans = ['Family', 'Francis', 'Horse', 'Playground', 'Train', 'Lighthouse', 'M60', 'Panther'] elif self.split == 'advanced': self.scans = ['Auditorium', 'Ballroom', 'Courtroom', 'Museum', 'Palace', 'Temple'] for scan in self.scans: with open(os.path.join(self.datapath, self.split, scan, 'pair.txt')) as f: num_viewpoint = int(f.readline()) for view_idx in range(num_viewpoint): ref_view = int(f.readline().rstrip()) src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] if len(src_views) != 0: self.metas += [(scan, -1, ref_view, src_views)] def read_cam_file(self, filename): with open(filename) as f: lines = [line.rstrip() for line in f.readlines()] # extrinsics: line [1,5), 4x4 matrix extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ') extrinsics = extrinsics.reshape((4, 4)) # intrinsics: line [7-10), 3x3 matrix intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ') intrinsics = intrinsics.reshape((3, 3)) depth_min = float(lines[11].split()[0]) depth_max = float(lines[11].split()[-1]) return intrinsics, extrinsics, depth_min, depth_max def read_img(self, filename): img = Image.open(filename) np_img = np.array(img, dtype=np.float32) / 255. return np_img def scale_input(self, intrinsics, img): """ intrinsics: 3x3 img: W H C """ intrinsics[1,2] = intrinsics[1,2] - 28 # 1080 -> 1024 img = img[28:1080-28, :, :] return intrinsics, img def __len__(self): return len(self.metas) def __getitem__(self, idx): scan, _, ref_view, src_views = self.metas[idx] # use only the reference view and first nviews-1 source views view_ids = [ref_view] + src_views[:self.n_views-1] imgs = [] # depth = None depth_min = None depth_max = None proj_matrices_0 = [] proj_matrices_1 = [] proj_matrices_2 = [] proj_matrices_3 = [] for i, vid in enumerate(view_ids): img_filename = os.path.join(self.datapath, self.split, scan, f'images/{vid:08d}.jpg') proj_mat_filename = os.path.join(self.datapath, self.split, scan, f'cams/{vid:08d}_cam.txt') img = self.read_img(img_filename) intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename) intrinsics, img = self.scale_input(intrinsics, img) imgs.append(img.transpose(2,0,1)) proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32) proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32) proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32) proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32) intrinsics[:2,:] *= 0.125 proj_mat_0[0,:4,:4] = extrinsics.copy() proj_mat_0[1,:3,:3] = intrinsics.copy() intrinsics[:2,:] *= 2 proj_mat_1[0,:4,:4] = extrinsics.copy() proj_mat_1[1,:3,:3] = intrinsics.copy() intrinsics[:2,:] *= 2 proj_mat_2[0,:4,:4] = extrinsics.copy() proj_mat_2[1,:3,:3] = intrinsics.copy() intrinsics[:2,:] *= 2 proj_mat_3[0,:4,:4] = extrinsics.copy() proj_mat_3[1,:3,:3] = intrinsics.copy() proj_matrices_0.append(proj_mat_0) proj_matrices_1.append(proj_mat_1) proj_matrices_2.append(proj_mat_2) proj_matrices_3.append(proj_mat_3) if i == 0: # reference view depth_min = depth_min_ depth_max = depth_max_ # proj_matrices: N*4*4 proj={} proj['stage1'] = np.stack(proj_matrices_0) proj['stage2'] = np.stack(proj_matrices_1) proj['stage3'] = np.stack(proj_matrices_2) proj['stage4'] = np.stack(proj_matrices_3) return {"imgs": imgs, # N*3*H0*W0 "proj_matrices": proj, # N*4*4 "depth_values": np.array([depth_min, depth_max], dtype=np.float32), "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}" } ================================================ FILE: evaluations/dtu/BaseEval2Obj_web.m ================================================ function BaseEval2Obj_web(BaseEval,method_string,outputPath) if(nargin<3) outputPath='./'; end % tresshold for coloring alpha channel in the range of 0-10 mm dist_tresshold=10; cSet=BaseEval.cSet; Qdata=BaseEval.Qdata; alpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold; fid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+'); for cP=1:size(Qdata,2) if(BaseEval.DataInMask(cP)) C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) else C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis) end fprintf(fid,'v %f %f %f %f %f %f\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]); end fclose(fid); disp('Data2Stl saved as obj') Qstl=BaseEval.Qstl; fid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+'); alpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold; for cP=1:size(Qstl,2) if(BaseEval.StlAbovePlane(cP)) C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) else C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis) end fprintf(fid,'v %f %f %f %f %f %f\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]); end fclose(fid); disp('Stl2Data saved as obj') ================================================ FILE: evaluations/dtu/BaseEvalMain_func.m ================================================ function None = BaseEvalMain_func(plyPath) % clear all % close all format compact % script to calculate distances have been measured for all included scans (UsedSets) dataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data'; % pred_results='cascade_hr/48-32-8_4-2-1_dlossw-0.5-1.0-2.0_chs888/gipuma_4_0.9/'; % plyPath=['../../outputs/1101/dtu/' pred_results]; % plyPath = '/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/repo_model_aligncorners_ITGT' resultsPath=[plyPath '/eval_out/']; disp(resultsPath); mkdir(resultsPath); method_string='mvsnet'; light_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6) representation_string='Points'; %mvs representation 'Points' or 'Surfaces' switch representation_string case 'Points' eval_string='_Eval_'; %results naming settings_string=''; end % get sets used in evaluation UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118]; % UsedSets=[15]; dst=0.2; %Min dist between points when reducing parfor cIdx=1:length(UsedSets) %Data set number cSet = UsedSets(cIdx) %input data name DataInName=[plyPath sprintf('/%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)] %results name EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat'] %check if file is already computed if(~exist(EvalName,'file')) disp(DataInName); time=clock;time(4:5), drawnow tic Mesh = plyread(DataInName); Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]'; toc BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath); disp('Saving results'), drawnow toc mySave(EvalName, BaseEval); toc % write obj-file of evaluation % BaseEval2Obj_web(BaseEval,method_string, resultsPath) % toc time=clock;time(4:5), drawnow BaseEval.MaxDist=20; %outlier threshold of 20 mm BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &... Qfrom(1,:)=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &... Qto(1,:)3)] end ================================================ FILE: evaluations/dtu/PointCompareMain.m ================================================ function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath) % evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the % distances from the evaluation points to the reference tic % reduce points 0.2 mm neighbourhood density Qdata=reducePts_haa(Qdata,dst); toc StlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply']; StlMesh = plyread(StlInName); %STL points already reduced 0.2 mm neighbourhood density Qstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]'; %Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res) Margin=10; MaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat']; load(MaskName) MaxDist=60; disp('Computing Data 2 Stl distances') Ddata = MaxDistCP(Qstl,Qdata,BB,MaxDist); toc disp('Computing Stl 2 Data distances') Dstl=MaxDistCP(Qdata,Qstl,BB,MaxDist); disp('Distances computed') toc %use mask %From Get mask - inverted & modified. One=ones(1,size(Qdata,2)); Qv=(Qdata-BB(1,:)'*One)/Res+1; Qv=round(Qv); Midx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3)); MidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1)); Midx2=find(ObsMask(MidxA)); BaseEval.DataInMask(1:size(Qv,2))=false; BaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask BaseEval.cSet=cSet; BaseEval.Margin=Margin; %Margin of masks BaseEval.dst=dst; %Min dist between points when reducing BaseEval.Qdata=Qdata; %Input data points BaseEval.Ddata=Ddata; %distance from data to stl BaseEval.Qstl=Qstl; %Input stl points BaseEval.Dstl=Dstl; %Distance from the stl to data load([dataPath '/ObsMask/Plane' num2str(cSet)],'P') BaseEval.GroundPlane=P; % Plane used to destinguise which Stl points are 'used' BaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane' BaseEval.Time=clock; %Time when computation is finished ================================================ FILE: evaluations/dtu/plyread.m ================================================ %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function [Elements,varargout] = plyread(Path,Str) %PLYREAD Read a PLY 3D data file. % [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file % FILENAME and returns a structure DATA. The fields in this structure % are defined by the PLY header; each element type is a field and each % element property is a subfield. If the file contains any comments, % they are returned in a cell string array COMMENTS. % % [TRI,PTS] = PLYREAD(FILENAME,'tri') or % [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex % and face data into triangular connectivity and vertex arrays. The % mesh can then be displayed using the TRISURF command. % % Note: This function is slow for large mesh files (+50K faces), % especially when reading data with list type properties. % % Example: % [Tri,Pts] = PLYREAD('cow.ply','tri'); % trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); % colormap(gray); axis equal; % % See also: PLYWRITE % Pascal Getreuer 2004 [fid,Msg] = fopen(Path,'rt'); % open file in read text mode if fid == -1, error(Msg); end Buf = fscanf(fid,'%s',1); if ~strcmp(Buf,'ply') fclose(fid); error('Not a PLY file.'); end %%% read header %%% Position = ftell(fid); Format = ''; NumComments = 0; Comments = {}; % for storing any file comments NumElements = 0; NumProperties = 0; Elements = []; % structure for holding the element data ElementCount = []; % number of each type of element in file PropertyTypes = []; % corresponding structure recording property types ElementNames = {}; % list of element names in the order they are stored in the file PropertyNames = []; % structure of lists of property names while 1 Buf = fgetl(fid); % read one line from file BufRem = Buf; Token = {}; Count = 0; while ~isempty(BufRem) % split line into tokens [tmp,BufRem] = strtok(BufRem); if ~isempty(tmp) Count = Count + 1; % count tokens Token{Count} = tmp; end end if Count % parse line switch lower(Token{1}) case 'format' % read data format if Count >= 2 Format = lower(Token{2}); if Count == 3 & ~strcmp(Token{3},'1.0') fclose(fid); error('Only PLY format version 1.0 supported.'); end end case 'comment' % read file comment NumComments = NumComments + 1; Comments{NumComments} = ''; for i = 2:Count Comments{NumComments} = [Comments{NumComments},Token{i},' ']; end case 'element' % element name if Count >= 3 if isfield(Elements,Token{2}) fclose(fid); error(['Duplicate element name, ''',Token{2},'''.']); end NumElements = NumElements + 1; NumProperties = 0; Elements = setfield(Elements,Token{2},[]); PropertyTypes = setfield(PropertyTypes,Token{2},[]); ElementNames{NumElements} = Token{2}; PropertyNames = setfield(PropertyNames,Token{2},{}); CurElement = Token{2}; ElementCount(NumElements) = str2double(Token{3}); if isnan(ElementCount(NumElements)) fclose(fid); error(['Bad element definition: ',Buf]); end else error(['Bad element definition: ',Buf]); end case 'property' % element property if ~isempty(CurElement) & Count >= 3 NumProperties = NumProperties + 1; eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],... 'fclose(fid);error([''Error reading property: '',Buf])'); if tmp error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']); end % add property subfield to Elements eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ... 'fclose(fid);error([''Error reading property: '',Buf])'); % add property subfield to PropertyTypes and save type eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ... 'fclose(fid);error([''Error reading property: '',Buf])'); % record property name order eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ... 'fclose(fid);error([''Error reading property: '',Buf])'); else fclose(fid); if isempty(CurElement) error(['Property definition without element definition: ',Buf]); else error(['Bad property definition: ',Buf]); end end case 'end_header' % end of header, break from while loop break; end end end %%% set reading for specified data format %%% if isempty(Format) warning('Data format unspecified, assuming ASCII.'); Format = 'ascii'; end switch Format case 'ascii' Format = 0; case 'binary_little_endian' Format = 1; case 'binary_big_endian' Format = 2; otherwise fclose(fid); error(['Data format ''',Format,''' not supported.']); end if ~Format Buf = fscanf(fid,'%f'); % read the rest of the file as ASCII data BufOff = 1; else % reopen the file in read binary mode fclose(fid); if Format == 1 fid = fopen(Path,'r','ieee-le.l64'); % little endian else fid = fopen(Path,'r','ieee-be.l64'); % big endian end % find the end of the header again (using ftell on the old handle doesn't give the correct position) BufSize = 8192; Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')]; i = []; tmp = -11; while isempty(i) i = findstr(Buf,['end_header',13,10]); % look for end_header + CR/LF i = [i,findstr(Buf,['end_header',10])]; % look for end_header + LF if isempty(i) tmp = tmp + BufSize; Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')]; end end % seek to just after the line feed fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1); end %%% read element data %%% % PLY and MATLAB data types (for fread) PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ... 'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'}; MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'}; SizeOf = [1,1,2,2,4,4,4,8]; % size in bytes of each type for i = 1:NumElements % get current element property information eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']); eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']); NumProperties = size(CurPropertyNames,2); % fprintf('Reading %s...\n',ElementNames{i}); if ~Format %%% read ASCII data %%% for j = 1:NumProperties Token = getfield(CurPropertyTypes,CurPropertyNames{j}); if strcmpi(Token{1},'list') Type(j) = 1; else Type(j) = 0; end end % parse buffer if ~any(Type) % no list types Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))'; BufOff = BufOff + ElementCount(i)*NumProperties; else ListData = cell(NumProperties,1); for k = 1:NumProperties ListData{k} = cell(ElementCount(i),1); end % list type for j = 1:ElementCount(i) for k = 1:NumProperties if ~Type(k) Data(j,k) = Buf(BufOff); BufOff = BufOff + 1; else tmp = Buf(BufOff); ListData{k}{j} = Buf(BufOff+(1:tmp))'; BufOff = BufOff + tmp + 1; end end end end else %%% read binary data %%% % translate PLY data type names to MATLAB data type names ListFlag = 0; % = 1 if there is a list type SameFlag = 1; % = 1 if all types are the same for j = 1:NumProperties Token = getfield(CurPropertyTypes,CurPropertyNames{j}); if ~strcmp(Token{1},'list') % non-list type tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1; if ~isempty(tmp) TypeSize(j) = SizeOf(tmp); Type{j} = MatlabTypeNames{tmp}; TypeSize2(j) = 0; Type2{j} = ''; SameFlag = SameFlag & strcmp(Type{1},Type{j}); else fclose(fid); error(['Unknown property data type, ''',Token{1},''', in ', ... ElementNames{i},'.',CurPropertyNames{j},'.']); end else % list type if length(Token) == 3 ListFlag = 1; SameFlag = 0; tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1; tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1; if ~isempty(tmp) & ~isempty(tmp2) TypeSize(j) = SizeOf(tmp); Type{j} = MatlabTypeNames{tmp}; TypeSize2(j) = SizeOf(tmp2); Type2{j} = MatlabTypeNames{tmp2}; else fclose(fid); error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ... ElementNames{i},'.',CurPropertyNames{j},'.']); end else fclose(fid); error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']); end end end % read file if ~ListFlag if SameFlag % no list types, all the same type (fast) Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})'; else % no list types, mixed type Data = zeros(ElementCount(i),NumProperties); for j = 1:ElementCount(i) for k = 1:NumProperties Data(j,k) = fread(fid,1,Type{k}); end end end else ListData = cell(NumProperties,1); for k = 1:NumProperties ListData{k} = cell(ElementCount(i),1); end if NumProperties == 1 BufSize = 512; SkipNum = 4; j = 0; % list type, one property (fast if lists are usually the same length) while j < ElementCount(i) Position = ftell(fid); % read in BufSize count values, assuming all counts = SkipNum [Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1)); Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum fseek(fid,Position + TypeSize(1),-1); % seek back to after first count if isempty(Miss) % all counts are SkipNum Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; fseek(fid,-TypeSize(1),0); % undo last skip for k = 1:BufSize ListData{1}{j+k} = Buf(k,:); end j = j + BufSize; BufSize = floor(1.5*BufSize); else if Miss(1) > 1 % some counts are SkipNum Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; for k = 1:Miss(1)-1 ListData{1}{j+k} = Buf2(k,:); end j = j + k; end % read in the list with the missed count SkipNum = Buf(Miss(1)); j = j + 1; ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1}); BufSize = ceil(0.6*BufSize); end end else % list type(s), multiple properties (slow) Data = zeros(ElementCount(i),NumProperties); for j = 1:ElementCount(i) for k = 1:NumProperties if isempty(Type2{k}) Data(j,k) = fread(fid,1,Type{k}); else tmp = fread(fid,1,Type{k}); ListData{k}{j} = fread(fid,[1,tmp],Type2{k}); end end end end end end % put data into Elements structure for k = 1:NumProperties if (~Format & ~Type(k)) | (Format & isempty(Type2{k})) eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']); else eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']); end end end clear Data ListData; fclose(fid); if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2 % find vertex element field Name = {'vertex','Vertex','point','Point','pts','Pts'}; Names = []; for i = 1:length(Name) if any(strcmp(ElementNames,Name{i})) Names = getfield(PropertyNames,Name{i}); Name = Name{i}; break; end end if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z')) eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']); else varargout{1} = zeros(1,3); end varargout{2} = Elements; varargout{3} = Comments; Elements = []; % find face element field Name = {'face','Face','poly','Poly','tri','Tri'}; Names = []; for i = 1:length(Name) if any(strcmp(ElementNames,Name{i})) Names = getfield(PropertyNames,Name{i}); Name = Name{i}; break; end end if ~isempty(Names) % find vertex indices property subfield PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'}; for i = 1:length(PropertyName) if any(strcmp(Names,PropertyName{i})) PropertyName = PropertyName{i}; break; end end if ~iscell(PropertyName) % convert face index lists to triangular connectivity eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']); N = length(FaceIndices); Elements = zeros(N*2,3); Extra = 0; for k = 1:N Elements(k,:) = FaceIndices{k}(1:3); for j = 4:length(FaceIndices{k}) Extra = Extra + 1; Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)]; end end Elements = Elements(1:N+Extra,:) + 1; end end else varargout{1} = Comments; end ================================================ FILE: evaluations/dtu/reducePts_haa.m ================================================ function [ptsOut,indexSet] = reducePts_haa(pts, dst) %Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance % between points is 'dst'. Writen by abd, edited by haa, then by raje nPoints=size(pts,2); indexSet=true(nPoints,1); RandOrd=randperm(nPoints); %tic NS = KDTreeSearcher(pts'); %toc % search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big Chunks=1:min(4e6,nPoints-1):nPoints; Chunks(end)=nPoints; for cChunk=1:(length(Chunks)-1) Range=Chunks(cChunk):Chunks(cChunk+1); idx = rangesearch(NS,pts(:,RandOrd(Range))',dst); for i = 1:size(idx,1) id =RandOrd(i-1+Chunks(cChunk)); if (indexSet(id)) indexSet(idx{i}) = 0; indexSet(id) = 1; end end end ptsOut = pts(:,indexSet); disp(['downsample factor: ' num2str(nPoints/sum(indexSet))]); ================================================ FILE: lists/blendedmvs/train.txt ================================================ 5c1f33f1d33e1f2e4aa6dda4 5bfe5ae0fe0ea555e6a969ca 5bff3c5cfe0ea555e6bcbf3a 58eaf1513353456af3a1682a 5bfc9d5aec61ca1dd69132a2 5bf18642c50e6f7f8bdbd492 5bf26cbbd43923194854b270 5bf17c0fd439231948355385 5be3ae47f44e235bdbbc9771 5be3a5fb8cfdd56947f6b67c 5bbb6eb2ea1cfa39f1af7e0c 5ba75d79d76ffa2c86cf2f05 5bb7a08aea1cfa39f1a947ab 5b864d850d072a699b32f4ae 5b6eff8b67b396324c5b2672 5b6e716d67b396324c2d77cb 5b69cc0cb44b61786eb959bf 5b62647143840965efc0dbde 5b60fa0c764f146feef84df0 5b558a928bbfb62204e77ba2 5b271079e0878c3816dacca4 5b08286b2775267d5b0634ba 5afacb69ab00705d0cefdd5b 5af28cea59bc705737003253 5af02e904c8216544b4ab5a2 5aa515e613d42d091d29d300 5c34529873a8df509ae57b58 5c34300a73a8df509add216d 5c1af2e2bee9a723c963d019 5c1892f726173c3a09ea9aeb 5c0d13b795da9479e12e2ee9 5c062d84a96e33018ff6f0a6 5bfd0f32ec61ca1dd69dc77b 5bf21799d43923194842c001 5bf3a82cd439231948877aed 5bf03590d4392319481971dc 5beb6e66abd34c35e18e66b9 5be883a4f98cee15019d5b83 5be47bf9b18881428d8fbc1d 5bcf979a6d5f586b95c258cd 5bce7ac9ca24970bce4934b6 5bb8a49aea1cfa39f1aa7f75 5b78e57afc8fcf6781d0c3ba 5b21e18c58e2823a67a10dd8 5b22269758e2823a67a3bd03 5b192eb2170cf166458ff886 5ae2e9c5fe405c5076abc6b2 5adc6bd52430a05ecb2ffb85 5ab8b8e029f5351f7f2ccf59 5abc2506b53b042ead637d86 5ab85f1dac4291329b17cb50 5a969eea91dfc339a9a3ad2c 5a8aa0fab18050187cbe060e 5a7d3db14989e929563eb153 5a69c47d0d5d0a7f3b2e9752 5a618c72784780334bc1972d 5a6464143d809f1d8208c43c 5a588a8193ac3d233f77fbca 5a57542f333d180827dfc132 5a572fd9fc597b0478a81d14 5a563183425d0f5186314855 5a4a38dad38c8a075495b5d2 5a48d4b2c7dab83a7d7b9851 5a489fb1c7dab83a7d7b1070 5a48ba95c7dab83a7d7b44ed 5a3ca9cb270f0e3f14d0eddb 5a3cb4e4270f0e3f14d12f43 5a3f4aba5889373fbbc5d3b5 5a0271884e62597cdee0d0eb 59e864b2a9e91f2c5529325f 599aa591d5b41f366fed0d58 59350ca084b7f26bf5ce6eb8 59338e76772c3e6384afbb15 5c20ca3a0843bc542d94e3e2 5c1dbf200843bc542d8ef8c4 5c1b1500bee9a723c96c3e78 5bea87f4abd34c35e1860ab5 5c2b3ed5e611832e8aed46bf 57f8d9bbe73f6760f10e916a 5bf7d63575c26f32dbf7413b 5be4ab93870d330ff2dce134 5bd43b4ba6b28b1ee86b92dd 5bccd6beca24970bce448134 5bc5f0e896b66a2cd8f9bd36 5b908d3dc6ab78485f3d24a9 5b2c67b5e0878c381608b8d8 5b4933abf2b5f44e95de482a 5b3b353d8d46a939f93524b9 5acf8ca0f3d8a750097e4b15 5ab8713ba3799a1d138bd69a 5aa235f64a17b335eeaf9609 5aa0f9d7a9efce63548c69a1 5a8315f624b8e938486e0bd8 5a48c4e9c7dab83a7d7b5cc7 59ecfd02e225f6492d20fcc9 59f87d0bfa6280566fb38c9a 59f363a8b45be22330016cad 59f70ab1e5c5d366af29bf3e 59e75a2ca9e91f2c5526005d 5947719bf1b45630bd096665 5947b62af1b45630bd0c2a02 59056e6760bb961de55f3501 58f7f7299f5b5647873cb110 58cf4771d0f5fb221defe6da 58d36897f387231e6c929903 58c4bb4f4a69c55606122be4 ================================================ FILE: lists/blendedmvs/val.txt ================================================ 5b7a3890fc8fcf6781e2593a 5c189f2326173c3a09ed7ef3 5b950c71608de421b1e7318f 5a6400933d809f1d8200af15 59d2657f82ca7774b1ec081d 5ba19a8a360c7c30c1c169df 59817e4a1bd4b175e7038d19 ================================================ FILE: lists/dtu/test.txt ================================================ scan1 scan4 scan9 scan10 scan11 scan12 scan13 scan15 scan23 scan24 scan29 scan32 scan33 scan34 scan48 scan49 scan62 scan75 scan77 scan110 scan114 scan118 ================================================ FILE: lists/dtu/train.txt ================================================ scan2 scan6 scan7 scan8 scan14 scan16 scan18 scan19 scan20 scan22 scan30 scan31 scan36 scan39 scan41 scan42 scan44 scan45 scan46 scan47 scan50 scan51 scan52 scan53 scan55 scan57 scan58 scan60 scan61 scan63 scan64 scan65 scan68 scan69 scan70 scan71 scan72 scan74 scan76 scan83 scan84 scan85 scan87 scan88 scan89 scan90 scan91 scan92 scan93 scan94 scan95 scan96 scan97 scan98 scan99 scan100 scan101 scan102 scan103 scan104 scan105 scan107 scan108 scan109 scan111 scan112 scan113 scan115 scan116 scan119 scan120 scan121 scan122 scan123 scan124 scan125 scan126 scan127 scan128 ================================================ FILE: lists/dtu/trainval.txt ================================================ scan2 scan6 scan7 scan8 scan14 scan16 scan18 scan19 scan20 scan22 scan30 scan31 scan36 scan39 scan41 scan42 scan44 scan45 scan46 scan47 scan50 scan51 scan52 scan53 scan55 scan57 scan58 scan60 scan61 scan63 scan64 scan65 scan68 scan69 scan70 scan71 scan72 scan74 scan76 scan83 scan84 scan85 scan87 scan88 scan89 scan90 scan91 scan92 scan93 scan94 scan95 scan96 scan97 scan98 scan99 scan100 scan101 scan102 scan103 scan104 scan105 scan107 scan108 scan109 scan111 scan112 scan113 scan115 scan116 scan119 scan120 scan121 scan122 scan123 scan124 scan125 scan126 scan127 scan128 scan3 scan5 scan17 scan21 scan28 scan35 scan37 scan38 scan40 scan43 scan56 scan59 scan66 scan67 scan82 scan86 scan106 scan117 ================================================ FILE: lists/dtu/val.txt ================================================ scan3 scan5 scan17 scan21 scan28 scan35 scan37 scan38 scan40 scan43 scan56 scan59 scan66 scan67 scan82 scan86 scan106 scan117 ================================================ FILE: models/MVS4Net.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from models.mvs4net_utils import stagenet, reg2d, reg3d, FPN4, FPN4_convnext, FPN4_convnext4, PosEncSine, PosEncLearned, \ init_range, schedule_range, init_inverse_range, schedule_inverse_range, sinkhorn, mono_depth_decoder, ASFF class MVS4net(nn.Module): def __init__(self, arch_mode="fpn", reg_net='reg2d', num_stage=4, fpn_base_channel=8, reg_channel=8, stage_splits=[8,8,4,4], depth_interals_ratio=[0.5,0.5,0.5,1], group_cor=False, group_cor_dim=[8,8,8,8], inverse_depth=False, agg_type='ConvBnReLU3D', dcn=False, pos_enc=0, mono=False, asff=False, attn_temp=2, attn_fuse_d=True, vis_ETA=False, vis_mono=False ): # pos_enc: 0 no pos enc; 1 depth sine; 2 learnable pos enc super(MVS4net, self).__init__() self.arch_mode = arch_mode self.num_stage = num_stage self.depth_interals_ratio = depth_interals_ratio self.group_cor = group_cor self.group_cor_dim = group_cor_dim self.inverse_depth = inverse_depth self.asff = asff if self.asff: self.asff = nn.ModuleList([ASFF(i) for i in range(num_stage)]) self.attn_ob = nn.ModuleList() if arch_mode == "fpn": self.feature = FPN4(base_channels=fpn_base_channel, gn=False, dcn=dcn) self.vis_mono = vis_mono self.stagenet = stagenet(inverse_depth, mono, attn_fuse_d, vis_ETA, attn_temp) self.stage_splits = stage_splits self.reg = nn.ModuleList() self.pos_enc = pos_enc self.pos_enc_func = nn.ModuleList() self.mono = mono if self.mono: self.mono_depth_decoder = mono_depth_decoder() if reg_net == 'reg3d': self.down_size = [3,3,2,2] for idx in range(num_stage): if self.group_cor: in_dim = group_cor_dim[idx] else: in_dim = self.feature.out_channels[idx] if reg_net == 'reg2d': self.reg.append(reg2d(input_channel=in_dim, base_channel=reg_channel, conv_name=agg_type)) elif reg_net == 'reg3d': self.reg.append(reg3d(in_channels=in_dim, base_channels=reg_channel, down_size=self.down_size[idx])) def forward(self, imgs, proj_matrices, depth_values, filename=None): depth_min = depth_values[:, 0].cpu().numpy() depth_max = depth_values[:, -1].cpu().numpy() depth_interval = (depth_max - depth_min) / depth_values.size(1) # step 1. feature extraction features = [] for nview_idx in range(len(imgs)): #imgs shape (B, N, C, H, W) img = imgs[nview_idx] features.append(self.feature(img)) if self.vis_mono: scan_name = filename[0].split('/')[0] image_name = filename[0].split('/')[2][:-2] save_fn = './debug_figs/vis_mono/feat_{}'.format(scan_name+'_'+image_name) feat_ = features[-1]['stage4'].detach().cpu().numpy() np.save(save_fn, feat_) # step 2. iter (multi-scale) outputs = {} for stage_idx in range(self.num_stage): if not self.asff: features_stage = [feat["stage{}".format(stage_idx+1)] for feat in features] else: features_stage = [self.asff[stage_idx](feat['stage1'],feat['stage2'],feat['stage3'],feat['stage4']) for feat in features] proj_matrices_stage = proj_matrices["stage{}".format(stage_idx + 1)] B,C,H,W = features[0]['stage{}'.format(stage_idx+1)].shape # init range if stage_idx == 0: if self.inverse_depth: depth_hypo = init_inverse_range(depth_values, self.stage_splits[stage_idx], img[0].device, img[0].dtype, H, W) else: depth_hypo = init_range(depth_values, self.stage_splits[stage_idx], img[0].device, img[0].dtype, H, W) else: if self.inverse_depth: depth_hypo = schedule_inverse_range(outputs_stage['inverse_min_depth'].detach(), outputs_stage['inverse_max_depth'].detach(), self.stage_splits[stage_idx], H, W) # B D H W else: depth_hypo = schedule_range(outputs_stage['depth'].detach(), self.stage_splits[stage_idx], self.depth_interals_ratio[stage_idx] * depth_interval, H, W) outputs_stage = self.stagenet(features_stage, proj_matrices_stage, depth_hypo=depth_hypo, regnet=self.reg[stage_idx], stage_idx=stage_idx, group_cor=self.group_cor, group_cor_dim=self.group_cor_dim[stage_idx], split_itv=self.depth_interals_ratio[stage_idx], fn=filename) outputs["stage{}".format(stage_idx + 1)] = outputs_stage outputs.update(outputs_stage) if self.mono and self.training: # if self.mono: outputs = self.mono_depth_decoder(outputs, depth_values[:,0], depth_values[:,1]) return outputs def MVS4net_loss(inputs, depth_gt_ms, mask_ms, **kwargs): stage_lw = kwargs.get("stage_lw", [1,1,1,1]) l1ot_lw = kwargs.get("l1ot_lw", [0,1]) inverse = kwargs.get("inverse_depth", False) ot_iter = kwargs.get("ot_iter", 3) ot_eps = kwargs.get("ot_eps", 1) ot_continous = kwargs.get("ot_continous", False) mono = kwargs.get("mono", False) total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) stage_ot_loss = [] stage_l1_loss = [] range_err_ratio = [] for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if "stage" in k]): depth_pred = stage_inputs['depth'] hypo_depth = stage_inputs['hypo_depth'] attn_weight = stage_inputs['attn_weight'] B,H,W = depth_pred.shape D = hypo_depth.shape[1] mask = mask_ms[stage_key] mask = mask > 0.5 depth_gt = depth_gt_ms[stage_key] if mono and stage_idx!=0: this_stage_l1_loss = F.l1_loss(stage_inputs['mono_depth'][mask], depth_gt[mask], reduction='mean') else: this_stage_l1_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) # mask range if inverse: depth_itv = (1/hypo_depth[:,2,:,:]-1/hypo_depth[:,1,:,:]).abs() # B H W mask_out_of_range = ((1/hypo_depth - 1/depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W else: depth_itv = (hypo_depth[:,2,:,:]-hypo_depth[:,1,:,:]).abs() # B H W mask_out_of_range = ((hypo_depth - depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W range_err_ratio.append(mask_out_of_range[mask].float().mean()) this_stage_ot_loss = sinkhorn(depth_gt, hypo_depth, attn_weight, mask, iters=ot_iter, eps=ot_eps, continuous=ot_continous)[1] stage_l1_loss.append(this_stage_l1_loss) stage_ot_loss.append(this_stage_ot_loss) total_loss = total_loss + stage_lw[stage_idx] * (l1ot_lw[0] * this_stage_l1_loss + l1ot_lw[1] * this_stage_ot_loss) return total_loss, stage_l1_loss, stage_ot_loss, range_err_ratio def Blend_loss(inputs, depth_gt_ms, mask_ms, **kwargs): stage_lw = kwargs.get("stage_lw", [1,1,1,1]) l1ot_lw = kwargs.get("l1ot_lw", [0,1]) inverse = kwargs.get("inverse_depth", False) ot_iter = kwargs.get("ot_iter", 3) ot_eps = kwargs.get("ot_eps", 1) ot_continous = kwargs.get("ot_continous", False) depth_max = kwargs.get("depth_max", 100) depth_min = kwargs.get("depth_min", 1) mono = kwargs.get("mono", False) total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) stage_ot_loss = [] stage_l1_loss = [] range_err_ratio = [] for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if "stage" in k]): depth_pred = stage_inputs['depth'] hypo_depth = stage_inputs['hypo_depth'] attn_weight = stage_inputs['attn_weight'] B,H,W = depth_pred.shape mask = mask_ms[stage_key] mask = mask > 0.5 depth_gt = depth_gt_ms[stage_key] depth_pred_norm = depth_pred * 128 / (depth_max - depth_min)[:,None,None] # B H W depth_gt_norm = depth_gt * 128 / (depth_max - depth_min)[:,None,None] # B H W if mono and stage_idx!=0: this_stage_l1_loss = F.l1_loss(stage_inputs['mono_depth'][mask], depth_gt[mask], reduction='mean') else: this_stage_l1_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) if inverse: depth_itv = (1/hypo_depth[:,2,:,:]-1/hypo_depth[:,1,:,:]).abs() # B H W mask_out_of_range = ((1/hypo_depth - 1/depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W else: depth_itv = (hypo_depth[:,2,:,:]-hypo_depth[:,1,:,:]).abs() # B H W mask_out_of_range = ((hypo_depth - depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W range_err_ratio.append(mask_out_of_range[mask].float().mean()) this_stage_ot_loss = sinkhorn(depth_gt, hypo_depth, attn_weight, mask, iters=ot_iter, eps=ot_eps, continuous=ot_continous)[1] stage_l1_loss.append(this_stage_l1_loss) stage_ot_loss.append(this_stage_ot_loss) total_loss = total_loss + stage_lw[stage_idx] * (l1ot_lw[0] * this_stage_l1_loss + l1ot_lw[1] * this_stage_ot_loss) abs_err = torch.abs(depth_pred_norm[mask] - depth_gt_norm[mask]) epe = abs_err.mean() err3 = (abs_err<=3).float().mean()*100 err1= (abs_err<=1).float().mean()*100 return total_loss, stage_l1_loss, stage_ot_loss, range_err_ratio, epe, err3, err1 ================================================ FILE: models/__init__.py ================================================ from models.MVS4Net import MVS4net, MVS4net_loss, Blend_loss ================================================ FILE: models/module.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import time import sys import seaborn as sns import numpy as np import matplotlib.pyplot as plt sys.path.append("..") from utils import local_pcd from modules.deform_conv import DeformConvPack def init_bn(module): if module.weight is not None: nn.init.ones_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) return def init_uniform(module, init_method): if module.weight is not None: if init_method == "kaiming": nn.init.kaiming_uniform_(module.weight) elif init_method == "xavier": nn.init.xavier_uniform_(module.weight) return class Conv2d(nn.Module): """Applies a 2D convolution (optionally with batch normalization and relu activation) over an input signal composed of several input planes. Attributes: conv (nn.Module): convolution module bn (nn.Module): batch normalization module relu (bool): whether to activate by relu Notes: Default momentum for batch normalization is set to be 0.01, """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): super(Conv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=(not bn), **kwargs) self.kernel_size = kernel_size self.stride = stride self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None self.relu = relu # assert init_method in ["kaiming", "xavier"] # self.init_weights(init_method) def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu: x = F.relu(x, inplace=True) return x def init_weights(self, init_method): """default initialization""" init_uniform(self.conv, init_method) if self.bn is not None: init_bn(self.bn) class DCNConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): super(DCNConv2d, self).__init__() self.conv = DeformConvPack(in_channels, out_channels, kernel_size, stride=stride, padding=1, bias=(not bn), im2col_step=16) self.kernel_size = kernel_size self.stride = stride self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None self.relu = relu # assert init_method in ["kaiming", "xavier"] # self.init_weights(init_method) def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu: x = F.relu(x, inplace=True) return x def init_weights(self, init_method): """default initialization""" init_uniform(self.conv, init_method) if self.bn is not None: init_bn(self.bn) class Deconv2d(nn.Module): """Applies a 2D deconvolution (optionally with batch normalization and relu activation) over an input signal composed of several input planes. Attributes: conv (nn.Module): convolution module bn (nn.Module): batch normalization module relu (bool): whether to activate by relu Notes: Default momentum for batch normalization is set to be 0.01, """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): super(Deconv2d, self).__init__() self.out_channels = out_channels assert stride in [1, 2] self.stride = stride self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, bias=(not bn), **kwargs) self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None self.relu = relu # assert init_method in ["kaiming", "xavier"] # self.init_weights(init_method) def forward(self, x): y = self.conv(x) if self.stride == 2: h, w = list(x.size())[2:] y = y[:, :, :2 * h, :2 * w].contiguous() if self.bn is not None: x = self.bn(y) if self.relu: x = F.relu(x, inplace=True) return x def init_weights(self, init_method): """default initialization""" init_uniform(self.conv, init_method) if self.bn is not None: init_bn(self.bn) class Conv3d(nn.Module): """Applies a 3D convolution (optionally with batch normalization and relu activation) over an input signal composed of several input planes. Attributes: conv (nn.Module): convolution module bn (nn.Module): batch normalization module relu (bool): whether to activate by relu Notes: Default momentum for batch normalization is set to be 0.01, """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): super(Conv3d, self).__init__() self.out_channels = out_channels self.kernel_size = kernel_size assert stride in [1, 2] self.stride = stride self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, bias=(not bn), **kwargs) self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None self.relu = relu # assert init_method in ["kaiming", "xavier"] # self.init_weights(init_method) def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) if self.relu: x = F.relu(x, inplace=True) return x def init_weights(self, init_method): """default initialization""" init_uniform(self.conv, init_method) if self.bn is not None: init_bn(self.bn) class PConv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, relu=True, bn=True, bn_momentum=0.1, padding=1, init_method="xavier", **kwargs): super(PConv3d, self).__init__() self.out_channels = out_channels self.kernel_size_xy = (1, kernel_size, kernel_size) self.kernel_size_d = (kernel_size, 1, 1) assert stride in [1, 2] self.stride_xy = (1, stride, stride) self.stride_d = (stride, 1, 1) self.padding_xy = (0, padding, padding) self.padding_d = (padding, 0, 0) self.convxy = nn.Conv3d(in_channels, in_channels, self.kernel_size_xy, stride=self.stride_xy, padding=self.padding_xy, bias=(not bn), **kwargs) self.convd = nn.Conv3d(in_channels, out_channels, self.kernel_size_d, stride=self.stride_d, padding=self.padding_d, bias=(not bn), **kwargs) self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None self.relu = relu # assert init_method in ["kaiming", "xavier"] # self.init_weights(init_method) def forward(self, x): x = self.convxy(x) x = self.convd(x) if self.bn is not None: x = self.bn(x) if self.relu: x = F.relu(x, inplace=True) return x def init_weights(self, init_method): """default initialization""" init_uniform(self.convxy, init_method) init_uniform(self.convd, init_method) if self.bn is not None: init_bn(self.bn) class Deconv3d(nn.Module): """Applies a 3D deconvolution (optionally with batch normalization and relu activation) over an input signal composed of several input planes. Attributes: conv (nn.Module): convolution module bn (nn.Module): batch normalization module relu (bool): whether to activate by relu Notes: Default momentum for batch normalization is set to be 0.01, """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): super(Deconv3d, self).__init__() self.out_channels = out_channels assert stride in [1, 2] self.stride = stride self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, bias=(not bn), **kwargs) self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None self.relu = relu # assert init_method in ["kaiming", "xavier"] # self.init_weights(init_method) def forward(self, x): y = self.conv(x) if self.bn is not None: x = self.bn(y) if self.relu: x = F.relu(x, inplace=True) return x def init_weights(self, init_method): """default initialization""" init_uniform(self.conv, init_method) if self.bn is not None: init_bn(self.bn) class PDeconv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,output_padding=1, relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): super(PDeconv3d, self).__init__() self.out_channels = out_channels assert stride in [1, 2] self.stride = stride self.kernel_size_xy = (1, kernel_size,kernel_size) self.kernel_size_d = (kernel_size, 1,1) self.stride_xy = (1, stride, stride) self.stride_d = (stride, 1, 1) self.padding_xy = (0, padding, padding) self.padding_d = (padding, 0, 0) self.outpadding_xy = (0, output_padding, output_padding) self.outpadding_d = (output_padding, 0, 0) self.convxy = nn.ConvTranspose3d(in_channels, in_channels, self.kernel_size_xy, stride=self.stride_xy, padding=self.padding_xy, output_padding=self.outpadding_xy, bias=(not bn)) self.convd = nn.ConvTranspose3d(in_channels, out_channels, self.kernel_size_d, stride=self.stride_d, padding=self.padding_d, output_padding=self.outpadding_d, bias=(not bn)) self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None self.relu = relu # assert init_method in ["kaiming", "xavier"] # self.init_weights(init_method) def forward(self, x): x = self.convxy(x) y = self.convd(x) if self.bn is not None: x = self.bn(y) if self.relu: x = F.relu(x, inplace=True) return x def init_weights(self, init_method): """default initialization""" init_uniform(self.convxy, init_method) init_uniform(self.convd, init_method) if self.bn is not None: init_bn(self.bn) class ConvBnReLU(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBnReLU, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): return F.relu(self.bn(self.conv(x)), inplace=True) class ConvBn(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBn, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm2d(out_channels) def forward(self, x): return self.bn(self.conv(x)) class ConvBnReLU3D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBnReLU3D, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm3d(out_channels) def forward(self, x): return F.relu(self.bn(self.conv(x)), inplace=True) class ConvBn3D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBn3D, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm3d(out_channels) def forward(self, x): return self.bn(self.conv(x)) class BasicBlock(nn.Module): def __init__(self, in_channels, out_channels, stride, downsample=None): super(BasicBlock, self).__init__() self.conv1 = ConvBnReLU(in_channels, out_channels, kernel_size=3, stride=stride, pad=1) self.conv2 = ConvBn(out_channels, out_channels, kernel_size=3, stride=1, pad=1) self.downsample = downsample self.stride = stride def forward(self, x): out = self.conv1(x) out = self.conv2(out) if self.downsample is not None: x = self.downsample(x) out += x return out class Hourglass3d(nn.Module): def __init__(self, channels): super(Hourglass3d, self).__init__() self.conv1a = ConvBnReLU3D(channels, channels * 2, kernel_size=3, stride=2, pad=1) self.conv1b = ConvBnReLU3D(channels * 2, channels * 2, kernel_size=3, stride=1, pad=1) self.conv2a = ConvBnReLU3D(channels * 2, channels * 4, kernel_size=3, stride=2, pad=1) self.conv2b = ConvBnReLU3D(channels * 4, channels * 4, kernel_size=3, stride=1, pad=1) self.dconv2 = nn.Sequential( nn.ConvTranspose3d(channels * 4, channels * 2, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), nn.BatchNorm3d(channels * 2)) self.dconv1 = nn.Sequential( nn.ConvTranspose3d(channels * 2, channels, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), nn.BatchNorm3d(channels)) self.redir1 = ConvBn3D(channels, channels, kernel_size=1, stride=1, pad=0) self.redir2 = ConvBn3D(channels * 2, channels * 2, kernel_size=1, stride=1, pad=0) def forward(self, x): conv1 = self.conv1b(self.conv1a(x)) conv2 = self.conv2b(self.conv2a(conv1)) dconv2 = F.relu(self.dconv2(conv2) + self.redir2(conv1), inplace=True) dconv1 = F.relu(self.dconv1(dconv2) + self.redir1(x), inplace=True) return dconv1 def homo_warping(src_fea, src_proj, ref_proj, depth_values, align_corners=False): # src_fea: [B, C, H, W] # src_proj: [B, 4, 4] # ref_proj: [B, 4, 4] # depth_values: [B, Ndepth] o [B, Ndepth, H, W] # out: [B, C, Ndepth, H, W] batch, channels = src_fea.shape[0], src_fea.shape[1] num_depth = depth_values.shape[1] height, width = src_fea.shape[2], src_fea.shape[3] with torch.no_grad(): proj = torch.matmul(src_proj, torch.inverse(ref_proj)) rot = proj[:, :3, :3] # [B,3,3] trans = proj[:, :3, 3:4] # [B,3,1] y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device), torch.arange(0, width, dtype=torch.float32, device=src_fea.device)]) y, x = y.contiguous(), x.contiguous() y, x = y.view(height * width), x.view(height * width) xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.view(batch, 1, num_depth, -1) # [B, 3, Ndepth, H*W] proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W] proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1 proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1 proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] grid = proj_xy warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear', padding_mode='zeros', align_corners=align_corners) warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width) return warped_src_fea class DeConv2dFuse(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, relu=True, bn=True, bn_momentum=0.1): super(DeConv2dFuse, self).__init__() self.deconv = Deconv2d(in_channels, out_channels, kernel_size, stride=2, padding=1, output_padding=1, bn=True, relu=relu, bn_momentum=bn_momentum) self.conv = Conv2d(2*out_channels, out_channels, kernel_size, stride=1, padding=1, bn=bn, relu=relu, bn_momentum=bn_momentum) # assert init_method in ["kaiming", "xavier"] # self.init_weights(init_method) def forward(self, x_pre, x): x = self.deconv(x) x = torch.cat((x, x_pre), dim=1) x = self.conv(x) return x class FeatureNet(nn.Module): def __init__(self, base_channels, num_stage=3, stride=4, arch_mode="unet"): super(FeatureNet, self).__init__() assert arch_mode in ["unet", "fpn"], print("mode must be in 'unet' or 'fpn', but get:{}".format(arch_mode)) print("*************feature extraction arch mode:{}****************".format(arch_mode)) self.arch_mode = arch_mode self.stride = stride self.base_channels = base_channels self.num_stage = num_stage self.conv0 = nn.Sequential( Conv2d(3, base_channels, 3, 1, padding=1), Conv2d(base_channels, base_channels, 3, 1, padding=1), ) self.conv1 = nn.Sequential( Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), ) self.conv2 = nn.Sequential( Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), ) self.out1 = nn.Conv2d(base_channels * 4, base_channels * 4, 1, bias=False) self.out_channels = [4 * base_channels] if self.arch_mode == 'unet': if num_stage == 3: self.deconv1 = DeConv2dFuse(base_channels * 4, base_channels * 2, 3) self.deconv2 = DeConv2dFuse(base_channels * 2, base_channels, 3) self.out2 = nn.Conv2d(base_channels * 2, base_channels * 2, 1, bias=False) self.out3 = nn.Conv2d(base_channels, base_channels, 1, bias=False) self.out_channels.append(2 * base_channels) self.out_channels.append(base_channels) elif num_stage == 2: self.deconv1 = DeConv2dFuse(base_channels * 4, base_channels * 2, 3) self.out2 = nn.Conv2d(base_channels * 2, base_channels * 2, 1, bias=False) self.out_channels.append(2 * base_channels) elif self.arch_mode == "fpn": final_chs = base_channels * 4 if num_stage == 3: self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True) self.out2 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False) self.out3 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False) self.out_channels.append(base_channels * 2) self.out_channels.append(base_channels) elif num_stage == 2: self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) self.out2 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False) self.out_channels.append(base_channels) def forward(self, x): conv0 = self.conv0(x) conv1 = self.conv1(conv0) conv2 = self.conv2(conv1) intra_feat = conv2 outputs = {} out = self.out1(intra_feat) outputs["stage1"] = out if self.arch_mode == "unet": if self.num_stage == 3: intra_feat = self.deconv1(conv1, intra_feat) out = self.out2(intra_feat) outputs["stage2"] = out intra_feat = self.deconv2(conv0, intra_feat) out = self.out3(intra_feat) outputs["stage3"] = out elif self.num_stage == 2: intra_feat = self.deconv1(conv1, intra_feat) out = self.out2(intra_feat) outputs["stage2"] = out elif self.arch_mode == "fpn": if self.num_stage == 3: intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner1(conv1) out = self.out2(intra_feat) outputs["stage2"] = out intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner2(conv0) out = self.out3(intra_feat) outputs["stage3"] = out elif self.num_stage == 2: intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner1(conv1) out = self.out2(intra_feat) outputs["stage2"] = out return outputs class FPNDCNpath(nn.Module): """ FPN+DCN pathway""" def __init__(self, base_channels, stride=4): super(FPNDCNpath, self).__init__() self.stride = stride self.base_channels = base_channels self.conv0 = nn.Sequential( Conv2d(3, base_channels, 3, 1, padding=1), Conv2d(base_channels, base_channels, 3, 1, padding=1), ) self.conv1 = nn.Sequential( Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), ) self.conv2 = nn.Sequential( Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), ) self.out1 = nn.Sequential( DCNConv2d(base_channels * 4, base_channels * 4, 3, stride=1, padding=1), DCNConv2d(base_channels * 4, base_channels * 4, 3, stride=1, padding=1), DeformConvPack(base_channels * 4, base_channels * 4, 3, stride=1, padding=1, bias=False, im2col_step=16) ) self.out_channels = [4 * base_channels] final_chs = base_channels * 4 self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True) self.out2 = nn.Sequential( DCNConv2d(base_channels * 4, base_channels * 2, 3, stride=1, padding=1), DCNConv2d(base_channels * 2, base_channels * 2, 3, stride=1, padding=1), DeformConvPack(base_channels * 2, base_channels * 2, 3, stride=1, padding=1, bias=False, im2col_step=16) ) self.out2pathconv = nn.Conv2d(base_channels * 4, base_channels * 2, 3, stride=1, padding=1) self.out3 = nn.Sequential( DCNConv2d(base_channels * 4, base_channels * 1, 3, stride=1, padding=1), DCNConv2d(base_channels * 1, base_channels * 1, 3, stride=1, padding=1), DeformConvPack(base_channels * 1, base_channels * 1, 3, stride=1, padding=1, bias=False, im2col_step=16) ) self.out3pathconv = nn.Conv2d(base_channels * 2, base_channels * 1, 3, stride=1, padding=1) self.out_channels.append(base_channels * 2) self.out_channels.append(base_channels) def forward(self, x): conv0 = self.conv0(x) conv1 = self.conv1(conv0) conv2 = self.conv2(conv1) intra_feat = conv2 outputs = {} out1 = self.out1(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv1) out2 = self.out2(intra_feat) out2 = out2 + self.out2pathconv(F.interpolate(out1, scale_factor=2, mode="bilinear", align_corners=True)) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv0) out3 = self.out3(intra_feat) out3 = out3 + self.out3pathconv(F.interpolate(out2, scale_factor=2, mode="bilinear", align_corners=True)) outputs["stage1"] = out1 outputs["stage2"] = out2 outputs["stage3"] = out3 return outputs class FPNDCN(nn.Module): """ FPN+DCN""" def __init__(self, base_channels, stride=4): super(FPNDCN, self).__init__() self.stride = stride self.base_channels = base_channels self.conv0 = nn.Sequential( Conv2d(3, base_channels, 3, 1, padding=1), Conv2d(base_channels, base_channels, 3, 1, padding=1), ) self.conv1 = nn.Sequential( Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), ) self.conv2 = nn.Sequential( Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), ) self.out1 = nn.Sequential( DCNConv2d(base_channels * 4, base_channels * 4, 3, stride=1, padding=1), DCNConv2d(base_channels * 4, base_channels * 4, 3, stride=1, padding=1), DeformConvPack(base_channels * 4, base_channels * 4, 3, stride=1, padding=1, bias=False, im2col_step=16) ) self.out_channels = [4 * base_channels] final_chs = base_channels * 4 self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True) self.out2 = nn.Sequential( DCNConv2d(base_channels * 4, base_channels * 2, 3, stride=1, padding=1), DCNConv2d(base_channels * 2, base_channels * 2, 3, stride=1, padding=1), DeformConvPack(base_channels * 2, base_channels * 2, 3, stride=1, padding=1, bias=False, im2col_step=16) ) self.out3 = nn.Sequential( DCNConv2d(base_channels * 4, base_channels * 1, 3, stride=1, padding=1), DCNConv2d(base_channels * 1, base_channels * 1, 3, stride=1, padding=1), DeformConvPack(base_channels * 1, base_channels * 1, 3, stride=1, padding=1, bias=False, im2col_step=16) ) self.out_channels.append(base_channels * 2) self.out_channels.append(base_channels) def forward(self, x): conv0 = self.conv0(x) conv1 = self.conv1(conv0) conv2 = self.conv2(conv1) intra_feat = conv2 outputs = {} out1 = self.out1(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv1) out2 = self.out2(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv0) out3 = self.out3(intra_feat) outputs["stage1"] = out1 outputs["stage2"] = out2 outputs["stage3"] = out3 return outputs class FPNA(nn.Module): """ FPN aligncorners""" def __init__(self, base_channels, stride=4): super(FPNA, self).__init__() self.stride = stride self.base_channels = base_channels self.conv0 = nn.Sequential( Conv2d(3, base_channels, 3, 1, padding=1), Conv2d(base_channels, base_channels, 3, 1, padding=1), ) self.conv1 = nn.Sequential( Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), ) self.conv2 = nn.Sequential( Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), ) self.out1 = nn.Conv2d(base_channels * 4, base_channels * 4, 1, bias=False) self.out_channels = [4 * base_channels] final_chs = base_channels * 4 self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True) self.out2 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False) self.out3 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False) self.out_channels.append(base_channels * 2) self.out_channels.append(base_channels) def forward(self, x): conv0 = self.conv0(x) conv1 = self.conv1(conv0) conv2 = self.conv2(conv1) intra_feat = conv2 outputs = {} out1 = self.out1(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv1) out2 = self.out2(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv0) out3 = self.out3(intra_feat) outputs["stage1"] = out1 outputs["stage2"] = out2 outputs["stage3"] = out3 return outputs class FPNA4(nn.Module): """ FPN aligncorners downsample 4x""" def __init__(self, base_channels): super(FPNA4, self).__init__() self.base_channels = base_channels self.conv0 = nn.Sequential( Conv2d(3, base_channels, 3, 1, padding=1), Conv2d(base_channels, base_channels, 3, 1, padding=1), ) self.conv1 = nn.Sequential( Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), ) self.conv2 = nn.Sequential( Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), ) self.conv3 = nn.Sequential( Conv2d(base_channels * 4, base_channels * 8, 5, stride=2, padding=2), Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1), Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1), ) self.out_channels = [8 * base_channels] final_chs = base_channels * 8 self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True) self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True) self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False) self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False) self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False) self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False) self.out_channels.append(base_channels * 4) self.out_channels.append(base_channels * 2) self.out_channels.append(base_channels) def forward(self, x): conv0 = self.conv0(x) conv1 = self.conv1(conv0) conv2 = self.conv2(conv1) conv3 = self.conv3(conv2) intra_feat = conv3 outputs = {} out1 = self.out1(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2) out2 = self.out2(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1) out3 = self.out3(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0) out4 = self.out4(intra_feat) outputs["stage1"] = out1 outputs["stage2"] = out2 outputs["stage3"] = out3 outputs["stage4"] = out4 return outputs class CostRegNet(nn.Module): def __init__(self, in_channels, base_channels, down_size=3): super(CostRegNet, self).__init__() self.down_size = down_size self.conv0 = Conv3d(in_channels, base_channels, padding=1) self.conv1 = Conv3d(base_channels, base_channels * 2, stride=2, padding=1) self.conv2 = Conv3d(base_channels * 2, base_channels * 2, padding=1) if down_size >= 2: self.conv3 = Conv3d(base_channels * 2, base_channels * 4, stride=2, padding=1) self.conv4 = Conv3d(base_channels * 4, base_channels * 4, padding=1) if down_size >= 3: self.conv5 = Conv3d(base_channels * 4, base_channels * 8, stride=2, padding=1) self.conv6 = Conv3d(base_channels * 8, base_channels * 8, padding=1) self.conv7 = Deconv3d(base_channels * 8, base_channels * 4, stride=2, padding=1, output_padding=1) if down_size >= 2: self.conv9 = Deconv3d(base_channels * 4, base_channels * 2, stride=2, padding=1, output_padding=1) self.conv11 = Deconv3d(base_channels * 2, base_channels * 1, stride=2, padding=1, output_padding=1) self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False) def forward(self, x): if self.down_size==3: conv0 = self.conv0(x) conv2 = self.conv2(self.conv1(conv0)) conv4 = self.conv4(self.conv3(conv2)) x = self.conv6(self.conv5(conv4)) x = conv4 + self.conv7(x) x = conv2 + self.conv9(x) x = conv0 + self.conv11(x) x = self.prob(x) elif self.down_size==2: conv0 = self.conv0(x) conv2 = self.conv2(self.conv1(conv0)) x = self.conv4(self.conv3(conv2)) x = conv2 + self.conv9(x) x = conv0 + self.conv11(x) x = self.prob(x) else: conv0 = self.conv0(x) x = self.conv2(self.conv1(conv0)) x = conv0 + self.conv11(x) x = self.prob(x) return x class P3DConv(nn.Module): """ Pseudo 3D conv: 3x3x1 + 1x3x3 """ def __init__(self, in_channels, base_channels): super(P3DConv, self).__init__() self.conv0 = PConv3d(in_channels, base_channels, padding=1) self.conv1 = PConv3d(base_channels, base_channels * 2, stride=2, padding=1) self.conv2 = PConv3d(base_channels * 2, base_channels * 2, padding=1) self.conv3 = PConv3d(base_channels * 2, base_channels * 4, stride=2, padding=1) self.conv4 = PConv3d(base_channels * 4, base_channels * 4, padding=1) self.conv5 = PConv3d(base_channels * 4, base_channels * 8, stride=2, padding=1) self.conv6 = PConv3d(base_channels * 8, base_channels * 8, padding=1) self.conv7 = PDeconv3d(base_channels * 8, base_channels * 4, stride=2, padding=1, output_padding=1) self.conv9 = PDeconv3d(base_channels * 4, base_channels * 2, stride=2, padding=1, output_padding=1) self.conv11 = PDeconv3d(base_channels * 2, base_channels * 1, stride=2, padding=1, output_padding=1) self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False) def forward(self, x): conv0 = self.conv0(x) conv2 = self.conv2(self.conv1(conv0)) conv4 = self.conv4(self.conv3(conv2)) x = self.conv6(self.conv5(conv4)) x = conv4 + self.conv7(x) x = conv2 + self.conv9(x) x = conv0 + self.conv11(x) x = self.prob(x) return x class RefineNet(nn.Module): def __init__(self): super(RefineNet, self).__init__() self.conv1 = ConvBnReLU(4, 32) self.conv2 = ConvBnReLU(32, 32) self.conv3 = ConvBnReLU(32, 32) self.res = ConvBnReLU(32, 1) def forward(self, img, depth_init): concat = F.cat((img, depth_init), dim=1) depth_residual = self.res(self.conv3(self.conv2(self.conv1(concat)))) depth_refined = depth_init + depth_residual return depth_refined def depth_regression(p, depth_values): if depth_values.dim() <= 2: # print("regression dim <= 2") depth_values = depth_values.view(*depth_values.shape, 1, 1) depth = torch.sum(p * depth_values, 1) return depth def cas_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs): depth_loss_weights = kwargs.get("dlossw", None) total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "stage" in k]: depth_est = stage_inputs["depth"] depth_gt = depth_gt_ms[stage_key] mask = mask_ms[stage_key] mask = mask > 0.5 depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean') if depth_loss_weights is not None: stage_idx = int(stage_key.replace("stage", "")) - 1 total_loss += depth_loss_weights[stage_idx] * depth_loss else: total_loss += 1.0 * depth_loss return total_loss, depth_loss def cas_mvsnet_T_loss(inputs, depth_gt_ms, mask_ms, **kwargs): depth_loss_weights = kwargs.get("dlossw", None) l1ce_lw = kwargs.get("l1ce_lw", [0.1, 1]) range_thres = kwargs.get("range_thres", [84.8, 10.6]) cas_method = kwargs.get("cascade_method", None) last_conv3d = kwargs.get("last_conv3d", False) visual = kwargs.get("visual", False) wt = kwargs.get("wt", False) fl = kwargs.get("fl", False) shrink_method = kwargs.get("shrink_method", 'schedule') upsampled_loss = kwargs.get("upsampled_loss", False) selected_loss = kwargs.get("selected_loss", False) mask_range_loss = kwargs.get("mask_range_loss", False) det = kwargs.get("det", False) if visual: f, axs = plt.subplots(figsize=(30, 10),ncols=3) # depth offset f2, axs2 = plt.subplots(figsize=(30, 10),ncols=3) # attn weight max f3, axs3 = plt.subplots(figsize=(30, 10),ncols=3) # attn weight gt val f4, axs4 = plt.subplots(figsize=(30, 10),ncols=3) # max gt offset err_848_str = '' err_106_str = '' err_002_str = '' total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) stage_depth_loss = [] stage_ce_loss = [] range_err_ratio = [] upsampled_depth_losses = [] det_offset_losses = [] for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if "stage" in k]): depth_est = stage_inputs["depth"] B,H,W = depth_est.shape mask = mask_ms[stage_key] mask = mask > 0.5 depth_gt = depth_gt_ms[stage_key] if upsampled_loss: if stage_idx!=0 : upsampled_depth = stage_inputs["upsampled_depth"] upsampled_depth_loss = F.smooth_l1_loss(upsampled_depth[mask], depth_gt[mask], reduction='mean') upsampled_depth_losses.append(upsampled_depth_loss) else: if stage_idx!=0 : upsampled_depth_losses.append(torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)) if mask_range_loss: if stage_idx != 0: depth_offset = next_stage_depth_hypo - depth_gt # B H W this_stage_mask_range = torch.abs(depth_offset)range_thres[stage_idx]).float().mean()) if visual: depth_offset = depth_est - depth_gt depth_offset[~mask] = 0 depth_offset = depth_offset.detach().cpu().numpy()[0] # H W err_848_str += str((np.abs(depth_offset)>84.8).sum()) + ',' err_106_str += str((np.abs(depth_offset)>10.6).sum()) + ',' err_002_str += str((np.abs(depth_offset)>2).sum()) + ',' sns.heatmap(depth_offset, annot=False, ax=axs[stage_idx]) attn_weights = stage_inputs["attn_weights"][0] # D H W attn_weights_max, ind_max = torch.max(attn_weights, 0) attn_weights_max = attn_weights_max.detach().cpu().numpy() # H W sns.heatmap(attn_weights_max, annot=False, ax=axs2[stage_idx]) this_stage_depth_val = stage_inputs['depth_values'] # B D H W depth_offsets = torch.abs(this_stage_depth_val- depth_gt[:,None,:,:])[0] # D,H,W _, indices = torch.min(depth_offsets, dim=0, keepdim=True) # [1, H, W] attn_gt = torch.gather(attn_weights, 0, indices)[0] # [H W] attn_gt = attn_gt.detach().cpu().numpy() sns.heatmap(attn_gt, annot=False, ax=axs3[stage_idx]) max_gt_offset = ind_max - indices[0] # H W max_gt_offset = max_gt_offset.detach().cpu().numpy() sns.heatmap(max_gt_offset, annot=False, ax=axs4[stage_idx]) if cas_method[stage_idx] == 't' or cas_method[stage_idx] == 'r' or cas_method[stage_idx] == 'p': # Loss for transformer depth_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) if last_conv3d: depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean') attn_weights = stage_inputs["attn_weights"].permute(0,2,3,1).reshape(B*H*W, -1) # BHW D this_stage_depth_val = stage_inputs['depth_values'] # B D H W depth_offsets = torch.abs(this_stage_depth_val- depth_gt[:,None,:,:]) # B,D,H,W _, indices = torch.min(depth_offsets, dim=1) # [B, H, W] indices = indices.reshape(-1) # [BHW] mask = mask.reshape(-1) # BHW if fl: # -p(1-q)^a log(q) this_stage_ce_loss = F.nll_loss((1-attn_weights[mask])**2 * torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean') else: # -plog(q) this_stage_ce_loss = F.nll_loss(torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean') stage_depth_loss.append(depth_loss) stage_ce_loss.append(this_stage_ce_loss) this_stage_loss = l1ce_lw[0]*depth_loss + l1ce_lw[1]*this_stage_ce_loss # Loss for 3D conv else: if wt: depth_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) stage_depth_loss.append(depth_loss) attn_weights = stage_inputs["attn_weights"].permute(0,2,3,1).reshape(B*H*W, -1) # BHW D depth_offsets = torch.abs(stage_inputs['depth_values']- depth_gt[:,None,:,:]) # B,D,H,W indices = torch.min(depth_offsets, dim=1)[1].reshape(-1) # [BHW] mask = mask.reshape(-1) # BHW if fl: # -p(1-q)^a log(q) this_stage_ce_loss = F.nll_loss((1-attn_weights[mask])**2 * torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean') else: # -plog(q) this_stage_ce_loss = F.nll_loss(torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean') stage_ce_loss.append(this_stage_ce_loss) else: depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean') stage_depth_loss.append(depth_loss) this_stage_ce_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) stage_ce_loss.append(this_stage_ce_loss) this_stage_loss = l1ce_lw[0]*depth_loss + l1ce_lw[1]*this_stage_ce_loss if upsampled_loss: if stage_idx!=0: this_stage_loss = this_stage_loss + upsampled_depth_loss * l1ce_lw[0] if shrink_method == 'DPF': if stage_idx!=0: depth_offsets = stage_inputs['depth_values'] - depth_gt[:,None,:,:] # B,D,H,W depth_offset_clamp = torch.clamp(depth_offsets, -1, 1) this_stage_loss = this_stage_loss + torch.abs(depth_offset_clamp).permute(0,2,3,1).reshape(B*H*W, -1)[mask.reshape(-1)].mean() if selected_loss: select_weight = stage_inputs["select_weight"].permute(0,2,3,1).reshape(B*H*W, -1) # BHW D depth_offsets = torch.abs(stage_inputs['depth_values']- depth_gt[:,None,:,:]) indices = torch.min(depth_offsets, dim=1)[1] # [B, H, W] indices = indices.reshape(-1) # [BHW] mask = mask.reshape(-1) # BHW this_stage_selected_loss = F.nll_loss(torch.log(select_weight[mask]+1e-12), indices[mask], reduce='mean') this_stage_loss = this_stage_loss + this_stage_selected_loss * 0.01*l1ce_lw[1] if det: assert wt depth_itv = stage_inputs['depth_values'][:,1,:,:] - stage_inputs['depth_values'][:,0,:,:] # B H W pred_offset = stage_inputs['offset_reg'].reshape(-1) # BHW offset_gt = (depth_gt - (depth_est - stage_inputs['offset_reg'])).reshape(-1) / depth_itv.reshape(-1) # BHW det_offset_loss = F.smooth_l1_loss(pred_offset[mask], offset_gt[mask], reduction='mean') det_offset_losses.append(det_offset_loss) this_stage_loss += det_offset_loss else: det_offset_losses.append(torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)) if depth_loss_weights is not None: stage_idx = int(stage_key.replace("stage", "")) - 1 total_loss += depth_loss_weights[stage_idx] * this_stage_loss else: total_loss += 1.0 * this_stage_loss if visual: axs[1].set_title('err848:{}'.format(err_848_str) + 'err_106:{}'.format(err_106_str) + 'err_002:{}'.format(err_002_str)) f.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/offset_heatmap.png') f.clf() f2.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/attn_max_heatmap.png') f2.clf() f3.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/attn_gt_heatmap.png') f3.clf() f4.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/max_gt_offset_heatmap.png') f4.clf() return total_loss, depth_loss, stage_depth_loss, stage_ce_loss, range_err_ratio, upsampled_depth_losses, det_offset_losses def get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, shape, max_depth=192.0, min_depth=0.0): #shape, (B, H, W) #cur_depth: (B, H, W) #return depth_range_values: (B, D, H, W) cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel) # (B, H, W) cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel) # cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel).clamp(min=0.0) #(B, H, W) # cur_depth_max = (cur_depth_min + (ndepth - 1) * depth_inteval_pixel).clamp(max=max_depth) assert cur_depth.shape == torch.Size(shape), "cur_depth:{}, input shape:{}".format(cur_depth.shape, shape) new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1) # (B, H, W) depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=cur_depth.device, dtype=cur_depth.dtype, requires_grad=False).reshape(1, -1, 1, 1) * new_interval.unsqueeze(1)) return depth_range_samples def get_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, device, dtype, shape, max_depth=192.0, min_depth=0.0): #shape: (B, H, W) #cur_depth: (B, H, W) or (B, D) #return depth_range_samples: (B, D, H, W) if cur_depth.dim() == 2: cur_depth_min = cur_depth[:, 0] # (B,) cur_depth_max = cur_depth[:, -1] new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1) # (B, ) depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=device, dtype=dtype, requires_grad=False).reshape(1, -1) * new_interval.unsqueeze(1)) #(B, D) depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, shape[1], shape[2]) #(B, D, H, W) else: depth_range_samples = get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, shape, max_depth, min_depth) return depth_range_samples if __name__ == "__main__": # some testing code, just IGNORE it import sys sys.path.append("../") from datasets import find_dataset_def from torch.utils.data import DataLoader import numpy as np import cv2 import matplotlib as mpl mpl.use('Agg') import matplotlib.pyplot as plt # MVSDataset = find_dataset_def("colmap") # dataset = MVSDataset("../data/results/ford/num10_1/", 3, 'test', # 128, interval_scale=1.06, max_h=1250, max_w=1024) MVSDataset = find_dataset_def("dtu_yao") num_depth = 48 dataset = MVSDataset("../data/DTU/mvs_training/dtu/", '../lists/dtu/train.txt', 'train', 3, num_depth, interval_scale=1.06 * 192 / num_depth) dataloader = DataLoader(dataset, batch_size=1) item = next(iter(dataloader)) imgs = item["imgs"][:, :, :, ::4, ::4] #(B, N, 3, H, W) # imgs = item["imgs"][:, :, :, :, :] proj_matrices = item["proj_matrices"] #(B, N, 2, 4, 4) dim=N: N view; dim=2: index 0 for extr, 1 for intric proj_matrices[:, :, 1, :2, :] = proj_matrices[:, :, 1, :2, :] # proj_matrices[:, :, 1, :2, :] = proj_matrices[:, :, 1, :2, :] * 4 depth_values = item["depth_values"] #(B, D) imgs = torch.unbind(imgs, 1) proj_matrices = torch.unbind(proj_matrices, 1) ref_img, src_imgs = imgs[0], imgs[1:] ref_proj, src_proj = proj_matrices[0], proj_matrices[1:][0] #only vis first view src_proj_new = src_proj[:, 0].clone() src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4]) ref_proj_new = ref_proj[:, 0].clone() ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4]) warped_imgs = homo_warping(src_imgs[0], src_proj_new, ref_proj_new, depth_values) ref_img_np = ref_img.permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255 cv2.imwrite('../tmp/ref.png', ref_img_np) cv2.imwrite('../tmp/src.png', src_imgs[0].permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255) for i in range(warped_imgs.shape[2]): warped_img = warped_imgs[:, :, i, :, :].permute([0, 2, 3, 1]).contiguous() img_np = warped_img[0].detach().cpu().numpy() img_np = img_np[:, :, ::-1] * 255 alpha = 0.5 beta = 1 - alpha gamma = 0 img_add = cv2.addWeighted(ref_img_np, alpha, img_np, beta, gamma) cv2.imwrite('../tmp/tmp{}.png'.format(i), np.hstack([ref_img_np, img_np, img_add])) #* ratio + img_np*(1-ratio)])) ================================================ FILE: models/mvs4net_utils.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F import importlib try: from modules.deform_conv import DeformConvPack except: print('DeformConvPack not found, please install it from: https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch') pass import math import numpy as np def homo_warping(src_fea, src_proj, ref_proj, depth_values, vis_ETA=False, fn=None): # src_fea: [B, C, H, W] # src_proj: [B, 4, 4] # ref_proj: [B, 4, 4] # depth_values: [B, Ndepth] o [B, Ndepth, H, W] # out: [B, C, Ndepth, H, W] C = src_fea.shape[1] Hs,Ws = src_fea.shape[-2:] B,num_depth,Hr,Wr = depth_values.shape with torch.no_grad(): proj = torch.matmul(src_proj, torch.inverse(ref_proj)) rot = proj[:, :3, :3] # [B,3,3] trans = proj[:, :3, 3:4] # [B,3,1] y, x = torch.meshgrid([torch.arange(0, Hr, dtype=torch.float32, device=src_fea.device), torch.arange(0, Wr, dtype=torch.float32, device=src_fea.device)]) y = y.reshape(Hr*Wr) x = x.reshape(Hr*Wr) xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] xyz = torch.unsqueeze(xyz, 0).repeat(B, 1, 1) # [B, 3, H*W] rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.reshape(B, 1, num_depth, -1) # [B, 3, Ndepth, H*W] proj_xyz = rot_depth_xyz + trans.reshape(B, 3, 1, 1) # [B, 3, Ndepth, H*W] # FIXME divide 0 temp = proj_xyz[:, 2:3, :, :] temp[temp==0] = 1e-9 proj_xy = proj_xyz[:, :2, :, :] / temp # [B, 2, Ndepth, H*W] # proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] proj_x_normalized = proj_xy[:, 0, :, :] / ((Ws - 1) / 2) - 1 proj_y_normalized = proj_xy[:, 1, :, :] / ((Hs - 1) / 2) - 1 proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] if vis_ETA: tensor_saved = proj_xy.reshape(B,num_depth,Hs,Ws,2).cpu().numpy() np.save(fn+'_grid', tensor_saved) grid = proj_xy if len(src_fea.shape)==4: warped_src_fea = F.grid_sample(src_fea, grid.reshape(B, num_depth * Hr, Wr, 2), mode='bilinear', padding_mode='zeros', align_corners=True) warped_src_fea = warped_src_fea.reshape(B, C, num_depth, Hr, Wr) elif len(src_fea.shape)==5: warped_src_fea = [] for d in range(src_fea.shape[2]): warped_src_fea.append(F.grid_sample(src_fea[:,:,d], grid.reshape(B, num_depth, Hr, Wr, 2)[:,d], mode='bilinear', padding_mode='zeros', align_corners=True)) warped_src_fea = torch.stack(warped_src_fea, dim=2) return warped_src_fea def init_range(cur_depth, ndepths, device, dtype, H, W): cur_depth_min = cur_depth[:, 0] # (B,) cur_depth_max = cur_depth[:, -1] new_interval = (cur_depth_max - cur_depth_min) / (ndepths - 1) # (B, ) new_interval = new_interval[:, None, None] # B H W depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepths, device=device, dtype=dtype, requires_grad=False).reshape(1, -1) * new_interval.squeeze(1)) #(B, D) depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, H, W) #(B, D, H, W) return depth_range_samples def init_inverse_range(cur_depth, ndepths, device, dtype, H, W): inverse_depth_min = 1. / cur_depth[:, 0] # (B,) inverse_depth_max = 1. / cur_depth[:, -1] itv = torch.arange(0, ndepths, device=device, dtype=dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H, W) / (ndepths - 1) # 1 D H W inverse_depth_hypo = inverse_depth_max[:,None, None, None] + (inverse_depth_min - inverse_depth_max)[:,None, None, None] * itv return 1./inverse_depth_hypo def schedule_inverse_range(inverse_min_depth, inverse_max_depth, ndepths, H, W): #cur_depth_min, (B, H, W) #cur_depth_max: (B, H, W) itv = torch.arange(0, ndepths, device=inverse_min_depth.device, dtype=inverse_min_depth.dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H//2, W//2) / (ndepths - 1) # 1 D H W inverse_depth_hypo = inverse_max_depth[:,None, :, :] + (inverse_min_depth - inverse_max_depth)[:,None, :, :] * itv # B D H W inverse_depth_hypo = F.interpolate(inverse_depth_hypo.unsqueeze(1), [ndepths, H, W], mode='trilinear', align_corners=True).squeeze(1) return 1./inverse_depth_hypo def schedule_range(cur_depth, ndepth, depth_inteval_pixel, H, W): #shape, (B, H, W) #cur_depth: (B, H, W) #return depth_range_values: (B, D, H, W) cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel[:,None,None]) # (B, H, W) cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel[:,None,None]) new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1) # (B, H, W) depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=cur_depth.device, dtype=cur_depth.dtype, requires_grad=False).reshape(1, -1, 1, 1) * new_interval.unsqueeze(1)) depth_range_samples = F.interpolate(depth_range_samples.unsqueeze(1), [ndepth, H, W], mode='trilinear', align_corners=True).squeeze(1) return depth_range_samples def init_bn(module): if module.weight is not None: nn.init.ones_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) return def init_uniform(module, init_method): if module.weight is not None: if init_method == "kaiming": nn.init.kaiming_uniform_(module.weight) elif init_method == "xavier": nn.init.xavier_uniform_(module.weight) return class ConvBnReLU3D(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBnReLU3D, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm3d(out_channels) def forward(self, x): return F.relu(self.bn(self.conv(x)), inplace=True) class ConvBnReLU3D_CAM(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBnReLU3D_CAM, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm3d(out_channels) self.linear_agg = nn.Sequential( nn.Linear(out_channels, out_channels//2), nn.ReLU(), nn.Linear(out_channels//2, out_channels) ) def forward(self, input): x = self.conv(input) B,C,D,H,W = x.shape avg_attn = self.linear_agg(x.reshape(B,C,D*H*W).mean(2)) max_attn = self.linear_agg(x.reshape(B,C,D*H*W).max(2)[0]) # B C attn = F.sigmoid(max_attn+avg_attn)[:,:,None,None,None] # B C,1,1,1 x = x * attn return F.relu(self.bn(x+input), inplace=True) class ConvBnReLU3D_DCAM(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBnReLU3D_DCAM, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm3d(out_channels) self.linear_agg = nn.Sequential( nn.Linear(out_channels, out_channels//2), nn.ReLU(), nn.Linear(out_channels//2, out_channels) ) def forward(self, input): x = self.conv(input) B,C,D,H,W = x.shape avg_attn = self.linear_agg(x.reshape(B,C,D,H*W).mean(3).permute(0,2,1).reshape(B*D,C)).reshape(B,D,C).permute(0,2,1) max_attn = self.linear_agg(x.reshape(B,C,D,H*W).max(3)[0].permute(0,2,1).reshape(B*D,C)).reshape(B,D,C).permute(0,2,1) # B C D attn = F.sigmoid(max_attn+avg_attn)[:,:,:,None,None] # B C,D,1,1 x = x * attn return F.relu(self.bn(x+input), inplace=True) class ConvBnReLU3D_PAM(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBnReLU3D_PAM, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm3d(out_channels) self.pixel_conv = nn.Conv2d(2,1,7,stride=1,padding='same') def forward(self, input): x = self.conv(input) B,C,D,H,W = x.shape max_attn = x.reshape(B,C*D,H,W).max(1, keepdim=True)[0] avg_attn = x.reshape(B,C*D,H,W).mean(1, keepdim=True) # B 1 H W attn = F.sigmoid(self.pixel_conv(torch.cat([max_attn, avg_attn], dim=1)))[:,:,None,:,:] # B 1,1,H,W x = x * attn return F.relu(self.bn(x+input), inplace=True) class ConvBnReLU3D_PDAM(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): super(ConvBnReLU3D_PDAM, self).__init__() self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) self.bn = nn.BatchNorm3d(out_channels) self.spatial_conv = nn.Conv3d(2,1,7,stride=1,padding='same') def forward(self, input): x = self.conv(input) B,C,D,H,W = x.shape max_attn = x.max(1, keepdim=True)[0] avg_attn = x.mean(1, keepdim=True) # B 1 D H W attn = F.sigmoid(self.spatial_conv(torch.cat([max_attn, avg_attn], dim=1))) # B 1,D,H,W x = x * attn return F.relu(self.bn(x+input), inplace=True) class Deconv3d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): super(Deconv3d, self).__init__() self.out_channels = out_channels assert stride in [1, 2] self.stride = stride self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, bias=(not bn), **kwargs) self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None self.relu = relu def forward(self, x): y = self.conv(x) if self.bn is not None: x = self.bn(y) if self.relu: x = F.relu(x, inplace=True) return x def init_weights(self, init_method): init_uniform(self.conv, init_method) if self.bn is not None: init_bn(self.bn) class Conv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, relu=True, bn_momentum=0.1, init_method="xavier", gn=False, group_channel=8, **kwargs): super(Conv2d, self).__init__() bn = not gn self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=(not bn), **kwargs) self.kernel_size = kernel_size self.stride = stride self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None self.gn = nn.GroupNorm(int(max(1, out_channels / group_channel)), out_channels) if gn else None self.relu = relu def forward(self, x): x = self.conv(x) if self.bn is not None: x = self.bn(x) else: x = self.gn(x) if self.relu: x = F.relu(x, inplace=True) return x def init_weights(self, init_method): init_uniform(self.conv, init_method) if self.bn is not None: init_bn(self.bn) class Deconv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): super(Deconv2d, self).__init__() self.out_channels = out_channels assert stride in [1, 2] self.stride = stride self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, bias=(not bn), **kwargs) self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None self.relu = relu class DeformConv2d(nn.Module): def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=True): super(DeformConv2d, self).__init__() self.kernel_size = kernel_size self.padding = padding self.stride = stride self.zero_padding = nn.ZeroPad2d(padding) self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias) self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) nn.init.constant_(self.p_conv.weight, 0) self.p_conv.register_backward_hook(self._set_lr) self.modulation = modulation if modulation: self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride) nn.init.constant_(self.m_conv.weight, 0) self.m_conv.register_backward_hook(self._set_lr) @staticmethod def _set_lr(module, grad_input, grad_output): grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input))) grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output))) def forward(self, x): offset = self.p_conv(x) if self.modulation: m = torch.sigmoid(self.m_conv(x)) dtype = offset.data.type() ks = self.kernel_size N = offset.size(1) // 2 if self.padding: x = self.zero_padding(x) # (b, 2N, h, w) p = self._get_p(offset, dtype) # (b, h, w, 2N) p = p.contiguous().permute(0, 2, 3, 1) q_lt = p.detach().floor() q_rb = q_lt + 1 q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long() q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long() q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1) q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1) # clip p p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1) # bilinear kernel (b, h, w, N) g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:])) g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:])) g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:])) g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:])) # (b, c, h, w, N) x_q_lt = self._get_x_q(x, q_lt, N) x_q_rb = self._get_x_q(x, q_rb, N) x_q_lb = self._get_x_q(x, q_lb, N) x_q_rt = self._get_x_q(x, q_rt, N) # (b, c, h, w, N) x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \ g_rb.unsqueeze(dim=1) * x_q_rb + \ g_lb.unsqueeze(dim=1) * x_q_lb + \ g_rt.unsqueeze(dim=1) * x_q_rt # modulation if self.modulation: m = m.contiguous().permute(0, 2, 3, 1) m = m.unsqueeze(dim=1) m = torch.cat([m for _ in range(x_offset.size(1))], dim=1) x_offset *= m x_offset = self._reshape_x_offset(x_offset, ks) out = self.conv(x_offset) return out def _get_p_n(self, N, dtype): p_n_x, p_n_y = torch.meshgrid( torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) # (2N, 1) p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) p_n = p_n.view(1, 2*N, 1, 1).type(dtype) return p_n def _get_p_0(self, h, w, N, dtype): p_0_x, p_0_y = torch.meshgrid( torch.arange(1, h*self.stride+1, self.stride), torch.arange(1, w*self.stride+1, self.stride)) p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) return p_0 def _get_p(self, offset, dtype): N, h, w = offset.size(1)//2, offset.size(2), offset.size(3) # (1, 2N, 1, 1) p_n = self._get_p_n(N, dtype) # (1, 2N, h, w) p_0 = self._get_p_0(h, w, N, dtype) p = p_0 + p_n + offset return p def _get_x_q(self, x, q, N): b, h, w, _ = q.size() padded_w = x.size(3) c = x.size(1) # (b, c, h*w) x = x.contiguous().view(b, c, -1) # (b, h, w, N) index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y # (b, c, h*w*N) index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1) x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N) return x_offset @staticmethod def _reshape_x_offset(x_offset, ks): b, c, h, w, N = x_offset.size() x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1) x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks) return x_offset def NA_DCN(in_channels, kernel_size=3, stride=1, dilation=1, bias=True, group_channel=8, gn=False): if gn: return nn.Sequential( nn.GroupNorm(int(max(1, in_channels / group_channel)), in_channels), nn.ReLU(inplace=True), # DeformConv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, bias=bias), DeformConvPack(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=1, deformable_groups=1, bias=False, im2col_step=16) ) else: return nn.Sequential( nn.BatchNorm2d(in_channels, momentum=0.1), nn.ReLU(inplace=True), # DeformConv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, bias=bias), DeformConvPack(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=1, deformable_groups=1, bias=False, im2col_step=16) ) class FPN4(nn.Module): """ FPN aligncorners downsample 4x""" def __init__(self, base_channels, gn=False, dcn=False): super(FPN4, self).__init__() self.base_channels = base_channels self.conv0 = nn.Sequential( Conv2d(3, base_channels, 3, 1, padding=1, gn=gn), Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn), ) self.conv1 = nn.Sequential( Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2, gn=gn), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn), Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn), ) self.conv2 = nn.Sequential( Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2, gn=gn), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn), Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn), ) self.conv3 = nn.Sequential( Conv2d(base_channels * 4, base_channels * 8, 5, stride=2, padding=2, gn=gn), Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn), Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn), ) self.out_channels = [8 * base_channels] final_chs = base_channels * 8 self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True) self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True) self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False) self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False) self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False) self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False) self.dcn = dcn if self.dcn: self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn) self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn) self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn) self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn) self.out_channels.append(base_channels * 4) self.out_channels.append(base_channels * 2) self.out_channels.append(base_channels) def forward(self, x): conv0 = self.conv0(x) conv1 = self.conv1(conv0) conv2 = self.conv2(conv1) conv3 = self.conv3(conv2) intra_feat = conv3 outputs = {} out1 = self.out1(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2) out2 = self.out2(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1) out3 = self.out3(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0) out4 = self.out4(intra_feat) if self.dcn: out1 = self.dcn1(out1) out2 = self.dcn2(out2) out3 = self.dcn3(out3) out4 = self.dcn4(out4) outputs["stage1"] = out1 outputs["stage2"] = out2 outputs["stage3"] = out3 outputs["stage4"] = out4 return outputs class LayerNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape, ) def forward(self, x): if self.data_format == "channels_last": return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x class convnext_block(nn.Module): def __init__(self, dim, layer_scale_init_value=1e-6): super().__init__() self.dwconv = nn.Conv2d(dim, 2*dim, kernel_size=7, stride=2, padding=3, groups=dim) # depthwise conv self.norm = LayerNorm(2*dim, eps=1e-6) self.pwconv1 = nn.Linear(2*dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(4 * dim, 2*dim) self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((2*dim)), requires_grad=True) if layer_scale_init_value > 0 else None def forward(self, x): input = x x = self.dwconv(x) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) # x = input + x return x class convnext4_block(nn.Module): def __init__(self, dim, layer_scale_init_value=1e-6): super().__init__() self.sconv = nn.Conv2d(dim, 2*dim, kernel_size=2, stride=2, padding=0) # stride=2 conv self.dwconv = nn.Conv2d(2*dim, 2*dim, kernel_size=7, stride=1, padding=3, groups=dim) # depthwise conv self.norm = LayerNorm(2*dim, eps=1e-6) self.pwconv1 = nn.Linear(2*dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(4 * dim, 2*dim) self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((2*dim)), requires_grad=True) if layer_scale_init_value > 0 else None def forward(self, x): input = self.sconv(x) x = self.dwconv(input) x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) x = self.norm(x) x = self.pwconv1(x) x = self.act(x) x = self.pwconv2(x) if self.gamma is not None: x = self.gamma * x x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) x = input + x return x class FPN4_convnext(nn.Module): """ FPN aligncorners downsample 4x""" def __init__(self, base_channels, gn=False, dcn=False): super(FPN4_convnext, self).__init__() self.base_channels = base_channels self.conv0 = nn.Sequential( Conv2d(3, base_channels, 3, 1, padding=1, gn=gn), Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn), ) self.conv1 = convnext_block(base_channels) self.conv2 = convnext_block(2*base_channels) self.conv3 = convnext_block(4*base_channels) self.out_channels = [8 * base_channels] final_chs = base_channels * 8 self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True) self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True) self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False) self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False) self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False) self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False) self.dcn = dcn if self.dcn: self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn) self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn) self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn) self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn) self.out_channels.append(base_channels * 4) self.out_channels.append(base_channels * 2) self.out_channels.append(base_channels) def forward(self, x): conv0 = self.conv0(x) conv1 = self.conv1(conv0) conv2 = self.conv2(conv1) conv3 = self.conv3(conv2) intra_feat = conv3 outputs = {} out1 = self.out1(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2) out2 = self.out2(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1) out3 = self.out3(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0) out4 = self.out4(intra_feat) if self.dcn: out1 = self.dcn1(out1) out2 = self.dcn2(out2) out3 = self.dcn3(out3) out4 = self.dcn4(out4) outputs["stage1"] = out1 outputs["stage2"] = out2 outputs["stage3"] = out3 outputs["stage4"] = out4 return outputs class FPN4_convnext4(nn.Module): """ FPN aligncorners downsample 4x""" def __init__(self, base_channels, gn=False, dcn=False): super(FPN4_convnext4, self).__init__() self.base_channels = base_channels self.conv0 = nn.Sequential( Conv2d(3, base_channels, 3, 1, padding=1, gn=gn), Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn), ) self.conv1 = convnext4_block(base_channels) self.conv2 = convnext4_block(2*base_channels) self.conv3 = convnext4_block(4*base_channels) self.out_channels = [8 * base_channels] final_chs = base_channels * 8 self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True) self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True) self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False) self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False) self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False) self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False) self.dcn = dcn if self.dcn: self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn) self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn) self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn) self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn) self.out_channels.append(base_channels * 4) self.out_channels.append(base_channels * 2) self.out_channels.append(base_channels) def forward(self, x): conv0 = self.conv0(x) conv1 = self.conv1(conv0) conv2 = self.conv2(conv1) conv3 = self.conv3(conv2) intra_feat = conv3 outputs = {} out1 = self.out1(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2) out2 = self.out2(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1) out3 = self.out3(intra_feat) intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0) out4 = self.out4(intra_feat) if self.dcn: out1 = self.dcn1(out1) out2 = self.dcn2(out2) out3 = self.dcn3(out3) out4 = self.dcn4(out4) outputs["stage1"] = out1 outputs["stage2"] = out2 outputs["stage3"] = out3 outputs["stage4"] = out4 return outputs class ASFF(nn.Module): def __init__(self, level): super(ASFF, self).__init__() self.level = level self.dim = [64,32,16,8] self.inter_dim = self.dim[self.level] if level==0: self.stride_level_1 = Conv2d(32, 64, 3, stride=2, padding=1) self.stride_level_2 = Conv2d(16, 64, 3, stride=2, padding=1) self.stride_level_3 = Conv2d(8, 64, 3, stride=2, padding=1) self.expand = Conv2d(64, 64, 3, stride=1, padding=1) elif level==1: self.compress_level_0 = Conv2d(64, 32, 1, stride=1, padding=0) self.stride_level_2 = Conv2d(16, 32, 3, stride=2, padding=1) self.stride_level_3 = Conv2d(8, 32, 3, stride=2, padding=1) self.expand = Conv2d(32, 32, 3, stride=1, padding=1) elif level==2: self.compress_level_0 = Conv2d(64, 16, 1, stride=1, padding=0) self.compress_level_1 = Conv2d(32, 16, 1, stride=1, padding=0) self.stride_level_3 = Conv2d(8, 16, 3, stride=2, padding=1) self.expand = Conv2d(16, 16, 3, stride=1, padding=1) elif level==3: self.compress_level_0 = Conv2d(64, 8, 1, stride=1, padding=0) self.compress_level_1 = Conv2d(32, 8, 1, stride=1, padding=0) self.compress_level_2 = Conv2d(16, 8, 1, stride=1, padding=0) self.expand = Conv2d(8, 8, 3, stride=1, padding=1) self.weight_level_0 = Conv2d(self.dim[level], 8, 1, 1, 0) self.weight_level_1 = Conv2d(self.dim[level], 8, 1, 1, 0) self.weight_level_2 = Conv2d(self.dim[level], 8, 1, 1, 0) self.weight_level_3 = Conv2d(self.dim[level], 8, 1, 1, 0) self.weight_levels = nn.Conv2d(32, 4, kernel_size=1, stride=1, padding=0) def forward(self, x_level_0, x_level_1, x_level_2, x_level_3): if self.level==0: level_0_resized = x_level_0 level_1_resized = self.stride_level_1(x_level_1) level_2_downsampled_inter = F.max_pool2d(x_level_2, 2, stride=2, padding=0) level_2_resized = self.stride_level_2(level_2_downsampled_inter) level_3_downsampled_inter = F.max_pool2d(x_level_3, 4, stride=4, padding=0) level_3_resized = self.stride_level_3(level_3_downsampled_inter) elif self.level==1: level_0_compressed = self.compress_level_0(x_level_0) level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode='nearest') level_1_resized = x_level_1 level_2_resized = self.stride_level_2(x_level_2) level_3_downsampled_inter = F.max_pool2d(x_level_3, 2, stride=2, padding=0) level_3_resized = self.stride_level_3(level_3_downsampled_inter) elif self.level==2: level_0_compressed = self.compress_level_0(x_level_0) level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode='nearest') level_1_compressed = self.compress_level_1(x_level_1) level_1_resized = F.interpolate(level_1_compressed, scale_factor=2, mode='nearest') level_2_resized = x_level_2 level_3_resized = self.stride_level_3(x_level_3) elif self.level==3: level_0_compressed = self.compress_level_0(x_level_0) level_0_resized = F.interpolate(level_0_compressed, scale_factor=8, mode='nearest') level_1_compressed = self.compress_level_1(x_level_1) level_1_resized = F.interpolate(level_1_compressed, scale_factor=4, mode='nearest') level_2_compressed = self.compress_level_2(x_level_2) level_2_resized = F.interpolate(level_2_compressed, scale_factor=2, mode='nearest') level_3_resized = x_level_3 level_0_weight_v = self.weight_level_0(level_0_resized) level_1_weight_v = self.weight_level_1(level_1_resized) level_2_weight_v = self.weight_level_2(level_2_resized) level_3_weight_v = self.weight_level_3(level_3_resized) levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v, level_3_weight_v),1) levels_weight = self.weight_levels(levels_weight_v) levels_weight = F.softmax(levels_weight, dim=1) fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+\ level_1_resized * levels_weight[:,1:2,:,:]+\ level_2_resized * levels_weight[:,2:3,:,:]+\ level_3_resized * levels_weight[:,3:,:,:] out = self.expand(fused_out_reduced) return out class FullImageEncoder(nn.Module): def __init__(self, h, w, kernel_size): super(FullImageEncoder, self).__init__() self.global_pooling = nn.AvgPool2d(kernel_size, stride=kernel_size, padding=kernel_size // 2) # KITTI 16 16 self.dropout = nn.Dropout2d(p=0.5) self.h = h // kernel_size + 1 self.w = w // kernel_size + 1 # print("h=", self.h, " w=", self.w, h, w) self.global_fc = nn.Linear(2048 * self.h * self.w, 512) # kitti 4x5 self.relu = nn.ReLU(inplace=True) self.conv1 = nn.Conv2d(512, 512, 1) # 1x1 卷积 def forward(self, x): # print('x size:', x.size()) x1 = self.global_pooling(x) # print('# x1 size:', x1.size()) x2 = self.dropout(x1) x3 = x2.view(-1, 2048 * self.h * self.w) # kitti 4x5 x4 = self.relu(self.global_fc(x3)) # print('# x4 size:', x4.size()) x4 = x4.view(-1, 512, 1, 1) # print('# x4 size:', x4.size()) x5 = self.conv1(x4) # out = self.upsample(x5) return x5 class mono_depth_decoder(nn.Module): def __init__(self): super(mono_depth_decoder, self).__init__() self.convblocks = nn.ModuleList( [Conv2d(64, 32, 3, 1, padding=1), Conv2d(32, 16, 3, 1, padding=1), Conv2d(16, 8, 3, 1, padding=1)] ) self.conv3x3 = nn.ModuleList( [nn.Conv2d(64, 1, 3, 1, 1), nn.Conv2d(32, 1, 3, 1, 1), nn.Conv2d(16, 1, 3, 1, 1)] ) self.sigmoid = nn.Sigmoid() def forward(self, outputs, d_min, d_max): """ d_max: B """ for i in range(1,4): # 1 2 3 mono_small_feat = outputs['stage{}'.format(i)]['mono_feat'] mono_large_feat = outputs['stage{}'.format(i+1)]['mono_feat'] mono_small_feat = self.convblocks[i-1](mono_small_feat) mono_small_feat = F.interpolate(mono_small_feat, scale_factor=2, mode="nearest") mono_feat = self.conv3x3[i-1](torch.cat([mono_small_feat, mono_large_feat], 1)) # B C H W disp = self.sigmoid(mono_feat) min_disp = (1 / d_max)[:,None,None,None] # B 1 1 1 max_disp = (1 / d_min)[:,None,None,None] scaled_disp = min_disp + (max_disp - min_disp) * disp depth = 1 / scaled_disp outputs['stage{}'.format(i+1)]['mono_depth'] = depth.squeeze(1) return outputs class reg2d(nn.Module): def __init__(self, input_channel=128, base_channel=32, conv_name='ConvBnReLU3D'): super(reg2d, self).__init__() module = importlib.import_module("models.mvs4net_utils") stride_conv_name = 'ConvBnReLU3D' self.conv0 = getattr(module, stride_conv_name)(input_channel, base_channel, kernel_size=(1,3,3), pad=(0,1,1)) self.conv1 = getattr(module, stride_conv_name)(base_channel, base_channel*2, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1)) self.conv2 = getattr(module, conv_name)(base_channel*2, base_channel*2) self.conv3 = getattr(module, stride_conv_name)(base_channel*2, base_channel*4, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1)) self.conv4 = getattr(module, conv_name)(base_channel*4, base_channel*4) self.conv5 = getattr(module, stride_conv_name)(base_channel*4, base_channel*8, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1)) self.conv6 = getattr(module, conv_name)(base_channel*8, base_channel*8) self.conv7 = nn.Sequential( nn.ConvTranspose3d(base_channel*8, base_channel*4, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False), nn.BatchNorm3d(base_channel*4), nn.ReLU(inplace=True)) self.conv9 = nn.Sequential( nn.ConvTranspose3d(base_channel*4, base_channel*2, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False), nn.BatchNorm3d(base_channel*2), nn.ReLU(inplace=True)) self.conv11 = nn.Sequential( nn.ConvTranspose3d(base_channel*2, base_channel, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False), nn.BatchNorm3d(base_channel), nn.ReLU(inplace=True)) self.prob = nn.Conv3d(8, 1, 1, stride=1, padding=0) def forward(self, x): conv0 = self.conv0(x) conv2 = self.conv2(self.conv1(conv0)) conv4 = self.conv4(self.conv3(conv2)) x = self.conv6(self.conv5(conv4)) x = conv4 + self.conv7(x) x = conv2 + self.conv9(x) x = conv0 + self.conv11(x) x = self.prob(x) return x.squeeze(1) class reg3d(nn.Module): def __init__(self, in_channels, base_channels, down_size=3): super(reg3d, self).__init__() self.down_size = down_size self.conv0 = ConvBnReLU3D(in_channels, base_channels, kernel_size=3, pad=1) self.conv1 = ConvBnReLU3D(base_channels, base_channels*2, kernel_size=3, stride=2, pad=1) self.conv2 = ConvBnReLU3D(base_channels*2, base_channels*2) if down_size >= 2: self.conv3 = ConvBnReLU3D(base_channels*2, base_channels*4, kernel_size=3, stride=2, pad=1) self.conv4 = ConvBnReLU3D(base_channels*4, base_channels*4) if down_size >= 3: self.conv5 = ConvBnReLU3D(base_channels*4, base_channels*8, kernel_size=3, stride=2, pad=1) self.conv6 = ConvBnReLU3D(base_channels*8, base_channels*8) self.conv7 = nn.Sequential( nn.ConvTranspose3d(base_channels*8, base_channels*4, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), nn.BatchNorm3d(base_channels*4), nn.ReLU(inplace=True)) if down_size >= 2: self.conv9 = nn.Sequential( nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), nn.BatchNorm3d(base_channels*2), nn.ReLU(inplace=True)) self.conv11 = nn.Sequential( nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), nn.BatchNorm3d(base_channels), nn.ReLU(inplace=True)) self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False) def forward(self, x): if self.down_size==3: conv0 = self.conv0(x) conv2 = self.conv2(self.conv1(conv0)) conv4 = self.conv4(self.conv3(conv2)) x = self.conv6(self.conv5(conv4)) x = conv4 + self.conv7(x) x = conv2 + self.conv9(x) x = conv0 + self.conv11(x) x = self.prob(x) elif self.down_size==2: conv0 = self.conv0(x) conv2 = self.conv2(self.conv1(conv0)) x = self.conv4(self.conv3(conv2)) x = conv2 + self.conv9(x) x = conv0 + self.conv11(x) x = self.prob(x) else: conv0 = self.conv0(x) x = self.conv2(self.conv1(conv0)) x = conv0 + self.conv11(x) x = self.prob(x) return x.squeeze(1) # B D H W class PosEncSine(nn.Module): def __init__(self, temperature=1000): super(PosEncSine, self).__init__() self.temperature = temperature def forward(self, x, depth): # depth : B D H W with torch.no_grad(): B,C,D,H,W = x.shape depth = depth.permute(0,2,3,1).reshape(B*H*W, D) / self.temperature # BHW D pos = torch.stack([torch.sin(i * math.pi * depth) for i in range(C//2)] + [torch.cos(i * math.pi * depth) for i in range(C//2)], dim=-1) # BHW,D,C pos = pos.reshape(B,H,W,D,C).permute(0,4,3,1,2) # B C D H W x = x + pos return x class PosEncLearned(nn.Module): """ Absolute pos embedding, learned. """ def __init__(self, D, C): super().__init__() self.D = D self.C = C self.depth_embed = nn.Parameter(torch.Tensor(C, self.D)) self.reset_parameters() def reset_parameters(self): nn.init.uniform_(self.depth_embed) def forward(self, x, **kwargs): B,C,D,H,W = x.shape pos = self.depth_embed[None,:,:,None,None].repeat(B,1,1,H,W) # B C D H W x = x + pos return x class stagenet(nn.Module): def __init__(self, inverse_depth=False, mono=False, attn_fuse_d=True, vis_ETA=False, attn_temp=1): super(stagenet, self).__init__() self.inverse_depth = inverse_depth self.mono = mono self.attn_fuse_d = attn_fuse_d self.vis_ETA = vis_ETA self.attn_temp = attn_temp def forward(self, features, proj_matrices, depth_hypo, regnet, stage_idx, group_cor=False, group_cor_dim=8, split_itv=1, fn=None): # step 1. feature extraction proj_matrices = torch.unbind(proj_matrices, 1) ref_feature, src_features = features[0], features[1:] ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] B,D,H,W = depth_hypo.shape C = ref_feature.shape[1] ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, D, 1, 1) cor_weight_sum = 1e-8 cor_feats = 0 # step 2. Epipolar Transformer Aggregation for src_idx, (src_fea, src_proj) in enumerate(zip(src_features, src_projs)): if self.vis_ETA: scan_name = fn[0].split('/')[0] image_name = fn[0].split('/')[2][:-2] save_fn = './debug_figs/vis_ETA/{}_stage{}_src{}'.format(scan_name+'_'+image_name, stage_idx, src_idx) else: save_fn = None src_proj_new = src_proj[:, 0].clone() src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4]) ref_proj_new = ref_proj[:, 0].clone() ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4]) warped_src = homo_warping(src_fea, src_proj_new, ref_proj_new, depth_hypo, self.vis_ETA, save_fn) # B C D H W if group_cor: warped_src = warped_src.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W) ref_volume = ref_volume.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W) cor_feat = (warped_src * ref_volume).mean(2) # B G D H W else: cor_feat = (ref_volume - warped_src)**2 # B C D H W del warped_src, src_proj, src_fea if self.vis_ETA: vis_weight = torch.softmax(cor_feat.sum(1), 1).detach().cpu().numpy() np.save(save_fn, vis_weight) if not self.attn_fuse_d: cor_weight = torch.softmax(cor_feat.sum(1), 1).max(1)[0] # B H W cor_weight_sum += cor_weight # B H W cor_feats += cor_weight.unsqueeze(1).unsqueeze(1) * cor_feat # B C D H W else: cor_weight = torch.softmax(cor_feat.sum(1) / self.attn_temp, 1) / math.sqrt(C) # B D H W cor_weight_sum += cor_weight # B D H W cor_feats += cor_weight.unsqueeze(1) * cor_feat # B C D H W del cor_weight, cor_feat if not self.attn_fuse_d: cor_feats = cor_feats / cor_weight_sum.unsqueeze(1).unsqueeze(1) # B C D H W else: cor_feats = cor_feats / cor_weight_sum.unsqueeze(1) # B C D H W del cor_weight_sum, src_features # step 3. regularization attn_weight = regnet(cor_feats) # B D H W del cor_feats attn_weight = F.softmax(attn_weight, dim=1) # B D H W # step 4. depth argmax attn_max_indices = attn_weight.max(1, keepdim=True)[1] # B 1 H W depth = torch.gather(depth_hypo, 1, attn_max_indices).squeeze(1) # B H W if not self.training: with torch.no_grad(): photometric_confidence = attn_weight.max(1)[0] # B H W photometric_confidence = F.interpolate(photometric_confidence.unsqueeze(1), scale_factor=2**(3-stage_idx), mode='bilinear', align_corners=True).squeeze(1) else: photometric_confidence = torch.tensor(0.0, dtype=torch.float32, device=ref_feature.device, requires_grad=False) ret_dict = {"depth": depth, "photometric_confidence": photometric_confidence, "hypo_depth": depth_hypo, "attn_weight": attn_weight} if self.inverse_depth: last_depth_itv = 1./depth_hypo[:,2,:,:] - 1./depth_hypo[:,1,:,:] inverse_min_depth = 1/depth + split_itv * last_depth_itv # B H W inverse_max_depth = 1/depth - split_itv * last_depth_itv # B H W ret_dict['inverse_min_depth'] = inverse_min_depth ret_dict['inverse_max_depth'] = inverse_max_depth # if self.mono and self.training: if self.mono: ret_dict['mono_feat'] = ref_feature # B C H W return ret_dict def sinkhorn(gt_depth, hypo_depth, attn_weight, mask, iters, eps=1, continuous=False): """ gt_depth: B H W hypo_depth: B D H W attn_weight: B D H W mask: B H W """ B,D,H,W = attn_weight.shape if not continuous: D_map = torch.stack([torch.arange(-i,D-i,1, dtype=torch.float32, device=gt_depth.device) for i in range(D)], dim=1).abs() D_map = D_map[None,None,:,:].repeat(B,H*W,1,1) # B HW D D gt_indices = torch.abs(hypo_depth - gt_depth[:,None,:,:]).min(1)[1].squeeze(1).reshape(B*H*W, 1) # BHW, 1 gt_dist = torch.zeros_like(hypo_depth).permute(0,2,3,1).reshape(B*H*W, D) gt_dist.scatter_add_(1,gt_indices,torch.ones([gt_dist.shape[0],1], dtype=gt_dist.dtype, device=gt_dist.device)) gt_dist = gt_dist.reshape(B,H*W,D) # B HW D else: gt_dist = torch.zeros((B,H*W,D+1), dtype=torch.float32, device=gt_depth.device, requires_grad=False) # B HW D+1 gt_dist[:,:,-1] = 1 D_map = torch.zeros((B,D,D+1), dtype=torch.float32, device=gt_depth.device, requires_grad=False) # B D D+1 D_map[:, :D, :D] = torch.stack([torch.arange(-i,D-i,1, dtype=torch.float32, device=gt_depth.device) for i in range(D)], dim=1).abs().unsqueeze(0) # B D D+1 D_map = D_map[:,None,None,:,:].repeat(1,H,W,1,1) # B H W D D+1 itv = 1/hypo_depth[:,2,:,:] - 1/hypo_depth[:,1,:,:] # B H W gt_bin_distance_ = (1/gt_depth - 1/hypo_depth[:,0,:,:]) / itv # B H W #FIXME hard code 100 gt_bin_distance_[~mask] = 10 gt_bin_distance = torch.stack([(gt_bin_distance_ - i).abs() for i in range(D)], dim=1).permute(0,2,3,1) # B H W D D_map[:,:,:,:,-1] = gt_bin_distance D_map = D_map.reshape(B,H*W,D,1+D) # B HW D D+1 pred_dist = attn_weight.permute(0,2,3,1).reshape(B,H*W,D) # B HW D # map to log space for stability log_mu = (gt_dist+1e-12).log() log_nu = (pred_dist+1e-12).log() # B HW D or D+1 u, v = torch.zeros_like(log_nu), torch.zeros_like(log_mu) for _ in range(iters): # scale v first then u to ensure row sum is 1, col sum slightly larger than 1 v = log_mu - torch.logsumexp(D_map/eps + u.unsqueeze(3), dim=2) # log(sum(exp())) u = log_nu - torch.logsumexp(D_map/eps + v.unsqueeze(2), dim=3) # convert back from log space, recover probabilities by normalization 2W T_map = (D_map/eps + u.unsqueeze(3) + v.unsqueeze(2)).exp() # B HW D D loss = (T_map * D_map).reshape(B*H*W,-1)[mask.reshape(-1)].sum(-1).mean() return T_map, loss ================================================ FILE: requirements.txt ================================================ torch==1.9.0 torchvision==0.10.0 numpy pillow tensorboardX opencv-python plyfile ================================================ FILE: scripts/test_dtu.sh ================================================ #!/usr/bin/env bash DTU_TESTPATH="/mnt/cfs/algorithm/public_data/mvs/dtu_test" DTU_TESTLIST="lists/dtu/test.txt" DTU_size=$1 exp=$2 PY_ARGS=${@:3} DTU_LOG_DIR="./checkpoints/dtu/"$exp if [ ! -d $DTU_LOG_DIR ]; then mkdir -p $DTU_LOG_DIR fi DTU_CKPT_FILE=$DTU_LOG_DIR"/finalmodel.ckpt" DTU_OUT_DIR="./outputs/dtu/"$exp if [ $DTU_size = "raw" ] ; then python test_mvs4.py --dataset=general_eval4 --batch_size=1 --testpath=$DTU_TESTPATH --testlist=$DTU_TESTLIST --loadckpt $DTU_CKPT_FILE --interval_scale 1.06 --outdir $DTU_OUT_DIR\ --use_raw_train --thres_view 4 --conf 0.5 --group_cor --attn_temp 2 --inverse_depth $PY_ARGS | tee -a $DTU_LOG_DIR/log_test.txt else python test_mvs4.py --dataset=general_eval4 --batch_size=1 --testpath=$DTU_TESTPATH --testlist=$DTU_TESTLIST --loadckpt $DTU_CKPT_FILE --interval_scale 1.06 --outdir $DTU_OUT_DIR\ --thres_view 4 --conf 0.5 --group_cor --attn_temp 2 --inverse_depth $PY_ARGS | tee -a $DTU_LOG_DIR/log_test.txt fi ================================================ FILE: scripts/train_dtu.sh ================================================ #!/usr/bin/env bash DTU_TRAINING="/mnt/cfs/algorithm/public_data/mvs/mvs_training/dtu" DTU_TRAINLIST="lists/dtu/train.txt" DTU_TESTLIST="lists/dtu/test.txt" DTU_trainsize=$1 exp=$2 PY_ARGS=${@:3} DTU_LOG_DIR="./checkpoints/dtu/"$exp if [ ! -d $DTU_LOG_DIR ]; then mkdir -p $DTU_LOG_DIR fi DTU_CKPT_FILE=$DTU_LOG_DIR"/finalmodel.ckpt" DTU_OUT_DIR="./outputs/dtu/"$exp if [ $DTU_trainsize = "raw" ] ; then python -m torch.distributed.launch --nproc_per_node=4 train_mvs4.py --logdir $DTU_LOG_DIR --dataset=dtu_yao4 --batch_size=2 --trainpath=$DTU_TRAINING --summary_freq 100 \ --group_cor --inverse_depth --rt --mono --attn_temp 2 --use_raw_train --trainlist $DTU_TRAINLIST --testlist $DTU_TESTLIST $PY_ARGS | tee -a $DTU_LOG_DIR/log.txt else python -m torch.distributed.launch --nproc_per_node=4 train_mvs4.py --logdir $DTU_LOG_DIR --dataset=dtu_yao4 --batch_size=2 --trainpath=$DTU_TRAINING --summary_freq 100 \ --group_cor --inverse_depth --rt --mono --attn_temp 2 --trainlist $DTU_TRAINLIST --testlist $DTU_TESTLIST $PY_ARGS | tee -a $DTU_LOG_DIR/log.txt fi ================================================ FILE: test_mvs4.py ================================================ import argparse, os, time, sys, gc, cv2 import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader import torch.nn.functional as F import numpy as np from datasets import find_dataset_def from models import * from utils import * from datasets.data_io import read_pfm, save_pfm from plyfile import PlyData, PlyElement from PIL import Image from multiprocessing import Pool from functools import partial import signal cudnn.benchmark = True parser = argparse.ArgumentParser(description='Predict depth, filter, and fuse') parser.add_argument('--model', default='mvsnet', help='select model') parser.add_argument('--dataset', default='dtu_yao_eval', help='select dataset') parser.add_argument('--testpath', help='testing data dir for some scenes') parser.add_argument('--testlist', help='testing scene list') parser.add_argument('--batch_size', type=int, default=1, help='testing batch size') parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint') parser.add_argument('--outdir', default='./outputs', help='output dir') parser.add_argument('--share_cr', action='store_true', help='whether share the cost volume regularization') parser.add_argument('--ndepths', type=str, default="8,8,4,4", help='ndepths') parser.add_argument('--depth_inter_r', type=str, default="0.5,0.5,0.5,1", help='depth_intervals_ratio') parser.add_argument('--interval_scale', type=float, required=True, help='the depth interval scale') parser.add_argument('--num_view', type=int, default=5, help='num of view') parser.add_argument('--max_h', type=int, default=864, help='testing max h') parser.add_argument('--max_w', type=int, default=1152, help='testing max w') parser.add_argument('--fix_res', action='store_true', help='scene all using same res') parser.add_argument('--num_worker', type=int, default=4, help='depth_filer worker') parser.add_argument('--save_freq', type=int, default=20, help='save freq of local pcd') parser.add_argument('--filter_method', type=str, default='normal', choices=["gipuma", "normal"], help="filter method") #filter parser.add_argument('--conf', type=float, default=0.9, help='prob confidence') parser.add_argument('--thres_view', type=int, default=5, help='threshold of num view') parser.add_argument("--fpn_base_channel", type=int, default=8) parser.add_argument("--reg_channel", type=int, default=8) parser.add_argument('--reg_mode', type=str, default="reg2d") parser.add_argument('--dlossw', type=str, default="1,1,1,1", help='depth loss weight for different stage') parser.add_argument('--resume', action='store_true', help='continue to train the model') parser.add_argument('--group_cor', action='store_true',help='group correlation') parser.add_argument('--group_cor_dim', type=str, default="8,8,4,4", help='group correlation dim') parser.add_argument('--inverse_depth', action='store_true',help='inverse depth') parser.add_argument('--agg_type', type=str, default="ConvBnReLU3D", help='cost regularization type') parser.add_argument('--dcn', action='store_true',help='dcn') parser.add_argument('--arch_mode', type=str, default="fpn") parser.add_argument('--ot_continous', action='store_true',help='optimal transport continous gt bin') parser.add_argument('--ot_eps', type=float, default=1) parser.add_argument('--ot_iter', type=int, default=0) parser.add_argument('--rt', action='store_true',help='robust training') parser.add_argument('--use_raw_train', action='store_true',help='using 1200x1600 training') parser.add_argument('--mono', action='store_true',help='query to build mono depth prediction and loss') parser.add_argument('--split', type=str, default='intermediate', help='intermediate or advanced') parser.add_argument('--save_jpg', action='store_true') parser.add_argument('--ASFF', action='store_true') parser.add_argument('--vis_ETA', action='store_true') parser.add_argument('--vis_mono', action='store_true') parser.add_argument('--attn_temp', type=float, default=2) # parse arguments and check args = parser.parse_args() print("argv:", sys.argv[1:]) print_args(args) if args.use_raw_train: args.max_h = 1200 args.max_w = 1600 num_stage = len([int(nd) for nd in args.ndepths.split(",") if nd]) Interval_Scale = args.interval_scale print("***********Interval_Scale**********\n", Interval_Scale) # read intrinsics and extrinsics def read_camera_parameters(filename): with open(filename) as f: lines = f.readlines() lines = [line.rstrip() for line in lines] # extrinsics: line [1,5), 4x4 matrix extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) # intrinsics: line [7-10), 3x3 matrix intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) return intrinsics, extrinsics # read an image def read_img(filename): img = Image.open(filename) # scale 0~255 to 0~1 np_img = np.array(img, dtype=np.float32) / 255. return np_img # read a binary mask def read_mask(filename): return read_img(filename) > 0.5 # save a binary mask def save_mask(filename, mask): assert mask.dtype == np.bool mask = mask.astype(np.uint8) * 255 Image.fromarray(mask).save(filename) # read a pair file, [(ref_view1, [src_view1-1, ...]), (ref_view2, [src_view2-1, ...]), ...] def read_pair_file(filename): data = [] with open(filename) as f: num_viewpoint = int(f.readline()) # 49 viewpoints for view_idx in range(num_viewpoint): ref_view = int(f.readline().rstrip()) src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] if len(src_views) > 0: data.append((ref_view, src_views)) return data def write_cam(file, cam): f = open(file, "w") f.write('extrinsic\n') for i in range(0, 4): for j in range(0, 4): f.write(str(cam[0][i][j]) + ' ') f.write('\n') f.write('\n') f.write('intrinsic\n') for i in range(0, 3): for j in range(0, 3): f.write(str(cam[1][i][j]) + ' ') f.write('\n') f.write('\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' ' + str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\n') f.close() def save_depth(testlist): torch.cuda.reset_peak_memory_stats() total_time = 0 total_sample = 0 for scene in testlist: time_this_scene, sample_this_scene = save_scene_depth([scene]) total_time += time_this_scene total_sample += sample_this_scene gpu_measure = torch.cuda.max_memory_allocated() / 1024. / 1024. /1024. print('avg time: {}'.format(total_time/total_sample)) print('max gpu: {}'.format(gpu_measure)) def save_scene_depth(testlist): # dataset, dataloader MVSDataset = find_dataset_def(args.dataset) test_dataset = MVSDataset(args.testpath, testlist, "test", args.num_view, Interval_Scale, max_h=args.max_h, max_w=args.max_w, fix_res=args.fix_res) TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False) # model model = MVS4net(arch_mode=args.arch_mode, reg_net=args.reg_mode, num_stage=4, fpn_base_channel=args.fpn_base_channel, reg_channel=args.reg_channel, stage_splits=[int(n) for n in args.ndepths.split(",")], depth_interals_ratio=[float(ir) for ir in args.depth_inter_r.split(",")], group_cor=args.group_cor, group_cor_dim=[int(n) for n in args.group_cor_dim.split(",")], inverse_depth=args.inverse_depth, agg_type=args.agg_type, dcn=args.dcn, mono=args.mono, asff=args.ASFF, attn_temp=args.attn_temp, vis_ETA=args.vis_ETA, vis_mono=args.vis_mono ) # load checkpoint file specified by args.loadckpt print("loading model {}".format(args.loadckpt)) state_dict = torch.load(args.loadckpt, map_location=torch.device("cpu")) model.load_state_dict(state_dict['model'], strict=True) model = nn.DataParallel(model) model.cuda() model.eval() total_time = 0 with torch.no_grad(): for batch_idx, sample in enumerate(TestImgLoader): sample_cuda = tocuda(sample) start_time = time.time() outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"], sample["filename"]) end_time = time.time() total_time += end_time - start_time outputs = tensor2numpy(outputs) del sample_cuda filenames = sample["filename"] cams = sample["proj_matrices"]["stage{}".format(num_stage)].numpy() imgs = sample["imgs"] print('Iter {}/{}, Time:{} Res:{}'.format(batch_idx, len(TestImgLoader), end_time - start_time, imgs[0].shape)) # save depth maps and confidence maps for filename, cam, img, depth_est, photometric_confidence in zip(filenames, cams, imgs, \ outputs["depth"], outputs["photometric_confidence"]): img = img[0].numpy() #ref view cam = cam[0] #ref cam depth_filename = os.path.join(args.outdir, filename.format('depth_est', '.pfm')) confidence_filename = os.path.join(args.outdir, filename.format('confidence', '.pfm')) cam_filename = os.path.join(args.outdir, filename.format('cams', '_cam.txt')) img_filename = os.path.join(args.outdir, filename.format('images', '.jpg')) ply_filename = os.path.join(args.outdir, filename.format('ply_local', '.ply')) os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True) os.makedirs(confidence_filename.rsplit('/', 1)[0], exist_ok=True) os.makedirs(cam_filename.rsplit('/', 1)[0], exist_ok=True) os.makedirs(img_filename.rsplit('/', 1)[0], exist_ok=True) os.makedirs(ply_filename.rsplit('/', 1)[0], exist_ok=True) #save depth maps save_pfm(depth_filename, depth_est) if args.save_jpg: for stage_idx in range(4): depth_jpg_filename = os.path.join(args.outdir, filename.format('depth_est', '{}_{}.jpg'.format('stage',str(stage_idx+1)))) stage_depth = outputs['stage{}'.format(stage_idx+1)]['depth'][0] mi = np.min(stage_depth[stage_depth>0]) ma = np.max(stage_depth) depth = (stage_depth-mi)/(ma-mi+1e-8) depth = (255*depth).astype(np.uint8) depth_img = cv2.applyColorMap(depth, cv2.COLORMAP_JET) print(cv2.imwrite(depth_jpg_filename, depth_img)) if stage_idx == 0: continue mono_depth_jpg_filename = os.path.join(args.outdir, filename.format('depth_est', '{}_{}.jpg'.format('mono',str(stage_idx+1)))) stage_mono_depth = outputs['stage{}'.format(stage_idx+1)]['mono_depth'][0] mi = np.min(stage_mono_depth[stage_mono_depth>0]) ma = np.max(stage_mono_depth) depth = (stage_mono_depth-mi)/(ma-mi+1e-8) depth = (255*depth).astype(np.uint8) depth_img = cv2.applyColorMap(depth, cv2.COLORMAP_JET) print(cv2.imwrite(mono_depth_jpg_filename, depth_img)) #save confidence maps confidence_list = [outputs['stage{}'.format(i)]['photometric_confidence'].squeeze(0) for i in range(1,5)] photometric_confidence = confidence_list[-1] # H W save_pfm(confidence_filename, photometric_confidence) #save cams, img write_cam(cam_filename, cam) img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype(np.uint8) img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) cv2.imwrite(img_filename, img_bgr) if batch_idx % args.save_freq == 0: generate_pointcloud(img, depth_est, ply_filename, cam[1, :3, :3]) torch.cuda.empty_cache() gc.collect() return total_time, len(TestImgLoader) # project the reference point cloud into the source view, then project back def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): width, height = depth_ref.shape[1], depth_ref.shape[0] ## step1. project reference pixels to the source view # reference view x, y x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1]) # reference 3D space xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref), np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1])) # source 3D space xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)), np.vstack((xyz_ref, np.ones_like(x_ref))))[:3] # source view x, y K_xyz_src = np.matmul(intrinsics_src, xyz_src) xy_src = K_xyz_src[:2] / K_xyz_src[2:3] ## step2. reproject the source view points with source view depth estimation # find the depth estimation of the source view x_src = xy_src[0].reshape([height, width]).astype(np.float32) y_src = xy_src[1].reshape([height, width]).astype(np.float32) sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR) # mask = sampled_depth_src > 0 # source 3D space # NOTE that we should use sampled source-view depth_here to project back xyz_src = np.matmul(np.linalg.inv(intrinsics_src), np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) # reference 3D space xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)), np.vstack((xyz_src, np.ones_like(x_ref))))[:3] # source view x, y, depth depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32) K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected) xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3] x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32) y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32) return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): width, height = depth_ref.shape[1], depth_ref.shape[0] x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src) # check |p_reproj-p_1| < 1 dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) # check |d_reproj-d_1| / d_1 < 0.01 depth_diff = np.abs(depth_reprojected - depth_ref) relative_depth_diff = depth_diff / depth_ref mask = np.logical_and(dist < 1, relative_depth_diff < 0.01) depth_reprojected[~mask] = 0 return mask, depth_reprojected, x2d_src, y2d_src def filter_depth(pair_folder, scan_folder, out_folder, plyfilename): # the pair file pair_file = os.path.join(pair_folder, "pair.txt") # for the final point cloud vertexs = [] vertex_colors = [] pair_data = read_pair_file(pair_file) # for each reference view and the corresponding source views for ref_view, src_views in pair_data: # src_views = src_views[:args.num_view] # load the camera parameters ref_intrinsics, ref_extrinsics = read_camera_parameters( os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view))) # load the reference image ref_img = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view))) # load the estimated depth of the reference view ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0] # load the photometric mask of the reference view confidence = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}.pfm'.format(ref_view)))[0] photo_mask = confidence > args.conf all_srcview_depth_ests = [] all_srcview_x = [] all_srcview_y = [] all_srcview_geomask = [] # compute the geometric mask geo_mask_sum = 0 for src_view in src_views: # camera parameters of the source view src_intrinsics, src_extrinsics = read_camera_parameters( os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(src_view))) # the estimated depth of the source view src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0] geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics, src_depth_est, src_intrinsics, src_extrinsics) geo_mask_sum += geo_mask.astype(np.int32) all_srcview_depth_ests.append(depth_reprojected) all_srcview_x.append(x2d_src) all_srcview_y.append(y2d_src) all_srcview_geomask.append(geo_mask) depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1) # at least 3 source views matched geo_mask = geo_mask_sum >= args.thres_view final_mask = np.logical_and(photo_mask, geo_mask) os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True) save_mask(os.path.join(out_folder, "mask/{:0>8}_photo.png".format(ref_view)), photo_mask) save_mask(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask) save_mask(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask) print("processing {}, ref-view{:0>2}, photo/geo/final-mask:{}/{}/{}".format(scan_folder, ref_view, photo_mask.mean(), geo_mask.mean(), final_mask.mean())) height, width = depth_est_averaged.shape[:2] x, y = np.meshgrid(np.arange(0, width), np.arange(0, height)) # valid_points = np.logical_and(final_mask, ~used_mask[ref_view]) valid_points = final_mask print("valid_points", valid_points.mean()) x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points] #color = ref_img[1:-16:4, 1::4, :][valid_points] # hardcoded for DTU dataset color = ref_img[valid_points] xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics), np.vstack((x, y, np.ones_like(x))) * depth) xyz_world = np.matmul(np.linalg.inv(ref_extrinsics), np.vstack((xyz_ref, np.ones_like(x))))[:3] vertexs.append(xyz_world.transpose((1, 0))) vertex_colors.append((color * 255).astype(np.uint8)) vertexs = np.concatenate(vertexs, axis=0) vertex_colors = np.concatenate(vertex_colors, axis=0) vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr) for prop in vertexs.dtype.names: vertex_all[prop] = vertexs[prop] for prop in vertex_colors.dtype.names: vertex_all[prop] = vertex_colors[prop] el = PlyElement.describe(vertex_all, 'vertex') PlyData([el]).write(plyfilename) print("saving the final model to", plyfilename) def init_worker(): ''' Catch Ctrl+C signal to termiante workers ''' signal.signal(signal.SIGINT, signal.SIG_IGN) def pcd_filter_worker(scan): if args.testlist != "all": scan_id = int(scan[4:]) save_name = 'mvsnet{:0>3}_l3.ply'.format(scan_id) else: save_name = '{}.ply'.format(scan) pair_folder = os.path.join(args.testpath, scan) scan_folder = os.path.join(args.outdir, scan) out_folder = os.path.join(args.outdir, scan) filter_depth(pair_folder, scan_folder, out_folder, os.path.join(args.outdir, save_name)) def pcd_filter(testlist, number_worker): partial_func = partial(pcd_filter_worker) p = Pool(number_worker, init_worker) try: p.map(partial_func, testlist) except KeyboardInterrupt: print("....\nCaught KeyboardInterrupt, terminating workers") p.terminate() else: p.close() p.join() def mrun_rst(eval_dir, plyPath): print('Runing BaseEvalMain_func.m...') os.chdir(eval_dir) os.system('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/misc/matlab/bin/matlab -nodesktop -nosplash -r "BaseEvalMain_func(\'{}\'); quit" '.format(plyPath)) print('Runing ComputeStat_func.m...') os.system('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/misc/matlab/bin/matlab -nodesktop -nosplash -r "ComputeStat_func(\'{}\'); quit" '.format(plyPath)) print('Check your results! ^-^') if __name__ == '__main__': if args.vis_ETA: os.makedirs('./debug_figs/vis_ETA', exist_ok=True) if args.testlist != "all": with open(args.testlist) as f: content = f.readlines() testlist = [line.rstrip() for line in content] # step1. save all the depth maps and the masks in outputs directory save_depth(testlist) if args.dataset.startswith('general'): # step2. filter saved depth maps with photometric confidence maps and geometric constraints pcd_filter(testlist, args.num_worker) # Make sure the matlab is installed and you can comment out the following lines # And you also need to change the path of the matlab script mrun_rst( eval_dir='./evaluations/dtu/', plyPath='./'+args.outdir[1:] ) ================================================ FILE: train_mvs4.py ================================================ import argparse, os, sys, time, gc, datetime import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim from torch.utils.data import DataLoader from tensorboardX import SummaryWriter from datasets import find_dataset_def from models import * from utils import * import torch.distributed as dist cudnn.benchmark = True parser = argparse.ArgumentParser(description='A PyTorch Implementation of MVSTER') parser.add_argument('--mode', default='train', help='train or test', choices=['train', 'test', 'profile']) parser.add_argument('--device', default='cuda', help='select model') parser.add_argument('--dataset', default='dtu_yao4', help='select dataset') parser.add_argument('--trainpath', help='train datapath') parser.add_argument('--testpath', help='test datapath') parser.add_argument('--trainlist', help='train list') parser.add_argument('--testlist', help='test list') parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train') parser.add_argument('--lr', type=float, default=0.001, help='learning rate') parser.add_argument('--lrepochs', type=str, default="6,8,9:2", help='epoch ids to downscale lr and the downscale rate') parser.add_argument('--wd', type=float, default=0.0, help='weight decay') parser.add_argument('--batch_size', type=int, default=1, help='train batch size') parser.add_argument('--interval_scale', type=float, default=1.06, help='the number of depth values') parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint') parser.add_argument('--logdir', default='./checkpoints/debug', help='the directory to save checkpoints/logs') parser.add_argument('--resume', action='store_true', help='continue to train the model') parser.add_argument('--summary_freq', type=int, default=2, help='print and summary frequency') parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency') parser.add_argument('--eval_freq', type=int, default=1, help='eval freq') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') parser.add_argument('--pin_m', action='store_true', help='data loader pin memory') parser.add_argument("--local_rank", type=int, default=0) parser.add_argument('--ndepths', type=str, default="8,8,4,4", help='ndepths') parser.add_argument('--depth_inter_r', type=str, default="0.5,0.5,0.5,1", help='depth_intervals_ratio') parser.add_argument('--dlossw', type=str, default="1,1,1,1", help='depth loss weight for different stage') parser.add_argument('--l1ce_lw', type=str, default="0,1", help='loss weight for l1 and ce loss') parser.add_argument("--fpn_base_channel", type=int, default=8) parser.add_argument("--reg_channel", type=int, default=8) parser.add_argument('--reg_mode', type=str, default="reg2d") parser.add_argument('--group_cor', action='store_true',help='group correlation') parser.add_argument('--group_cor_dim', type=str, default="8,8,4,4", help='group correlation dim') parser.add_argument('--inverse_depth', action='store_true',help='inverse depth') parser.add_argument('--agg_type', type=str, default="ConvBnReLU3D", help='cost regularization type') parser.add_argument('--dcn', action='store_true',help='dcn') parser.add_argument('--pos_enc', type=int, default=0, help='pos_enc: 0 no pos enc; 1 depth sine; 2 learnable pos enc') parser.add_argument('--arch_mode', type=str, default="fpn") parser.add_argument('--ot_continous', action='store_true',help='optimal transport continous gt bin') parser.add_argument('--ot_iter', type=int, default=10) parser.add_argument('--ot_eps', type=float, default=1) parser.add_argument('--rt', action='store_true',help='robust training') parser.add_argument('--max_h', type=int, default=864, help='testing max h') parser.add_argument('--max_w', type=int, default=1152, help='testing max w') parser.add_argument('--use_raw_train', action='store_true',help='using 1200x1600 training') parser.add_argument('--mono', action='store_true',help='query to build mono depth prediction and loss') parser.add_argument('--lr_scheduler', type=str, default='MS') parser.add_argument('--ASFF', action='store_true') parser.add_argument('--attn_temp', type=float, default=2) num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 is_distributed = num_gpus > 1 # main function def train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args): milestones = [len(TrainImgLoader) * int(epoch_idx) for epoch_idx in args.lrepochs.split(':')[0].split(',')] lr_gamma = 1 / float(args.lrepochs.split(':')[1]) if args.lr_scheduler == 'MS': lr_scheduler = WarmupMultiStepLR(optimizer, milestones, gamma=lr_gamma, warmup_factor=1.0/3, warmup_iters=500, last_epoch=len(TrainImgLoader) * start_epoch - 1) elif args.lr_scheduler == 'cos': lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=int(args.epochs*len(TrainImgLoader)), eta_min=0) elif args.lr_scheduler == 'onecycle': lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr,total_steps=int(args.epochs*len(TrainImgLoader))) for epoch_idx in range(start_epoch, args.epochs): print('Epoch {}:'.format(epoch_idx)) global_step = len(TrainImgLoader) * epoch_idx # training for batch_idx, sample in enumerate(TrainImgLoader): start_time = time.time() global_step = len(TrainImgLoader) * epoch_idx + batch_idx do_summary = global_step % args.summary_freq == 0 loss, scalar_outputs, image_outputs = train_sample(model, model_loss, optimizer, sample, args) lr_scheduler.step() if (not is_distributed) or (dist.get_rank() == 0): if do_summary: save_scalars(logger, 'train', scalar_outputs, global_step) save_images(logger, 'train', image_outputs, global_step) print( "Epoch {}/{}, Iter {}/{}, lr {:.6f}, train loss = {:.3f}, d_loss = {:.3f}, {:.3f}, {:.3f}, {:.3f}, c_loss = {:.3f}, {:.3f}, {:.3f}, {:.3f}, range_err = {:.3f}, {:.3f}, {:.3f}, {:.3f}, time = {:.3f}".format( epoch_idx, args.epochs, batch_idx, len(TrainImgLoader), optimizer.param_groups[0]["lr"], loss, scalar_outputs["s0_d_loss"], scalar_outputs["s1_d_loss"], scalar_outputs["s2_d_loss"], scalar_outputs["s3_d_loss"], scalar_outputs["s0_c_loss"], scalar_outputs["s1_c_loss"], scalar_outputs["s2_c_loss"], scalar_outputs["s3_c_loss"], scalar_outputs["s0_range_err_ratio"], scalar_outputs["s1_range_err_ratio"], scalar_outputs["s2_range_err_ratio"], scalar_outputs["s3_range_err_ratio"], time.time() - start_time)) del scalar_outputs, image_outputs # checkpoint if (not is_distributed) or (dist.get_rank() == 0): if (epoch_idx + 1) % args.save_freq == 0: if epoch_idx == args.epochs - 1: torch.save({ 'epoch': epoch_idx, 'model': model.module.state_dict(), 'optimizer': optimizer.state_dict()}, "{}/finalmodel.ckpt".format(args.logdir)) gc.collect() # testing if (epoch_idx % args.eval_freq == 0) or (epoch_idx == args.epochs - 1): avg_test_scalars = DictAverageMeter() for batch_idx, sample in enumerate(TestImgLoader): start_time = time.time() global_step = len(TrainImgLoader) * epoch_idx + batch_idx do_summary = global_step % args.summary_freq == 0 loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args) if (not is_distributed) or (dist.get_rank() == 0): if do_summary: save_scalars(logger, 'test', scalar_outputs, global_step) save_images(logger, 'test', image_outputs, global_step) print( "Epoch {}/{}, Iter {}/{}, lr {:.6f}, test loss = {:.3f}, d_loss = {:.3f}, {:.3f}, {:.3f}, {:.3f}, c_loss = {:.3f}, {:.3f}, {:.3f}, {:.3f}, range_err = {:.3f}, {:.3f}, {:.3f}, {:.3f}, time = {:.3f}".format( epoch_idx, args.epochs, batch_idx, len(TrainImgLoader), optimizer.param_groups[0]["lr"], loss, scalar_outputs["s0_d_loss"], scalar_outputs["s1_d_loss"], scalar_outputs["s2_d_loss"], scalar_outputs["s3_d_loss"], scalar_outputs["s0_c_loss"], scalar_outputs["s1_c_loss"], scalar_outputs["s2_c_loss"], scalar_outputs["s3_c_loss"], scalar_outputs["s0_range_err_ratio"], scalar_outputs["s1_range_err_ratio"], scalar_outputs["s2_range_err_ratio"], scalar_outputs["s3_range_err_ratio"], time.time() - start_time)) avg_test_scalars.update(scalar_outputs) del scalar_outputs, image_outputs if (not is_distributed) or (dist.get_rank() == 0): save_scalars(logger, 'fulltest', avg_test_scalars.mean(), global_step) print("avg_test_scalars:", avg_test_scalars.mean()) gc.collect() def test(model, model_loss, TestImgLoader, args): avg_test_scalars = DictAverageMeter() for batch_idx, sample in enumerate(TestImgLoader): start_time = time.time() loss, scalar_outputs, image_outputs = test_sample_depth(model, model_loss, sample, args) avg_test_scalars.update(scalar_outputs) del scalar_outputs, image_outputs if (not is_distributed) or (dist.get_rank() == 0): print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, len(TestImgLoader), loss, time.time() - start_time)) if batch_idx % 100 == 0: print("Iter {}/{}, test results = {}".format(batch_idx, len(TestImgLoader), avg_test_scalars.mean())) if (not is_distributed) or (dist.get_rank() == 0): print("final", avg_test_scalars.mean()) def train_sample(model, model_loss, optimizer, sample, args): model.train() optimizer.zero_grad() sample_cuda = tocuda(sample) depth_gt_ms = sample_cuda["depth"] mask_ms = sample_cuda["mask"] num_stage = len([int(nd) for nd in args.ndepths.split(",") if nd]) depth_gt = depth_gt_ms["stage{}".format(num_stage)] mask = mask_ms["stage{}".format(num_stage)] outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) depth_est = outputs["depth"] loss, stage_d_loss, stage_c_loss, range_err_ratio = model_loss( outputs, depth_gt_ms, mask_ms, stage_lw=[float(e) for e in args.dlossw.split(",") if e], l1ce_lw=[float(lw) for lw in args.l1ce_lw.split(",")], inverse_depth=args.inverse_depth, ot_iter=args.ot_iter, ot_continous=args.ot_continous, ot_eps=args.ot_eps, mono=args.mono ) loss.backward() optimizer.step() scalar_outputs = {"loss": loss, "s0_d_loss": stage_d_loss[0], "s1_d_loss": stage_d_loss[1], "s2_d_loss": stage_d_loss[2], "s3_d_loss": stage_d_loss[3], "s0_c_loss": stage_c_loss[0], "s1_c_loss": stage_c_loss[1], "s2_c_loss": stage_c_loss[2], "s3_c_loss": stage_c_loss[3], "s0_range_err_ratio":range_err_ratio[0], "s1_range_err_ratio":range_err_ratio[1], "s2_range_err_ratio":range_err_ratio[2], "s3_range_err_ratio":range_err_ratio[3], "abs_depth_error": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5), "thres2mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2), "thres4mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4), "thres8mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8),} image_outputs = {"depth_est": depth_est * mask, "depth_est_nomask": depth_est, "depth_gt": sample["depth"]["stage1"], "ref_img": sample["imgs"][0], "mask": sample["mask"]["stage1"], "errormap": (depth_est - depth_gt).abs() * mask, } if is_distributed: scalar_outputs = reduce_scalar_outputs(scalar_outputs) return tensor2float(scalar_outputs["loss"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs) @make_nograd_func def test_sample_depth(model, model_loss, sample, args): if is_distributed: model_eval = model.module else: model_eval = model model_eval.eval() sample_cuda = tocuda(sample) depth_gt_ms = sample_cuda["depth"] mask_ms = sample_cuda["mask"] num_stage = len([int(nd) for nd in args.ndepths.split(",") if nd]) depth_gt = depth_gt_ms["stage{}".format(num_stage)] mask = mask_ms["stage{}".format(num_stage)] outputs = model_eval(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) depth_est = outputs["depth"] loss, stage_d_loss, stage_c_loss, range_err_ratio = model_loss( outputs, depth_gt_ms, mask_ms, stage_lw=[float(e) for e in args.dlossw.split(",") if e], l1ce_lw=[float(lw) for lw in args.l1ce_lw.split(",")], inverse_depth=args.inverse_depth, ot_iter=args.ot_iter, ot_continous=args.ot_continous, ot_eps=args.ot_eps, mono=False ) scalar_outputs = {"loss": loss, "s0_d_loss": stage_d_loss[0], "s1_d_loss": stage_d_loss[1], "s2_d_loss": stage_d_loss[2], "s3_d_loss": stage_d_loss[3], "s0_c_loss": stage_c_loss[0], "s1_c_loss": stage_c_loss[1], "s2_c_loss": stage_c_loss[2], "s3_c_loss": stage_c_loss[3], "s0_range_err_ratio":range_err_ratio[0], "s1_range_err_ratio":range_err_ratio[1], "s2_range_err_ratio":range_err_ratio[2], "s3_range_err_ratio":range_err_ratio[3], "abs_depth_error": AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5), "thres2mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 2), "thres4mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 4), "thres8mm_error": Thres_metrics(depth_est, depth_gt, mask > 0.5, 8), } image_outputs = {"depth_est": depth_est * mask, "depth_est_nomask": depth_est, "depth_gt": sample["depth"]["stage1"], "ref_img": sample["imgs"][0], "mask": sample["mask"]["stage1"], "errormap": (depth_est - depth_gt).abs() * mask} if is_distributed: scalar_outputs = reduce_scalar_outputs(scalar_outputs) return tensor2float(scalar_outputs["loss"]), tensor2float(scalar_outputs), tensor2numpy(image_outputs) if __name__ == '__main__': # parse arguments and check args = parser.parse_args() if args.resume: assert args.mode == "train" assert args.loadckpt is None if args.testpath is None: args.testpath = args.trainpath if is_distributed: torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group( backend="nccl", init_method="env://" ) synchronize() set_random_seed(args.seed) device = torch.device(args.device) if (not is_distributed) or (dist.get_rank() == 0): # create logger for mode "train" and "testall" if args.mode == "train": if not os.path.isdir(args.logdir): os.makedirs(args.logdir) current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S')) print("current time", current_time_str) print("creating new summary file") logger = SummaryWriter(args.logdir) print("argv:", sys.argv[1:]) print_args(args) # model, optimizer model = MVS4net(arch_mode=args.arch_mode, reg_net=args.reg_mode, num_stage=4, fpn_base_channel=args.fpn_base_channel, reg_channel=args.reg_channel, stage_splits=[int(n) for n in args.ndepths.split(",")], depth_interals_ratio=[float(ir) for ir in args.depth_inter_r.split(",")], group_cor=args.group_cor, group_cor_dim=[int(n) for n in args.group_cor_dim.split(",")], inverse_depth=args.inverse_depth, agg_type=args.agg_type, dcn=args.dcn, pos_enc=args.pos_enc, mono=args.mono, asff=args.ASFF, attn_temp=args.attn_temp, ) model.to(device) model_loss = MVS4net_loss optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.wd) # load parameters start_epoch = 0 if args.resume: saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")] saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0])) # use the latest checkpoint file loadckpt = os.path.join(args.logdir, saved_models[-1]) print("resuming", loadckpt) state_dict = torch.load(loadckpt, map_location=torch.device("cpu")) model.load_state_dict(state_dict['model']) optimizer.load_state_dict(state_dict['optimizer']) start_epoch = state_dict['epoch'] + 1 elif args.loadckpt: # load checkpoint file specified by args.loadckpt print("loading model {}".format(args.loadckpt)) state_dict = torch.load(args.loadckpt, map_location=torch.device("cpu")) model.load_state_dict(state_dict['model']) if (not is_distributed) or (dist.get_rank() == 0): print("start at epoch {}".format(start_epoch)) print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) if is_distributed: if dist.get_rank() == 0: print("Let's use", torch.cuda.device_count(), "GPUs!") model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, # find_unused_parameters=True, ) else: if torch.cuda.is_available(): print("Let's use", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) # dataset, dataloader MVSDataset = find_dataset_def(args.dataset) if args.dataset.startswith('dtu'): train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", 5, args.interval_scale, rt=args.rt, use_raw_train=args.use_raw_train) test_dataset = MVSDataset(args.testpath, args.testlist, "val", 5, args.interval_scale) elif args.dataset.startswith('blendedmvs'): train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", 7, robust_train=args.rt) test_dataset = MVSDataset(args.testpath, args.testlist, "val", 7) if is_distributed: train_sampler = torch.utils.data.DistributedSampler(train_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) test_sampler = torch.utils.data.DistributedSampler(test_dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) TrainImgLoader = DataLoader(train_dataset, args.batch_size, sampler=train_sampler, num_workers=1, drop_last=True, pin_memory=args.pin_m) TestImgLoader = DataLoader(test_dataset, args.batch_size, sampler=test_sampler, num_workers=1, drop_last=False, pin_memory=args.pin_m) else: TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=1, drop_last=True, pin_memory=args.pin_m) TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=1, drop_last=False, pin_memory=args.pin_m) if args.mode == "train": train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, start_epoch, args) elif args.mode == "test": test(model, model_loss, TestImgLoader, args) else: raise NotImplementedError ================================================ FILE: utils.py ================================================ import numpy as np import torchvision.utils as vutils import torch, random import torch.nn.functional as F # print arguments def print_args(args): print("################################ args ################################") for k, v in args.__dict__.items(): print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v)))) print("########################################################################") # torch.no_grad warpper for functions def make_nograd_func(func): def wrapper(*f_args, **f_kwargs): with torch.no_grad(): ret = func(*f_args, **f_kwargs) return ret return wrapper # convert a function into recursive style to handle nested dict/list/tuple variables def make_recursive_func(func): def wrapper(vars): if isinstance(vars, list): return [wrapper(x) for x in vars] elif isinstance(vars, tuple): return tuple([wrapper(x) for x in vars]) elif isinstance(vars, dict): return {k: wrapper(v) for k, v in vars.items()} else: return func(vars) return wrapper @make_recursive_func def tensor2float(vars): if isinstance(vars, float): return vars elif isinstance(vars, torch.Tensor): return vars.data.item() else: raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars))) @make_recursive_func def tensor2numpy(vars): if isinstance(vars, np.ndarray): return vars elif isinstance(vars, torch.Tensor): return vars.detach().cpu().numpy().copy() else: raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) @make_recursive_func def tocuda(vars): if isinstance(vars, torch.Tensor): return vars.to(torch.device("cuda")) elif isinstance(vars, str): return vars else: raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) def save_scalars(logger, mode, scalar_dict, global_step): scalar_dict = tensor2float(scalar_dict) for key, value in scalar_dict.items(): if not isinstance(value, (list, tuple)): name = '{}/{}'.format(mode, key) logger.add_scalar(name, value, global_step) else: for idx in range(len(value)): name = '{}/{}_{}'.format(mode, key, idx) logger.add_scalar(name, value[idx], global_step) def save_images(logger, mode, images_dict, global_step): images_dict = tensor2numpy(images_dict) def preprocess(name, img): if not (len(img.shape) == 3 or len(img.shape) == 4): raise NotImplementedError("invalid img shape {}:{} in save_images".format(name, img.shape)) if len(img.shape) == 3: img = img[:, np.newaxis, :, :] img = torch.from_numpy(img[:1]) return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True) for key, value in images_dict.items(): if not isinstance(value, (list, tuple)): name = '{}/{}'.format(mode, key) logger.add_image(name, preprocess(name, value), global_step) else: for idx in range(len(value)): name = '{}/{}_{}'.format(mode, key, idx) logger.add_image(name, preprocess(name, value[idx]), global_step) class DictAverageMeter(object): def __init__(self): self.data = {} self.count = 0 def update(self, new_input): self.count += 1 if len(self.data) == 0: for k, v in new_input.items(): if not isinstance(v, float): raise NotImplementedError("invalid data {}: {}".format(k, type(v))) self.data[k] = v else: for k, v in new_input.items(): if not isinstance(v, float): raise NotImplementedError("invalid data {}: {}".format(k, type(v))) self.data[k] += v def mean(self): return {k: v / self.count for k, v in self.data.items()} # a wrapper to compute metrics for each image individually def compute_metrics_for_each_image(metric_func): def wrapper(depth_est, depth_gt, mask, *args): batch_size = depth_gt.shape[0] results = [] # compute result one by one for idx in range(batch_size): ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args) results.append(ret) return torch.stack(results).mean() return wrapper @make_nograd_func @compute_metrics_for_each_image def Thres_metrics(depth_est, depth_gt, mask, thres): assert isinstance(thres, (int, float)) depth_est, depth_gt = depth_est[mask], depth_gt[mask] errors = torch.abs(depth_est - depth_gt) err_mask = errors > thres return torch.mean(err_mask.float()) # NOTE: please do not use this to build up training loss @make_nograd_func @compute_metrics_for_each_image def AbsDepthError_metrics(depth_est, depth_gt, mask, thres=None): depth_est, depth_gt = depth_est[mask], depth_gt[mask] error = (depth_est - depth_gt).abs() if thres is not None: error = error[(error >= float(thres[0])) & (error <= float(thres[1]))] if error.shape[0] == 0: return torch.tensor(0, device=error.device, dtype=error.dtype) return torch.mean(error) import torch.distributed as dist def synchronize(): """ Helper function to synchronize (barrier) among all processes when using distributed training """ if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() def get_world_size(): if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size() def reduce_scalar_outputs(scalar_outputs): world_size = get_world_size() if world_size < 2: return scalar_outputs with torch.no_grad(): names = [] scalars = [] for k in sorted(scalar_outputs.keys()): names.append(k) scalars.append(scalar_outputs[k]) scalars = torch.stack(scalars, dim=0) dist.reduce(scalars, dst=0) if dist.get_rank() == 0: # only main process gets accumulated, so only divide by # world_size in this case scalars /= world_size reduced_scalars = {k: v for k, v in zip(names, scalars)} return reduced_scalars import torch from bisect import bisect_right # FIXME ideally this would be achieved with a CombinedLRScheduler, # separating MultiStepLR with WarmupLR # but the current LRScheduler design doesn't allow it class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): def __init__( self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, warmup_iters=500, warmup_method="linear", last_epoch=-1, ): if not list(milestones) == sorted(milestones): raise ValueError( "Milestones should be a list of" " increasing integers. Got {}", milestones, ) if warmup_method not in ("constant", "linear"): raise ValueError( "Only 'constant' or 'linear' warmup_method accepted" "got {}".format(warmup_method) ) self.milestones = milestones self.gamma = gamma self.warmup_factor = warmup_factor self.warmup_iters = warmup_iters self.warmup_method = warmup_method super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) def get_lr(self): warmup_factor = 1 if self.last_epoch < self.warmup_iters: if self.warmup_method == "constant": warmup_factor = self.warmup_factor elif self.warmup_method == "linear": alpha = float(self.last_epoch) / self.warmup_iters warmup_factor = self.warmup_factor * (1 - alpha) + alpha return [ base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) for base_lr in self.base_lrs ] def set_random_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def local_pcd(depth, intr): nx = depth.shape[1] # w ny = depth.shape[0] # h x, y = np.meshgrid(np.arange(nx), np.arange(ny), indexing='xy') x = x.reshape(nx * ny) y = y.reshape(nx * ny) p2d = np.array([x, y, np.ones_like(y)]) p3d = np.matmul(np.linalg.inv(intr), p2d) depth = depth.reshape(1, nx * ny) p3d *= depth p3d = np.transpose(p3d, (1, 0)) p3d = p3d.reshape(ny, nx, 3).astype(np.float32) return p3d def generate_pointcloud(rgb, depth, ply_file, intr, scale=1.0): """ Generate a colored point cloud in PLY format from a color and a depth image. Input: rgb_file -- filename of color image depth_file -- filename of depth image ply_file -- filename of ply file """ fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] points = [] for v in range(rgb.shape[0]): for u in range(rgb.shape[1]): color = rgb[v, u] #rgb.getpixel((u, v)) Z = depth[v, u] / scale if Z == 0: continue X = (u - cx) * Z / fx Y = (v - cy) * Z / fy points.append("%f %f %f %d %d %d 0\n" % (X, Y, Z, color[0], color[1], color[2])) file = open(ply_file, "w") file.write('''ply format ascii 1.0 element vertex %d property float x property float y property float z property uchar red property uchar green property uchar blue property uchar alpha end_header %s ''' % (len(points), "".join(points))) file.close() print("save ply, fx:{}, fy:{}, cx:{}, cy:{}".format(fx, fy, cx, cy))