Full Code of JeffWang987/MVSTER for AI

main 3f8bb98bba0c cached
35 files
246.6 KB
71.3k tokens
265 symbols
1 requests
Download .txt
Showing preview only (259K chars total). Download the full file or copy to clipboard to get everything.
Repository: JeffWang987/MVSTER
Branch: main
Commit: 3f8bb98bba0c
Files: 35
Total size: 246.6 KB

Directory structure:
gitextract_j4zdg8al/

├── .gitignore
├── LICENSE
├── README.md
├── datasets/
│   ├── __init__.py
│   ├── blendedmvs.py
│   ├── data_io.py
│   ├── dtu_yao4.py
│   ├── eth3d.py
│   ├── general_eval4.py
│   └── tanks.py
├── evaluations/
│   └── dtu/
│       ├── BaseEval2Obj_web.m
│       ├── BaseEvalMain_func.m
│       ├── BaseEvalMain_web.m
│       ├── ComputeStat_func.m
│       ├── ComputeStat_web.m
│       ├── MaxDistCP.m
│       ├── PointCompareMain.m
│       ├── plyread.m
│       └── reducePts_haa.m
├── lists/
│   ├── blendedmvs/
│   │   ├── train.txt
│   │   └── val.txt
│   └── dtu/
│       ├── test.txt
│       ├── train.txt
│       ├── trainval.txt
│       └── val.txt
├── models/
│   ├── MVS4Net.py
│   ├── __init__.py
│   ├── module.py
│   └── mvs4net_utils.py
├── requirements.txt
├── scripts/
│   ├── test_dtu.sh
│   └── train_dtu.sh
├── test_mvs4.py
├── train_mvs4.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
outputs/
checkpoints/
debug_figs/
*__pycache__

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2022 Jeff Wang

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# MVSTER
MVSTER: Epipolar Transformer for Efficient Multi-View Stereo, ECCV 2022. [arXiv](https://arxiv.org/abs/2204.07346)

This repository contains the official implementation of the paper: "MVSTER: Epipolar Transformer for Efficient Multi-View Stereo".


## Introduction
MVSTER is a learning-based MVS method which achieves competitive reconstruction performance with significantly higher efficiency. MVSTER leverages the proposed epipolar Transformer to learn both 2D semantics and 3D spatial associations efficiently. Specifically, the epipolar Transformer utilizes a detachable monocular depth estimator to enhance 2D semantics and uses cross-attention to construct data-dependent 3D associations along epipolar line. Additionally, MVSTER is built in a cascade structure, where entropy-regularized optimal transport is leveraged to propagate finer depth estimations in each stage.
![](img/arch.png)



## Installation
MVSTER is tested on:
* python 3.7
* CUDA 11.1
### Requirements
```
pip install -r requirements.txt
```

## Training
* Dowload [DTU dataset](https://roboimagedata.compute.dtu.dk/). For convenience, can download the preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view)
 and [Depths_raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip) 
 (both from [Original MVSNet](https://github.com/YoYo000/MVSNet)), and upzip it as the $DTU_TRAINING folder. For training and testing with raw image size, you can download [Rectified_raw](http://roboimagedata2.compute.dtu.dk/data/MVS/Rectified.zip), and unzip it.

```                
├── Cameras    
├── Depths
├── Depths_raw   
├── Rectified
├── Rectified_raw (Optional)                                      
```
In ``scripts/train_dtu.sh``, set ``DTU_TRAINING`` as $DTU_TRAINING

Train MVSTER (Multi-GPU training): 
* Train with middle size (512x640):
```
bash ./scripts/train_dtu.sh mid exp_name
```
* Train with raw size (1200x1600):
```
bash ./scripts/train_dtu.sh raw exp_name
```
After training, you will get model checkpoints in ./checkpoints/dtu/exp_name.

## Testing
* Download the preprocessed test data [DTU testing data](https://drive.google.com/open?id=135oKPefcPTsdtLRzoDAQtPpHuoIrpRI_) (from [Original MVSNet](https://github.com/YoYo000/MVSNet)) and unzip it as the $DTU_TESTPATH folder, which should contain one ``cams`` folder, one ``images`` folder and one ``pair.txt`` file.
* In ``scripts/test_dtu.sh``, set ``DTU_TESTPATH`` as $DTU_TESTPATH.
* The ``DTU_CKPT_FILE`` is automatically set as your pretrained checkpoint file, you also can download my [pretrained model](https://github.com/JeffWang987/MVSTER/releases/tag/dtu_ckpt).
* Test with middle size:
```
bash ./scripts/test_dtu.sh mid exp_name
```
* Test with raw size:
```
bash ./scripts/test_dtu.sh raw exp_name
```
* Test with provided pretrained model:
```
bash scripts/test_dtu.sh mid benchmark --loadckpt PATH_TO_CKPT_FILE
```
After testing, you will get reconstructed point clouds of DTU test set in ./outputs/dtu/exp_name.

## Metric
* For quantitative evaluation, download [SampleSet](http://roboimagedata.compute.dtu.dk/?page_id=36) and [Points](http://roboimagedata.compute.dtu.dk/?page_id=36) from DTU's website. Unzip them and place `Points` folder in `SampleSet/MVS Data/`. The structure looks like:
```
SampleSet
├──MVS Data
      └──Points
```
* For convinience evaluation, please install matlab (tested on Ubuntu 18.04) and uncomment **mrun_rst** function at the end of **./test_mvs4.py**, and you also need to change the path of matlab excutable file (for me, it is /mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/misc/matlab/bin/matlab). Then you can evaluate point cloud reconstruction results when testing is finished.

* You can also evaluate the metrics with the traditional steps:
In ``evaluations/dtu/BaseEvalMain_web.m``, set `dataPath` as the path to `SampleSet/MVS Data/`, `plyPath` as directory that stores the reconstructed point clouds and `resultsPath` as directory to store the evaluation results. Then run ``evaluations/dtu/BaseEvalMain_web.m`` in matlab.

## Results on DTU (single RTX 3090)
|                       | Acc.   | Comp.  | Overall. | Inf. Time |
|-----------------------|--------|--------|----------|-----------|
| MVSTER (mid size)     | 0.350  | 0.276  | 0.313    |    0.09s  |
| MVSTER (raw size)     | 0.340  | 0.266  | 0.303    |    0.17s  |

Point cloud results on [DTU](https://github.com/JeffWang987/MVSTER/releases/tag/DTU_ply), [Tanks and Temples](https://github.com/JeffWang987/MVSTER/releases/tag/T%26T_ply), [ETH3D](https://github.com/JeffWang987/MVSTER/releases/tag/ETH3D_ply)

![](img/vegetables.gif) ![](img/house.gif) 

![](img/sculpture.gif) ![](img/rabit.gif)


If you find this project useful for your research, please cite: 
```
@misc{wang2022mvster,
      title={MVSTER: Epipolar Transformer for Efficient Multi-View Stereo}, 
      author={Xiaofeng Wang, Zheng Zhu, Fangbo Qin, Yun Ye, Guan Huang, Xu Chi, Yijia He and Xingang Wang},
      journal={arXiv preprint arXiv:2204.07346},
      year={2022}
}
```


## Acknowledgements
Our work is partially baed on these opening source work: [MVSNet](https://github.com/YoYo000/MVSNet), [MVSNet-pytorch](https://github.com/xy-guo/MVSNet_pytorch), [cascade-stereo](https://github.com/alibaba/cascade-stereo), [PatchmatchNet](https://github.com/FangjinhuaWang/PatchmatchNet).

We appreciate their contributions to the MVS community.


================================================
FILE: datasets/__init__.py
================================================
import importlib


# find the dataset definition by name, for example dtu_yao (dtu_yao.py)
def find_dataset_def(dataset_name):
    module_name = 'datasets.{}'.format(dataset_name)
    module = importlib.import_module(module_name)
    return getattr(module, "MVSDataset")


================================================
FILE: datasets/blendedmvs.py
================================================
from torch.utils.data import Dataset
from datasets.data_io import *
import os
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms as T
import random
import copy

def check_invalid_input(imgs, depths, masks, depth_mins, depth_maxs):
    for img in imgs:
        assert np.isnan(img).sum() == 0
        assert np.isinf(img).sum() == 0
    for depth in depths.values():
        assert np.isnan(depth).sum() == 0
        assert np.isinf(depth).sum() == 0
    for mask in masks.values():
        assert np.isnan(mask).sum() == 0
        assert np.isinf(mask).sum() == 0

    assert (depth_mins<=0) == 0
    assert (depth_maxs<=depth_mins) == 0


class MVSDataset(Dataset):
    def __init__(self, datapath, listfile, split, nviews, img_wh=(768, 576), robust_train=True):
        
        super(MVSDataset, self).__init__()
        self.levels = 4 
        self.datapath = datapath
        self.split = split
        self.listfile = listfile
        self.robust_train = robust_train
        assert self.split in ['train', 'val', 'all'], \
            'split must be either "train", "val" or "all"!'

        self.img_wh = img_wh
        if img_wh is not None:
            assert img_wh[0]%32==0 and img_wh[1]%32==0, \
                'img_wh must both be multiples of 32!'
        self.nviews = nviews
        self.scale_factors = {} # depth scale factors for each scan
        self.scale_factor = 0 # depth scale factors for each scan
        self.build_metas()

        self.color_augment = T.ColorJitter(brightness=0.5, contrast=0.5)

    def build_metas(self):
        self.metas = []
        with open(self.listfile) as f:
            self.scans = [line.rstrip() for line in f.readlines()]
        for scan in self.scans:
            with open(os.path.join(self.datapath, scan, "cams/pair.txt")) as f:
                num_viewpoint = int(f.readline())
                for _ in range(num_viewpoint):
                    ref_view = int(f.readline().rstrip())
                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
                    if len(src_views) >= self.nviews-1:
                        self.metas += [(scan, ref_view, src_views)]

    def read_cam_file(self, scan, filename):
        with open(filename) as f:
            lines = f.readlines()
            lines = [line.rstrip() for line in lines]
        # extrinsics: line [1,5), 4x4 matrix
        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
        # intrinsics: line [7-10), 3x3 matrix
        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
        depth_min = float(lines[11].split()[0])
        depth_max = float(lines[11].split()[-1])

        if scan not in self.scale_factors:
            self.scale_factors[scan] = 100.0 / depth_min
        depth_min *= self.scale_factors[scan]
        depth_max *= self.scale_factors[scan]
        extrinsics[:3, 3] *= self.scale_factors[scan]

        return intrinsics, extrinsics, depth_min, depth_max

    def read_depth_mask(self, scan, filename, depth_min, depth_max, scale):
        depth = np.array(read_pfm(filename)[0], dtype=np.float32)
        # depth = (depth * self.scale_factor) * scale
        depth = (depth * self.scale_factors[scan]) * scale
        # depth = depth * scale
        # depth = np.squeeze(depth,2)

        mask = (depth>=depth_min) & (depth<=depth_max)
        assert mask.sum() > 0
        mask = mask.astype(np.float32)
        if self.img_wh is not None:
            depth = cv2.resize(depth, self.img_wh,
                                 interpolation=cv2.INTER_NEAREST)
        h, w = depth.shape
        depth_ms = {}
        mask_ms = {}

        for i in range(4):
            depth_cur = cv2.resize(depth, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)
            mask_cur = cv2.resize(mask, (w//(2**i), h//(2**i)), interpolation=cv2.INTER_NEAREST)

            depth_ms[f"stage{4-i}"] = depth_cur
            mask_ms[f"stage{4-i}"] = mask_cur

        return depth_ms, mask_ms


    def read_img(self, filename):
        img = Image.open(filename)
        # img = self.color_augment(img)
        # scale 0~255 to 0~1
        np_img = np.array(img, dtype=np.float32) / 255.
        return np_img

    def __len__(self):
        return len(self.metas)

    def __getitem__(self, idx):
        meta = self.metas[idx]
        scan, ref_view, src_views = meta
        
        if self.robust_train:
            num_src_views = len(src_views)
            index = random.sample(range(num_src_views), self.nviews - 1)
            view_ids = [ref_view] + [src_views[i] for i in index]
            scale = random.uniform(0.8, 1.25)

        else:
            view_ids = [ref_view] + src_views[:self.nviews - 1]
            scale = 1

        imgs = []
        mask = None
        depth = None
        depth_min = None
        depth_max = None

        proj={}
        proj_matrices_0 = []
        proj_matrices_1 = []
        proj_matrices_2 = []
        proj_matrices_3 = []


        for i, vid in enumerate(view_ids):
            img_filename = os.path.join(self.datapath, '{}/blended_images/{:0>8}.jpg'.format(scan, vid))
            depth_filename = os.path.join(self.datapath, '{}/rendered_depth_maps/{:0>8}.pfm'.format(scan, vid))
            proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))

            img = self.read_img(img_filename)
            imgs.append(img.transpose(2,0,1))

            intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(scan, proj_mat_filename)
            # proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid)


            proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
            proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
            proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
            proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
            extrinsics[:3, 3] *= scale
            intrinsics[:2,:] *= 0.125
            proj_mat_0[0,:4,:4] = extrinsics.copy()
            proj_mat_0[1,:3,:3] = intrinsics.copy()

            intrinsics[:2,:] *= 2
            proj_mat_1[0,:4,:4] = extrinsics.copy()
            proj_mat_1[1,:3,:3] = intrinsics.copy()

            intrinsics[:2,:] *= 2
            proj_mat_2[0,:4,:4] = extrinsics.copy()
            proj_mat_2[1,:3,:3] = intrinsics.copy()

            intrinsics[:2,:] *= 2
            proj_mat_3[0,:4,:4] = extrinsics.copy()
            proj_mat_3[1,:3,:3] = intrinsics.copy()  

            proj_matrices_0.append(proj_mat_0)
            proj_matrices_1.append(proj_mat_1)
            proj_matrices_2.append(proj_mat_2)
            proj_matrices_3.append(proj_mat_3)

            if i == 0:  # reference view
                depth_min = depth_min_ * scale
                depth_max = depth_max_ * scale
                depth, mask = self.read_depth_mask(scan, depth_filename, depth_min, depth_max, scale)
                for l in range(self.levels):
                    mask[f'stage{l+1}'] = mask[f'stage{l+1}'] # np.expand_dims(mask[f'stage{l+1}'],2)
                    depth[f'stage{l+1}'] = depth[f'stage{l+1}']

        proj['stage1'] = np.stack(proj_matrices_0)
        proj['stage2'] = np.stack(proj_matrices_1)
        proj['stage3'] = np.stack(proj_matrices_2)
        proj['stage4'] = np.stack(proj_matrices_3)

        # check_invalid_input(imgs, depth, mask, depth_min, depth_max)
        # data is numpy array
        return {"imgs": imgs,                   # [Nv, 3, H, W]
                "proj_matrices": proj,          # [N,2,4,4]
                "depth": depth,                 # [1, H, W]
                "depth_values": np.array([depth_min, depth_max], dtype=np.float32),
                "mask": mask}                   # [1, H, W]
        

================================================
FILE: datasets/data_io.py
================================================
import numpy as np
import re
import sys


def read_pfm(filename):
    file = open(filename, 'rb')
    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().decode('utf-8').rstrip()
    if header == 'PF':
        color = True
    elif header == 'Pf':
        color = False
    else:
        raise Exception('Not a PFM file.')

    dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception('Malformed PFM header.')

    scale = float(file.readline().rstrip())
    if scale < 0:  # little-endian
        endian = '<'
        scale = -scale
    else:
        endian = '>'  # big-endian

    data = np.fromfile(file, endian + 'f')
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    file.close()
    return data, scale


def save_pfm(filename, image, scale=1):
    file = open(filename, "wb")
    color = None

    image = np.flipud(image)

    if image.dtype.name != 'float32':
        raise Exception('Image dtype must be float32.')

    if len(image.shape) == 3 and image.shape[2] == 3:  # color image
        color = True
    elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1:  # greyscale
        color = False
    else:
        raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')

    file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8'))
    file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8'))

    endian = image.dtype.byteorder

    if endian == '<' or endian == '=' and sys.byteorder == 'little':
        scale = -scale

    file.write(('%f\n' % scale).encode('utf-8'))

    image.tofile(file)
    file.close()


import random, cv2
class RandomCrop(object):
    def __init__(self, CropSize=0.1):
        self.CropSize = CropSize

    def __call__(self, image, normal):
        h, w = normal.shape[:2]
        img_h, img_w = image.shape[:2]
        CropSize_w, CropSize_h = max(1, int(w * self.CropSize)), max(1, int(h * self.CropSize))
        x1, y1 = random.randint(0, CropSize_w), random.randint(0, CropSize_h)
        x2, y2 = random.randint(w - CropSize_w, w), random.randint(h - CropSize_h, h)

        normal_crop = normal[y1:y2, x1:x2]
        normal_resize = cv2.resize(normal_crop, (w, h), interpolation=cv2.INTER_NEAREST)

        image_crop = image[4*y1:4*y2, 4*x1:4*x2]
        image_resize = cv2.resize(image_crop, (img_w, img_h), interpolation=cv2.INTER_LINEAR)

        # import matplotlib.pyplot as plt
        # plt.subplot(2, 3, 1)
        # plt.imshow(image)
        # plt.subplot(2, 3, 2)
        # plt.imshow(image_crop)
        # plt.subplot(2, 3, 3)
        # plt.imshow(image_resize)
        #
        # plt.subplot(2, 3, 4)
        # plt.imshow((normal + 1.0) / 2, cmap="rainbow")
        # plt.subplot(2, 3, 5)
        # plt.imshow((normal_crop + 1.0) / 2, cmap="rainbow")
        # plt.subplot(2, 3, 6)
        # plt.imshow((normal_resize + 1.0) / 2, cmap="rainbow")
        # plt.show()
        # plt.pause(1)
        # plt.close()

        return image_resize, normal_resize

================================================
FILE: datasets/dtu_yao4.py
================================================
from torch.utils.data import Dataset
import numpy as np
import os, cv2, time, math
from PIL import Image
from datasets.data_io import *
from torchvision import transforms

# the DTU dataset preprocessed by Yao Yao (only for training)
class MVSDataset(Dataset):
    def __init__(self, datapath, listfile, mode, nviews, interval_scale=1.06, **kwargs):
        super(MVSDataset, self).__init__()
        self.datapath = datapath
        self.listfile = listfile
        self.mode = mode
        self.nviews = nviews
        self.ndepths = 192  # Hardcode
        self.interval_scale = interval_scale
        self.kwargs = kwargs
        self.rt = kwargs.get("rt", False)
        self.use_raw_train = kwargs.get("use_raw_train", False)
        self.color_augment = transforms.ColorJitter(brightness=0.5, contrast=0.5)

        assert self.mode in ["train", "val", "test"]
        self.metas = self.build_list()

    def build_list(self):
        metas = []
        with open(self.listfile) as f:
            scans = f.readlines()
            scans = [line.rstrip() for line in scans]

        # scans
        for scan in scans:
            pair_file = "Cameras/pair.txt"
            # read the pair file
            with open(os.path.join(self.datapath, pair_file)) as f:
                num_viewpoint = int(f.readline())
                # viewpoints (49)
                for view_idx in range(num_viewpoint):
                    ref_view = int(f.readline().rstrip())
                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
                    # light conditions 0-6
                    for light_idx in range(7):
                        metas.append((scan, light_idx, ref_view, src_views))
        # print("dataset", self.mode, "metas:", len(metas))
        return metas

    def __len__(self):
        return len(self.metas)

    def read_cam_file(self, filename):
        with open(filename) as f:
            lines = f.readlines()
            lines = [line.rstrip() for line in lines]
        # extrinsics: line [1,5), 4x4 matrix
        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
        # intrinsics: line [7-10), 3x3 matrix
        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
        # depth_min & depth_interval: line 11
        depth_min = float(lines[11].split()[0])
        depth_interval = float(lines[11].split()[1]) * self.interval_scale
        return intrinsics, extrinsics, depth_min, depth_interval

    def read_img(self, filename):
        img = Image.open(filename)
        if self.mode == 'train':
            img = self.color_augment(img)
        # scale 0~255 to 0~1
        np_img = np.array(img, dtype=np.float32) / 255.
        return np_img

    def crop_img(self, img):
        raw_h, raw_w = img.shape[:2]
        start_h = (raw_h-1024)//2
        start_w = (raw_w-1280)//2
        return img[start_h:start_h+1024, start_w:start_w+1280, :]  # 1024, 1280, C

    def prepare_img(self, hr_img):
        h, w = hr_img.shape
        if not self.use_raw_train:
            #w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128
            #downsample
            hr_img_ds = cv2.resize(hr_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST)
            h, w = hr_img_ds.shape
            target_h, target_w = 512, 640
            start_h, start_w = (h - target_h)//2, (w - target_w)//2
            hr_img_crop = hr_img_ds[start_h: start_h + target_h, start_w: start_w + target_w]
        elif self.use_raw_train:
            hr_img_crop = hr_img[h//2-1024//2:h//2+1024//2, w//2-1280//2:w//2+1280//2]  # 1024, 1280, c
        return hr_img_crop

    def read_mask_hr(self, filename):
        img = Image.open(filename)
        np_img = np.array(img, dtype=np.float32)
        np_img = (np_img > 10).astype(np.float32)
        np_img = self.prepare_img(np_img)

        h, w = np_img.shape
        np_img_ms = {
            "stage1": cv2.resize(np_img, (w//8, h//8), interpolation=cv2.INTER_NEAREST),
            "stage2": cv2.resize(np_img, (w//4, h//4), interpolation=cv2.INTER_NEAREST),
            "stage3": cv2.resize(np_img, (w//2, h//2), interpolation=cv2.INTER_NEAREST),
            "stage4": np_img,
        }
        return np_img_ms


    def read_depth_hr(self, filename, scale):
        # read pfm depth file
        #w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128
        depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) * scale
        depth_lr = self.prepare_img(depth_hr)

        h, w = depth_lr.shape
        depth_lr_ms = {
            "stage1": cv2.resize(depth_lr, (w//8, h//8), interpolation=cv2.INTER_NEAREST),
            "stage2": cv2.resize(depth_lr, (w//4, h//4), interpolation=cv2.INTER_NEAREST),
            "stage3": cv2.resize(depth_lr, (w//2, h//2), interpolation=cv2.INTER_NEAREST),
            "stage4": depth_lr,
        }
        return depth_lr_ms

    def __getitem__(self, idx):
        meta = self.metas[idx]
        scan, light_idx, ref_view, src_views = meta
        # use only the reference view and first nviews-1 source views

        if self.mode == 'train' and self.rt:
            num_src_views = len(src_views)
            index = random.sample(range(num_src_views), self.nviews - 1)
            view_ids = [ref_view] + [src_views[i] for i in index]
            scale = random.uniform(0.8, 1.25)
        else:
            view_ids = [ref_view] + src_views[:self.nviews - 1]
            scale = 1
        imgs = []
        mask = None
        depth_values = None
        proj_matrices = []
        for i, vid in enumerate(view_ids):
            # NOTE that the id in image file names is from 1 to 49 (not 0~48)
            if not self.use_raw_train:
                img_filename = os.path.join(self.datapath, 'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx))
            else:
                img_filename = os.path.join(self.datapath, 'Rectified_raw/{}/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx))
            mask_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid))
            depth_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid))
            proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid)
            img = self.read_img(img_filename)
            if self.use_raw_train:
                img = self.crop_img(img)
            intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename)
            if self.rt:
                extrinsics[:3,3] *= scale
            if self.use_raw_train:
                intrinsics[:2, :] *= 2.0

            if i == 0:

                mask_read_ms = self.read_mask_hr(mask_filename_hr)
                depth_ms = self.read_depth_hr(depth_filename_hr, scale)
                #get depth values
                depth_max = depth_interval * self.ndepths + depth_min
                depth_values = np.array([depth_min * scale, depth_max * scale], dtype=np.float32)
                mask = mask_read_ms

            proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32)  #
            proj_mat[0, :4, :4] = extrinsics
            proj_mat[1, :3, :3] = intrinsics
            proj_matrices.append(proj_mat)
            imgs.append(img.transpose(2,0,1))

        #all
        # imgs = np.stack(imgs).transpose([0, 3, 1, 2])
        #ms proj_mats
        proj_matrices = np.stack(proj_matrices)
        stage1_pjmats = proj_matrices.copy()
        stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] / 2.0
        stage3_pjmats = proj_matrices.copy()
        stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2
        stage4_pjmats = proj_matrices.copy()
        stage4_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4

        proj_matrices_ms = {
            "stage1": stage1_pjmats,
            "stage2": proj_matrices,
            "stage3": stage3_pjmats,
            "stage4": stage4_pjmats
        }

        return {"imgs": imgs,  # Nv C H W
                "proj_matrices": proj_matrices_ms,  # 4 stage of Nv 2 4 4
                "depth": depth_ms,
                "depth_values": depth_values,
                "mask": mask }

================================================
FILE: datasets/eth3d.py
================================================
from torch.utils.data import Dataset
from datasets.data_io import *
import os
import numpy as np
import cv2
from PIL import Image

class MVSDataset(Dataset):
    def __init__(self, datapath, split='test', n_views=7, img_wh=(1920,1280)):
        self.levels = 4
        self.datapath = datapath
        self.img_wh = img_wh
        self.split = split
        self.build_metas()
        self.n_views = n_views

    def build_metas(self):
        self.metas = []
        if self.split == "test":
            self.scans = ['botanical_garden', 'boulders', 'bridge', 'door',
                'exhibition_hall', 'lecture_room', 'living_room', 'lounge',
                'observatory', 'old_computer', 'statue', 'terrace_2']

        elif self.split == "train":
            self.scans = ['courtyard', 'delivery_area', 'electro', 'facade',
                    'kicker', 'meadow', 'office', 'pipes', 'playground',
                    'relief', 'relief_2', 'terrace', 'terrains']
        

        for scan in self.scans:
            with open(os.path.join(self.datapath, scan, 'pair.txt')) as f:
                num_viewpoint = int(f.readline())
                for view_idx in range(num_viewpoint):
                    ref_view = int(f.readline().rstrip())
                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
                    if len(src_views) != 0:
                        self.metas += [(scan, -1, ref_view, src_views)]
                    

    def read_cam_file(self, filename):
        with open(filename) as f:
            lines = [line.rstrip() for line in f.readlines()]
        # extrinsics: line [1,5), 4x4 matrix
        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
        extrinsics = extrinsics.reshape((4, 4))
        # intrinsics: line [7-10), 3x3 matrix
        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
        intrinsics = intrinsics.reshape((3, 3))
        
        depth_min = float(lines[11].split()[0])
        if depth_min < 0:
            depth_min = 1
        depth_max = float(lines[11].split()[-1])

        return intrinsics, extrinsics, depth_min, depth_max

    def read_img(self, filename):
        img = Image.open(filename)
        np_img = np.array(img, dtype=np.float32) / 255.
        original_h, original_w, _ = np_img.shape
        np_img = cv2.resize(np_img, self.img_wh, interpolation=cv2.INTER_LINEAR)
        return np_img, original_h, original_w

    def __len__(self):
        return len(self.metas)

    def __getitem__(self, idx):
        scan, _, ref_view, src_views = self.metas[idx]
        # use only the reference view and first nviews-1 source views
        view_ids = [ref_view] + src_views[:self.n_views-1]
        imgs = []

        # depth = None
        depth_min = None
        depth_max = None

        proj_matrices_0 = []
        proj_matrices_1 = []
        proj_matrices_2 = []
        proj_matrices_3 = []

        for i, vid in enumerate(view_ids):
            img_filename = os.path.join(self.datapath,  scan, f'images/{vid:08d}.jpg')
            proj_mat_filename = os.path.join(self.datapath, scan, f'cams_1/{vid:08d}_cam.txt')

            img, original_h, original_w = self.read_img(img_filename)

            intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
            intrinsics[0] *= self.img_wh[0]/original_w
            intrinsics[1] *= self.img_wh[1]/original_h
            imgs.append(img.transpose(2,0,1))

            proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
            proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
            proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
            proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)

            intrinsics[:2,:] *= 0.125
            proj_mat_0[0,:4,:4] = extrinsics.copy()
            proj_mat_0[1,:3,:3] = intrinsics.copy()

            intrinsics[:2,:] *= 2
            proj_mat_1[0,:4,:4] = extrinsics.copy()
            proj_mat_1[1,:3,:3] = intrinsics.copy()

            intrinsics[:2,:] *= 2
            proj_mat_2[0,:4,:4] = extrinsics.copy()
            proj_mat_2[1,:3,:3] = intrinsics.copy()

            intrinsics[:2,:] *= 2
            proj_mat_3[0,:4,:4] = extrinsics.copy()
            proj_mat_3[1,:3,:3] = intrinsics.copy()  

            proj_matrices_0.append(proj_mat_0)
            proj_matrices_1.append(proj_mat_1)
            proj_matrices_2.append(proj_mat_2)
            proj_matrices_3.append(proj_mat_3)

            if i == 0:  # reference view
                depth_min =  depth_min_
                depth_max = depth_max_

        # proj_matrices: N*4*4
        proj={}
        proj['stage1'] = np.stack(proj_matrices_0)
        proj['stage2'] = np.stack(proj_matrices_1)
        proj['stage3'] = np.stack(proj_matrices_2)
        proj['stage4'] = np.stack(proj_matrices_3)


        return {"imgs": imgs,                   # N*3*H0*W0
                "proj_matrices": proj, # N*4*4
                "depth_values": np.array([depth_min, depth_max], dtype=np.float32),
                "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"
                }  


================================================
FILE: datasets/general_eval4.py
================================================
from torch.utils.data import Dataset
import numpy as np
import os, cv2, time
from PIL import Image
from datasets.data_io import *

s_h, s_w = 0, 0
class MVSDataset(Dataset):
    def __init__(self, datapath, listfile, mode, nviews, interval_scale=1.06, **kwargs):
        super(MVSDataset, self).__init__()
        self.datapath = datapath
        self.listfile = listfile
        self.mode = mode
        self.nviews = nviews
        self.ndepths = 192  # Hardcode
        self.interval_scale = interval_scale
        self.max_h, self.max_w = kwargs["max_h"], kwargs["max_w"]
        self.fix_res = kwargs.get("fix_res", False)  #whether to fix the resolution of input image.
        self.fix_wh = False

        assert self.mode == "test"
        self.metas = self.build_list()

    def build_list(self):
        metas = []
        scans = self.listfile

        interval_scale_dict = {}
        # scans
        for scan in scans:
            # determine the interval scale of each scene. default is 1.06
            if isinstance(self.interval_scale, float):
                interval_scale_dict[scan] = self.interval_scale
            else:
                interval_scale_dict[scan] = self.interval_scale[scan]

            pair_file = "{}/pair.txt".format(scan)
            # read the pair file
            with open(os.path.join(self.datapath, pair_file)) as f:
                num_viewpoint = int(f.readline())
                # viewpoints
                for view_idx in range(num_viewpoint):
                    ref_view = int(f.readline().rstrip())
                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
                    # filter by no src view and fill to nviews
                    if len(src_views) > 0:
                        if len(src_views) < self.nviews:
                            print("{}< num_views:{}".format(len(src_views), self.nviews))
                            src_views += [src_views[0]] * (self.nviews - len(src_views))
                        metas.append((scan, ref_view, src_views, scan))

        self.interval_scale = interval_scale_dict
        print("dataset", self.mode, "metas:", len(metas), "interval_scale:{}".format(self.interval_scale))
        return metas

    def __len__(self):
        return len(self.metas)

    def read_cam_file(self, filename, interval_scale):
        with open(filename) as f:
            lines = f.readlines()
            lines = [line.rstrip() for line in lines]
        # extrinsics: line [1,5), 4x4 matrix
        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4))
        # intrinsics: line [7-10), 3x3 matrix
        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3))
        intrinsics[:2, :] /= 4.0
        # depth_min & depth_interval: line 11
        depth_min = float(lines[11].split()[0])
        depth_interval = float(lines[11].split()[1])

        if len(lines[11].split()) >= 3:
            num_depth = lines[11].split()[2]
            depth_max = depth_min + int(float(num_depth)) * depth_interval
            depth_interval = (depth_max - depth_min) / self.ndepths

        depth_interval *= interval_scale

        return intrinsics, extrinsics, depth_min, depth_interval

    def read_img(self, filename):
        img = Image.open(filename)
        # scale 0~255 to 0~1
        np_img = np.array(img, dtype=np.float32) / 255.

        return np_img

    def read_depth(self, filename):
        # read pfm depth file
        return np.array(read_pfm(filename)[0], dtype=np.float32)

    def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=64):
        h, w = img.shape[:2]
        if h > max_h or w > max_w:
            scale = 1.0 * max_h / h
            if scale * w > max_w:
                scale = 1.0 * max_w / w
            new_w, new_h = scale * w // base * base, scale * h // base * base
        else:
            new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base

        scale_w = 1.0 * new_w / w
        scale_h = 1.0 * new_h / h
        intrinsics[0, :] *= scale_w
        intrinsics[1, :] *= scale_h

        img = cv2.resize(img, (int(new_w), int(new_h)))

        return img, intrinsics

    def __getitem__(self, idx):
        global s_h, s_w
        meta = self.metas[idx]
        scan, ref_view, src_views, scene_name = meta
        # use only the reference view and first nviews-1 source views
        view_ids = [ref_view] + src_views[:self.nviews - 1]

        imgs = []
        depth_values = None
        proj_matrices = []

        for i, vid in enumerate(view_ids):
            img_filename = os.path.join(self.datapath, '{}/images_post/{:0>8}.jpg'.format(scan, vid))
            if not os.path.exists(img_filename):
                img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid))

            proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid))

            img = self.read_img(img_filename)
            intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename, interval_scale=
                                                                                   self.interval_scale[scene_name])
            # scale input
            img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_w, self.max_h)

            if self.fix_res:
                # using the same standard height or width in entire scene.
                s_h, s_w = img.shape[:2]
                self.fix_res = False
                self.fix_wh = True

            if i == 0:
                if not self.fix_wh:
                    # using the same standard height or width in each nviews.
                    s_h, s_w = img.shape[:2]

            # resize to standard height or width
            c_h, c_w = img.shape[:2]
            if (c_h != s_h) or (c_w != s_w):
                scale_h = 1.0 * s_h / c_h
                scale_w = 1.0 * s_w / c_w
                img = cv2.resize(img, (s_w, s_h))
                intrinsics[0, :] *= scale_w
                intrinsics[1, :] *= scale_h


            imgs.append(img.transpose(2,0,1))
            # extrinsics, intrinsics
            proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32)  #
            proj_mat[0, :4, :4] = extrinsics
            proj_mat[1, :3, :3] = intrinsics
            proj_matrices.append(proj_mat)

            if i == 0:  # reference view
                depth_values = np.arange(depth_min, depth_interval * (self.ndepths - 0.5) + depth_min, depth_interval,
                                         dtype=np.float32)

        #all
        # imgs = np.stack(imgs).transpose([0, 3, 1, 2])
        #ms proj_mats
        proj_matrices = np.stack(proj_matrices)
        stage1_pjmats = proj_matrices.copy()
        stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] / 2.0
        stage3_pjmats = proj_matrices.copy()
        stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2
        stage4_pjmats = proj_matrices.copy()
        stage4_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4

        proj_matrices_ms = {
            "stage1": stage1_pjmats,
            "stage2": proj_matrices,
            "stage3": stage3_pjmats,
            "stage4": stage4_pjmats
        }

        return {"imgs": imgs,
                "proj_matrices": proj_matrices_ms,
                "depth_values": depth_values,
                "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"}


================================================
FILE: datasets/tanks.py
================================================
from torch.utils.data import Dataset
from datasets.data_io import *
import os
import numpy as np
import cv2
from PIL import Image

class MVSDataset(Dataset):
    def __init__(self, datapath, n_views=7, split='intermediate'):
        self.levels = 4
        self.datapath = datapath
        self.split = split
        self.build_metas()
        self.n_views = n_views

    def build_metas(self):
        self.metas = []
        if self.split == 'intermediate':
            self.scans = ['Family', 'Francis', 'Horse', 'Playground', 'Train', 'Lighthouse', 'M60', 'Panther']
        elif self.split == 'advanced':
            self.scans = ['Auditorium', 'Ballroom', 'Courtroom',
                          'Museum', 'Palace', 'Temple']

        for scan in self.scans:
            with open(os.path.join(self.datapath, self.split, scan, 'pair.txt')) as f:
                num_viewpoint = int(f.readline())
                for view_idx in range(num_viewpoint):
                    ref_view = int(f.readline().rstrip())
                    src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
                    if len(src_views) != 0:
                        self.metas += [(scan, -1, ref_view, src_views)]
   
    def read_cam_file(self, filename):
        with open(filename) as f:
            lines = [line.rstrip() for line in f.readlines()]
        # extrinsics: line [1,5), 4x4 matrix
        extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ')
        extrinsics = extrinsics.reshape((4, 4))
        # intrinsics: line [7-10), 3x3 matrix
        intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ')
        intrinsics = intrinsics.reshape((3, 3))
        
        depth_min = float(lines[11].split()[0])
        depth_max = float(lines[11].split()[-1])

        return intrinsics, extrinsics, depth_min, depth_max

    def read_img(self, filename):
        img = Image.open(filename)
        np_img = np.array(img, dtype=np.float32) / 255.
        return np_img

    def scale_input(self, intrinsics, img):
        """
        intrinsics: 3x3
        img: W H C
        """
        intrinsics[1,2] =  intrinsics[1,2] - 28  # 1080 -> 1024
        img = img[28:1080-28, :, :]
        return intrinsics, img

    def __len__(self):
        return len(self.metas)

    def __getitem__(self, idx):
        scan, _, ref_view, src_views = self.metas[idx]
        # use only the reference view and first nviews-1 source views
        view_ids = [ref_view] + src_views[:self.n_views-1]

        imgs = []

        # depth = None
        depth_min = None
        depth_max = None

        proj_matrices_0 = []
        proj_matrices_1 = []
        proj_matrices_2 = []
        proj_matrices_3 = []

        for i, vid in enumerate(view_ids):
            img_filename = os.path.join(self.datapath, self.split, scan, f'images/{vid:08d}.jpg')
            proj_mat_filename = os.path.join(self.datapath, self.split, scan, f'cams/{vid:08d}_cam.txt')

            img = self.read_img(img_filename)

            intrinsics, extrinsics, depth_min_, depth_max_ = self.read_cam_file(proj_mat_filename)
            intrinsics, img = self.scale_input(intrinsics, img)
            imgs.append(img.transpose(2,0,1))

            proj_mat_0 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
            proj_mat_1 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
            proj_mat_2 = np.zeros(shape=(2, 4, 4), dtype=np.float32)
            proj_mat_3 = np.zeros(shape=(2, 4, 4), dtype=np.float32)

            intrinsics[:2,:] *= 0.125
            proj_mat_0[0,:4,:4] = extrinsics.copy()
            proj_mat_0[1,:3,:3] = intrinsics.copy()

            intrinsics[:2,:] *= 2
            proj_mat_1[0,:4,:4] = extrinsics.copy()
            proj_mat_1[1,:3,:3] = intrinsics.copy()

            intrinsics[:2,:] *= 2
            proj_mat_2[0,:4,:4] = extrinsics.copy()
            proj_mat_2[1,:3,:3] = intrinsics.copy()

            intrinsics[:2,:] *= 2
            proj_mat_3[0,:4,:4] = extrinsics.copy()
            proj_mat_3[1,:3,:3] = intrinsics.copy()  

            proj_matrices_0.append(proj_mat_0)
            proj_matrices_1.append(proj_mat_1)
            proj_matrices_2.append(proj_mat_2)
            proj_matrices_3.append(proj_mat_3)

            if i == 0:  # reference view
                depth_min =  depth_min_
                depth_max = depth_max_


        # proj_matrices: N*4*4
        proj={}
        proj['stage1'] = np.stack(proj_matrices_0)
        proj['stage2'] = np.stack(proj_matrices_1)
        proj['stage3'] = np.stack(proj_matrices_2)
        proj['stage4'] = np.stack(proj_matrices_3)

        return {"imgs": imgs,                   # N*3*H0*W0
                "proj_matrices": proj, # N*4*4
                "depth_values": np.array([depth_min, depth_max], dtype=np.float32),
                "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"
                }  


================================================
FILE: evaluations/dtu/BaseEval2Obj_web.m
================================================
function BaseEval2Obj_web(BaseEval,method_string,outputPath)

if(nargin<3)
    outputPath='./';
end

% tresshold for coloring alpha channel in the range of 0-10 mm
dist_tresshold=10;

cSet=BaseEval.cSet;

Qdata=BaseEval.Qdata;
alpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold;

fid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+');

for cP=1:size(Qdata,2)
    if(BaseEval.DataInMask(cP))
        C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)
    else
        C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis)
    end
    fprintf(fid,'v %f %f %f %f %f %f\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]);
end
fclose(fid);

disp('Data2Stl saved as obj')

Qstl=BaseEval.Qstl;
fid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+');

alpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold;

for cP=1:size(Qstl,2)
    if(BaseEval.StlAbovePlane(cP))
        C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold)
    else
        C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis)
    end
    fprintf(fid,'v %f %f %f %f %f %f\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]);
end
fclose(fid);

disp('Stl2Data saved as obj')

================================================
FILE: evaluations/dtu/BaseEvalMain_func.m
================================================
function None = BaseEvalMain_func(plyPath)

% clear all
% close all
format compact

% script to calculate distances have been measured for all included scans (UsedSets)

dataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';
% pred_results='cascade_hr/48-32-8_4-2-1_dlossw-0.5-1.0-2.0_chs888/gipuma_4_0.9/';
% plyPath=['../../outputs/1101/dtu/' pred_results];

% plyPath = '/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/repo_model_aligncorners_ITGT'
resultsPath=[plyPath '/eval_out/'];
disp(resultsPath);
mkdir(resultsPath);

method_string='mvsnet';
light_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'

switch representation_string
    case 'Points'
        eval_string='_Eval_'; %results naming
        settings_string='';
end

% get sets used in evaluation
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];
% UsedSets=[15];

dst=0.2;    %Min dist between points when reducing

parfor cIdx=1:length(UsedSets)
    %Data set number
    cSet = UsedSets(cIdx)
    %input data name
    DataInName=[plyPath sprintf('/%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)]

    %results name
    EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat']

    %check if file is already computed
    if(~exist(EvalName,'file'))
        disp(DataInName);

        time=clock;time(4:5), drawnow

        tic
        Mesh = plyread(DataInName);
        Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]';
        toc

        BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath);

        disp('Saving results'), drawnow
        toc
        mySave(EvalName, BaseEval);
        toc

        % write obj-file of evaluation
        % BaseEval2Obj_web(BaseEval,method_string, resultsPath)
        % toc
        time=clock;time(4:5), drawnow

        BaseEval.MaxDist=20; %outlier threshold of 20 mm

        BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane
        BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers

        BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask
        BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers

        fprintf("mean/median Data (acc.) %f/%f\n", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));
        fprintf("mean/median Stl (comp.) %f/%f\n", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));
    end
end

end

function mySave(filenm, data)
    save(filenm, 'data');
end

================================================
FILE: evaluations/dtu/BaseEvalMain_web.m
================================================
clear all
close all
format compact
clc

% script to calculate distances have been measured for all included scans (UsedSets)

dataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';
% pred_results='cascade_hr/48-32-8_4-2-1_dlossw-0.5-1.0-2.0_chs888/gipuma_4_0.9/';
% plyPath=['../../outputs/1101/dtu/' pred_results];

plyPath = '/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/ccc_4x2_scedule_aligncorners'
resultsPath=[plyPath '/eval_out/'];
disp(resultsPath);
mkdir(resultsPath);

method_string='mvsnet';
light_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'

switch representation_string
    case 'Points'
        eval_string='_Eval_'; %results naming
        settings_string='';
end

% get sets used in evaluation
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];
% UsedSets=[15];

dst=0.2;    %Min dist between points when reducing

parfor cIdx=1:length(UsedSets)
    %Data set number
    cSet = UsedSets(cIdx)
    %input data name
    DataInName=[plyPath sprintf('/%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)]

    %results name
    EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat']

    %check if file is already computed
    if(~exist(EvalName,'file'))
        disp(DataInName);

        time=clock;time(4:5), drawnow

        tic
        Mesh = plyread(DataInName);
        Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]';
        toc

        BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath);

        disp('Saving results'), drawnow
        toc
        mySave(EvalName, BaseEval);
        toc

        % write obj-file of evaluation
        % BaseEval2Obj_web(BaseEval,method_string, resultsPath)
        % toc
        time=clock;time(4:5), drawnow

        BaseEval.MaxDist=20; %outlier threshold of 20 mm

        BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane
        BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl<BaseEval.MaxDist); % discard outliers

        BaseEval.FilteredDdata=BaseEval.Ddata(BaseEval.DataInMask); %use only points that within mask
        BaseEval.FilteredDdata=BaseEval.FilteredDdata(BaseEval.FilteredDdata<BaseEval.MaxDist); % discard outliers

        fprintf("mean/median Data (acc.) %f/%f\n", mean(BaseEval.FilteredDdata), median(BaseEval.FilteredDdata));
        fprintf("mean/median Stl (comp.) %f/%f\n", mean(BaseEval.FilteredDstl), median(BaseEval.FilteredDstl));
    end
end


function mySave(filenm, data)
    save(filenm, 'data');
end

================================================
FILE: evaluations/dtu/ComputeStat_func.m
================================================
function None = ComputeStat_func(plyPath)
format compact

% script to calculate the statistics for each scan given this will currently only run if distances have been measured
% for all included scans (UsedSets)

% modify the path to evaluate your models
dataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';
% resultsPath=['/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/repo_model_aligncorners_ITGT/eval_out/'];
resultsPath=[plyPath '/eval_out/'];

MaxDist=20; %outlier thresshold of 20 mm

time=clock;

method_string='mvsnet';
light_string='l3'; %'l7'; l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'

switch representation_string
    case 'Points'
        eval_string='_Eval_'; %results naming
        settings_string='';
end

% get sets used in evaluation
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];

nStat=length(UsedSets);

BaseStat.nStl=zeros(1,nStat);
BaseStat.nData=zeros(1,nStat);
BaseStat.MeanStl=zeros(1,nStat);
BaseStat.MeanData=zeros(1,nStat);
BaseStat.VarStl=zeros(1,nStat);
BaseStat.VarData=zeros(1,nStat);
BaseStat.MedStl=zeros(1,nStat);
BaseStat.MedData=zeros(1,nStat);

for cStat=1:length(UsedSets) %Data set number

    currentSet=UsedSets(cStat);

    %input results name
    EvalName=[resultsPath method_string eval_string num2str(currentSet) '.mat'];

    disp(EvalName);
    load(EvalName);

    Dstl=data.Dstl(data.StlAbovePlane); %use only points that are above the plane
    Dstl=Dstl(Dstl<MaxDist); % discard outliers

    Ddata=data.Ddata(data.DataInMask); %use only points that within mask
    Ddata=Ddata(Ddata<MaxDist); % discard outliers

    BaseStat.nStl(cStat)=length(Dstl);
    BaseStat.nData(cStat)=length(Ddata);

    BaseStat.MeanStl(cStat)=mean(Dstl);
    BaseStat.MeanData(cStat)=mean(Ddata);

    BaseStat.VarStl(cStat)=var(Dstl);
    BaseStat.VarData(cStat)=var(Ddata);

    BaseStat.MedStl(cStat)=median(Dstl);
    BaseStat.MedData(cStat)=median(Ddata);

    disp("acc");
    disp(mean(Ddata));
    disp("comp");
    disp(mean(Dstl));
    time=clock;
end

disp(BaseStat);
disp("mean acc")
disp(mean(BaseStat.MeanData));
disp("mean comp")
disp(mean(BaseStat.MeanStl));
disp("mean overall")
disp((mean(BaseStat.MeanStl)+mean(BaseStat.MeanData))/2.0);

totalStatName=[resultsPath 'TotalStat_' method_string eval_string '.mat']
save(totalStatName,'BaseStat','time','MaxDist');

totalStatName=[resultsPath 'TotalStat_' method_string eval_string '.txt']
fp=fopen(totalStatName,'a');
fprintf(fp,'%f\n',mean(BaseStat.MeanData));
fprintf(fp,'%f\n',mean(BaseStat.MeanStl));
end

================================================
FILE: evaluations/dtu/ComputeStat_web.m
================================================
clear all
close all
format compact
clc

% script to calculate the statistics for each scan given this will currently only run if distances have been measured
% for all included scans (UsedSets)

% modify the path to evaluate your models
dataPath='/mnt/cfs/algorithm/public_data/mvs/dtu_evalset/SampleSet/MVS Data';
resultsPath=['/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/outputs/repo_model_aligncorners_ITGT/eval_out/'];

MaxDist=20; %outlier thresshold of 20 mm

time=clock;

method_string='mvsnet';
light_string='l3'; %'l7'; l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6)
representation_string='Points'; %mvs representation 'Points' or 'Surfaces'

switch representation_string
    case 'Points'
        eval_string='_Eval_'; %results naming
        settings_string='';
end

% get sets used in evaluation
UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118];

nStat=length(UsedSets);

BaseStat.nStl=zeros(1,nStat);
BaseStat.nData=zeros(1,nStat);
BaseStat.MeanStl=zeros(1,nStat);
BaseStat.MeanData=zeros(1,nStat);
BaseStat.VarStl=zeros(1,nStat);
BaseStat.VarData=zeros(1,nStat);
BaseStat.MedStl=zeros(1,nStat);
BaseStat.MedData=zeros(1,nStat);

for cStat=1:length(UsedSets) %Data set number

    currentSet=UsedSets(cStat);

    %input results name
    EvalName=[resultsPath method_string eval_string num2str(currentSet) '.mat'];

    disp(EvalName);
    load(EvalName);

    Dstl=data.Dstl(data.StlAbovePlane); %use only points that are above the plane
    Dstl=Dstl(Dstl<MaxDist); % discard outliers

    Ddata=data.Ddata(data.DataInMask); %use only points that within mask
    Ddata=Ddata(Ddata<MaxDist); % discard outliers

    BaseStat.nStl(cStat)=length(Dstl);
    BaseStat.nData(cStat)=length(Ddata);

    BaseStat.MeanStl(cStat)=mean(Dstl);
    BaseStat.MeanData(cStat)=mean(Ddata);

    BaseStat.VarStl(cStat)=var(Dstl);
    BaseStat.VarData(cStat)=var(Ddata);

    BaseStat.MedStl(cStat)=median(Dstl);
    BaseStat.MedData(cStat)=median(Ddata);

    disp("acc");
    disp(mean(Ddata));
    disp("comp");
    disp(mean(Dstl));
    time=clock;
end

disp(BaseStat);
disp("mean acc")
disp(mean(BaseStat.MeanData));
disp("mean comp")
disp(mean(BaseStat.MeanStl));
disp("mean overall")
disp((mean(BaseStat.MeanStl)+mean(BaseStat.MeanData))/2.0);

totalStatName=[resultsPath 'TotalStat_' method_string eval_string '.mat']
save(totalStatName,'BaseStat','time','MaxDist');

totalStatName=[resultsPath 'TotalStat_' method_string eval_string '.txt']
fp=fopen(totalStatName,'a');
fprintf(fp,'%f\n',mean(BaseStat.MeanData));
fprintf(fp,'%f\n',mean(BaseStat.MeanStl));


================================================
FILE: evaluations/dtu/MaxDistCP.m
================================================
function Dist = MaxDistCP(Qto,Qfrom,BB,MaxDist)

Dist=ones(1,size(Qfrom,2))*MaxDist;

Range=floor((BB(2,:)-BB(1,:))/MaxDist);

tic
Done=0;
LookAt=zeros(1,size(Qfrom,2));
for x=0:Range(1),
    for y=0:Range(2),
        for z=0:Range(3),
            
            Low=BB(1,:)+[x y z]*MaxDist;
            High=Low+MaxDist;
            
            idxF=find(Qfrom(1,:)>=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &...
                Qfrom(1,:)<High(1) & Qfrom(2,:)<High(2) & Qfrom(3,:)<High(3));
            SQfrom=Qfrom(:,idxF);
            LookAt(idxF)=LookAt(idxF)+1; %Debug
            
            Low=Low-MaxDist;
            High=High+MaxDist;
            idxT=find(Qto(1,:)>=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &...
                Qto(1,:)<High(1) & Qto(2,:)<High(2) & Qto(3,:)<High(3));
            SQto=Qto(:,idxT);
            
            if(isempty(SQto))
                Dist(idxF)=MaxDist;
            else
                KDstl=KDTreeSearcher(SQto');
                [~,SDist] = knnsearch(KDstl,SQfrom');
                Dist(idxF)=SDist;
                
            end
            
            Done=Done+length(idxF); %Debug
            
        end
    end
    %Complete=Done/size(Qfrom,2);
    %EstTime=(toc/Complete)/60
    %toc
    %LA=[sum(LookAt==0),...
    %	sum(LookAt==1),...
   % 	sum(LookAt==2),...
   % 	sum(LookAt==3),...
   % 	sum(LookAt>3)]
end



================================================
FILE: evaluations/dtu/PointCompareMain.m
================================================
function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath)
% evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the
% distances from the evaluation points to the reference

tic
% reduce points 0.2 mm neighbourhood density
Qdata=reducePts_haa(Qdata,dst);
toc

StlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply'];

StlMesh = plyread(StlInName);  %STL points already reduced 0.2 mm neighbourhood density
Qstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]';

%Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res)
Margin=10;
MaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat'];
load(MaskName)

MaxDist=60;
disp('Computing Data 2 Stl distances')
Ddata = MaxDistCP(Qstl,Qdata,BB,MaxDist);
toc

disp('Computing Stl 2 Data distances')
Dstl=MaxDistCP(Qdata,Qstl,BB,MaxDist);
disp('Distances computed')
toc

%use mask
%From Get mask - inverted & modified.
One=ones(1,size(Qdata,2));
Qv=(Qdata-BB(1,:)'*One)/Res+1;
Qv=round(Qv);

Midx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3));
MidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1));
Midx2=find(ObsMask(MidxA));

BaseEval.DataInMask(1:size(Qv,2))=false;
BaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask

BaseEval.cSet=cSet;
BaseEval.Margin=Margin;         %Margin of masks
BaseEval.dst=dst;               %Min dist between points when reducing
BaseEval.Qdata=Qdata;           %Input data points
BaseEval.Ddata=Ddata;           %distance from data to stl
BaseEval.Qstl=Qstl;             %Input stl points
BaseEval.Dstl=Dstl;             %Distance from the stl to data

load([dataPath '/ObsMask/Plane' num2str(cSet)],'P')
BaseEval.GroundPlane=P;         % Plane used to destinguise which Stl points are 'used'
BaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane'
BaseEval.Time=clock;            %Time when computation is finished






================================================
FILE: evaluations/dtu/plyread.m
================================================
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [Elements,varargout] = plyread(Path,Str)
%PLYREAD   Read a PLY 3D data file.
%   [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file
%   FILENAME and returns a structure DATA.  The fields in this structure
%   are defined by the PLY header; each element type is a field and each
%   element property is a subfield.  If the file contains any comments,
%   they are returned in a cell string array COMMENTS.
%
%   [TRI,PTS] = PLYREAD(FILENAME,'tri') or
%   [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex
%   and face data into triangular connectivity and vertex arrays.  The
%   mesh can then be displayed using the TRISURF command.
%
%   Note: This function is slow for large mesh files (+50K faces),
%   especially when reading data with list type properties.
%
%   Example:
%   [Tri,Pts] = PLYREAD('cow.ply','tri');
%   trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); 
%   colormap(gray); axis equal;
%
%   See also: PLYWRITE

% Pascal Getreuer 2004

[fid,Msg] = fopen(Path,'rt');	% open file in read text mode

if fid == -1, error(Msg); end

Buf = fscanf(fid,'%s',1);
if ~strcmp(Buf,'ply')
   fclose(fid);
   error('Not a PLY file.'); 
end


%%% read header %%%

Position = ftell(fid);
Format = '';
NumComments = 0;
Comments = {};				% for storing any file comments
NumElements = 0;
NumProperties = 0;
Elements = [];				% structure for holding the element data
ElementCount = [];		% number of each type of element in file
PropertyTypes = [];		% corresponding structure recording property types
ElementNames = {};		% list of element names in the order they are stored in the file
PropertyNames = [];		% structure of lists of property names

while 1
   Buf = fgetl(fid);   								% read one line from file
   BufRem = Buf;
   Token = {};
   Count = 0;
   
   while ~isempty(BufRem)								% split line into tokens
      [tmp,BufRem] = strtok(BufRem);
      
      if ~isempty(tmp)
         Count = Count + 1;							% count tokens
         Token{Count} = tmp;
      end
   end
   
   if Count 		% parse line
      switch lower(Token{1})
      case 'format'		% read data format
         if Count >= 2
            Format = lower(Token{2});
            
            if Count == 3 & ~strcmp(Token{3},'1.0')
               fclose(fid);
               error('Only PLY format version 1.0 supported.');
            end
         end
      case 'comment'		% read file comment
         NumComments = NumComments + 1;
         Comments{NumComments} = '';
         for i = 2:Count
            Comments{NumComments} = [Comments{NumComments},Token{i},' '];
         end
      case 'element'		% element name
         if Count >= 3
            if isfield(Elements,Token{2})
               fclose(fid);
               error(['Duplicate element name, ''',Token{2},'''.']);
            end
            
            NumElements = NumElements + 1;
            NumProperties = 0;
   	      Elements = setfield(Elements,Token{2},[]);
            PropertyTypes = setfield(PropertyTypes,Token{2},[]);
            ElementNames{NumElements} = Token{2};
            PropertyNames = setfield(PropertyNames,Token{2},{});
            CurElement = Token{2};
            ElementCount(NumElements) = str2double(Token{3});
            
            if isnan(ElementCount(NumElements))
               fclose(fid);
               error(['Bad element definition: ',Buf]); 
            end            
         else
            error(['Bad element definition: ',Buf]);
         end         
      case 'property'	% element property
         if ~isempty(CurElement) & Count >= 3            
            NumProperties = NumProperties + 1;
            eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],...
               'fclose(fid);error([''Error reading property: '',Buf])');
            
            if tmp
               error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']);
            end            
            
            % add property subfield to Elements
            eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ...
               'fclose(fid);error([''Error reading property: '',Buf])');            
            % add property subfield to PropertyTypes and save type
            eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ...
               'fclose(fid);error([''Error reading property: '',Buf])');            
            % record property name order 
            eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ...
               'fclose(fid);error([''Error reading property: '',Buf])');
         else
            fclose(fid);
            
            if isempty(CurElement)            
               error(['Property definition without element definition: ',Buf]);
            else               
               error(['Bad property definition: ',Buf]);
            end            
         end         
      case 'end_header'	% end of header, break from while loop
         break;		
      end
   end
end

%%% set reading for specified data format %%%

if isempty(Format)
	warning('Data format unspecified, assuming ASCII.');
   Format = 'ascii';
end

switch Format
case 'ascii'
   Format = 0;
case 'binary_little_endian'
   Format = 1;
case 'binary_big_endian'
   Format = 2;
otherwise
   fclose(fid);
   error(['Data format ''',Format,''' not supported.']);
end

if ~Format   
   Buf = fscanf(fid,'%f');		% read the rest of the file as ASCII data
   BufOff = 1;
else
   % reopen the file in read binary mode
   fclose(fid);
   
   if Format == 1
      fid = fopen(Path,'r','ieee-le.l64');		% little endian
   else
      fid = fopen(Path,'r','ieee-be.l64');		% big endian
   end
   
   % find the end of the header again (using ftell on the old handle doesn't give the correct position)   
   BufSize = 8192;
   Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')];
   i = [];
   tmp = -11;
   
   while isempty(i)
   	i = findstr(Buf,['end_header',13,10]);			% look for end_header + CR/LF
   	i = [i,findstr(Buf,['end_header',10])];		% look for end_header + LF
      
      if isempty(i)
         tmp = tmp + BufSize;
         Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')];
      end
   end
   
   % seek to just after the line feed
   fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1);
end


%%% read element data %%%

% PLY and MATLAB data types (for fread)
PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ...
   'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'};
MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'};
SizeOf = [1,1,2,2,4,4,4,8];	% size in bytes of each type

for i = 1:NumElements
   % get current element property information
   eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']);
   eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']);
   NumProperties = size(CurPropertyNames,2);
   
%   fprintf('Reading %s...\n',ElementNames{i});
      
   if ~Format	%%% read ASCII data %%%
      for j = 1:NumProperties
         Token = getfield(CurPropertyTypes,CurPropertyNames{j});
         
         if strcmpi(Token{1},'list')
            Type(j) = 1;
         else
            Type(j) = 0;
			end
      end
      
      % parse buffer
      if ~any(Type)
         % no list types
         Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))';
         BufOff = BufOff + ElementCount(i)*NumProperties;
      else
         ListData = cell(NumProperties,1);
         
         for k = 1:NumProperties
            ListData{k} = cell(ElementCount(i),1);
         end
         
         % list type
		   for j = 1:ElementCount(i)
   	      for k = 1:NumProperties
      	      if ~Type(k)
         	      Data(j,k) = Buf(BufOff);
            	   BufOff = BufOff + 1;
	            else
   	            tmp = Buf(BufOff);
      	         ListData{k}{j} = Buf(BufOff+(1:tmp))';
         	      BufOff = BufOff + tmp + 1;
            	end
            end
         end
      end
   else		%%% read binary data %%%
      % translate PLY data type names to MATLAB data type names
      ListFlag = 0;		% = 1 if there is a list type 
      SameFlag = 1;     % = 1 if all types are the same
      
      for j = 1:NumProperties
         Token = getfield(CurPropertyTypes,CurPropertyNames{j});
         
         if ~strcmp(Token{1},'list')			% non-list type
	         tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1;
         
            if ~isempty(tmp)
               TypeSize(j) = SizeOf(tmp);
               Type{j} = MatlabTypeNames{tmp};
               TypeSize2(j) = 0;
               Type2{j} = '';
               
               SameFlag = SameFlag & strcmp(Type{1},Type{j});
	         else
   	         fclose(fid);
               error(['Unknown property data type, ''',Token{1},''', in ', ...
                     ElementNames{i},'.',CurPropertyNames{j},'.']);
         	end
         else											% list type
            if length(Token) == 3
               ListFlag = 1;
               SameFlag = 0;
               tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1;
               tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1;
         
               if ~isempty(tmp) & ~isempty(tmp2)
                  TypeSize(j) = SizeOf(tmp);
                  Type{j} = MatlabTypeNames{tmp};
                  TypeSize2(j) = SizeOf(tmp2);
                  Type2{j} = MatlabTypeNames{tmp2};
	   	      else
   	   	      fclose(fid);
               	error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ...
                        ElementNames{i},'.',CurPropertyNames{j},'.']);
               end
            else
               fclose(fid);
               error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']);
            end
         end
      end
      
      % read file
      if ~ListFlag
         if SameFlag
            % no list types, all the same type (fast)
            Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})';
         else
            % no list types, mixed type
            Data = zeros(ElementCount(i),NumProperties);
            
         	for j = 1:ElementCount(i)
        			for k = 1:NumProperties
               	Data(j,k) = fread(fid,1,Type{k});
              	end
         	end
         end
      else
         ListData = cell(NumProperties,1);
         
         for k = 1:NumProperties
            ListData{k} = cell(ElementCount(i),1);
         end
         
         if NumProperties == 1
            BufSize = 512;
            SkipNum = 4;
            j = 0;
            
            % list type, one property (fast if lists are usually the same length)
            while j < ElementCount(i)
               Position = ftell(fid);
               % read in BufSize count values, assuming all counts = SkipNum
               [Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1));
               Miss = find(Buf ~= SkipNum);					% find first count that is not SkipNum
               fseek(fid,Position + TypeSize(1),-1); 		% seek back to after first count                              
               
               if isempty(Miss)									% all counts are SkipNum
                  Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';
                  fseek(fid,-TypeSize(1),0); 				% undo last skip
                  
                  for k = 1:BufSize
                     ListData{1}{j+k} = Buf(k,:);
                  end
                  
                  j = j + BufSize;
                  BufSize = floor(1.5*BufSize);
               else
                  if Miss(1) > 1									% some counts are SkipNum
                     Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))';                     
                     
                     for k = 1:Miss(1)-1
                        ListData{1}{j+k} = Buf2(k,:);
                     end
                     
                     j = j + k;
                  end
                  
                  % read in the list with the missed count
                  SkipNum = Buf(Miss(1));
                  j = j + 1;
                  ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1});
                  BufSize = ceil(0.6*BufSize);
               end
            end
         else
            % list type(s), multiple properties (slow)
            Data = zeros(ElementCount(i),NumProperties);
            
            for j = 1:ElementCount(i)
         		for k = 1:NumProperties
            		if isempty(Type2{k})
               		Data(j,k) = fread(fid,1,Type{k});
            		else
               		tmp = fread(fid,1,Type{k});
               		ListData{k}{j} = fread(fid,[1,tmp],Type2{k});
		            end
      		   end
      		end
         end
      end
   end
   
   % put data into Elements structure
   for k = 1:NumProperties
   	if (~Format & ~Type(k)) | (Format & isempty(Type2{k}))
      	eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']);
      else
      	eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']);
		end
   end
end

clear Data ListData;
fclose(fid);

if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2   
   % find vertex element field
   Name = {'vertex','Vertex','point','Point','pts','Pts'};
   Names = [];
   
   for i = 1:length(Name)
      if any(strcmp(ElementNames,Name{i}))
         Names = getfield(PropertyNames,Name{i});
         Name = Name{i};         
         break;
      end
   end
   
   if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z'))
      eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']);
   else
      varargout{1} = zeros(1,3);
	end
           
   varargout{2} = Elements;
   varargout{3} = Comments;
   Elements = [];
   
   % find face element field
   Name = {'face','Face','poly','Poly','tri','Tri'};
   Names = [];
   
   for i = 1:length(Name)
      if any(strcmp(ElementNames,Name{i}))
         Names = getfield(PropertyNames,Name{i});
         Name = Name{i};
         break;
      end
   end
   
   if ~isempty(Names)
      % find vertex indices property subfield
	   PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'};           
      
   	for i = 1:length(PropertyName)
      	if any(strcmp(Names,PropertyName{i}))
         	PropertyName = PropertyName{i};
	         break;
   	   end
      end
      
      if ~iscell(PropertyName)
         % convert face index lists to triangular connectivity
         eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']);
  			N = length(FaceIndices);
   		Elements = zeros(N*2,3);
   		Extra = 0;   

			for k = 1:N
   			Elements(k,:) = FaceIndices{k}(1:3);
   
   			for j = 4:length(FaceIndices{k})
      			Extra = Extra + 1;      
	      		Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)];
   			end
         end
         Elements = Elements(1:N+Extra,:) + 1;
      end
   end
else
   varargout{1} = Comments;
end

================================================
FILE: evaluations/dtu/reducePts_haa.m
================================================
function [ptsOut,indexSet] = reducePts_haa(pts, dst)

%Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance
% between points is 'dst'. Writen by abd, edited by haa, then by raje

nPoints=size(pts,2);

indexSet=true(nPoints,1);
RandOrd=randperm(nPoints);

%tic
NS = KDTreeSearcher(pts');
%toc

% search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big
Chunks=1:min(4e6,nPoints-1):nPoints;
Chunks(end)=nPoints;

for cChunk=1:(length(Chunks)-1)
    Range=Chunks(cChunk):Chunks(cChunk+1);
    
    idx = rangesearch(NS,pts(:,RandOrd(Range))',dst);
    
    for i = 1:size(idx,1)
        id =RandOrd(i-1+Chunks(cChunk));
        if (indexSet(id))
            indexSet(idx{i}) = 0;
            indexSet(id) = 1;
        end
    end
end

ptsOut = pts(:,indexSet);

disp(['downsample factor: ' num2str(nPoints/sum(indexSet))]);


================================================
FILE: lists/blendedmvs/train.txt
================================================
5c1f33f1d33e1f2e4aa6dda4
5bfe5ae0fe0ea555e6a969ca
5bff3c5cfe0ea555e6bcbf3a
58eaf1513353456af3a1682a
5bfc9d5aec61ca1dd69132a2
5bf18642c50e6f7f8bdbd492
5bf26cbbd43923194854b270
5bf17c0fd439231948355385
5be3ae47f44e235bdbbc9771
5be3a5fb8cfdd56947f6b67c
5bbb6eb2ea1cfa39f1af7e0c
5ba75d79d76ffa2c86cf2f05
5bb7a08aea1cfa39f1a947ab
5b864d850d072a699b32f4ae
5b6eff8b67b396324c5b2672
5b6e716d67b396324c2d77cb
5b69cc0cb44b61786eb959bf
5b62647143840965efc0dbde
5b60fa0c764f146feef84df0
5b558a928bbfb62204e77ba2
5b271079e0878c3816dacca4
5b08286b2775267d5b0634ba
5afacb69ab00705d0cefdd5b
5af28cea59bc705737003253
5af02e904c8216544b4ab5a2
5aa515e613d42d091d29d300
5c34529873a8df509ae57b58
5c34300a73a8df509add216d
5c1af2e2bee9a723c963d019
5c1892f726173c3a09ea9aeb
5c0d13b795da9479e12e2ee9
5c062d84a96e33018ff6f0a6
5bfd0f32ec61ca1dd69dc77b
5bf21799d43923194842c001
5bf3a82cd439231948877aed
5bf03590d4392319481971dc
5beb6e66abd34c35e18e66b9
5be883a4f98cee15019d5b83
5be47bf9b18881428d8fbc1d
5bcf979a6d5f586b95c258cd
5bce7ac9ca24970bce4934b6
5bb8a49aea1cfa39f1aa7f75
5b78e57afc8fcf6781d0c3ba
5b21e18c58e2823a67a10dd8
5b22269758e2823a67a3bd03
5b192eb2170cf166458ff886
5ae2e9c5fe405c5076abc6b2
5adc6bd52430a05ecb2ffb85
5ab8b8e029f5351f7f2ccf59
5abc2506b53b042ead637d86
5ab85f1dac4291329b17cb50
5a969eea91dfc339a9a3ad2c
5a8aa0fab18050187cbe060e
5a7d3db14989e929563eb153
5a69c47d0d5d0a7f3b2e9752
5a618c72784780334bc1972d
5a6464143d809f1d8208c43c
5a588a8193ac3d233f77fbca
5a57542f333d180827dfc132
5a572fd9fc597b0478a81d14
5a563183425d0f5186314855
5a4a38dad38c8a075495b5d2
5a48d4b2c7dab83a7d7b9851
5a489fb1c7dab83a7d7b1070
5a48ba95c7dab83a7d7b44ed
5a3ca9cb270f0e3f14d0eddb
5a3cb4e4270f0e3f14d12f43
5a3f4aba5889373fbbc5d3b5
5a0271884e62597cdee0d0eb
59e864b2a9e91f2c5529325f
599aa591d5b41f366fed0d58
59350ca084b7f26bf5ce6eb8
59338e76772c3e6384afbb15
5c20ca3a0843bc542d94e3e2
5c1dbf200843bc542d8ef8c4
5c1b1500bee9a723c96c3e78
5bea87f4abd34c35e1860ab5
5c2b3ed5e611832e8aed46bf
57f8d9bbe73f6760f10e916a
5bf7d63575c26f32dbf7413b
5be4ab93870d330ff2dce134
5bd43b4ba6b28b1ee86b92dd
5bccd6beca24970bce448134
5bc5f0e896b66a2cd8f9bd36
5b908d3dc6ab78485f3d24a9
5b2c67b5e0878c381608b8d8
5b4933abf2b5f44e95de482a
5b3b353d8d46a939f93524b9
5acf8ca0f3d8a750097e4b15
5ab8713ba3799a1d138bd69a
5aa235f64a17b335eeaf9609
5aa0f9d7a9efce63548c69a1
5a8315f624b8e938486e0bd8
5a48c4e9c7dab83a7d7b5cc7
59ecfd02e225f6492d20fcc9
59f87d0bfa6280566fb38c9a
59f363a8b45be22330016cad
59f70ab1e5c5d366af29bf3e
59e75a2ca9e91f2c5526005d
5947719bf1b45630bd096665
5947b62af1b45630bd0c2a02
59056e6760bb961de55f3501
58f7f7299f5b5647873cb110
58cf4771d0f5fb221defe6da
58d36897f387231e6c929903
58c4bb4f4a69c55606122be4


================================================
FILE: lists/blendedmvs/val.txt
================================================
5b7a3890fc8fcf6781e2593a
5c189f2326173c3a09ed7ef3
5b950c71608de421b1e7318f
5a6400933d809f1d8200af15
59d2657f82ca7774b1ec081d
5ba19a8a360c7c30c1c169df
59817e4a1bd4b175e7038d19


================================================
FILE: lists/dtu/test.txt
================================================
scan1
scan4
scan9
scan10
scan11
scan12
scan13
scan15
scan23
scan24
scan29
scan32
scan33
scan34
scan48
scan49
scan62
scan75
scan77
scan110
scan114
scan118

================================================
FILE: lists/dtu/train.txt
================================================
scan2
scan6
scan7
scan8
scan14
scan16
scan18
scan19
scan20
scan22
scan30
scan31
scan36
scan39
scan41
scan42
scan44
scan45
scan46
scan47
scan50
scan51
scan52
scan53
scan55
scan57
scan58
scan60
scan61
scan63
scan64
scan65
scan68
scan69
scan70
scan71
scan72
scan74
scan76
scan83
scan84
scan85
scan87
scan88
scan89
scan90
scan91
scan92
scan93
scan94
scan95
scan96
scan97
scan98
scan99
scan100
scan101
scan102
scan103
scan104
scan105
scan107
scan108
scan109
scan111
scan112
scan113
scan115
scan116
scan119
scan120
scan121
scan122
scan123
scan124
scan125
scan126
scan127
scan128

================================================
FILE: lists/dtu/trainval.txt
================================================
scan2
scan6
scan7
scan8
scan14
scan16
scan18
scan19
scan20
scan22
scan30
scan31
scan36
scan39
scan41
scan42
scan44
scan45
scan46
scan47
scan50
scan51
scan52
scan53
scan55
scan57
scan58
scan60
scan61
scan63
scan64
scan65
scan68
scan69
scan70
scan71
scan72
scan74
scan76
scan83
scan84
scan85
scan87
scan88
scan89
scan90
scan91
scan92
scan93
scan94
scan95
scan96
scan97
scan98
scan99
scan100
scan101
scan102
scan103
scan104
scan105
scan107
scan108
scan109
scan111
scan112
scan113
scan115
scan116
scan119
scan120
scan121
scan122
scan123
scan124
scan125
scan126
scan127
scan128
scan3
scan5
scan17
scan21
scan28
scan35
scan37
scan38
scan40
scan43
scan56
scan59
scan66
scan67
scan82
scan86
scan106
scan117

================================================
FILE: lists/dtu/val.txt
================================================
scan3
scan5
scan17
scan21
scan28
scan35
scan37
scan38
scan40
scan43
scan56
scan59
scan66
scan67
scan82
scan86
scan106
scan117

================================================
FILE: models/MVS4Net.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from models.mvs4net_utils import stagenet, reg2d, reg3d, FPN4, FPN4_convnext, FPN4_convnext4, PosEncSine, PosEncLearned, \
        init_range, schedule_range, init_inverse_range, schedule_inverse_range, sinkhorn, mono_depth_decoder, ASFF


class MVS4net(nn.Module):
    def __init__(self, arch_mode="fpn", reg_net='reg2d', num_stage=4, fpn_base_channel=8, 
                reg_channel=8, stage_splits=[8,8,4,4], depth_interals_ratio=[0.5,0.5,0.5,1],
                group_cor=False, group_cor_dim=[8,8,8,8],
                inverse_depth=False,
                agg_type='ConvBnReLU3D',
                dcn=False,
                pos_enc=0,
                mono=False,
                asff=False,
                attn_temp=2,
                attn_fuse_d=True,
                vis_ETA=False,
                vis_mono=False
                ):
        # pos_enc: 0 no pos enc; 1 depth sine; 2 learnable pos enc
        super(MVS4net, self).__init__()
        self.arch_mode = arch_mode
        self.num_stage = num_stage
        self.depth_interals_ratio = depth_interals_ratio
        self.group_cor = group_cor
        self.group_cor_dim = group_cor_dim
        self.inverse_depth = inverse_depth
        self.asff = asff
        if self.asff:
            self.asff = nn.ModuleList([ASFF(i) for i in range(num_stage)])
        self.attn_ob = nn.ModuleList()
        if arch_mode == "fpn":
            self.feature = FPN4(base_channels=fpn_base_channel, gn=False, dcn=dcn)
        self.vis_mono = vis_mono
        self.stagenet = stagenet(inverse_depth, mono, attn_fuse_d, vis_ETA, attn_temp)
        self.stage_splits = stage_splits
        self.reg = nn.ModuleList()
        self.pos_enc = pos_enc
        self.pos_enc_func = nn.ModuleList()
        self.mono = mono
        if self.mono:
            self.mono_depth_decoder = mono_depth_decoder()
        if reg_net == 'reg3d':
            self.down_size = [3,3,2,2]
        for idx in range(num_stage):
            if self.group_cor:
                in_dim = group_cor_dim[idx]
            else:
                in_dim = self.feature.out_channels[idx]
            if reg_net == 'reg2d':
                self.reg.append(reg2d(input_channel=in_dim, base_channel=reg_channel, conv_name=agg_type))
            elif reg_net == 'reg3d':
                self.reg.append(reg3d(in_channels=in_dim, base_channels=reg_channel, down_size=self.down_size[idx]))


    def forward(self, imgs, proj_matrices, depth_values, filename=None):
        depth_min = depth_values[:, 0].cpu().numpy()
        depth_max = depth_values[:, -1].cpu().numpy()
        depth_interval = (depth_max - depth_min) / depth_values.size(1)

        # step 1. feature extraction
        features = []
        for nview_idx in range(len(imgs)):  #imgs shape (B, N, C, H, W)
            img = imgs[nview_idx]
            features.append(self.feature(img))
        if self.vis_mono:
            scan_name = filename[0].split('/')[0]
            image_name = filename[0].split('/')[2][:-2]
            save_fn = './debug_figs/vis_mono/feat_{}'.format(scan_name+'_'+image_name)
            feat_ = features[-1]['stage4'].detach().cpu().numpy()
            np.save(save_fn, feat_)
        # step 2. iter (multi-scale)
        outputs = {}
        for stage_idx in range(self.num_stage):
            if not self.asff:
                features_stage = [feat["stage{}".format(stage_idx+1)] for feat in features]
            else:
                features_stage = [self.asff[stage_idx](feat['stage1'],feat['stage2'],feat['stage3'],feat['stage4']) for feat in features]

            proj_matrices_stage = proj_matrices["stage{}".format(stage_idx + 1)]
            B,C,H,W = features[0]['stage{}'.format(stage_idx+1)].shape

            # init range
            if stage_idx == 0:
                if self.inverse_depth:
                    depth_hypo = init_inverse_range(depth_values, self.stage_splits[stage_idx], img[0].device, img[0].dtype, H, W)
                else:
                    depth_hypo = init_range(depth_values, self.stage_splits[stage_idx], img[0].device, img[0].dtype, H, W)
            else:
                if self.inverse_depth:
                    depth_hypo = schedule_inverse_range(outputs_stage['inverse_min_depth'].detach(), outputs_stage['inverse_max_depth'].detach(), self.stage_splits[stage_idx], H, W)  # B D H W
                else:
                    depth_hypo = schedule_range(outputs_stage['depth'].detach(), self.stage_splits[stage_idx], self.depth_interals_ratio[stage_idx] * depth_interval, H, W)

            outputs_stage = self.stagenet(features_stage, proj_matrices_stage, depth_hypo=depth_hypo, regnet=self.reg[stage_idx], stage_idx=stage_idx,
                                        group_cor=self.group_cor, group_cor_dim=self.group_cor_dim[stage_idx],
                                        split_itv=self.depth_interals_ratio[stage_idx],
                                        fn=filename)

            outputs["stage{}".format(stage_idx + 1)] = outputs_stage
            outputs.update(outputs_stage)
        
        if self.mono and self.training:
        # if self.mono:
            outputs = self.mono_depth_decoder(outputs, depth_values[:,0], depth_values[:,1])

        return outputs

def MVS4net_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
    stage_lw = kwargs.get("stage_lw", [1,1,1,1])
    l1ot_lw = kwargs.get("l1ot_lw", [0,1])
    inverse = kwargs.get("inverse_depth", False)
    ot_iter = kwargs.get("ot_iter", 3)
    ot_eps = kwargs.get("ot_eps", 1)
    ot_continous = kwargs.get("ot_continous", False)
    mono = kwargs.get("mono", False)
    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
    stage_ot_loss = []
    stage_l1_loss = []
    range_err_ratio = []
    for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if "stage" in k]):
        depth_pred = stage_inputs['depth']
        hypo_depth = stage_inputs['hypo_depth']
        attn_weight = stage_inputs['attn_weight']
        B,H,W = depth_pred.shape
        D = hypo_depth.shape[1]
        mask = mask_ms[stage_key]
        mask = mask > 0.5
        depth_gt = depth_gt_ms[stage_key]

        if mono and stage_idx!=0:
            this_stage_l1_loss = F.l1_loss(stage_inputs['mono_depth'][mask], depth_gt[mask], reduction='mean')
        else:
            this_stage_l1_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)

        # mask range
        if inverse:
            depth_itv = (1/hypo_depth[:,2,:,:]-1/hypo_depth[:,1,:,:]).abs()  # B H W
            mask_out_of_range = ((1/hypo_depth - 1/depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W
        else:
            depth_itv = (hypo_depth[:,2,:,:]-hypo_depth[:,1,:,:]).abs()  # B H W
            mask_out_of_range = ((hypo_depth - depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W
        range_err_ratio.append(mask_out_of_range[mask].float().mean())

        this_stage_ot_loss = sinkhorn(depth_gt, hypo_depth, attn_weight, mask, iters=ot_iter, eps=ot_eps, continuous=ot_continous)[1]

        stage_l1_loss.append(this_stage_l1_loss)
        stage_ot_loss.append(this_stage_ot_loss)
        total_loss = total_loss + stage_lw[stage_idx] * (l1ot_lw[0] * this_stage_l1_loss + l1ot_lw[1] * this_stage_ot_loss)

    return total_loss, stage_l1_loss, stage_ot_loss, range_err_ratio


def Blend_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
    stage_lw = kwargs.get("stage_lw", [1,1,1,1])
    l1ot_lw = kwargs.get("l1ot_lw", [0,1])
    inverse = kwargs.get("inverse_depth", False)
    ot_iter = kwargs.get("ot_iter", 3)
    ot_eps = kwargs.get("ot_eps", 1)
    ot_continous = kwargs.get("ot_continous", False)
    depth_max = kwargs.get("depth_max", 100)
    depth_min = kwargs.get("depth_min", 1)
    mono = kwargs.get("mono", False)
    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
    stage_ot_loss = []
    stage_l1_loss = []
    range_err_ratio = []
    for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if "stage" in k]):
        depth_pred = stage_inputs['depth']
        hypo_depth = stage_inputs['hypo_depth']
        attn_weight = stage_inputs['attn_weight']
        B,H,W = depth_pred.shape
        mask = mask_ms[stage_key]
        mask = mask > 0.5
        depth_gt = depth_gt_ms[stage_key]
        depth_pred_norm = depth_pred * 128 / (depth_max - depth_min)[:,None,None]  # B H W
        depth_gt_norm = depth_gt * 128 / (depth_max - depth_min)[:,None,None]  # B H W

        if mono and stage_idx!=0:
            this_stage_l1_loss = F.l1_loss(stage_inputs['mono_depth'][mask], depth_gt[mask], reduction='mean')
        else:
            this_stage_l1_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)

        if inverse:
            depth_itv = (1/hypo_depth[:,2,:,:]-1/hypo_depth[:,1,:,:]).abs()  # B H W
            mask_out_of_range = ((1/hypo_depth - 1/depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W
        else:
            depth_itv = (hypo_depth[:,2,:,:]-hypo_depth[:,1,:,:]).abs()  # B H W
            mask_out_of_range = ((hypo_depth - depth_gt.unsqueeze(1)).abs() <= depth_itv.unsqueeze(1)).sum(1) == 0 # B H W
        range_err_ratio.append(mask_out_of_range[mask].float().mean())

        this_stage_ot_loss = sinkhorn(depth_gt, hypo_depth, attn_weight, mask, iters=ot_iter, eps=ot_eps, continuous=ot_continous)[1]

        stage_l1_loss.append(this_stage_l1_loss)
        stage_ot_loss.append(this_stage_ot_loss)
        total_loss = total_loss + stage_lw[stage_idx] * (l1ot_lw[0] * this_stage_l1_loss + l1ot_lw[1] * this_stage_ot_loss)

    abs_err = torch.abs(depth_pred_norm[mask] - depth_gt_norm[mask])
    epe = abs_err.mean()
    err3 = (abs_err<=3).float().mean()*100
    err1= (abs_err<=1).float().mean()*100
    return total_loss, stage_l1_loss, stage_ot_loss, range_err_ratio, epe, err3, err1

================================================
FILE: models/__init__.py
================================================

from models.MVS4Net import MVS4net, MVS4net_loss, Blend_loss

================================================
FILE: models/module.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import sys
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
sys.path.append("..")
from utils import local_pcd
from modules.deform_conv import DeformConvPack


def init_bn(module):
    if module.weight is not None:
        nn.init.ones_(module.weight)
    if module.bias is not None:
        nn.init.zeros_(module.bias)
    return


def init_uniform(module, init_method):
    if module.weight is not None:
        if init_method == "kaiming":
            nn.init.kaiming_uniform_(module.weight)
        elif init_method == "xavier":
            nn.init.xavier_uniform_(module.weight)
    return

class Conv2d(nn.Module):
    """Applies a 2D convolution (optionally with batch normalization and relu activation)
    over an input signal composed of several input planes.

    Attributes:
        conv (nn.Module): convolution module
        bn (nn.Module): batch normalization module
        relu (bool): whether to activate by relu

    Notes:
        Default momentum for batch normalization is set to be 0.01,

    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
        super(Conv2d, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
                              bias=(not bn), **kwargs)
        self.kernel_size = kernel_size
        self.stride = stride
        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
        self.relu = relu

        # assert init_method in ["kaiming", "xavier"]
        # self.init_weights(init_method)

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x

    def init_weights(self, init_method):
        """default initialization"""
        init_uniform(self.conv, init_method)
        if self.bn is not None:
            init_bn(self.bn)

class DCNConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
        super(DCNConv2d, self).__init__()

        self.conv = DeformConvPack(in_channels, out_channels, kernel_size, stride=stride, padding=1, bias=(not bn), im2col_step=16)
        self.kernel_size = kernel_size
        self.stride = stride
        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
        self.relu = relu

        # assert init_method in ["kaiming", "xavier"]
        # self.init_weights(init_method)

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x

    def init_weights(self, init_method):
        """default initialization"""
        init_uniform(self.conv, init_method)
        if self.bn is not None:
            init_bn(self.bn)

class Deconv2d(nn.Module):
    """Applies a 2D deconvolution (optionally with batch normalization and relu activation)
       over an input signal composed of several input planes.

       Attributes:
           conv (nn.Module): convolution module
           bn (nn.Module): batch normalization module
           relu (bool): whether to activate by relu

       Notes:
           Default momentum for batch normalization is set to be 0.01,

       """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
        super(Deconv2d, self).__init__()
        self.out_channels = out_channels
        assert stride in [1, 2]
        self.stride = stride

        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride,
                                       bias=(not bn), **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
        self.relu = relu

        # assert init_method in ["kaiming", "xavier"]
        # self.init_weights(init_method)

    def forward(self, x):
        y = self.conv(x)
        if self.stride == 2:
            h, w = list(x.size())[2:]
            y = y[:, :, :2 * h, :2 * w].contiguous()
        if self.bn is not None:
            x = self.bn(y)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x

    def init_weights(self, init_method):
        """default initialization"""
        init_uniform(self.conv, init_method)
        if self.bn is not None:
            init_bn(self.bn)

class Conv3d(nn.Module):
    """Applies a 3D convolution (optionally with batch normalization and relu activation)
    over an input signal composed of several input planes.

    Attributes:
        conv (nn.Module): convolution module
        bn (nn.Module): batch normalization module
        relu (bool): whether to activate by relu

    Notes:
        Default momentum for batch normalization is set to be 0.01,

    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
        super(Conv3d, self).__init__()
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        assert stride in [1, 2]
        self.stride = stride

        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride,
                              bias=(not bn), **kwargs)
        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None
        self.relu = relu

        # assert init_method in ["kaiming", "xavier"]
        # self.init_weights(init_method)

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x

    def init_weights(self, init_method):
        """default initialization"""
        init_uniform(self.conv, init_method)
        if self.bn is not None:
            init_bn(self.bn)

class PConv3d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 relu=True, bn=True, bn_momentum=0.1, padding=1, init_method="xavier", **kwargs):
        super(PConv3d, self).__init__()
        self.out_channels = out_channels
        self.kernel_size_xy = (1, kernel_size, kernel_size)
        self.kernel_size_d = (kernel_size, 1, 1)
        assert stride in [1, 2]
        self.stride_xy = (1, stride, stride)
        self.stride_d = (stride, 1, 1)
        self.padding_xy = (0, padding, padding)
        self.padding_d = (padding, 0, 0)

        self.convxy = nn.Conv3d(in_channels, in_channels, self.kernel_size_xy, stride=self.stride_xy, padding=self.padding_xy, bias=(not bn), **kwargs)
        self.convd = nn.Conv3d(in_channels, out_channels, self.kernel_size_d, stride=self.stride_d, padding=self.padding_d, bias=(not bn), **kwargs)
        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None
        self.relu = relu

        # assert init_method in ["kaiming", "xavier"]
        # self.init_weights(init_method)

    def forward(self, x):
        x = self.convxy(x)
        x = self.convd(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x

    def init_weights(self, init_method):
        """default initialization"""
        init_uniform(self.convxy, init_method)
        init_uniform(self.convd, init_method)
        if self.bn is not None:
            init_bn(self.bn)


class Deconv3d(nn.Module):
    """Applies a 3D deconvolution (optionally with batch normalization and relu activation)
       over an input signal composed of several input planes.

       Attributes:
           conv (nn.Module): convolution module
           bn (nn.Module): batch normalization module
           relu (bool): whether to activate by relu

       Notes:
           Default momentum for batch normalization is set to be 0.01,

       """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
        super(Deconv3d, self).__init__()
        self.out_channels = out_channels
        assert stride in [1, 2]
        self.stride = stride

        self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
                                       bias=(not bn), **kwargs)
        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None
        self.relu = relu

        # assert init_method in ["kaiming", "xavier"]
        # self.init_weights(init_method)

    def forward(self, x):
        y = self.conv(x)
        if self.bn is not None:
            x = self.bn(y)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x

    def init_weights(self, init_method):
        """default initialization"""
        init_uniform(self.conv, init_method)
        if self.bn is not None:
            init_bn(self.bn)


class PDeconv3d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,output_padding=1,
                 relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
        super(PDeconv3d, self).__init__()
        self.out_channels = out_channels
        assert stride in [1, 2]
        self.stride = stride
        self.kernel_size_xy = (1, kernel_size,kernel_size)
        self.kernel_size_d = (kernel_size, 1,1)
        self.stride_xy = (1, stride, stride)
        self.stride_d = (stride, 1, 1)
        self.padding_xy = (0, padding, padding)
        self.padding_d = (padding, 0, 0)
        self.outpadding_xy = (0, output_padding, output_padding)
        self.outpadding_d = (output_padding, 0, 0)
        self.convxy = nn.ConvTranspose3d(in_channels, in_channels, self.kernel_size_xy, stride=self.stride_xy, padding=self.padding_xy, output_padding=self.outpadding_xy, bias=(not bn))
        self.convd = nn.ConvTranspose3d(in_channels, out_channels, self.kernel_size_d, stride=self.stride_d, padding=self.padding_d, output_padding=self.outpadding_d, bias=(not bn))
        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None
        self.relu = relu

        # assert init_method in ["kaiming", "xavier"]
        # self.init_weights(init_method)

    def forward(self, x):
        x = self.convxy(x)
        y = self.convd(x)
        if self.bn is not None:
            x = self.bn(y)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x

    def init_weights(self, init_method):
        """default initialization"""
        init_uniform(self.convxy, init_method)
        init_uniform(self.convd, init_method)
        if self.bn is not None:
            init_bn(self.bn)

class ConvBnReLU(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
        super(ConvBnReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        return F.relu(self.bn(self.conv(x)), inplace=True)

class ConvBn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
        super(ConvBn, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        return self.bn(self.conv(x))

class ConvBnReLU3D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
        super(ConvBnReLU3D, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        return F.relu(self.bn(self.conv(x)), inplace=True)


class ConvBn3D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
        super(ConvBn3D, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        return self.bn(self.conv(x))


class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride, downsample=None):
        super(BasicBlock, self).__init__()

        self.conv1 = ConvBnReLU(in_channels, out_channels, kernel_size=3, stride=stride, pad=1)
        self.conv2 = ConvBn(out_channels, out_channels, kernel_size=3, stride=1, pad=1)

        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2(out)
        if self.downsample is not None:
            x = self.downsample(x)
        out += x
        return out


class Hourglass3d(nn.Module):
    def __init__(self, channels):
        super(Hourglass3d, self).__init__()

        self.conv1a = ConvBnReLU3D(channels, channels * 2, kernel_size=3, stride=2, pad=1)
        self.conv1b = ConvBnReLU3D(channels * 2, channels * 2, kernel_size=3, stride=1, pad=1)

        self.conv2a = ConvBnReLU3D(channels * 2, channels * 4, kernel_size=3, stride=2, pad=1)
        self.conv2b = ConvBnReLU3D(channels * 4, channels * 4, kernel_size=3, stride=1, pad=1)

        self.dconv2 = nn.Sequential(
            nn.ConvTranspose3d(channels * 4, channels * 2, kernel_size=3, padding=1, output_padding=1, stride=2,
                               bias=False),
            nn.BatchNorm3d(channels * 2))

        self.dconv1 = nn.Sequential(
            nn.ConvTranspose3d(channels * 2, channels, kernel_size=3, padding=1, output_padding=1, stride=2,
                               bias=False),
            nn.BatchNorm3d(channels))

        self.redir1 = ConvBn3D(channels, channels, kernel_size=1, stride=1, pad=0)
        self.redir2 = ConvBn3D(channels * 2, channels * 2, kernel_size=1, stride=1, pad=0)

    def forward(self, x):
        conv1 = self.conv1b(self.conv1a(x))
        conv2 = self.conv2b(self.conv2a(conv1))
        dconv2 = F.relu(self.dconv2(conv2) + self.redir2(conv1), inplace=True)
        dconv1 = F.relu(self.dconv1(dconv2) + self.redir1(x), inplace=True)
        return dconv1


def homo_warping(src_fea, src_proj, ref_proj, depth_values, align_corners=False):
    # src_fea: [B, C, H, W]
    # src_proj: [B, 4, 4]
    # ref_proj: [B, 4, 4]
    # depth_values: [B, Ndepth] o [B, Ndepth, H, W]
    # out: [B, C, Ndepth, H, W]
    batch, channels = src_fea.shape[0], src_fea.shape[1]
    num_depth = depth_values.shape[1]
    height, width = src_fea.shape[2], src_fea.shape[3]

    with torch.no_grad():
        proj = torch.matmul(src_proj, torch.inverse(ref_proj))
        rot = proj[:, :3, :3]  # [B,3,3]
        trans = proj[:, :3, 3:4]  # [B,3,1]

        y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device),
                               torch.arange(0, width, dtype=torch.float32, device=src_fea.device)])
        y, x = y.contiguous(), x.contiguous()
        y, x = y.view(height * width), x.view(height * width)
        xyz = torch.stack((x, y, torch.ones_like(x)))  # [3, H*W]
        xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1)  # [B, 3, H*W]
        rot_xyz = torch.matmul(rot, xyz)  # [B, 3, H*W]
        rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.view(batch, 1, num_depth, -1)  # [B, 3, Ndepth, H*W]
        proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1)  # [B, 3, Ndepth, H*W]
        proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :]  # [B, 2, Ndepth, H*W]
        proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1
        proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
        proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3)  # [B, Ndepth, H*W, 2]
        grid = proj_xy

    warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear', padding_mode='zeros', align_corners=align_corners)
    warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width)

    return warped_src_fea

class DeConv2dFuse(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, relu=True, bn=True,
                 bn_momentum=0.1):
        super(DeConv2dFuse, self).__init__()

        self.deconv = Deconv2d(in_channels, out_channels, kernel_size, stride=2, padding=1, output_padding=1,
                               bn=True, relu=relu, bn_momentum=bn_momentum)

        self.conv = Conv2d(2*out_channels, out_channels, kernel_size, stride=1, padding=1,
                           bn=bn, relu=relu, bn_momentum=bn_momentum)

        # assert init_method in ["kaiming", "xavier"]
        # self.init_weights(init_method)

    def forward(self, x_pre, x):
        x = self.deconv(x)
        x = torch.cat((x, x_pre), dim=1)
        x = self.conv(x)
        return x


class FeatureNet(nn.Module):
    def __init__(self, base_channels, num_stage=3, stride=4, arch_mode="unet"):
        super(FeatureNet, self).__init__()
        assert arch_mode in ["unet", "fpn"], print("mode must be in 'unet' or 'fpn', but get:{}".format(arch_mode))
        print("*************feature extraction arch mode:{}****************".format(arch_mode))
        self.arch_mode = arch_mode
        self.stride = stride
        self.base_channels = base_channels
        self.num_stage = num_stage

        self.conv0 = nn.Sequential(
            Conv2d(3, base_channels, 3, 1, padding=1),
            Conv2d(base_channels, base_channels, 3, 1, padding=1),
        )

        self.conv1 = nn.Sequential(
            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
        )

        self.conv2 = nn.Sequential(
            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
        )

        self.out1 = nn.Conv2d(base_channels * 4, base_channels * 4, 1, bias=False)
        self.out_channels = [4 * base_channels]

        if self.arch_mode == 'unet':
            if num_stage == 3:
                self.deconv1 = DeConv2dFuse(base_channels * 4, base_channels * 2, 3)
                self.deconv2 = DeConv2dFuse(base_channels * 2, base_channels, 3)

                self.out2 = nn.Conv2d(base_channels * 2, base_channels * 2, 1, bias=False)
                self.out3 = nn.Conv2d(base_channels, base_channels, 1, bias=False)
                self.out_channels.append(2 * base_channels)
                self.out_channels.append(base_channels)

            elif num_stage == 2:
                self.deconv1 = DeConv2dFuse(base_channels * 4, base_channels * 2, 3)

                self.out2 = nn.Conv2d(base_channels * 2, base_channels * 2, 1, bias=False)
                self.out_channels.append(2 * base_channels)
        elif self.arch_mode == "fpn":
            final_chs = base_channels * 4
            if num_stage == 3:
                self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
                self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)

                self.out2 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
                self.out3 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)
                self.out_channels.append(base_channels * 2)
                self.out_channels.append(base_channels)

            elif num_stage == 2:
                self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)

                self.out2 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)
                self.out_channels.append(base_channels)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)

        intra_feat = conv2
        outputs = {}
        out = self.out1(intra_feat)
        outputs["stage1"] = out
        if self.arch_mode == "unet":
            if self.num_stage == 3:
                intra_feat = self.deconv1(conv1, intra_feat)
                out = self.out2(intra_feat)
                outputs["stage2"] = out

                intra_feat = self.deconv2(conv0, intra_feat)
                out = self.out3(intra_feat)
                outputs["stage3"] = out

            elif self.num_stage == 2:
                intra_feat = self.deconv1(conv1, intra_feat)
                out = self.out2(intra_feat)
                outputs["stage2"] = out

        elif self.arch_mode == "fpn":
            if self.num_stage == 3:
                intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner1(conv1)
                out = self.out2(intra_feat)
                outputs["stage2"] = out

                intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner2(conv0)
                out = self.out3(intra_feat)
                outputs["stage3"] = out

            elif self.num_stage == 2:
                intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner1(conv1)
                out = self.out2(intra_feat)
                outputs["stage2"] = out

        return outputs

class FPNDCNpath(nn.Module):
    """
    FPN+DCN pathway"""
    def __init__(self, base_channels, stride=4):
        super(FPNDCNpath, self).__init__()
        self.stride = stride
        self.base_channels = base_channels

        self.conv0 = nn.Sequential(
            Conv2d(3, base_channels, 3, 1, padding=1),
            Conv2d(base_channels, base_channels, 3, 1, padding=1),
        )

        self.conv1 = nn.Sequential(
            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
        )

        self.conv2 = nn.Sequential(
            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
        )

        self.out1 = nn.Sequential(
            DCNConv2d(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1),
            DCNConv2d(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1),
            DeformConvPack(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1, bias=False, im2col_step=16)
        )
        self.out_channels = [4 * base_channels]

        final_chs = base_channels * 4
        self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
        self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)

        self.out2 = nn.Sequential(
            DCNConv2d(base_channels * 4, base_channels * 2, 3,  stride=1, padding=1),
            DCNConv2d(base_channels * 2, base_channels * 2, 3,  stride=1, padding=1),
            DeformConvPack(base_channels * 2, base_channels * 2, 3,  stride=1, padding=1, bias=False, im2col_step=16)
        )
        self.out2pathconv = nn.Conv2d(base_channels * 4, base_channels * 2, 3,  stride=1, padding=1)
        self.out3 = nn.Sequential(
            DCNConv2d(base_channels * 4, base_channels * 1, 3,  stride=1, padding=1),
            DCNConv2d(base_channels * 1, base_channels * 1, 3,  stride=1, padding=1),
            DeformConvPack(base_channels * 1, base_channels * 1, 3,  stride=1, padding=1, bias=False, im2col_step=16)
        )
        self.out3pathconv = nn.Conv2d(base_channels * 2, base_channels * 1, 3,  stride=1, padding=1)
        self.out_channels.append(base_channels * 2)
        self.out_channels.append(base_channels)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)

        intra_feat = conv2
        outputs = {}
        out1 = self.out1(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv1)
        out2 = self.out2(intra_feat)
        out2 = out2 + self.out2pathconv(F.interpolate(out1, scale_factor=2, mode="bilinear", align_corners=True))

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv0)
        out3 = self.out3(intra_feat)
        out3 = out3 + self.out3pathconv(F.interpolate(out2, scale_factor=2, mode="bilinear", align_corners=True))

        outputs["stage1"] = out1
        outputs["stage2"] = out2
        outputs["stage3"] = out3

        return outputs

class FPNDCN(nn.Module):
    """
    FPN+DCN"""
    def __init__(self, base_channels, stride=4):
        super(FPNDCN, self).__init__()
        self.stride = stride
        self.base_channels = base_channels

        self.conv0 = nn.Sequential(
            Conv2d(3, base_channels, 3, 1, padding=1),
            Conv2d(base_channels, base_channels, 3, 1, padding=1),
        )

        self.conv1 = nn.Sequential(
            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
        )

        self.conv2 = nn.Sequential(
            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
        )

        self.out1 = nn.Sequential(
            DCNConv2d(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1),
            DCNConv2d(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1),
            DeformConvPack(base_channels * 4, base_channels * 4, 3,  stride=1, padding=1, bias=False, im2col_step=16)
        )
        self.out_channels = [4 * base_channels]

        final_chs = base_channels * 4
        self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
        self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)

        self.out2 = nn.Sequential(
            DCNConv2d(base_channels * 4, base_channels * 2, 3,  stride=1, padding=1),
            DCNConv2d(base_channels * 2, base_channels * 2, 3,  stride=1, padding=1),
            DeformConvPack(base_channels * 2, base_channels * 2, 3,  stride=1, padding=1, bias=False, im2col_step=16)
        )
        self.out3 = nn.Sequential(
            DCNConv2d(base_channels * 4, base_channels * 1, 3,  stride=1, padding=1),
            DCNConv2d(base_channels * 1, base_channels * 1, 3,  stride=1, padding=1),
            DeformConvPack(base_channels * 1, base_channels * 1, 3,  stride=1, padding=1, bias=False, im2col_step=16)
        )
        self.out_channels.append(base_channels * 2)
        self.out_channels.append(base_channels)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)

        intra_feat = conv2
        outputs = {}
        out1 = self.out1(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv1)
        out2 = self.out2(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv0)
        out3 = self.out3(intra_feat)

        outputs["stage1"] = out1
        outputs["stage2"] = out2
        outputs["stage3"] = out3

        return outputs

class FPNA(nn.Module):
    """
    FPN aligncorners"""
    def __init__(self, base_channels, stride=4):
        super(FPNA, self).__init__()
        self.stride = stride
        self.base_channels = base_channels

        self.conv0 = nn.Sequential(
            Conv2d(3, base_channels, 3, 1, padding=1),
            Conv2d(base_channels, base_channels, 3, 1, padding=1),
        )

        self.conv1 = nn.Sequential(
            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
        )

        self.conv2 = nn.Sequential(
            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
        )

        self.out1 = nn.Conv2d(base_channels * 4, base_channels * 4, 1, bias=False)
        self.out_channels = [4 * base_channels]

        final_chs = base_channels * 4
        self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
        self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)

        self.out2 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
        self.out3 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)

        self.out_channels.append(base_channels * 2)
        self.out_channels.append(base_channels)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)

        intra_feat = conv2
        outputs = {}
        out1 = self.out1(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv1)
        out2 = self.out2(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv0)
        out3 = self.out3(intra_feat)

        outputs["stage1"] = out1
        outputs["stage2"] = out2
        outputs["stage3"] = out3

        return outputs

class FPNA4(nn.Module):
    """
    FPN aligncorners downsample 4x"""
    def __init__(self, base_channels):
        super(FPNA4, self).__init__()
        self.base_channels = base_channels

        self.conv0 = nn.Sequential(
            Conv2d(3, base_channels, 3, 1, padding=1),
            Conv2d(base_channels, base_channels, 3, 1, padding=1),
        )

        self.conv1 = nn.Sequential(
            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1),
        )

        self.conv2 = nn.Sequential(
            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1),
        )

        self.conv3 = nn.Sequential(
            Conv2d(base_channels * 4, base_channels * 8, 5, stride=2, padding=2),
            Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1),
            Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1),
        )

        self.out_channels = [8 * base_channels]
        final_chs = base_channels * 8

        self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)
        self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
        self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)

        self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)
        self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)
        self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
        self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)

        self.out_channels.append(base_channels * 4)
        self.out_channels.append(base_channels * 2)
        self.out_channels.append(base_channels)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)

        intra_feat = conv3
        outputs = {}
        out1 = self.out1(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2)
        out2 = self.out2(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1)
        out3 = self.out3(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0)
        out4 = self.out4(intra_feat)

        outputs["stage1"] = out1
        outputs["stage2"] = out2
        outputs["stage3"] = out3
        outputs["stage4"] = out4

        return outputs

class CostRegNet(nn.Module):
    def __init__(self, in_channels, base_channels, down_size=3):
        super(CostRegNet, self).__init__()
        self.down_size = down_size
        self.conv0 = Conv3d(in_channels, base_channels, padding=1)

        self.conv1 = Conv3d(base_channels, base_channels * 2, stride=2, padding=1)
        self.conv2 = Conv3d(base_channels * 2, base_channels * 2, padding=1)

        if down_size >= 2:
            self.conv3 = Conv3d(base_channels * 2, base_channels * 4, stride=2, padding=1)
            self.conv4 = Conv3d(base_channels * 4, base_channels * 4, padding=1)

        if down_size >= 3:
            self.conv5 = Conv3d(base_channels * 4, base_channels * 8, stride=2, padding=1)
            self.conv6 = Conv3d(base_channels * 8, base_channels * 8, padding=1)
            self.conv7 = Deconv3d(base_channels * 8, base_channels * 4, stride=2, padding=1, output_padding=1)

        if down_size >= 2:
            self.conv9 = Deconv3d(base_channels * 4, base_channels * 2, stride=2, padding=1, output_padding=1)
            
        self.conv11 = Deconv3d(base_channels * 2, base_channels * 1, stride=2, padding=1, output_padding=1)
        self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False)

    def forward(self, x):
        if self.down_size==3:
            conv0 = self.conv0(x)
            conv2 = self.conv2(self.conv1(conv0))
            conv4 = self.conv4(self.conv3(conv2))
            x = self.conv6(self.conv5(conv4))
            x = conv4 + self.conv7(x)
            x = conv2 + self.conv9(x)
            x = conv0 + self.conv11(x)
            x = self.prob(x)
        elif self.down_size==2:
            conv0 = self.conv0(x)
            conv2 = self.conv2(self.conv1(conv0))
            x = self.conv4(self.conv3(conv2))
            x = conv2 + self.conv9(x)
            x = conv0 + self.conv11(x)
            x = self.prob(x)
        else:
            conv0 = self.conv0(x)
            x = self.conv2(self.conv1(conv0))
            x = conv0 + self.conv11(x)
            x = self.prob(x)
        return x

class P3DConv(nn.Module):
    """
    Pseudo 3D conv: 3x3x1 + 1x3x3
    """
    def __init__(self, in_channels, base_channels):
        super(P3DConv, self).__init__()
        self.conv0 = PConv3d(in_channels, base_channels, padding=1)

        self.conv1 = PConv3d(base_channels, base_channels * 2, stride=2, padding=1)
        self.conv2 = PConv3d(base_channels * 2, base_channels * 2, padding=1)

        self.conv3 = PConv3d(base_channels * 2, base_channels * 4, stride=2, padding=1)
        self.conv4 = PConv3d(base_channels * 4, base_channels * 4, padding=1)

        self.conv5 = PConv3d(base_channels * 4, base_channels * 8, stride=2, padding=1)
        self.conv6 = PConv3d(base_channels * 8, base_channels * 8, padding=1)

        self.conv7 = PDeconv3d(base_channels * 8, base_channels * 4, stride=2, padding=1, output_padding=1)

        self.conv9 = PDeconv3d(base_channels * 4, base_channels * 2, stride=2, padding=1, output_padding=1)

        self.conv11 = PDeconv3d(base_channels * 2, base_channels * 1, stride=2, padding=1, output_padding=1)

        self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv2 = self.conv2(self.conv1(conv0))
        conv4 = self.conv4(self.conv3(conv2))
        x = self.conv6(self.conv5(conv4))
        x = conv4 + self.conv7(x)
        x = conv2 + self.conv9(x)
        x = conv0 + self.conv11(x)
        x = self.prob(x)
        return x

class RefineNet(nn.Module):
    def __init__(self):
        super(RefineNet, self).__init__()
        self.conv1 = ConvBnReLU(4, 32)
        self.conv2 = ConvBnReLU(32, 32)
        self.conv3 = ConvBnReLU(32, 32)
        self.res = ConvBnReLU(32, 1)

    def forward(self, img, depth_init):
        concat = F.cat((img, depth_init), dim=1)
        depth_residual = self.res(self.conv3(self.conv2(self.conv1(concat))))
        depth_refined = depth_init + depth_residual
        return depth_refined


def depth_regression(p, depth_values):
    if depth_values.dim() <= 2:
        # print("regression dim <= 2")
        depth_values = depth_values.view(*depth_values.shape, 1, 1)
    depth = torch.sum(p * depth_values, 1)

    return depth

def cas_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
    depth_loss_weights = kwargs.get("dlossw", None)

    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)

    for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "stage" in k]:
        depth_est = stage_inputs["depth"]
        depth_gt = depth_gt_ms[stage_key]
        mask = mask_ms[stage_key]
        mask = mask > 0.5

        depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean')

        if depth_loss_weights is not None:
            stage_idx = int(stage_key.replace("stage", "")) - 1
            total_loss += depth_loss_weights[stage_idx] * depth_loss
        else:
            total_loss += 1.0 * depth_loss

    return total_loss, depth_loss

def cas_mvsnet_T_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
    depth_loss_weights = kwargs.get("dlossw", None)
    l1ce_lw = kwargs.get("l1ce_lw", [0.1, 1])
    range_thres = kwargs.get("range_thres", [84.8, 10.6])
    cas_method = kwargs.get("cascade_method", None)
    last_conv3d = kwargs.get("last_conv3d", False)
    visual = kwargs.get("visual", False)
    wt = kwargs.get("wt", False)
    fl = kwargs.get("fl", False)
    shrink_method = kwargs.get("shrink_method", 'schedule')
    upsampled_loss = kwargs.get("upsampled_loss", False)
    selected_loss = kwargs.get("selected_loss", False)
    mask_range_loss = kwargs.get("mask_range_loss", False)
    det = kwargs.get("det", False)
    if visual:
        f, axs = plt.subplots(figsize=(30, 10),ncols=3)  # depth offset
        f2, axs2 = plt.subplots(figsize=(30, 10),ncols=3)  # attn weight max
        f3, axs3 = plt.subplots(figsize=(30, 10),ncols=3)  # attn weight gt val
        f4, axs4 = plt.subplots(figsize=(30, 10),ncols=3)  # max gt offset
        err_848_str = ''
        err_106_str = ''
        err_002_str = ''

    total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
    stage_depth_loss = []
    stage_ce_loss = []
    range_err_ratio = []
    upsampled_depth_losses = []
    det_offset_losses = []
    for stage_idx, (stage_inputs, stage_key) in enumerate([(inputs[k], k) for k in inputs.keys() if "stage" in k]):
        depth_est = stage_inputs["depth"]
        B,H,W = depth_est.shape
        mask = mask_ms[stage_key]
        mask = mask > 0.5
        depth_gt = depth_gt_ms[stage_key]

        if upsampled_loss:
            if stage_idx!=0 :
                upsampled_depth = stage_inputs["upsampled_depth"]
                upsampled_depth_loss = F.smooth_l1_loss(upsampled_depth[mask], depth_gt[mask], reduction='mean')
                upsampled_depth_losses.append(upsampled_depth_loss)
        else:
            if stage_idx!=0 :
                upsampled_depth_losses.append(torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False))
        
        if mask_range_loss:
            if stage_idx != 0:
                depth_offset = next_stage_depth_hypo - depth_gt  # B H W
                this_stage_mask_range = torch.abs(depth_offset)<range_thres[stage_idx-1]
                mask = mask & this_stage_mask_range  # B H W
            next_stage_depth_hypo = F.interpolate(depth_est.unsqueeze(1), scale_factor=2, mode='bilinear', align_corners=True).squeeze(1)
            

        if stage_idx != len(range_thres):
            depth_offset = depth_est - depth_gt
            depth_offset[~mask] = 0
            depth_offset = depth_offset # B H W
            range_err_ratio.append((torch.abs(depth_offset)>range_thres[stage_idx]).float().mean())


        if visual:
            depth_offset = depth_est - depth_gt
            depth_offset[~mask] = 0
            depth_offset = depth_offset.detach().cpu().numpy()[0] # H W  
            err_848_str += str((np.abs(depth_offset)>84.8).sum()) + ','
            err_106_str += str((np.abs(depth_offset)>10.6).sum()) + ','
            err_002_str += str((np.abs(depth_offset)>2).sum()) + ','
            sns.heatmap(depth_offset, annot=False, ax=axs[stage_idx])

            attn_weights = stage_inputs["attn_weights"][0]  # D H W
            attn_weights_max, ind_max = torch.max(attn_weights, 0)
            attn_weights_max = attn_weights_max.detach().cpu().numpy()  # H W
            sns.heatmap(attn_weights_max, annot=False, ax=axs2[stage_idx])

            this_stage_depth_val = stage_inputs['depth_values']  # B D H W
            depth_offsets = torch.abs(this_stage_depth_val- depth_gt[:,None,:,:])[0]  # D,H,W
            _, indices = torch.min(depth_offsets, dim=0, keepdim=True)  # [1, H, W]
            attn_gt = torch.gather(attn_weights, 0, indices)[0]  # [H W]
            attn_gt = attn_gt.detach().cpu().numpy()
            sns.heatmap(attn_gt, annot=False, ax=axs3[stage_idx])

            max_gt_offset = ind_max - indices[0]  # H W
            max_gt_offset = max_gt_offset.detach().cpu().numpy()
            sns.heatmap(max_gt_offset, annot=False, ax=axs4[stage_idx])

        if cas_method[stage_idx] == 't' or cas_method[stage_idx] == 'r' or cas_method[stage_idx] == 'p':
            # Loss for transformer 
            depth_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
            if last_conv3d:
                depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean')
            attn_weights = stage_inputs["attn_weights"].permute(0,2,3,1).reshape(B*H*W, -1)  # BHW D
            this_stage_depth_val = stage_inputs['depth_values']  # B D H W
            depth_offsets = torch.abs(this_stage_depth_val- depth_gt[:,None,:,:])  # B,D,H,W
            _, indices = torch.min(depth_offsets, dim=1)  # [B, H, W]
            indices = indices.reshape(-1)  # [BHW]
            mask = mask.reshape(-1)  # BHW
            if fl:  # -p(1-q)^a log(q)
                this_stage_ce_loss = F.nll_loss((1-attn_weights[mask])**2 * torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')
            else:  # -plog(q)
                this_stage_ce_loss = F.nll_loss(torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')
            stage_depth_loss.append(depth_loss)
            stage_ce_loss.append(this_stage_ce_loss)

            this_stage_loss = l1ce_lw[0]*depth_loss + l1ce_lw[1]*this_stage_ce_loss
        
        # Loss for 3D conv
        else: 
            if wt:
                depth_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
                stage_depth_loss.append(depth_loss)
                attn_weights = stage_inputs["attn_weights"].permute(0,2,3,1).reshape(B*H*W, -1)  # BHW D
                depth_offsets = torch.abs(stage_inputs['depth_values']- depth_gt[:,None,:,:])  # B,D,H,W
                indices = torch.min(depth_offsets, dim=1)[1].reshape(-1)  # [BHW]
                mask = mask.reshape(-1)  # BHW
                if fl:  # -p(1-q)^a log(q)
                    this_stage_ce_loss = F.nll_loss((1-attn_weights[mask])**2 * torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')
                else:  # -plog(q)
                    this_stage_ce_loss = F.nll_loss(torch.log(attn_weights[mask]+1e-12), indices[mask], reduce='mean')
                stage_ce_loss.append(this_stage_ce_loss)
            else:
                depth_loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='mean')
                stage_depth_loss.append(depth_loss)
                this_stage_ce_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False)
                stage_ce_loss.append(this_stage_ce_loss)

            this_stage_loss = l1ce_lw[0]*depth_loss + l1ce_lw[1]*this_stage_ce_loss
        
        if upsampled_loss:
            if stage_idx!=0:
                this_stage_loss = this_stage_loss + upsampled_depth_loss * l1ce_lw[0]
        if shrink_method == 'DPF':
            if stage_idx!=0:
                depth_offsets = stage_inputs['depth_values'] - depth_gt[:,None,:,:]  # B,D,H,W
                depth_offset_clamp = torch.clamp(depth_offsets, -1, 1)
                this_stage_loss = this_stage_loss + torch.abs(depth_offset_clamp).permute(0,2,3,1).reshape(B*H*W, -1)[mask.reshape(-1)].mean()
        if selected_loss:
            select_weight = stage_inputs["select_weight"].permute(0,2,3,1).reshape(B*H*W, -1)  # BHW D
            depth_offsets = torch.abs(stage_inputs['depth_values']- depth_gt[:,None,:,:]) 
            indices = torch.min(depth_offsets, dim=1)[1]  # [B, H, W]
            indices = indices.reshape(-1)  # [BHW]
            mask = mask.reshape(-1)  # BHW
            this_stage_selected_loss = F.nll_loss(torch.log(select_weight[mask]+1e-12), indices[mask], reduce='mean')
            this_stage_loss = this_stage_loss + this_stage_selected_loss * 0.01*l1ce_lw[1]
        if det:
            assert wt
            depth_itv = stage_inputs['depth_values'][:,1,:,:] - stage_inputs['depth_values'][:,0,:,:]   # B H W
            pred_offset = stage_inputs['offset_reg'].reshape(-1)  # BHW
            offset_gt = (depth_gt - (depth_est - stage_inputs['offset_reg'])).reshape(-1) / depth_itv.reshape(-1) # BHW
            det_offset_loss = F.smooth_l1_loss(pred_offset[mask], offset_gt[mask], reduction='mean')
            det_offset_losses.append(det_offset_loss)
            this_stage_loss += det_offset_loss
        else:
            det_offset_losses.append(torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False))

        if depth_loss_weights is not None:
            stage_idx = int(stage_key.replace("stage", "")) - 1
            total_loss += depth_loss_weights[stage_idx] * this_stage_loss
        else:
            total_loss += 1.0 * this_stage_loss

    if visual:
        axs[1].set_title('err848:{}'.format(err_848_str) + 'err_106:{}'.format(err_106_str) + 'err_002:{}'.format(err_002_str))
        f.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/offset_heatmap.png')
        f.clf()
        f2.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/attn_max_heatmap.png')
        f2.clf()
        f3.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/attn_gt_heatmap.png')
        f3.clf()
        f4.savefig('/mnt/cfs/algorithm/xiaofeng.wang/jeff/code/MVS/cascade-stereo/CasMVSNet/debug_figs/max_gt_offset_heatmap.png')
        f4.clf()

    return total_loss, depth_loss, stage_depth_loss, stage_ce_loss, range_err_ratio, upsampled_depth_losses, det_offset_losses


def get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, shape, max_depth=192.0, min_depth=0.0):
    #shape, (B, H, W)
    #cur_depth: (B, H, W)
    #return depth_range_values: (B, D, H, W)
    cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel)  # (B, H, W)
    cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel)
    # cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel).clamp(min=0.0)   #(B, H, W)
    # cur_depth_max = (cur_depth_min + (ndepth - 1) * depth_inteval_pixel).clamp(max=max_depth)

    assert cur_depth.shape == torch.Size(shape), "cur_depth:{}, input shape:{}".format(cur_depth.shape, shape)
    new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1)  # (B, H, W)

    depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=cur_depth.device,
                                                                  dtype=cur_depth.dtype,
                                                                  requires_grad=False).reshape(1, -1, 1, 1) * new_interval.unsqueeze(1))

    return depth_range_samples


def get_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, device, dtype, shape,
                           max_depth=192.0, min_depth=0.0):
    #shape: (B, H, W)
    #cur_depth: (B, H, W) or (B, D)
    #return depth_range_samples: (B, D, H, W)
    if cur_depth.dim() == 2:
        cur_depth_min = cur_depth[:, 0]  # (B,)
        cur_depth_max = cur_depth[:, -1]
        new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1)  # (B, )

        depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=device, dtype=dtype,
                                                                       requires_grad=False).reshape(1, -1) * new_interval.unsqueeze(1)) #(B, D)

        depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, shape[1], shape[2]) #(B, D, H, W)

    else:

        depth_range_samples = get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, shape, max_depth, min_depth)

    return depth_range_samples



if __name__ == "__main__":
    # some testing code, just IGNORE it
    import sys
    sys.path.append("../")
    from datasets import find_dataset_def
    from torch.utils.data import DataLoader
    import numpy as np
    import cv2
    import matplotlib as mpl
    mpl.use('Agg')
    import matplotlib.pyplot as plt

    # MVSDataset = find_dataset_def("colmap")
    # dataset = MVSDataset("../data/results/ford/num10_1/", 3, 'test',
    #                      128, interval_scale=1.06, max_h=1250, max_w=1024)

    MVSDataset = find_dataset_def("dtu_yao")
    num_depth = 48
    dataset = MVSDataset("../data/DTU/mvs_training/dtu/", '../lists/dtu/train.txt', 'train',
                         3, num_depth, interval_scale=1.06 * 192 / num_depth)

    dataloader = DataLoader(dataset, batch_size=1)
    item = next(iter(dataloader))

    imgs = item["imgs"][:, :, :, ::4, ::4]  #(B, N, 3, H, W)
    # imgs = item["imgs"][:, :, :, :, :]
    proj_matrices = item["proj_matrices"]   #(B, N, 2, 4, 4) dim=N: N view; dim=2: index 0 for extr, 1 for intric
    proj_matrices[:, :, 1, :2, :] = proj_matrices[:, :, 1, :2, :]
    # proj_matrices[:, :, 1, :2, :] = proj_matrices[:, :, 1, :2, :] * 4
    depth_values = item["depth_values"]     #(B, D)

    imgs = torch.unbind(imgs, 1)
    proj_matrices = torch.unbind(proj_matrices, 1)
    ref_img, src_imgs = imgs[0], imgs[1:]
    ref_proj, src_proj = proj_matrices[0], proj_matrices[1:][0]  #only vis first view

    src_proj_new = src_proj[:, 0].clone()
    src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4])
    ref_proj_new = ref_proj[:, 0].clone()
    ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4])

    warped_imgs = homo_warping(src_imgs[0], src_proj_new, ref_proj_new, depth_values)

    ref_img_np = ref_img.permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255
    cv2.imwrite('../tmp/ref.png', ref_img_np)
    cv2.imwrite('../tmp/src.png', src_imgs[0].permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255)

    for i in range(warped_imgs.shape[2]):
        warped_img = warped_imgs[:, :, i, :, :].permute([0, 2, 3, 1]).contiguous()
        img_np = warped_img[0].detach().cpu().numpy()
        img_np = img_np[:, :, ::-1] * 255

        alpha = 0.5
        beta = 1 - alpha
        gamma = 0
        img_add = cv2.addWeighted(ref_img_np, alpha, img_np, beta, gamma)
        cv2.imwrite('../tmp/tmp{}.png'.format(i), np.hstack([ref_img_np, img_np, img_add])) #* ratio + img_np*(1-ratio)]))

================================================
FILE: models/mvs4net_utils.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import importlib
try:
    from modules.deform_conv import DeformConvPack
except:
    print('DeformConvPack not found, please install it from: https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch')
    pass
import math
import numpy as np

def homo_warping(src_fea, src_proj, ref_proj, depth_values, vis_ETA=False, fn=None):
    # src_fea: [B, C, H, W]
    # src_proj: [B, 4, 4]
    # ref_proj: [B, 4, 4]
    # depth_values: [B, Ndepth] o [B, Ndepth, H, W]
    # out: [B, C, Ndepth, H, W]
    C = src_fea.shape[1]
    Hs,Ws = src_fea.shape[-2:]
    B,num_depth,Hr,Wr = depth_values.shape

    with torch.no_grad():
        proj = torch.matmul(src_proj, torch.inverse(ref_proj))
        rot = proj[:, :3, :3]  # [B,3,3]
        trans = proj[:, :3, 3:4]  # [B,3,1]

        y, x = torch.meshgrid([torch.arange(0, Hr, dtype=torch.float32, device=src_fea.device),
                               torch.arange(0, Wr, dtype=torch.float32, device=src_fea.device)])
        y = y.reshape(Hr*Wr)
        x = x.reshape(Hr*Wr)
        xyz = torch.stack((x, y, torch.ones_like(x)))  # [3, H*W]
        xyz = torch.unsqueeze(xyz, 0).repeat(B, 1, 1)  # [B, 3, H*W]
        rot_xyz = torch.matmul(rot, xyz)  # [B, 3, H*W]
        rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.reshape(B, 1, num_depth, -1)  # [B, 3, Ndepth, H*W]
        proj_xyz = rot_depth_xyz + trans.reshape(B, 3, 1, 1)  # [B, 3, Ndepth, H*W]
        # FIXME divide 0
        temp = proj_xyz[:, 2:3, :, :]
        temp[temp==0] = 1e-9
        proj_xy = proj_xyz[:, :2, :, :] / temp  # [B, 2, Ndepth, H*W]
        # proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :]  # [B, 2, Ndepth, H*W]

        proj_x_normalized = proj_xy[:, 0, :, :] / ((Ws - 1) / 2) - 1
        proj_y_normalized = proj_xy[:, 1, :, :] / ((Hs - 1) / 2) - 1
        proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3)  # [B, Ndepth, H*W, 2]
        if vis_ETA:
            tensor_saved = proj_xy.reshape(B,num_depth,Hs,Ws,2).cpu().numpy()
            np.save(fn+'_grid', tensor_saved)
        grid = proj_xy
    if len(src_fea.shape)==4:
        warped_src_fea = F.grid_sample(src_fea, grid.reshape(B, num_depth * Hr, Wr, 2), mode='bilinear', padding_mode='zeros', align_corners=True)
        warped_src_fea = warped_src_fea.reshape(B, C, num_depth, Hr, Wr)
    elif len(src_fea.shape)==5:
        warped_src_fea = []
        for d in range(src_fea.shape[2]):
            warped_src_fea.append(F.grid_sample(src_fea[:,:,d], grid.reshape(B, num_depth, Hr, Wr, 2)[:,d], mode='bilinear', padding_mode='zeros', align_corners=True))
        warped_src_fea = torch.stack(warped_src_fea, dim=2)

    return warped_src_fea

def init_range(cur_depth, ndepths, device, dtype, H, W):
    cur_depth_min = cur_depth[:, 0]  # (B,)
    cur_depth_max = cur_depth[:, -1]
    new_interval = (cur_depth_max - cur_depth_min) / (ndepths - 1)  # (B, )
    new_interval = new_interval[:, None, None]  # B H W
    depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepths, device=device, dtype=dtype,
                                                                requires_grad=False).reshape(1, -1) * new_interval.squeeze(1)) #(B, D)
    depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, H, W) #(B, D, H, W)
    return depth_range_samples

def init_inverse_range(cur_depth, ndepths, device, dtype, H, W):
    inverse_depth_min = 1. / cur_depth[:, 0]  # (B,)
    inverse_depth_max = 1. / cur_depth[:, -1]
    itv = torch.arange(0, ndepths, device=device, dtype=dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H, W)  / (ndepths - 1)  # 1 D H W
    inverse_depth_hypo = inverse_depth_max[:,None, None, None] + (inverse_depth_min - inverse_depth_max)[:,None, None, None] * itv

    return 1./inverse_depth_hypo

def schedule_inverse_range(inverse_min_depth, inverse_max_depth, ndepths, H, W):
    #cur_depth_min, (B, H, W)
    #cur_depth_max: (B, H, W)
    itv = torch.arange(0, ndepths, device=inverse_min_depth.device, dtype=inverse_min_depth.dtype, requires_grad=False).reshape(1, -1,1,1).repeat(1, 1, H//2, W//2)  / (ndepths - 1)  # 1 D H W

    inverse_depth_hypo = inverse_max_depth[:,None, :, :] + (inverse_min_depth - inverse_max_depth)[:,None, :, :] * itv  # B D H W
    inverse_depth_hypo = F.interpolate(inverse_depth_hypo.unsqueeze(1), [ndepths, H, W], mode='trilinear', align_corners=True).squeeze(1)
    return 1./inverse_depth_hypo

def schedule_range(cur_depth, ndepth, depth_inteval_pixel, H, W):
    #shape, (B, H, W)
    #cur_depth: (B, H, W)
    #return depth_range_values: (B, D, H, W)
    cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel[:,None,None])  # (B, H, W)
    cur_depth_max = (cur_depth + ndepth / 2 * depth_inteval_pixel[:,None,None])
    new_interval = (cur_depth_max - cur_depth_min) / (ndepth - 1)  # (B, H, W)

    depth_range_samples = cur_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=cur_depth.device, dtype=cur_depth.dtype,
                                                                  requires_grad=False).reshape(1, -1, 1, 1) * new_interval.unsqueeze(1))
    depth_range_samples = F.interpolate(depth_range_samples.unsqueeze(1), [ndepth, H, W], mode='trilinear', align_corners=True).squeeze(1)
    return depth_range_samples

def init_bn(module):
    if module.weight is not None:
        nn.init.ones_(module.weight)
    if module.bias is not None:
        nn.init.zeros_(module.bias)
    return

def init_uniform(module, init_method):
    if module.weight is not None:
        if init_method == "kaiming":
            nn.init.kaiming_uniform_(module.weight)
        elif init_method == "xavier":
            nn.init.xavier_uniform_(module.weight)
    return

class ConvBnReLU3D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
        super(ConvBnReLU3D, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm3d(out_channels)

    def forward(self, x):
        return F.relu(self.bn(self.conv(x)), inplace=True)

class ConvBnReLU3D_CAM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
        super(ConvBnReLU3D_CAM, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm3d(out_channels)
        self.linear_agg = nn.Sequential(
            nn.Linear(out_channels, out_channels//2),
            nn.ReLU(),
            nn.Linear(out_channels//2, out_channels)
        )

    def forward(self, input):
        x = self.conv(input)
        B,C,D,H,W = x.shape
        avg_attn = self.linear_agg(x.reshape(B,C,D*H*W).mean(2))
        max_attn = self.linear_agg(x.reshape(B,C,D*H*W).max(2)[0])  # B C
        attn = F.sigmoid(max_attn+avg_attn)[:,:,None,None,None]  # B C,1,1,1
        x = x * attn
        return F.relu(self.bn(x+input), inplace=True)

class ConvBnReLU3D_DCAM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
        super(ConvBnReLU3D_DCAM, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm3d(out_channels)
        self.linear_agg = nn.Sequential(
            nn.Linear(out_channels, out_channels//2),
            nn.ReLU(),
            nn.Linear(out_channels//2, out_channels)
        )

    def forward(self, input):
        x = self.conv(input)
        B,C,D,H,W = x.shape
        avg_attn = self.linear_agg(x.reshape(B,C,D,H*W).mean(3).permute(0,2,1).reshape(B*D,C)).reshape(B,D,C).permute(0,2,1)
        max_attn = self.linear_agg(x.reshape(B,C,D,H*W).max(3)[0].permute(0,2,1).reshape(B*D,C)).reshape(B,D,C).permute(0,2,1)  # B C D
        attn = F.sigmoid(max_attn+avg_attn)[:,:,:,None,None]  # B C,D,1,1
        x = x * attn
        return F.relu(self.bn(x+input), inplace=True)

class ConvBnReLU3D_PAM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
        super(ConvBnReLU3D_PAM, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm3d(out_channels)
        self.pixel_conv = nn.Conv2d(2,1,7,stride=1,padding='same')

    def forward(self, input):
        x = self.conv(input)
        B,C,D,H,W = x.shape
        max_attn = x.reshape(B,C*D,H,W).max(1, keepdim=True)[0]
        avg_attn = x.reshape(B,C*D,H,W).mean(1, keepdim=True)  # B 1 H W
        attn = F.sigmoid(self.pixel_conv(torch.cat([max_attn, avg_attn], dim=1)))[:,:,None,:,:]  # B 1,1,H,W
        x = x * attn
        return F.relu(self.bn(x+input), inplace=True)

class ConvBnReLU3D_PDAM(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
        super(ConvBnReLU3D_PDAM, self).__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False)
        self.bn = nn.BatchNorm3d(out_channels)
        self.spatial_conv = nn.Conv3d(2,1,7,stride=1,padding='same')

    def forward(self, input):
        x = self.conv(input)
        B,C,D,H,W = x.shape
        max_attn = x.max(1, keepdim=True)[0]
        avg_attn = x.mean(1, keepdim=True)  # B 1 D H W
        attn = F.sigmoid(self.spatial_conv(torch.cat([max_attn, avg_attn], dim=1)))  # B 1,D,H,W
        x = x * attn
        return F.relu(self.bn(x+input), inplace=True)

class Deconv3d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
        super(Deconv3d, self).__init__()
        self.out_channels = out_channels
        assert stride in [1, 2]
        self.stride = stride

        self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride,
                                       bias=(not bn), **kwargs)
        self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None
        self.relu = relu

    def forward(self, x):
        y = self.conv(x)
        if self.bn is not None:
            x = self.bn(y)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x

    def init_weights(self, init_method):
        init_uniform(self.conv, init_method)
        if self.bn is not None:
            init_bn(self.bn)

class Conv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 relu=True, bn_momentum=0.1, init_method="xavier", gn=False, group_channel=8, **kwargs):
        super(Conv2d, self).__init__()
        bn = not gn
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
                              bias=(not bn), **kwargs)
        self.kernel_size = kernel_size
        self.stride = stride
        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
        self.gn = nn.GroupNorm(int(max(1, out_channels / group_channel)), out_channels) if gn else None
        self.relu = relu

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        else:
            x = self.gn(x)
        if self.relu:
            x = F.relu(x, inplace=True)
        return x

    def init_weights(self, init_method):
        init_uniform(self.conv, init_method)
        if self.bn is not None:
            init_bn(self.bn)

class Deconv2d(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs):
        super(Deconv2d, self).__init__()
        self.out_channels = out_channels
        assert stride in [1, 2]
        self.stride = stride

        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride,
                                       bias=(not bn), **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None
        self.relu = relu

class DeformConv2d(nn.Module):
    def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=True):
        super(DeformConv2d, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.zero_padding = nn.ZeroPad2d(padding)
        self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)

        self.p_conv = nn.Conv2d(inc, 2*kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
        nn.init.constant_(self.p_conv.weight, 0)
        self.p_conv.register_backward_hook(self._set_lr)

        self.modulation = modulation
        if modulation:
            self.m_conv = nn.Conv2d(inc, kernel_size*kernel_size, kernel_size=3, padding=1, stride=stride)
            nn.init.constant_(self.m_conv.weight, 0)
            self.m_conv.register_backward_hook(self._set_lr)

    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))

    def forward(self, x):
        offset = self.p_conv(x)
        if self.modulation:
            m = torch.sigmoid(self.m_conv(x))

        dtype = offset.data.type()
        ks = self.kernel_size
        N = offset.size(1) // 2

        if self.padding:
            x = self.zero_padding(x)

        # (b, 2N, h, w)
        p = self._get_p(offset, dtype)

        # (b, h, w, 2N)
        p = p.contiguous().permute(0, 2, 3, 1)
        q_lt = p.detach().floor()
        q_rb = q_lt + 1

        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)

        # clip p
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)

        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))

        # (b, c, h, w, N)
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)

        # (b, c, h, w, N)
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
                   g_rb.unsqueeze(dim=1) * x_q_rb + \
                   g_lb.unsqueeze(dim=1) * x_q_lb + \
                   g_rt.unsqueeze(dim=1) * x_q_rt

        # modulation
        if self.modulation:
            m = m.contiguous().permute(0, 2, 3, 1)
            m = m.unsqueeze(dim=1)
            m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
            x_offset *= m

        x_offset = self._reshape_x_offset(x_offset, ks)
        out = self.conv(x_offset)

        return out

    def _get_p_n(self, N, dtype):
        p_n_x, p_n_y = torch.meshgrid(
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
        # (2N, 1)
        p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
        p_n = p_n.view(1, 2*N, 1, 1).type(dtype)

        return p_n

    def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
            torch.arange(1, h*self.stride+1, self.stride),
            torch.arange(1, w*self.stride+1, self.stride))
        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

        return p_0

    def _get_p(self, offset, dtype):
        N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)

        # (1, 2N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, 2N, h, w)
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + p_n + offset
        return p

    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        # (b, c, h*w)
        x = x.contiguous().view(b, c, -1)

        # (b, h, w, N)
        index = q[..., :N]*padded_w + q[..., N:]  # offset_x*w + offset_y
        # (b, c, h*w*N)
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)

        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)

        return x_offset

    @staticmethod
    def _reshape_x_offset(x_offset, ks):
        b, c, h, w, N = x_offset.size()
        x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
        x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)

        return x_offset

def NA_DCN(in_channels, kernel_size=3, stride=1, dilation=1, bias=True, group_channel=8, gn=False):
    if gn:
        return nn.Sequential(
            nn.GroupNorm(int(max(1, in_channels / group_channel)), in_channels),
            nn.ReLU(inplace=True),
            # DeformConv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride,  bias=bias),
            DeformConvPack(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=1, deformable_groups=1, bias=False, im2col_step=16)
        )
    else:
        return nn.Sequential(
            nn.BatchNorm2d(in_channels, momentum=0.1),
            nn.ReLU(inplace=True),
            # DeformConv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride,  bias=bias),
            DeformConvPack(in_channels, in_channels, kernel_size=3, stride=1, padding=1, groups=1, deformable_groups=1, bias=False, im2col_step=16)
        )

class FPN4(nn.Module):
    """
    FPN aligncorners downsample 4x"""
    def __init__(self, base_channels, gn=False, dcn=False):
        super(FPN4, self).__init__()
        self.base_channels = base_channels

        self.conv0 = nn.Sequential(
            Conv2d(3, base_channels, 3, 1, padding=1, gn=gn),
            Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn),
        )

        self.conv1 = nn.Sequential(
            Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2, gn=gn),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn),
            Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1, gn=gn),
        )

        self.conv2 = nn.Sequential(
            Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2, gn=gn),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn),
            Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1, gn=gn),
        )

        self.conv3 = nn.Sequential(
            Conv2d(base_channels * 4, base_channels * 8, 5, stride=2, padding=2, gn=gn),
            Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn),
            Conv2d(base_channels * 8, base_channels * 8, 3, 1, padding=1, gn=gn),
        )

        self.out_channels = [8 * base_channels]
        final_chs = base_channels * 8

        self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)
        self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
        self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)

        self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)
        self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)
        self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
        self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)

        self.dcn = dcn
        if self.dcn:
            self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn)
            self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn)
            self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn)
            self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn)

        self.out_channels.append(base_channels * 4)
        self.out_channels.append(base_channels * 2)
        self.out_channels.append(base_channels)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)

        intra_feat = conv3
        outputs = {}
        out1 = self.out1(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2)
        out2 = self.out2(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1)
        out3 = self.out3(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0)
        out4 = self.out4(intra_feat)

        if self.dcn:
            out1 = self.dcn1(out1)
            out2 = self.dcn2(out2)
            out3 = self.dcn3(out3)
            out4 = self.dcn4(out4)

        outputs["stage1"] = out1
        outputs["stage2"] = out2
        outputs["stage3"] = out3
        outputs["stage4"] = out4

        return outputs

class LayerNorm(nn.Module):

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class convnext_block(nn.Module):

    def __init__(self, dim, layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, 2*dim, kernel_size=7, stride=2, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm(2*dim, eps=1e-6)
        self.pwconv1 = nn.Linear(2*dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, 2*dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((2*dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

        # x = input + x
        return x

class convnext4_block(nn.Module):

    def __init__(self, dim, layer_scale_init_value=1e-6):
        super().__init__()
        self.sconv = nn.Conv2d(dim, 2*dim, kernel_size=2, stride=2, padding=0) # stride=2 conv
        self.dwconv = nn.Conv2d(2*dim, 2*dim, kernel_size=7, stride=1, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm(2*dim, eps=1e-6)
        self.pwconv1 = nn.Linear(2*dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, 2*dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((2*dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None

    def forward(self, x):
        input = self.sconv(x)
        x = self.dwconv(input)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

        x = input + x
        return x

class FPN4_convnext(nn.Module):
    """
    FPN aligncorners downsample 4x"""
    def __init__(self, base_channels, gn=False, dcn=False):
        super(FPN4_convnext, self).__init__()
        self.base_channels = base_channels

        self.conv0 = nn.Sequential(
            Conv2d(3, base_channels, 3, 1, padding=1, gn=gn),
            Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn),
        )

        self.conv1 = convnext_block(base_channels)
        self.conv2 = convnext_block(2*base_channels)
        self.conv3 = convnext_block(4*base_channels)

        self.out_channels = [8 * base_channels]
        final_chs = base_channels * 8

        self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)
        self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
        self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)

        self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)
        self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)
        self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
        self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)

        self.dcn = dcn
        if self.dcn:
            self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn)
            self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn)
            self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn)
            self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn)

        self.out_channels.append(base_channels * 4)
        self.out_channels.append(base_channels * 2)
        self.out_channels.append(base_channels)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)

        intra_feat = conv3
        outputs = {}
        out1 = self.out1(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2)
        out2 = self.out2(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1)
        out3 = self.out3(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0)
        out4 = self.out4(intra_feat)

        if self.dcn:
            out1 = self.dcn1(out1)
            out2 = self.dcn2(out2)
            out3 = self.dcn3(out3)
            out4 = self.dcn4(out4)

        outputs["stage1"] = out1
        outputs["stage2"] = out2
        outputs["stage3"] = out3
        outputs["stage4"] = out4

        return outputs

class FPN4_convnext4(nn.Module):
    """
    FPN aligncorners downsample 4x"""
    def __init__(self, base_channels, gn=False, dcn=False):
        super(FPN4_convnext4, self).__init__()
        self.base_channels = base_channels

        self.conv0 = nn.Sequential(
            Conv2d(3, base_channels, 3, 1, padding=1, gn=gn),
            Conv2d(base_channels, base_channels, 3, 1, padding=1, gn=gn),
        )

        self.conv1 = convnext4_block(base_channels)
        self.conv2 = convnext4_block(2*base_channels)
        self.conv3 = convnext4_block(4*base_channels)

        self.out_channels = [8 * base_channels]
        final_chs = base_channels * 8

        self.inner1 = nn.Conv2d(base_channels * 4, final_chs, 1, bias=True)
        self.inner2 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True)
        self.inner3 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True)

        self.out1 = nn.Conv2d(final_chs, base_channels * 8, 1, bias=False)
        self.out2 = nn.Conv2d(final_chs, base_channels * 4, 3, padding=1, bias=False)
        self.out3 = nn.Conv2d(final_chs, base_channels * 2, 3, padding=1, bias=False)
        self.out4 = nn.Conv2d(final_chs, base_channels, 3, padding=1, bias=False)

        self.dcn = dcn
        if self.dcn:
            self.dcn1 = NA_DCN(base_channels * 8, 3, gn=gn)
            self.dcn2 = NA_DCN(base_channels * 4, 3, gn=gn)
            self.dcn3 = NA_DCN(base_channels * 2, 3, gn=gn)
            self.dcn4 = NA_DCN(base_channels * 1, 3, gn=gn)

        self.out_channels.append(base_channels * 4)
        self.out_channels.append(base_channels * 2)
        self.out_channels.append(base_channels)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv1 = self.conv1(conv0)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)

        intra_feat = conv3
        outputs = {}
        out1 = self.out1(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner1(conv2)
        out2 = self.out2(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner2(conv1)
        out3 = self.out3(intra_feat)

        intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear", align_corners=True) + self.inner3(conv0)
        out4 = self.out4(intra_feat)

        if self.dcn:
            out1 = self.dcn1(out1)
            out2 = self.dcn2(out2)
            out3 = self.dcn3(out3)
            out4 = self.dcn4(out4)

        outputs["stage1"] = out1
        outputs["stage2"] = out2
        outputs["stage3"] = out3
        outputs["stage4"] = out4

        return outputs
    
class ASFF(nn.Module):
    def __init__(self, level):
        super(ASFF, self).__init__()
        self.level = level
        self.dim = [64,32,16,8]
        self.inter_dim = self.dim[self.level]
        if level==0:
            self.stride_level_1 = Conv2d(32, 64, 3, stride=2, padding=1)
            self.stride_level_2 = Conv2d(16, 64, 3, stride=2, padding=1)
            self.stride_level_3 = Conv2d(8, 64, 3, stride=2, padding=1)
            self.expand = Conv2d(64, 64, 3, stride=1, padding=1)
        elif level==1:
            self.compress_level_0 =  Conv2d(64, 32, 1, stride=1, padding=0)
            self.stride_level_2 = Conv2d(16, 32, 3, stride=2, padding=1)
            self.stride_level_3 = Conv2d(8, 32, 3, stride=2, padding=1)
            self.expand = Conv2d(32, 32, 3, stride=1, padding=1)
        elif level==2:
            self.compress_level_0 = Conv2d(64, 16, 1, stride=1, padding=0)
            self.compress_level_1 = Conv2d(32, 16, 1, stride=1, padding=0)
            self.stride_level_3 = Conv2d(8, 16, 3, stride=2, padding=1)
            self.expand = Conv2d(16, 16, 3, stride=1, padding=1)
        elif level==3:
            self.compress_level_0 = Conv2d(64, 8, 1, stride=1, padding=0)
            self.compress_level_1 = Conv2d(32, 8, 1, stride=1, padding=0)
            self.compress_level_2 = Conv2d(16, 8, 1, stride=1, padding=0)
            self.expand = Conv2d(8, 8, 3, stride=1, padding=1)

        self.weight_level_0 = Conv2d(self.dim[level], 8, 1, 1, 0)
        self.weight_level_1 = Conv2d(self.dim[level], 8, 1, 1, 0)
        self.weight_level_2 = Conv2d(self.dim[level], 8, 1, 1, 0)
        self.weight_level_3 = Conv2d(self.dim[level], 8, 1, 1, 0)

        self.weight_levels = nn.Conv2d(32, 4, kernel_size=1, stride=1, padding=0)


    def forward(self, x_level_0, x_level_1, x_level_2, x_level_3):
        if self.level==0:
            level_0_resized = x_level_0
            level_1_resized = self.stride_level_1(x_level_1)
            level_2_downsampled_inter = F.max_pool2d(x_level_2, 2, stride=2, padding=0)
            level_2_resized = self.stride_level_2(level_2_downsampled_inter)
            level_3_downsampled_inter = F.max_pool2d(x_level_3, 4, stride=4, padding=0)
            level_3_resized = self.stride_level_3(level_3_downsampled_inter)

        elif self.level==1:
            level_0_compressed = self.compress_level_0(x_level_0)
            level_0_resized = F.interpolate(level_0_compressed, scale_factor=2, mode='nearest')
            level_1_resized = x_level_1
            level_2_resized = self.stride_level_2(x_level_2)
            level_3_downsampled_inter = F.max_pool2d(x_level_3, 2, stride=2, padding=0)
            level_3_resized = self.stride_level_3(level_3_downsampled_inter)
        elif self.level==2:
            level_0_compressed = self.compress_level_0(x_level_0)
            level_0_resized = F.interpolate(level_0_compressed, scale_factor=4, mode='nearest')
            level_1_compressed = self.compress_level_1(x_level_1)
            level_1_resized = F.interpolate(level_1_compressed, scale_factor=2, mode='nearest')
            level_2_resized = x_level_2
            level_3_resized = self.stride_level_3(x_level_3)
        elif self.level==3:
            level_0_compressed = self.compress_level_0(x_level_0)
            level_0_resized = F.interpolate(level_0_compressed, scale_factor=8, mode='nearest')
            level_1_compressed = self.compress_level_1(x_level_1)
            level_1_resized = F.interpolate(level_1_compressed, scale_factor=4, mode='nearest')
            level_2_compressed = self.compress_level_2(x_level_2)
            level_2_resized = F.interpolate(level_2_compressed, scale_factor=2, mode='nearest')
            level_3_resized = x_level_3

        level_0_weight_v = self.weight_level_0(level_0_resized)
        level_1_weight_v = self.weight_level_1(level_1_resized)
        level_2_weight_v = self.weight_level_2(level_2_resized)
        level_3_weight_v = self.weight_level_3(level_3_resized)
        levels_weight_v = torch.cat((level_0_weight_v, level_1_weight_v, level_2_weight_v, level_3_weight_v),1)
        levels_weight = self.weight_levels(levels_weight_v)
        levels_weight = F.softmax(levels_weight, dim=1)

        fused_out_reduced = level_0_resized * levels_weight[:,0:1,:,:]+\
                            level_1_resized * levels_weight[:,1:2,:,:]+\
                            level_2_resized * levels_weight[:,2:3,:,:]+\
                            level_3_resized * levels_weight[:,3:,:,:]

        out = self.expand(fused_out_reduced)

        return out

class FullImageEncoder(nn.Module):
    def __init__(self, h, w, kernel_size):
        super(FullImageEncoder, self).__init__()
        self.global_pooling = nn.AvgPool2d(kernel_size, stride=kernel_size, padding=kernel_size // 2)  # KITTI 16 16
        self.dropout = nn.Dropout2d(p=0.5)
        self.h = h // kernel_size + 1
        self.w = w // kernel_size + 1
        # print("h=", self.h, " w=", self.w, h, w)
        self.global_fc = nn.Linear(2048 * self.h * self.w, 512)  # kitti 4x5
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(512, 512, 1)  # 1x1 卷积

    def forward(self, x):
        # print('x size:', x.size())
        x1 = self.global_pooling(x)
        # print('# x1 size:', x1.size())
        x2 = self.dropout(x1)
        x3 = x2.view(-1, 2048 * self.h * self.w)  # kitti 4x5
        x4 = self.relu(self.global_fc(x3))
        # print('# x4 size:', x4.size())
        x4 = x4.view(-1, 512, 1, 1)
        # print('# x4 size:', x4.size())
        x5 = self.conv1(x4)
        # out = self.upsample(x5)
        return x5

class mono_depth_decoder(nn.Module):

    def __init__(self):
        super(mono_depth_decoder, self).__init__()
        self.convblocks = nn.ModuleList(
            [Conv2d(64, 32, 3, 1, padding=1),
            Conv2d(32, 16, 3, 1, padding=1),
            Conv2d(16, 8, 3, 1, padding=1)]
        )
        self.conv3x3 = nn.ModuleList(
           [nn.Conv2d(64, 1, 3, 1, 1),
            nn.Conv2d(32, 1, 3, 1, 1),
            nn.Conv2d(16, 1, 3, 1, 1)]
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, outputs, d_min, d_max):
        """
        d_max: B
        """
        for i in range(1,4):  # 1 2 3
            mono_small_feat = outputs['stage{}'.format(i)]['mono_feat']
            mono_large_feat = outputs['stage{}'.format(i+1)]['mono_feat']

            mono_small_feat = self.convblocks[i-1](mono_small_feat)
            mono_small_feat = F.interpolate(mono_small_feat, scale_factor=2, mode="nearest")

            mono_feat = self.conv3x3[i-1](torch.cat([mono_small_feat, mono_large_feat], 1))  # B C H W

            disp = self.sigmoid(mono_feat)
            min_disp = (1 / d_max)[:,None,None,None]  # B 1 1 1
            max_disp = (1 / d_min)[:,None,None,None]
            scaled_disp = min_disp + (max_disp - min_disp) * disp
            depth = 1 / scaled_disp
            outputs['stage{}'.format(i+1)]['mono_depth'] = depth.squeeze(1)
        return outputs

class reg2d(nn.Module):
    def __init__(self, input_channel=128, base_channel=32, conv_name='ConvBnReLU3D'):
        super(reg2d, self).__init__()
        module = importlib.import_module("models.mvs4net_utils")
        stride_conv_name = 'ConvBnReLU3D'
        self.conv0 = getattr(module, stride_conv_name)(input_channel, base_channel, kernel_size=(1,3,3), pad=(0,1,1))
        self.conv1 = getattr(module, stride_conv_name)(base_channel, base_channel*2, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))
        self.conv2 = getattr(module, conv_name)(base_channel*2, base_channel*2)

        self.conv3 = getattr(module, stride_conv_name)(base_channel*2, base_channel*4, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))
        self.conv4 = getattr(module, conv_name)(base_channel*4, base_channel*4)

        self.conv5 = getattr(module, stride_conv_name)(base_channel*4, base_channel*8, kernel_size=(1,3,3), stride=(1,2,2), pad=(0,1,1))
        self.conv6 = getattr(module, conv_name)(base_channel*8, base_channel*8)

        self.conv7 = nn.Sequential(
            nn.ConvTranspose3d(base_channel*8, base_channel*4, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),
            nn.BatchNorm3d(base_channel*4),
            nn.ReLU(inplace=True))

        self.conv9 = nn.Sequential(
            nn.ConvTranspose3d(base_channel*4, base_channel*2, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),
            nn.BatchNorm3d(base_channel*2),
            nn.ReLU(inplace=True))

        self.conv11 = nn.Sequential(
            nn.ConvTranspose3d(base_channel*2, base_channel, kernel_size=(1,3,3), padding=(0,1,1), output_padding=(0,1,1), stride=(1,2,2), bias=False),
            nn.BatchNorm3d(base_channel),
            nn.ReLU(inplace=True))

        self.prob = nn.Conv3d(8, 1, 1, stride=1, padding=0)

    def forward(self, x):
        conv0 = self.conv0(x)
        conv2 = self.conv2(self.conv1(conv0))
        conv4 = self.conv4(self.conv3(conv2))
        x = self.conv6(self.conv5(conv4))
        x = conv4 + self.conv7(x)
        x = conv2 + self.conv9(x)
        x = conv0 + self.conv11(x)
        x = self.prob(x)

        return x.squeeze(1)

class reg3d(nn.Module):
    def __init__(self, in_channels, base_channels, down_size=3):
        super(reg3d, self).__init__()
        self.down_size = down_size
        self.conv0 = ConvBnReLU3D(in_channels, base_channels, kernel_size=3, pad=1)
        self.conv1 = ConvBnReLU3D(base_channels, base_channels*2, kernel_size=3, stride=2, pad=1)
        self.conv2 = ConvBnReLU3D(base_channels*2, base_channels*2)
        if down_size >= 2:
            self.conv3 = ConvBnReLU3D(base_channels*2, base_channels*4, kernel_size=3, stride=2, pad=1)
            self.conv4 = ConvBnReLU3D(base_channels*4, base_channels*4)
        if down_size >= 3:
            self.conv5 = ConvBnReLU3D(base_channels*4, base_channels*8, kernel_size=3, stride=2, pad=1)
            self.conv6 = ConvBnReLU3D(base_channels*8, base_channels*8)
            self.conv7 = nn.Sequential(
                nn.ConvTranspose3d(base_channels*8, base_channels*4, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),
                nn.BatchNorm3d(base_channels*4),
                nn.ReLU(inplace=True))
        if down_size >= 2:
            self.conv9 = nn.Sequential(
                nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),
                nn.BatchNorm3d(base_channels*2),
                nn.ReLU(inplace=True))

        self.conv11 = nn.Sequential(
            nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False),
            nn.BatchNorm3d(base_channels),
            nn.ReLU(inplace=True))
        self.prob = nn.Conv3d(base_channels, 1, 3, stride=1, padding=1, bias=False)

    def forward(self, x):
        if self.down_size==3:
            conv0 = self.conv0(x)
            conv2 = self.conv2(self.conv1(conv0))
            conv4 = self.conv4(self.conv3(conv2))
            x = self.conv6(self.conv5(conv4))
            x = conv4 + self.conv7(x)
            x = conv2 + self.conv9(x)
            x = conv0 + self.conv11(x)
            x = self.prob(x)
        elif self.down_size==2:
            conv0 = self.conv0(x)
            conv2 = self.conv2(self.conv1(conv0))
            x = self.conv4(self.conv3(conv2))
            x = conv2 + self.conv9(x)
            x = conv0 + self.conv11(x)
            x = self.prob(x)
        else:
            conv0 = self.conv0(x)
            x = self.conv2(self.conv1(conv0))
            x = conv0 + self.conv11(x)
            x = self.prob(x)
        return x.squeeze(1)  # B D H W

class PosEncSine(nn.Module):

    def __init__(self, temperature=1000):
        super(PosEncSine, self).__init__()
        self.temperature = temperature

    def forward(self, x, depth):
        # depth : B D H W
        with torch.no_grad():
            B,C,D,H,W = x.shape
            depth = depth.permute(0,2,3,1).reshape(B*H*W, D) / self.temperature  # BHW D
            pos = torch.stack([torch.sin(i * math.pi * depth) for i in range(C//2)] + [torch.cos(i * math.pi * depth) for i in range(C//2)], dim=-1)  # BHW,D,C
            pos = pos.reshape(B,H,W,D,C).permute(0,4,3,1,2)  # B C D H W
        x = x + pos
        return x

class PosEncLearned(nn.Module):
    """
    Absolute pos embedding, learned.
    """
    def __init__(self, D, C):
        super().__init__()
        self.D = D
        self.C = C
        self.depth_embed = nn.Parameter(torch.Tensor(C, self.D))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.uniform_(self.depth_embed)

    def forward(self, x, **kwargs):
        B,C,D,H,W = x.shape
        pos = self.depth_embed[None,:,:,None,None].repeat(B,1,1,H,W)  # B C D H W
        x = x + pos
        return x

class stagenet(nn.Module):
    def __init__(self, inverse_depth=False, mono=False, attn_fuse_d=True, vis_ETA=False, attn_temp=1):
        super(stagenet, self).__init__()
        self.inverse_depth = inverse_depth
        self.mono = mono
        self.attn_fuse_d = attn_fuse_d
        self.vis_ETA = vis_ETA
        self.attn_temp = attn_temp

    def forward(self, features, proj_matrices, depth_hypo, regnet, stage_idx, group_cor=False, group_cor_dim=8, split_itv=1, fn=None):

        # step 1. feature extraction
        proj_matrices = torch.unbind(proj_matrices, 1)
        ref_feature, src_features = features[0], features[1:]
        ref_proj, src_projs = proj_matrices[0], proj_matrices[1:]
        B,D,H,W = depth_hypo.shape
        C = ref_feature.shape[1]

        ref_volume =  ref_feature.unsqueeze(2).repeat(1, 1, D, 1, 1)
        cor_weight_sum = 1e-8
        cor_feats = 0
        # step 2. Epipolar Transformer Aggregation
        for src_idx, (src_fea, src_proj) in enumerate(zip(src_features, src_projs)):
            if self.vis_ETA:
                scan_name = fn[0].split('/')[0]
                image_name = fn[0].split('/')[2][:-2]
                save_fn = './debug_figs/vis_ETA/{}_stage{}_src{}'.format(scan_name+'_'+image_name, stage_idx, src_idx)
            else:
                save_fn = None
            src_proj_new = src_proj[:, 0].clone()
            src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4])
            ref_proj_new = ref_proj[:, 0].clone()
            ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4])
            warped_src = homo_warping(src_fea, src_proj_new, ref_proj_new, depth_hypo, self.vis_ETA, save_fn)  # B C D H W
            if group_cor:
                warped_src = warped_src.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W)
                ref_volume = ref_volume.reshape(B, group_cor_dim, C//group_cor_dim, D, H, W)
                cor_feat = (warped_src * ref_volume).mean(2)  # B G D H W
            else:
                cor_feat = (ref_volume - warped_src)**2 # B C D H W 
            del warped_src, src_proj, src_fea
            if self.vis_ETA:
                vis_weight = torch.softmax(cor_feat.sum(1), 1).detach().cpu().numpy()
                np.save(save_fn, vis_weight)

            if not self.attn_fuse_d:
                cor_weight = torch.softmax(cor_feat.sum(1), 1).max(1)[0]  # B H W
                cor_weight_sum += cor_weight  # B H W
                cor_feats += cor_weight.unsqueeze(1).unsqueeze(1) * cor_feat  # B C D H W
            else:
                cor_weight = torch.softmax(cor_feat.sum(1) / self.attn_temp, 1) / math.sqrt(C)  # B D H W
                cor_weight_sum += cor_weight  # B D H W
                cor_feats += cor_weight.unsqueeze(1) * cor_feat  # B C D H W
            del cor_weight, cor_feat
        if not self.attn_fuse_d:
            cor_feats = cor_feats / cor_weight_sum.unsqueeze(1).unsqueeze(1)  # B C D H W
        else:
            cor_feats = cor_feats / cor_weight_sum.unsqueeze(1)  # B C D H W

        del cor_weight_sum, src_features
        
    
        # step 3. regularization
        attn_weight = regnet(cor_feats)  # B D H W
        del cor_feats
        attn_weight = F.softmax(attn_weight, dim=1)  # B D H W

        # step 4. depth argmax
        attn_max_indices = attn_weight.max(1, keepdim=True)[1]  # B 1 H W
        depth = torch.gather(depth_hypo, 1, attn_max_indices).squeeze(1)  # B H W

        if not self.training:
            with torch.no_grad():
                photometric_confidence = attn_weight.max(1)[0]  # B H W
                photometric_confidence = F.interpolate(photometric_confidence.unsqueeze(1), scale_factor=2**(3-stage_idx), mode='bilinear', align_corners=True).squeeze(1)
        else:
            photometric_confidence = torch.tensor(0.0, dtype=torch.float32, device=ref_feature.device, requires_grad=False)
        
        ret_dict = {"depth": depth,  "photometric_confidence": photometric_confidence, "hypo_depth": depth_hypo, "attn_weight": attn_weight}
        
        if self.inverse_depth:
            last_depth_itv = 1./depth_hypo[:,2,:,:] - 1./depth_hypo[:,1,:,:]
            inverse_min_depth = 1/depth + split_itv * last_depth_itv  # B H W
            inverse_max_depth = 1/depth - split_itv * last_depth_itv  # B H W
            ret_dict['inverse_min_depth'] = inverse_min_depth
            ret_dict['inverse_max_depth'] = inverse_max_depth

        # if self.mono and self.training:
        if self.mono:
            ret_dict['mono_feat'] = ref_feature  # B C H W
            
        return ret_dict
 
def sinkhorn(gt_depth, hypo_depth, attn_weight, mask, iters, eps=1, continuous=False):
    """
    gt_depth: B H W
    hypo_depth: B D H W
    attn_weight: B D H W
    mask: B H W
    """
    B,D,H,W = attn_weight.shape
    if not continuous:
        D_map = torch.stack([torch.arange(-i,D-i,1, dtype=torch.float32, device=gt_depth.device) for i in range(D)], dim=1).abs()
        D_map = D_map[None,None,:,:].repeat(B,H*W,1,1)  # B HW D D
        gt_indices = torch.abs(hypo_depth - gt_depth[:,None,:,:]).min(1)[1].squeeze(1).reshape(B*H*W, 1)  # BHW, 1
        gt_dist = torch.zeros_like(hypo_depth).permute(0,2,3,1).reshape(B*H*W, D)
        gt_dist.scatter_add_(1,gt_indices,torch.ones([gt_dist.shape[0],1], dtype=gt_dist.dtype, device=gt_dist.device))
        gt_dist = gt_dist.reshape(B,H*W,D)  # B HW D
    else:
        gt_dist = torch.zeros((B,H*W,D+1), dtype=torch.float32, device=gt_depth.device, requires_grad=False)  # B HW D+1
        gt_dist[:,:,-1] = 1
        D_map = torch.zeros((B,D,D+1), dtype=torch.float32, device=gt_depth.device, requires_grad=False)  # B D D+1
        D_map[:, :D, :D] = torch.stack([torch.arange(-i,D-i,1, dtype=torch.float32, device=gt_depth.device) for i in range(D)], dim=1).abs().unsqueeze(0)  # B D D+1
        D_map = D_map[:,None,None,:,:].repeat(1,H,W,1,1)  # B H W D D+1
        itv = 1/hypo_depth[:,2,:,:] - 1/hypo_depth[:,1,:,:]  # B H W
        gt_bin_distance_ = (1/gt_depth - 1/hypo_depth[:,0,:,:]) / itv  # B H W
        #FIXME hard code 100
        gt_bin_distance_[~mask] = 10

        gt_bin_distance = torch.stack([(gt_bin_distance_ - i).abs() for i in range(D)], dim=1).permute(0,2,3,1)  # B H W D
        D_map[:,:,:,:,-1] = gt_bin_distance
        D_map = D_map.reshape(B,H*W,D,1+D)  # B HW D D+1

    pred_dist = attn_weight.permute(0,2,3,1).reshape(B,H*W,D)  # B HW D

    # map to log space for stability
    log_mu = (gt_dist+1e-12).log()
    log_nu = (pred_dist+1e-12).log()  # B HW D or D+1

    u, v = torch.zeros_like(log_nu), torch.zeros_like(log_mu)
    for _ in range(iters):
        # scale v first then u to ensure row sum is 1, col sum slightly larger than 1
        v = log_mu - torch.logsumexp(D_map/eps + u.unsqueeze(3), dim=2)  # log(sum(exp()))
        u = log_nu - torch.logsumexp(D_map/eps + v.unsqueeze(2), dim=3)

    # convert back from log space, recover probabilities by normalization 2W
    T_map = (D_map/eps + u.unsqueeze(3) + v.unsqueeze(2)).exp()  # B HW D D
    loss = (T_map * D_map).reshape(B*H*W,-1)[mask.reshape(-1)].sum(-1).mean()
    
    return T_map, loss

================================================
FILE: requirements.txt
================================================
torch==1.9.0
torchvision==0.10.0
numpy
pillow
tensorboardX
opencv-python
plyfile

================================================
FILE: scripts/test_dtu.sh
================================================
#!/usr/bin/env bash
DTU_TESTPATH="/mnt/cfs/algorithm/public_data/mvs/dtu_test"
DTU_TESTLIST="lists/dtu/test.txt"

DTU_size=$1
exp=$2
PY_ARGS=${@:3}

DTU_LOG_DIR="./checkpoints/dtu/"$exp 
if [ ! -d $DTU_LOG_DIR ]; then
    mkdir -p $DTU_LOG_DIR
fi
DTU_CKPT_FILE=$DTU_LOG_DIR"/finalmodel.ckpt"
DTU_OUT_DIR="./outputs/dtu/"$exp



if [ $DTU_size = "raw" ] ; then
python test_mvs4.py --dataset=general_eval4 --batch_size=1 --testpath=$DTU_TESTPATH  --testlist=$DTU_TESTLIST --loadckpt $DTU_CKPT_FILE --interval_scale 1.06 --outdir $DTU_OUT_DIR\
             --use_raw_train --thres_view 4 --conf 0.5 --group_cor --attn_temp 2 --inverse_depth $PY_ARGS | tee -a $DTU_LOG_DIR/log_test.txt
else
python test
Download .txt
gitextract_j4zdg8al/

├── .gitignore
├── LICENSE
├── README.md
├── datasets/
│   ├── __init__.py
│   ├── blendedmvs.py
│   ├── data_io.py
│   ├── dtu_yao4.py
│   ├── eth3d.py
│   ├── general_eval4.py
│   └── tanks.py
├── evaluations/
│   └── dtu/
│       ├── BaseEval2Obj_web.m
│       ├── BaseEvalMain_func.m
│       ├── BaseEvalMain_web.m
│       ├── ComputeStat_func.m
│       ├── ComputeStat_web.m
│       ├── MaxDistCP.m
│       ├── PointCompareMain.m
│       ├── plyread.m
│       └── reducePts_haa.m
├── lists/
│   ├── blendedmvs/
│   │   ├── train.txt
│   │   └── val.txt
│   └── dtu/
│       ├── test.txt
│       ├── train.txt
│       ├── trainval.txt
│       └── val.txt
├── models/
│   ├── MVS4Net.py
│   ├── __init__.py
│   ├── module.py
│   └── mvs4net_utils.py
├── requirements.txt
├── scripts/
│   ├── test_dtu.sh
│   └── train_dtu.sh
├── test_mvs4.py
├── train_mvs4.py
└── utils.py
Download .txt
SYMBOL INDEX (265 symbols across 13 files)

FILE: datasets/__init__.py
  function find_dataset_def (line 5) | def find_dataset_def(dataset_name):

FILE: datasets/blendedmvs.py
  function check_invalid_input (line 11) | def check_invalid_input(imgs, depths, masks, depth_mins, depth_maxs):
  class MVSDataset (line 26) | class MVSDataset(Dataset):
    method __init__ (line 27) | def __init__(self, datapath, listfile, split, nviews, img_wh=(768, 576...
    method build_metas (line 49) | def build_metas(self):
    method read_cam_file (line 62) | def read_cam_file(self, scan, filename):
    method read_depth_mask (line 81) | def read_depth_mask(self, scan, filename, depth_min, depth_max, scale):
    method read_img (line 108) | def read_img(self, filename):
    method __len__ (line 115) | def __len__(self):
    method __getitem__ (line 118) | def __getitem__(self, idx):

FILE: datasets/data_io.py
  function read_pfm (line 6) | def read_pfm(filename):
  function save_pfm (line 44) | def save_pfm(filename, image, scale=1):
  class RandomCrop (line 75) | class RandomCrop(object):
    method __init__ (line 76) | def __init__(self, CropSize=0.1):
    method __call__ (line 79) | def __call__(self, image, normal):

FILE: datasets/dtu_yao4.py
  class MVSDataset (line 9) | class MVSDataset(Dataset):
    method __init__ (line 10) | def __init__(self, datapath, listfile, mode, nviews, interval_scale=1....
    method build_list (line 26) | def build_list(self):
    method __len__ (line 48) | def __len__(self):
    method read_cam_file (line 51) | def read_cam_file(self, filename):
    method read_img (line 64) | def read_img(self, filename):
    method crop_img (line 72) | def crop_img(self, img):
    method prepare_img (line 78) | def prepare_img(self, hr_img):
    method read_mask_hr (line 92) | def read_mask_hr(self, filename):
    method read_depth_hr (line 108) | def read_depth_hr(self, filename, scale):
    method __getitem__ (line 123) | def __getitem__(self, idx):

FILE: datasets/eth3d.py
  class MVSDataset (line 8) | class MVSDataset(Dataset):
    method __init__ (line 9) | def __init__(self, datapath, split='test', n_views=7, img_wh=(1920,128...
    method build_metas (line 17) | def build_metas(self):
    method read_cam_file (line 40) | def read_cam_file(self, filename):
    method read_img (line 57) | def read_img(self, filename):
    method __len__ (line 64) | def __len__(self):
    method __getitem__ (line 67) | def __getitem__(self, idx):

FILE: datasets/general_eval4.py
  class MVSDataset (line 8) | class MVSDataset(Dataset):
    method __init__ (line 9) | def __init__(self, datapath, listfile, mode, nviews, interval_scale=1....
    method build_list (line 24) | def build_list(self):
    method __len__ (line 56) | def __len__(self):
    method read_cam_file (line 59) | def read_cam_file(self, filename, interval_scale):
    method read_img (line 81) | def read_img(self, filename):
    method read_depth (line 88) | def read_depth(self, filename):
    method scale_mvs_input (line 92) | def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=64):
    method __getitem__ (line 111) | def __getitem__(self, idx):

FILE: datasets/tanks.py
  class MVSDataset (line 8) | class MVSDataset(Dataset):
    method __init__ (line 9) | def __init__(self, datapath, n_views=7, split='intermediate'):
    method build_metas (line 16) | def build_metas(self):
    method read_cam_file (line 33) | def read_cam_file(self, filename):
    method read_img (line 48) | def read_img(self, filename):
    method scale_input (line 53) | def scale_input(self, intrinsics, img):
    method __len__ (line 62) | def __len__(self):
    method __getitem__ (line 65) | def __getitem__(self, idx):

FILE: models/MVS4Net.py
  class MVS4net (line 9) | class MVS4net(nn.Module):
    method __init__ (line 10) | def __init__(self, arch_mode="fpn", reg_net='reg2d', num_stage=4, fpn_...
    method forward (line 60) | def forward(self, imgs, proj_matrices, depth_values, filename=None):
  function MVS4net_loss (line 113) | def MVS4net_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
  function Blend_loss (line 158) | def Blend_loss(inputs, depth_gt_ms, mask_ms, **kwargs):

FILE: models/module.py
  function init_bn (line 14) | def init_bn(module):
  function init_uniform (line 22) | def init_uniform(module, init_method):
  class Conv2d (line 30) | class Conv2d(nn.Module):
    method __init__ (line 44) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
    method forward (line 58) | def forward(self, x):
    method init_weights (line 66) | def init_weights(self, init_method):
  class DCNConv2d (line 72) | class DCNConv2d(nn.Module):
    method __init__ (line 74) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
    method forward (line 87) | def forward(self, x):
    method init_weights (line 95) | def init_weights(self, init_method):
  class Deconv2d (line 101) | class Deconv2d(nn.Module):
    method __init__ (line 115) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
    method forward (line 130) | def forward(self, x):
    method init_weights (line 141) | def init_weights(self, init_method):
  class Conv3d (line 147) | class Conv3d(nn.Module):
    method __init__ (line 161) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
    method forward (line 177) | def forward(self, x):
    method init_weights (line 185) | def init_weights(self, init_method):
  class PConv3d (line 191) | class PConv3d(nn.Module):
    method __init__ (line 193) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
    method forward (line 213) | def forward(self, x):
    method init_weights (line 222) | def init_weights(self, init_method):
  class Deconv3d (line 230) | class Deconv3d(nn.Module):
    method __init__ (line 244) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
    method forward (line 259) | def forward(self, x):
    method init_weights (line 267) | def init_weights(self, init_method):
  class PDeconv3d (line 274) | class PDeconv3d(nn.Module):
    method __init__ (line 276) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
    method forward (line 298) | def forward(self, x):
    method init_weights (line 307) | def init_weights(self, init_method):
  class ConvBnReLU (line 314) | class ConvBnReLU(nn.Module):
    method __init__ (line 315) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
    method forward (line 320) | def forward(self, x):
  class ConvBn (line 323) | class ConvBn(nn.Module):
    method __init__ (line 324) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
    method forward (line 329) | def forward(self, x):
  class ConvBnReLU3D (line 332) | class ConvBnReLU3D(nn.Module):
    method __init__ (line 333) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
    method forward (line 338) | def forward(self, x):
  class ConvBn3D (line 342) | class ConvBn3D(nn.Module):
    method __init__ (line 343) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
    method forward (line 348) | def forward(self, x):
  class BasicBlock (line 352) | class BasicBlock(nn.Module):
    method __init__ (line 353) | def __init__(self, in_channels, out_channels, stride, downsample=None):
    method forward (line 362) | def forward(self, x):
  class Hourglass3d (line 371) | class Hourglass3d(nn.Module):
    method __init__ (line 372) | def __init__(self, channels):
    method forward (line 394) | def forward(self, x):
  function homo_warping (line 402) | def homo_warping(src_fea, src_proj, ref_proj, depth_values, align_corner...
  class DeConv2dFuse (line 437) | class DeConv2dFuse(nn.Module):
    method __init__ (line 438) | def __init__(self, in_channels, out_channels, kernel_size, relu=True, ...
    method forward (line 451) | def forward(self, x_pre, x):
  class FeatureNet (line 458) | class FeatureNet(nn.Module):
    method __init__ (line 459) | def __init__(self, base_channels, num_stage=3, stride=4, arch_mode="un...
    method forward (line 520) | def forward(self, x):
  class FPNDCNpath (line 561) | class FPNDCNpath(nn.Module):
    method __init__ (line 564) | def __init__(self, base_channels, stride=4):
    method forward (line 612) | def forward(self, x):
  class FPNDCN (line 635) | class FPNDCN(nn.Module):
    method __init__ (line 638) | def __init__(self, base_channels, stride=4):
    method forward (line 684) | def forward(self, x):
  class FPNA (line 705) | class FPNA(nn.Module):
    method __init__ (line 708) | def __init__(self, base_channels, stride=4):
    method forward (line 743) | def forward(self, x):
  class FPNA4 (line 764) | class FPNA4(nn.Module):
    method __init__ (line 767) | def __init__(self, base_channels):
    method forward (line 810) | def forward(self, x):
  class CostRegNet (line 836) | class CostRegNet(nn.Module):
    method __init__ (line 837) | def __init__(self, in_channels, base_channels, down_size=3):
    method forward (line 860) | def forward(self, x):
  class P3DConv (line 884) | class P3DConv(nn.Module):
    method __init__ (line 888) | def __init__(self, in_channels, base_channels):
    method forward (line 909) | def forward(self, x):
  class RefineNet (line 920) | class RefineNet(nn.Module):
    method __init__ (line 921) | def __init__(self):
    method forward (line 928) | def forward(self, img, depth_init):
  function depth_regression (line 935) | def depth_regression(p, depth_values):
  function cas_mvsnet_loss (line 943) | def cas_mvsnet_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
  function cas_mvsnet_T_loss (line 964) | def cas_mvsnet_T_loss(inputs, depth_gt_ms, mask_ms, **kwargs):
  function get_cur_depth_range_samples (line 1138) | def get_cur_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, ...
  function get_depth_range_samples (line 1157) | def get_depth_range_samples(cur_depth, ndepth, depth_inteval_pixel, devi...

FILE: models/mvs4net_utils.py
  function homo_warping (line 13) | def homo_warping(src_fea, src_proj, ref_proj, depth_values, vis_ETA=Fals...
  function init_range (line 61) | def init_range(cur_depth, ndepths, device, dtype, H, W):
  function init_inverse_range (line 71) | def init_inverse_range(cur_depth, ndepths, device, dtype, H, W):
  function schedule_inverse_range (line 79) | def schedule_inverse_range(inverse_min_depth, inverse_max_depth, ndepths...
  function schedule_range (line 88) | def schedule_range(cur_depth, ndepth, depth_inteval_pixel, H, W):
  function init_bn (line 101) | def init_bn(module):
  function init_uniform (line 108) | def init_uniform(module, init_method):
  class ConvBnReLU3D (line 116) | class ConvBnReLU3D(nn.Module):
    method __init__ (line 117) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
    method forward (line 122) | def forward(self, x):
  class ConvBnReLU3D_CAM (line 125) | class ConvBnReLU3D_CAM(nn.Module):
    method __init__ (line 126) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
    method forward (line 136) | def forward(self, input):
  class ConvBnReLU3D_DCAM (line 145) | class ConvBnReLU3D_DCAM(nn.Module):
    method __init__ (line 146) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
    method forward (line 156) | def forward(self, input):
  class ConvBnReLU3D_PAM (line 165) | class ConvBnReLU3D_PAM(nn.Module):
    method __init__ (line 166) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
    method forward (line 172) | def forward(self, input):
  class ConvBnReLU3D_PDAM (line 181) | class ConvBnReLU3D_PDAM(nn.Module):
    method __init__ (line 182) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,...
    method forward (line 188) | def forward(self, input):
  class Deconv3d (line 197) | class Deconv3d(nn.Module):
    method __init__ (line 199) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
    method forward (line 211) | def forward(self, x):
    method init_weights (line 219) | def init_weights(self, init_method):
  class Conv2d (line 224) | class Conv2d(nn.Module):
    method __init__ (line 226) | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
    method forward (line 238) | def forward(self, x):
    method init_weights (line 248) | def init_weights(self, init_method):
  class Deconv2d (line 253) | class Deconv2d(nn.Module):
    method __init__ (line 255) | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
  class DeformConv2d (line 267) | class DeformConv2d(nn.Module):
    method __init__ (line 268) | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias...
    method _set_lr (line 287) | def _set_lr(module, grad_input, grad_output):
    method forward (line 291) | def forward(self, x):
    method _get_p_n (line 349) | def _get_p_n(self, N, dtype):
    method _get_p_0 (line 359) | def _get_p_0(self, h, w, N, dtype):
    method _get_p (line 369) | def _get_p(self, offset, dtype):
    method _get_x_q (line 379) | def _get_x_q(self, x, q, N):
    method _reshape_x_offset (line 396) | def _reshape_x_offset(x_offset, ks):
  function NA_DCN (line 403) | def NA_DCN(in_channels, kernel_size=3, stride=1, dilation=1, bias=True, ...
  class FPN4 (line 419) | class FPN4(nn.Module):
    method __init__ (line 422) | def __init__(self, base_channels, gn=False, dcn=False):
    method forward (line 472) | def forward(self, x):
  class LayerNorm (line 504) | class LayerNorm(nn.Module):
    method __init__ (line 506) | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_l...
    method forward (line 516) | def forward(self, x):
  class convnext_block (line 526) | class convnext_block(nn.Module):
    method __init__ (line 528) | def __init__(self, dim, layer_scale_init_value=1e-6):
    method forward (line 538) | def forward(self, x):
  class convnext4_block (line 553) | class convnext4_block(nn.Module):
    method __init__ (line 555) | def __init__(self, dim, layer_scale_init_value=1e-6):
    method forward (line 566) | def forward(self, x):
  class FPN4_convnext (line 581) | class FPN4_convnext(nn.Module):
    method __init__ (line 584) | def __init__(self, base_channels, gn=False, dcn=False):
    method forward (line 620) | def forward(self, x):
  class FPN4_convnext4 (line 652) | class FPN4_convnext4(nn.Module):
    method __init__ (line 655) | def __init__(self, base_channels, gn=False, dcn=False):
    method forward (line 691) | def forward(self, x):
  class ASFF (line 723) | class ASFF(nn.Module):
    method __init__ (line 724) | def __init__(self, level):
    method forward (line 758) | def forward(self, x_level_0, x_level_1, x_level_2, x_level_3):
  class FullImageEncoder (line 807) | class FullImageEncoder(nn.Module):
    method __init__ (line 808) | def __init__(self, h, w, kernel_size):
    method forward (line 819) | def forward(self, x):
  class mono_depth_decoder (line 833) | class mono_depth_decoder(nn.Module):
    method __init__ (line 835) | def __init__(self):
    method forward (line 849) | def forward(self, outputs, d_min, d_max):
  class reg2d (line 870) | class reg2d(nn.Module):
    method __init__ (line 871) | def __init__(self, input_channel=128, base_channel=32, conv_name='Conv...
    method forward (line 902) | def forward(self, x):
  class reg3d (line 914) | class reg3d(nn.Module):
    method __init__ (line 915) | def __init__(self, in_channels, base_channels, down_size=3):
    method forward (line 943) | def forward(self, x):
  class PosEncSine (line 967) | class PosEncSine(nn.Module):
    method __init__ (line 969) | def __init__(self, temperature=1000):
    method forward (line 973) | def forward(self, x, depth):
  class PosEncLearned (line 983) | class PosEncLearned(nn.Module):
    method __init__ (line 987) | def __init__(self, D, C):
    method reset_parameters (line 994) | def reset_parameters(self):
    method forward (line 997) | def forward(self, x, **kwargs):
  class stagenet (line 1003) | class stagenet(nn.Module):
    method __init__ (line 1004) | def __init__(self, inverse_depth=False, mono=False, attn_fuse_d=True, ...
    method forward (line 1012) | def forward(self, features, proj_matrices, depth_hypo, regnet, stage_i...
  function sinkhorn (line 1096) | def sinkhorn(gt_depth, hypo_depth, attn_weight, mask, iters, eps=1, cont...

FILE: test_mvs4.py
  function read_camera_parameters (line 94) | def read_camera_parameters(filename):
  function read_img (line 106) | def read_img(filename):
  function read_mask (line 114) | def read_mask(filename):
  function save_mask (line 119) | def save_mask(filename, mask):
  function read_pair_file (line 126) | def read_pair_file(filename):
  function write_cam (line 138) | def write_cam(file, cam):
  function save_depth (line 157) | def save_depth(testlist):
  function save_scene_depth (line 170) | def save_scene_depth(testlist):
  function reproject_with_depth (line 273) | def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, dept...
  function check_geometric_consistency (line 313) | def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_re...
  function filter_depth (line 331) | def filter_depth(pair_folder, scan_folder, out_folder, plyfilename):
  function init_worker (line 424) | def init_worker():
  function pcd_filter_worker (line 431) | def pcd_filter_worker(scan):
  function pcd_filter (line 443) | def pcd_filter(testlist, number_worker):
  function mrun_rst (line 457) | def mrun_rst(eval_dir, plyPath):

FILE: train_mvs4.py
  function train (line 83) | def train(model, model_loss, optimizer, TrainImgLoader, TestImgLoader, s...
  function test (line 179) | def test(model, model_loss, TestImgLoader, args):
  function train_sample (line 195) | def train_sample(model, model_loss, optimizer, sample, args):
  function test_sample_depth (line 253) | def test_sample_depth(model, model_loss, sample, args):

FILE: utils.py
  function print_args (line 8) | def print_args(args):
  function make_nograd_func (line 16) | def make_nograd_func(func):
  function make_recursive_func (line 26) | def make_recursive_func(func):
  function tensor2float (line 41) | def tensor2float(vars):
  function tensor2numpy (line 51) | def tensor2numpy(vars):
  function tocuda (line 61) | def tocuda(vars):
  function save_scalars (line 70) | def save_scalars(logger, mode, scalar_dict, global_step):
  function save_images (line 82) | def save_images(logger, mode, images_dict, global_step):
  class DictAverageMeter (line 103) | class DictAverageMeter(object):
    method __init__ (line 104) | def __init__(self):
    method update (line 108) | def update(self, new_input):
    method mean (line 121) | def mean(self):
  function compute_metrics_for_each_image (line 126) | def compute_metrics_for_each_image(metric_func):
  function Thres_metrics (line 141) | def Thres_metrics(depth_est, depth_gt, mask, thres):
  function AbsDepthError_metrics (line 152) | def AbsDepthError_metrics(depth_est, depth_gt, mask, thres=None):
  function synchronize (line 162) | def synchronize():
  function get_world_size (line 176) | def get_world_size():
  function reduce_scalar_outputs (line 183) | def reduce_scalar_outputs(scalar_outputs):
  class WarmupMultiStepLR (line 208) | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
    method __init__ (line 209) | def __init__(
    method get_lr (line 237) | def get_lr(self):
  function set_random_seed (line 253) | def set_random_seed(seed):
  function local_pcd (line 260) | def local_pcd(depth, intr):
  function generate_pointcloud (line 274) | def generate_pointcloud(rgb, depth, ply_file, intr, scale=1.0):
Condensed preview — 35 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (263K chars).
[
  {
    "path": ".gitignore",
    "chars": 46,
    "preview": "outputs/\ncheckpoints/\ndebug_figs/\n*__pycache__"
  },
  {
    "path": "LICENSE",
    "chars": 1066,
    "preview": "MIT License\n\nCopyright (c) 2022 Jeff Wang\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\n"
  },
  {
    "path": "README.md",
    "chars": 5508,
    "preview": "# MVSTER\nMVSTER: Epipolar Transformer for Efficient Multi-View Stereo, ECCV 2022. [arXiv](https://arxiv.org/abs/2204.073"
  },
  {
    "path": "datasets/__init__.py",
    "chars": 271,
    "preview": "import importlib\n\n\n# find the dataset definition by name, for example dtu_yao (dtu_yao.py)\ndef find_dataset_def(dataset_"
  },
  {
    "path": "datasets/blendedmvs.py",
    "chars": 7911,
    "preview": "from torch.utils.data import Dataset\nfrom datasets.data_io import *\nimport os\nimport numpy as np\nimport cv2\nfrom PIL imp"
  },
  {
    "path": "datasets/data_io.py",
    "chars": 3260,
    "preview": "import numpy as np\nimport re\nimport sys\n\n\ndef read_pfm(filename):\n    file = open(filename, 'rb')\n    color = None\n    w"
  },
  {
    "path": "datasets/dtu_yao4.py",
    "chars": 8342,
    "preview": "from torch.utils.data import Dataset\nimport numpy as np\nimport os, cv2, time, math\nfrom PIL import Image\nfrom datasets.d"
  },
  {
    "path": "datasets/eth3d.py",
    "chars": 5210,
    "preview": "from torch.utils.data import Dataset\nfrom datasets.data_io import *\nimport os\nimport numpy as np\nimport cv2\nfrom PIL imp"
  },
  {
    "path": "datasets/general_eval4.py",
    "chars": 7515,
    "preview": "from torch.utils.data import Dataset\nimport numpy as np\nimport os, cv2, time\nfrom PIL import Image\nfrom datasets.data_io"
  },
  {
    "path": "datasets/tanks.py",
    "chars": 4948,
    "preview": "from torch.utils.data import Dataset\nfrom datasets.data_io import *\nimport os\nimport numpy as np\nimport cv2\nfrom PIL imp"
  },
  {
    "path": "evaluations/dtu/BaseEval2Obj_web.m",
    "chars": 1501,
    "preview": "function BaseEval2Obj_web(BaseEval,method_string,outputPath)\r\n\r\nif(nargin<3)\r\n    outputPath='./';\r\nend\r\n\r\n% tresshold f"
  },
  {
    "path": "evaluations/dtu/BaseEvalMain_func.m",
    "chars": 2866,
    "preview": "function None = BaseEvalMain_func(plyPath)\r\n\r\n% clear all\r\n% close all\r\nformat compact\r\n\r\n% script to calculate distance"
  },
  {
    "path": "evaluations/dtu/BaseEvalMain_web.m",
    "chars": 2814,
    "preview": "clear all\r\nclose all\r\nformat compact\r\nclc\r\n\r\n% script to calculate distances have been measured for all included scans ("
  },
  {
    "path": "evaluations/dtu/ComputeStat_func.m",
    "chars": 2816,
    "preview": "function None = ComputeStat_func(plyPath)\r\nformat compact\r\n\r\n% script to calculate the statistics for each scan given th"
  },
  {
    "path": "evaluations/dtu/ComputeStat_web.m",
    "chars": 2758,
    "preview": "clear all\r\nclose all\r\nformat compact\r\nclc\r\n\r\n% script to calculate the statistics for each scan given this will currentl"
  },
  {
    "path": "evaluations/dtu/MaxDistCP.m",
    "chars": 1444,
    "preview": "function Dist = MaxDistCP(Qto,Qfrom,BB,MaxDist)\r\n\r\nDist=ones(1,size(Qfrom,2))*MaxDist;\r\n\r\nRange=floor((BB(2,:)-BB(1,:))/"
  },
  {
    "path": "evaluations/dtu/PointCompareMain.m",
    "chars": 2103,
    "preview": "function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath)\r\n% evaluation function the calculates the distantes from the"
  },
  {
    "path": "evaluations/dtu/plyread.m",
    "chars": 15651,
    "preview": "%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%\r\nfunction [Elements,varargout] = plyread(Path,Str)\r\n"
  },
  {
    "path": "evaluations/dtu/reducePts_haa.m",
    "chars": 900,
    "preview": "function [ptsOut,indexSet] = reducePts_haa(pts, dst)\n\n%Reduces a point set, pts, in a stochastic manner, such that the m"
  },
  {
    "path": "lists/blendedmvs/train.txt",
    "chars": 2650,
    "preview": "5c1f33f1d33e1f2e4aa6dda4\n5bfe5ae0fe0ea555e6a969ca\n5bff3c5cfe0ea555e6bcbf3a\n58eaf1513353456af3a1682a\n5bfc9d5aec61ca1dd691"
  },
  {
    "path": "lists/blendedmvs/val.txt",
    "chars": 182,
    "preview": "5b7a3890fc8fcf6781e2593a\r\n5c189f2326173c3a09ed7ef3\r\n5b950c71608de421b1e7318f\r\n5a6400933d809f1d8200af15\r\n59d2657f82ca7774"
  },
  {
    "path": "lists/dtu/test.txt",
    "chars": 153,
    "preview": "scan1\nscan4\nscan9\nscan10\nscan11\nscan12\nscan13\nscan15\nscan23\nscan24\nscan29\nscan32\nscan33\nscan34\nscan48\nscan49\nscan62\nscan"
  },
  {
    "path": "lists/dtu/train.txt",
    "chars": 572,
    "preview": "scan2\nscan6\nscan7\nscan8\nscan14\nscan16\nscan18\nscan19\nscan20\nscan22\nscan30\nscan31\nscan36\nscan39\nscan41\nscan42\nscan44\nscan4"
  },
  {
    "path": "lists/dtu/trainval.txt",
    "chars": 698,
    "preview": "scan2\nscan6\nscan7\nscan8\nscan14\nscan16\nscan18\nscan19\nscan20\nscan22\nscan30\nscan31\nscan36\nscan39\nscan41\nscan42\nscan44\nscan4"
  },
  {
    "path": "lists/dtu/val.txt",
    "chars": 125,
    "preview": "scan3\nscan5\nscan17\nscan21\nscan28\nscan35\nscan37\nscan38\nscan40\nscan43\nscan56\nscan59\nscan66\nscan67\nscan82\nscan86\nscan106\nsc"
  },
  {
    "path": "models/MVS4Net.py",
    "chars": 10195,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nfrom models.mvs4net_utils import s"
  },
  {
    "path": "models/__init__.py",
    "chars": 61,
    "preview": "\nfrom models.MVS4Net import MVS4net, MVS4net_loss, Blend_loss"
  },
  {
    "path": "models/module.py",
    "chars": 52886,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport time\nimport sys\nimport seaborn as sns\nimport n"
  },
  {
    "path": "models/mvs4net_utils.py",
    "chars": 50767,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport importlib\ntry:\n    from modules.deform_conv im"
  },
  {
    "path": "requirements.txt",
    "chars": 80,
    "preview": "torch==1.9.0\ntorchvision==0.10.0\nnumpy\npillow\ntensorboardX\nopencv-python\nplyfile"
  },
  {
    "path": "scripts/test_dtu.sh",
    "chars": 996,
    "preview": "#!/usr/bin/env bash\nDTU_TESTPATH=\"/mnt/cfs/algorithm/public_data/mvs/dtu_test\"\nDTU_TESTLIST=\"lists/dtu/test.txt\"\n\nDTU_si"
  },
  {
    "path": "scripts/train_dtu.sh",
    "chars": 1105,
    "preview": "#!/usr/bin/env bash\nDTU_TRAINING=\"/mnt/cfs/algorithm/public_data/mvs/mvs_training/dtu\"\nDTU_TRAINLIST=\"lists/dtu/train.tx"
  },
  {
    "path": "test_mvs4.py",
    "chars": 22811,
    "preview": "import argparse, os, time, sys, gc, cv2\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.backend"
  },
  {
    "path": "train_mvs4.py",
    "chars": 22156,
    "preview": "import argparse, os, sys, time, gc, datetime\nimport torch\nimport torch.nn as nn\nimport torch.nn.parallel\nimport torch.ba"
  },
  {
    "path": "utils.py",
    "chars": 10159,
    "preview": "import numpy as np\nimport torchvision.utils as vutils\nimport torch, random\nimport torch.nn.functional as F\n\n\n# print arg"
  }
]

About this extraction

This page contains the full source code of the JeffWang987/MVSTER GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 35 files (246.6 KB), approximately 71.3k tokens, and a symbol index with 265 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!