Showing preview only (259K chars total). Download the full file or copy to clipboard to get everything.
Repository: JeffWang987/MVSTER
Branch: main
Commit: 3f8bb98bba0c
Files: 35
Total size: 246.6 KB
Directory structure:
gitextract_j4zdg8al/
├── .gitignore
├── LICENSE
├── README.md
├── datasets/
│ ├── __init__.py
│ ├── blendedmvs.py
│ ├── data_io.py
│ ├── dtu_yao4.py
│ ├── eth3d.py
│ ├── general_eval4.py
│ └── tanks.py
├── evaluations/
│ └── dtu/
│ ├── BaseEval2Obj_web.m
│ ├── BaseEvalMain_func.m
│ ├── BaseEvalMain_web.m
│ ├── ComputeStat_func.m
│ ├── ComputeStat_web.m
│ ├── MaxDistCP.m
│ ├── PointCompareMain.m
│ ├── plyread.m
│ └── reducePts_haa.m
├── lists/
│ ├── blendedmvs/
│ │ ├── train.txt
│ │ └── val.txt
│ └── dtu/
│ ├── test.txt
│ ├── train.txt
│ ├── trainval.txt
│ └── val.txt
├── models/
│ ├── MVS4Net.py
│ ├── __init__.py
│ ├── module.py
│ └── mvs4net_utils.py
├── requirements.txt
├── scripts/
│ ├── test_dtu.sh
│ └── train_dtu.sh
├── test_mvs4.py
├── train_mvs4.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
outputs/
checkpoints/
debug_figs/
*__pycache__
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2022 Jeff Wang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# MVSTER
MVSTER: Epipolar Transformer for Efficient Multi-View Stereo, ECCV 2022. [arXiv](https://arxiv.org/abs/2204.07346)
This repository contains the official implementation of the paper: "MVSTER: Epipolar Transformer for Efficient Multi-View Stereo".
## Introduction
MVSTER is a learning-based MVS method which achieves competitive reconstruction performance with significantly higher efficiency. MVSTER leverages the proposed epipolar Transformer to learn both 2D semantics and 3D spatial associations efficiently. Specifically, the epipolar Transformer utilizes a detachable monocular depth estimator to enhance 2D semantics and uses cross-attention to construct data-dependent 3D associations along epipolar line. Additionally, MVSTER is built in a cascade structure, where entropy-regularized optimal transport is leveraged to propagate finer depth estimations in each stage.

## Installation
MVSTER is tested on:
* python 3.7
* CUDA 11.1
### Requirements
```
pip install -r requirements.txt
```
## Training
* Dowload [DTU dataset](https://roboimagedata.compute.dtu.dk/). For convenience, can download the preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view)
and [Depths_raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip)
(both from [Original MVSNet](https://github.com/YoYo000/MVSNet)), and upzip it as the $DTU_TRAINING folder. For training and testing with raw image size, you can download [Rectified_raw](http://roboimagedata2.compute.dtu.dk/data/MVS/Rectified.zip), and unzip it.
```
├── Cameras
├── Depths
├── Depths_raw
├── Rectified
├── Rectified_raw (Optional)
```
In ``scripts/train_dtu.sh``, set ``DTU_TRAINING`` as $DTU_TRAINING
Train MVSTER (Multi-GPU training):
* Train with middle size (512x640):
```
bash ./scripts/train_dtu.sh mid exp_name
```
* Train with raw size (1200x1600):
```
bash ./scripts/train_dtu.sh raw exp_name
```
After training, you will get model checkpoints in ./checkpoints/dtu/exp_name.
## Testing
* Download the preprocessed test data [DTU testing data](https://drive.google.com/open?id=135oKPefcPTsdtLRzoDAQtPpHuoIrpRI_) (from [Original MVSNet](https://github.com/YoYo000/MVSNet)) and unzip it as the $DTU_TESTPATH folder, which should contain one ``cams`` folder, one ``images`` folder and one ``pair.txt`` file.
* In ``scripts/test_dtu.sh``, set ``DTU_TESTPATH`` as $DTU_TESTPATH.
* The ``DTU_CKPT_FILE`` is automatically set as your pretrained checkpoint file, you also can download my [pretrained model](https://github.com/JeffWang987/MVSTER/releases/tag/dtu_ckpt).
* Test with middle size:
```
bash ./scripts/test_dtu.sh mid exp_name
```
* Test with raw size:
```
bash ./scripts/test_dtu.sh raw exp_name
```
* Test with provided pretrained model:
```
bash scripts/test_dtu.sh mid benchmark --loadckpt PATH_TO_CKPT_FILE
```
After testing, you will get reconstructed point clouds of DTU test set in ./outputs/dtu/exp_name.
## Metric
* For quantitative evaluation, download [SampleSet](http://roboimagedata.compute.dtu.dk/?page_id=36) and [Points](http://roboimagedata.compute.dtu.dk/?page_id=36) from DTU's website. Unzip them and place `Points` folder in `SampleSet/MVS Data/`. The structure looks like:
```
SampleSet
├──MVS Data
└──Points
```
* For convinience evaluation, please install matlab (tested on Ubuntu 18.04) and uncomment **mrun_rst** function at the end of **./test_mvs4.py**, and you also need to change the path of matlab excutable file (for me, it is /mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/misc/matlab/bin/matlab). Then you can evaluate point cloud reconstruction results when testing is finished.
* You can also evaluate the metrics with the traditional steps:
In ``evaluations/dtu/BaseEvalMain_web.m``, set `dataPath` as the path to `SampleSet/MVS Data/`, `plyPath` as directory that stores the reconstructed point clouds and `resultsPath` as directory to store the evaluation results. Then run ``evaluations/dtu/BaseEvalMain_web.m`` in matlab.
## Results on DTU (single RTX 3090)
| | Acc. | Comp. | Overall. | Inf. Time |
|-----------------------|--------|--------|----------|-----------|
| MVSTER (mid size) | 0.350 | 0.276 | 0.313 | 0.09s |
| MVSTER (raw size) | 0.340 | 0.266 | 0.303 | 0.17s |
Point cloud results on [DTU](https://github.com/JeffWang987/MVSTER/releases/tag/DTU_ply), [Tanks and Temples](https://github.com/JeffWang987/MVSTER/releases/tag/T%26T_ply), [ETH3D](https://github.com/JeffWang987/MVSTER/releases/tag/ETH3D_ply)
 
 
If you find this project useful for your research, please cite:
```
@misc{wang2022mvster,
title={MVSTER: Epipolar Transformer for Efficient Multi-View Stereo},
author={Xiaofeng Wang, Zheng Zhu, Fangbo Qin, Yun Ye, Guan Huang, Xu Chi, Yijia He and Xingang Wang},
journal={arXiv preprint arXiv:2204.07346},
year={2022}
}
```
## Acknowledgements
Our work is partially baed on these opening source work: [MVSNet](https://github.com/YoYo000/MVSNet), [MVSNet-pytorch](https://github.com/xy-guo/MVSNet_pytorch), [cascade-stereo](https://github.com/alibaba/cascade-stereo), [PatchmatchNet](https://github.com/FangjinhuaWang/PatchmatchNet).
We appreciate their contributions to the MVS community.
================================================
FILE: datasets/__init__.py
================================================
import importlib
# find the dataset definition by name, for example dtu_yao (dtu_yao.py)
def find_dataset_def(dataset_name):
module_name = 'datasets.{}'.format(dataset_name)
module = importlib.import_module(module_name)
return getattr(module, "MVSDataset")
================================================
FILE: datasets/blendedmvs.py
================================================
from torch.utils.data import Dataset
from datasets.data_io import *
import os
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms as T
import random
import copy
def check_invalid_input(imgs, depths, masks, depth_mins, depth_maxs):
for img in imgs:
assert np.isnan(img).sum() == 0
assert np.isinf(img).sum() == 0
for depth in depths.values():
assert np.isnan(depth).sum() == 0
assert np.isinf(depth).sum() == 0
for mask in masks.values():
assert np.isnan(mask).sum() == 0
assert np.isinf(mask).sum() == 0
assert (depth_mins<=0) == 0
assert (depth_maxs<=depth_mins) == 0
class MVSDataset(Dataset):
def __init__(self, datapath, listfile, split, nviews, img_wh=(768, 576), robust_train=True):
super(MVSDataset, self).__init__()
self.levels = 4
self.datapath = datapath
self.split = split
self.listfile = listfile
self.robust_train = robust_train
assert self.split in ['train', 'val', 'all'], \
'split must be either "train", "val" or "all"!'
self.img_wh = img_wh
if img_wh is not None:
assert img_wh[0]%32==0 and img_wh[1]%32==0, \
'img_wh must both be multiples of 32!'
self.nviews = nviews
self.scale_factors = {} # depth scale factors for each scan
self.scale_factor = 0 # depth scale factors for each scan
self.build_metas()
self.color_augment = T.ColorJitter(brightness=0.5, contrast=0.5)
def build_metas(self):
self.metas = []
with open(self.listfile) as f:
self.scans = [line.rstrip() for line in f.readlines()]
for scan in self.scans:
with open(os.path.join(self.datapath, scan, "cams/pair.txt")) as f:
num_viewpoint = int(f.readline())
for _ in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
if len(src_views) >= self.nviews-1:
self.metas += [(scan, ref_view, src_views)]
def read_cam_file(self, scan, filename):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
depth_min = float(lines[11].split()[0])
depth_max = float(lines[11].split()[-1])
if scan not in self.scale_factors:
self.scale_factors[scan] = 100.0 / depth_min
depth_min *= self.scale_factors[scan]
depth_max *= self.scale_factors[scan]
extrinsics[:3, 3] *= self.scale_factors[scan]
return intrinsics, extrinsics, depth_min, depth_max
def read_depth_mask(self, scan, filename, depth_min, depth_max, scale):
depth = np.array(read_pfm(filename)[0], dtype=np.float32)
# depth = (depth * self.scale_factor) * scale
depth = (depth * self.scale_factors[scan]) * scale
# depth = depth * scale
# depth = np.squeeze(depth,2)
mask = (depth>=depth_min) & (depth<=depth_max)
assert mask.sum() > 0
mask = mask.astype(np.float32)
if self.img_wh is not None:
depth = cv2.resize(depth, self.img_wh,
interpolation=cv2.INTER_NEAREST)
h, w = depth.shape
depth_ms = {}
mask_ms = {}
for i in range(4):
depth_cur = cv2.resize(depth, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
depth_ms[f"stage{4-i}"] = depth_cur
mask_ms[f"stage{4-i}"] = mask_cur
return depth_ms, mask_ms
def read_img(self, filename):
img = Image.open(filename)
# img = self.color_augment(img)
# scale 0~255 to 0~1
np_img = np.array(img, dtype=np.float32) / 255.
return np_img
def __len__(self):
return len(self.metas)
def __getitem__(self, idx):
meta = self.metas[idx]
scan, ref_view, src_views = meta
if self.robust_train:
num_src_views = len(src_views)
index = random.sample(range(num_src_views), self.nviews - 1)
view_ids = [ref_view] + [src_views[i] for i in index]
scale = random.uniform(0.8, 1.25)
else:
view_ids = [ref_view] + src_views[:self.nviews - 1]
scale = 1
imgs = []
mask = None
depth = None
depth_min = None
depth_max = None
proj={}
proj_matrices_0 = []
proj_matrices_1 = []
proj_matrices_2 = []
proj_matrices_3 = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(self.datapath, '{}/blended_images/{:0>8}.jpg'.format(scan, vid))
depth_filename = os.path.join(self.datapath, '{}/rendered_depth_maps/{:0>8}.pfm'.format(scan, vid))
proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))
img = self.read_img(img_filename)
imgs.append(img.transpose(2,0,1))
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(scan, proj_mat_filename)
# proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid)
proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
extrinsics[:3, 3] *= scale
intrinsics[:2,:] *= 0.125
proj_mat_0[0,:4,:4] = extrinsics.copy()
proj_mat_0[1,:3,:3] = intrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat_1[0,:4,:4] = extrinsics.copy()
proj_mat_1[1,:3,:3] = intrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat_2[0,:4,:4] = extrinsics.copy()
proj_mat_2[1,:3,:3] = intrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat_3[0,:4,:4] = extrinsics.copy()
proj_mat_3[1,:3,:3] = intrinsics.copy()
proj_matrices_0.append(proj_mat_0)
proj_matrices_1.append(proj_mat_1)
proj_matrices_2.append(proj_mat_2)
proj_matrices_3.append(proj_mat_3)
if i == 0: # reference view
depth_min = depth_min_ * scale
depth_max = depth_max_ * scale
depth, mask = self.read_depth_mask(scan, depth_filename, depth_min, depth_max, scale)
for l in range(self.levels):
mask[f'stage{l+1}'] = mask[f'stage{l+1}'] # np.expand_dims(mask[f'stage{l+1}'],2)
depth[f'stage{l+1}'] = depth[f'stage{l+1}']
proj['stage1'] = np.stack(proj_matrices_0)
proj['stage2'] = np.stack(proj_matrices_1)
proj['stage3'] = np.stack(proj_matrices_2)
proj['stage4'] = np.stack(proj_matrices_3)
# check_invalid_input(imgs, depth, mask, depth_min, depth_max)
# data is numpy array
return {"imgs": imgs, # [Nv, 3, H, W]
"proj_matrices": proj, # [N,2,4,4]
"depth": depth, # [1, H, W]
"depth_values": np.array([depth_min, depth_max], dtype=np.float32),
"mask": mask} # [1, H, W]
================================================
FILE: datasets/data_io.py
================================================
import numpy as np
import re
import sys
def read_pfm(filename):
file = open(filename, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().decode('utf-8').rstrip()
if header == 'PF':
color = True
elif header == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
file.close()
return data, scale
def save_pfm(filename, image, scale=1):
file = open(filename, "wb")
color = None
image = np.flipud(image)
if image.dtype.name != 'float32':
raise Exception('Image dtype must be float32.')
if len(image.shape) == 3 and image.shape[2] == 3: # color image
color = True
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
color = False
else:
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8'))
endian = image.dtype.byteorder
if endian == '<' or endian == '=' and sys.byteorder == 'little':
scale = -scale
file.write(('%f\n' % scale).encode('utf-8'))
image.tofile(file)
file.close()
import random, cv2
class RandomCrop(object):
def __init__(self, CropSize=0.1):
self.CropSize = CropSize
def __call__(self, image, normal):
h, w = normal.shape[:2]
img_h, img_w = image.shape[:2]
CropSize_w, CropSize_h = max(1, int(w * self.CropSize)), max(1, int(h * self.CropSize))
x1, y1 = random.randint(0, CropSize_w), random.randint(0, CropSize_h)
x2, y2 = random.randint(w - CropSize_w, w), random.randint(h - CropSize_h, h)
normal_crop = normal[y1:y2, x1:x2]
normal_resize = cv2.resize(normal_crop, (w, h), interpolation=cv2.INTER_NEAREST)
image_crop = image[4*y1:4*y2, 4*x1:4*x2]
image_resize = cv2.resize(image_crop, (img_w, img_h), interpolation=cv2.INTER_LINEAR)
# import matplotlib.pyplot as plt
# plt.subplot(2, 3, 1)
# plt.imshow(image)
# plt.subplot(2, 3, 2)
# plt.imshow(image_crop)
# plt.subplot(2, 3, 3)
# plt.imshow(image_resize)
#
# plt.subplot(2, 3, 4)
# plt.imshow((normal + 1.0) / 2, cmap="rainbow")
# plt.subplot(2, 3, 5)
# plt.imshow((normal_crop + 1.0) / 2, cmap="rainbow")
# plt.subplot(2, 3, 6)
# plt.imshow((normal_resize + 1.0) / 2, cmap="rainbow")
# plt.show()
# plt.pause(1)
# plt.close()
return image_resize, normal_resize
================================================
FILE: datasets/dtu_yao4.py
================================================
from torch.utils.data import Dataset
import numpy as np
import os, cv2, time, math
from PIL import Image
from datasets.data_io import *
from torchvision import transforms
# the DTU dataset preprocessed by Yao Yao (only for training)
class MVSDataset(Dataset):
def __init__(self, datapath, listfile, mode, nviews, interval_scale=1.06, **kwargs):
super(MVSDataset, self).__init__()
self.datapath = datapath
self.listfile = listfile
self.mode = mode
self.nviews = nviews
self.ndepths = 192 # Hardcode
self.interval_scale = interval_scale
self.kwargs = kwargs
self.rt = kwargs.get("rt", False)
self.use_raw_train = kwargs.get("use_raw_train", False)
self.color_augment = transforms.ColorJitter(brightness=0.5, contrast=0.5)
assert self.mode in ["train", "val", "test"]
self.metas = self.build_list()
def build_list(self):
metas = []
with open(self.listfile) as f:
scans = f.readlines()
scans = [line.rstrip() for line in scans]
# scans
for scan in scans:
pair_file = "Cameras/pair.txt"
# read the pair file
with open(os.path.join(self.datapath, pair_file)) as f:
num_viewpoint = int(f.readline())
# viewpoints (49)
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
# light conditions 0-6
for light_idx in range(7):
metas.append((scan, light_idx, ref_view, src_views))
# print("dataset", self.mode, "metas:", len(metas))
return metas
def __len__(self):
return len(self.metas)
def read_cam_file(self, filename):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
# depth_min & depth_interval: line 11
depth_min = float(lines[11].split()[0])
depth_interval = float(lines[11].split()[1]) * self.interval_scale
return intrinsics, extrinsics, depth_min, depth_interval
def read_img(self, filename):
img = Image.open(filename)
if self.mode == 'train':
img = self.color_augment(img)
# scale 0~255 to 0~1
np_img = np.array(img, dtype=np.float32) / 255.
return np_img
def crop_img(self, img):
raw_h, raw_w = img.shape[:2]
start_h = (raw_h-1024)//2
start_w = (raw_w-1280)//2
return img[start_h:start_h+1024, start_w:start_w+1280, :] # 1024, 1280, C
def prepare_img(self, hr_img):
h, w = hr_img.shape
if not self.use_raw_train:
#w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128
#downsample
hr_img_ds = cv2.resize(hr_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST)
h, w = hr_img_ds.shape
target_h, target_w = 512, 640
start_h, start_w = (h - target_h)//2, (w - target_w)//2
hr_img_crop = hr_img_ds[start_h: start_h + target_h, start_w: start_w + target_w]
elif self.use_raw_train:
hr_img_crop = hr_img[h//2-1024//2:h//2+1024//2, w//2-1280//2:w//2+1280//2] # 1024, 1280, c
return hr_img_crop
def read_mask_hr(self, filename):
img = Image.open(filename)
np_img = np.array(img, dtype=np.float32)
np_img = (np_img > 10).astype(np.float32)
np_img = self.prepare_img(np_img)
h, w = np_img.shape
np_img_ms = {
"stage1": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_NEAREST),
"stage2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_NEAREST),
"stage3": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST),
"stage4": np_img,
}
return np_img_ms
def read_depth_hr(self, filename, scale):
# read pfm depth file
#w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128
depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) * scale
depth_lr = self.prepare_img(depth_hr)
h, w = depth_lr.shape
depth_lr_ms = {
"stage1": cv2.resize(depth_lr, (w//8, h//8), interpolation=cv2.INTER_NEAREST),
"stage2": cv2.resize(depth_lr, (w//4, h//4), interpolation=cv2.INTER_NEAREST),
"stage3": cv2.resize(depth_lr, (w//2, h//2), interpolation=cv2.INTER_NEAREST),
"stage4": depth_lr,
}
return depth_lr_ms
def __getitem__(self, idx):
meta = self.metas[idx]
scan, light_idx, ref_view, src_views = meta
# use only the reference view and first nviews-1 source views
if self.mode == 'train' and self.rt:
num_src_views = len(src_views)
index = random.sample(range(num_src_views), self.nviews - 1)
view_ids = [ref_view] + [src_views[i] for i in index]
scale = random.uniform(0.8, 1.25)
else:
view_ids = [ref_view] + src_views[:self.nviews - 1]
scale = 1
imgs = []
mask = None
depth_values = None
proj_matrices = []
for i, vid in enumerate(view_ids):
# NOTE that the id in image file names is from 1 to 49 (not 0~48)
if not self.use_raw_train:
img_filename = os.path.join(self.datapath, 'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx))
else:
img_filename = os.path.join(self.datapath, 'Rectified_raw/{}/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx))
mask_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid))
depth_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid))
proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid)
img = self.read_img(img_filename)
if self.use_raw_train:
img = self.crop_img(img)
intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename)
if self.rt:
extrinsics[:3,3] *= scale
if self.use_raw_train:
intrinsics[:2, :] *= 2.0
if i == 0:
mask_read_ms = self.read_mask_hr(mask_filename_hr)
depth_ms = self.read_depth_hr(depth_filename_hr, scale)
#get depth values
depth_max = depth_interval * self.ndepths + depth_min
depth_values = np.array([depth_min * scale, depth_max * scale], dtype=np.float32)
mask = mask_read_ms
proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) #
proj_mat[0, :4, :4] = extrinsics
proj_mat[1, :3, :3] = intrinsics
proj_matrices.append(proj_mat)
imgs.append(img.transpose(2,0,1))
#all
# imgs = np.stack(imgs).transpose([0, 3, 1, 2])
#ms proj_mats
proj_matrices = np.stack(proj_matrices)
stage1_pjmats = proj_matrices.copy()
stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] / 2.0
stage3_pjmats = proj_matrices.copy()
stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2
stage4_pjmats = proj_matrices.copy()
stage4_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4
proj_matrices_ms = {
"stage1": stage1_pjmats,
"stage2": proj_matrices,
"stage3": stage3_pjmats,
"stage4": stage4_pjmats
}
return {"imgs": imgs, # Nv C H W
"proj_matrices": proj_matrices_ms, # 4 stage of Nv 2 4 4
"depth": depth_ms,
"depth_values": depth_values,
"mask": mask }
================================================
FILE: datasets/eth3d.py
================================================
from torch.utils.data import Dataset
from datasets.data_io import *
import os
import numpy as np
import cv2
from PIL import Image
class MVSDataset(Dataset):
def __init__(self, datapath, split='test', n_views=7, img_wh=(1920,1280)):
self.levels = 4
self.datapath = datapath
self.img_wh = img_wh
self.split = split
self.build_metas()
self.n_views = n_views
def build_metas(self):
self.metas = []
if self.split == "test":
self.scans = ['botanical_garden', 'boulders', 'bridge', 'door',
'exhibition_hall', 'lecture_room', 'living_room', 'lounge',
'observatory', 'old_computer', 'statue', 'terrace_2']
elif self.split == "train":
self.scans = ['courtyard', 'delivery_area', 'electro', 'facade',
'kicker', 'meadow', 'office', 'pipes', 'playground',
'relief', 'relief_2', 'terrace', 'terrains']
for scan in self.scans:
with open(os.path.join(self.datapath, scan, 'pair.txt')) as f:
num_viewpoint = int(f.readline())
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
if len(src_views) != 0:
self.metas += [(scan, -1, ref_view, src_views)]
def read_cam_file(self, filename):
with open(filename) as f:
lines = [line.rstrip() for line in f.readlines()]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
extrinsics = extrinsics.reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
intrinsics = intrinsics.reshape((3, 3))
depth_min = float(lines[11].split()[0])
if depth_min < 0:
depth_min = 1
depth_max = float(lines[11].split()[-1])
return intrinsics, extrinsics, depth_min, depth_max
def read_img(self, filename):
img = Image.open(filename)
np_img = np.array(img, dtype=np.float32) / 255.
original_h, original_w, _ = np_img.shape
np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)
return np_img, original_h, original_w
def __len__(self):
return len(self.metas)
def __getitem__(self, idx):
scan, _, ref_view, src_views = self.metas[idx]
# use only the reference view and first nviews-1 source views
view_ids = [ref_view] + src_views[:self.n_views-1]
imgs = []
# depth = None
depth_min = None
depth_max = None
proj_matrices_0 = []
proj_matrices_1 = []
proj_matrices_2 = []
proj_matrices_3 = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(self.datapath, scan, f'images/{vid:08d}.jpg')
proj_mat_filename = os.path.join(self.datapath, scan, f'cams_1/{vid:08d}_cam.txt')
img, original_h, original_w = self.read_img(img_filename)
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
intrinsics[0] *= self.img_wh[0]/original_w
intrinsics[1] *= self.img_wh[1]/original_h
imgs.append(img.transpose(2,0,1))
proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
intrinsics[:2,:] *= 0.125
proj_mat_0[0,:4,:4] = extrinsics.copy()
proj_mat_0[1,:3,:3] = intrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat_1[0,:4,:4] = extrinsics.copy()
proj_mat_1[1,:3,:3] = intrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat_2[0,:4,:4] = extrinsics.copy()
proj_mat_2[1,:3,:3] = intrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat_3[0,:4,:4] = extrinsics.copy()
proj_mat_3[1,:3,:3] = intrinsics.copy()
proj_matrices_0.append(proj_mat_0)
proj_matrices_1.append(proj_mat_1)
proj_matrices_2.append(proj_mat_2)
proj_matrices_3.append(proj_mat_3)
if i == 0: # reference view
depth_min = depth_min_
depth_max = depth_max_
# proj_matrices: N*4*4
proj={}
proj['stage1'] = np.stack(proj_matrices_0)
proj['stage2'] = np.stack(proj_matrices_1)
proj['stage3'] = np.stack(proj_matrices_2)
proj['stage4'] = np.stack(proj_matrices_3)
return {"imgs": imgs, # N*3*H0*W0
"proj_matrices": proj, # N*4*4
"depth_values": np.array([depth_min, depth_max], dtype=np.float32),
"filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"
}
================================================
FILE: datasets/general_eval4.py
================================================
from torch.utils.data import Dataset
import numpy as np
import os, cv2, time
from PIL import Image
from datasets.data_io import *
s_h, s_w = 0, 0
class MVSDataset(Dataset):
def __init__(self, datapath, listfile, mode, nviews, interval_scale=1.06, **kwargs):
super(MVSDataset, self).__init__()
self.datapath = datapath
self.listfile = listfile
self.mode = mode
self.nviews = nviews
self.ndepths = 192 # Hardcode
self.interval_scale = interval_scale
self.max_h, self.max_w = kwargs["max_h"], kwargs["max_w"]
self.fix_res = kwargs.get("fix_res", False) #whether to fix the resolution of input image.
self.fix_wh = False
assert self.mode == "test"
self.metas = self.build_list()
def build_list(self):
metas = []
scans = self.listfile
interval_scale_dict = {}
# scans
for scan in scans:
# determine the interval scale of each scene. default is 1.06
if isinstance(self.interval_scale, float):
interval_scale_dict[scan] = self.interval_scale
else:
interval_scale_dict[scan] = self.interval_scale[scan]
pair_file = "{}/pair.txt".format(scan)
# read the pair file
with open(os.path.join(self.datapath, pair_file)) as f:
num_viewpoint = int(f.readline())
# viewpoints
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
# filter by no src view and fill to nviews
if len(src_views) > 0:
if len(src_views) < self.nviews:
print("{}< num_views:{}".format(len(src_views), self.nviews))
src_views += [src_views[0]] * (self.nviews - len(src_views))
metas.append((scan, ref_view, src_views, scan))
self.interval_scale = interval_scale_dict
print("dataset", self.mode, "metas:", len(metas), "interval_scale:{}".format(self.interval_scale))
return metas
def __len__(self):
return len(self.metas)
def read_cam_file(self, filename, interval_scale):
with open(filename) as f:
lines = f.readlines()
lines = [line.rstrip() for line in lines]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
intrinsics[:2, :] /= 4.0
# depth_min & depth_interval: line 11
depth_min = float(lines[11].split()[0])
depth_interval = float(lines[11].split()[1])
if len(lines[11].split()) >= 3:
num_depth = lines[11].split()[2]
depth_max = depth_min + int(float(num_depth)) * depth_interval
depth_interval = (depth_max - depth_min) / self.ndepths
depth_interval *= interval_scale
return intrinsics, extrinsics, depth_min, depth_interval
def read_img(self, filename):
img = Image.open(filename)
# scale 0~255 to 0~1
np_img = np.array(img, dtype=np.float32) / 255.
return np_img
def read_depth(self, filename):
# read pfm depth file
return np.array(read_pfm(filename)[0], dtype=np.float32)
def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=64):
h, w = img.shape[:2]
if h > max_h or w > max_w:
scale = 1.0 * max_h / h
if scale * w > max_w:
scale = 1.0 * max_w / w
new_w, new_h = scale * w // base * base, scale * h // base * base
else:
new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base
scale_w = 1.0 * new_w / w
scale_h = 1.0 * new_h / h
intrinsics[0, :] *= scale_w
intrinsics[1, :] *= scale_h
img = cv2.resize(img, (int(new_w), int(new_h)))
return img, intrinsics
def __getitem__(self, idx):
global s_h, s_w
meta = self.metas[idx]
scan, ref_view, src_views, scene_name = meta
# use only the reference view and first nviews-1 source views
view_ids = [ref_view] + src_views[:self.nviews - 1]
imgs = []
depth_values = None
proj_matrices = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(self.datapath, '{}/images_post/{:0>8}.jpg'.format(scan, vid))
if not os.path.exists(img_filename):
img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid))
proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))
img = self.read_img(img_filename)
intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename, interval_scale=
self.interval_scale[scene_name])
# scale input
img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_w, self.max_h)
if self.fix_res:
# using the same standard height or width in entire scene.
s_h, s_w = img.shape[:2]
self.fix_res = False
self.fix_wh = True
if i == 0:
if not self.fix_wh:
# using the same standard height or width in each nviews.
s_h, s_w = img.shape[:2]
# resize to standard height or width
c_h, c_w = img.shape[:2]
if (c_h != s_h) or (c_w != s_w):
scale_h = 1.0 * s_h / c_h
scale_w = 1.0 * s_w / c_w
img = cv2.resize(img, (s_w, s_h))
intrinsics[0, :] *= scale_w
intrinsics[1, :] *= scale_h
imgs.append(img.transpose(2,0,1))
# extrinsics, intrinsics
proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) #
proj_mat[0, :4, :4] = extrinsics
proj_mat[1, :3, :3] = intrinsics
proj_matrices.append(proj_mat)
if i == 0: # reference view
depth_values = np.arange(depth_min, depth_interval * (self.ndepths - 0.5) + depth_min, depth_interval,
dtype=np.float32)
#all
# imgs = np.stack(imgs).transpose([0, 3, 1, 2])
#ms proj_mats
proj_matrices = np.stack(proj_matrices)
stage1_pjmats = proj_matrices.copy()
stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] / 2.0
stage3_pjmats = proj_matrices.copy()
stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2
stage4_pjmats = proj_matrices.copy()
stage4_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4
proj_matrices_ms = {
"stage1": stage1_pjmats,
"stage2": proj_matrices,
"stage3": stage3_pjmats,
"stage4": stage4_pjmats
}
return {"imgs": imgs,
"proj_matrices": proj_matrices_ms,
"depth_values": depth_values,
"filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"}
================================================
FILE: datasets/tanks.py
================================================
from torch.utils.data import Dataset
from datasets.data_io import *
import os
import numpy as np
import cv2
from PIL import Image
class MVSDataset(Dataset):
def __init__(self, datapath, n_views=7, split='intermediate'):
self.levels = 4
self.datapath = datapath
self.split = split
self.build_metas()
self.n_views = n_views
def build_metas(self):
self.metas = []
if self.split == 'intermediate':
self.scans = ['Family', 'Francis', 'Horse', 'Playground', 'Train', 'Lighthouse', 'M60', 'Panther']
elif self.split == 'advanced':
self.scans = ['Auditorium', 'Ballroom', 'Courtroom',
'Museum', 'Palace', 'Temple']
for scan in self.scans:
with open(os.path.join(self.datapath, self.split, scan, 'pair.txt')) as f:
num_viewpoint = int(f.readline())
for view_idx in range(num_viewpoint):
ref_view = int(f.readline().rstrip())
src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
if len(src_views) != 0:
self.metas += [(scan, -1, ref_view, src_views)]
def read_cam_file(self, filename):
with open(filename) as f:
lines = [line.rstrip() for line in f.readlines()]
# extrinsics: line [1,5), 4x4 matrix
extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
extrinsics = extrinsics.reshape((4, 4))
# intrinsics: line [7-10), 3x3 matrix
intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
intrinsics = intrinsics.reshape((3, 3))
depth_min = float(lines[11].split()[0])
depth_max = float(lines[11].split()[-1])
return intrinsics, extrinsics, depth_min, depth_max
def read_img(self, filename):
img = Image.open(filename)
np_img = np.array(img, dtype=np.float32) / 255.
return np_img
def scale_input(self, intrinsics, img):
"""
intrinsics: 3x3
img: W H C
"""
intrinsics[1,2] = intrinsics[1,2] - 28 # 1080 -> 1024
img = img[28:1080-28, :, :]
return intrinsics, img
def __len__(self):
return len(self.metas)
def __getitem__(self, idx):
scan, _, ref_view, src_views = self.metas[idx]
# use only the reference view and first nviews-1 source views
view_ids = [ref_view] + src_views[:self.n_views-1]
imgs = []
# depth = None
depth_min = None
depth_max = None
proj_matrices_0 = []
proj_matrices_1 = []
proj_matrices_2 = []
proj_matrices_3 = []
for i, vid in enumerate(view_ids):
img_filename = os.path.join(self.datapath, self.split, scan, f'images/{vid:08d}.jpg')
proj_mat_filename = os.path.join(self.datapath, self.split, scan, f'cams/{vid:08d}_cam.txt')
img = self.read_img(img_filename)
intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
intrinsics, img = self.scale_input(intrinsics, img)
imgs.append(img.transpose(2,0,1))
proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
intrinsics[:2,:] *= 0.125
proj_mat_0[0,:4,:4] = extrinsics.copy()
proj_mat_0[1,:3,:3] = intrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat_1[0,:4,:4] = extrinsics.copy()
proj_mat_1[1,:3,:3] = intrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat_2[0,:4,:4] = extrinsics.copy()
proj_mat_2[1,:3,:3] = intrinsics.copy()
intrinsics[:2,:] *= 2
proj_mat_3[0,:4,:4] = extrinsics.copy()
proj_mat_3[1,:3,:3] = intrinsics.copy()
proj_matrices_0.append(proj_mat_0)
proj_matrices_1.append(proj_mat_1)
proj_matrices_2.append(proj_mat_2)
proj_matrices_3.append(proj_mat_3)
if i == 0: # reference view
depth_min = depth_min_
depth_max = depth_max_
# proj_matrices: N*4*4
proj={}
proj['stage1'] = np.stack(proj_matrices_0)
proj['stage2'] = np.stack(proj_matrices_1)
proj['stage3'] = np.stack(proj_matrices_2)
proj['stage4'] = np.stack(proj_matrices_3)
return {"imgs": imgs, # N*3*H0*W0
"proj_matrices": proj, # N*4*4
"depth_values": np.array([depth_min, depth_max], dtype=np.float32),
"filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"
}
================================================
FILE: evaluations/dtu/BaseEval2Obj_web.m
================================================
function BaseEval2Obj_web(BaseEval,method_string,outputPath)
if(nargin<3)
outputPath='./';
end
% tresshold for coloring alpha channel in the range of 0-10 mm
dist_tresshold=10;
cSet=BaseEval.cSet;
Qdata=BaseEval.Qdata;
alpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold;
fid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+');
for cP=1:size(Qdata,2)
if(BaseEval.DataInMask(cP))
C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)
else
C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis)
end
fprintf(fid,'v %f %f %f %f %f %f\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]);
end
fclose(fid);
disp('Data2Stl saved as obj')
Qstl=BaseEval.Qstl;
fid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+');
alpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold;
for cP=1:size(Qstl,2)
if(BaseEval.StlAbovePlane(cP))
C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)
else
C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis)
end
fprintf(fid,'v %f %f %f %f %f %f\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]);
end
fclose(fid);
disp('Stl2Data saved as obj')
================================================
FILE: evaluations/dtu/BaseEvalMain_func.m
================================================
function None = BaseEvalMain_func(plyPath)
% clear all
% close all
format compact
% script to calculate distances have been measured for all included scans (UsedSets)
dataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';
% pred_results='cascade_hr/48-32-8_4-2-1_dlossw-0.5-1.0-2.0_chs888/gipuma_4_0.9/';
% plyPath=['../../outputs/1101/dtu/' pred_results];
% plyPath = '/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/repo_model_aligncorners_ITGT'
resultsPath=[plyPath '/eval_out/'];
disp(resultsPath);
mkdir(resultsPath);
method_string='mvsnet';
light_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'
switch representation_string
case 'Points'
eval_string='_Eval_'; %results naming
settings_string='';
end
% get sets used in evaluation
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];
% UsedSets=[15];
dst=0.2; %Min dist between points when reducing
parfor cIdx=1:length(UsedSets)
%Data set number
cSet = UsedSets(cIdx)
%input data name
DataInName=[plyPath sprintf('/%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)]
%results name
EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat']
%check if file is already computed
if(~exist(EvalName,'file'))
disp(DataInName);
time=clock;time(4:5), drawnow
tic
Mesh = plyread(DataInName);
Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]';
toc
BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath);
disp('Saving results'), drawnow
toc
mySave(EvalName, BaseEval);
toc
% write obj-file of evaluation
% BaseEval2Obj_web(BaseEval,method_string, resultsPath)
% toc
time=clock;time(4:5), drawnow
BaseEval.MaxDist=20; %outlier threshold of 20 mm
BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane
BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers
BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask
BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers
fprintf("mean/median Data (acc.) %f/%f\n", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));
fprintf("mean/median Stl (comp.) %f/%f\n", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));
end
end
end
function mySave(filenm, data)
save(filenm, 'data');
end
================================================
FILE: evaluations/dtu/BaseEvalMain_web.m
================================================
clear all
close all
format compact
clc
% script to calculate distances have been measured for all included scans (UsedSets)
dataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';
% pred_results='cascade_hr/48-32-8_4-2-1_dlossw-0.5-1.0-2.0_chs888/gipuma_4_0.9/';
% plyPath=['../../outputs/1101/dtu/' pred_results];
plyPath = '/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/ccc_4x2_scedule_aligncorners'
resultsPath=[plyPath '/eval_out/'];
disp(resultsPath);
mkdir(resultsPath);
method_string='mvsnet';
light_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'
switch representation_string
case 'Points'
eval_string='_Eval_'; %results naming
settings_string='';
end
% get sets used in evaluation
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];
% UsedSets=[15];
dst=0.2; %Min dist between points when reducing
parfor cIdx=1:length(UsedSets)
%Data set number
cSet = UsedSets(cIdx)
%input data name
DataInName=[plyPath sprintf('/%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)]
%results name
EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat']
%check if file is already computed
if(~exist(EvalName,'file'))
disp(DataInName);
time=clock;time(4:5), drawnow
tic
Mesh = plyread(DataInName);
Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]';
toc
BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath);
disp('Saving results'), drawnow
toc
mySave(EvalName, BaseEval);
toc
% write obj-file of evaluation
% BaseEval2Obj_web(BaseEval,method_string, resultsPath)
% toc
time=clock;time(4:5), drawnow
BaseEval.MaxDist=20; %outlier threshold of 20 mm
BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane
BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers
BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask
BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers
fprintf("mean/median Data (acc.) %f/%f\n", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));
fprintf("mean/median Stl (comp.) %f/%f\n", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));
end
end
function mySave(filenm, data)
save(filenm, 'data');
end
================================================
FILE: evaluations/dtu/ComputeStat_func.m
================================================
function None = ComputeStat_func(plyPath)
format compact
% script to calculate the statistics for each scan given this will currently only run if distances have been measured
% for all included scans (UsedSets)
% modify the path to evaluate your models
dataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';
% resultsPath=['/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/repo_model_aligncorners_ITGT/eval_out/'];
resultsPath=[plyPath '/eval_out/'];
MaxDist=20; %outlier thresshold of 20 mm
time=clock;
method_string='mvsnet';
light_string='l3'; %'l7'; l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'
switch representation_string
case 'Points'
eval_string='_Eval_'; %results naming
settings_string='';
end
% get sets used in evaluation
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];
nStat=length(UsedSets);
BaseStat.nStl=zeros(1,nStat);
BaseStat.nData=zeros(1,nStat);
BaseStat.MeanStl=zeros(1,nStat);
BaseStat.MeanData=zeros(1,nStat);
BaseStat.VarStl=zeros(1,nStat);
BaseStat.VarData=zeros(1,nStat);
BaseStat.MedStl=zeros(1,nStat);
BaseStat.MedData=zeros(1,nStat);
for cStat=1:length(UsedSets) %Data set number
currentSet=UsedSets(cStat);
%input results name
EvalName=[resultsPath method_string eval_string num2str(currentSet) '.mat'];
disp(EvalName);
load(EvalName);
Dstl=data.Dstl(data.StlAbovePlane); %use only points that are above the plane
Dstl=Dstl(Dstl<MaxDist); % discard outliers
Ddata=data.Ddata(data.DataInMask); %use only points that within mask
Ddata=Ddata(Ddata<MaxDist); % discard outliers
BaseStat.nStl(cStat)=length(Dstl);
BaseStat.nData(cStat)=length(Ddata);
BaseStat.MeanStl(cStat)=mean(Dstl);
BaseStat.MeanData(cStat)=mean(Ddata);
BaseStat.VarStl(cStat)=var(Dstl);
BaseStat.VarData(cStat)=var(Ddata);
BaseStat.MedStl(cStat)=median(Dstl);
BaseStat.MedData(cStat)=median(Ddata);
disp("acc");
disp(mean(Ddata));
disp("comp");
disp(mean(Dstl));
time=clock;
end
disp(BaseStat);
disp("mean acc")
disp(mean(BaseStat.MeanData));
disp("mean comp")
disp(mean(BaseStat.MeanStl));
disp("mean overall")
disp((mean(BaseStat.MeanStl)+mean(BaseStat.MeanData))/2.0);
totalStatName=[resultsPath 'TotalStat_' method_string eval_string '.mat']
save(totalStatName,'BaseStat','time','MaxDist');
totalStatName=[resultsPath 'TotalStat_' method_string eval_string '.txt']
fp=fopen(totalStatName,'a');
fprintf(fp,'%f\n',mean(BaseStat.MeanData));
fprintf(fp,'%f\n',mean(BaseStat.MeanStl));
end
================================================
FILE: evaluations/dtu/ComputeStat_web.m
================================================
clear all
close all
format compact
clc
% script to calculate the statistics for each scan given this will currently only run if distances have been measured
% for all included scans (UsedSets)
% modify the path to evaluate your models
dataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';
resultsPath=['/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/repo_model_aligncorners_ITGT/eval_out/'];
MaxDist=20; %outlier thresshold of 20 mm
time=clock;
method_string='mvsnet';
light_string='l3'; %'l7'; l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'
switch representation_string
case 'Points'
eval_string='_Eval_'; %results naming
settings_string='';
end
% get sets used in evaluation
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];
nStat=length(UsedSets);
BaseStat.nStl=zeros(1,nStat);
BaseStat.nData=zeros(1,nStat);
BaseStat.MeanStl=zeros(1,nStat);
BaseStat.MeanData=zeros(1,nStat);
BaseStat.VarStl=zeros(1,nStat);
BaseStat.VarData=zeros(1,nStat);
BaseStat.MedStl=zeros(1,nStat);
BaseStat.MedData=zeros(1,nStat);
for cStat=1:length(UsedSets) %Data set number
currentSet=UsedSets(cStat);
%input results name
EvalName=[resultsPath method_string eval_string num2str(currentSet) '.mat'];
disp(EvalName);
load(EvalName);
Dstl=data.Dstl(data.StlAbovePlane); %use only points that are above the plane
Dstl=Dstl(Dstl<MaxDist); % discard outliers
Ddata=data.Ddata(data.DataInMask); %use only points that within mask
Ddata=Ddata(Ddata<MaxDist); % discard outliers
BaseStat.nStl(cStat)=length(Dstl);
BaseStat.nData(cStat)=length(Ddata);
BaseStat.MeanStl(cStat)=mean(Dstl);
BaseStat.MeanData(cStat)=mean(Ddata);
BaseStat.VarStl(cStat)=var(Dstl);
BaseStat.VarData(cStat)=var(Ddata);
BaseStat.MedStl(cStat)=median(Dstl);
BaseStat.MedData(cStat)=median(Ddata);
disp("acc");
disp(mean(Ddata));
disp("comp");
disp(mean(Dstl));
time=clock;
end
disp(BaseStat);
disp("mean acc")
disp(mean(BaseStat.MeanData));
disp("mean comp")
disp(mean(BaseStat.MeanStl));
disp("mean overall")
disp((mean(BaseStat.MeanStl)+mean(BaseStat.MeanData))/2.0);
totalStatName=[resultsPath 'TotalStat_' method_string eval_string '.mat']
save(totalStatName,'BaseStat','time','MaxDist');
totalStatName=[resultsPath 'TotalStat_' method_string eval_string '.txt']
fp=fopen(totalStatName,'a');
fprintf(fp,'%f\n',mean(BaseStat.MeanData));
fprintf(fp,'%f\n',mean(BaseStat.MeanStl));
================================================
FILE: evaluations/dtu/MaxDistCP.m
================================================
function Dist = MaxDistCP(Qto,Qfrom,BB,MaxDist)
Dist=ones(1,size(Qfrom,2))*MaxDist;
Range=floor((BB(2,:)-BB(1,:))/MaxDist);
tic
Done=0;
LookAt=zeros(1,size(Qfrom,2));
for x=0:Range(1),
for y=0:Range(2),
for z=0:Range(3),
Low=BB(1,:)+[x y z]*MaxDist;
High=Low+MaxDist;
idxF=find(Qfrom(1,:)>=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &...
Qfrom(1,:)<High(1) & Qfrom(2,:)<High(2) & Qfrom(3,:)<High(3));
SQfrom=Qfrom(:,idxF);
LookAt(idxF)=LookAt(idxF)+1; %Debug
Low=Low-MaxDist;
High=High+MaxDist;
idxT=find(Qto(1,:)>=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &...
Qto(1,:)<High(1) & Qto(2,:)<High(2) & Qto(3,:)<High(3));
SQto=Qto(:,idxT);
if(isempty(SQto))
Dist(idxF)=MaxDist;
else
KDstl=KDTreeSearcher(SQto');
[~,SDist] = knnsearch(KDstl,SQfrom');
Dist(idxF)=SDist;
end
Done=Done+length(idxF); %Debug
end
end
%Complete=Done/size(Qfrom,2);
%EstTime=(toc/Complete)/60
%toc
%LA=[sum(LookAt==0),...
% sum(LookAt==1),...
% sum(LookAt==2),...
% sum(LookAt==3),...
% sum(LookAt>3)]
end
================================================
FILE: evaluations/dtu/PointCompareMain.m
================================================
function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath)
% evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the
% distances from the evaluation points to the reference
tic
% reduce points 0.2 mm neighbourhood density
Qdata=reducePts_haa(Qdata,dst);
toc
StlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply'];
StlMesh = plyread(StlInName); %STL points already reduced 0.2 mm neighbourhood density
Qstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]';
%Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res)
Margin=10;
MaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat'];
load(MaskName)
MaxDist=60;
disp('Computing Data 2 Stl distances')
Ddata = MaxDistCP(Qstl,Qdata,BB,MaxDist);
toc
disp('Computing Stl 2 Data distances')
Dstl=MaxDistCP(Qdata,Qstl,BB,MaxDist);
disp('Distances computed')
toc
%use mask
%From Get mask - inverted & modified.
One=ones(1,size(Qdata,2));
Qv=(Qdata-BB(1,:)'*One)/Res+1;
Qv=round(Qv);
Midx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3));
MidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1));
Midx2=find(ObsMask(MidxA));
BaseEval.DataInMask(1:size(Qv,2))=false;
BaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask
BaseEval.cSet=cSet;
BaseEval.Margin=Margin; %Margin of masks
BaseEval.dst=dst; %Min dist between points when reducing
BaseEval.Qdata=Qdata; %Input data points
BaseEval.Ddata=Ddata; %distance from data to stl
BaseEval.Qstl=Qstl; %Input stl points
BaseEval.Dstl=Dstl; %Distance from the stl to data
load([dataPath '/ObsMask/Plane' num2str(cSet)],'P')
BaseEval.GroundPlane=P; % Plane used to destinguise which Stl points are 'used'
BaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane'
BaseEval.Time=clock; %Time when computation is finished
================================================
FILE: evaluations/dtu/plyread.m
================================================
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [Elements,varargout] = plyread(Path,Str)
%PLYREAD Read a PLY 3D data file.
% [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file
% FILENAME and returns a structure DATA. The fields in this structure
% are defined by the PLY header; each element type is a field and each
% element property is a subfield. If the file contains any comments,
% they are returned in a cell string array COMMENTS.
%
% [TRI,PTS] = PLYREAD(FILENAME,'tri') or
% [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex
% and face data into triangular connectivity and vertex arrays. The
% mesh can then be displayed using the TRISURF command.
%
% Note: This function is slow for large mesh files (+50K faces),
% especially when reading data with list type properties.
%
% Example:
% [Tri,Pts] = PLYREAD('cow.ply','tri');
% trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3));
% colormap(gray); axis equal;
%
% See also: PLYWRITE
% Pascal Getreuer 2004
[fid,Msg] = fopen(Path,'rt'); % open file in read text mode
if fid == -1, error(Msg); end
Buf = fscanf(fid,'%s',1);
if ~strcmp(Buf,'ply')
fclose(fid);
error('Not a PLY file.');
end
%%% read header %%%
Position = ftell(fid);
Format = '';
NumComments = 0;
Comments = {}; % for storing any file comments
NumElements = 0;
NumProperties = 0;
Elements = []; % structure for holding the element data
ElementCount = []; % number of each type of element in file
PropertyTypes = []; % corresponding structure recording property types
ElementNames = {}; % list of element names in the order they are stored in the file
PropertyNames = []; % structure of lists of property names
while 1
Buf = fgetl(fid); % read one line from file
BufRem = Buf;
Token = {};
Count = 0;
while ~isempty(BufRem) % split line into tokens
[tmp,BufRem] = strtok(BufRem);
if ~isempty(tmp)
Count = Count + 1; % count tokens
Token{Count} = tmp;
end
end
if Count % parse line
switch lower(Token{1})
case 'format' % read data format
if Count >= 2
Format = lower(Token{2});
if Count == 3 & ~strcmp(Token{3},'1.0')
fclose(fid);
error('Only PLY format version 1.0 supported.');
end
end
case 'comment' % read file comment
NumComments = NumComments + 1;
Comments{NumComments} = '';
for i = 2:Count
Comments{NumComments} = [Comments{NumComments},Token{i},' '];
end
case 'element' % element name
if Count >= 3
if isfield(Elements,Token{2})
fclose(fid);
error(['Duplicate element name, ''',Token{2},'''.']);
end
NumElements = NumElements + 1;
NumProperties = 0;
Elements = setfield(Elements,Token{2},[]);
PropertyTypes = setfield(PropertyTypes,Token{2},[]);
ElementNames{NumElements} = Token{2};
PropertyNames = setfield(PropertyNames,Token{2},{});
CurElement = Token{2};
ElementCount(NumElements) = str2double(Token{3});
if isnan(ElementCount(NumElements))
fclose(fid);
error(['Bad element definition: ',Buf]);
end
else
error(['Bad element definition: ',Buf]);
end
case 'property' % element property
if ~isempty(CurElement) & Count >= 3
NumProperties = NumProperties + 1;
eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],...
'fclose(fid);error([''Error reading property: '',Buf])');
if tmp
error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']);
end
% add property subfield to Elements
eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ...
'fclose(fid);error([''Error reading property: '',Buf])');
% add property subfield to PropertyTypes and save type
eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ...
'fclose(fid);error([''Error reading property: '',Buf])');
% record property name order
eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ...
'fclose(fid);error([''Error reading property: '',Buf])');
else
fclose(fid);
if isempty(CurElement)
error(['Property definition without element definition: ',Buf]);
else
error(['Bad property definition: ',Buf]);
end
end
case 'end_header' % end of header, break from while loop
break;
end
end
end
%%% set reading for specified data format %%%
if isempty(Format)
warning('Data format unspecified, assuming ASCII.');
Format = 'ascii';
end
switch Format
case 'ascii'
Format = 0;
case 'binary_little_endian'
Format = 1;
case 'binary_big_endian'
Format = 2;
otherwise
fclose(fid);
error(['Data format ''',Format,''' not supported.']);
end
if ~Format
Buf = fscanf(fid,'%f'); % read the rest of the file as ASCII data
BufOff = 1;
else
% reopen the file in read binary mode
fclose(fid);
if Format == 1
fid = fopen(Path,'r','ieee-le.l64'); % little endian
else
fid = fopen(Path,'r','ieee-be.l64'); % big endian
end
% find the end of the header again (using ftell on the old handle doesn't give the correct position)
BufSize = 8192;
Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')];
i = [];
tmp = -11;
while isempty(i)
i = findstr(Buf,['end_header',13,10]); % look for end_header + CR/LF
i = [i,findstr(Buf,['end_header',10])]; % look for end_header + LF
if isempty(i)
tmp = tmp + BufSize;
Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')];
end
end
% seek to just after the line feed
fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1);
end
%%% read element data %%%
% PLY and MATLAB data types (for fread)
PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ...
'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'};
MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'};
SizeOf = [1,1,2,2,4,4,4,8]; % size in bytes of each type
for i = 1:NumElements
% get current element property information
eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']);
eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']);
NumProperties = size(CurPropertyNames,2);
% fprintf('Reading %s...\n',ElementNames{i});
if ~Format %%% read ASCII data %%%
for j = 1:NumProperties
Token = getfield(CurPropertyTypes,CurPropertyNames{j});
if strcmpi(Token{1},'list')
Type(j) = 1;
else
Type(j) = 0;
end
end
% parse buffer
if ~any(Type)
% no list types
Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))';
BufOff = BufOff + ElementCount(i)*NumProperties;
else
ListData = cell(NumProperties,1);
for k = 1:NumProperties
ListData{k} = cell(ElementCount(i),1);
end
% list type
for j = 1:ElementCount(i)
for k = 1:NumProperties
if ~Type(k)
Data(j,k) = Buf(BufOff);
BufOff = BufOff + 1;
else
tmp = Buf(BufOff);
ListData{k}{j} = Buf(BufOff+(1:tmp))';
BufOff = BufOff + tmp + 1;
end
end
end
end
else %%% read binary data %%%
% translate PLY data type names to MATLAB data type names
ListFlag = 0; % = 1 if there is a list type
SameFlag = 1; % = 1 if all types are the same
for j = 1:NumProperties
Token = getfield(CurPropertyTypes,CurPropertyNames{j});
if ~strcmp(Token{1},'list') % non-list type
tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1;
if ~isempty(tmp)
TypeSize(j) = SizeOf(tmp);
Type{j} = MatlabTypeNames{tmp};
TypeSize2(j) = 0;
Type2{j} = '';
SameFlag = SameFlag & strcmp(Type{1},Type{j});
else
fclose(fid);
error(['Unknown property data type, ''',Token{1},''', in ', ...
ElementNames{i},'.',CurPropertyNames{j},'.']);
end
else % list type
if length(Token) == 3
ListFlag = 1;
SameFlag = 0;
tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1;
tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1;
if ~isempty(tmp) & ~isempty(tmp2)
TypeSize(j) = SizeOf(tmp);
Type{j} = MatlabTypeNames{tmp};
TypeSize2(j) = SizeOf(tmp2);
Type2{j} = MatlabTypeNames{tmp2};
else
fclose(fid);
error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ...
ElementNames{i},'.',CurPropertyNames{j},'.']);
end
else
fclose(fid);
error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']);
end
end
end
% read file
if ~ListFlag
if SameFlag
% no list types, all the same type (fast)
Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})';
else
% no list types, mixed type
Data = zeros(ElementCount(i),NumProperties);
for j = 1:ElementCount(i)
for k = 1:NumProperties
Data(j,k) = fread(fid,1,Type{k});
end
end
end
else
ListData = cell(NumProperties,1);
for k = 1:NumProperties
ListData{k} = cell(ElementCount(i),1);
end
if NumProperties == 1
BufSize = 512;
SkipNum = 4;
j = 0;
% list type, one property (fast if lists are usually the same length)
while j < ElementCount(i)
Position = ftell(fid);
% read in BufSize count values, assuming all counts = SkipNum
[Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1));
Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum
fseek(fid,Position + TypeSize(1),-1); % seek back to after first count
if isempty(Miss) % all counts are SkipNum
Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';
fseek(fid,-TypeSize(1),0); % undo last skip
for k = 1:BufSize
ListData{1}{j+k} = Buf(k,:);
end
j = j + BufSize;
BufSize = floor(1.5*BufSize);
else
if Miss(1) > 1 % some counts are SkipNum
Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';
for k = 1:Miss(1)-1
ListData{1}{j+k} = Buf2(k,:);
end
j = j + k;
end
% read in the list with the missed count
SkipNum = Buf(Miss(1));
j = j + 1;
ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1});
BufSize = ceil(0.6*BufSize);
end
end
else
% list type(s), multiple properties (slow)
Data = zeros(ElementCount(i),NumProperties);
for j = 1:ElementCount(i)
for k = 1:NumProperties
if isempty(Type2{k})
Data(j,k) = fread(fid,1,Type{k});
else
tmp = fread(fid,1,Type{k});
ListData{k}{j} = fread(fid,[1,tmp],Type2{k});
end
end
end
end
end
end
% put data into Elements structure
for k = 1:NumProperties
if (~Format & ~Type(k)) | (Format & isempty(Type2{k}))
eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']);
else
eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']);
end
end
end
clear Data ListData;
fclose(fid);
if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2
% find vertex element field
Name = {'vertex','Vertex','point','Point','pts','Pts'};
Names = [];
for i = 1:length(Name)
if any(strcmp(ElementNames,Name{i}))
Names = getfield(PropertyNames,Name{i});
Name = Name{i};
break;
end
end
if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z'))
eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']);
else
varargout{1} = zeros(1,3);
end
varargout{2} = Elements;
varargout{3} = Comments;
Elements = [];
% find face element field
Name = {'face','Face','poly','Poly','tri','Tri'};
Names = [];
for i = 1:length(Name)
if any(strcmp(ElementNames,Name{i}))
Names = getfield(PropertyNames,Name{i});
Name = Name{i};
break;
end
end
if ~isempty(Names)
% find vertex indices property subfield
PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'};
for i = 1:length(PropertyName)
if any(strcmp(Names,PropertyName{i}))
PropertyName = PropertyName{i};
break;
end
end
if ~iscell(PropertyName)
% convert face index lists to triangular connectivity
eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']);
N = length(FaceIndices);
Elements = zeros(N*2,3);
Extra = 0;
for k = 1:N
Elements(k,:) = FaceIndices{k}(1:3);
for j = 4:length(FaceIndices{k})
Extra = Extra + 1;
Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)];
end
end
Elements = Elements(1:N+Extra,:) + 1;
end
end
else
varargout{1} = Comments;
end
================================================
FILE: evaluations/dtu/reducePts_haa.m
================================================
function [ptsOut,indexSet] = reducePts_haa(pts, dst)
%Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance
% between points is 'dst'. Writen by abd, edited by haa, then by raje
nPoints=size(pts,2);
indexSet=true(nPoints,1);
RandOrd=randperm(nPoints);
%tic
NS = KDTreeSearcher(pts');
%toc
% search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big
Chunks=1:min(4e6,nPoints-1):nPoints;
Chunks(end)=nPoints;
for cChunk=1:(length(Chunks)-1)
Range=Chunks(cChunk):Chunks(cChunk+1);
idx = rangesearch(NS,pts(:,RandOrd(Range))',dst);
for i = 1:size(idx,1)
id =RandOrd(i-1+Chunks(cChunk));
if (indexSet(id))
indexSet(idx{i}) = 0;
indexSet(id) = 1;
end
end
end
ptsOut = pts(:,indexSet);
disp(['downsample factor: ' num2str(nPoints/sum(indexSet))]);
================================================
FILE: lists/blendedmvs/train.txt
================================================
5c1f33f1d33e1f2e4aa6dda4
5bfe5ae0fe0ea555e6a969ca
5bff3c5cfe0ea555e6bcbf3a
58eaf1513353456af3a1682a
5bfc9d5aec61ca1dd69132a2
5bf18642c50e6f7f8bdbd492
5bf26cbbd43923194854b270
5bf17c0fd439231948355385
5be3ae47f44e235bdbbc9771
5be3a5fb8cfdd56947f6b67c
5bbb6eb2ea1cfa39f1af7e0c
5ba75d79d76ffa2c86cf2f05
5bb7a08aea1cfa39f1a947ab
5b864d850d072a699b32f4ae
5b6eff8b67b396324c5b2672
5b6e716d67b396324c2d77cb
5b69cc0cb44b61786eb959bf
5b62647143840965efc0dbde
5b60fa0c764f146feef84df0
5b558a928bbfb62204e77ba2
5b271079e0878c3816dacca4
5b08286b2775267d5b0634ba
5afacb69ab00705d0cefdd5b
5af28cea59bc705737003253
5af02e904c8216544b4ab5a2
5aa515e613d42d091d29d300
5c34529873a8df509ae57b58
5c34300a73a8df509add216d
5c1af2e2bee9a723c963d019
5c1892f726173c3a09ea9aeb
5c0d13b795da9479e12e2ee9
5c062d84a96e33018ff6f0a6
5bfd0f32ec61ca1dd69dc77b
5bf21799d43923194842c001
5bf3a82cd439231948877aed
5bf03590d4392319481971dc
5beb6e66abd34c35e18e66b9
5be883a4f98cee15019d5b83
5be47bf9b18881428d8fbc1d
5bcf979a6d5f586b95c258cd
5bce7ac9ca24970bce4934b6
5bb8a49aea1cfa39f1aa7f75
5b78e57afc8fcf6781d0c3ba
5b21e18c58e2823a67a10dd8
5b22269758e2823a67a3bd03
5b192eb2170cf166458ff886
5ae2e9c5fe405c5076abc6b2
5adc6bd52430a05ecb2ffb85
5ab8b8e029f5351f7f2ccf59
5abc2506b53b042ead637d86
5ab85f1dac4291329b17cb50
5a969eea91dfc339a9a3ad2c
5a8aa0fab18050187cbe060e
5a7d3db14989e929563eb153
5a69c47d0d5d0a7f3b2e9752
5a618c72784780334bc1972d
5a6464143d809f1d8208c43c
5a588a8193ac3d233f77fbca
5a57542f333d180827dfc132
5a572fd9fc597b0478a81d14
5a563183425d0f5186314855
5a4a38dad38c8a075495b5d2
5a48d4b2c7dab83a7d7b9851
5a489fb1c7dab83a7d7b1070
5a48ba95c7dab83a7d7b44ed
5a3ca9cb270f0e3f14d0eddb
5a3cb4e4270f0e3f14d12f43
5a3f4aba5889373fbbc5d3b5
5a0271884e62597cdee0d0eb
59e864b2a9e91f2c5529325f
599aa591d5b41f366fed0d58
59350ca084b7f26bf5ce6eb8
59338e76772c3e6384afbb15
5c20ca3a0843bc542d94e3e2
5c1dbf200843bc542d8ef8c4
5c1b1500bee9a723c96c3e78
5bea87f4abd34c35e1860ab5
5c2b3ed5e611832e8aed46bf
57f8d9bbe73f6760f10e916a
5bf7d63575c26f32dbf7413b
5be4ab93870d330ff2dce134
5bd43b4ba6b28b1ee86b92dd
5bccd6beca24970bce448134
5bc5f0e896b66a2cd8f9bd36
5b908d3dc6ab78485f3d24a9
5b2c67b5e0878c381608b8d8
5b4933abf2b5f44e95de482a
5b3b353d8d46a939f93524b9
5acf8ca0f3d8a750097e4b15
5ab8713ba3799a1d138bd69a
5aa235f64a17b335eeaf9609
5aa0f9d7a9efce63548c69a1
5a8315f624b8e938486e0bd8
5a48c4e9c7dab83a7d7b5cc7
59ecfd02e225f6492d20fcc9
59f87d0bfa6280566fb38c9a
59f363a8b45be22330016cad
59f70ab1e5c5d366af29bf3e
59e75a2ca9e91f2c5526005d
5947719bf1b45630bd096665
5947b62af1b45630bd0c2a02
59056e6760bb961de55f3501
58f7f7299f5b5647873cb110
58cf4771d0f5fb221defe6da
58d36897f387231e6c929903
58c4bb4f4a69c55606122be4
================================================
FILE: lists/blendedmvs/val.txt
================================================
5b7a3890fc8fcf6781e2593a
5c189f2326173c3a09ed7ef3
5b950c71608de421b1e7318f
5a6400933d809f1d8200af15
59d2657f82ca7774b1ec081d
5ba19a8a360c7c30c1c169df
59817e4a1bd4b175e7038d19
================================================
FILE: lists/dtu/test.txt
================================================
scan1
scan4
scan9
scan10
scan11
scan12
scan13
scan15
scan23
scan24
scan29
scan32
scan33
scan34
scan48
scan49
scan62
scan75
scan77
scan110
scan114
scan118
================================================
FILE: lists/dtu/train.txt
================================================
scan2
scan6
scan7
scan8
scan14
scan16
scan18
scan19
scan20
scan22
scan30
scan31
scan36
scan39
scan41
scan42
scan44
scan45
scan46
scan47
scan50
scan51
scan52
scan53
scan55
scan57
scan58
scan60
scan61
scan63
scan64
scan65
scan68
scan69
scan70
scan71
scan72
scan74
scan76
scan83
scan84
scan85
scan87
scan88
scan89
scan90
scan91
scan92
scan93
scan94
scan95
scan96
scan97
scan98
scan99
scan100
scan101
scan102
scan103
scan104
scan105
scan107
scan108
scan109
scan111
scan112
scan113
scan115
scan116
scan119
scan120
scan121
scan122
scan123
scan124
scan125
scan126
scan127
scan128
================================================
FILE: lists/dtu/trainval.txt
================================================
scan2
scan6
scan7
scan8
scan14
scan16
scan18
scan19
scan20
scan22
scan30
scan31
scan36
scan39
scan41
scan42
scan44
scan45
scan46
scan47
scan50
scan51
scan52
scan53
scan55
scan57
scan58
scan60
scan61
scan63
scan64
scan65
scan68
scan69
scan70
scan71
scan72
scan74
scan76
scan83
scan84
scan85
scan87
scan88
scan89
scan90
scan91
scan92
scan93
scan94
scan95
scan96
scan97
scan98
scan99
scan100
scan101
scan102
scan103
scan104
scan105
scan107
scan108
scan109
scan111
scan112
scan113
scan115
scan116
scan119
scan120
scan121
scan122
scan123
scan124
scan125
scan126
scan127
scan128
scan3
scan5
scan17
scan21
scan28
scan35
scan37
scan38
scan40
scan43
scan56
scan59
scan66
scan67
scan82
scan86
scan106
scan117
================================================
FILE: lists/dtu/val.txt
================================================
scan3
scan5
scan17
scan21
scan28
scan35
scan37
scan38
scan40
scan43
scan56
scan59
scan66
scan67
scan82
scan86
scan106
scan117
================================================
FILE: models/MVS4Net.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.mvs4net_utils import stagenet, reg2d, reg3d, FPN4, FPN4_convnext, FPN4_convnext4, PosEncSine, PosEncLearned, \
init_range, schedule_range, init_inverse_range, schedule_inverse_range, sinkhorn, mono_depth_decoder, ASFF
class MVS4net(nn.Module):
def __init__(self, arch_mode="fpn", reg_net='reg2d', num_stage=4, fpn_base_channel=8,
reg_channel=8, stage_splits=[8,8,4,4], depth_interals_ratio=[0.5,0.5,0.5,1],
group_cor=False, group_cor_dim=[8,8,8,8],
inverse_depth=False,
agg_type='ConvBnReLU3D',
dcn=False,
pos_enc=0,
mono=False,
asff=False,
attn_temp=2,
attn_fuse_d=True,
vis_ETA=False,
vis_mono=False
):
# pos_enc: 0 no pos enc; 1 depth sine; 2 learnable pos enc
super(MVS4net, self).__init__()
self.arch_mode = arch_mode
self.num_stage = num_stage
self.depth_interals_ratio = depth_interals_ratio
self.group_cor = group_cor
self.group_cor_dim = group_cor_dim
self.inverse_depth = inverse_depth
self.asff = asff
if self.asff:
self.asff = nn.ModuleList([ASFF(i) for i in range(num_stage)])
self.attn_ob = nn.ModuleList()
if arch_mode == "fpn":
self.feature = FPN4(base_channels=fpn_base_channel, gn=False, dcn=dcn)
self.vis_mono = vis_mono
self.stagenet = stagenet(inverse_depth, mono, attn_fuse_d, vis_ETA, attn_temp)
self.stage_splits = stage_splits
self.reg = nn.ModuleList()
self.pos_enc = pos_enc
self.pos_enc_func = nn.ModuleList()
self.mono = mono
if self.mono:
self.mono_depth_decoder = mono_depth_decoder()
if reg_net == 'reg3d':
self.down_size = [3,3,2,2]
for idx in range(num_stage):
if self.group_cor:
in_dim = group_cor_dim[idx]
else:
in_dim = self.feature.out_channels[idx]
if reg_net == 'reg2d':
self.reg.append(reg2d(input_channel=in_dim, base_channel=reg_channel, conv_name=agg_type))
elif reg_net == 'reg3d':
self.reg.append(reg3d(in_channels=in_dim, base_channels=reg_channel, down_size=self.down_size[idx]))
def forward(self, imgs, proj_matrices, depth_values, filename=None):
depth_min = depth_values[:, 0].cpu().numpy()
depth_max = depth_values[:, -1].cpu().numpy()
depth_interval = (depth_max - depth_min) / depth_values.size(1)
# step 1. feature extraction
features = []
for nview_idx in range(len(imgs)): #imgs shape (B, N, C, H, W)
img = imgs[nview_idx]
features.append(self.feature(img))
if self.vis_mono:
scan_name = filename[0].split('/')[0]
image_name = filename[0].split('/')[2][:-2]
save_fn = './debug_figs/vis_mono/feat_{}'.format(scan_name+'_'+image_name)
feat_ = features[-1]['stage4'].detach().cpu().numpy()
np.save(save_fn, feat_)
# step 2. iter (multi-scale)
outputs = {}
for stage_idx in range(self.num_stage):
if not self.asff:
features_stage = [feat["stage{}".format(stage_idx+1)] for feat in features]
else:
features_stage = [self.asff[stage_idx](feat['stage1'],feat['stage2'],feat['stage3'],feat['stage4']) for feat in features]
proj_matrices_stage = proj_matrices["stage{}".format(stage_idx + 1)]
B,C,H,W = features[0]['stage{}'.format(stage_idx+1)].shape
# init range
if stage_idx == 0:
if self.inverse_depth:
depth_hypo = init_inverse_range(depth_values, self.stage_splits[stage_idx], img[0].device, img[0].dtype, H, W)
else:
depth_hypo = init_range(depth_values, self.stage_splits[stage_idx], img[0].device, img[0].dtype, H, W)
else:
if self.inverse_depth:
depth_hypo = schedule_inverse_range(outputs_stage['inverse_min_depth'].detach(), outputs_stage['inverse_max_depth'].detach(), self.stage_splits[stage_idx], H, W) # B D H W
else:
depth_hypo = schedule_range(outputs_stage['depth'].detach(), self.stage_splits[stage_idx], self.depth_interals_ratio[stage_idx] * depth_interval, H, W)
outputs_stage = self.stagenet(features_stage, proj_matrices_stage, depth_hypo=depth_hypo, regnet=self.reg[stage_idx], stage_idx=stage_idx,
group_cor=self.group_cor, group_cor_dim=self.group_cor_dim[stage_idx],
split_itv=self.depth_interals_ratio[stage_idx],
fn=filename)
outputs["stage{}".format(stage_idx + 1)] = outputs_stage
outputs.update(outputs_stage)
if self.mono and self.training:
# if self.mono:
outputs = self.mono_depth_decoder(outputs, depth_values[:,0], depth_values[:,1])
return outputs
def MVS4net_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
stage_lw = kwargs.get("stage_lw", [1,1,1,1])
l1ot_lw = kwargs.get("l1ot_lw", [0,1])
inverse = kwargs.get("inverse_depth", False)
ot_iter = kwargs.get("ot_iter", 3)
ot_eps = kwargs.get("ot_eps", 1)
ot_continous = kwargs.get("ot_continous", False)
mono = kwargs.get("mono", False)
total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
stage_ot_loss = []
stage_l1_loss = []
range_err_ratio = []
for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if "stage" in k]):
depth_pred = stage_inputs['depth']
hypo_depth = stage_inputs['hypo_depth']
attn_weight = stage_inputs['attn_weight']
B,H,W = depth_pred.shape
D = hypo_depth.shape[1]
mask = mask_ms[stage_key]
mask = mask > 0.5
depth_gt = depth_gt_ms[stage_key]
if mono and stage_idx!=0:
this_stage_l1_loss = F.l1_loss(stage_inputs['mono_depth'][mask], depth_gt[mask], reduction='mean')
else:
this_stage_l1_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
# mask range
if inverse:
depth_itv = (1/hypo_depth[:,2,:,:]-1/hypo_depth[:,1,:,:]).abs() # B H W
mask_out_of_range = ((1/hypo_depth - 1/depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W
else:
depth_itv = (hypo_depth[:,2,:,:]-hypo_depth[:,1,:,:]).abs() # B H W
mask_out_of_range = ((hypo_depth - depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W
range_err_ratio.append(mask_out_of_range[mask].float().mean())
this_stage_ot_loss = sinkhorn(depth_gt, hypo_depth, attn_weight, mask, iters=ot_iter, eps=ot_eps, continuous=ot_continous)[1]
stage_l1_loss.append(this_stage_l1_loss)
stage_ot_loss.append(this_stage_ot_loss)
total_loss = total_loss + stage_lw[stage_idx] * (l1ot_lw[0] * this_stage_l1_loss + l1ot_lw[1] * this_stage_ot_loss)
return total_loss, stage_l1_loss, stage_ot_loss, range_err_ratio
def Blend_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
stage_lw = kwargs.get("stage_lw", [1,1,1,1])
l1ot_lw = kwargs.get("l1ot_lw", [0,1])
inverse = kwargs.get("inverse_depth", False)
ot_iter = kwargs.get("ot_iter", 3)
ot_eps = kwargs.get("ot_eps", 1)
ot_continous = kwargs.get("ot_continous", False)
depth_max = kwargs.get("depth_max", 100)
depth_min = kwargs.get("depth_min", 1)
mono = kwargs.get("mono", False)
total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
stage_ot_loss = []
stage_l1_loss = []
range_err_ratio = []
for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if "stage" in k]):
depth_pred = stage_inputs['depth']
hypo_depth = stage_inputs['hypo_depth']
attn_weight = stage_inputs['attn_weight']
B,H,W = depth_pred.shape
mask = mask_ms[stage_key]
mask = mask > 0.5
depth_gt = depth_gt_ms[stage_key]
depth_pred_norm = depth_pred * 128 / (depth_max - depth_min)[:,None,None] # B H W
depth_gt_norm = depth_gt * 128 / (depth_max - depth_min)[:,None,None] # B H W
if mono and stage_idx!=0:
this_stage_l1_loss = F.l1_loss(stage_inputs['mono_depth'][mask], depth_gt[mask], reduction='mean')
else:
this_stage_l1_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
if inverse:
depth_itv = (1/hypo_depth[:,2,:,:]-1/hypo_depth[:,1,:,:]).abs() # B H W
mask_out_of_range = ((1/hypo_depth - 1/depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W
else:
depth_itv = (hypo_depth[:,2,:,:]-hypo_depth[:,1,:,:]).abs() # B H W
mask_out_of_range = ((hypo_depth - depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W
range_err_ratio.append(mask_out_of_range[mask].float().mean())
this_stage_ot_loss = sinkhorn(depth_gt, hypo_depth, attn_weight, mask, iters=ot_iter, eps=ot_eps, continuous=ot_continous)[1]
stage_l1_loss.append(this_stage_l1_loss)
stage_ot_loss.append(this_stage_ot_loss)
total_loss = total_loss + stage_lw[stage_idx] * (l1ot_lw[0] * this_stage_l1_loss + l1ot_lw[1] * this_stage_ot_loss)
abs_err = torch.abs(depth_pred_norm[mask] - depth_gt_norm[mask])
epe = abs_err.mean()
err3 = (abs_err<=3).float().mean()*100
err1= (abs_err<=1).float().mean()*100
return total_loss, stage_l1_loss, stage_ot_loss, range_err_ratio, epe, err3, err1
================================================
FILE: models/__init__.py
================================================
from models.MVS4Net import MVS4net, MVS4net_loss, Blend_loss
================================================
FILE: models/module.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import sys
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("..")
from utils import local_pcd
from modules.deform_conv import DeformConvPack
def init_bn(module):
if module.weight is not None:
nn.init.ones_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
return
def init_uniform(module, init_method):
if module.weight is not None:
if init_method == "kaiming":
nn.init.kaiming_uniform_(module.weight)
elif init_method == "xavier":
nn.init.xavier_uniform_(module.weight)
return
class Conv2d(nn.Module):
"""Applies a 2D convolution (optionally with batch normalization and relu activation)
over an input signal composed of several input planes.
Attributes:
conv (nn.Module): convolution module
bn (nn.Module): batch normalization module
relu (bool): whether to activate by relu
Notes:
Default momentum for batch normalization is set to be 0.01,
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
super(Conv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
bias=(not bn), **kwargs)
self.kernel_size = kernel_size
self.stride = stride
self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
# assert init_method in ["kaiming", "xavier"]
# self.init_weights(init_method)
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class DCNConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
super(DCNConv2d, self).__init__()
self.conv = DeformConvPack(in_channels, out_channels, kernel_size, stride=stride, padding=1, bias=(not bn), im2col_step=16)
self.kernel_size = kernel_size
self.stride = stride
self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
# assert init_method in ["kaiming", "xavier"]
# self.init_weights(init_method)
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class Deconv2d(nn.Module):
"""Applies a 2D deconvolution (optionally with batch normalization and relu activation)
over an input signal composed of several input planes.
Attributes:
conv (nn.Module): convolution module
bn (nn.Module): batch normalization module
relu (bool): whether to activate by relu
Notes:
Default momentum for batch normalization is set to be 0.01,
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
super(Deconv2d, self).__init__()
self.out_channels = out_channels
assert stride in [1, 2]
self.stride = stride
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride,
bias=(not bn), **kwargs)
self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
# assert init_method in ["kaiming", "xavier"]
# self.init_weights(init_method)
def forward(self, x):
y = self.conv(x)
if self.stride == 2:
h, w = list(x.size())[2:]
y = y[:, :, :2 * h, :2 * w].contiguous()
if self.bn is not None:
x = self.bn(y)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class Conv3d(nn.Module):
"""Applies a 3D convolution (optionally with batch normalization and relu activation)
over an input signal composed of several input planes.
Attributes:
conv (nn.Module): convolution module
bn (nn.Module): batch normalization module
relu (bool): whether to activate by relu
Notes:
Default momentum for batch normalization is set to be 0.01,
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
super(Conv3d, self).__init__()
self.out_channels = out_channels
self.kernel_size = kernel_size
assert stride in [1, 2]
self.stride = stride
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride,
bias=(not bn), **kwargs)
self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
# assert init_method in ["kaiming", "xavier"]
# self.init_weights(init_method)
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class PConv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
relu=True, bn=True, bn_momentum=0.1, padding=1, init_method="xavier", **kwargs):
super(PConv3d, self).__init__()
self.out_channels = out_channels
self.kernel_size_xy = (1, kernel_size, kernel_size)
self.kernel_size_d = (kernel_size, 1, 1)
assert stride in [1, 2]
self.stride_xy = (1, stride, stride)
self.stride_d = (stride, 1, 1)
self.padding_xy = (0, padding, padding)
self.padding_d = (padding, 0, 0)
self.convxy = nn.Conv3d(in_channels, in_channels, self.kernel_size_xy, stride=self.stride_xy, padding=self.padding_xy, bias=(not bn), **kwargs)
self.convd = nn.Conv3d(in_channels, out_channels, self.kernel_size_d, stride=self.stride_d, padding=self.padding_d, bias=(not bn), **kwargs)
self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
# assert init_method in ["kaiming", "xavier"]
# self.init_weights(init_method)
def forward(self, x):
x = self.convxy(x)
x = self.convd(x)
if self.bn is not None:
x = self.bn(x)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.convxy, init_method)
init_uniform(self.convd, init_method)
if self.bn is not None:
init_bn(self.bn)
class Deconv3d(nn.Module):
"""Applies a 3D deconvolution (optionally with batch normalization and relu activation)
over an input signal composed of several input planes.
Attributes:
conv (nn.Module): convolution module
bn (nn.Module): batch normalization module
relu (bool): whether to activate by relu
Notes:
Default momentum for batch normalization is set to be 0.01,
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
super(Deconv3d, self).__init__()
self.out_channels = out_channels
assert stride in [1, 2]
self.stride = stride
self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
bias=(not bn), **kwargs)
self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
# assert init_method in ["kaiming", "xavier"]
# self.init_weights(init_method)
def forward(self, x):
y = self.conv(x)
if self.bn is not None:
x = self.bn(y)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class PDeconv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,output_padding=1,
relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
super(PDeconv3d, self).__init__()
self.out_channels = out_channels
assert stride in [1, 2]
self.stride = stride
self.kernel_size_xy = (1, kernel_size,kernel_size)
self.kernel_size_d = (kernel_size, 1,1)
self.stride_xy = (1, stride, stride)
self.stride_d = (stride, 1, 1)
self.padding_xy = (0, padding, padding)
self.padding_d = (padding, 0, 0)
self.outpadding_xy = (0, output_padding, output_padding)
self.outpadding_d = (output_padding, 0, 0)
self.convxy = nn.ConvTranspose3d(in_channels, in_channels, self.kernel_size_xy, stride=self.stride_xy, padding=self.padding_xy, output_padding=self.outpadding_xy, bias=(not bn))
self.convd = nn.ConvTranspose3d(in_channels, out_channels, self.kernel_size_d, stride=self.stride_d, padding=self.padding_d, output_padding=self.outpadding_d, bias=(not bn))
self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
# assert init_method in ["kaiming", "xavier"]
# self.init_weights(init_method)
def forward(self, x):
x = self.convxy(x)
y = self.convd(x)
if self.bn is not None:
x = self.bn(y)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
"""default initialization"""
init_uniform(self.convxy, init_method)
init_uniform(self.convd, init_method)
if self.bn is not None:
init_bn(self.bn)
class ConvBnReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
super(ConvBnReLU, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
return F.relu(self.bn(self.conv(x)), inplace=True)
class ConvBn(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
super(ConvBn, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
return self.bn(self.conv(x))
class ConvBnReLU3D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
super(ConvBnReLU3D, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm3d(out_channels)
def forward(self, x):
return F.relu(self.bn(self.conv(x)), inplace=True)
class ConvBn3D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
super(ConvBn3D, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm3d(out_channels)
def forward(self, x):
return self.bn(self.conv(x))
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = ConvBnReLU(in_channels, out_channels, kernel_size=3, stride=stride, pad=1)
self.conv2 = ConvBn(out_channels, out_channels, kernel_size=3, stride=1, pad=1)
self.downsample = downsample
self.stride = stride
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
if self.downsample is not None:
x = self.downsample(x)
out += x
return out
class Hourglass3d(nn.Module):
def __init__(self, channels):
super(Hourglass3d, self).__init__()
self.conv1a = ConvBnReLU3D(channels, channels * 2, kernel_size=3, stride=2, pad=1)
self.conv1b = ConvBnReLU3D(channels * 2, channels * 2, kernel_size=3, stride=1, pad=1)
self.conv2a = ConvBnReLU3D(channels * 2, channels * 4, kernel_size=3, stride=2, pad=1)
self.conv2b = ConvBnReLU3D(channels * 4, channels * 4, kernel_size=3, stride=1, pad=1)
self.dconv2 = nn.Sequential(
nn.ConvTranspose3d(channels * 4, channels * 2, kernel_size=3, padding=1, output_padding=1, stride=2,
bias=False),
nn.BatchNorm3d(channels * 2))
self.dconv1 = nn.Sequential(
nn.ConvTranspose3d(channels * 2, channels, kernel_size=3, padding=1, output_padding=1, stride=2,
bias=False),
nn.BatchNorm3d(channels))
self.redir1 = ConvBn3D(channels, channels, kernel_size=1, stride=1, pad=0)
self.redir2 = ConvBn3D(channels * 2, channels * 2, kernel_size=1, stride=1, pad=0)
def forward(self, x):
conv1 = self.conv1b(self.conv1a(x))
conv2 = self.conv2b(self.conv2a(conv1))
dconv2 = F.relu(self.dconv2(conv2) + self.redir2(conv1), inplace=True)
dconv1 = F.relu(self.dconv1(dconv2) + self.redir1(x), inplace=True)
return dconv1
def homo_warping(src_fea, src_proj, ref_proj, depth_values, align_corners=False):
# src_fea: [B, C, H, W]
# src_proj: [B, 4, 4]
# ref_proj: [B, 4, 4]
# depth_values: [B, Ndepth] o [B, Ndepth, H, W]
# out: [B, C, Ndepth, H, W]
batch, channels = src_fea.shape[0], src_fea.shape[1]
num_depth = depth_values.shape[1]
height, width = src_fea.shape[2], src_fea.shape[3]
with torch.no_grad():
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
rot = proj[:, :3, :3] # [B,3,3]
trans = proj[:, :3, 3:4] # [B,3,1]
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device),
torch.arange(0, width, dtype=torch.float32, device=src_fea.device)])
y, x = y.contiguous(), x.contiguous()
y, x = y.view(height * width), x.view(height * width)
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.view(batch, 1, num_depth, -1) # [B, 3, Ndepth, H*W]
proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1
proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]
grid = proj_xy
warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear', padding_mode='zeros', align_corners=align_corners)
warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width)
return warped_src_fea
class DeConv2dFuse(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, relu=True, bn=True,
bn_momentum=0.1):
super(DeConv2dFuse, self).__init__()
self.deconv = Deconv2d(in_channels, out_channels, kernel_size, stride=2, padding=1, output_padding=1,
bn=True, relu=relu, bn_momentum=bn_momentum)
self.conv = Conv2d(2*out_channels, out_channels, kernel_size, stride=1, padding=1,
bn=bn, relu=relu, bn_momentum=bn_momentum)
# assert init_method in ["kaiming", "xavier"]
# self.init_weights(init_method)
def forward(self, x_pre, x):
x = self.deconv(x)
x = torch.cat((x, x_pre), dim=1)
x = self.conv(x)
return x
class FeatureNet(nn.Module):
def __init__(self, base_channels, num_stage=3, stride=4, arch_mode="unet"):
super(FeatureNet, self).__init__()
assert arch_mode in ["unet", "fpn"], print("mode must be in 'unet' or 'fpn', but get:{}".format(arch_mode))
print("*************feature extraction arch mode:{}****************".format(arch_mode))
self.arch_mode = arch_mode
self.stride = stride
self.base_channels = base_channels
self.num_stage = num_stage
self.conv0 = nn.Sequential(
Conv2d(3, base_channels, 3, 1, padding=1),
Conv2d(base_channels, base_channels, 3, 1, padding=1),
)
self.conv1 = nn.Sequential(
Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
)
self.conv2 = nn.Sequential(
Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
)
self.out1 = nn.Conv2d(base_channels * 4, base_channels * 4, 1, bias=False)
self.out_channels = [4 * base_channels]
if self.arch_mode == 'unet':
if num_stage == 3:
self.deconv1 = DeConv2dFuse(base_channels * 4, base_channels * 2, 3)
self.deconv2 = DeConv2dFuse(base_channels * 2, base_channels, 3)
self.out2 = nn.Conv2d(base_channels * 2, base_channels * 2, 1, bias=False)
self.out3 = nn.Conv2d(base_channels, base_channels, 1, bias=False)
self.out_channels.append(2 * base_channels)
self.out_channels.append(base_channels)
elif num_stage == 2:
self.deconv1 = DeConv2dFuse(base_channels * 4, base_channels * 2, 3)
self.out2 = nn.Conv2d(base_channels * 2, base_channels * 2, 1, bias=False)
self.out_channels.append(2 * base_channels)
elif self.arch_mode == "fpn":
final_chs = base_channels * 4
if num_stage == 3:
self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)
self.out2 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
self.out3 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)
self.out_channels.append(base_channels * 2)
self.out_channels.append(base_channels)
elif num_stage == 2:
self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
self.out2 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)
self.out_channels.append(base_channels)
def forward(self, x):
conv0 = self.conv0(x)
conv1 = self.conv1(conv0)
conv2 = self.conv2(conv1)
intra_feat = conv2
outputs = {}
out = self.out1(intra_feat)
outputs["stage1"] = out
if self.arch_mode == "unet":
if self.num_stage == 3:
intra_feat = self.deconv1(conv1, intra_feat)
out = self.out2(intra_feat)
outputs["stage2"] = out
intra_feat = self.deconv2(conv0, intra_feat)
out = self.out3(intra_feat)
outputs["stage3"] = out
elif self.num_stage == 2:
intra_feat = self.deconv1(conv1, intra_feat)
out = self.out2(intra_feat)
outputs["stage2"] = out
elif self.arch_mode == "fpn":
if self.num_stage == 3:
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner1(conv1)
out = self.out2(intra_feat)
outputs["stage2"] = out
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner2(conv0)
out = self.out3(intra_feat)
outputs["stage3"] = out
elif self.num_stage == 2:
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner1(conv1)
out = self.out2(intra_feat)
outputs["stage2"] = out
return outputs
class FPNDCNpath(nn.Module):
"""
FPN+DCN pathway"""
def __init__(self, base_channels, stride=4):
super(FPNDCNpath, self).__init__()
self.stride = stride
self.base_channels = base_channels
self.conv0 = nn.Sequential(
Conv2d(3, base_channels, 3, 1, padding=1),
Conv2d(base_channels, base_channels, 3, 1, padding=1),
)
self.conv1 = nn.Sequential(
Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
)
self.conv2 = nn.Sequential(
Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
)
self.out1 = nn.Sequential(
DCNConv2d(base_channels * 4, base_channels * 4, 3, stride=1, padding=1),
DCNConv2d(base_channels * 4, base_channels * 4, 3, stride=1, padding=1),
DeformConvPack(base_channels * 4, base_channels * 4, 3, stride=1, padding=1, bias=False, im2col_step=16)
)
self.out_channels = [4 * base_channels]
final_chs = base_channels * 4
self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)
self.out2 = nn.Sequential(
DCNConv2d(base_channels * 4, base_channels * 2, 3, stride=1, padding=1),
DCNConv2d(base_channels * 2, base_channels * 2, 3, stride=1, padding=1),
DeformConvPack(base_channels * 2, base_channels * 2, 3, stride=1, padding=1, bias=False, im2col_step=16)
)
self.out2pathconv = nn.Conv2d(base_channels * 4, base_channels * 2, 3, stride=1, padding=1)
self.out3 = nn.Sequential(
DCNConv2d(base_channels * 4, base_channels * 1, 3, stride=1, padding=1),
DCNConv2d(base_channels * 1, base_channels * 1, 3, stride=1, padding=1),
DeformConvPack(base_channels * 1, base_channels * 1, 3, stride=1, padding=1, bias=False, im2col_step=16)
)
self.out3pathconv = nn.Conv2d(base_channels * 2, base_channels * 1, 3, stride=1, padding=1)
self.out_channels.append(base_channels * 2)
self.out_channels.append(base_channels)
def forward(self, x):
conv0 = self.conv0(x)
conv1 = self.conv1(conv0)
conv2 = self.conv2(conv1)
intra_feat = conv2
outputs = {}
out1 = self.out1(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv1)
out2 = self.out2(intra_feat)
out2 = out2 + self.out2pathconv(F.interpolate(out1, scale_factor=2, mode="bilinear", align_corners=True))
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv0)
out3 = self.out3(intra_feat)
out3 = out3 + self.out3pathconv(F.interpolate(out2, scale_factor=2, mode="bilinear", align_corners=True))
outputs["stage1"] = out1
outputs["stage2"] = out2
outputs["stage3"] = out3
return outputs
class FPNDCN(nn.Module):
"""
FPN+DCN"""
def __init__(self, base_channels, stride=4):
super(FPNDCN, self).__init__()
self.stride = stride
self.base_channels = base_channels
self.conv0 = nn.Sequential(
Conv2d(3, base_channels, 3, 1, padding=1),
Conv2d(base_channels, base_channels, 3, 1, padding=1),
)
self.conv1 = nn.Sequential(
Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
)
self.conv2 = nn.Sequential(
Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
)
self.out1 = nn.Sequential(
DCNConv2d(base_channels * 4, base_channels * 4, 3, stride=1, padding=1),
DCNConv2d(base_channels * 4, base_channels * 4, 3, stride=1, padding=1),
DeformConvPack(base_channels * 4, base_channels * 4, 3, stride=1, padding=1, bias=False, im2col_step=16)
)
self.out_channels = [4 * base_channels]
final_chs = base_channels * 4
self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)
self.out2 = nn.Sequential(
DCNConv2d(base_channels * 4, base_channels * 2, 3, stride=1, padding=1),
DCNConv2d(base_channels * 2, base_channels * 2, 3, stride=1, padding=1),
DeformConvPack(base_channels * 2, base_channels * 2, 3, stride=1, padding=1, bias=False, im2col_step=16)
)
self.out3 = nn.Sequential(
DCNConv2d(base_channels * 4, base_channels * 1, 3, stride=1, padding=1),
DCNConv2d(base_channels * 1, base_channels * 1, 3, stride=1, padding=1),
DeformConvPack(base_channels * 1, base_channels * 1, 3, stride=1, padding=1, bias=False, im2col_step=16)
)
self.out_channels.append(base_channels * 2)
self.out_channels.append(base_channels)
def forward(self, x):
conv0 = self.conv0(x)
conv1 = self.conv1(conv0)
conv2 = self.conv2(conv1)
intra_feat = conv2
outputs = {}
out1 = self.out1(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv1)
out2 = self.out2(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv0)
out3 = self.out3(intra_feat)
outputs["stage1"] = out1
outputs["stage2"] = out2
outputs["stage3"] = out3
return outputs
class FPNA(nn.Module):
"""
FPN aligncorners"""
def __init__(self, base_channels, stride=4):
super(FPNA, self).__init__()
self.stride = stride
self.base_channels = base_channels
self.conv0 = nn.Sequential(
Conv2d(3, base_channels, 3, 1, padding=1),
Conv2d(base_channels, base_channels, 3, 1, padding=1),
)
self.conv1 = nn.Sequential(
Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
)
self.conv2 = nn.Sequential(
Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
)
self.out1 = nn.Conv2d(base_channels * 4, base_channels * 4, 1, bias=False)
self.out_channels = [4 * base_channels]
final_chs = base_channels * 4
self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)
self.out2 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
self.out3 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)
self.out_channels.append(base_channels * 2)
self.out_channels.append(base_channels)
def forward(self, x):
conv0 = self.conv0(x)
conv1 = self.conv1(conv0)
conv2 = self.conv2(conv1)
intra_feat = conv2
outputs = {}
out1 = self.out1(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv1)
out2 = self.out2(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv0)
out3 = self.out3(intra_feat)
outputs["stage1"] = out1
outputs["stage2"] = out2
outputs["stage3"] = out3
return outputs
class FPNA4(nn.Module):
"""
FPN aligncorners downsample 4x"""
def __init__(self, base_channels):
super(FPNA4, self).__init__()
self.base_channels = base_channels
self.conv0 = nn.Sequential(
Conv2d(3, base_channels, 3, 1, padding=1),
Conv2d(base_channels, base_channels, 3, 1, padding=1),
)
self.conv1 = nn.Sequential(
Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
)
self.conv2 = nn.Sequential(
Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
)
self.conv3 = nn.Sequential(
Conv2d(base_channels * 4, base_channels * 8, 5, stride=2, padding=2),
Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1),
Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1),
)
self.out_channels = [8 * base_channels]
final_chs = base_channels * 8
self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)
self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)
self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)
self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)
self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)
self.out_channels.append(base_channels * 4)
self.out_channels.append(base_channels * 2)
self.out_channels.append(base_channels)
def forward(self, x):
conv0 = self.conv0(x)
conv1 = self.conv1(conv0)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
intra_feat = conv3
outputs = {}
out1 = self.out1(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2)
out2 = self.out2(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1)
out3 = self.out3(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0)
out4 = self.out4(intra_feat)
outputs["stage1"] = out1
outputs["stage2"] = out2
outputs["stage3"] = out3
outputs["stage4"] = out4
return outputs
class CostRegNet(nn.Module):
def __init__(self, in_channels, base_channels, down_size=3):
super(CostRegNet, self).__init__()
self.down_size = down_size
self.conv0 = Conv3d(in_channels, base_channels, padding=1)
self.conv1 = Conv3d(base_channels, base_channels * 2, stride=2, padding=1)
self.conv2 = Conv3d(base_channels * 2, base_channels * 2, padding=1)
if down_size >= 2:
self.conv3 = Conv3d(base_channels * 2, base_channels * 4, stride=2, padding=1)
self.conv4 = Conv3d(base_channels * 4, base_channels * 4, padding=1)
if down_size >= 3:
self.conv5 = Conv3d(base_channels * 4, base_channels * 8, stride=2, padding=1)
self.conv6 = Conv3d(base_channels * 8, base_channels * 8, padding=1)
self.conv7 = Deconv3d(base_channels * 8, base_channels * 4, stride=2, padding=1, output_padding=1)
if down_size >= 2:
self.conv9 = Deconv3d(base_channels * 4, base_channels * 2, stride=2, padding=1, output_padding=1)
self.conv11 = Deconv3d(base_channels * 2, base_channels * 1, stride=2, padding=1, output_padding=1)
self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False)
def forward(self, x):
if self.down_size==3:
conv0 = self.conv0(x)
conv2 = self.conv2(self.conv1(conv0))
conv4 = self.conv4(self.conv3(conv2))
x = self.conv6(self.conv5(conv4))
x = conv4 + self.conv7(x)
x = conv2 + self.conv9(x)
x = conv0 + self.conv11(x)
x = self.prob(x)
elif self.down_size==2:
conv0 = self.conv0(x)
conv2 = self.conv2(self.conv1(conv0))
x = self.conv4(self.conv3(conv2))
x = conv2 + self.conv9(x)
x = conv0 + self.conv11(x)
x = self.prob(x)
else:
conv0 = self.conv0(x)
x = self.conv2(self.conv1(conv0))
x = conv0 + self.conv11(x)
x = self.prob(x)
return x
class P3DConv(nn.Module):
"""
Pseudo 3D conv: 3x3x1 + 1x3x3
"""
def __init__(self, in_channels, base_channels):
super(P3DConv, self).__init__()
self.conv0 = PConv3d(in_channels, base_channels, padding=1)
self.conv1 = PConv3d(base_channels, base_channels * 2, stride=2, padding=1)
self.conv2 = PConv3d(base_channels * 2, base_channels * 2, padding=1)
self.conv3 = PConv3d(base_channels * 2, base_channels * 4, stride=2, padding=1)
self.conv4 = PConv3d(base_channels * 4, base_channels * 4, padding=1)
self.conv5 = PConv3d(base_channels * 4, base_channels * 8, stride=2, padding=1)
self.conv6 = PConv3d(base_channels * 8, base_channels * 8, padding=1)
self.conv7 = PDeconv3d(base_channels * 8, base_channels * 4, stride=2, padding=1, output_padding=1)
self.conv9 = PDeconv3d(base_channels * 4, base_channels * 2, stride=2, padding=1, output_padding=1)
self.conv11 = PDeconv3d(base_channels * 2, base_channels * 1, stride=2, padding=1, output_padding=1)
self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False)
def forward(self, x):
conv0 = self.conv0(x)
conv2 = self.conv2(self.conv1(conv0))
conv4 = self.conv4(self.conv3(conv2))
x = self.conv6(self.conv5(conv4))
x = conv4 + self.conv7(x)
x = conv2 + self.conv9(x)
x = conv0 + self.conv11(x)
x = self.prob(x)
return x
class RefineNet(nn.Module):
def __init__(self):
super(RefineNet, self).__init__()
self.conv1 = ConvBnReLU(4, 32)
self.conv2 = ConvBnReLU(32, 32)
self.conv3 = ConvBnReLU(32, 32)
self.res = ConvBnReLU(32, 1)
def forward(self, img, depth_init):
concat = F.cat((img, depth_init), dim=1)
depth_residual = self.res(self.conv3(self.conv2(self.conv1(concat))))
depth_refined = depth_init + depth_residual
return depth_refined
def depth_regression(p, depth_values):
if depth_values.dim() <= 2:
# print("regression dim <= 2")
depth_values = depth_values.view(*depth_values.shape, 1, 1)
depth = torch.sum(p * depth_values, 1)
return depth
def cas_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
depth_loss_weights = kwargs.get("dlossw", None)
total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "stage" in k]:
depth_est = stage_inputs["depth"]
depth_gt = depth_gt_ms[stage_key]
mask = mask_ms[stage_key]
mask = mask > 0.5
depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean')
if depth_loss_weights is not None:
stage_idx = int(stage_key.replace("stage", "")) - 1
total_loss += depth_loss_weights[stage_idx] * depth_loss
else:
total_loss += 1.0 * depth_loss
return total_loss, depth_loss
def cas_mvsnet_T_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
depth_loss_weights = kwargs.get("dlossw", None)
l1ce_lw = kwargs.get("l1ce_lw", [0.1, 1])
range_thres = kwargs.get("range_thres", [84.8, 10.6])
cas_method = kwargs.get("cascade_method", None)
last_conv3d = kwargs.get("last_conv3d", False)
visual = kwargs.get("visual", False)
wt = kwargs.get("wt", False)
fl = kwargs.get("fl", False)
shrink_method = kwargs.get("shrink_method", 'schedule')
upsampled_loss = kwargs.get("upsampled_loss", False)
selected_loss = kwargs.get("selected_loss", False)
mask_range_loss = kwargs.get("mask_range_loss", False)
det = kwargs.get("det", False)
if visual:
f, axs = plt.subplots(figsize=(30, 10),ncols=3) # depth offset
f2, axs2 = plt.subplots(figsize=(30, 10),ncols=3) # attn weight max
f3, axs3 = plt.subplots(figsize=(30, 10),ncols=3) # attn weight gt val
f4, axs4 = plt.subplots(figsize=(30, 10),ncols=3) # max gt offset
err_848_str = ''
err_106_str = ''
err_002_str = ''
total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
stage_depth_loss = []
stage_ce_loss = []
range_err_ratio = []
upsampled_depth_losses = []
det_offset_losses = []
for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if "stage" in k]):
depth_est = stage_inputs["depth"]
B,H,W = depth_est.shape
mask = mask_ms[stage_key]
mask = mask > 0.5
depth_gt = depth_gt_ms[stage_key]
if upsampled_loss:
if stage_idx!=0 :
upsampled_depth = stage_inputs["upsampled_depth"]
upsampled_depth_loss = F.smooth_l1_loss(upsampled_depth[mask], depth_gt[mask], reduction='mean')
upsampled_depth_losses.append(upsampled_depth_loss)
else:
if stage_idx!=0 :
upsampled_depth_losses.append(torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False))
if mask_range_loss:
if stage_idx != 0:
depth_offset = next_stage_depth_hypo - depth_gt # B H W
this_stage_mask_range = torch.abs(depth_offset)<range_thres[stage_idx-1]
mask = mask & this_stage_mask_range # B H W
next_stage_depth_hypo = F.interpolate(depth_est.unsqueeze(1), scale_factor=2, mode='bilinear', align_corners=True).squeeze(1)
if stage_idx != len(range_thres):
depth_offset = depth_est - depth_gt
depth_offset[~mask] = 0
depth_offset = depth_offset # B H W
range_err_ratio.append((torch.abs(depth_offset)>range_thres[stage_idx]).float().mean())
if visual:
depth_offset = depth_est - depth_gt
depth_offset[~mask] = 0
depth_offset = depth_offset.detach().cpu().numpy()[0] # H W
err_848_str += str((np.abs(depth_offset)>84.8).sum()) + ','
err_106_str += str((np.abs(depth_offset)>10.6).sum()) + ','
err_002_str += str((np.abs(depth_offset)>2).sum()) + ','
sns.heatmap(depth_offset, annot=False, ax=axs[stage_idx])
attn_weights = stage_inputs["attn_weights"][0] # D H W
attn_weights_max, ind_max = torch.max(attn_weights, 0)
attn_weights_max = attn_weights_max.detach().cpu().numpy() # H W
sns.heatmap(attn_weights_max, annot=False, ax=axs2[stage_idx])
this_stage_depth_val = stage_inputs['depth_values'] # B D H W
depth_offsets = torch.abs(this_stage_depth_val- depth_gt[:,None,:,:])[0] # D,H,W
_, indices = torch.min(depth_offsets, dim=0, keepdim=True) # [1, H, W]
attn_gt = torch.gather(attn_weights, 0, indices)[0] # [H W]
attn_gt = attn_gt.detach().cpu().numpy()
sns.heatmap(attn_gt, annot=False, ax=axs3[stage_idx])
max_gt_offset = ind_max - indices[0] # H W
max_gt_offset = max_gt_offset.detach().cpu().numpy()
sns.heatmap(max_gt_offset, annot=False, ax=axs4[stage_idx])
if cas_method[stage_idx] == 't' or cas_method[stage_idx] == 'r' or cas_method[stage_idx] == 'p':
# Loss for transformer
depth_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
if last_conv3d:
depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean')
attn_weights = stage_inputs["attn_weights"].permute(0,2,3,1).reshape(B*H*W, -1) # BHW D
this_stage_depth_val = stage_inputs['depth_values'] # B D H W
depth_offsets = torch.abs(this_stage_depth_val- depth_gt[:,None,:,:]) # B,D,H,W
_, indices = torch.min(depth_offsets, dim=1) # [B, H, W]
indices = indices.reshape(-1) # [BHW]
mask = mask.reshape(-1) # BHW
if fl: # -p(1-q)^a log(q)
this_stage_ce_loss = F.nll_loss((1-attn_weights[mask])**2 * torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')
else: # -plog(q)
this_stage_ce_loss = F.nll_loss(torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')
stage_depth_loss.append(depth_loss)
stage_ce_loss.append(this_stage_ce_loss)
this_stage_loss = l1ce_lw[0]*depth_loss + l1ce_lw[1]*this_stage_ce_loss
# Loss for 3D conv
else:
if wt:
depth_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
stage_depth_loss.append(depth_loss)
attn_weights = stage_inputs["attn_weights"].permute(0,2,3,1).reshape(B*H*W, -1) # BHW D
depth_offsets = torch.abs(stage_inputs['depth_values']- depth_gt[:,None,:,:]) # B,D,H,W
indices = torch.min(depth_offsets, dim=1)[1].reshape(-1) # [BHW]
mask = mask.reshape(-1) # BHW
if fl: # -p(1-q)^a log(q)
this_stage_ce_loss = F.nll_loss((1-attn_weights[mask])**2 * torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')
else: # -plog(q)
this_stage_ce_loss = F.nll_loss(torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')
stage_ce_loss.append(this_stage_ce_loss)
else:
depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean')
stage_depth_loss.append(depth_loss)
this_stage_ce_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
stage_ce_loss.append(this_stage_ce_loss)
this_stage_loss = l1ce_lw[0]*depth_loss + l1ce_lw[1]*this_stage_ce_loss
if upsampled_loss:
if stage_idx!=0:
this_stage_loss = this_stage_loss + upsampled_depth_loss * l1ce_lw[0]
if shrink_method == 'DPF':
if stage_idx!=0:
depth_offsets = stage_inputs['depth_values'] - depth_gt[:,None,:,:] # B,D,H,W
depth_offset_clamp = torch.clamp(depth_offsets, -1, 1)
this_stage_loss = this_stage_loss + torch.abs(depth_offset_clamp).permute(0,2,3,1).reshape(B*H*W, -1)[mask.reshape(-1)].mean()
if selected_loss:
select_weight = stage_inputs["select_weight"].permute(0,2,3,1).reshape(B*H*W, -1) # BHW D
depth_offsets = torch.abs(stage_inputs['depth_values']- depth_gt[:,None,:,:])
indices = torch.min(depth_offsets, dim=1)[1] # [B, H, W]
indices = indices.reshape(-1) # [BHW]
mask = mask.reshape(-1) # BHW
this_stage_selected_loss = F.nll_loss(torch.log(select_weight[mask]+1e-12), indices[mask], reduce='mean')
this_stage_loss = this_stage_loss + this_stage_selected_loss * 0.01*l1ce_lw[1]
if det:
assert wt
depth_itv = stage_inputs['depth_values'][:,1,:,:] - stage_inputs['depth_values'][:,0,:,:] # B H W
pred_offset = stage_inputs['offset_reg'].reshape(-1) # BHW
offset_gt = (depth_gt - (depth_est - stage_inputs['offset_reg'])).reshape(-1) / depth_itv.reshape(-1) # BHW
det_offset_loss = F.smooth_l1_loss(pred_offset[mask], offset_gt[mask], reduction='mean')
det_offset_losses.append(det_offset_loss)
this_stage_loss += det_offset_loss
else:
det_offset_losses.append(torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False))
if depth_loss_weights is not None:
stage_idx = int(stage_key.replace("stage", "")) - 1
total_loss += depth_loss_weights[stage_idx] * this_stage_loss
else:
total_loss += 1.0 * this_stage_loss
if visual:
axs[1].set_title('err848:{}'.format(err_848_str) + 'err_106:{}'.format(err_106_str) + 'err_002:{}'.format(err_002_str))
f.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/offset_heatmap.png')
f.clf()
f2.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/attn_max_heatmap.png')
f2.clf()
f3.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/attn_gt_heatmap.png')
f3.clf()
f4.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/max_gt_offset_heatmap.png')
f4.clf()
return total_loss, depth_loss, stage_depth_loss, stage_ce_loss, range_err_ratio, upsampled_depth_losses, det_offset_losses
def get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, shape, max_depth=192.0, min_depth=0.0):
#shape, (B, H, W)
#cur_depth: (B, H, W)
#return depth_range_values: (B, D, H, W)
cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel) # (B, H, W)
cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel)
# cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel).clamp(min=0.0) #(B, H, W)
# cur_depth_max = (cur_depth_min + (ndepth - 1) * depth_inteval_pixel).clamp(max=max_depth)
assert cur_depth.shape == torch.Size(shape), "cur_depth:{}, input shape:{}".format(cur_depth.shape, shape)
new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1) # (B, H, W)
depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=cur_depth.device,
dtype=cur_depth.dtype,
requires_grad=False).reshape(1, -1, 1, 1) * new_interval.unsqueeze(1))
return depth_range_samples
def get_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, device, dtype, shape,
max_depth=192.0, min_depth=0.0):
#shape: (B, H, W)
#cur_depth: (B, H, W) or (B, D)
#return depth_range_samples: (B, D, H, W)
if cur_depth.dim() == 2:
cur_depth_min = cur_depth[:, 0] # (B,)
cur_depth_max = cur_depth[:, -1]
new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1) # (B, )
depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=device, dtype=dtype,
requires_grad=False).reshape(1, -1) * new_interval.unsqueeze(1)) #(B, D)
depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, shape[1], shape[2]) #(B, D, H, W)
else:
depth_range_samples = get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, shape, max_depth, min_depth)
return depth_range_samples
if __name__ == "__main__":
# some testing code, just IGNORE it
import sys
sys.path.append("../")
from datasets import find_dataset_def
from torch.utils.data import DataLoader
import numpy as np
import cv2
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
# MVSDataset = find_dataset_def("colmap")
# dataset = MVSDataset("../data/results/ford/num10_1/", 3, 'test',
# 128, interval_scale=1.06, max_h=1250, max_w=1024)
MVSDataset = find_dataset_def("dtu_yao")
num_depth = 48
dataset = MVSDataset("../data/DTU/mvs_training/dtu/", '../lists/dtu/train.txt', 'train',
3, num_depth, interval_scale=1.06 * 192 / num_depth)
dataloader = DataLoader(dataset, batch_size=1)
item = next(iter(dataloader))
imgs = item["imgs"][:, :, :, ::4, ::4] #(B, N, 3, H, W)
# imgs = item["imgs"][:, :, :, :, :]
proj_matrices = item["proj_matrices"] #(B, N, 2, 4, 4) dim=N: N view; dim=2: index 0 for extr, 1 for intric
proj_matrices[:, :, 1, :2, :] = proj_matrices[:, :, 1, :2, :]
# proj_matrices[:, :, 1, :2, :] = proj_matrices[:, :, 1, :2, :] * 4
depth_values = item["depth_values"] #(B, D)
imgs = torch.unbind(imgs, 1)
proj_matrices = torch.unbind(proj_matrices, 1)
ref_img, src_imgs = imgs[0], imgs[1:]
ref_proj, src_proj = proj_matrices[0], proj_matrices[1:][0] #only vis first view
src_proj_new = src_proj[:, 0].clone()
src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4])
ref_proj_new = ref_proj[:, 0].clone()
ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4])
warped_imgs = homo_warping(src_imgs[0], src_proj_new, ref_proj_new, depth_values)
ref_img_np = ref_img.permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255
cv2.imwrite('../tmp/ref.png', ref_img_np)
cv2.imwrite('../tmp/src.png', src_imgs[0].permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255)
for i in range(warped_imgs.shape[2]):
warped_img = warped_imgs[:, :, i, :, :].permute([0, 2, 3, 1]).contiguous()
img_np = warped_img[0].detach().cpu().numpy()
img_np = img_np[:, :, ::-1] * 255
alpha = 0.5
beta = 1 - alpha
gamma = 0
img_add = cv2.addWeighted(ref_img_np, alpha, img_np, beta, gamma)
cv2.imwrite('../tmp/tmp{}.png'.format(i), np.hstack([ref_img_np, img_np, img_add])) #* ratio + img_np*(1-ratio)]))
================================================
FILE: models/mvs4net_utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib
try:
from modules.deform_conv import DeformConvPack
except:
print('DeformConvPack not found, please install it from: https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch')
pass
import math
import numpy as np
def homo_warping(src_fea, src_proj, ref_proj, depth_values, vis_ETA=False, fn=None):
# src_fea: [B, C, H, W]
# src_proj: [B, 4, 4]
# ref_proj: [B, 4, 4]
# depth_values: [B, Ndepth] o [B, Ndepth, H, W]
# out: [B, C, Ndepth, H, W]
C = src_fea.shape[1]
Hs,Ws = src_fea.shape[-2:]
B,num_depth,Hr,Wr = depth_values.shape
with torch.no_grad():
proj = torch.matmul(src_proj, torch.inverse(ref_proj))
rot = proj[:, :3, :3] # [B,3,3]
trans = proj[:, :3, 3:4] # [B,3,1]
y, x = torch.meshgrid([torch.arange(0, Hr, dtype=torch.float32, device=src_fea.device),
torch.arange(0, Wr, dtype=torch.float32, device=src_fea.device)])
y = y.reshape(Hr*Wr)
x = x.reshape(Hr*Wr)
xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
xyz = torch.unsqueeze(xyz, 0).repeat(B, 1, 1) # [B, 3, H*W]
rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.reshape(B, 1, num_depth, -1) # [B, 3, Ndepth, H*W]
proj_xyz = rot_depth_xyz + trans.reshape(B, 3, 1, 1) # [B, 3, Ndepth, H*W]
# FIXME divide 0
temp = proj_xyz[:, 2:3, :, :]
temp[temp==0] = 1e-9
proj_xy = proj_xyz[:, :2, :, :] / temp # [B, 2, Ndepth, H*W]
# proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
proj_x_normalized = proj_xy[:, 0, :, :] / ((Ws - 1) / 2) - 1
proj_y_normalized = proj_xy[:, 1, :, :] / ((Hs - 1) / 2) - 1
proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]
if vis_ETA:
tensor_saved = proj_xy.reshape(B,num_depth,Hs,Ws,2).cpu().numpy()
np.save(fn+'_grid', tensor_saved)
grid = proj_xy
if len(src_fea.shape)==4:
warped_src_fea = F.grid_sample(src_fea, grid.reshape(B, num_depth * Hr, Wr, 2), mode='bilinear', padding_mode='zeros', align_corners=True)
warped_src_fea = warped_src_fea.reshape(B, C, num_depth, Hr, Wr)
elif len(src_fea.shape)==5:
warped_src_fea = []
for d in range(src_fea.shape[2]):
warped_src_fea.append(F.grid_sample(src_fea[:,:,d], grid.reshape(B, num_depth, Hr, Wr, 2)[:,d], mode='bilinear', padding_mode='zeros', align_corners=True))
warped_src_fea = torch.stack(warped_src_fea, dim=2)
return warped_src_fea
def init_range(cur_depth, ndepths, device, dtype, H, W):
cur_depth_min = cur_depth[:, 0] # (B,)
cur_depth_max = cur_depth[:, -1]
new_interval = (cur_depth_max - cur_depth_min) / (ndepths - 1) # (B, )
new_interval = new_interval[:, None, None] # B H W
depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepths, device=device, dtype=dtype,
requires_grad=False).reshape(1, -1) * new_interval.squeeze(1)) #(B, D)
depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, H, W) #(B, D, H, W)
return depth_range_samples
def init_inverse_range(cur_depth, ndepths, device, dtype, H, W):
inverse_depth_min = 1. / cur_depth[:, 0] # (B,)
inverse_depth_max = 1. / cur_depth[:, -1]
itv = torch.arange(0, ndepths, device=device, dtype=dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H, W) / (ndepths - 1) # 1 D H W
inverse_depth_hypo = inverse_depth_max[:,None, None, None] + (inverse_depth_min - inverse_depth_max)[:,None, None, None] * itv
return 1./inverse_depth_hypo
def schedule_inverse_range(inverse_min_depth, inverse_max_depth, ndepths, H, W):
#cur_depth_min, (B, H, W)
#cur_depth_max: (B, H, W)
itv = torch.arange(0, ndepths, device=inverse_min_depth.device, dtype=inverse_min_depth.dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H//2, W//2) / (ndepths - 1) # 1 D H W
inverse_depth_hypo = inverse_max_depth[:,None, :, :] + (inverse_min_depth - inverse_max_depth)[:,None, :, :] * itv # B D H W
inverse_depth_hypo = F.interpolate(inverse_depth_hypo.unsqueeze(1), [ndepths, H, W], mode='trilinear', align_corners=True).squeeze(1)
return 1./inverse_depth_hypo
def schedule_range(cur_depth, ndepth, depth_inteval_pixel, H, W):
#shape, (B, H, W)
#cur_depth: (B, H, W)
#return depth_range_values: (B, D, H, W)
cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel[:,None,None]) # (B, H, W)
cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel[:,None,None])
new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1) # (B, H, W)
depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=cur_depth.device, dtype=cur_depth.dtype,
requires_grad=False).reshape(1, -1, 1, 1) * new_interval.unsqueeze(1))
depth_range_samples = F.interpolate(depth_range_samples.unsqueeze(1), [ndepth, H, W], mode='trilinear', align_corners=True).squeeze(1)
return depth_range_samples
def init_bn(module):
if module.weight is not None:
nn.init.ones_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
return
def init_uniform(module, init_method):
if module.weight is not None:
if init_method == "kaiming":
nn.init.kaiming_uniform_(module.weight)
elif init_method == "xavier":
nn.init.xavier_uniform_(module.weight)
return
class ConvBnReLU3D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
super(ConvBnReLU3D, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm3d(out_channels)
def forward(self, x):
return F.relu(self.bn(self.conv(x)), inplace=True)
class ConvBnReLU3D_CAM(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
super(ConvBnReLU3D_CAM, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm3d(out_channels)
self.linear_agg = nn.Sequential(
nn.Linear(out_channels, out_channels//2),
nn.ReLU(),
nn.Linear(out_channels//2, out_channels)
)
def forward(self, input):
x = self.conv(input)
B,C,D,H,W = x.shape
avg_attn = self.linear_agg(x.reshape(B,C,D*H*W).mean(2))
max_attn = self.linear_agg(x.reshape(B,C,D*H*W).max(2)[0]) # B C
attn = F.sigmoid(max_attn+avg_attn)[:,:,None,None,None] # B C,1,1,1
x = x * attn
return F.relu(self.bn(x+input), inplace=True)
class ConvBnReLU3D_DCAM(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
super(ConvBnReLU3D_DCAM, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm3d(out_channels)
self.linear_agg = nn.Sequential(
nn.Linear(out_channels, out_channels//2),
nn.ReLU(),
nn.Linear(out_channels//2, out_channels)
)
def forward(self, input):
x = self.conv(input)
B,C,D,H,W = x.shape
avg_attn = self.linear_agg(x.reshape(B,C,D,H*W).mean(3).permute(0,2,1).reshape(B*D,C)).reshape(B,D,C).permute(0,2,1)
max_attn = self.linear_agg(x.reshape(B,C,D,H*W).max(3)[0].permute(0,2,1).reshape(B*D,C)).reshape(B,D,C).permute(0,2,1) # B C D
attn = F.sigmoid(max_attn+avg_attn)[:,:,:,None,None] # B C,D,1,1
x = x * attn
return F.relu(self.bn(x+input), inplace=True)
class ConvBnReLU3D_PAM(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
super(ConvBnReLU3D_PAM, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm3d(out_channels)
self.pixel_conv = nn.Conv2d(2,1,7,stride=1,padding='same')
def forward(self, input):
x = self.conv(input)
B,C,D,H,W = x.shape
max_attn = x.reshape(B,C*D,H,W).max(1, keepdim=True)[0]
avg_attn = x.reshape(B,C*D,H,W).mean(1, keepdim=True) # B 1 H W
attn = F.sigmoid(self.pixel_conv(torch.cat([max_attn, avg_attn], dim=1)))[:,:,None,:,:] # B 1,1,H,W
x = x * attn
return F.relu(self.bn(x+input), inplace=True)
class ConvBnReLU3D_PDAM(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
super(ConvBnReLU3D_PDAM, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
self.bn = nn.BatchNorm3d(out_channels)
self.spatial_conv = nn.Conv3d(2,1,7,stride=1,padding='same')
def forward(self, input):
x = self.conv(input)
B,C,D,H,W = x.shape
max_attn = x.max(1, keepdim=True)[0]
avg_attn = x.mean(1, keepdim=True) # B 1 D H W
attn = F.sigmoid(self.spatial_conv(torch.cat([max_attn, avg_attn], dim=1))) # B 1,D,H,W
x = x * attn
return F.relu(self.bn(x+input), inplace=True)
class Deconv3d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
super(Deconv3d, self).__init__()
self.out_channels = out_channels
assert stride in [1, 2]
self.stride = stride
self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
bias=(not bn), **kwargs)
self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
def forward(self, x):
y = self.conv(x)
if self.bn is not None:
x = self.bn(y)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class Conv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
relu=True, bn_momentum=0.1, init_method="xavier", gn=False, group_channel=8, **kwargs):
super(Conv2d, self).__init__()
bn = not gn
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
bias=(not bn), **kwargs)
self.kernel_size = kernel_size
self.stride = stride
self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
self.gn = nn.GroupNorm(int(max(1, out_channels / group_channel)), out_channels) if gn else None
self.relu = relu
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
else:
x = self.gn(x)
if self.relu:
x = F.relu(x, inplace=True)
return x
def init_weights(self, init_method):
init_uniform(self.conv, init_method)
if self.bn is not None:
init_bn(self.bn)
class Deconv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
super(Deconv2d, self).__init__()
self.out_channels = out_channels
assert stride in [1, 2]
self.stride = stride
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride,
bias=(not bn), **kwargs)
self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
self.relu = relu
class DeformConv2d(nn.Module):
def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=True):
super(DeformConv2d, self).__init__()
self.kernel_size = kernel_size
self.padding = padding
self.stride = stride
self.zero_padding = nn.ZeroPad2d(padding)
self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
nn.init.constant_(self.p_conv.weight, 0)
self.p_conv.register_backward_hook(self._set_lr)
self.modulation = modulation
if modulation:
self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
nn.init.constant_(self.m_conv.weight, 0)
self.m_conv.register_backward_hook(self._set_lr)
@staticmethod
def _set_lr(module, grad_input, grad_output):
grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
def forward(self, x):
offset = self.p_conv(x)
if self.modulation:
m = torch.sigmoid(self.m_conv(x))
dtype = offset.data.type()
ks = self.kernel_size
N = offset.size(1) // 2
if self.padding:
x = self.zero_padding(x)
# (b, 2N, h, w)
p = self._get_p(offset, dtype)
# (b, h, w, 2N)
p = p.contiguous().permute(0, 2, 3, 1)
q_lt = p.detach().floor()
q_rb = q_lt + 1
q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
# clip p
p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)
# bilinear kernel (b, h, w, N)
g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
# (b, c, h, w, N)
x_q_lt = self._get_x_q(x, q_lt, N)
x_q_rb = self._get_x_q(x, q_rb, N)
x_q_lb = self._get_x_q(x, q_lb, N)
x_q_rt = self._get_x_q(x, q_rt, N)
# (b, c, h, w, N)
x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
g_rb.unsqueeze(dim=1) * x_q_rb + \
g_lb.unsqueeze(dim=1) * x_q_lb + \
g_rt.unsqueeze(dim=1) * x_q_rt
# modulation
if self.modulation:
m = m.contiguous().permute(0, 2, 3, 1)
m = m.unsqueeze(dim=1)
m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
x_offset *= m
x_offset = self._reshape_x_offset(x_offset, ks)
out = self.conv(x_offset)
return out
def _get_p_n(self, N, dtype):
p_n_x, p_n_y = torch.meshgrid(
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
# (2N, 1)
p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
return p_n
def _get_p_0(self, h, w, N, dtype):
p_0_x, p_0_y = torch.meshgrid(
torch.arange(1, h*self.stride+1, self.stride),
torch.arange(1, w*self.stride+1, self.stride))
p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
return p_0
def _get_p(self, offset, dtype):
N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
# (1, 2N, 1, 1)
p_n = self._get_p_n(N, dtype)
# (1, 2N, h, w)
p_0 = self._get_p_0(h, w, N, dtype)
p = p_0 + p_n + offset
return p
def _get_x_q(self, x, q, N):
b, h, w, _ = q.size()
padded_w = x.size(3)
c = x.size(1)
# (b, c, h*w)
x = x.contiguous().view(b, c, -1)
# (b, h, w, N)
index = q[..., :N]*padded_w + q[..., N:] # offset_x*w + offset_y
# (b, c, h*w*N)
index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
return x_offset
@staticmethod
def _reshape_x_offset(x_offset, ks):
b, c, h, w, N = x_offset.size()
x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)
return x_offset
def NA_DCN(in_channels, kernel_size=3, stride=1, dilation=1, bias=True, group_channel=8, gn=False):
if gn:
return nn.Sequential(
nn.GroupNorm(int(max(1, in_channels / group_channel)), in_channels),
nn.ReLU(inplace=True),
# DeformConv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, bias=bias),
DeformConvPack(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=1, deformable_groups=1, bias=False, im2col_step=16)
)
else:
return nn.Sequential(
nn.BatchNorm2d(in_channels, momentum=0.1),
nn.ReLU(inplace=True),
# DeformConv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, bias=bias),
DeformConvPack(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=1, deformable_groups=1, bias=False, im2col_step=16)
)
class FPN4(nn.Module):
"""
FPN aligncorners downsample 4x"""
def __init__(self, base_channels, gn=False, dcn=False):
super(FPN4, self).__init__()
self.base_channels = base_channels
self.conv0 = nn.Sequential(
Conv2d(3, base_channels, 3, 1, padding=1, gn=gn),
Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn),
)
self.conv1 = nn.Sequential(
Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2, gn=gn),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn),
Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn),
)
self.conv2 = nn.Sequential(
Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2, gn=gn),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn),
Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn),
)
self.conv3 = nn.Sequential(
Conv2d(base_channels * 4, base_channels * 8, 5, stride=2, padding=2, gn=gn),
Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn),
Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn),
)
self.out_channels = [8 * base_channels]
final_chs = base_channels * 8
self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)
self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)
self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)
self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)
self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)
self.dcn = dcn
if self.dcn:
self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn)
self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn)
self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn)
self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn)
self.out_channels.append(base_channels * 4)
self.out_channels.append(base_channels * 2)
self.out_channels.append(base_channels)
def forward(self, x):
conv0 = self.conv0(x)
conv1 = self.conv1(conv0)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
intra_feat = conv3
outputs = {}
out1 = self.out1(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2)
out2 = self.out2(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1)
out3 = self.out3(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0)
out4 = self.out4(intra_feat)
if self.dcn:
out1 = self.dcn1(out1)
out2 = self.dcn2(out2)
out3 = self.dcn3(out3)
out4 = self.dcn4(out4)
outputs["stage1"] = out1
outputs["stage2"] = out2
outputs["stage3"] = out3
outputs["stage4"] = out4
return outputs
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class convnext_block(nn.Module):
def __init__(self, dim, layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(dim, 2*dim, kernel_size=7, stride=2, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(2*dim, eps=1e-6)
self.pwconv1 = nn.Linear(2*dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, 2*dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((2*dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
def forward(self, x):
input = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
# x = input + x
return x
class convnext4_block(nn.Module):
def __init__(self, dim, layer_scale_init_value=1e-6):
super().__init__()
self.sconv = nn.Conv2d(dim, 2*dim, kernel_size=2, stride=2, padding=0) # stride=2 conv
self.dwconv = nn.Conv2d(2*dim, 2*dim, kernel_size=7, stride=1, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(2*dim, eps=1e-6)
self.pwconv1 = nn.Linear(2*dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, 2*dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((2*dim)),
requires_grad=True) if layer_scale_init_value > 0 else None
def forward(self, x):
input = self.sconv(x)
x = self.dwconv(input)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + x
return x
class FPN4_convnext(nn.Module):
"""
FPN aligncorners downsample 4x"""
def __init__(self, base_channels, gn=False, dcn=False):
super(FPN4_convnext, self).__init__()
self.base_channels = base_channels
self.conv0 = nn.Sequential(
Conv2d(3, base_channels, 3, 1, padding=1, gn=gn),
Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn),
)
self.conv1 = convnext_block(base_channels)
self.conv2 = convnext_block(2*base_channels)
self.conv3 = convnext_block(4*base_channels)
self.out_channels = [8 * base_channels]
final_chs = base_channels * 8
self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)
self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)
self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)
self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)
self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)
self.dcn = dcn
if self.dcn:
self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn)
self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn)
self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn)
self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn)
self.out_channels.append(base_channels * 4)
self.out_channels.append(base_channels * 2)
self.out_channels.append(base_channels)
def forward(self, x):
conv0 = self.conv0(x)
conv1 = self.conv1(conv0)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
intra_feat = conv3
outputs = {}
out1 = self.out1(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2)
out2 = self.out2(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1)
out3 = self.out3(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0)
out4 = self.out4(intra_feat)
if self.dcn:
out1 = self.dcn1(out1)
out2 = self.dcn2(out2)
out3 = self.dcn3(out3)
out4 = self.dcn4(out4)
outputs["stage1"] = out1
outputs["stage2"] = out2
outputs["stage3"] = out3
outputs["stage4"] = out4
return outputs
class FPN4_convnext4(nn.Module):
"""
FPN aligncorners downsample 4x"""
def __init__(self, base_channels, gn=False, dcn=False):
super(FPN4_convnext4, self).__init__()
self.base_channels = base_channels
self.conv0 = nn.Sequential(
Conv2d(3, base_channels, 3, 1, padding=1, gn=gn),
Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn),
)
self.conv1 = convnext4_block(base_channels)
self.conv2 = convnext4_block(2*base_channels)
self.conv3 = convnext4_block(4*base_channels)
self.out_channels = [8 * base_channels]
final_chs = base_channels * 8
self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)
self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)
self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)
self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)
self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)
self.dcn = dcn
if self.dcn:
self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn)
self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn)
self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn)
self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn)
self.out_channels.append(base_channels * 4)
self.out_channels.append(base_channels * 2)
self.out_channels.append(base_channels)
def forward(self, x):
conv0 = self.conv0(x)
conv1 = self.conv1(conv0)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
intra_feat = conv3
outputs = {}
out1 = self.out1(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2)
out2 = self.out2(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1)
out3 = self.out3(intra_feat)
intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0)
out4 = self.out4(intra_feat)
if self.dcn:
out1 = self.dcn1(out1)
out2 = self.dcn2(out2)
out3 = self.dcn3(out3)
out4 = self.dcn4(out4)
outputs["stage1"] = out1
outputs["stage2"] = out2
outputs["stage3"] = out3
outputs["stage4"] = out4
return outputs
class ASFF(nn.Module):
def __init__(self, level):
super(ASFF, self).__init__()
self.level = level
self.dim = [64,32,16,8]
self.inter_dim = self.dim[self.level]
if level==0:
self.stride_level_1 = Conv2d(32, 64, 3, stride=2, padding=1)
self.stride_level_2 = Conv2d(16, 64, 3, stride=2, padding=1)
self.stride_level_3 = Conv2d(8, 64, 3, stride=2, padding=1)
self.expand = Conv2d(64, 64, 3, stride=1, padding=1)
elif level==1:
self.compress_level_0 = Conv2d(64, 32, 1, stride=1, padding=0)
self.stride_level_2 = Conv2d(16, 32, 3, stride=2, padding=1)
self.stride_level_3 = Conv2d(8, 32, 3, stride=2, padding=1)
self.expand = Conv2d(32, 32, 3, stride=1, padding=1)
elif level==2:
self.compress_level_0 = Conv2d(64, 16, 1, stride=1, padding=0)
self.compress_level_1 = Conv2d(32, 16, 1, stride=1, padding=0)
self.stride_level_3 = Conv2d(8, 16, 3, stride=2, padding=1)
self.expand = Conv2d(16, 16, 3, stride=1, padding=1)
elif level==3:
self.compress_level_0 = Conv2d(64, 8, 1, stride=1, padding=0)
self.compress_level_1 = Conv2d(32, 8, 1, stride=1, padding=0)
self.compress_level_2 = Conv2d(16, 8, 1, stride=1, padding=0)
self.expand = Conv2d(8, 8, 3, stride=1, padding=1)
self.weight_level_0 = Conv2d(self.dim[level], 8, 1, 1, 0)
self.weight_level_1 = Conv2d(self.dim[level], 8, 1, 1, 0)
self.weight_level_2 = Conv2d(self.dim[level], 8, 1, 1, 0)
self.weight_level_3 = Conv2d(self.dim[level], 8, 1, 1, 0)
self.weight_levels = nn.Conv2d(32, 4, kernel_size=1, stride=1, padding=0)
def forward(self, x_level_0, x_level_1, x_level_2, x_level_3):
if self.level==0:
level_0_resized = x_level_0
level_1_resized = self.stride_level_1(x_level_1)
level_2_downsampled_inter = F.max_pool2d(x_level_2, 2, stride=2, padding=0)
level_2_resized = self.stride_level_2(level_2_downsampled_inter)
level_3_downsampled_inter = F.max_pool2d(x_level_3, 4, stride=4, padding=0)
level_3_resized = self.stride_level_3(level_3_downsampled_inter)
elif self.level==1:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')
level_1_resized = x_level_1
level_2_resized = self.stride_level_2(x_level_2)
level_3_downsampled_inter = F.max_pool2d(x_level_3, 2, stride=2, padding=0)
level_3_resized = self.stride_level_3(level_3_downsampled_inter)
elif self.level==2:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')
level_1_compressed = self.compress_level_1(x_level_1)
level_1_resized = F.interpolate(level_1_compressed, scale_factor=2, mode='nearest')
level_2_resized = x_level_2
level_3_resized = self.stride_level_3(x_level_3)
elif self.level==3:
level_0_compressed = self.compress_level_0(x_level_0)
level_0_resized = F.interpolate(level_0_compressed, scale_factor=8, mode='nearest')
level_1_compressed = self.compress_level_1(x_level_1)
level_1_resized = F.interpolate(level_1_compressed, scale_factor=4, mode='nearest')
level_2_compressed = self.compress_level_2(x_level_2)
level_2_resized = F.interpolate(level_2_compressed, scale_factor=2, mode='nearest')
level_3_resized = x_level_3
level_0_weight_v = self.weight_level_0(level_0_resized)
level_1_weight_v = self.weight_level_1(level_1_resized)
level_2_weight_v = self.weight_level_2(level_2_resized)
level_3_weight_v = self.weight_level_3(level_3_resized)
levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v, level_3_weight_v),1)
levels_weight = self.weight_levels(levels_weight_v)
levels_weight = F.softmax(levels_weight, dim=1)
fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+\
level_1_resized * levels_weight[:,1:2,:,:]+\
level_2_resized * levels_weight[:,2:3,:,:]+\
level_3_resized * levels_weight[:,3:,:,:]
out = self.expand(fused_out_reduced)
return out
class FullImageEncoder(nn.Module):
def __init__(self, h, w, kernel_size):
super(FullImageEncoder, self).__init__()
self.global_pooling = nn.AvgPool2d(kernel_size, stride=kernel_size, padding=kernel_size // 2) # KITTI 16 16
self.dropout = nn.Dropout2d(p=0.5)
self.h = h // kernel_size + 1
self.w = w // kernel_size + 1
# print("h=", self.h, " w=", self.w, h, w)
self.global_fc = nn.Linear(2048 * self.h * self.w, 512) # kitti 4x5
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(512, 512, 1) # 1x1 卷积
def forward(self, x):
# print('x size:', x.size())
x1 = self.global_pooling(x)
# print('# x1 size:', x1.size())
x2 = self.dropout(x1)
x3 = x2.view(-1, 2048 * self.h * self.w) # kitti 4x5
x4 = self.relu(self.global_fc(x3))
# print('# x4 size:', x4.size())
x4 = x4.view(-1, 512, 1, 1)
# print('# x4 size:', x4.size())
x5 = self.conv1(x4)
# out = self.upsample(x5)
return x5
class mono_depth_decoder(nn.Module):
def __init__(self):
super(mono_depth_decoder, self).__init__()
self.convblocks = nn.ModuleList(
[Conv2d(64, 32, 3, 1, padding=1),
Conv2d(32, 16, 3, 1, padding=1),
Conv2d(16, 8, 3, 1, padding=1)]
)
self.conv3x3 = nn.ModuleList(
[nn.Conv2d(64, 1, 3, 1, 1),
nn.Conv2d(32, 1, 3, 1, 1),
nn.Conv2d(16, 1, 3, 1, 1)]
)
self.sigmoid = nn.Sigmoid()
def forward(self, outputs, d_min, d_max):
"""
d_max: B
"""
for i in range(1,4): # 1 2 3
mono_small_feat = outputs['stage{}'.format(i)]['mono_feat']
mono_large_feat = outputs['stage{}'.format(i+1)]['mono_feat']
mono_small_feat = self.convblocks[i-1](mono_small_feat)
mono_small_feat = F.interpolate(mono_small_feat, scale_factor=2, mode="nearest")
mono_feat = self.conv3x3[i-1](torch.cat([mono_small_feat, mono_large_feat], 1)) # B C H W
disp = self.sigmoid(mono_feat)
min_disp = (1 / d_max)[:,None,None,None] # B 1 1 1
max_disp = (1 / d_min)[:,None,None,None]
scaled_disp = min_disp + (max_disp - min_disp) * disp
depth = 1 / scaled_disp
outputs['stage{}'.format(i+1)]['mono_depth'] = depth.squeeze(1)
return outputs
class reg2d(nn.Module):
def __init__(self, input_channel=128, base_channel=32, conv_name='ConvBnReLU3D'):
super(reg2d, self).__init__()
module = importlib.import_module("models.mvs4net_utils")
stride_conv_name = 'ConvBnReLU3D'
self.conv0 = getattr(module, stride_conv_name)(input_channel, base_channel, kernel_size=(1,3,3), pad=(0,1,1))
self.conv1 = getattr(module, stride_conv_name)(base_channel, base_channel*2, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))
self.conv2 = getattr(module, conv_name)(base_channel*2, base_channel*2)
self.conv3 = getattr(module, stride_conv_name)(base_channel*2, base_channel*4, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))
self.conv4 = getattr(module, conv_name)(base_channel*4, base_channel*4)
self.conv5 = getattr(module, stride_conv_name)(base_channel*4, base_channel*8, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))
self.conv6 = getattr(module, conv_name)(base_channel*8, base_channel*8)
self.conv7 = nn.Sequential(
nn.ConvTranspose3d(base_channel*8, base_channel*4, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),
nn.BatchNorm3d(base_channel*4),
nn.ReLU(inplace=True))
self.conv9 = nn.Sequential(
nn.ConvTranspose3d(base_channel*4, base_channel*2, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),
nn.BatchNorm3d(base_channel*2),
nn.ReLU(inplace=True))
self.conv11 = nn.Sequential(
nn.ConvTranspose3d(base_channel*2, base_channel, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),
nn.BatchNorm3d(base_channel),
nn.ReLU(inplace=True))
self.prob = nn.Conv3d(8, 1, 1, stride=1, padding=0)
def forward(self, x):
conv0 = self.conv0(x)
conv2 = self.conv2(self.conv1(conv0))
conv4 = self.conv4(self.conv3(conv2))
x = self.conv6(self.conv5(conv4))
x = conv4 + self.conv7(x)
x = conv2 + self.conv9(x)
x = conv0 + self.conv11(x)
x = self.prob(x)
return x.squeeze(1)
class reg3d(nn.Module):
def __init__(self, in_channels, base_channels, down_size=3):
super(reg3d, self).__init__()
self.down_size = down_size
self.conv0 = ConvBnReLU3D(in_channels, base_channels, kernel_size=3, pad=1)
self.conv1 = ConvBnReLU3D(base_channels, base_channels*2, kernel_size=3, stride=2, pad=1)
self.conv2 = ConvBnReLU3D(base_channels*2, base_channels*2)
if down_size >= 2:
self.conv3 = ConvBnReLU3D(base_channels*2, base_channels*4, kernel_size=3, stride=2, pad=1)
self.conv4 = ConvBnReLU3D(base_channels*4, base_channels*4)
if down_size >= 3:
self.conv5 = ConvBnReLU3D(base_channels*4, base_channels*8, kernel_size=3, stride=2, pad=1)
self.conv6 = ConvBnReLU3D(base_channels*8, base_channels*8)
self.conv7 = nn.Sequential(
nn.ConvTranspose3d(base_channels*8, base_channels*4, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),
nn.BatchNorm3d(base_channels*4),
nn.ReLU(inplace=True))
if down_size >= 2:
self.conv9 = nn.Sequential(
nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),
nn.BatchNorm3d(base_channels*2),
nn.ReLU(inplace=True))
self.conv11 = nn.Sequential(
nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),
nn.BatchNorm3d(base_channels),
nn.ReLU(inplace=True))
self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False)
def forward(self, x):
if self.down_size==3:
conv0 = self.conv0(x)
conv2 = self.conv2(self.conv1(conv0))
conv4 = self.conv4(self.conv3(conv2))
x = self.conv6(self.conv5(conv4))
x = conv4 + self.conv7(x)
x = conv2 + self.conv9(x)
x = conv0 + self.conv11(x)
x = self.prob(x)
elif self.down_size==2:
conv0 = self.conv0(x)
conv2 = self.conv2(self.conv1(conv0))
x = self.conv4(self.conv3(conv2))
x = conv2 + self.conv9(x)
x = conv0 + self.conv11(x)
x = self.prob(x)
else:
conv0 = self.conv0(x)
x = self.conv2(self.conv1(conv0))
x = conv0 + self.conv11(x)
x = self.prob(x)
return x.squeeze(1) # B D H W
class PosEncSine(nn.Module):
def __init__(self, temperature=1000):
super(PosEncSine, self).__init__()
self.temperature = temperature
def forward(self, x, depth):
# depth : B D H W
with torch.no_grad():
B,C,D,H,W = x.shape
depth = depth.permute(0,2,3,1).reshape(B*H*W, D) / self.temperature # BHW D
pos = torch.stack([torch.sin(i * math.pi * depth) for i in range(C//2)] + [torch.cos(i * math.pi * depth) for i in range(C//2)], dim=-1) # BHW,D,C
pos = pos.reshape(B,H,W,D,C).permute(0,4,3,1,2) # B C D H W
x = x + pos
return x
class PosEncLearned(nn.Module):
"""
Absolute pos embedding, learned.
"""
def __init__(self, D, C):
super().__init__()
self.D = D
self.C = C
self.depth_embed = nn.Parameter(torch.Tensor(C, self.D))
self.reset_parameters()
def reset_parameters(self):
nn.init.uniform_(self.depth_embed)
def forward(self, x, **kwargs):
B,C,D,H,W = x.shape
pos = self.depth_embed[None,:,:,None,None].repeat(B,1,1,H,W) # B C D H W
x = x + pos
return x
class stagenet(nn.Module):
def __init__(self, inverse_depth=False, mono=False, attn_fuse_d=True, vis_ETA=False, attn_temp=1):
super(stagenet, self).__init__()
self.inverse_depth = inverse_depth
self.mono = mono
self.attn_fuse_d = attn_fuse_d
self.vis_ETA = vis_ETA
self.attn_temp = attn_temp
def forward(self, features, proj_matrices, depth_hypo, regnet, stage_idx, group_cor=False, group_cor_dim=8, split_itv=1, fn=None):
# step 1. feature extraction
proj_matrices = torch.unbind(proj_matrices, 1)
ref_feature, src_features = features[0], features[1:]
ref_proj, src_projs = proj_matrices[0], proj_matrices[1:]
B,D,H,W = depth_hypo.shape
C = ref_feature.shape[1]
ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, D, 1, 1)
cor_weight_sum = 1e-8
cor_feats = 0
# step 2. Epipolar Transformer Aggregation
for src_idx, (src_fea, src_proj) in enumerate(zip(src_features, src_projs)):
if self.vis_ETA:
scan_name = fn[0].split('/')[0]
image_name = fn[0].split('/')[2][:-2]
save_fn = './debug_figs/vis_ETA/{}_stage{}_src{}'.format(scan_name+'_'+image_name, stage_idx, src_idx)
else:
save_fn = None
src_proj_new = src_proj[:, 0].clone()
src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4])
ref_proj_new = ref_proj[:, 0].clone()
ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4])
warped_src = homo_warping(src_fea, src_proj_new, ref_proj_new, depth_hypo, self.vis_ETA, save_fn) # B C D H W
if group_cor:
warped_src = warped_src.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W)
ref_volume = ref_volume.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W)
cor_feat = (warped_src * ref_volume).mean(2) # B G D H W
else:
cor_feat = (ref_volume - warped_src)**2 # B C D H W
del warped_src, src_proj, src_fea
if self.vis_ETA:
vis_weight = torch.softmax(cor_feat.sum(1), 1).detach().cpu().numpy()
np.save(save_fn, vis_weight)
if not self.attn_fuse_d:
cor_weight = torch.softmax(cor_feat.sum(1), 1).max(1)[0] # B H W
cor_weight_sum += cor_weight # B H W
cor_feats += cor_weight.unsqueeze(1).unsqueeze(1) * cor_feat # B C D H W
else:
cor_weight = torch.softmax(cor_feat.sum(1) / self.attn_temp, 1) / math.sqrt(C) # B D H W
cor_weight_sum += cor_weight # B D H W
cor_feats += cor_weight.unsqueeze(1) * cor_feat # B C D H W
del cor_weight, cor_feat
if not self.attn_fuse_d:
cor_feats = cor_feats / cor_weight_sum.unsqueeze(1).unsqueeze(1) # B C D H W
else:
cor_feats = cor_feats / cor_weight_sum.unsqueeze(1) # B C D H W
del cor_weight_sum, src_features
# step 3. regularization
attn_weight = regnet(cor_feats) # B D H W
del cor_feats
attn_weight = F.softmax(attn_weight, dim=1) # B D H W
# step 4. depth argmax
attn_max_indices = attn_weight.max(1, keepdim=True)[1] # B 1 H W
depth = torch.gather(depth_hypo, 1, attn_max_indices).squeeze(1) # B H W
if not self.training:
with torch.no_grad():
photometric_confidence = attn_weight.max(1)[0] # B H W
photometric_confidence = F.interpolate(photometric_confidence.unsqueeze(1), scale_factor=2**(3-stage_idx), mode='bilinear', align_corners=True).squeeze(1)
else:
photometric_confidence = torch.tensor(0.0, dtype=torch.float32, device=ref_feature.device, requires_grad=False)
ret_dict = {"depth": depth, "photometric_confidence": photometric_confidence, "hypo_depth": depth_hypo, "attn_weight": attn_weight}
if self.inverse_depth:
last_depth_itv = 1./depth_hypo[:,2,:,:] - 1./depth_hypo[:,1,:,:]
inverse_min_depth = 1/depth + split_itv * last_depth_itv # B H W
inverse_max_depth = 1/depth - split_itv * last_depth_itv # B H W
ret_dict['inverse_min_depth'] = inverse_min_depth
ret_dict['inverse_max_depth'] = inverse_max_depth
# if self.mono and self.training:
if self.mono:
ret_dict['mono_feat'] = ref_feature # B C H W
return ret_dict
def sinkhorn(gt_depth, hypo_depth, attn_weight, mask, iters, eps=1, continuous=False):
"""
gt_depth: B H W
hypo_depth: B D H W
attn_weight: B D H W
mask: B H W
"""
B,D,H,W = attn_weight.shape
if not continuous:
D_map = torch.stack([torch.arange(-i,D-i,1, dtype=torch.float32, device=gt_depth.device) for i in range(D)], dim=1).abs()
D_map = D_map[None,None,:,:].repeat(B,H*W,1,1) # B HW D D
gt_indices = torch.abs(hypo_depth - gt_depth[:,None,:,:]).min(1)[1].squeeze(1).reshape(B*H*W, 1) # BHW, 1
gt_dist = torch.zeros_like(hypo_depth).permute(0,2,3,1).reshape(B*H*W, D)
gt_dist.scatter_add_(1,gt_indices,torch.ones([gt_dist.shape[0],1], dtype=gt_dist.dtype, device=gt_dist.device))
gt_dist = gt_dist.reshape(B,H*W,D) # B HW D
else:
gt_dist = torch.zeros((B,H*W,D+1), dtype=torch.float32, device=gt_depth.device, requires_grad=False) # B HW D+1
gt_dist[:,:,-1] = 1
D_map = torch.zeros((B,D,D+1), dtype=torch.float32, device=gt_depth.device, requires_grad=False) # B D D+1
D_map[:, :D, :D] = torch.stack([torch.arange(-i,D-i,1, dtype=torch.float32, device=gt_depth.device) for i in range(D)], dim=1).abs().unsqueeze(0) # B D D+1
D_map = D_map[:,None,None,:,:].repeat(1,H,W,1,1) # B H W D D+1
itv = 1/hypo_depth[:,2,:,:] - 1/hypo_depth[:,1,:,:] # B H W
gt_bin_distance_ = (1/gt_depth - 1/hypo_depth[:,0,:,:]) / itv # B H W
#FIXME hard code 100
gt_bin_distance_[~mask] = 10
gt_bin_distance = torch.stack([(gt_bin_distance_ - i).abs() for i in range(D)], dim=1).permute(0,2,3,1) # B H W D
D_map[:,:,:,:,-1] = gt_bin_distance
D_map = D_map.reshape(B,H*W,D,1+D) # B HW D D+1
pred_dist = attn_weight.permute(0,2,3,1).reshape(B,H*W,D) # B HW D
# map to log space for stability
log_mu = (gt_dist+1e-12).log()
log_nu = (pred_dist+1e-12).log() # B HW D or D+1
u, v = torch.zeros_like(log_nu), torch.zeros_like(log_mu)
for _ in range(iters):
# scale v first then u to ensure row sum is 1, col sum slightly larger than 1
v = log_mu - torch.logsumexp(D_map/eps + u.unsqueeze(3), dim=2) # log(sum(exp()))
u = log_nu - torch.logsumexp(D_map/eps + v.unsqueeze(2), dim=3)
# convert back from log space, recover probabilities by normalization 2W
T_map = (D_map/eps + u.unsqueeze(3) + v.unsqueeze(2)).exp() # B HW D D
loss = (T_map * D_map).reshape(B*H*W,-1)[mask.reshape(-1)].sum(-1).mean()
return T_map, loss
================================================
FILE: requirements.txt
================================================
torch==1.9.0
torchvision==0.10.0
numpy
pillow
tensorboardX
opencv-python
plyfile
================================================
FILE: scripts/test_dtu.sh
================================================
#!/usr/bin/env bash
DTU_TESTPATH="/mnt/cfs/algorithm/public_data/mvs/dtu_test"
DTU_TESTLIST="lists/dtu/test.txt"
DTU_size=$1
exp=$2
PY_ARGS=${@:3}
DTU_LOG_DIR="./checkpoints/dtu/"$exp
if [ ! -d $DTU_LOG_DIR ]; then
mkdir -p $DTU_LOG_DIR
fi
DTU_CKPT_FILE=$DTU_LOG_DIR"/finalmodel.ckpt"
DTU_OUT_DIR="./outputs/dtu/"$exp
if [ $DTU_size = "raw" ] ; then
python test_mvs4.py --dataset=general_eval4 --batch_size=1 --testpath=$DTU_TESTPATH --testlist=$DTU_TESTLIST --loadckpt $DTU_CKPT_FILE --interval_scale 1.06 --outdir $DTU_OUT_DIR\
--use_raw_train --thres_view 4 --conf 0.5 --group_cor --attn_temp 2 --inverse_depth $PY_ARGS | tee -a $DTU_LOG_DIR/log_test.txt
else
python test
gitextract_j4zdg8al/ ├── .gitignore ├── LICENSE ├── README.md ├── datasets/ │ ├── __init__.py │ ├── blendedmvs.py │ ├── data_io.py │ ├── dtu_yao4.py │ ├── eth3d.py │ ├── general_eval4.py │ └── tanks.py ├── evaluations/ │ └── dtu/ │ ├── BaseEval2Obj_web.m │ ├── BaseEvalMain_func.m │ ├── BaseEvalMain_web.m │ ├── ComputeStat_func.m │ ├── ComputeStat_web.m │ ├── MaxDistCP.m │ ├── PointCompareMain.m │ ├── plyread.m │ └── reducePts_haa.m ├── lists/ │ ├── blendedmvs/ │ │ ├── train.txt │ │ └── val.txt │ └── dtu/ │ ├── test.txt │ ├── train.txt │ ├── trainval.txt │ └── val.txt ├── models/ │ ├── MVS4Net.py │ ├── __init__.py │ ├── module.py │ └── mvs4net_utils.py ├── requirements.txt ├── scripts/ │ ├── test_dtu.sh │ └── train_dtu.sh ├── test_mvs4.py ├── train_mvs4.py └── utils.py
SYMBOL INDEX (265 symbols across 13 files)
FILE: datasets/__init__.py
function find_dataset_def (line 5) | def find_dataset_def(dataset_name):
FILE: datasets/blendedmvs.py
function check_invalid_input (line 11) | def check_invalid_input(imgs, depths, masks, depth_mins, depth_maxs):
class MVSDataset (line 26) | class MVSDataset(Dataset):
method __init__ (line 27) | def __init__(self, datapath, listfile, split, nviews, img_wh=(768, 576...
method build_metas (line 49) | def build_metas(self):
method read_cam_file (line 62) | def read_cam_file(self, scan, filename):
method read_depth_mask (line 81) | def read_depth_mask(self, scan, filename, depth_min, depth_max, scale):
method read_img (line 108) | def read_img(self, filename):
method __len__ (line 115) | def __len__(self):
method __getitem__ (line 118) | def __getitem__(self, idx):
FILE: datasets/data_io.py
function read_pfm (line 6) | def read_pfm(filename):
function save_pfm (line 44) | def save_pfm(filename, image, scale=1):
class RandomCrop (line 75) | class RandomCrop(object):
method __init__ (line 76) | def __init__(self, CropSize=0.1):
method __call__ (line 79) | def __call__(self, image, normal):
FILE: datasets/dtu_yao4.py
class MVSDataset (line 9) | class MVSDataset(Dataset):
method __init__ (line 10) | def __init__(self, datapath, listfile, mode, nviews, interval_scale=1....
method build_list (line 26) | def build_list(self):
method __len__ (line 48) | def __len__(self):
method read_cam_file (line 51) | def read_cam_file(self, filename):
method read_img (line 64) | def read_img(self, filename):
method crop_img (line 72) | def crop_img(self, img):
method prepare_img (line 78) | def prepare_img(self, hr_img):
method read_mask_hr (line 92) | def read_mask_hr(self, filename):
method read_depth_hr (line 108) | def read_depth_hr(self, filename, scale):
method __getitem__ (line 123) | def __getitem__(self, idx):
FILE: datasets/eth3d.py
class MVSDataset (line 8) | class MVSDataset(Dataset):
method __init__ (line 9) | def __init__(self, datapath, split='test', n_views=7, img_wh=(1920,128...
method build_metas (line 17) | def build_metas(self):
method read_cam_file (line 40) | def read_cam_file(self, filename):
method read_img (line 57) | def read_img(self, filename):
method __len__ (line 64) | def __len__(self):
method __getitem__ (line 67) | def __getitem__(self, idx):
FILE: datasets/general_eval4.py
class MVSDataset (line 8) | class MVSDataset(Dataset):
method __init__ (line 9) | def __init__(self, datapath, listfile, mode, nviews, interval_scale=1....
method build_list (line 24) | def build_list(self):
method __len__ (line 56) | def __len__(self):
method read_cam_file (line 59) | def read_cam_file(self, filename, interval_scale):
method read_img (line 81) | def read_img(self, filename):
method read_depth (line 88) | def read_depth(self, filename):
method scale_mvs_input (line 92) | def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=64):
method __getitem__ (line 111) | def __getitem__(self, idx):
FILE: datasets/tanks.py
class MVSDataset (line 8) | class MVSDataset(Dataset):
method __init__ (line 9) | def __init__(self, datapath, n_views=7, split='intermediate'):
method build_metas (line 16) | def build_metas(self):
method read_cam_file (line 33) | def read_cam_file(self, filename):
method read_img (line 48) | def read_img(self, filename):
method scale_input (line 53) | def scale_input(self, intrinsics, img):
method __len__ (line 62) | def __len__(self):
method __getitem__ (line 65) | def __getitem__(self, idx):
FILE: models/MVS4Net.py
class MVS4net (line 9) | class MVS4net(nn.Module):
method __init__ (line 10) | def __init__(self, arch_mode="fpn", reg_net='reg2d', num_stage=4, fpn_...
method forward (line 60) | def forward(self, imgs, proj_matrices, depth_values, filename=None):
function MVS4net_loss (line 113) | def MVS4net_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
function Blend_loss (line 158) | def Blend_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
FILE: models/module.py
function init_bn (line 14) | def init_bn(module):
function init_uniform (line 22) | def init_uniform(module, init_method):
class Conv2d (line 30) | class Conv2d(nn.Module):
method __init__ (line 44) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
method forward (line 58) | def forward(self, x):
method init_weights (line 66) | def init_weights(self, init_method):
class DCNConv2d (line 72) | class DCNConv2d(nn.Module):
method __init__ (line 74) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
method forward (line 87) | def forward(self, x):
method init_weights (line 95) | def init_weights(self, init_method):
class Deconv2d (line 101) | class Deconv2d(nn.Module):
method __init__ (line 115) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
method forward (line 130) | def forward(self, x):
method init_weights (line 141) | def init_weights(self, init_method):
class Conv3d (line 147) | class Conv3d(nn.Module):
method __init__ (line 161) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
method forward (line 177) | def forward(self, x):
method init_weights (line 185) | def init_weights(self, init_method):
class PConv3d (line 191) | class PConv3d(nn.Module):
method __init__ (line 193) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
method forward (line 213) | def forward(self, x):
method init_weights (line 222) | def init_weights(self, init_method):
class Deconv3d (line 230) | class Deconv3d(nn.Module):
method __init__ (line 244) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
method forward (line 259) | def forward(self, x):
method init_weights (line 267) | def init_weights(self, init_method):
class PDeconv3d (line 274) | class PDeconv3d(nn.Module):
method __init__ (line 276) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 298) | def forward(self, x):
method init_weights (line 307) | def init_weights(self, init_method):
class ConvBnReLU (line 314) | class ConvBnReLU(nn.Module):
method __init__ (line 315) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 320) | def forward(self, x):
class ConvBn (line 323) | class ConvBn(nn.Module):
method __init__ (line 324) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 329) | def forward(self, x):
class ConvBnReLU3D (line 332) | class ConvBnReLU3D(nn.Module):
method __init__ (line 333) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 338) | def forward(self, x):
class ConvBn3D (line 342) | class ConvBn3D(nn.Module):
method __init__ (line 343) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 348) | def forward(self, x):
class BasicBlock (line 352) | class BasicBlock(nn.Module):
method __init__ (line 353) | def __init__(self, in_channels, out_channels, stride, downsample=None):
method forward (line 362) | def forward(self, x):
class Hourglass3d (line 371) | class Hourglass3d(nn.Module):
method __init__ (line 372) | def __init__(self, channels):
method forward (line 394) | def forward(self, x):
function homo_warping (line 402) | def homo_warping(src_fea, src_proj, ref_proj, depth_values, align_corner...
class DeConv2dFuse (line 437) | class DeConv2dFuse(nn.Module):
method __init__ (line 438) | def __init__(self, in_channels, out_channels, kernel_size, relu=True, ...
method forward (line 451) | def forward(self, x_pre, x):
class FeatureNet (line 458) | class FeatureNet(nn.Module):
method __init__ (line 459) | def __init__(self, base_channels, num_stage=3, stride=4, arch_mode="un...
method forward (line 520) | def forward(self, x):
class FPNDCNpath (line 561) | class FPNDCNpath(nn.Module):
method __init__ (line 564) | def __init__(self, base_channels, stride=4):
method forward (line 612) | def forward(self, x):
class FPNDCN (line 635) | class FPNDCN(nn.Module):
method __init__ (line 638) | def __init__(self, base_channels, stride=4):
method forward (line 684) | def forward(self, x):
class FPNA (line 705) | class FPNA(nn.Module):
method __init__ (line 708) | def __init__(self, base_channels, stride=4):
method forward (line 743) | def forward(self, x):
class FPNA4 (line 764) | class FPNA4(nn.Module):
method __init__ (line 767) | def __init__(self, base_channels):
method forward (line 810) | def forward(self, x):
class CostRegNet (line 836) | class CostRegNet(nn.Module):
method __init__ (line 837) | def __init__(self, in_channels, base_channels, down_size=3):
method forward (line 860) | def forward(self, x):
class P3DConv (line 884) | class P3DConv(nn.Module):
method __init__ (line 888) | def __init__(self, in_channels, base_channels):
method forward (line 909) | def forward(self, x):
class RefineNet (line 920) | class RefineNet(nn.Module):
method __init__ (line 921) | def __init__(self):
method forward (line 928) | def forward(self, img, depth_init):
function depth_regression (line 935) | def depth_regression(p, depth_values):
function cas_mvsnet_loss (line 943) | def cas_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
function cas_mvsnet_T_loss (line 964) | def cas_mvsnet_T_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
function get_cur_depth_range_samples (line 1138) | def get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, ...
function get_depth_range_samples (line 1157) | def get_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, devi...
FILE: models/mvs4net_utils.py
function homo_warping (line 13) | def homo_warping(src_fea, src_proj, ref_proj, depth_values, vis_ETA=Fals...
function init_range (line 61) | def init_range(cur_depth, ndepths, device, dtype, H, W):
function init_inverse_range (line 71) | def init_inverse_range(cur_depth, ndepths, device, dtype, H, W):
function schedule_inverse_range (line 79) | def schedule_inverse_range(inverse_min_depth, inverse_max_depth, ndepths...
function schedule_range (line 88) | def schedule_range(cur_depth, ndepth, depth_inteval_pixel, H, W):
function init_bn (line 101) | def init_bn(module):
function init_uniform (line 108) | def init_uniform(module, init_method):
class ConvBnReLU3D (line 116) | class ConvBnReLU3D(nn.Module):
method __init__ (line 117) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 122) | def forward(self, x):
class ConvBnReLU3D_CAM (line 125) | class ConvBnReLU3D_CAM(nn.Module):
method __init__ (line 126) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 136) | def forward(self, input):
class ConvBnReLU3D_DCAM (line 145) | class ConvBnReLU3D_DCAM(nn.Module):
method __init__ (line 146) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 156) | def forward(self, input):
class ConvBnReLU3D_PAM (line 165) | class ConvBnReLU3D_PAM(nn.Module):
method __init__ (line 166) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 172) | def forward(self, input):
class ConvBnReLU3D_PDAM (line 181) | class ConvBnReLU3D_PDAM(nn.Module):
method __init__ (line 182) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
method forward (line 188) | def forward(self, input):
class Deconv3d (line 197) | class Deconv3d(nn.Module):
method __init__ (line 199) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
method forward (line 211) | def forward(self, x):
method init_weights (line 219) | def init_weights(self, init_method):
class Conv2d (line 224) | class Conv2d(nn.Module):
method __init__ (line 226) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
method forward (line 238) | def forward(self, x):
method init_weights (line 248) | def init_weights(self, init_method):
class Deconv2d (line 253) | class Deconv2d(nn.Module):
method __init__ (line 255) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
class DeformConv2d (line 267) | class DeformConv2d(nn.Module):
method __init__ (line 268) | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias...
method _set_lr (line 287) | def _set_lr(module, grad_input, grad_output):
method forward (line 291) | def forward(self, x):
method _get_p_n (line 349) | def _get_p_n(self, N, dtype):
method _get_p_0 (line 359) | def _get_p_0(self, h, w, N, dtype):
method _get_p (line 369) | def _get_p(self, offset, dtype):
method _get_x_q (line 379) | def _get_x_q(self, x, q, N):
method _reshape_x_offset (line 396) | def _reshape_x_offset(x_offset, ks):
function NA_DCN (line 403) | def NA_DCN(in_channels, kernel_size=3, stride=1, dilation=1, bias=True, ...
class FPN4 (line 419) | class FPN4(nn.Module):
method __init__ (line 422) | def __init__(self, base_channels, gn=False, dcn=False):
method forward (line 472) | def forward(self, x):
class LayerNorm (line 504) | class LayerNorm(nn.Module):
method __init__ (line 506) | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_l...
method forward (line 516) | def forward(self, x):
class convnext_block (line 526) | class convnext_block(nn.Module):
method __init__ (line 528) | def __init__(self, dim, layer_scale_init_value=1e-6):
method forward (line 538) | def forward(self, x):
class convnext4_block (line 553) | class convnext4_block(nn.Module):
method __init__ (line 555) | def __init__(self, dim, layer_scale_init_value=1e-6):
method forward (line 566) | def forward(self, x):
class FPN4_convnext (line 581) | class FPN4_convnext(nn.Module):
method __init__ (line 584) | def __init__(self, base_channels, gn=False, dcn=False):
method forward (line 620) | def forward(self, x):
class FPN4_convnext4 (line 652) | class FPN4_convnext4(nn.Module):
method __init__ (line 655) | def __init__(self, base_channels, gn=False, dcn=False):
method forward (line 691) | def forward(self, x):
class ASFF (line 723) | class ASFF(nn.Module):
method __init__ (line 724) | def __init__(self, level):
method forward (line 758) | def forward(self, x_level_0, x_level_1, x_level_2, x_level_3):
class FullImageEncoder (line 807) | class FullImageEncoder(nn.Module):
method __init__ (line 808) | def __init__(self, h, w, kernel_size):
method forward (line 819) | def forward(self, x):
class mono_depth_decoder (line 833) | class mono_depth_decoder(nn.Module):
method __init__ (line 835) | def __init__(self):
method forward (line 849) | def forward(self, outputs, d_min, d_max):
class reg2d (line 870) | class reg2d(nn.Module):
method __init__ (line 871) | def __init__(self, input_channel=128, base_channel=32, conv_name='Conv...
method forward (line 902) | def forward(self, x):
class reg3d (line 914) | class reg3d(nn.Module):
method __init__ (line 915) | def __init__(self, in_channels, base_channels, down_size=3):
method forward (line 943) | def forward(self, x):
class PosEncSine (line 967) | class PosEncSine(nn.Module):
method __init__ (line 969) | def __init__(self, temperature=1000):
method forward (line 973) | def forward(self, x, depth):
class PosEncLearned (line 983) | class PosEncLearned(nn.Module):
method __init__ (line 987) | def __init__(self, D, C):
method reset_parameters (line 994) | def reset_parameters(self):
method forward (line 997) | def forward(self, x, **kwargs):
class stagenet (line 1003) | class stagenet(nn.Module):
method __init__ (line 1004) | def __init__(self, inverse_depth=False, mono=False, attn_fuse_d=True, ...
method forward (line 1012) | def forward(self, features, proj_matrices, depth_hypo, regnet, stage_i...
function sinkhorn (line 1096) | def sinkhorn(gt_depth, hypo_depth, attn_weight, mask, iters, eps=1, cont...
FILE: test_mvs4.py
function read_camera_parameters (line 94) | def read_camera_parameters(filename):
function read_img (line 106) | def read_img(filename):
function read_mask (line 114) | def read_mask(filename):
function save_mask (line 119) | def save_mask(filename, mask):
function read_pair_file (line 126) | def read_pair_file(filename):
function write_cam (line 138) | def write_cam(file, cam):
function save_depth (line 157) | def save_depth(testlist):
function save_scene_depth (line 170) | def save_scene_depth(testlist):
function reproject_with_depth (line 273) | def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, dept...
function check_geometric_consistency (line 313) | def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_re...
function filter_depth (line 331) | def filter_depth(pair_folder, scan_folder, out_folder, plyfilename):
function init_worker (line 424) | def init_worker():
function pcd_filter_worker (line 431) | def pcd_filter_worker(scan):
function pcd_filter (line 443) | def pcd_filter(testlist, number_worker):
function mrun_rst (line 457) | def mrun_rst(eval_dir, plyPath):
FILE: train_mvs4.py
function train (line 83) | def train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, s...
function test (line 179) | def test(model, model_loss, TestImgLoader, args):
function train_sample (line 195) | def train_sample(model, model_loss, optimizer, sample, args):
function test_sample_depth (line 253) | def test_sample_depth(model, model_loss, sample, args):
FILE: utils.py
function print_args (line 8) | def print_args(args):
function make_nograd_func (line 16) | def make_nograd_func(func):
function make_recursive_func (line 26) | def make_recursive_func(func):
function tensor2float (line 41) | def tensor2float(vars):
function tensor2numpy (line 51) | def tensor2numpy(vars):
function tocuda (line 61) | def tocuda(vars):
function save_scalars (line 70) | def save_scalars(logger, mode, scalar_dict, global_step):
function save_images (line 82) | def save_images(logger, mode, images_dict, global_step):
class DictAverageMeter (line 103) | class DictAverageMeter(object):
method __init__ (line 104) | def __init__(self):
method update (line 108) | def update(self, new_input):
method mean (line 121) | def mean(self):
function compute_metrics_for_each_image (line 126) | def compute_metrics_for_each_image(metric_func):
function Thres_metrics (line 141) | def Thres_metrics(depth_est, depth_gt, mask, thres):
function AbsDepthError_metrics (line 152) | def AbsDepthError_metrics(depth_est, depth_gt, mask, thres=None):
function synchronize (line 162) | def synchronize():
function get_world_size (line 176) | def get_world_size():
function reduce_scalar_outputs (line 183) | def reduce_scalar_outputs(scalar_outputs):
class WarmupMultiStepLR (line 208) | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
method __init__ (line 209) | def __init__(
method get_lr (line 237) | def get_lr(self):
function set_random_seed (line 253) | def set_random_seed(seed):
function local_pcd (line 260) | def local_pcd(depth, intr):
function generate_pointcloud (line 274) | def generate_pointcloud(rgb, depth, ply_file, intr, scale=1.0):
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (263K chars).
[
{
"path": ".gitignore",
"chars": 46,
"preview": "outputs/\ncheckpoints/\ndebug_figs/\n*__pycache__"
},
{
"path": "LICENSE",
"chars": 1066,
"preview": "MIT License\n\nCopyright (c) 2022 Jeff Wang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
},
{
"path": "README.md",
"chars": 5508,
"preview": "# MVSTER\nMVSTER: Epipolar Transformer for Efficient Multi-View Stereo, ECCV 2022. [arXiv](https://arxiv.org/abs/2204.073"
},
{
"path": "datasets/__init__.py",
"chars": 271,
"preview": "import importlib\n\n\n# find the dataset definition by name, for example dtu_yao (dtu_yao.py)\ndef find_dataset_def(dataset_"
},
{
"path": "datasets/blendedmvs.py",
"chars": 7911,
"preview": "from torch.utils.data import Dataset\nfrom datasets.data_io import *\nimport os\nimport numpy as np\nimport cv2\nfrom PIL imp"
},
{
"path": "datasets/data_io.py",
"chars": 3260,
"preview": "import numpy as np\nimport re\nimport sys\n\n\ndef read_pfm(filename):\n file = open(filename, 'rb')\n color = None\n w"
},
{
"path": "datasets/dtu_yao4.py",
"chars": 8342,
"preview": "from torch.utils.data import Dataset\nimport numpy as np\nimport os, cv2, time, math\nfrom PIL import Image\nfrom datasets.d"
},
{
"path": "datasets/eth3d.py",
"chars": 5210,
"preview": "from torch.utils.data import Dataset\nfrom datasets.data_io import *\nimport os\nimport numpy as np\nimport cv2\nfrom PIL imp"
},
{
"path": "datasets/general_eval4.py",
"chars": 7515,
"preview": "from torch.utils.data import Dataset\nimport numpy as np\nimport os, cv2, time\nfrom PIL import Image\nfrom datasets.data_io"
},
{
"path": "datasets/tanks.py",
"chars": 4948,
"preview": "from torch.utils.data import Dataset\nfrom datasets.data_io import *\nimport os\nimport numpy as np\nimport cv2\nfrom PIL imp"
},
{
"path": "evaluations/dtu/BaseEval2Obj_web.m",
"chars": 1501,
"preview": "function BaseEval2Obj_web(BaseEval,method_string,outputPath)\r\n\r\nif(nargin<3)\r\n outputPath='./';\r\nend\r\n\r\n% tresshold f"
},
{
"path": "evaluations/dtu/BaseEvalMain_func.m",
"chars": 2866,
"preview": "function None = BaseEvalMain_func(plyPath)\r\n\r\n% clear all\r\n% close all\r\nformat compact\r\n\r\n% script to calculate distance"
},
{
"path": "evaluations/dtu/BaseEvalMain_web.m",
"chars": 2814,
"preview": "clear all\r\nclose all\r\nformat compact\r\nclc\r\n\r\n% script to calculate distances have been measured for all included scans ("
},
{
"path": "evaluations/dtu/ComputeStat_func.m",
"chars": 2816,
"preview": "function None = ComputeStat_func(plyPath)\r\nformat compact\r\n\r\n% script to calculate the statistics for each scan given th"
},
{
"path": "evaluations/dtu/ComputeStat_web.m",
"chars": 2758,
"preview": "clear all\r\nclose all\r\nformat compact\r\nclc\r\n\r\n% script to calculate the statistics for each scan given this will currentl"
},
{
"path": "evaluations/dtu/MaxDistCP.m",
"chars": 1444,
"preview": "function Dist = MaxDistCP(Qto,Qfrom,BB,MaxDist)\r\n\r\nDist=ones(1,size(Qfrom,2))*MaxDist;\r\n\r\nRange=floor((BB(2,:)-BB(1,:))/"
},
{
"path": "evaluations/dtu/PointCompareMain.m",
"chars": 2103,
"preview": "function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath)\r\n% evaluation function the calculates the distantes from the"
},
{
"path": "evaluations/dtu/plyread.m",
"chars": 15651,
"preview": "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\r\nfunction [Elements,varargout] = plyread(Path,Str)\r\n"
},
{
"path": "evaluations/dtu/reducePts_haa.m",
"chars": 900,
"preview": "function [ptsOut,indexSet] = reducePts_haa(pts, dst)\n\n%Reduces a point set, pts, in a stochastic manner, such that the m"
},
{
"path": "lists/blendedmvs/train.txt",
"chars": 2650,
"preview": "5c1f33f1d33e1f2e4aa6dda4\n5bfe5ae0fe0ea555e6a969ca\n5bff3c5cfe0ea555e6bcbf3a\n58eaf1513353456af3a1682a\n5bfc9d5aec61ca1dd691"
},
{
"path": "lists/blendedmvs/val.txt",
"chars": 182,
"preview": "5b7a3890fc8fcf6781e2593a\r\n5c189f2326173c3a09ed7ef3\r\n5b950c71608de421b1e7318f\r\n5a6400933d809f1d8200af15\r\n59d2657f82ca7774"
},
{
"path": "lists/dtu/test.txt",
"chars": 153,
"preview": "scan1\nscan4\nscan9\nscan10\nscan11\nscan12\nscan13\nscan15\nscan23\nscan24\nscan29\nscan32\nscan33\nscan34\nscan48\nscan49\nscan62\nscan"
},
{
"path": "lists/dtu/train.txt",
"chars": 572,
"preview": "scan2\nscan6\nscan7\nscan8\nscan14\nscan16\nscan18\nscan19\nscan20\nscan22\nscan30\nscan31\nscan36\nscan39\nscan41\nscan42\nscan44\nscan4"
},
{
"path": "lists/dtu/trainval.txt",
"chars": 698,
"preview": "scan2\nscan6\nscan7\nscan8\nscan14\nscan16\nscan18\nscan19\nscan20\nscan22\nscan30\nscan31\nscan36\nscan39\nscan41\nscan42\nscan44\nscan4"
},
{
"path": "lists/dtu/val.txt",
"chars": 125,
"preview": "scan3\nscan5\nscan17\nscan21\nscan28\nscan35\nscan37\nscan38\nscan40\nscan43\nscan56\nscan59\nscan66\nscan67\nscan82\nscan86\nscan106\nsc"
},
{
"path": "models/MVS4Net.py",
"chars": 10195,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom models.mvs4net_utils import s"
},
{
"path": "models/__init__.py",
"chars": 61,
"preview": "\nfrom models.MVS4Net import MVS4net, MVS4net_loss, Blend_loss"
},
{
"path": "models/module.py",
"chars": 52886,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport time\nimport sys\nimport seaborn as sns\nimport n"
},
{
"path": "models/mvs4net_utils.py",
"chars": 50767,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport importlib\ntry:\n from modules.deform_conv im"
},
{
"path": "requirements.txt",
"chars": 80,
"preview": "torch==1.9.0\ntorchvision==0.10.0\nnumpy\npillow\ntensorboardX\nopencv-python\nplyfile"
},
{
"path": "scripts/test_dtu.sh",
"chars": 996,
"preview": "#!/usr/bin/env bash\nDTU_TESTPATH=\"/mnt/cfs/algorithm/public_data/mvs/dtu_test\"\nDTU_TESTLIST=\"lists/dtu/test.txt\"\n\nDTU_si"
},
{
"path": "scripts/train_dtu.sh",
"chars": 1105,
"preview": "#!/usr/bin/env bash\nDTU_TRAINING=\"/mnt/cfs/algorithm/public_data/mvs/mvs_training/dtu\"\nDTU_TRAINLIST=\"lists/dtu/train.tx"
},
{
"path": "test_mvs4.py",
"chars": 22811,
"preview": "import argparse, os, time, sys, gc, cv2\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backend"
},
{
"path": "train_mvs4.py",
"chars": 22156,
"preview": "import argparse, os, sys, time, gc, datetime\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.ba"
},
{
"path": "utils.py",
"chars": 10159,
"preview": "import numpy as np\nimport torchvision.utils as vutils\nimport torch, random\nimport torch.nn.functional as F\n\n\n# print arg"
}
]
About this extraction
This page contains the full source code of the JeffWang987/MVSTER GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (246.6 KB), approximately 71.3k tokens, and a symbol index with 265 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.