Full Code of jingsenzhu/i2-sdf for AI

main 58c9a8241feb cached
31 files
217.1 KB
59.9k tokens
262 symbols
1 requests
Download .txt
Showing preview only (228K chars total). Download the full file or copy to clipboard to get everything.
Repository: jingsenzhu/i2-sdf
Branch: main
Commit: 58c9a8241feb
Files: 31
Total size: 217.1 KB

Directory structure:
gitextract_xvem6nqy/

├── .gitignore
├── DATA_CONVENTION.md
├── LICENSE
├── README.md
├── config/
│   ├── synthetic.yml
│   └── synthetic_light_mask.yml
├── data/
│   ├── normalize_cameras.py
│   └── npz_to_blender.py
├── dataset/
│   ├── __init__.py
│   ├── eval_dataset.py
│   └── train_dataset.py
├── environment.yml
├── i2-sdf-dataset-links.csv
├── main_recon.py
├── model/
│   ├── __init__.py
│   ├── eval/
│   │   ├── __init__.py
│   │   └── recon.py
│   ├── network/
│   │   ├── __init__.py
│   │   ├── density.py
│   │   ├── embedder.py
│   │   ├── mlp.py
│   │   └── ray_sampler.py
│   ├── rendering/
│   │   ├── __init__.py
│   │   └── brdf.py
│   └── trainer/
│       ├── __init__.py
│       └── recon.py
└── utils/
    ├── __init__.py
    ├── cfgnode.py
    ├── mesh_util.py
    ├── plots.py
    └── rend_util.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
data/synthetic
exps
__pycache__
archive
tmp*

================================================
FILE: DATA_CONVENTION.md
================================================
# Data Convention

The format of our multi-view dataset is derived from [VolSDF](https://github.com/lioryariv/volsdf/blob/main/DATA_CONVENTION.md).

### Directory Structure

```python
scan<scan_id>/
	cameras.npz
	image/ -> {:04d}.png # tone-mapped LDR images
	depth/ -> {:04d}.exr
	normal/ -> {:04d}.exr
	mask/ -> {:04d}.png
	val/ -> {:04d}.png # validation images (LDR)
	hdr/ -> {:04d}.exr # raw HDR images
	# followings are optional
	light_mask/ -> {:04d}.png # emitter mask images
	material/ -> {:04d}_kd.exr, {:04d}_ks.exr, {:04d}_rough.exr # diffuse, specular albedo and roughness
```

Zeros areas in the depth maps and normal maps indicate invalid areas such as windows.

Note that not all areas inside the scenes use the GGX material model. We'll provide a mask of the invalid areas.

### Camera Information

The `cameras.npz` contains each image's associated camera projection matrix `'world_mat_{i}'` and a normalization matrix `'scale_mat_{i}'`, the same as VolSDF. Besides, we also provide a validation set of images for novel view synthesis, whose associated camera projection matrices are `'val_mat_{i}'`. Validation set and training set share the same normalization matrix.

The normalization matrices may not be readily available in `cameras.npz`. You can manually run `data/normalize_cameras.py` to generate `cameras_normalize.npz`. Since our method requires the entire scene to be within a radius-3 bounding sphere, we suggest normalizing cameras by radius 2.0 or 2.5.

An example of running `normalize_cameras.py`:

```shell
python normalize_cameras.py --id <scan_id> -n <synthetic/real/...> -r 2.0
```

Note that we follow **OpenCV camera coordinate system** (X right, Y downwards, Z into the image plane).

### Dataset Format Conversion

If you want to convert the dataset format to NeRF blender format, run `npz_to_blender.py`:

```sh
python npz_to_blender.py --root /path/to/dataset
```

The script will automatically scale all pose matrices to fit within a `[-1, 1]` bounding box.

### About Real Dataset

Our real dataset comes from [Inria](https://repo-sam.inria.fr/fungraph/deep-indoor-relight/) and [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/), with estimated depth from MVS tools and manually-labeled light masks (2 living room scenes). All depths has an absolute scale without needs of shifting like MonoSDF. All camera calibrations and depths are provided by the authors of [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/). We thank them for providing the datasets.

Normal is not provided in the real dataset, and we find it sufficient for plausible reconstruction without a normal supervision in these scenes. Of course, you can estimate normal using any methods if you want to enable normal supervision.

### About EXR format

We suggest using OpenCV to load an `.exr` format `float32` image:

```python
import os
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' # Enable OpenCV support for EXR
import cv2
...
im = cv2.imread(im_path, -1) # im will be an numpy.float32 array of shape (H, W, C)
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # cv2 reads image in BGR shape, convert into RGB
```

We suggest using [tev](https://github.com/Tom94/tev) to preview HDR `.exr` images conveniently.

### About Mesh

Due to copyright issues, we could not release the original 3D mesh of our synthetic scenes. Instead, we'll provide a point cloud sampled from the GT mesh to enable 3D reconstruction evaluations.

================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 Jingsen Zhu

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
**News**

- `04/04/2023` dataset preview release: 2 synthetic scenes available
- `15/04/2023` code release: 3D reconstruction and novel view synthesis part
- `21/04/2023` dataset release: real data

**TODO**

- [ ] Full dataset release
- [x] Code release for 3D reconstruction and novel view synthesis
- [ ] Code release for intrinsic decomposition and scene editing

**Dataset released**

- Synthetic: `kitchen_0`, `bedroom_relight_0`, `bedroom_0`, `bedroom_1`, `bedroom_relight_1`, `diningroom_0`, `livingroom_0`, `livingroom_1`, more scenes to be released
- Real: `inria_livingroom`, `nisr_livingroom`, `nisr_coffee_shop_0`, `nisr_coffee_shop_1`, release complete

# I<sup>2</sup>-SDF: Intrinsic Indoor Scene Reconstruction and Editing via Raytracing in Neural SDFs (CVPR 2023)

### [Project Page](https://jingsenzhu.github.io/i2-sdf/) | [Paper](https://arxiv.org/abs/2303.07634) | [Dataset](i2-sdf-dataset-links.csv)

## Setup

### Installation

```
conda env create -f environment.yml
conda activate i2sdf
```

### Data preparation

Download our synthetic dataset and extract them into `data/synthetic`. If you want to run on your customized dataset, we provide a brief introduction to our data convention [here](DATA_CONVENTION.md).

## Dataset

We provide a high-quality synthetic indoor scene multi-view dataset, with ground truth camera pose and geometry annotations. See [HERE](DATA_CONVENTION.md) for data conventions. Click [HERE](https://mega.nz/folder/jdhDnTqL#Ija678SU2Va_JJOiwqmdEg) to download.

## 3D Reconstruction and Novel View Synthesis

### Training

```
python main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version>
```

Note: `config/synthetic.yml` doesn't contain light mask network, while `config/synthetic_light_mask.yml` contains.

If you run out of GPU memory, try to reduce the `split_n_pixels` (i.e. validation batch size), `batch_size` in the config. The default parameters are evaluated under RTX A6000 (48GB). For RTX 3090 (24GB), try to set `split_n_pixels` 5000.

### Evaluation

#### Novel view synthesis

```
python main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version> --test [--is_val] [--full]
```

The optional flag `--is_val` evaluates on the validation set instead of training set, `--full` produces full-resolution rendered images without downsampling.

#### View Interpolation

```
python main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version> --test --test_mode interpolate --inter_id <view_id_0> <view_id_1> [--full]
```

Generates a view interpolation video between 2 views. Requires `ffmpeg` being installed.

The number of frames and frame rate of the video can be specified by options.

#### Mesh Extraction

```
python main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version> --test --test_mode mesh
```

## Intrinsic Decomposition and Scene Editing

**Brewing🍺, code coming soon.**

## Citation

If you find our work is useful, please consider cite:

```
@inproceedings{zhu2023i2sdf,
    title = {I$^2$-SDF: Intrinsic Indoor Scene Reconstruction and Editing via Raytracing in Neural SDFs},
    author = {Jingsen Zhu and Yuchi Huo and Qi Ye and Fujun Luan and Jifan Li and Dianbing Xi and Lisha Wang and Rui Tang and Wei Hua and Hujun Bao and Rui Wang},
    booktitle = {CVPR},
    year = {2023}
}
```

## Acknowledgement

- This repository is built upon [Pytorch lightning](https://lightning.ai/).
- Thanks to Lior Yariv for her excellent work [VolSDF](https://lioryariv.github.io/volsdf/).
- Thanks to [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/) team for providing their real-world dataset.


================================================
FILE: config/synthetic.yml
================================================
train:
    expname: synthetic
    learning_rate: 5.0e-4
    steps: 200000
    checkpoint_freq: 10000
    plot_freq: 500
    split_n_pixels: 12000
    batch_size: 1600
    pdf_criterion: DEPTH

plot:
    plot_nimgs: 1
    grid_boundary: [-1.5, 1.5]

loss:
    eikonal_weight: 0.1
    smooth_weight: 0.01
    smooth_iter: 150000
    depth_weight: 0.1
    normal_weight: 0.05
    bubble_weight: 0.5
    min_bubble_iter: 50000
    max_bubble_iter: 150000

dataset:
    data_dir: synthetic
    img_res: [480, 640]
    downsample: 2
    pdf_prune: 0.05
    pdf_max: 0.2

model:
    feature_vector_size: 256
    scene_bounding_sphere: 3.0
    implicit_network:
        d_in: 3
        d_out: 1
        dims: [ 256, 256, 256, 256, 256, 256, 256, 256 ]
        geometric_init: True
        bias: 0.6
        skip_in: [4]
        weight_norm: True
        embed_type: 'positional'
        multires: 6
    
    rendering_network:
        # mode: idr
        # d_in: 9
        # Don't find actual differences between 'nerf' and 'idr' mode
        # Choose 'nerf' mode for a slight faster performance
        mode: nerf
        d_in: 3
        d_out: 3
        dims: [ 256, 256, 256, 256 ]
        weight_norm: True
        embed_type: 'positional'
        multires: 4
    
    density:
        params_init:
            beta: 0.1
        
        beta_min: 0.0001
    
    ray_sampler:
        near: 0.0
        N_samples: 64
        N_samples_eval: 128
        N_samples_extra: 32
        eps: 0.1
        beta_iters: 10
        max_total_iters: 5
        N_samples_inverse_sphere: 32
        add_tiny: 1.0e-6



================================================
FILE: config/synthetic_light_mask.yml
================================================
train:
    expname: synthetic_light
    learning_rate: 5.0e-4
    steps: 200000
    checkpoint_freq: 10000
    plot_freq: 500
    split_n_pixels: 12000
    batch_size: 1600
    pdf_criterion: DEPTH

plot:
    plot_nimgs: 1
    grid_boundary: [-1.5, 1.5]

loss:
    eikonal_weight: 0.1
    smooth_weight: 0.01
    smooth_iter: 150000
    depth_weight: 0.1
    normal_weight: 0.05
    bubble_weight: 0.5
    light_mask_weight: 0.5
    min_bubble_iter: 50000
    max_bubble_iter: 150000

dataset:
    data_dir: synthetic
    img_res: [480, 640]
    downsample: 2
    pdf_prune: 0.05
    pdf_max: 0.2

model:
    feature_vector_size: 256
    scene_bounding_sphere: 3.0
    implicit_network:
        d_in: 3
        d_out: 1
        dims: [ 256, 256, 256, 256, 256, 256 ]
        geometric_init: True
        bias: 0.6
        skip_in: [3]
        weight_norm: True
        embed_type: 'positional'
        multires: 6
    
    rendering_network:
        mode: nerf
        d_in: 3
        d_out: 3
        dims: [ 256, 256, 256 ]
        weight_norm: True
        embed_type: 'positional'
        multires: 4
    
    light_network:
        dims: [ 128 ]
        weight_norm: True
    
    density:
        params_init:
            beta: 0.1
        
        beta_min: 0.0001
    
    ray_sampler:
        near: 0.0
        N_samples: 64
        N_samples_eval: 128
        N_samples_extra: 32
        eps: 0.1
        beta_iters: 10
        max_total_iters: 5
        N_samples_inverse_sphere: 32
        add_tiny: 1.0e-6



================================================
FILE: data/normalize_cameras.py
================================================
import cv2
import numpy as np
import argparse
from copy import deepcopy

def get_center_point(num_cams,cameras):
    A = np.zeros((3 * num_cams, 3 + num_cams))
    b = np.zeros((3 * num_cams, 1))
    camera_centers=np.zeros((3,num_cams))
    for i in range(num_cams):
        P0 = cameras['world_mat_%d' % i][:3, :]

        K = cv2.decomposeProjectionMatrix(P0)[0]
        R = cv2.decomposeProjectionMatrix(P0)[1]
        c = cv2.decomposeProjectionMatrix(P0)[2]
        c = c / c[3]
        camera_centers[:,i]=c[:3].flatten()

        # v = np.linalg.inv(K) @ np.array([800, 600, 1])
        # v = v / np.linalg.norm(v)

        v=R[2,:]
        A[3 * i:(3 * i + 3), :3] = np.eye(3)
        A[3 * i:(3 * i + 3), 3 + i] = -v
        b[3 * i:(3 * i + 3)] = c[:3]

    soll= np.linalg.pinv(A) @ b

    return soll,camera_centers

def normalize_cameras(original_cameras_filename,output_cameras_filename,num_of_cameras,radius,convert_coord):
    cameras = np.load(original_cameras_filename)
    if num_of_cameras==-1:
        all_files=cameras.files
        maximal_ind=0
        for field in all_files:
            if 'val' not in field:
                maximal_ind=np.maximum(maximal_ind,int(field.split('_')[-1]))
        num_of_cameras=maximal_ind+1
    soll, camera_centers = get_center_point(num_of_cameras, cameras)

    center = soll[:3].flatten()

    max_radius = np.linalg.norm((center[:, np.newaxis] - camera_centers), axis=0).max() * 1.1

    normalization = np.eye(4).astype(np.float32)

    normalization[0, 3] = center[0]
    normalization[1, 3] = center[1]
    normalization[2, 3] = center[2]

    normalization[0, 0] = max_radius / radius
    normalization[1, 1] = max_radius / radius
    normalization[2, 2] = max_radius / radius

    cameras_new = {}
    cameras_new = deepcopy(dict(cameras))
    for i in range(num_of_cameras):
        cameras_new['scale_mat_%d' % i] = normalization
        # cameras_new['world_mat_%d' % i] = cameras['world_mat_%d' % i].copy()
        # if ('val_mat_%d' % i) in cameras:
        #     cameras_new['val_mat_%d' % i] = cameras['val_mat_%d' % i].copy()
        
        def opengl2opencv(P):
            out = cv2.decomposeProjectionMatrix(P[:3,:])
            K, R, t = out[0:3]
            K = K/K[2,2]
            intrinsics = np.eye(4, dtype=np.float32)
            intrinsics[:3, :3] = K
            t = (t[:3] / t[3]).squeeze()
            w2c = np.eye(4, dtype=np.float32)
            w2c[:3,:3] = R
            w2c[:3,3] = -R @ t
            T = np.diag([1, -1, -1, 1])
            w2c = T @ w2c
            return intrinsics @ w2c
        if convert_coord:
            cameras_new['world_mat_%d' % i] = opengl2opencv(cameras_new['world_mat_%d' % i])
            if ('val_mat_%d' % i) in cameras_new:
                cameras_new['val_mat_%d' % i] = opengl2opencv(cameras_new['val_mat_%d' % i])
            # cameras_new['world_mat_%d' % i] = T @ cameras_new['world_mat_%d' % i]
    np.savez(output_cameras_filename, **cameras_new)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Normalizing cameras')
    parser.add_argument('-i', '--input_cameras_file', type=str, default="cameras.npz",
                        help='the input cameras file')
    parser.add_argument('-o', '--output_cameras_file', type=str, default="cameras_normalize.npz",
                        help='the output cameras file')
    parser.add_argument('--id', type=int, nargs='?')
    parser.add_argument('-n', '--name', type=str, default='synthetic')
    parser.add_argument('--number_of_cams',type=int, default=-1,
                        help='Number of cameras, if -1 use all')
    parser.add_argument('-r', '--radius', type=float, default=2.0)
    parser.add_argument('-c', '--convert_coord', action='store_true')

    args = parser.parse_args()
    if args.id:
        args.input_cameras_file = f'{args.name}/scan{args.id}/cameras.npz'
        args.output_cameras_file = f'{args.name}/scan{args.id}/cameras_normalize.npz'

    normalize_cameras(args.input_cameras_file, args.output_cameras_file, args.number_of_cams, args.radius, args.convert_coord)


================================================
FILE: data/npz_to_blender.py
================================================
"""
    Transform npz-formatted scenes to json-formatted scene (NeRF blender format)
    Scale all poses to fit in a [-1, 1] box
"""

import copy
import json
import os
import cv2
import numpy as np
from tqdm import tqdm
import argparse

def to16b(img):
    img = img.clip(0, 1) * 65535
    return img.astype(np.uint16)


def opencv_to_gl(pose):
    mat = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])
    pose[:3, :3] = pose[:3, :3] @ mat
    return pose


def get_offset(poses):
    eyes = np.stack([pose[:3, 3] for pose in poses])

    scale = eyes.max(axis=0) - eyes.min(axis=0)
    print(f'scale : {scale}')

    offset = -(eyes.max(axis=0) + eyes.min(axis=0)) / 2
    print(f'offset : {offset}')

    return scale / 2, offset


def scale_pose(pose, scale, offset):
    pose[:3, 3] = (pose[:3, 3] + offset) / scale
    # print(pose[:3, 3])
    return pose.tolist()


def load_K_Rt_from_P(filename, P=None):
    if P is None:
        lines = open(filename).read().splitlines()
        if len(lines) == 4:
            lines = lines[1:]
        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
        P = np.asarray(lines).astype(np.float32).squeeze()

    out = cv2.decomposeProjectionMatrix(P)
    K = out[0]
    R = out[1]
    t = out[2]

    K = K / K[2, 2]
    intrinsics = np.eye(4)
    intrinsics[:3, :3] = K

    pose = np.eye(4, dtype=np.float32)
    pose[:3, :3] = R.transpose()
    pose[:3, 3] = (t[:3] / t[3])[:, 0]

    return intrinsics, pose


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--root', required=True)
    parser.add_argument('--scale', action='store_true')
    args = parser.parse_args()
    os.chdir(os.path.join(args.root))

    image_dir = 'image'
    n_images = len(os.listdir(image_dir))
    val_dir = 'val'
    n_val = len(os.listdir(val_dir))
    os.makedirs('depths', exist_ok=True)

    cam_file = 'cameras.npz'
    camera_dict = np.load(cam_file)
    world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
    val_mats = [camera_dict['val_mat_%d' % idx].astype(np.float32) for idx in range(n_val)]

    intrinsics_all = []
    pose_all = []
    for mat in world_mats + val_mats:
        P = mat
        P = P[:3, :4]
        intrinsics, pose = load_K_Rt_from_P(None, P)
        intrinsics_all.append(intrinsics)
        pose_all.append(opencv_to_gl(pose))

    train_json = dict()

    train_json['fl_y'] = intrinsics[1][1]
    train_json['h'] = int(intrinsics[1, 2] * 2)
    train_json['fl_x'] = intrinsics[0][0]
    train_json['w'] = int(intrinsics[0, 2] * 2)

    if args.scale:
        scale, offset = get_offset(pose_all)

    # train_json['enable_depth_loading'] = True
    # train_json['integer_depth_scale'] = 1 / 65535

    train_json['frames'] = []

    test_json = copy.deepcopy(train_json)
    test_json['enable_depth_loading'] = False

    for i in tqdm(range(n_images)):
        frames = train_json['frames']
        if args.scale:
            depth = cv2.imread(os.path.join('depth', '{:04d}.exr'.format(i)), -1)
            cv2.imwrite(os.path.join('depths', '{:04d}.exr'.format(i)), depth / scale.max())

        pose = pose_all[i].tolist() if not args.scale else scale_pose(pose_all[i], scale.max(), offset)
        frame = {
            'file_path': f'./image/{i:04d}',
            'depth_path': f'./depths/{i:04d}.exr' if args.scale else f'./depth/{i:04d}.exr',
            'transform_matrix': pose
        }
        frames.append(frame)

    for i in tqdm(range(n_val)):
        frames = test_json['frames']
        pose = pose_all[i + n_images].tolist() if not args.scale else scale_pose(pose_all[i + n_images], scale.max(), offset)
        frame = {
            'file_path': f'./val/{i:04d}',
            'transform_matrix': pose
        }
        frames.append(frame)

    with open('transforms_train.json', 'w') as f:
        json.dump(train_json, f, indent=4)
    with open('transforms_test.json', 'w') as f:
        json.dump(test_json, f, indent=4)
    with open('transforms_val.json', 'w') as f:
        json.dump(test_json, f, indent=4)


if __name__ == '__main__':
    main()


================================================
FILE: dataset/__init__.py
================================================
from .train_dataset import *
from .eval_dataset import *

================================================
FILE: dataset/eval_dataset.py
================================================
from copy import deepcopy
import os
import torch
import numpy as np
from torch.utils.data import Dataset
import utils.plots as plt
import torch.nn.functional as F
import utils
from utils import rend_util
from tqdm.contrib import tzip, tenumerate
from scipy.spatial.transform import Rotation as Rot
from scipy.spatial.transform import Slerp
import cv2

class GridDataset(Dataset):
    """
    Used for mesh extraction
    """
    def __init__(self, points, xyz) -> None:
        super().__init__()
        self.grid_points = points
        self.xyz = xyz
    
    def __len__(self):
        return self.grid_points.size(0)
    
    def __getitem__(self, index):
        return self.grid_points[index]


class PlotDataset(torch.utils.data.Dataset):
    def __init__(self,
                 data_dir,
                 plot_nimgs,
                 scan_id=0,
                 is_val=False,
                 data=None,
                 is_hdr=False,
                 indices=None,
                 use_lmask=False,
                 **kwargs
                 ):

        self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))
        val_dir = '{0}/val'.format(self.instance_dir)
        is_val = is_val and os.path.exists(val_dir)
        lmask_dir = '{0}/light_mask'.format(self.instance_dir)
        self.use_lmask = use_lmask and os.path.exists(lmask_dir)
        if is_val:
            print("[INFO] Validation set detected")
        if is_val or data is None:
            assert os.path.exists(self.instance_dir), "Data directory is empty"

            if is_val:
                image_dir = val_dir
            elif is_hdr:
                image_dir = '{0}/hdr'.format(self.instance_dir)
            else:
                image_dir = '{0}/image'.format(self.instance_dir)
            image_paths = sorted(utils.glob_imgs(image_dir))
            if indices is not None:
                print(f"[INFO] Selecting indices: {indices}")
                image_paths = [image_paths[i] for i in indices]
            self.n_images = len(image_paths)
            self.indices = indices if indices is not None else list(range(self.n_images))

            self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)
            camera_dict = np.load(self.cam_file)
            scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.indices] if not is_val else [camera_dict['scale_mat_0'].astype(np.float32)] * len(self.indices)
            world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.indices] if not is_val else [camera_dict['val_mat_%d' % idx].astype(np.float32) for idx in self.indices]

            self.intrinsics_all = []
            self.pose_all = []
            for scale_mat, world_mat in zip(scale_mats, world_mats):
                P = world_mat @ scale_mat
                P = P[:3, :4]
                intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
                self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
                self.pose_all.append(torch.from_numpy(pose).float())
            self.intrinsics_all = torch.stack(self.intrinsics_all, 0)
            self.pose_all = torch.stack(self.pose_all, 0)
            self.rgb_images = []
            for path in image_paths:
                rgb = rend_util.load_rgb(path, is_hdr=is_hdr)
                self.img_res = [rgb.shape[1], rgb.shape[2]]
                rgb = rgb.reshape(3, -1).transpose(1, 0)
                self.rgb_images.append(torch.from_numpy(rgb).float())
            self.rgb_images = torch.stack(self.rgb_images, 0)
            if self.use_lmask:
                self.lightmask_images = []
                lmask_paths = sorted(utils.glob_imgs(lmask_dir))
                for path in lmask_paths:
                    lmask = rend_util.load_mask(path)
                    lmask = lmask.reshape(-1, 1)
                    self.lightmask_images.append(torch.from_numpy(lmask).float())
                self.lightmask_images = torch.stack(self.lightmask_images, 0)
            self.total_pixels = self.rgb_images.size(1)
        else:
            self.intrinsics_all = data['intrinsics']
            self.pose_all = data['pose']
            self.rgb_images = data['rgb']
            self.n_images = len(self.rgb_images)
            self.img_res = [data['img_res'][0], data['img_res'][1]]
            self.total_pixels = self.img_res[0] * self.img_res[1]
            if 'light_mask' in data:
                self.lightmask_images = data['light_mask']
                self.use_lmask = True
        
        if (scale := kwargs.get('downsample', 1)) > 1:
            old_img_res = deepcopy(self.img_res)
            self.img_res[0] //= scale
            self.img_res[1] //= scale
            self.total_pixels = self.img_res[0] * self.img_res[1]
            self.rgb_images = self.rgb_images.transpose(1, 2).reshape(-1, 3, old_img_res[0], old_img_res[1])
            self.rgb_images = F.interpolate(self.rgb_images, self.img_res, mode='area')
            self.rgb_images = self.rgb_images.reshape(-1, 3, self.total_pixels).transpose(1, 2)
            if self.use_lmask:
                self.lightmask_images = self.lightmask_images.transpose(1, 2).reshape(-1, 1, old_img_res[0], old_img_res[1])
                self.lightmask_images = F.interpolate(self.lightmask_images, self.img_res, mode='area')
                self.lightmask_images = self.lightmask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2)

            self.intrinsics_all = self.intrinsics_all.clone()
            self.intrinsics_all[:,0,0] /= scale
            self.intrinsics_all[:,1,1] /= scale
            self.intrinsics_all[:,0,2] /= scale
            self.intrinsics_all[:,1,2] /= scale
        
        print(f"[INFO] Plot image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total")
        if plot_nimgs == -1:
            self.plot_nimgs = self.n_images
        else:
            self.plot_nimgs = min(plot_nimgs, self.n_images)
        self.shuffle = kwargs.get('shuffle', True)
        if self.shuffle:
            self.shuffle_plot_index()

    def shuffle_plot_index(self):
        if self.shuffle:
            self.plot_index = torch.randperm(self.n_images)[:self.plot_nimgs]

    def __len__(self):
        return self.plot_nimgs
    
    def get_uv(self):
        uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
        uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()
        uv = uv.reshape(2, -1).transpose(1, 0)
        return uv

    def __getitem__(self, idx):
        if self.shuffle:
            idx = self.plot_index[idx]
        uv = self.get_uv()

        sample = {
            "uv": uv,
            "intrinsics": self.intrinsics_all[idx],
            "pose": self.pose_all[idx]
        }

        ground_truth = {
            "rgb": self.rgb_images[idx]
        }

        if self.use_lmask:
            ground_truth['light_mask'] = self.lightmask_images[idx]

        return idx, sample, ground_truth

    def collate_fn(self, batch_list):
        # get list of dictionaries and returns input, ground_true as dictionary for all batch instances
        batch_list = zip(*batch_list)

        all_parsed = []
        for entry in batch_list:
            if type(entry[0]) is dict:
                # make them all into a new dict
                ret = {}
                for k in entry[0].keys():
                    ret[k] = torch.stack([obj[k] for obj in entry])
                all_parsed.append(ret)
            else:
                all_parsed.append(torch.LongTensor(entry))

        return tuple(all_parsed)


class InterpolateDataset(torch.utils.data.Dataset):
    """
    View interpolation: specify 2 view ids from training set and generate a video moving between them
    """
    def __init__(self,
                 data_dir,
                #  img_res,
                 id0,
                 id1,
                 num_frames=60,
                 scan_id=0,
                 **kwargs
                 ):

        self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))
        assert os.path.exists(self.instance_dir), "Data directory is empty"

        image_dir = '{0}/image'.format(self.instance_dir)
        im = cv2.imread(f"{image_dir}/{id0:04d}.png")
        h, w, _ = im.shape
        self.img_res = [h, w]
        self.total_pixels = h * w

        self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)
        camera_dict = np.load(self.cam_file)
        P0 = camera_dict['world_mat_%d' % id0].astype(np.float32) @ camera_dict['scale_mat_%d' % id0].astype(np.float32)
        P1 = camera_dict['world_mat_%d' % id1].astype(np.float32) @ camera_dict['scale_mat_%d' % id1].astype(np.float32)
        P0 = P0[:3,:]
        P1 = P1[:3,:]
        K, pose0 = rend_util.load_K_Rt_from_P(None, P0)
        _, pose1 = rend_util.load_K_Rt_from_P(None, P1)
        rots = Rot.from_matrix(np.stack([pose0[:3,:3].T, pose1[:3,:3].T]))
        slerp = Slerp([0, 1], rots)

        if (scale := kwargs.get('downsample', 1)) > 1:
            self.img_res[0] = self.img_res[0] // scale
            self.img_res[1] = self.img_res[1] // scale
            self.total_pixels = self.img_res[0] * self.img_res[1]
            K[0,0] /= scale
            K[1,1] /= scale
            K[0,2] /= scale
            K[1,2] /= scale

        self.intrinsics = torch.from_numpy(K).float()
        self.pose_all = []
        for i in range(num_frames):
            ratio = np.sin(((i / num_frames) - 0.5) * np.pi) * 0.5 + 0.5
            t = (1 - ratio) * pose0[:3,3] + ratio * pose1[:3,3]
            R = slerp(ratio).as_matrix()
            pose = np.eye(4, dtype=np.float32)
            pose[:3,3] = t
            pose[:3,:3] = R.T
            self.pose_all.append(torch.from_numpy(pose).float())
        self.pose_all = torch.stack(self.pose_all)
        self.n_frames = num_frames

    def __len__(self):
        return self.n_frames

    def __getitem__(self, idx):
        uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
        uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()
        uv = uv.reshape(2, -1).transpose(1, 0)
        sample = {
            "uv": uv,
            "intrinsics": self.intrinsics,
            "pose": self.pose_all[idx]
        }
        return idx, sample

    def collate_fn(self, batch_list):
        # get list of dictionaries and returns input, ground_true as dictionary for all batch instances
        batch_list = zip(*batch_list)

        all_parsed = []
        for entry in batch_list:
            if type(entry[0]) is dict:
                # make them all into a new dict
                ret = {}
                for k in entry[0].keys():
                    ret[k] = torch.stack([obj[k] for obj in entry])
                all_parsed.append(ret)
            else:
                all_parsed.append(torch.LongTensor(entry))

        return tuple(all_parsed)


class RelightDataset(PlotDataset):
    def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwargs):
        super().__init__(data_dir, 1, scan_id, is_val, None, False, [edit_cfg['index']], True, **kwargs)
        self.edit_mask = 'mask' in edit_cfg
        if self.edit_mask:
            self.mask = rend_util.load_mask(edit_cfg['mask']).astype(np.float32)
            mh, mw = self.mask.shape
            if mh != self.img_res[0] or mw != self.img_res[1]:
                self.mask = cv2.resize(self.mask, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA)
                self.mask = (self.mask > 0.5)
            self.mask = torch.from_numpy(self.mask).float().flatten()
            if 'normal' in edit_cfg:
                self.loadattr(edit_cfg, 'normal', 0)
                self.normal = self.normal.reshape(-1, 3)
                self.normal = F.normalize(self.normal, dim=-1, eps=1e-6)
            if 'rough' in edit_cfg:
                self.loadattr(edit_cfg, 'rough', 1)
                self.rough = self.rough.reshape(-1, 1)
            if 'kd' in edit_cfg:
                self.loadattr(edit_cfg, 'kd', 2)
                self.kd = self.kd.reshape(-1, 3)
            if 'ks' in edit_cfg:
                self.loadattr(edit_cfg, 'ks', 2)
                self.ks = self.ks.reshape(-1, 3)
        self.uv = self.get_uv()
    
    def loadattr(self, edit_cfg, attr, mode=0):
        if mode == 0:
            im = rend_util.load_normal(edit_cfg[attr])
        elif mode == 1:
            im = cv2.imread(edit_cfg[attr], -1)
            if len(im.shape) == 3:
                im = im[:,:,-1]
        else:
            im = rend_util.load_rgb(edit_cfg[attr]).transpose(1, 2, 0)
        h, w = im.shape[:2]
        if h != self.img_res[0] or w != self.img_res[1]:
            im = cv2.resize(im, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA)
        setattr(self, attr, torch.from_numpy(im).float())

    def __len__(self):
        return self.total_pixels
    
    def __getitem__(self, idx):
        sample = {
            "uv": self.uv[idx].unsqueeze(0),
            "intrinsics": self.intrinsics_all[0],
            "pose": self.pose_all[0] 
        }
        ground_truth = {
            "rgb": self.rgb_images[0][idx],
            'light_mask': self.lightmask_images[0][idx]
            # 'edit_mask': self.edit_mask[idx]
        }
        if self.edit_mask:
            ground_truth['mask'] = self.mask[idx]
        if hasattr(self, 'normal'):
            ground_truth['normal'] = self.normal[idx]
        if hasattr(self, 'rough'):
            ground_truth['rough'] = self.rough[idx]
        if hasattr(self, 'kd'):
            ground_truth['kd'] = self.kd[idx]
        if hasattr(self, 'ks'):
            ground_truth['ks'] = self.ks[idx]
        return idx, sample, ground_truth

    
class RelightVideoDataset(PlotDataset):
    def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwargs):
        self.n_frames = edit_cfg['n_frames']
        self.img_idx = edit_cfg['index']
        super().__init__(data_dir, 1, scan_id, is_val, None, False, [edit_cfg['index']] * self.n_frames, True, **kwargs)
        self.edit_mask = 'mask' in edit_cfg
        if self.edit_mask:
            self.mask = rend_util.load_mask(edit_cfg['mask']).astype(np.float32)
            mh, mw = self.mask.shape
            if mh != self.img_res[0] or mw != self.img_res[1]:
                self.mask = cv2.resize(self.mask, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA)
                self.mask = (self.mask > 0.5)
            self.mask = torch.from_numpy(self.mask).float().flatten()
        self.uv = self.get_uv()
    
    def __len__(self):
        return self.n_frames
    
    def __getitem__(self, idx):
        sample = {
            "uv": self.uv,
            "intrinsics": self.intrinsics_all[idx],
            "pose": self.pose_all[idx]
        }
        ground_truth = {
            "rgb": self.rgb_images[idx],
            'light_mask': self.lightmask_images[idx]
            # 'edit_mask': self.edit_mask[idx]
        }
        if self.edit_mask:
            ground_truth['mask'] = self.mask
        return idx, sample, ground_truth


================================================
FILE: dataset/train_dataset.py
================================================
import json
import os
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset
import torch.nn.functional as F
import utils
from utils import rend_util
from tqdm.contrib import tzip, tenumerate
from scipy.spatial.transform import Rotation as Rot
from scipy.spatial.transform import Slerp


class ReconDataset(torch.utils.data.Dataset):

    def __init__(self,
                 data_dir,
                 scan_id=0,
                 use_mask=False,
                 use_depth=False,
                 use_normal=False,
                 use_bubble=False,
                 use_lightmask=False,
                 is_hdr=False,
                 **kwargs
                 ):
        self.sampling_idx = slice(None)

        self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))
        assert os.path.exists(self.instance_dir), "Data directory is empty"
        print(f"[INFO] Loading data from {self.instance_dir}")

        image_dir = '{0}/image'.format(self.instance_dir) if not is_hdr else '{0}/hdr'.format(self.instance_dir)
        self.is_hdr = is_hdr
        if is_hdr:
            print("[INFO] Using HDR image")
        image_paths = sorted(utils.glob_imgs(image_dir))
        self.n_images = len(image_paths)

        self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)
        camera_dict = np.load(self.cam_file)
        scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
        world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]

        self.intrinsics_all = []
        self.pose_all = []
        for scale_mat, world_mat in zip(scale_mats, world_mats):
            P = world_mat @ scale_mat
            P = P[:3, :4]
            intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
            self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
            self.pose_all.append(torch.from_numpy(pose).float())
        self.intrinsics_all = torch.stack(self.intrinsics_all, 0)
        self.pose_all = torch.stack(self.pose_all, 0)

        self.rgb_images = []
        for path in image_paths:
            rgb = rend_util.load_rgb(path, is_hdr=is_hdr)
            self.img_res = [rgb.shape[1], rgb.shape[2]]
            rgb = rgb.reshape(3, -1).transpose(1, 0)
            self.rgb_images.append(torch.from_numpy(rgb).float())
        self.rgb_images = torch.stack(self.rgb_images, 0)
        self.total_pixels = self.rgb_images.size(1)
        print(f"[INFO] image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total")
        
        uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
        uv = np.flip(uv, axis=0).copy()
        self.uv = torch.from_numpy(uv).float()
        self.uv = self.uv.reshape(2, -1).transpose(1, 0) # (h*w, 2)

        mask_dir = '{0}/mask'.format(self.instance_dir)
        self.use_mask = use_mask
        if self.use_mask:
            if os.path.exists(mask_dir):
                mask_paths = sorted(utils.glob_imgs(mask_dir))
                # assert len(mask_paths) == self.n_images
                self.mask_images = []
                for path in mask_paths:
                    mask = rend_util.load_mask(path)
                    mask = mask.reshape(-1, 1)
                    self.mask_images.append(torch.from_numpy(mask).float())
                self.mask_images = torch.stack(self.mask_images, 0)
            else:
                print("[INFO] No existing mask image, use one mask as default")
                self.mask_images = torch.ones(self.n_images, self.total_pixels, 1, dtype=torch.float)
        
        lmask_dir = '{0}/light_mask'.format(self.instance_dir)
        self.use_lightmask = use_lightmask and os.path.exists(lmask_dir)
        if self.use_lightmask:
            lmask_paths = sorted(utils.glob_imgs(lmask_dir))
            self.lightmask_images = []
            for path in lmask_paths:
                lmask = rend_util.load_mask(path)
                lmask = lmask.reshape(-1, 1)
                self.lightmask_images.append(torch.from_numpy(lmask).float())
            self.lightmask_images = torch.stack(self.lightmask_images, 0)
        
        depth_dir = '{0}/depth'.format(self.instance_dir)
        self.use_depth = use_depth and os.path.exists(depth_dir)
        self.use_bubble = use_bubble and os.path.exists(depth_dir)
        if self.use_depth or self.use_bubble:
            self.depth_images = []
            self.depth_masks = []
            self.pointcloud = [] # pointcloud for bubble loss, unprojected from depth and poses
            self.pointlinks = [] # link from pixel index to pointcloud index, value -1 when the pixel is invalid at pointcloud
            self.pixlinks = [] # link from pointcloud index to pixel index
            depth_paths = sorted(utils.glob_depths(depth_dir))
            n_points = 0
            if kwargs.get('noise_scale', 0.0) > 0:
                print(f"[INFO] Ablation study: using noise scale {kwargs.get('noise_scale')}")
            for scale_mat, depth_path, intrinsics, pose, i in tzip(scale_mats, depth_paths, self.intrinsics_all, self.pose_all, range(len(self.pose_all))):
                depth = rend_util.load_depth(depth_path)
                depth = torch.from_numpy(depth.reshape(-1)).float()
                depth = depth / scale_mat[2,2]
                valid_indices = torch.where((depth > 1e-3) & (depth < 6))[0]
                if i == 0 and scale_mat[2,2] != 1:
                    print(f"[INFO] Depth scaled by {scale_mat[2,2]:.2f}")
                depth_mask = torch.zeros([self.total_pixels], dtype=torch.bool)
                depth_mask[valid_indices] = True
                # if self.use_depth:
                if (noise_scale := kwargs.get('noise_scale', 0.0)) > 0:
                    depth = rend_util.add_depth_noise(depth, depth_mask.float(), noise_scale)
                self.depth_images.append(depth)
                self.depth_masks.append(depth_mask)
                if self.use_bubble:
                    pointlink = -torch.ones([self.total_pixels], dtype=torch.long)
                    pointlink[depth_mask] = torch.arange(0, len(valid_indices), dtype=torch.long) + n_points
                    pixlink = torch.arange(i * self.total_pixels, (i + 1) * self.total_pixels, dtype=torch.long)[depth_mask]
                    n_points += len(valid_indices)
                    self.pointlinks.append(pointlink)
                    self.pixlinks.append(pixlink)
                    self.pointcloud.append(rend_util.depth_to_world(self.uv, intrinsics, pose, depth, depth_mask))

            self.depth_images = torch.stack(self.depth_images, 0)
            self.depth_masks = torch.stack(self.depth_masks, 0)
            if self.use_bubble:
                self.pointcloud = torch.cat(self.pointcloud, 0)
                self.pointlinks = torch.cat(self.pointlinks, 0)
                self.pixlinks = torch.cat(self.pixlinks, 0)
                self.pointcloud = self.pointcloud[:,:3] / self.pointcloud[:,3:]
                self.pdf_prune = kwargs.get('pdf_prune', 0)
                self.pdf_max = kwargs.get('pdf_max', None)
                print(f"[INFO] PDF clamped to {self.pdf_prune}")
        
        normal_dir = '{0}/normal'.format(self.instance_dir)
        self.use_normal = use_normal and os.path.exists(normal_dir)
        if self.use_normal:
            self.normal_images = []
            self.normal_masks = []
            normal_paths = sorted(utils.glob_normal(normal_dir))
            for pose, normal_path in tzip(self.pose_all, normal_paths):
                normal = rend_util.load_normal(normal_path)
                normal = torch.from_numpy(normal.reshape(-1, 3)).float()
                valid_indices = torch.where(torch.linalg.vector_norm(normal, dim=1) > 1e-3)[0]
                R = pose[:3,:3]
                normal = (R @ normal.T).T # convert normal from view space to world space
                normal = F.normalize(normal, dim=1, eps=1e-6)
                self.normal_images.append(normal)
                normal_mask = torch.zeros([self.total_pixels], dtype=torch.bool)
                normal_mask[valid_indices] = True
                self.normal_masks.append(normal_mask)
            self.normal_images = torch.stack(self.normal_images, 0)
            self.normal_masks = torch.stack(self.normal_masks, 0)

    def __len__(self):
        return self.n_images * self.total_pixels

    def __getitem__(self, idx):
        pidx = idx % self.total_pixels
        tidx = idx
        idx = idx // self.total_pixels
        sample = {
            "uv": self.uv[pidx].unsqueeze(0),
            "intrinsics": self.intrinsics_all[idx],
            "pose": self.pose_all[idx]
        }
        ground_truth = {
            "rgb": self.rgb_images[idx][pidx]
        }
        if self.use_mask:
            ground_truth['mask'] = self.mask_images[idx][pidx]
        if self.use_lightmask:
            ground_truth['light_mask'] = self.lightmask_images[idx][pidx]
        if self.use_depth or self.use_bubble:
            ground_truth['depth'] = self.depth_images[idx][pidx]
            ground_truth['depth_mask'] = self.depth_masks[idx][pidx]
        if self.use_normal:
            ground_truth['normal'] = self.normal_images[idx][pidx]
            ground_truth['normal_mask'] = self.normal_masks[idx][pidx]

        return tidx, idx, sample, ground_truth

    def collate_fn(self, batch_list):
        # get list of dictionaries and returns input, ground_true as dictionary for all batch instances
        batch_list = zip(*batch_list)

        all_parsed = []
        for entry in batch_list:
            if type(entry[0]) is dict:
                # make them all into a new dict
                ret = {}
                for k in entry[0].keys():
                    ret[k] = torch.stack([obj[k] for obj in entry])
                all_parsed.append(ret)
            else:
                all_parsed.append(torch.LongTensor(entry))

        return tuple(all_parsed)


class MaterialDataset(torch.utils.data.Dataset):

    def __init__(self,
                 data_dir,
                 scan_id=0,
                 downsample_train=1,
                 is_hdr=False,
                 **kwargs
                 ):
        self.sampling_idx = slice(None)

        self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))

        assert os.path.exists(self.instance_dir), "Data directory is empty"

        image_dir = '{0}/image'.format(self.instance_dir) if not is_hdr else '{0}/hdr'.format(self.instance_dir)
        self.is_hdr = is_hdr
        if is_hdr:
            print("[INFO] Using HDR image")
        image_paths = sorted(utils.glob_imgs(image_dir))
        self.n_images = len(image_paths)

        self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)
        camera_dict = np.load(self.cam_file)
        scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]
        world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]

        self.intrinsics_all = []
        self.pose_all = []

        for scale_mat, world_mat in zip(scale_mats, world_mats):
            P = world_mat @ scale_mat
            P = P[:3, :4]
            intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)
            self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
            self.pose_all.append(torch.from_numpy(pose).float())

        self.intrinsics_all = torch.stack(self.intrinsics_all, 0)
        self.pose_all = torch.stack(self.pose_all, 0)

        self.rgb_images = []
        for path in image_paths:
            rgb = rend_util.load_rgb(path, is_hdr=is_hdr)
            self.img_res = [rgb.shape[1], rgb.shape[2]]
            rgb = rgb.reshape(3, -1).transpose(1, 0)
            self.rgb_images.append(torch.from_numpy(rgb).float())
        self.rgb_images = torch.stack(self.rgb_images, 0)
        self.total_pixels = self.rgb_images.size(1)

        mask_dir = '{0}/mask'.format(self.instance_dir)
        self.use_mask = os.path.exists(mask_dir)
        if self.use_mask:
            mask_paths = sorted(utils.glob_imgs(mask_dir))
            # assert len(mask_paths) == self.n_images
            self.mask_images = []
            for path in mask_paths:
                mask = rend_util.load_mask(path)
                mask = mask.reshape(-1, 1)
                self.mask_images.append(torch.from_numpy(mask).float())
            self.mask_images = torch.stack(self.mask_images, 0)
        
        lmask_dir = '{0}/light_mask'.format(self.instance_dir)
        self.use_lightmask = os.path.exists(lmask_dir)
        if self.use_lightmask:
            print("[INFO] Light mask detected")
            lmask_paths = sorted(utils.glob_imgs(lmask_dir))
            self.lightmask_images = []
            for path in lmask_paths:
                lmask = rend_util.load_mask(path)
                lmask = lmask.reshape(-1, 1)
                self.lightmask_images.append(torch.from_numpy(lmask).float())
            self.lightmask_images = torch.stack(self.lightmask_images, 0)
        
        if downsample_train > 1:
            old_res = (self.img_res[0], self.img_res[1])
            self.rgb_images = self.rgb_images.transpose(1, 2).reshape(-1, 3, *old_res)
            self.img_res[0] //= downsample_train
            self.img_res[1] //= downsample_train
            self.total_pixels = self.img_res[0] * self.img_res[1]
            self.rgb_images = F.interpolate(self.rgb_images, self.img_res, mode='area')
            self.rgb_images = self.rgb_images.reshape(-1, 3, self.total_pixels).transpose(1, 2)
            self.intrinsics_all = self.intrinsics_all.clone()
            self.intrinsics_all[:,0,0] /= downsample_train
            self.intrinsics_all[:,1,1] /= downsample_train
            self.intrinsics_all[:,0,2] /= downsample_train
            self.intrinsics_all[:,1,2] /= downsample_train
            if self.use_mask:
                self.mask_images = self.mask_images.transpose(1, 2).reshape(-1, 1, *old_res)
                self.mask_images = F.interpolate(self.mask_images, self.img_res, mode='area')
                self.mask_images = self.mask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2)
            if self.use_lightmask:
                self.lightmask_images = self.lightmask_images.transpose(1, 2).reshape(-1, 1, *old_res)
                self.lightmask_images = F.interpolate(self.lightmask_images, self.img_res, mode='area')
                self.lightmask_images[self.lightmask_images > 0] = 1
                self.lightmask_images = self.lightmask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2)

        print(f"[INFO] image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total")
        
        uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)
        uv = np.flip(uv, axis=0).copy()
        self.uv = torch.from_numpy(uv).float()
        self.uv = self.uv.reshape(2, -1).transpose(1, 0) # (h*w, 2)

    def __len__(self):
        return self.n_images * self.total_pixels

    def __getitem__(self, idx):
        pidx = idx % self.total_pixels
        tidx = idx
        idx = idx // self.total_pixels
        sample = {
            "uv": self.uv[pidx].unsqueeze(0),
            "intrinsics": self.intrinsics_all[idx],
            "pose": self.pose_all[idx]
        }
        ground_truth = {
            "rgb": self.rgb_images[idx][pidx]
        }
        if self.use_mask:
            ground_truth['mask'] = self.mask_images[idx][pidx]
        if self.use_lightmask:
            ground_truth['light_mask'] = self.lightmask_images[idx][pidx]
        return tidx, idx, sample, ground_truth

    def collate_fn(self, batch_list):
        # get list of dictionaries and returns input, ground_true as dictionary for all batch instances
        batch_list = zip(*batch_list)

        all_parsed = []
        for entry in batch_list:
            if type(entry[0]) is dict:
                # make them all into a new dict
                ret = {}
                for k in entry[0].keys():
                    ret[k] = torch.stack([obj[k] for obj in entry])
                all_parsed.append(ret)
            else:
                all_parsed.append(torch.LongTensor(entry))

        return tuple(all_parsed)

================================================
FILE: environment.yml
================================================
name: i2sdf
channels:
  - pytorch
  - conda-forge
  - defaults
dependencies:
  - cudatoolkit=11.3.1=h9edb442_10
  - ffmpeg=4.3=hf484d3e_0
  - numpy=1.23.5=py39h14f4228_0
  - pip=23.0.1=py39h06a4308_0
  - python=3.9.16=h7a1cb2a_2
  - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0
  - torchaudio=0.12.1=py39_cu113
  - torchvision=0.13.1=py39_cu113
  - pip:
    - fast-pytorch-kmeans==0.1.9
    - ffmpeg-python==0.2.0
    - gputil==1.4.0
    - lpips==0.1.4
    - open3d==0.17.0
    - opencv-python==4.7.0.72
    - pymcubes==0.1.4
    - pytorch-lightning==1.9.0
    - pyyaml==6.0
    - rich==13.3.3
    - scikit-image==0.20.0
    - scikit-learn==1.2.2
    - scipy==1.9.1
    - tensorboard==2.12.0
    - tensorboardx==2.6
    - torchmetrics==0.11.4
    - tqdm==4.65.0
    - trimesh==3.21.4


================================================
FILE: i2-sdf-dataset-links.csv
================================================
file,url
interiorverse/i2-sdf/i2-sdf/bedroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_0.zip
interiorverse/i2-sdf/i2-sdf/bedroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_0_preview.png
interiorverse/i2-sdf/i2-sdf/bedroom_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_1.zip
interiorverse/i2-sdf/i2-sdf/bedroom_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_1_preview.png
interiorverse/i2-sdf/i2-sdf/bedroom_relight_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_0.zip
interiorverse/i2-sdf/i2-sdf/bedroom_relight_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_0_preview.png
interiorverse/i2-sdf/i2-sdf/bedroom_relight_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_1.zip
interiorverse/i2-sdf/i2-sdf/bedroom_relight_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_1_preview.png
interiorverse/i2-sdf/i2-sdf/diningroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/diningroom_0.zip
interiorverse/i2-sdf/i2-sdf/diningroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/diningroom_0_preview.png
interiorverse/i2-sdf/i2-sdf/kitchen_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/kitchen_0.zip
interiorverse/i2-sdf/i2-sdf/kitchen_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/kitchen_0_preview.png
interiorverse/i2-sdf/i2-sdf/livingroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_0.zip
interiorverse/i2-sdf/i2-sdf/livingroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_0_preview.png
interiorverse/i2-sdf/i2-sdf/livingroom_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_1.zip
interiorverse/i2-sdf/i2-sdf/livingroom_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_1_preview.png
interiorverse/i2-sdf/i2-sdf/real_data/inria_livingroom.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/inria_livingroom.zip
interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_0.zip
interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_1.zip
interiorverse/i2-sdf/i2-sdf/real_data/nisr_livingroom.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_livingroom.zip


================================================
FILE: main_recon.py
================================================
import torch
import yaml
import pytorch_lightning as pl
import argparse
import os
import utils
import model
from pytorch_lightning import loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from rich.progress import TextColumn
import GPUtil

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--conf", type=str, required=True, help="Path to (.yml) config file.")
    parser.add_argument('-d', "--device_ids", type=int, nargs='+', default=None, help="GPU devices to use")
    parser.add_argument("--exps_folder", type=str, default="exps")
    parser.add_argument('--expname', type=str, default='')
    parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--test_mode', choices=['render', 'mesh', 'interpolate'], default='render')
    parser.add_argument('-v', '--version', type=int, nargs='?')
    parser.add_argument('--inter_id', type=int, nargs=2, required=False, help='2 view ids for interpolation video.')
    parser.add_argument('-i', '--indices', nargs='*', type=int, help='If set, render only specified indices of the dataset instead of all images.')
    parser.add_argument('--n_frames', type=int, default=60, help='Number of frames in the interpolation video.')
    parser.add_argument('--frame_rate', type=int, default=24, help='Frame rate of the interpolation video.')
    parser.add_argument('-f', '--full_res', action='store_true', help='If set, dataset downscaling will be ignored.')
    parser.add_argument('--is_val', action='store_true', help='If set, render the validation set instead of training set.')
    parser.add_argument('--val_mesh', action='store_true', help='If set, extract and save mesh every validation epoch.')
    parser.add_argument('--score', action='store_true', help='If set, evaluate the meshing score (need to provide GT mesh).')
    parser.add_argument('--far_clip', type=float, default=5.0)
    parser.add_argument('--ckpt', type=str, default='last')
    parser.add_argument('--resolution', type=int, default=512, help='Resolution for marching cube algorithm')
    parser.add_argument('--spp', type=int, default=128)
    parser.add_argument('--seed', type=int, default=42)
    args = parser.parse_args()

    with open(args.conf) as f:
        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)
        cfg = utils.CfgNode(cfg_dict)
    
    expname = args.expname if args.expname else cfg.train.expname
    scan_id = cfg.dataset.scan_id if args.scan_id == -1 else args.scan_id
    cfg.dataset.scan_id = scan_id
    expname = expname + '_' + str(scan_id)

    if args.version is None and (v := args.conf.find("version_")) != -1:
        args.version = int(args.conf[v + 8:args.conf.find("/config")])
        print(f"[INFO] Loaded version {args.version} from config file")
    
    if args.version is not None:
        logger = loggers.TensorBoardLogger(save_dir=args.exps_folder, name=expname, version=args.version)
    else:
        logger = loggers.TensorBoardLogger(save_dir=args.exps_folder, name=expname)
    
    if args.device_ids is None:
        args.device_ids = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False,
                                              excludeID=[], excludeUUID=[])
        print("Selected GPU {} automatically".format(args.device_ids[0]))
    torch.cuda.set_device(args.device_ids[0])
    torch.set_float32_matmul_precision('medium')
    progbar_callback = utils.RichProgressBarWithScanId(scan_id, leave=False)
    pl.seed_everything(args.seed)
    
    if args.test:
        version = args.version if args.version is not None else logger.version - 1
        exp_dir = os.path.join(logger.root_dir, f"version_{version}")
        del logger
        if args.test_mode == 'render':
            system = model.VolumeRenderSystem(cfg, exp_dir, indices=args.indices, is_val=args.is_val, full_res=args.full_res)
            if not args.ckpt.endswith('.ckpt'):
                args.ckpt += '.ckpt'
            ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda')
            system.load_state_dict(ckpt['state_dict'])
            model.lpips.cuda()
        elif args.test_mode == 'mesh':
            system = model.SDFMeshSystem(cfg, exp_dir, args.resolution, args.score)
            if not args.ckpt.endswith('.ckpt'):
                args.ckpt += '.ckpt'
            ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda')
            system.load_state_dict(ckpt['state_dict'])
            system.cuda()
            system.eval()
            system.initialize()
        # elif args.test_mode == 'interpolate':
        else:
            system = model.ViewInterpolateSystem(cfg, exp_dir, *args.inter_id, n_frames=args.n_frames, frame_rate=args.frame_rate)
            if not args.ckpt.endswith('.ckpt'):
                args.ckpt += '.ckpt'
            ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda')
            system.load_state_dict(ckpt['state_dict'])
        trainer = pl.Trainer(
            logger=False,
            accelerator='gpu',
            devices=args.device_ids,
            callbacks=[progbar_callback]
        )
        trainer.test(system)
    else:
        max_steps = cfg.train.get('steps', 200000)
        print(f"Training for {max_steps} steps")
        exp_dir = logger.log_dir
        checkpoint_callback = ModelCheckpoint(os.path.join(exp_dir, 'checkpoints'), save_last=True, every_n_train_steps=cfg.train.checkpoint_freq)
        if hasattr(cfg.train, 'plot_freq'):
            kwargs = {'val_check_interval': cfg.train.plot_freq}
        else:
            kwargs = {'check_val_every_n_epoch': cfg.train.plot_epochs}
        trainer = pl.Trainer(
            logger=logger,
            accelerator='gpu',
            devices=args.device_ids,
            strategy=None,
            callbacks=[checkpoint_callback, progbar_callback],
            max_steps=max_steps,
            **kwargs
        )
        system = model.ReconstructionTrainer(
            cfg, progbar_callback,
            exp_dir=exp_dir,
            is_val=args.is_val,
            val_mesh=args.val_mesh
        )
        trainer.fit(system)
    torch.cuda.empty_cache()

================================================
FILE: model/__init__.py
================================================
from .network import *
from .trainer import *
# from .material import *
from .eval import *
from .rendering import RenderingLayer

================================================
FILE: model/eval/__init__.py
================================================
from .recon import *

================================================
FILE: model/eval/recon.py
================================================
import torch
import pytorch_lightning as pl
import numpy as np
import os
from glob import glob
from torch.utils.data import DataLoader
import utils
from utils import rend_util
import utils.plots as plt
import dataset
import model
from skimage import measure
import cv2
import trimesh
from rich.progress import track
from torchmetrics.functional import structural_similarity_index_measure as ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS

lpips = LPIPS()

class SDFMeshSystem(pl.LightningModule):
    def __init__(self, conf, exp_dir, resolution, score=False, far_clip=5.0) -> None:
        super().__init__()
        self.expdir = exp_dir
        conf_model = conf.model
        conf_model.use_normal = False
        self.model = model.I2SDFNetwork(conf_model)
        self.resolution = resolution
        self.grid_boundary = conf.plot.grid_boundary
        self.initialized = False
        self.instance_dir = os.path.join('data', conf.dataset.data_dir, 'scan{0}'.format(conf.dataset.scan_id))
        camera_dict = np.load(os.path.join(self.instance_dir, 'cameras_normalize.npz'))
        self.scale_mat = camera_dict['scale_mat_0']
        self.scan_id = conf.dataset.scan_id
        self.score = score
        if score:
            self.n_imgs = len(os.listdir(os.path.join(self.instance_dir, 'image')))
            self.poses = []
            self.far_clip = far_clip
            for i in range(self.n_imgs):
                K, pose = rend_util.load_K_Rt_from_P(None, camera_dict[f'world_mat_{i}'][:3,:])
                self.poses.append(pose)
            self.K = K
            self.H, self.W, _ = cv2.imread(os.path.join(self.instance_dir, 'image', '0000.png')).shape

    def initialize(self):
        grid = plt.get_grid_uniform(100, self.grid_boundary)
        z = []
        points = grid['grid_points']
        for pnts in track(torch.split(points, 1000000, dim=0)):
            z.append(self.model.implicit_network(pnts)[:,0].detach().cpu().numpy())
        z = np.concatenate(z, axis=0).astype(np.float32)
        verts, faces, normals, values = measure.marching_cubes(
        volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],
                         grid['xyz'][2].shape[0]).transpose([1, 0, 2]),
                         level=0,
                         spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],
                         grid['xyz'][0][2] - grid['xyz'][0][1],
                         grid['xyz'][0][2] - grid['xyz'][0][1]))
        verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])
        mesh_low_res = trimesh.Trimesh(verts, faces, normals)
        recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]
        recon_pc = torch.from_numpy(recon_pc).float().cuda()
        s_mean = recon_pc.mean(dim=0)
        s_cov = recon_pc - s_mean
        s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)
        self.vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0]
        if torch.det(self.vecs) < 0:
            self.vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), self.vecs)
        helper = torch.bmm(self.vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),
                        (recon_pc - s_mean).unsqueeze(-1)).squeeze().cpu()
        grid_aligned = plt.get_grid(helper, self.resolution)
        grid_points = grid_aligned['grid_points']
        g = []
        for pnts in track(torch.split(grid_points, 1000000, dim=0)):
            g.append(torch.bmm(self.vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),
                           pnts.unsqueeze(-1)).squeeze() + s_mean)
        grid_points = torch.cat(g, dim=0)
        points = grid_points.cpu()
        self.test_dataset = dataset.GridDataset(points, grid_aligned['xyz'])
        self.grid_points = grid_points
        self.initialized = True

    def test_dataloader(self):
        assert self.initialized
        print(len(self.test_dataset))
        return DataLoader(self.test_dataset, batch_size=2000000, shuffle=False, num_workers=32)
    
    def test_step(self, batch, batch_idx):
        return self.model.implicit_network(batch)[:,0].detach().cpu().numpy()
    
    def test_epoch_end(self, outputs) -> None:
        # z = torch.cat(outputs, dim=0).cpu().numpy()
        z = np.concatenate(outputs, axis=0).astype(np.float32)
        if (not (np.min(z) > 0 or np.max(z) < 0)):
            verts, faces, normals, values = measure.marching_cubes(
                            volume=z.reshape(self.test_dataset.xyz[1].shape[0], self.test_dataset.xyz[0].shape[0], self.test_dataset.xyz[2].shape[0]).transpose([1, 0, 2]),
                            level=0,
                            spacing=(self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1],
                                     self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1],
                                     self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1]))
            verts = torch.from_numpy(verts).float().cuda()
            verts = torch.bmm(self.vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),
                   verts.unsqueeze(-1)).squeeze()
            verts = (verts + self.grid_points[0]).cpu().numpy()
            mesh = trimesh.Trimesh(verts, faces, normals)
            mesh.apply_transform(self.scale_mat)
            mesh_folder = os.path.join(self.expdir, 'eval/mesh')
            os.makedirs(mesh_folder, exist_ok=True)
            mesh.export(os.path.join(mesh_folder, 'scan{0}.ply'.format(self.scan_id)), 'ply')
            if self.score:
                from utils import mesh_util
                import open3d as o3d
                mesh = mesh_util.refuse(mesh, self.poses, self.K, self.H, self.W)
                out_mesh_path = os.path.join(mesh_folder, 'scan{0}_refined.ply'.format(self.scan_id))
                o3d.io.write_triangle_mesh(out_mesh_path, mesh)
                mesh = trimesh.load(out_mesh_path)
                print("[INFO] Pred mesh refined")
                gt_mesh = trimesh.load(os.path.join(self.instance_dir, 'mesh.ply'))
                gt_mesh = mesh_util.refuse(gt_mesh, self.poses, self.K, self.H, self.W, self.far_clip)
                out_mesh_path = os.path.join(mesh_folder, 'scan{0}_gt.ply'.format(self.scan_id))
                o3d.io.write_triangle_mesh(out_mesh_path, gt_mesh)
                gt_mesh = trimesh.load(out_mesh_path)
                print("[INFO] GT mesh refined")
                metrics = mesh_util.evaluate(mesh, gt_mesh)
                with open(f"{mesh_folder}/metrics.txt", 'w') as f:
                    for k in metrics:
                        f.write(f"{k.upper()}: {metrics[k]}\n")
                print(f"[INFO] Metrics saved to {mesh_folder}/metrics.txt\n")

    def forward(self):
        raise NotImplementedError("forward not supported by trainer")


class VolumeRenderSystem(pl.LightningModule):
    def __init__(self, conf, exp_dir, indices=None, is_val=False, score_mesh=False, full_res=False) -> None:
        super().__init__()
        self.expdir = exp_dir
        conf_model = conf.model
        conf_model.use_normal = False
        self.model = model.I2SDFNetwork(conf_model)
        self.scan_id = conf.dataset.scan_id
        dataset_conf = conf.dataset
        if full_res:
            dataset_conf.downsample = 1
        self.test_dataset = dataset.PlotDataset(**dataset_conf, plot_nimgs=-1, shuffle=False, indices=indices, is_val=is_val)
        self.total_pixels = self.test_dataset.total_pixels
        self.img_res = self.test_dataset.img_res
        self.split_n_pixels = conf.train.split_n_pixels
        self.expdir = os.path.join(self.expdir, 'eval')
        if is_val:
            self.expdir = os.path.join(self.expdir, 'test')
        os.makedirs(os.path.join(self.expdir, 'rendering'), exist_ok=True)
        os.makedirs(os.path.join(self.expdir, 'depth'), exist_ok=True)
        os.makedirs(os.path.join(self.expdir, 'normal'), exist_ok=True)

    def test_dataloader(self):
        print(len(self.test_dataset))
        return DataLoader(self.test_dataset, batch_size=1, shuffle=False, collate_fn=self.test_dataset.collate_fn)

    @torch.inference_mode(False)
    @torch.no_grad()
    def test_step(self, batch, batch_idx):
        indices, model_input, ground_truth = batch
        # idx = batch_idx
        idx = self.test_dataset.indices[batch_idx]
        split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels)
        res = []
        for s in split:
            out = utils.detach_dict(self.model(s))
            d = {
                'rgb_values': out['rgb_values'].detach(),
                'depth_values': out['depth_values'].detach()
            }
            d['normal_map'] = out['normal_map'].detach()
            # d['surface_point'] = out['surface_point'].detach()
            del out
            res.append(d)
        model_outputs = utils.merge_output(res, self.total_pixels, 1)
        _, num_samples, _ = ground_truth['rgb'].shape
        model_outputs['rgb_values'] = model_outputs['rgb_values'].reshape(1, num_samples, 3)
        model_outputs['depth_values'] = model_outputs['depth_values'].reshape(1, num_samples, 1)
        plt.plot_imgs_wo_gt(model_outputs['normal_map'].reshape(1, num_samples, 3), self.expdir, "{:04d}w".format(idx), 1, self.img_res, is_hdr=True)
        normal_map = model_outputs['normal_map'].reshape(num_samples, 3).T # (3, h*w)
        R = model_input['pose'].squeeze()[:3,:3].T
        normal_map = R @ normal_map
        model_outputs['normal_map'] = normal_map.T.reshape(1, num_samples, 3)
        plt.plot_imgs_wo_gt(model_outputs['normal_map'], self.expdir, "{:04d}".format(idx), 1, self.img_res, is_hdr=True)
        model_outputs['normal_map'] = (model_outputs['normal_map'] + 1.) / 2.
        plt.plot_imgs_wo_gt(model_outputs['normal_map'], self.expdir, "{:04d}".format(idx), 1, self.img_res)
        # plt.plot_imgs_wo_gt(model_outputs['surface_point'].reshape(1, num_samples, 3), self.expdir, "{:04d}p".format(idx), 1, self.img_res, is_hdr=True)

        plt.plot_images(model_outputs['rgb_values'], ground_truth['rgb'], self.expdir, "{:04d}".format(idx), 1, self.img_res)
        plt.plot_imgs_wo_gt(model_outputs['rgb_values'], self.expdir, "{:04d}_pred".format(idx), 1, self.img_res, 'rendering')
        plt.plot_depths(model_outputs['depth_values'], self.expdir, "{:04d}".format(idx), 1, self.img_res)
        plt.plot_depths(model_outputs['depth_values'], self.expdir, "{:04d}".format(idx), 1, self.img_res, None)
        pred_img = model_outputs['rgb_values'].T.reshape(3, *self.img_res).unsqueeze(0)
        gt_img = ground_truth['rgb'].T.reshape(3, *self.img_res).unsqueeze(0)
        return {
            'psnr': utils.get_psnr(model_outputs['rgb_values'], ground_truth['rgb']).item(),
            'ssim': ssim(pred_img, gt_img).item(),
            'lpips': lpips(pred_img.clamp(0, 1) * 2 - 1, gt_img.clamp(0, 1) * 2 - 1).item()
        }
    
    def test_epoch_end(self, outputs):
        with open(os.path.join(self.expdir, 'metrics.txt'), 'w') as f:
            f.write(f"# IMAGE RESOLUTION {self.img_res}\n")
            psnr_sum = ssim_sum = lpips_sum = 0
            psnrs = []
            ssims = []
            lpipss = []
            for i, metrics in enumerate(outputs):
                f.write(f"[{i:04d}] [PSNR]{metrics['psnr']:.2f} [SSIM]{metrics['ssim']:.2f} [LPIPS]{metrics['lpips']:.2f}\n")
                psnrs.append(metrics['psnr'])
                ssims.append(metrics['ssim'])
                lpipss.append(metrics['lpips'])
                psnr_sum += metrics['psnr']
                ssim_sum += metrics['ssim']
                lpips_sum += metrics['lpips']
            f.write(f"[MEAN] [PSNR]{psnr_sum/len(outputs):.2f} [SSIM]{ssim_sum/len(outputs):.2f} [LPIPS]{lpips_sum/len(outputs):.2f}\n")
            np.savez_compressed(os.path.join(self.expdir, 'metrics.npz'), psnr=np.array(psnrs), ssim=np.array(ssims), lpips=np.array(lpipss))

    def forward(self):
        raise NotImplementedError("forward not supported by trainer")


class ViewInterpolateSystem(pl.LightningModule):
    def __init__(self, conf, exp_dir, id0, id1, n_frames=60, frame_rate=24, use_normal=True) -> None:
        super().__init__()
        self.expdir = exp_dir
        conf_model = conf.model
        conf_model.use_normal = False
        self.model = model.I2SDFNetwork(conf_model)
        self.scan_id = conf.dataset.scan_id
        dataset_conf = conf.dataset
        self.test_dataset = dataset.InterpolateDataset(**dataset_conf, id0=id0, id1=id1, num_frames=n_frames)
        self.total_pixels = self.test_dataset.total_pixels
        self.img_res = self.test_dataset.img_res
        self.split_n_pixels = conf.train.split_n_pixels
        self.n_frames = n_frames
        self.frame_rate = frame_rate
        self.video_dir = os.path.join(self.expdir, 'eval/interpolate')
        self.id0 = id0
        self.id1 = id1
        self.use_normal = use_normal
        os.makedirs(self.video_dir, exist_ok=True)
        self.frame_dir = os.path.join(self.video_dir, f"{self.id0:04d}_{self.id1:04d}")
        os.makedirs(self.frame_dir, exist_ok=True)
        if self.use_normal:
            self.normal_fdir = os.path.join(self.video_dir, f"{self.id0:04d}_{self.id1:04d}_normal")
            os.makedirs(self.normal_fdir, exist_ok=True)

    def test_dataloader(self):
        print(len(self.test_dataset))
        return DataLoader(self.test_dataset, batch_size=1, shuffle=False, collate_fn=self.test_dataset.collate_fn)
    
    @torch.inference_mode(False)
    @torch.no_grad()
    def test_step(self, batch, batch_idx):
        indices, model_input = batch
        idx = batch_idx
        split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels)
        res = []
        res_normal = []
        for s in split:
            out = utils.detach_dict(self.model(s, predict_only=not self.use_normal))
            rgb = out['rgb_values'].detach()
            res.append(rgb)
            if self.use_normal:
                res_normal.append(out['normal_map'])
            del out
        rendered = torch.cat(res, dim=0).reshape(self.img_res[0], self.img_res[1], 3).cpu().numpy()
        rendered = (rendered * 255).clip(0, 255).astype(np.uint8)
        cv2.imwrite(f"{self.frame_dir}/{idx:04d}.png", rendered[:,:,::-1])
        if self.use_normal:
            normal_map = torch.cat(res_normal, dim=0).reshape(-1, 3).T
            R = model_input['pose'].squeeze()[:3,:3].T
            normal_map = R @ normal_map
            normal_map = normal_map.T.reshape(self.img_res[0], self.img_res[1], 3).cpu().numpy()
            normal_map = (((normal_map + 1) * 0.5) * 255).clip(0, 255).astype(np.uint8)
            cv2.imwrite(f"{self.normal_fdir}/{idx:04d}.png", normal_map[:,:,::-1])

    
    def test_epoch_end(self, outputs):
        import ffmpeg
        (
            ffmpeg
            .input(os.path.join(self.frame_dir, '*.png'), pattern_type='glob', framerate=self.frame_rate)
            .output(os.path.join(self.video_dir, f"scan{self.scan_id}_{self.id0:04d}_{self.id1:04d}.mp4"), vcodec='h264')
            .overwrite_output()
            .run()
        )
        if self.use_normal:
            (
                ffmpeg
                .input(os.path.join(self.normal_fdir, '*.png'), pattern_type='glob', framerate=self.frame_rate)
                .output(os.path.join(self.video_dir, f"scan{self.scan_id}_{self.id0:04d}_{self.id1:04d}_normal.mp4"), vcodec='h264')
                .overwrite_output()
                .run()
            )

    def forward(self):
        raise NotImplementedError("forward not supported by trainer")



================================================
FILE: model/network/__init__.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

import utils
from model.network.mlp import ImplicitNetwork, RenderingNetwork
from model.network.density import LaplaceDensity, AbsDensity
from model.network.ray_sampler import ErrorBoundSampler
from fast_pytorch_kmeans import KMeans
from sklearn.cluster import DBSCAN


"""
For modeling more complex backgrounds, we follow the inverted sphere parametrization from NeRF++ 
https://github.com/Kai-46/nerfplusplus 
"""
class I2SDFNetwork(nn.Module):
    def __init__(self, conf):
        super().__init__()
        self.feature_vector_size = conf.feature_vector_size
        self.scene_bounding_sphere = getattr(conf, 'scene_bounding_sphere', 1.0)

        # Foreground object's networks
        self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0, **conf.implicit_network)
        self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.rendering_network)

        self.use_light = hasattr(conf, 'light_network')
        if self.use_light:
            # self.light_network = RenderingNetwork(self.feature_vector_size, mode='nerf', output_activation='sigmoid', use_dir=False, **conf.light_network)
            self.light_network = ImplicitNetwork(0, 0, d_in=self.feature_vector_size, d_out=1, geometric_init=False, embed_type=None, output_activation='sigmoid', **conf.light_network)

        self.density = LaplaceDensity(**conf.density)

        # Background's networks
        self.use_bg = hasattr(conf, 'bg_network')
        if self.use_bg:
            bg_feature_vector_size = conf.bg_network.feature_vector_size
            self.bg_implicit_network = ImplicitNetwork(bg_feature_vector_size, 0.0, **conf.bg_network.implicit_network)
            self.bg_rendering_network = RenderingNetwork(bg_feature_vector_size, **conf.bg_network.rendering_network)
            self.bg_density = AbsDensity(**getattr(conf.bg_network, 'density', {}))
        else:
            print("[INFO] BG Network Disabled")
        self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, inverse_sphere_bg=self.use_bg, **conf.ray_sampler)
        self.use_normal = conf.get('use_normal', False)
        self.detach_light_feature = conf.get('detach_light_feature', True)
    
    def init_emission_groups(self, n_emitters, pointcloud, init_emission=1.0, use_dbscan=False):
        if use_dbscan:
            """
            Use DBSCAN algorithm to initialize emitter cluster centroids for K-Means from a small random batch
            Note that DBSCAN can automatically determine the number of clusters
            """
            pt_samples = pointcloud[torch.randperm(len(pointcloud))[:10000]].cpu().numpy()
            lab_samples = torch.from_numpy(DBSCAN(n_jobs=16).fit_predict(pt_samples))
            if n_emitters != len(torch.unique(lab_samples)):
                print(f"[ERROR] Inconsistent emitter count: {n_emitters} / {len(torch.unique(lab_samples))}")
                # n_emitters = len(torch.unique(lab_samples))
                exit()
            init_centroids = torch.zeros(n_emitters, 3)
            for i in range(n_emitters):
                idx = (lab_samples == i).int().argmax()
                init_centroids[i,:] = torch.from_numpy(pt_samples[idx, :])
            init_centroids = init_centroids.to(pointcloud.device)
        else:
            """
            Use K-Means plus plus to initialize emitter cluster centroids for K-Means
            """
            init_centroids = utils.kmeans_pp_centroid(pointcloud, n_emitters)
        self.emitter_clusters = KMeans(n_emitters)
        labels = self.emitter_clusters.fit_predict(pointcloud, init_centroids)
        print("[INFO] emitters clustered")
        self.emissions = nn.Parameter(torch.empty(n_emitters, 3).fill_(init_emission), True)
        return labels, self.emitter_clusters.centroids
    
    def get_param_groups(self, lr):
        return [{'params': self.parameters(), 'lr': lr}]

    def forward(self, input, predict_only=False):

        intrinsics = input["intrinsics"]
        uv = input["uv"]
        pose = input["pose"]

        ray_dirs, cam_loc = utils.get_camera_params(uv, pose, intrinsics)

        batch_size, num_pixels, _ = ray_dirs.shape

        cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)
        ray_dirs = ray_dirs.reshape(-1, 3)
        ray_dirs_norm = torch.linalg.vector_norm(ray_dirs, dim=1)
        ray_dirs = F.normalize(ray_dirs, dim=1)

        z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self)

        if self.use_bg:
            z_vals, z_vals_bg = z_vals
        z_max = z_vals[:,-1]
        z_vals = z_vals[:,:-1]
        N_samples = z_vals.shape[1]

        points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1)
        points_flat = points.reshape(-1, 3)

        dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1)
        dirs_flat = dirs.reshape(-1, 3)

        returns_grad = self.use_normal or (not self.training) or (self.rendering_network.mode == 'idr')
        # with torch.enable_grad():
        with torch.set_grad_enabled(returns_grad):
        # with torch.inference_mode(not returns_grad):
            sdf, feature_vectors, gradients = self.implicit_network.get_outputs(points_flat, returns_grad)

        rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors)
        rgb = rgb_flat.reshape(-1, N_samples, 3)

        weights, bg_transmittance = self.volume_rendering(z_vals, z_max, sdf)

        fg_rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1)

        weight_sum = torch.sum(weights, -1, keepdim=True)
        # dist = torch.sum(weights / weight_sum.clamp(min=1e-6) * z_vals, 1)
        dist = torch.sum(weights * z_vals, 1)
        depth_values = dist / torch.clamp(ray_dirs_norm, min=1e-6)
        # depth_values = torch.sum(weights * z_vals, 1) / torch.clamp(torch.sum(weights, 1), min=1e-6) # (bn,)

        # Background rendering
        if self.use_bg:
            N_bg_samples = z_vals_bg.shape[1]
            z_vals_bg = torch.flip(z_vals_bg, dims=[-1, ])  # 1--->0

            bg_dirs = ray_dirs.unsqueeze(1).repeat(1,N_bg_samples,1)
            bg_locs = cam_loc.unsqueeze(1).repeat(1,N_bg_samples,1)

            bg_points = self.depth2pts_outside(bg_locs, bg_dirs, z_vals_bg)  # [..., N_samples, 4]
            bg_points_flat = bg_points.reshape(-1, 4)
            bg_dirs_flat = bg_dirs.reshape(-1, 3)

            output = self.bg_implicit_network(bg_points_flat)
            bg_sdf = output[:,:1]
            bg_feature_vectors = output[:, 1:]
            bg_rgb_flat = self.bg_rendering_network(None, None, bg_dirs_flat, bg_feature_vectors)
            bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3)

            bg_weights = self.bg_volume_rendering(z_vals_bg, bg_sdf)

            bg_rgb_values = torch.sum(bg_weights.unsqueeze(-1) * bg_rgb, 1)

            # Composite foreground and background
            bg_rgb_values = bg_transmittance.unsqueeze(-1) * bg_rgb_values
            rgb_values = fg_rgb_values + bg_rgb_values
        else:
            rgb_values = fg_rgb_values

        output = {
            'rgb_values': rgb_values,
            'depth_values': depth_values,
            'weight_sum': weight_sum
        }
        
        if self.use_light:
            light_features = F.relu(feature_vectors)
            if self.detach_light_feature:
                light_features = light_features.detach_()
            # lmask_flat = self.light_network(None, None, None, light_features)
            lmask_flat = self.light_network(light_features)
            lmask = lmask_flat.reshape(-1, N_samples, 1)
            lmask_values = torch.sum(weights.unsqueeze(-1).detach() * lmask, 1)
            output['light_mask'] = lmask_values

        if predict_only:
            return output

        if self.training:
            # Sample points for the eikonal loss
            n_eik_points = batch_size * num_pixels
            eikonal_points = torch.empty(n_eik_points, 3, device=cam_loc.device).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere)

            # Add some of the near surface points
            eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3)
            n_eik_near = eik_near_points.size(0)
            eikonal_points = torch.cat([eikonal_points, eik_near_points], 0)

            # Add neighbor points near surface for smoothness loss
            eik_near_neighbors = eik_near_points + torch.empty_like(eik_near_points).uniform_(-0.005, 0.005)
            eikonal_points = torch.cat([eikonal_points, eik_near_neighbors], 0)
            grad_theta = self.implicit_network.gradient(eikonal_points)
            output['grad_theta'] = grad_theta[:n_eik_points+n_eik_near,]
            normals = grad_theta[n_eik_points:,]
            normals = F.normalize(normals, dim=1, eps=1e-6)
            diff_norm = torch.norm(normals[:n_eik_near,:] - normals[n_eik_near:,:], dim=1)
            output['diff_norm'] = diff_norm

            # Sample pointclouds for bubble loss
            if 'pointcloud' in input:
                surface_points = input['pointcloud']
                cam_loc_selected = cam_loc[np.random.randint(0, len(cam_loc)),:]
                surface_points = torch.cat([surface_points, cam_loc_selected.unsqueeze(0)], dim=0)
                surface_sdf = self.implicit_network.get_sdf_vals(surface_points)
                output['surface_sdf'] = surface_sdf[:-1,:]

            # Accumulate gradients for normal loss
            if self.use_normal:
                normals = F.normalize(gradients, dim=-1)
                normals = normals.reshape(-1, N_samples, 3)
                normal_map = torch.sum(weights.unsqueeze(-1).detach() * normals, 1)
                normal_map = F.normalize(normal_map, dim=-1)
                output['normal_values'] = normal_map

        # elif not self.training:
        else:
            # Accumulate gradients for normal visualization
            gradients = gradients.detach()
            normals = F.normalize(gradients, dim=-1)
            normals = normals.reshape(-1, N_samples, 3)
            normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1)
            normal_map = F.normalize(normal_map, dim=-1)
            output['normal_map'] = normal_map

        return output

    def volume_rendering(self, z_vals, z_max, sdf):
        density_flat = self.density(sdf)
        density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples

        # included also the dist from the sphere intersection
        dists = z_vals[:, 1:] - z_vals[:, :-1]
        dists = torch.cat([dists, z_max.unsqueeze(-1) - z_vals[:, -1:]], -1)

        # LOG SPACE
        free_energy = dists * density
        shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=free_energy.device), free_energy], dim=-1)  # add 0 for transperancy 1 at t_0
        alpha = 1 - torch.exp(-free_energy)  # probability of it is not empty here
        transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))  # probability of everything is empty up to now
        fg_transmittance = transmittance[:, :-1]
        weights = alpha * fg_transmittance  # probability of the ray hits something here
        bg_transmittance = transmittance[:, -1]  # factor to be multiplied with the bg volume rendering

        return weights, bg_transmittance

    def bg_volume_rendering(self, z_vals_bg, bg_sdf):
        bg_density_flat = self.bg_density(bg_sdf)
        bg_density = bg_density_flat.reshape(-1, z_vals_bg.shape[1]) # (batch_size * num_pixels) x N_samples

        bg_dists = z_vals_bg[:, :-1] - z_vals_bg[:, 1:]
        bg_dists = torch.cat([bg_dists, torch.tensor([1e10], device=bg_dists.device).unsqueeze(0).repeat(bg_dists.shape[0], 1)], -1)

        # LOG SPACE
        bg_free_energy = bg_dists * bg_density
        bg_shifted_free_energy = torch.cat([torch.zeros(bg_dists.shape[0], 1, device=bg_free_energy.device), bg_free_energy[:, :-1]], dim=-1)  # shift one step
        bg_alpha = 1 - torch.exp(-bg_free_energy)  # probability of it is not empty here
        bg_transmittance = torch.exp(-torch.cumsum(bg_shifted_free_energy, dim=-1))  # probability of everything is empty up to now
        bg_weights = bg_alpha * bg_transmittance # probability of the ray hits something here

        return bg_weights

    def depth2pts_outside(self, ray_o, ray_d, depth):

        '''
        ray_o, ray_d: [..., 3]
        depth: [...]; inverse of distance to sphere origin
        '''

        o_dot_d = torch.sum(ray_d * ray_o, dim=-1)
        under_sqrt = o_dot_d ** 2 - ((ray_o ** 2).sum(-1) - self.scene_bounding_sphere ** 2)
        d_sphere = torch.sqrt(under_sqrt) - o_dot_d
        p_sphere = ray_o + d_sphere.unsqueeze(-1) * ray_d
        p_mid = ray_o - o_dot_d.unsqueeze(-1) * ray_d
        p_mid_norm = torch.norm(p_mid, dim=-1)

        rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
        rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
        phi = torch.asin(p_mid_norm / self.scene_bounding_sphere)
        theta = torch.asin(p_mid_norm * depth)  # depth is inside [0, 1]
        rot_angle = (phi - theta).unsqueeze(-1)  # [..., 1]

        # now rotate p_sphere
        # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
        p_sphere_new = p_sphere * torch.cos(rot_angle) + \
                       torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
                       rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * (1. - torch.cos(rot_angle))
        p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
        pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)

        return pts


class I2SDFLoss(nn.Module):
    def __init__(self, eikonal_weight=0.1, smooth_weight=0.0, mask_weight=0.0, depth_weight=0.1, normal_weight=0.05, angular_weight=0.05, bubble_weight=0.0, min_bubble_iter=0, max_bubble_iter=None, smooth_iter=None, light_mask_weight=0.0, eikonal_weight_bubble=0.0):
        super().__init__()
        self.eikonal_weight = eikonal_weight
        self.rgb_loss = F.l1_loss
        self.smooth_weight = smooth_weight
        self.mask_weight = mask_weight
        self.depth_weight = depth_weight
        self.normal_weight = normal_weight
        self.angular_weight = angular_weight
        self.bubble_weight = bubble_weight
        # self.eikonal_weight_bubble = eikonal_weight_bubble if eikonal_weight_bubble else self.eikonal_weight
        self.min_bubble_iter = min_bubble_iter
        self.max_bubble_iter = max_bubble_iter
        self.smooth_iter = smooth_iter
        if self.bubble_weight > 0 and self.max_bubble_iter is not None and self.smooth_iter < self.max_bubble_iter:
            self.smooth_iter = self.max_bubble_iter # Disable smoothness loss during bubble steps
        self.light_mask_weight = light_mask_weight

    def get_rgb_loss(self, rgb_values, rgb_gt):
        rgb_gt = rgb_gt.reshape(-1, 3)
        rgb_loss = self.rgb_loss(rgb_values, rgb_gt)
        return rgb_loss

    def get_eikonal_loss(self, grad_theta):
        eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
        return eikonal_loss

    def get_mask_loss(self, mask_pred, mask_gt):
        return F.binary_cross_entropy(mask_pred.clip(1e-3, 1.0 - 1e-3), mask_gt)
    
    def get_depth_loss(self, depth, depth_gt, depth_mask):
        depth_gt = depth_gt.flatten()
        depth_mask = depth_mask.flatten()
        # TODO: Add support for scale invariant depth loss (like MonoSDF)
        return F.mse_loss(depth[depth_mask], depth_gt[depth_mask])
    
    def get_normal_l1_loss(self, normal, normal_gt, normal_mask):
        normal_gt = normal_gt.reshape(-1, 3)
        normal_mask = normal_mask.flatten()
        return torch.abs(1 - torch.sum(normal[normal_mask] * normal_gt[normal_mask], dim=-1)).mean()
    
    def get_normal_angular_loss(self, normal, normal_gt, normal_mask):
        normal_gt = normal_gt.reshape(-1, 3)
        normal_mask = normal_mask.flatten()
        dot = torch.sum(normal[normal_mask] * normal_gt[normal_mask], dim=-1)
        angle = torch.acos(torch.clamp(dot, -1.0+1e-6, 1.0-1e-6)) / math.tau
        return angle.clamp_max(0.5).abs().mean()

    def forward(self, model_outputs, ground_truth, current_step):
        rgb_gt = ground_truth['rgb']

        rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt)
        if 'grad_theta' in model_outputs:
            eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
        else:
            eikonal_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
        
        smooth_activated = self.smooth_iter is None or current_step > self.smooth_iter
        if smooth_activated and self.smooth_weight > 0 and 'diff_norm' in model_outputs:
            smooth_loss = model_outputs['diff_norm'].mean()
        else:
            smooth_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
        
        if 'mask' in ground_truth and self.mask_weight > 0:
            mask_loss = self.get_mask_loss(model_outputs['weight_sum'], ground_truth['mask'])
        else:
            mask_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
        
        if 'depth' in ground_truth and self.depth_weight > 0:
            depth_loss = self.get_depth_loss(model_outputs['depth_values'], ground_truth['depth'], ground_truth['depth_mask'])
        else:
            depth_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
        
        if 'normal' in ground_truth and self.normal_weight > 0:
            normal_loss = self.get_normal_l1_loss(model_outputs['normal_values'], ground_truth['normal'], ground_truth['normal_mask'])
        else:
            normal_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()

        if 'normal' in ground_truth and self.angular_weight > 0:
            angular_loss = self.get_normal_l1_loss(model_outputs['normal_values'], ground_truth['normal'], ground_truth['normal_mask'])
        else:
            angular_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
        
        if 'surface_sdf' in model_outputs and self.bubble_weight > 0:
            bubble_loss = model_outputs['surface_sdf'].abs().mean()
        else:
            bubble_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()
        
        if 'light_mask' in model_outputs and self.light_mask_weight > 0:
            light_mask_loss = self.get_mask_loss(model_outputs['light_mask'].reshape(-1, 1), ground_truth['light_mask'].reshape(-1, 1))
        else:
            light_mask_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()

        loss = rgb_loss + \
                self.eikonal_weight * eikonal_loss + \
                 self.smooth_weight * smooth_loss + \
                  self.mask_weight * mask_loss + \
                   self.depth_weight * depth_loss + \
                    self.normal_weight * normal_loss + \
                     self.angular_weight * angular_loss + \
                      self.bubble_weight * bubble_loss + \
                       self.light_mask_weight * light_mask_loss

        output = {
            'loss': loss,
            'rgb_loss': rgb_loss,
            'eikonal_loss': eikonal_loss,
            'smooth_loss': smooth_loss,
            'mask_loss': mask_loss,
            'depth_loss': depth_loss,
            'normal_loss': normal_loss,
            'angular_loss': angular_loss,
            'bubble_loss': bubble_loss,
            'light_mask_loss': light_mask_loss
        }

        return output


================================================
FILE: model/network/density.py
================================================
import torch.nn as nn
import torch


class Density(nn.Module):
    def __init__(self, params_init={}):
        super().__init__()
        for p in params_init:
            param = nn.Parameter(torch.tensor(params_init[p]))
            setattr(self, p, param)

    def forward(self, sdf, beta=None):
        return self.density_func(sdf, beta=beta)


class LaplaceDensity(Density):  # alpha * Laplace(loc=0, scale=beta).cdf(-sdf)
    def __init__(self, params_init={}, beta_min=0.0001):
        super().__init__(params_init=params_init)
        self.beta_min = torch.tensor(beta_min)

    def density_func(self, sdf, beta=None):
        if beta is None:
            beta = self.get_beta()

        alpha = 1 / beta
        return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta))

    def get_beta(self):
        beta = self.beta.abs() + self.beta_min
        return beta


class AbsDensity(Density):  # like NeRF++
    def density_func(self, sdf, beta=None):
        return torch.abs(sdf)


class SimpleDensity(Density):  # like NeRF
    def __init__(self, params_init={}, noise_std=1.0):
        super().__init__(params_init=params_init)
        self.noise_std = noise_std

    def density_func(self, sdf, beta=None):
        if self.training and self.noise_std > 0.0:
            noise = torch.randn(sdf.shape).to(sdf.device) * self.noise_std
            sdf = sdf + noise
        return torch.relu(sdf)


================================================
FILE: model/network/embedder.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Embedder:
    """ Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. """
    def __init__(self, **kwargs):
        self.kwargs = kwargs
        self.create_embedding_fn()

    def create_embedding_fn(self):
        embed_fns = []
        d = self.kwargs['input_dims']
        out_dim = 0
        if self.kwargs['include_input']:
            embed_fns.append(lambda x: x)
            out_dim += d

        max_freq = self.kwargs['max_freq_log2']
        N_freqs = self.kwargs['num_freqs']

        if self.kwargs['log_sampling']:
            freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)

        for freq in freq_bands:
            for p_fn in self.kwargs['periodic_fns']:
                embed_fns.append(lambda x, p_fn=p_fn,
                                 freq=freq: p_fn(x * freq))
                out_dim += d

        self.embed_fns = embed_fns
        self.out_dim = out_dim

    def embed(self, inputs):
        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)


class SHEncoder(nn.Module):
    def __init__(self, input_dims=3, degree=4):
    
        super().__init__()

        self.input_dims = input_dims
        self.degree = degree

        assert self.input_dims == 3
        assert self.degree >= 1 and self.degree <= 5

        self.out_dim = degree ** 2

        self.C0 = 0.28209479177387814
        self.C1 = 0.4886025119029199
        self.C2 = [
            1.0925484305920792,
            -1.0925484305920792,
            0.31539156525252005,
            -1.0925484305920792,
            0.5462742152960396
        ]
        self.C3 = [
            -0.5900435899266435,
            2.890611442640554,
            -0.4570457994644658,
            0.3731763325901154,
            -0.4570457994644658,
            1.445305721320277,
            -0.5900435899266435
        ]
        self.C4 = [
            2.5033429417967046,
            -1.7701307697799304,
            0.9461746957575601,
            -0.6690465435572892,
            0.10578554691520431,
            -0.6690465435572892,
            0.47308734787878004,
            -1.7701307697799304,
            0.6258357354491761
        ]

    def forward(self, input, **kwargs):

        result = torch.empty((*input.shape[:-1], self.out_dim), dtype=input.dtype, device=input.device)
        x, y, z = input.unbind(-1)

        result[..., 0] = self.C0
        if self.degree > 1:
            result[..., 1] = -self.C1 * y
            result[..., 2] = self.C1 * z
            result[..., 3] = -self.C1 * x
            if self.degree > 2:
                xx, yy, zz = x * x, y * y, z * z
                xy, yz, xz = x * y, y * z, x * z
                result[..., 4] = self.C2[0] * xy
                result[..., 5] = self.C2[1] * yz
                result[..., 6] = self.C2[2] * (2.0 * zz - xx - yy)
                #result[..., 6] = self.C2[2] * (3.0 * zz - 1) # xx + yy + zz == 1, but this will lead to different backward gradients, interesting...
                result[..., 7] = self.C2[3] * xz
                result[..., 8] = self.C2[4] * (xx - yy)
                if self.degree > 3:
                    result[..., 9] = self.C3[0] * y * (3 * xx - yy)
                    result[..., 10] = self.C3[1] * xy * z
                    result[..., 11] = self.C3[2] * y * (4 * zz - xx - yy)
                    result[..., 12] = self.C3[3] * z * (2 * zz - 3 * xx - 3 * yy)
                    result[..., 13] = self.C3[4] * x * (4 * zz - xx - yy)
                    result[..., 14] = self.C3[5] * z * (xx - yy)
                    result[..., 15] = self.C3[6] * x * (xx - 3 * yy)
                    if self.degree > 4:
                        result[..., 16] = self.C4[0] * xy * (xx - yy)
                        result[..., 17] = self.C4[1] * yz * (3 * xx - yy)
                        result[..., 18] = self.C4[2] * xy * (7 * zz - 1)
                        result[..., 19] = self.C4[3] * yz * (7 * zz - 3)
                        result[..., 20] = self.C4[4] * (zz * (35 * zz - 30) + 3)
                        result[..., 21] = self.C4[5] * xz * (7 * zz - 3)
                        result[..., 22] = self.C4[6] * (xx - yy) * (7 * zz - 1)
                        result[..., 23] = self.C4[7] * xz * (xx - 3 * yy)
                        result[..., 24] = self.C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))

        return result


class FourierFeature(nn.Module):
    def __init__(self, channels, sigma=1.0, input_dims=3, include_input=True) -> None:
        super().__init__()
        self.register_buffer('B', torch.randn(input_dims, channels) * sigma, True)
        self.channels = channels
        self.out_dim = 2 * self.channels + 3 if include_input else 2 * self.channels
        self.include_input = include_input
    
    def forward(self, x):
        xp = torch.matmul(2 * np.pi * x, self.B)
        return torch.cat([x, torch.sin(xp), torch.cos(xp)], dim=-1) if self.include_input else torch.cat([torch.sin(xp), torch.cos(xp)], dim=-1)


def get_embedder(embed_type='positional', **kwargs):
    if embed_type == 'positional':
        input_dims = kwargs['input_dims']
        multires = kwargs['multires']
        embed_kwargs = {
            'include_input': True,
            'input_dims': input_dims,
            'max_freq_log2': multires-1,
            'num_freqs': multires,
            'log_sampling': True,
            'periodic_fns': [torch.sin, torch.cos],
        }
        embedder_obj = Embedder(**embed_kwargs)
        def embed(x, eo=embedder_obj): return eo.embed(x)
        return embed, embedder_obj.out_dim
    elif embed_type == 'spherical_harmonics':
        embedder = SHEncoder(**kwargs)
        return embedder, embedder.out_dim
    elif embed_type == 'fourier':
        embedder = FourierFeature(**kwargs)
        return embedder, embedder.out_dim
    else:
        raise ValueError('Unknown embedding type: {}'.format(embed_type))

================================================
FILE: model/network/mlp.py
================================================
import torch.nn as nn
import numpy as np

import utils
from .embedder import *
from .density import LaplaceDensity
from .ray_sampler import ErrorBoundSampler


class ImplicitNetwork(nn.Module):
    def __init__(
            self,
            feature_vector_size,
            sdf_bounding_sphere,
            d_in,
            d_out,
            dims,
            geometric_init=True,
            bias=1.0,
            skip_in=(),
            weight_norm=True,
            embed_type=None,
            sphere_scale=1.0,
            output_activation=None,
            **kwargs
    ):
        super().__init__()

        self.sdf_bounding_sphere = sdf_bounding_sphere
        self.sphere_scale = sphere_scale
        dims = [d_in] + dims + [d_out + feature_vector_size]

        self.embed_fn = None
        if embed_type:
            embed_fn, input_ch = get_embedder(embed_type, input_dims=d_in, **kwargs)
            self.embed_fn = embed_fn
            dims[0] = input_ch
        
        print(f"[INFO] Implicit network dims: {dims}")

        self.num_layers = len(dims)
        self.skip_in = skip_in
        self.weight_norm = weight_norm

        for l in range(0, self.num_layers - 1):
            if l + 1 in self.skip_in:
                out_dim = dims[l + 1] - dims[0]
                if out_dim < 0:
                    print(dims)
            else:
                out_dim = dims[l + 1]

            lin = nn.Linear(dims[l], out_dim)

            if geometric_init:
                if l == self.num_layers - 2:
                    torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)
                    torch.nn.init.constant_(lin.bias, -bias)
                elif (embed_type or self.use_grid) and l == 0:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
                    torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))
                elif (embed_type or self.use_grid) and l in self.skip_in:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))
                    torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)
                else:
                    torch.nn.init.constant_(lin.bias, 0.0)
                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))

            if weight_norm:
                lin = nn.utils.weight_norm(lin)

            setattr(self, "lin" + str(l), lin)

        self.activation = nn.Softplus(beta=100)
        self.output_activation = None
        if output_activation is not None:
            self.output_activation = activations[output_activation]

    def get_param_groups(self, lr):
        return [{'params': self.parameters(), 'lr': lr}]

    def forward(self, input):

        if self.embed_fn is not None:
            input = self.embed_fn(input)

        x = input

        for l in range(0, self.num_layers - 1):
            lin = getattr(self, "lin" + str(l))

            if l in self.skip_in:
                x = torch.cat([x, input], 1) / np.sqrt(2)

            x = lin(x)

            if l < self.num_layers - 2:
                x = self.activation(x)

        if self.output_activation is not None:
            x = self.output_activation(x)
        
        return x

    def gradient(self, x):
        x.requires_grad_(True)
        y = self.forward(x)[:,:1]
        d_output = torch.ones_like(y, requires_grad=False, device=y.device)
        gradients = torch.autograd.grad(
            outputs=y,
            inputs=x,
            grad_outputs=d_output,
            create_graph=True,
            retain_graph=True,
            only_inputs=True)[0]
        return gradients

    def feature(self, x):
        return self.forward(x)[:,1:]

    def get_outputs(self, x, returns_grad=True):
        x.requires_grad_(returns_grad)
        output = self.forward(x)
        sdf = output[:,:1]
        ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
        if self.sdf_bounding_sphere > 0.0:
            sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
            sdf = torch.minimum(sdf, sphere_sdf)
        feature_vectors = output[:, 1:]
        if returns_grad:
            d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
            gradients = torch.autograd.grad(
                outputs=sdf,
                inputs=x,
                grad_outputs=d_output,
                create_graph=True,
                retain_graph=True,
                only_inputs=True)[0]
            return sdf, feature_vectors, gradients
        else:
            return sdf, feature_vectors, None

    def get_sdf_vals(self, x):
        sdf = self.forward(x)[:,:1]
        ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''
        if self.sdf_bounding_sphere > 0.0:
            sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))
            sdf = torch.minimum(sdf, sphere_sdf)
        return sdf

activations = {
    'sigmoid': nn.Sigmoid(),
    'relu': nn.ReLU(),
    'softplus': nn.Softplus()
}

class RenderingNetwork(nn.Module):
    def __init__(
            self,
            feature_vector_size,
            mode,
            d_in,
            d_out,
            dims,
            weight_norm=True,
            embed_type=None,
            embed_point=None,
            output_activation='sigmoid',
            **kwargs
    ):
        super().__init__()

        self.mode = mode
        dims = [d_in + feature_vector_size] + dims + [d_out]
        self.d_out = d_out

        self.embedview_fn = None
        if embed_type:
            embedview_fn, input_ch = get_embedder(embed_type, input_dims=3, **kwargs)
            self.embedview_fn = embedview_fn
            dims[0] += (input_ch - 3)
        
        if mode == 'idr':
            self.embedpoint_fn = None
            if embed_point is not None:
                embedpoint_fn, input_ch = get_embedder(input_dims=3, **embed_point)
                self.embedpoint_fn = embedpoint_fn
                dims[0] += (input_ch - 3)

        print(f"[INFO] Rendering network dims: {dims}")
        self.num_layers = len(dims)
        self.weight_norm = weight_norm

        for l in range(0, self.num_layers - 1):
            out_dim = dims[l + 1]
            lin = nn.Linear(dims[l], out_dim)

            if weight_norm:
                lin = nn.utils.weight_norm(lin)

            setattr(self, "lin" + str(l), lin)

        self.activation = nn.ReLU()
        self.output_activation = activations[output_activation]

    def forward(self, points, normals, view_dirs, feature_vectors):
        if self.embedview_fn is not None:
            view_dirs = self.embedview_fn(view_dirs)

        if self.mode == 'idr':
            rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)
        # elif self.mode == 'nerf':
        else:
            rendering_input = torch.cat([view_dirs, feature_vectors], dim=-1)
        x = rendering_input

        for l in range(0, self.num_layers - 1):
            lin = getattr(self, "lin" + str(l))

            x = lin(x)

            if l < self.num_layers - 2:
                x = self.activation(x)

        x = self.output_activation(x)

        return x



================================================
FILE: model/network/ray_sampler.py
================================================
import abc
import torch
from utils import rend_util
import utils

class RaySampler(metaclass=abc.ABCMeta):
    def __init__(self, near, far):
        self.near = near
        self.far = far

    @abc.abstractmethod
    def get_z_vals(self, ray_dirs, cam_loc, model):
        pass

class UniformSampler(RaySampler):
    def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1):
        super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far)  # default far is 2*R
        self.N_samples = N_samples
        self.scene_bounding_sphere = scene_bounding_sphere
        self.take_sphere_intersection = take_sphere_intersection

    def get_z_vals(self, ray_dirs, cam_loc, model):
        if not self.take_sphere_intersection:
            near, far = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device), self.far * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device)
        else:
            sphere_intersections = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)
            near = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device)
            far = sphere_intersections[:,1:]

        t_vals = torch.linspace(0., 1., steps=self.N_samples, device=ray_dirs.device)
        z_vals = near * (1. - t_vals) + far * (t_vals)

        if model.training:
            # get intervals between samples
            mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
            upper = torch.cat([mids, z_vals[..., -1:]], -1)
            lower = torch.cat([z_vals[..., :1], mids], -1)
            # stratified samples in those intervals
            t_rand = torch.rand(z_vals.shape, device=z_vals.device)

            z_vals = lower + (upper - lower) * t_rand

        return z_vals


class ErrorBoundSampler(RaySampler):
    def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra,
                 eps, beta_iters, max_total_iters,
                 inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0):
        super().__init__(near, 2.0 * scene_bounding_sphere)
        self.N_samples = N_samples
        self.N_samples_eval = N_samples_eval
        self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg)

        self.N_samples_extra = N_samples_extra

        self.eps = eps
        self.beta_iters = beta_iters
        self.max_total_iters = max_total_iters
        self.scene_bounding_sphere = scene_bounding_sphere
        self.add_tiny = add_tiny

        self.inverse_sphere_bg = inverse_sphere_bg
        if inverse_sphere_bg:
            self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)

    def get_z_vals(self, ray_dirs, cam_loc, model):
        beta0 = model.density.get_beta().detach()

        # Start with uniform sampling
        z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model)
        samples, samples_idx = z_vals, None

        # Get maximum beta from the upper bound (Lemma 2)
        dists = z_vals[:, 1:] - z_vals[:, :-1]
        bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)
        beta = torch.sqrt(bound)
        # beta = torch.sqrt(bound).clone()

        total_iters, not_converge = 0, True

        # Algorithm 1
        while not_converge and total_iters < self.max_total_iters:
            points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1)
            points_flat = points.reshape(-1, 3)

            # Calculating the SDF only for the new sampled points
            with torch.no_grad():
                samples_sdf = model.implicit_network.get_sdf_vals(points_flat)
            if samples_idx is not None:
                sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),
                                       samples_sdf.reshape(-1, samples.shape[1])], -1)
                sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)
            else:
                sdf = samples_sdf


            # Calculating the bound d* (Theorem 1)
            d = sdf.reshape(z_vals.shape)
            dists = z_vals[:, 1:] - z_vals[:, :-1]
            a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()
            first_cond = a.pow(2) + b.pow(2) <= c.pow(2)
            second_cond = a.pow(2) + c.pow(2) <= b.pow(2)
            # d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1, device=z_vals.device)
            # d_star[first_cond] = b[first_cond]
            # d_star[second_cond] = c[second_cond]
            s = (a + b + c) / 2.0
            area_before_sqrt = s * (s - a) * (s - b) * (s - c)
            mask = ~first_cond & ~second_cond & (b + c - a > 0)
            # d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])
            # Optimization: multiplication is 5-20 times faster than indexing
            first_cond = first_cond & ~second_cond
            d_star = first_cond * b + second_cond * c + torch.nan_to_num((2.0 * torch.sqrt(area_before_sqrt)) / a) * mask
            d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star  # Fixing the sign


            # Updating beta using line search
            curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)
            # beta[curr_error <= self.eps] = beta0
            # Optimization: multiplication is 5-20 times faster than indexing
            beta0_mask = curr_error <= self.eps
            beta = beta * ~beta0_mask + beta0 * beta0_mask
            beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta
            for j in range(self.beta_iters):
                beta_mid = (beta_min + beta_max) / 2.
                curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)
                # beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]
                # beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]
                beta_mid_mask = curr_error <= self.eps
                beta_max = beta_max * ~beta_mid_mask + beta_mid * beta_mid_mask
                beta_min = beta_min * beta_mid_mask + beta_mid * ~beta_mid_mask
            beta = beta_max


            # Upsample more points
            # tmp0 = beta.unsqueeze(-1).clone()
            # tmp0 = beta.unsqueeze(-1)
            # density = model.density(sdf.reshape(z_vals.shape), beta=tmp0)
            density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))

            # dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).unsqueeze(0).repeat(dists.shape[0], 1)], -1)
            dists = torch.cat([dists, torch.full([dists.shape[0], 1], 1e10, device=dists.device)], -1)
            free_energy = dists * density
            shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=free_energy.device), free_energy[:, :-1]], dim=-1)
            alpha = 1 - torch.exp(-free_energy)
            transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))
            weights = alpha * transmittance  # probability of the ray hits something here

            #  Check if we are done and this is the last sampling
            total_iters += 1
            not_converge = beta.max() > beta0

            if not_converge and total_iters < self.max_total_iters:
                ''' Sample more points proportional to the current error bound'''

                N = self.N_samples_eval

                bins = z_vals
                error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)
                # tmp0 = beta.unsqueeze(-1).clone()
                # tmp1 = -d_star / tmp0
                # tmp2 = dists[:,:-1] ** 2.
                # tmp3 = 4 * tmp0 ** 2
                # error_per_section = tmp1 * tmp2 / tmp3
                error_integral = torch.cumsum(error_per_section, dim=-1)
                bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]

                pdf = bound_opacity + self.add_tiny
                pdf = pdf / torch.sum(pdf, -1, keepdim=True)
                cdf = torch.cumsum(pdf, -1)
                cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)

            else:
                ''' Sample the final sample set to be used in the volume rendering integral '''

                N = self.N_samples

                bins = z_vals
                pdf = weights[..., :-1]
                pdf = pdf + 1e-5  # prevent nans
                pdf = pdf / torch.sum(pdf, -1, keepdim=True)
                cdf = torch.cumsum(pdf, -1)
                cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)  # (batch, len(bins))


            # Invert CDF
            if (not_converge and total_iters < self.max_total_iters) or (not model.training):
                u = torch.linspace(0., 1., steps=N, device=cdf.device).unsqueeze(0).repeat(cdf.shape[0], 1)
            else:
                u = torch.rand(list(cdf.shape[:-1]) + [N], device=cdf.device)
            u = u.contiguous()

            inds = torch.searchsorted(cdf, u, right=True)
            below = torch.max(torch.zeros_like(inds - 1), inds - 1)
            above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
            inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)

            matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
            cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
            bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

            denom = (cdf_g[..., 1] - cdf_g[..., 0])
            denom_mask = denom < 1e-5
            denom = denom_mask + ~denom_mask * denom
            # denom = torch.where(denom < 1e-5, 1.0, denom)
            t = (u - cdf_g[..., 0]) / denom
            samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])


            # Adding samples if we not converged
            if not_converge and total_iters < self.max_total_iters:
                z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)


        z_samples = samples

        near, far = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device), self.far * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device)
        if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection
            far = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:]

        if self.N_samples_extra > 0:
            if model.training:
                sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]
            else:
                sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()
            z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)
        else:
            z_vals_extra = torch.cat([near, far], -1)

        z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)

        # add some of the near surface points
        idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],), device=z_vals.device)
        z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))

        if self.inverse_sphere_bg:
            z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model)
            z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)
            z_vals = (z_vals, z_vals_inverse_sphere)

        return z_vals, z_samples_eik

    def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
        density = model.density(sdf.reshape(z_vals.shape), beta=beta)
        shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=dists.device), dists * density[:, :-1]], dim=-1)
        integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)
        error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)
        error_integral = torch.cumsum(error_per_section, dim=-1)
        bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])

        return bound_opacity.max(-1)[0]




================================================
FILE: model/rendering/__init__.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import utils
import cv2
from .brdf import *


class RenderingLayer(nn.Module):
    def __init__(self, spp, split_n_pixels, preserve_light=True) -> None:
        super().__init__()
        self.spp = spp
        self.split_n_pixels = split_n_pixels
        self.preserve_light = preserve_light
    
    def forward(
            self,
            model,
            surface_points,
            view_direction,
            Kd,
            Ks,
            normal,
            rough,
            radiance_scale=None,
            intersect_func=None
        ):
        """
        Render according to material, normal and lighting conditions
        Params:
            model: NeRF model to predict radiance
            surface_points, view_direction, albedo, normal, rough, metal: (bn, c)
        """
        bn = normal.size(0)

        cx, cy, cz = create_frame(normal)
        wi_x = torch.sum(cx*view_direction, dim=1)
        wi_y = torch.sum(cy*view_direction, dim=1)
        wi_z = torch.sum(cz*view_direction, dim=1)
        wi = torch.stack([wi_x, wi_y, wi_z], dim=1)
        wi_mask = (wi[:,2] >= 0.00001)
        wi[:,2,...] = torch.where(wi[:,2,...] < 0.00001, torch.ones_like(wi[:,2,...]) * 0.00001, wi[:,2,...])
        wi = F.normalize(wi, dim=1, eps=1e-6)
        # wi_mask = torch.where(wi[:,2:3,...] < 0, torch.zeros_like(wi[:,2:3,...]), torch.ones_like(wi[:,2:3,...]))
        wi = wi.unsqueeze(1) # (bn, 1, 3)

        # with torch.no_grad():
        if True:
            samples = torch.rand(bn, self.spp, 3, device=normal.device)
            pS = probabilityToSampleSpecular(Kd, Ks)
            clamp_value = 0.0
            pS.clamp_min_(clamp_value)
            sample_diffuse = samples[:,:,0] >= pS

            ls_diffuse = square_to_cosine_hemisphere(samples[:,:,1:])
            ls_specular = sample_ggx_specular(samples[:,:,1:], rough, wi)
            wo = torch.where(sample_diffuse.unsqueeze(2).expand(bn, self.spp, 3), ls_diffuse, ls_specular) # (bn, spp, 3)
        pdfs = pdf_ggx(Kd, Ks, rough, wi, wo, clamp_value).unsqueeze(2)
        eval_diff, eval_spec, wo_mask = eval_ggx(Kd, Ks, rough, wi, wo)
        # wo_mask = torch.all(wo_mask, dim=1)

        direction = to_global(wo, cx.unsqueeze(1), cy.unsqueeze(1), cz.unsqueeze(1))

        # surface_points = surface_points + 0.01 * view_direction # prevent self-intersection
        surface_points = surface_points.unsqueeze(1).expand_as(direction).reshape(-1, 3)
        direction = direction.reshape(-1, 3)
        surface_points = surface_points + direction * 0.01 # prevent self-intersection
        
        pts_splits = torch.split(surface_points, self.split_n_pixels, dim=0)
        dirs_splits = torch.split(direction, self.split_n_pixels, dim=0)
        radiance = []
        # with torch.no_grad():
        for pts, dirs in zip(pts_splits, dirs_splits):
            radiance.append(model.get_incident_radiance(pts, dirs, intersect_func))
        radiance = torch.cat(radiance, dim=0)

        radiance = radiance.view(bn, self.spp, 3)
        if radiance_scale is not None:
            radiance = radiance * radiance_scale[None,None,:]
        pdfs = torch.clamp(pdfs, min=0.00001)
        ndl = torch.clamp(wo[:,:,2:], min=0)

        brdfDiffuse = eval_diff.expand(bn, self.spp, 3) * ndl / pdfs
        colorDiffuse = torch.mean(brdfDiffuse * radiance, dim=1)
        brdfSpec = eval_spec.expand(bn, self.spp, 3) * ndl / pdfs
        colorSpec = torch.mean(brdfSpec * radiance, dim=1)

        return colorDiffuse, colorSpec, wi_mask




================================================
FILE: model/rendering/brdf.py
================================================
import torch
import torch.nn.functional as F
import numpy as np

def create_frame(n: torch.Tensor, eps:float = 1e-6):
    """
    Generate orthonormal coordinate system based on surface normal
    [Duff et al. 17] Building An Orthonormal Basis, Revisited. JCGT. 2017.
    :param: n (bn, 3, ...)
    """
    z = F.normalize(n, dim=1, eps=eps)
    sgn = torch.where(z[:,2,...] >= 0, 1.0, -1.0)
    a = -1.0 / (sgn + z[:,2,...])
    b = z[:,0,...] * z[:,1,...] * a
    x = torch.stack([1.0 + sgn * z[:,0,...] * z[:,0,...] * a, sgn * b, -sgn * z[:,0,...]], dim=1)
    y = torch.stack([b, sgn + z[:,1,...] * z[:,1,...] * a, -z[:,1,...]], dim=1)
    return x, y, z


def get_rendering_parameters(albedo_raw, rough_raw, use_metallic):
    if use_metallic:
        assert albedo_raw.size(-1) == 3 and rough_raw.size(-1) == 2
        metal = rough_raw[:,1:]
        rough = rough_raw[:,:1].clamp_min(0.01)
        Ks = baseColorToSpecularF0(albedo_raw, metal)
        Kd = albedo_raw * (1 - metal)
    else:
        assert albedo_raw.size(-1) == 6 and rough_raw.size(-1) == 1
        Kd = albedo_raw[:,:3]
        Ks = albedo_raw[:,3:].clamp_min(0.04)
        rough = rough_raw.clamp_min(0.01)
    return Kd, Ks, rough


def to_global(d, x, y, z):
    """
    d, x, y, z: (*, 3)
    """
    return d[...,0:1] * x + d[...,1:2] * y + d[...,2:3] * z

def sqrt_(x: torch.Tensor, eps=1e-8) -> torch.Tensor:
    """
    clamping 0 values of sqrt input to avoid NAN gradients
    """
    return torch.sqrt(torch.clamp(x, min=eps))

def reflect(v: torch.Tensor, h: torch.Tensor):
    dot = torch.sum(v*h, dim=2, keepdim=True)
    return 2 * dot * h - v

def square_to_cosine_hemisphere(sample: torch.Tensor):
    u, v = sample[:,:,0,...], sample[:,:,1,...]
    phi = u * 2 * np.pi
    r = sqrt_(v)
    cos_theta = sqrt_(torch.clamp(1 - v, 0))
    return torch.stack([torch.cos(phi) * r, torch.sin(phi) * r, cos_theta], dim=2)


def get_cos_theta(v: torch.Tensor):
    return v[:,:,2,...]


def get_phi(v: torch.Tensor):
    cos_theta = torch.clamp(v[:,:,2,...], min=0, max=1)
    sin_theta = torch.clamp(sqrt_(1 - cos_theta*cos_theta), min=1e-8)
    cos_phi = torch.clamp(v[:,:,0,...] / sin_theta, -1, 1)
    sin_phi = v[:,:,1,...] / sin_theta
    phi = torch.acos(cos_phi) # (0, pi)
    return torch.where(sin_phi > 0, phi, 2*np.pi - phi)


def sample_disney_specular(sample: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor):
    """
    :param: sample (bn, spp, 3, h, w)
    :param: roughness (bn, 1, 1, h, w)
    :param: wi (*, *, 3, h, w), supposed to be normalized
    :return: wo (bn, spp, 3, h, w), phi (bn, spp, h, w), cos theta (bn, spp, h, w)
    """
    # a = torch.clamp(roughness, 0.001)
    a = roughness
    u, v = sample[:,:,0,...], sample[:,:,1,...]
    phi = u * 2 * np.pi
    cos_theta = sqrt_((1 - v) / (1 + (a*a - 1) * v))
    sin_theta = sqrt_(1 - cos_theta*cos_theta)
    cos_phi = torch.cos(phi)
    sin_phi = torch.sin(phi)
    half = torch.stack([sin_theta*cos_phi, sin_theta*sin_phi, cos_theta], dim=2)
    wo = F.normalize(reflect(wi.expand_as(half), half), dim=2, eps=1e-8)
    return wo
    #, phi.squeeze(2), cos_theta.squeeze(2)



def GTR2(ndh, a):
    a2 = a*a
    t = 1.0 + (a2 - 1.0) * ndh * ndh
    return a2 / (np.pi * t * t)

def SchlickFresnel(u):
    m = torch.clamp(1.0 - u, 0, 1)
    return m**5

def smithG_GGX(ndv, a):
    a = a*a
    b = ndv*ndv
    return 1.0 / (ndv + sqrt_(a + b - a * b))


def pdf_disney(roughness: torch.Tensor, metallic: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor):
    """
    :param: roughness/metallic (bn, 1, h, w)
    :param: wi (*, *, 3, h, w), supposed to be normalized
    :param: wo (*, *, 3, h, w), supposed to be normalized
    """
    # specularAlpha = torch.clamp(roughness, 0.001)
    specularAlpha = roughness
    diffuseRatio = 0.5 * (1 - metallic)
    specularRatio = 1 - diffuseRatio
    half = F.normalize(wi + wo, dim=2, eps=1e-8)
    cosTheta = torch.abs(half[:,:,2,...])
    pdfGTR2 = GTR2(cosTheta, specularAlpha) * cosTheta
    pdfSpec = pdfGTR2 / torch.clamp(4.0 * torch.abs(torch.sum(wo*half, dim=2)), min=1e-8)
    pdfDiff = torch.abs(wo[:,:,2,...]) / np.pi
    pdf = diffuseRatio * pdfDiff + specularRatio * pdfSpec
    pdf = torch.where(wi[:,:,2,...] < 0.0001, torch.ones_like(pdf) * 0.0001, pdf)
    pdf = torch.where(wo[:,:,2,...] < 0.0001, torch.ones_like(pdf) * 0.0001, pdf)
    return pdf


def eval_disney(albedo: torch.Tensor, roughness: torch.Tensor, metallic: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor):
    """
    :param: albedo/roughness/metallic (bn, c, h, w)
    :param: wi (*, *, 3, h, w), supposed to be normalized
    :param: wo (*, *, 3, h, w), supposed to be normalized
    """
    h = wi + wo;
    h = F.normalize(h, dim=2, eps=1e-8)
    
    CSpec0 = torch.lerp(torch.ones_like(albedo)*0.04, albedo, metallic).unsqueeze(1)

    ldh = torch.clamp(torch.sum( (wo * h), dim = 2), 0, 1).unsqueeze(2)
    ndv = wi[:,:,2:3,...]
    ndl = wo[:,:,2:3,...]
    ndh = h[:,:,2:3,...]

    FL, FV = SchlickFresnel(ndl), SchlickFresnel(ndv)
    roughness = roughness.unsqueeze(1)
    Fd90 = 0.5 + 2.0 * ldh * ldh * roughness
    Fd = torch.lerp(torch.ones_like(Fd90), Fd90, FL) * torch.lerp(torch.ones_like(Fd90), Fd90, FV)

    Ds = GTR2(ndh, roughness)
    FH = SchlickFresnel(ldh)
    Fs = torch.lerp(CSpec0, torch.ones_like(CSpec0), FH)
    roughg = (roughness * 0.5 + 0.5) ** 2
    Gs1, Gs2 = smithG_GGX(ndl, roughg), smithG_GGX(ndv, roughg)
    Gs = Gs1 * Gs2

    eval_diff = Fd * albedo.unsqueeze(1) * (1.0 - metallic.unsqueeze(1)) / np.pi
    eval_spec = Gs * Fs * Ds
    mask = torch.where(ndl < 0, torch.zeros_like(ndl), torch.ones(ndl))
    return eval_diff, eval_spec, mask


def F_Schlick(SpecularColor, VoH):
	Fc = (1 - VoH)**5
	return torch.clamp(50.0 * SpecularColor[:,:,1:2,...], min=0, max=1) * Fc + (1 - Fc) * SpecularColor

def GetSpecularEventProbability(SpecularColor, NoV) -> torch.Tensor:
	f = F_Schlick(SpecularColor, NoV);
	return (f[:,:,0,...] + f[:,:,1,...] + f[:,:,2,...]) / 3

def baseColorToSpecularF0(baseColor, metalness):
    return torch.lerp(torch.empty_like(baseColor).fill_(0.04), baseColor, metalness)

def luminance(color):
    if color.size(1) == 1:
        return color
    # return color.mean(dim=1, keepdim=True)
    return color[:,0:1,...] * 0.212671 + color[:,1:2,...] * 0.715160 + color[:,2:3,...] * 0.072169

def probabilityToSampleSpecular(difColor, specColor) -> torch.Tensor:
    lumDiffuse = torch.clamp(luminance(difColor), min=0.01)
    lumSpecular = torch.clamp(luminance(specColor), min=0.01)
    return lumSpecular / (lumDiffuse + lumSpecular)

def shadowedF90(F0):
    t = 1 / 0.04
    return torch.clamp(t * luminance(F0), max=1)

def evalFresnel(f0, f90, NdotS):
    # print(f0.shape, f90.shape, NdotS.shape)
    return f0 + (f90 - f0) * (1 - NdotS)**5

def Smith_G1_GGX(alphaSquared, NdotSSquared):
    return 2 / (sqrt_(((alphaSquared * (1 - NdotSSquared)) + NdotSSquared) / NdotSSquared) + 1)

def Smith_G2_GGX(alphaSquared, NdotL, NdotV):
	a = NdotV * sqrt_(alphaSquared + NdotL * (NdotL - alphaSquared * NdotL))
	b = NdotL * sqrt_(alphaSquared + NdotV * (NdotV - alphaSquared * NdotV))
	return 0.5 / (a + b)

def GGX_D(alphaSquared, NdotH):
    b = ((alphaSquared - 1) * NdotH * NdotH + 1)
    return alphaSquared / (np.pi * b * b)

def pdf_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor, ps_min=0.0):
    """
    :param: color (bn, 3, h, w)
    :param: roughness/metallic (bn, 1, h, w)
    :param: wi (*, *, 3, h, w), supposed to be normalized
    :param: wo (*, *, 3, h, w), supposed to be normalized
    :return: pdf (*, *, h, w)
    """
    alpha = roughness * roughness
    alphaSquared = alpha * alpha
    NdotV = wi[:,:,2,...]
    h = F.normalize(wi + wo, dim=2, eps=1e-8)
    NdotH = h[:,:,2,...]
    # print(alphaSquared.min(), NdotH.min(), NdotV.min())
    ggxd = GGX_D(torch.clamp(alphaSquared, min=0.00001), NdotH)
    smith = Smith_G1_GGX(alphaSquared, NdotV * NdotV)
    # pdf_spec = GGX_D(torch.clamp(alphaSquared, min=0.00001), NdotH) * Smith_G1_GGX(alphaSquared, NdotV * NdotV) / (4 * NdotV)
    pdf_spec = ggxd * smith / (4 * NdotV)
    # print(torch.any(torch.isnan(ggxd)), torch.any(torch.isnan(smith)), torch.any(torch.isnan(NdotV)))
    # print(NdotV.min(), ggxd.min(), smith.min())
    with torch.no_grad():
        pS = probabilityToSampleSpecular(Kd, Ks).clamp_min(ps_min)
    pdf_diff = wo[:,:,2,...] / np.pi
    # print("#########################################")
    # print("#########################################")
    # print("#########################################")
    # print(torch.any(torch.isnan(kS)), torch.any(torch.isnan(pdf_spec)), torch.any(torch.isnan(pdf_diff)))
    # print("#########################################")
    # print("#########################################")
    # print("#########################################")
    pdf = pS * pdf_spec + (1 - pS) * pdf_diff
    pdf = torch.where(wi[:,:,2,...] <= 0.0001, torch.ones_like(pdf) * 0.0001, pdf)
    pdf = torch.where(wo[:,:,2,...] <= 0.0001, torch.ones_like(pdf) * 0.0001, pdf)
    return pdf

def eval_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor):
    """
    :param: color (bn, c, h, w)
    :param: roughness/metallic (bn, 1, h, w)
    :param: wi (*, *, c, h, w), supposed to be normalized
    :param: wo (*, *, c, h, w), supposed to be normalized
    :return: fr(wi, wo) (*, *, c, h, w)
    """
    NDotL = wo[:,:,2:3,...]
    NDotV = wi[:,:,2:3,...]
    H = F.normalize(wi + wo, dim=2, eps=1e-8)
    NDotH = H[:,:,2:3,...]
    LDotH = torch.sum(wo*H, dim=2, keepdim=True)
    roughness = roughness.unsqueeze(1)
    alpha = roughness * roughness
    alpha2 = alpha * alpha
    D = GGX_D(torch.clamp(alpha2, min=0.00001), NDotH)
    G2 = Smith_G2_GGX(alpha2, NDotL, NDotV)
    f = evalFresnel(Ks.unsqueeze(1), shadowedF90(Ks).unsqueeze(1), LDotH)
    # spec = torch.where(NDotL <= 0, torch.zeros_like(NDotL), f * G2 * D)
    # mask = torch.where(NDotL <= 0, torch.zeros_like(NDotL), torch.ones_like(NDotL))
    spec = torch.where(NDotL < 0.0001, torch.ones_like(NDotL) * 0.0001, f * G2 * D)
    # mask = torch.where(NDotL <= 0, torch.zeros_like(NDotL), torch.ones_like(NDotL))
    mask = (NDotL >= 0.0001).squeeze(-1)
    return Kd.unsqueeze(1) / np.pi, spec, mask


def sample_weight_ggx(alphaSquared, NdotL, NdotV):
    G1V = Smith_G1_GGX(alphaSquared, NdotV*NdotV)
    G1L = Smith_G1_GGX(alphaSquared, NdotL*NdotL)
    return G1L / (G1V + G1L - G1V * G1L)

def sample_ggx(sample: torch.Tensor, Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor):
    """
    :param: sample (bn, spp, 3, h, w)
    :param: roughness (bn, 1, h, w)
    :param: wi (*, *, 3, h, w), supposed to be normalized
    :return: wo (bn, spp, 3, h, w), weight (bn, spp, 3, h, w)
    """
    with torch.no_grad():
        pS = probabilityToSampleSpecular(Kd, Ks)
    sample_diffuse = sample[:,:,2,...] >= pS

    wo_diff = square_to_cosine_hemisphere(sample[:,:,1:,...])
    weight_diff = Kd / (1 - pS)
    weight_diff = weight_diff.unsqueeze(1)

    roughness = roughness.unsqueeze(1)
    alpha = roughness * roughness
    # alpha = roughness
    Vh = F.normalize(torch.cat([alpha * wi[:,:,0:1,...], alpha * wi[:,:,1:2,...], wi[:,:,2:3,...]], dim=2), dim=2, eps=1e-8)
    lensq = Vh[:,:,0:1,...]**2 + Vh[:,:,1:2,...]**2
    zero_ = torch.zeros_like(Vh[:,:,0,...])
    one_ = torch.ones_like(Vh[:,:,0,...])
    T1 = torch.where(
        lensq > 0, 
        torch.stack([-Vh[:,:,1,...], Vh[:,:,0,...], zero_], dim=2) / sqrt_(lensq),
        torch.stack([one_, zero_, zero_], dim=2)
    )
    T2 = torch.cross(Vh, T1, dim=2)
    r = sqrt_(sample[:,:,0:1,...])
    phi = 2 * np.pi * sample[:,:,1:2,...]
    t1 = r * torch.cos(phi)
    t2 = r * torch.sin(phi)
    s = 0.5 * (1 + Vh[:,:,2:3,...])
    t2 = torch.lerp(sqrt_(1 - t1**2), t2, s)
    Nh = t1 * T1 + t2 * T2 + sqrt_(torch.clamp(1 - t1*t1 - t2*t2, min=0)) * Vh
    h = F.normalize(torch.cat([alpha * Nh[:,:,0:1,...], alpha * Nh[:,:,1:2,...], torch.clamp(Nh[:,:,2:3,...], min=0)], dim=2), dim=2, eps=1e-8)
    wo = reflect(wi, h)

    HdotL = torch.clamp(torch.sum(h*wo, dim=2, keepdim=True), min=0.0001, max=1.0)
    NdotL = torch.clamp(wo[:,:,2:3,...], min=0.0001, max=1.0)
    NdotV = torch.clamp(wi[:,:,2:3,...], min=0.0001, max=1.0)
    # NdotH = torch.clamp(h[:,:,2:3,...], min=0.00001, max=1.0)
    # F = evalFresnel(specularF0, shadowedF90(specularF0), HdotL)
    weight = evalFresnel(Ks, shadowedF90(Ks), HdotL) * sample_weight_ggx(alpha*alpha, NdotL, NdotV) / pS.unsqueeze(1)

    wo = torch.where(sample_diffuse.unsqueeze(2), wo_diff, wo)
    weight = torch.where(sample_diffuse.unsqueeze(2), weight_diff, weight)

    return wo, weight



def sample_ggx_specular(sample: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor):
    """
    :param: sample (bn, spp, 2, h, w)
    :param: roughness (bn, 1, h, w)
    :param: wi (*, *, 3, h, w), supposed to be normalized
    :return: wo (bn, spp, 3, h, w), phi (bn, spp, h, w), cos theta (bn, spp, h, w)
    """
    roughness = roughness.unsqueeze(1)
    alpha = roughness * roughness
    # alpha = roughness
    Vh = F.normalize(torch.cat([alpha * wi[:,:,0:1,...], alpha * wi[:,:,1:2,...], wi[:,:,2:3,...]], dim=2), dim=2, eps=1e-8)
    # bn, spp, _, row, col = Vh.shape
    # Vh = Vh.view(-1, 3, row, col)
    # T1, T2, Vh = utils.hughes_moeller(Vh)
    # T1 = T1.view(bn, spp, 3, row, col)
    # T2 = T2.view(bn, spp, 3, row, col)
    # Vh = Vh.view(bn, spp, 3, row, col)
    lensq = Vh[:,:,0:1,...]**2 + Vh[:,:,1:2,...]**2
    zero_ = torch.zeros_like(Vh[:,:,0,...])
    one_ = torch.ones_like(Vh[:,:,0,...])
    T1 = torch.where(
        lensq > 0, 
        torch.stack([-Vh[:,:,1,...], Vh[:,:,0,...], zero_], dim=2) / sqrt_(lensq),
        torch.stack([one_, zero_, zero_], dim=2)
    )
    T2 = torch.cross(Vh, T1, dim=2)
    r = sqrt_(sample[:,:,0:1,...])
    phi = 2 * np.pi * sample[:,:,1:2,...]
    t1 = r * torch.cos(phi)
    t2 = r * torch.sin(phi)
    s = 0.5 * (1 + Vh[:,:,2:3,...])
    t2 = torch.lerp(sqrt_(1 - t1**2), t2, s)
    Nh = t1 * T1 + t2 * T2 + sqrt_(torch.clamp(1 - t1*t1 - t2*t2, min=0)) * Vh
    h = F.normalize(torch.cat([alpha * Nh[:,:,0:1,...], alpha * Nh[:,:,1:2,...], torch.clamp(Nh[:,:,2:3,...], min=0)], dim=2), dim=2, eps=1e-8)
    wo = reflect(wi, h)
    return wo

================================================
FILE: model/trainer/__init__.py
================================================
from .recon import ReconstructionTrainer

================================================
FILE: model/trainer/recon.py
================================================
import math
import torch
import pytorch_lightning as pl
import numpy as np
import torch.optim as optim
import os
from torch.utils.data import DataLoader
import utils
from utils import rend_util
import utils.plots as plt
import dataset
import model
from tqdm import trange
from pytorch_lightning.callbacks import RichProgressBar
from torchmetrics.functional import structural_similarity_index_measure as ssim
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS
os.environ["OPENCV_IO_ENABLE_OPENEXR"]="1"
import cv2


lpips = LPIPS()

class ReconstructionTrainer(pl.LightningModule):
    def __init__(self, conf, prog_bar: RichProgressBar, exp_dir, model_only=False, val_mesh=False, is_val=False, **kwargs) -> None:
        super().__init__()
        self.conf = conf
        self.prog_bar = prog_bar
        self.batch_size = conf.train.batch_size
        self.bubble_batch_size = getattr(conf.train, 'bubble_batch_size', self.batch_size)
        self.expdir = exp_dir
        self.val_mesh = val_mesh

        conf_model = conf.model
        use_normal = (getattr(conf.loss, 'normal_weight', 0) > 0) or (getattr(conf.loss, 'angular_weight', 0) > 0)
        conf_model.use_normal = use_normal
        self.model = model.I2SDFNetwork(conf_model)
        
        if model_only:
            return

        print('[INFO] Loading data ...')
        dataset_conf = conf.dataset
        self.scan_id = dataset_conf.scan_id
        self.train_dataset = dataset.ReconDataset(
                **dataset_conf, 
                use_mask=getattr(conf.loss, 'mask_weight', 0) > 0, 
                use_depth=getattr(conf.loss, 'depth_weight', 0) > 0, 
                use_normal=use_normal, 
                use_bubble=getattr(conf.loss, 'bubble_weight', 0) > 0, 
                use_lightmask=getattr(conf.loss, 'light_mask_weight', 0) > 0
            )
        if self.train_dataset.use_bubble:
            os.makedirs(os.path.join(self.expdir, 'hotmap'), exist_ok=True)
            os.makedirs(os.path.join(self.expdir, 'countmap'), exist_ok=True)
            self.pdf_criterion = getattr(conf.train, 'pdf_criterion', 'DEPTH')
            assert self.pdf_criterion in ['RGB', 'DEPTH']

        self.is_hdr = self.train_dataset.is_hdr
        self.plots_dir = os.path.join(self.expdir, 'plots')
        self.trace_bub_idx = self.conf.train.get('trace_bub_idx', -1)
        if self.trace_bub_idx != -1:
            os.makedirs(f"{self.plots_dir}/bubble", exist_ok=True)
            print(f"[INFO] Activate hotmap visualization for #{self.trace_bub_idx}")
            self.plot_dataset = dataset.PlotDataset(**dataset_conf, indices=[self.trace_bub_idx], plot_nimgs=1, is_val=is_val)
        else:
            data = {
                'intrinsics': self.train_dataset.intrinsics_all,
                'pose': self.train_dataset.pose_all,
                'rgb': self.train_dataset.rgb_images,
                'img_res': self.train_dataset.img_res
            }
            if self.train_dataset.use_lightmask:
                data['light_mask'] = self.train_dataset.lightmask_images
            self.plot_dataset = dataset.PlotDataset(**dataset_conf, data=data, plot_nimgs=conf.plot.plot_nimgs, is_val=is_val)

        os.makedirs(self.plots_dir, exist_ok=True)
        with open(f"{self.expdir}/config.yml", 'w') as f:
            f.write(self.conf.dump())
        if self.train_dataset.use_bubble:
            points = self.train_dataset.pointcloud
            index = torch.randperm(points.size(0))[:200000]
            points = points[index,:]
            plt.visualize_pointcloud(points, f"{self.expdir}/pointcloud.html")
            print(f"[INFO] Pointcloud visualization success: saved to {self.expdir}/pointcloud.html")
            self.pdf_prune = self.train_dataset.pdf_prune
            self.pdf_max = self.train_dataset.pdf_max

        # self.ds_len = len(self.train_dataset)
        self.ds_len = self.train_dataset.n_images
        print('[INFO] Finish loading data. Data-set size: {0}'.format(self.ds_len))
        epoch_steps = len(self.train_dataset) / self.batch_size
        self.nepochs = int(math.ceil(200000 / epoch_steps))

        self.loss = model.I2SDFLoss(**conf.loss)
        self.total_pixels = self.plot_dataset.total_pixels
        self.img_res = self.plot_dataset.img_res
        self.bubble_activated = False
        self.uniform_bubble = getattr(self.conf.train, 'uniform_bubble', False)
        if self.uniform_bubble:
            print("[INFO] Ablation study: uniform sampling for bubble loss")
        self.checkpoint_freq = self.conf.train.checkpoint_freq
        self.split_n_pixels = self.conf.train.split_n_pixels
        self.plot_conf = self.conf.plot
        self.progbar_task = None
        if self.train_dataset.use_lightmask and getattr(self.conf.train, 'flip_light', False):
            self.train_dataset.lightmask_images = 1.0 - self.train_dataset.lightmask_images
            self.plot_dataset.lightmask_images = 1.0 - self.plot_dataset.lightmask_images

    def forward(self):
        raise NotImplementedError("forward not supported by trainer")

    def plot_hotmap(self, path):
        assert self.bubble_activated
        ds = self.train_dataset
        hotmaps = torch.zeros(self.ds_len * ds.total_pixels)
        hotmaps[ds.pixlinks] = self.pdf.cpu()
        hotmaps = hotmaps.reshape(self.ds_len, *ds.img_res)
        for i, hotmap in enumerate(hotmaps):
            hotmap = hotmap.numpy()
            # hotmap /= max(1e-4, hotmap.max())
            hotmap = (hotmap * 255).astype(np.uint8)
            hotmap = cv2.applyColorMap(hotmap, cv2.COLORMAP_MAGMA)
            cv2.imwrite(os.path.join(path, "{:04d}.png".format(i)), hotmap)
            if self.trace_bub_idx == i:
                cv2.imwrite(os.path.join(f"{self.plots_dir}/bubble", f"{self.global_step}_hot.png"), hotmap)
    
    def plot_countmap(self, path):
        assert self.bubble_activated
        ds = self.train_dataset
        countmaps = torch.zeros(self.ds_len * ds.total_pixels)
        countmaps[ds.pixlinks] = self.sample_count.cpu().float()
        countmaps = countmaps.reshape(self.ds_len, *ds.img_res)
        countmaps = countmaps / max(1, countmaps.max())
        for i, countmap in enumerate(countmaps):
            countmap = countmap.numpy()
            countmap = (countmap * 255).astype(np.uint8)
            countmap = cv2.applyColorMap(countmap, cv2.COLORMAP_MAGMA)
            cv2.imwrite(os.path.join(path, "{:04d}.png".format(i)), countmap)
            if self.trace_bub_idx == i:
                cv2.imwrite(os.path.join(f"{self.plots_dir}/bubble", f"{self.global_step}_cnt.png"), countmap)

    def update_pdf(self, value, idx):
        assert self.bubble_activated
        ds = self.train_dataset
        value = value.to(self.pdf.device)
        if self.pdf_max is not None:
            value = value.clamp(max=self.pdf_max)
        value[value < self.pdf_prune] = 0 # PDF pruning
        link = ds.pointlinks[idx]
        mask = (link != -1)
        link = link[mask]
        value = value[mask]
        self.pdf[link] = value
    
    def sample_bubble(self, batch_size):
        assert self.bubble_activated
        ds = self.train_dataset
        if self.uniform_bubble:
            sample_idx = torch.randperm(ds.pointcloud.size(0), device=ds.pointcloud.device)[:batch_size]
            return ds.pointcloud[sample_idx,:]
        sample_idx = torch.where(self.pdf > 0)[0]
        pdf_samples = self.pdf[sample_idx]
        pointcloud_samples = ds.pointcloud[sample_idx,:]
        if sample_idx.size(0) >= (1 << 24):
            # print(sample_idx.size(0), self.pdf.size(0), (1 << 24))
            print("[ERROR] PDF capacity exceeds maximum limit of PyTorch")
            exit(1)
        idx = torch.multinomial(pdf_samples, batch_size, replacement=False) # importance sampling
        self.sample_count[sample_idx[idx]] += 1
        return pointcloud_samples[idx,:]

    def initialize_bubble_pdf(self, split_size):
        ds = self.train_dataset
        # ds.pdf = ds.pdf.cuda()
        self.register_buffer('pdf', torch.zeros(len(ds.pointcloud)), False)
        self.register_buffer('sample_count', torch.zeros(len(ds.pointcloud)), False)
        self.pdf = self.pdf.cuda()
        # self.sample_count = self.sample_count.cuda()
        for i in trange(ds.n_images):
            intrinsics = ds.intrinsics_all[i].cuda().unsqueeze(0)
            pose = ds.pose_all[i].cuda().unsqueeze(0)
            img = ds.rgb_images[i].cuda() if self.pdf_criterion != 'DEPTH' else ds.depth_images[i].cuda()
            uv = ds.uv.cuda().unsqueeze(1) # (h*w, 1, 2)
            img_splits = torch.split(img, split_size)
            uv_splits = torch.split(uv, split_size)
            indices = torch.arange(i * ds.total_pixels, (i + 1) * ds.total_pixels, dtype=torch.long, device='cuda')
            index_splits = torch.split(indices, split_size)
            for img_split, uv_split, index_split in zip(img_splits, uv_splits, index_splits):
                data = {
                    'uv': uv_split,
                    'intrinsics': intrinsics.repeat(len(uv_split), 1, 1),
                    'pose': pose.repeat(len(uv_split), 1, 1)
                }
                model_output = self.model.forward(data, True)
                if self.pdf_criterion == 'RGB':
                    self.update_pdf((model_output['rgb_values'].detach().clamp(0, 1) - img_split.clamp(0, 1)).abs().mean(dim=-1), index_split)
                # elif self.pdf_criterion == 'DEPTH':
                else:
                    self.update_pdf((model_output['depth_values'].detach() - img_split).abs(), index_split)

    def configure_optimizers(self):
        lr = self.conf.train.learning_rate
        optimizer = optim.Adam(self.model.get_param_groups(lr), eps=1e-15)
        decay_rate = getattr(self.conf.train, 'sched_decay_rate', 0.1)
        decay_steps = self.nepochs * self.ds_len
        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, decay_rate ** (1./decay_steps))
        return [optimizer], [scheduler]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.train_dataset.collate_fn, num_workers=4)
        
    def val_dataloader(self):
        return DataLoader(self.plot_dataset, batch_size=self.conf.plot.plot_nimgs, shuffle=False, collate_fn=self.train_dataset.collate_fn)

    def log_if_nonzero(self, name, value, *args, **kwargs):
        if value > 0:
            self.log(name, value, *args, **kwargs)

    def training_step(self, batch, batch_idx):
        indices, img_indices, model_input, ground_truth = batch
        if not self.bubble_activated and self.train_dataset.use_bubble and self.global_step >= self.loss.min_bubble_iter and self.global_step < self.loss.max_bubble_iter:
            # Start bubble step
            with torch.no_grad():
                self.bubble_activated = True
                self.train_dataset.pointcloud = self.train_dataset.pointcloud.cuda()
                # self.loss.eikonal_weight = self.loss.eikonal_weight_pointcloud

                # Disable normal loss, since it will discourage the growth of bubbles
                self.loss.normal_weight_bak = self.loss.normal_weight
                self.loss.normal_weight = 0.0
                self.loss.angular_weight_bak = self.loss.angular_weight
                self.loss.angular_weight = 0.0
                if not self.uniform_bubble:
                    print(f"[INFO] Start to initializing pointcloud PDF, criterion: {self.pdf_criterion}")
                    self.initialize_bubble_pdf(self.split_n_pixels) # initialize PDF maps for each image by computing losses
                    torch.save(self.pdf, os.path.join(self.expdir, 'checkpoints', "pdf.pt"))
                    torch.cuda.empty_cache()
                    self.plot_hotmap(os.path.join(self.expdir, 'hotmap'))
                    print("[INFO] Finish to initializing pointcloud PDF")
                    print(f"[INFO] {torch.count_nonzero(self.pdf).item()}/{self.pdf.size(0)} points to be sampled")

        if self.bubble_activated:
            model_input['pointcloud'] = self.sample_bubble(self.bubble_batch_size)

        model_outputs = self.model(model_input)
        if self.bubble_activated and not self.uniform_bubble:
            with torch.no_grad():
                if self.pdf_criterion == 'RGB':
                    self.update_pdf((model_outputs['rgb_values'].detach().clamp(0, 1) - ground_truth['rgb'].clamp(0, 1)).abs().mean(dim=-1), indices)
                # elif self.pdf_criterion == 'DEPTH':
                else:
                    self.update_pdf((model_outputs['depth_values'].detach() - ground_truth['depth']).abs(), indices)

        loss_output = self.loss(model_outputs, ground_truth, self.global_step)
        if self.bubble_activated and self.loss.max_bubble_iter is not None and self.global_step >= self.loss.max_bubble_iter:
            # End bubble step
            self.train_dataset.use_bubble = False
            self.bubble_activated = False
            del self.train_dataset.pointcloud
            del self.train_dataset.pointlinks
            del self.train_dataset.pixlinks
            if not self.uniform_bubble:
                delattr(self, 'pdf')
                delattr(self, 'sample_count')
            torch.cuda.empty_cache()
            # self.loss.eikonal_weight = self.conf.loss.eikonal_weight
            # Restore normal loss
            self.loss.normal_weight = self.loss.normal_weight_bak
            self.loss.angular_weight = self.loss.angular_weight_bak
        loss = loss_output['loss']

        with torch.no_grad():
            psnr = rend_util.get_psnr(model_outputs['rgb_values'].detach(), ground_truth['rgb'].view(-1, 3))
            self.log('train/loss', loss.item())
            self.log('train/psnr', psnr.item(), True)
            self.log('train/rgb_loss', loss_output['rgb_loss'].item())
            self.log_if_nonzero('train/eikonal_loss', loss_output['eikonal_loss'].item())
            self.log_if_nonzero('train/smooth_loss', loss_output['smooth_loss'].item())
            self.log_if_nonzero('train/mask_loss', loss_output['mask_loss'].item())
            self.log_if_nonzero('train/depth_loss', loss_output['depth_loss'].item())
            self.log_if_nonzero('train/normal_loss', loss_output['normal_loss'].item())
            self.log_if_nonzero('train/angular_loss', loss_output['angular_loss'].item())
            self.log_if_nonzero('train/bubble_loss', loss_output['bubble_loss'].item())
            self.log_if_nonzero('train/light_mask_loss', loss_output['light_mask_loss'].item())
            self.log('train/beta', self.model.density.beta.item())

        return loss


    def validation_step(self, batch, batch_idx):

        indices, model_input, ground_truth = batch

        split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels)
        res = []
        if self.progbar_task is None and self.prog_bar.progress:
            self.progbar_task = self.prog_bar.progress.add_task("[cyan]Validation split", total=len(split))
        elif self.progbar_task:
            self.prog_bar.progress.reset(self.progbar_task, total=len(split), visible=True)

        for s in split:
            out = utils.detach_dict(self.model(s))
            d = {
                'rgb_values': out['rgb_values'].detach(),
                'depth_values': out['depth_values'].detach()
            }
            if 'normal_map' in out:
                d['normal_map'] = out['normal_map'].detach()
            if 'light_mask' in out:
                d['light_mask'] = out['light_mask'].detach()
            del out
            res.append(d)
            if self.progbar_task:
                self.prog_bar.progress.update(self.progbar_task, advance=1, refresh=True)
        if self.progbar_task:
            self.prog_bar.progress.update(self.progbar_task, visible=False)
        batch_size = ground_truth['rgb'].shape[0]
        model_outputs = utils.merge_output(res, self.total_pixels, batch_size)

        def get_plot_data(model_outputs, pose, ground_truth):
            rgb_gt = ground_truth['rgb']
            batch_size, num_samples, _ = rgb_gt.shape
            rgb_eval = model_outputs['rgb_values'].reshape(batch_size, num_samples, 3)
            if self.is_hdr:
                eval_hdr = rgb_eval
                gt_hdr = rgb_gt
                rgb_eval = rend_util.linear_to_srgb(rgb_eval.clamp(0, 1))
                rgb_gt = rend_util.linear_to_srgb(rgb_gt.clamp(0, 1))
            depth_eval = model_outputs['depth_values'].reshape(batch_size, num_samples, 1)
            plot_data = {
                'rgb_gt': rgb_gt,
                'pose': pose,
                'rgb_eval': rgb_eval,
                'depth_eval': depth_eval
            }
            if self.is_hdr:
                plot_data['hdr_gt'] = gt_hdr
                plot_data['hdr_eval'] = eval_hdr
            if 'normal_map' in model_outputs:
                normal_map = model_outputs['normal_map'].reshape(batch_size, num_samples, 3)
                normal_map = normal_map.transpose(1, 2) # (bn, 3, h*w)
                R = pose[:,:3,:3].transpose(1, 2)
                normal_map = torch.bmm(R, normal_map) # world to camera
                normal_map = normal_map.transpose(1, 2)
                normal_map = (normal_map + 1.) / 2.
                plot_data['normal_map'] = normal_map
            if 'light_mask' in model_outputs:
                plot_data['lmask_eval'] = model_outputs['light_mask'].reshape(batch_size, num_samples, 1)
                plot_data['lmask_gt'] = ground_truth['light_mask'].reshape(batch_size, num_samples, 1)
            return plot_data

        plot_data = get_plot_data(model_outputs, model_input['pose'], ground_truth)
        return {
            'indices': indices,
            'plot_data': plot_data
        }

    def validation_epoch_end(self, outputs) -> None:
        self.plot_dataset.shuffle_plot_index()
        indices = torch.cat([x['indices'] for x in outputs], dim=0)
        plot_data = utils.merge_dict([x['plot_data'] for x in outputs])

        rgb_eval = plot_data['rgb_eval']
        rgb_gt = plot_data['rgb_gt']
        psnr = rend_util.get_psnr(rgb_eval, rgb_gt)
        self.log('val/psnr', psnr.item())
        rgb_gt = rgb_gt.transpose(1, 2).view(-1, 3, *self.img_res) # (bn, h*w, 3) => (bn, 3, h, w)
        rgb_eval = rgb_eval.transpose(1, 2).view(-1, 3, *self.img_res)
        self.log('val/ssim', ssim(rgb_eval, rgb_gt).item())
        lpips.to(rgb_eval.device)
        self.log('val/lpips', lpips(rgb_eval.clamp(0, 1) * 2 - 1, rgb_gt.clamp(0, 1) * 2 - 1).item())

        os.makedirs(self.plots_dir, exist_ok=True)
        os.makedirs('{0}/rendering'.format(self.plots_dir), exist_ok=True)
        if self.is_hdr:
            os.makedirs('{0}/hdr'.format(self.plots_dir), exist_ok=True)
        os.makedirs('{0}/depth'.format(self.plots_dir), exist_ok=True)
        if 'normal_map' in plot_data:
            os.makedirs('{0}/normal'.format(self.plots_dir), exist_ok=True)
        if 'lmask_eval' in plot_data:
            os.makedirs('{0}/light_mask'.format(self.plots_dir), exist_ok=True)
        if self.val_mesh:
            os.makedirs('{0}/mesh'.format(self.plots_dir), exist_ok=True)
        if self.bubble_activated and not self.uniform_bubble:
            self.plot_hotmap(os.path.join(self.expdir, 'hotmap'))
            self.plot_countmap(os.path.join(self.expdir, 'countmap'))
        plt.plot(self.model.implicit_network,
                indices,
                plot_data,
                self.plots_dir,
                self.global_step,
                self.img_res,
                meshing=self.val_mesh,
                **self.plot_conf
                )




================================================
FILE: utils/__init__.py
================================================
from .cfgnode import CfgNode
from .rend_util import *
import torch
import torch.nn.functional as F
import torch.nn as nn
from glob import glob
import os
import numpy as np
from pytorch_lightning.callbacks import RichProgressBar
from rich.progress import TextColumn

class RichProgressBarWithScanId(RichProgressBar):
    def __init__(self, scan_id, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.custom_column = TextColumn(f"[progress.description]scan_id: {scan_id}")
    
    def configure_columns(self, trainer):
        return super().configure_columns(trainer) + [self.custom_column]


def glob_imgs(path):
    imgs = []
    for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG', '*.exr']:
        imgs.extend(glob(os.path.join(path, ext)))
    return imgs

def glob_depths(path):
    imgs = []
    for ext in ['*.exr']:
        imgs.extend(glob(os.path.join(path, ext)))
    return imgs

glob_normal = glob_depths

def split_input(model_input, total_pixels, n_pixels=10000):
    '''
     Split the input to fit Cuda memory for large resolution.
     Can decrease the value of n_pixels in case of cuda out of memory error.
     '''
    split = []
    for i, indx in enumerate(torch.split(torch.arange(total_pixels, device=model_input['uv'].device), n_pixels, dim=0)):
        data = model_input.copy()
        data['uv'] = torch.index_select(model_input['uv'], 1, indx)
        if 'object_mask' in data:
            data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx)
        split.append(data)
    return split


def split_dict(d, batch_size=10000):
    keys = d.keys()
    splits = {}
    for k in d:
        splits[k] = torch.split(d[k], batch_size)
        n_splits = len(splits[k])
    split_inputs = []
    for i in range(n_splits):
        split = {}
        for k in d:
            split[k] = splits[k][i]
        split_inputs.append(split)
    return split_inputs



def detach_dict(d):
    return {k: v.detach() for k, v in d.items() if torch.is_tensor(v)}


def merge_output(res, total_pixels, batch_size):
    ''' Merge the split output. '''

    model_outputs = {}
    for entry in res[0]:
        if res[0][entry] is None:
            continue
        if len(res[0][entry].shape) == 1:
            model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res],
                                             1).reshape(batch_size * total_pixels)
        else:
            model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res],
                                             1).reshape(batch_size * total_pixels, -1)

    return model_outputs


def merge_dict(dicts):
    output = {}
    for entry in dicts[0]:
        output[entry] = torch.cat([r[entry] for r in dicts], dim=0)
    return output

from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd 

class _trunc_exp(Function):
    @staticmethod
    @custom_fwd(cast_inputs=torch.float32) # cast to float32
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.exp(x)

    @staticmethod
    @custom_bwd
    def backward(ctx, g):
        x = ctx.saved_tensors[0]
        return g * torch.exp(x.clamp(-15, 15))

trunc_exp = _trunc_exp.apply

def kmeans_pp_centroid(points: torch.Tensor, k):
    n, c = points.shape
    centroids = torch.zeros(k, c, device=points.device)
    centroids[0, :] = points[np.random.randint(0, n), :].clone()
    d = [0.0] * n
    for i in range(1, k):
        sum_all = 0
        d = (points.unsqueeze(1) - centroids[:i,:].unsqueeze(0)).norm(p=2, dim=-1).min(dim=1).values
        sum_all = d.sum() * np.random.random()
        cumsum = torch.cumsum(d, dim=0)
        j = ((cumsum - sum_all) > 0).int().argmax()
        centroids[i,:] = points[j,:].clone()
    return centroids

================================================
FILE: utils/cfgnode.py
================================================
"""
Define a class to hold configurations.
Borrows and merges stuff from YACS, fvcore, and detectron2
https://github.com/rbgirshick/yacs
https://github.com/facebookresearch/fvcore/
https://github.com/facebookresearch/detectron2/
"""

import copy
import importlib.util
import io
import logging
import os
from ast import literal_eval
from typing import Optional

import yaml

# File exts for yaml
_YAML_EXTS = {"", ".yml", ".yaml"}
# File exts for python
_PY_EXTS = {".py"}

# CfgNodes can only contain a limited set of valid types
_VALID_TYPES = {tuple, list, str, int, float, bool}

# Valid file object types
_FILE_TYPES = (io.IOBase,)

# Logger
logger = logging.getLogger(__name__)


class CfgNode(dict):
    r"""CfgNode is a `node` in the configuration `tree`. It's a simple wrapper around a `dict` and supports access to
    `attributes` via `keys`.
    """

    IMMUTABLE = "__immutable__"
    DEPRECATED_KEYS = "__deprecated_keys__"
    RENAMED_KEYS = "__renamed_keys__"
    NEW_ALLOWED = "__new_allowed__"

    def __init__(
        self,
        init_dict: Optional[dict] = None,
        key_list: Optional[list] = None,
        new_allowed: Optional[bool] = False,
    ):
        r"""
        Args:
            init_dict (dict): A dictionary to initialize the `CfgNode`.
            key_list (list[str]): A list of names that index this `CfgNode` from the root. Currently, only used for
                logging.
            new_allowed (bool): Whether adding a new key is allowed when merging with other `CfgNode` objects.
        """

        # Recursively convert nested dictionaries in `init_dict` to config tree.
        init_dict = {} if init_dict is None else init_dict
        key_list = [] if key_list is None else key_list
        init_dict = self._create_config_tree_from_dict(init_dict, key_list)
        super(CfgNode, self).__init__(init_dict)

        # Control the immutability of the `CfgNode`.
        self.__dict__[CfgNode.IMMUTABLE] = False
        # Support for deprecated options.
        # If you choose to remove support for an option in code, but don't want to change all of the config files
        # (to allow for deprecated config files to run), you can add the full config key as a string to this set.
        self.__dict__[CfgNode.DEPRECATED_KEYS] = set()
        # Support for renamed options.
        # If you rename an option, record the mapping from the old name to the new name in this dictionary. Optionally,
        # if the type also changed, you can make this value a tuple that specifies two things: the renamed key, and the
        # instructions to edit the config file.
        self.__dict__[CfgNode.RENAMED_KEYS] = {
            # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY',  # Dummy example
            # 'EXAMPLE.OLD.KEY': (                   # A more complex example
            #     'EXAMPLE.NEW.KEY',
            #     "Also convert to a tuple, eg. 'foo' -> ('foo', ) or "
            #     + "'foo.bar' -> ('foo', 'bar')"
            # ),
        }

        # Allow new attributes after initialization.
        self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed

    @classmethod
    def _create_config_tree_from_dict(cls, init_dict: dict, key_list: list):
        r"""Create a configuration tree using the input dict. Any dict-like objects inside `init_dict` will be treated
        as new `CfgNode` objects.
        Args:
            init_dict (dict): Input dictionary, to create config tree from.
            key_list (list): A list of names that index this `CfgNode` from the root. Currently only used for logging.
        """

        d = copy.deepcopy(init_dict)
        for k, v in d.items():
            if isinstance(v, dict):
                # Convert dictionary to CfgNode
                d[k] = cls(v, key_list=key_list + [k])
            else:
                # Check for valid leaf type or nested CfgNode
                _assert_with_logging(
                    _valid_type(v, allow_cfg_node=False),
                    "Key {} with value {} is not a valid type; valid types: {}".format(
                        ".".join(key_list + [k]), type(v), _VALID_TYPES
                    ),
                )
        return d

    def __getattr__(self, name: str):
        if name in self:
            return self[name]
        else:
            raise AttributeError(name)

    def __setattr__(self, name: str, value):
        if self.is_frozen():
            raise AttributeError(
                "Attempted to set {} to {}, but CfgNode is immutable".format(
                    name, value
                )
            )

        _assert_with_logging(
            name not in self.__dict__,
            "Invalid attempt to modify internal CfgNode state: {}".format(name),
        )

        _assert_with_logging(
            _valid_type(value, allow_cfg_node=True),
            "Invalid type {} for key {}; valid types = {}".format(
                type(value), name, _VALID_TYPES
            ),
        )

        self[name] = value

    def __str__(self):
        def _indent(s_, num_spaces):
            s = s_.split("\n")
            if len(s) == 1:
                return s_
            first = s.pop(0)
            s = [(num_spaces * " ") + line for line in s]
            s = "\n".join(s)
            s = first + "\n" + s
            return s

        r = ""
        s = []
        for k, v in sorted(self.items()):
            separator = "\n" if isinstance(v, CfgNode) else " "
            attr_str = "{}:{}{}".format(str(k), separator, str(v))
            attr_str = _indent(attr_str, 2)
            s.append(attr_str)
        r += "\n".join(s)
        return r

    def __repr__(self):
        return "{}({})".format(self.__class__.__name__, super(CfgNode, self).__repr__())

    def dump(self, **kwargs):
        r"""Dump CfgNode to a string.
        """

        def _convert_to_dict(cfg_node, key_list):
            if not isinstance(cfg_node, CfgNode):
                _assert_with_logging(
                    _valid_type(cfg_node),
                    "Key {} with value {} is not a valid type; valid types: {}".format(
                        ".".join(key_list), type(cfg_node), _VALID_TYPES
                    ),
                )
                return cfg_node
            else:
                cfg_dict = dict(cfg_node)
                for k, v in cfg_dict.items():
                    cfg_dict[k] = _convert_to_dict(v, key_list + [k])
                return cfg_dict

        self_as_dict = _convert_to_dict(self, [])
        return yaml.safe_dump(self_as_dict, **kwargs)

    def merge_from_file(self, cfg_filename: str):
        r"""Load a yaml config file and merge it with this CfgNode.
        Args:
            cfg_filename (str): Config file path.
        """
        with open(cfg_filename, "r") as f:
            cfg = self.load_cfg(f)
        self.merge_from_other_cfg(cfg)

    def merge_from_other_cfg(self, cfg_other):
        r"""Merge `cfg_other` into the current `CfgNode`.
        Args:
            cfg_other
        """
        _merge_a_into_b(cfg_other, self, self, [])

    def merge_from_list(self, cfg_list: list):
        r"""Merge config (keys, values) in a list (eg. from commandline) into this `CfgNode`.
        Eg. `cfg_list = ['FOO.BAR', 0.5]`.
        """
        _assert_with_logging(
            len(cfg_list) % 2 == 0,
            "Override list has odd lengths: {}; it must be a list of pairs".format(
                cfg_list
            ),
        )
        root = self
        for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):
            if root.key_is_deprecated(full_key):
                continue
            if root.key_is_renamed(full_key):
                root.raise_key_rename_error(full_key)
            key_list = full_key.split(".")
            d = self
            for subkey in key_list[:-1]:
                _assert_with_logging(
                    subkey in d, "Non-existent key: {}".format(full_key)
                )
                d = d[subkey]
            subkey = key_list[-1]
            _assert_with_logging(subkey in d, "Non-existent key: {}".format(full_key))
            value = self._decode_cfg_value(v)
            value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)
            d[subkey] = value

    def freeze(self):
        r"""Make this `CfgNode` and all of its children immutable. """
        self._immutable(True)

    def defrost(self):
        r"""Make this `CfgNode` and all of its children mutable. """
        self._immutable(False)

    def is_frozen(self):
        r"""Return mutability. """
        return self.__dict__[CfgNode.IMMUTABLE]

    def _immutable(self, is_immutable: bool):
        r"""Set mutability and recursively apply to all nested `CfgNode` objects.
        Args:
            is_immutable (bool): Whether or not the `CfgNode` and its children are immutable.
        """
        self.__dict__[CfgNode.IMMUTABLE] = is_immutable
        # Recursively propagate state to all children.
        for v in self.__dict__.values():
            if isinstance(v, CfgNode):
                v._immutable(is_immutable)
        for v in self.values():
            if isinstance(v, CfgNode):
                v._immutable(is_immutable)

    def clone(self):
        r"""Recursively copy this `CfgNode`. """
        return copy.deepcopy(self)

    def register_deprecated_key(self, key: str):
        r"""Register key (eg. `FOO.BAR`) a deprecated option. When merging deprecated keys, a warning is generated and
        the key is ignored.
        """

        _assert_with_logging(
            key not in self.__dict__[CfgNode.DEPRECATED_KEYS],
            "key {} is already registered as a deprecated key".format(key),
        )
        self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)

    def register_renamed_key(
        self, old_name: str, new_name: str, message: Optional[str] = None
    ):
        r"""Register a key as having been renamed from `old_name` to `new_name`. When merging a renamed key, an
        exception is thrown alerting the user to the fact that the key has been renamed.
        """

        _assert_with_logging(
            old_name not in self.__dict__[CfgNode.RENAMED_KEYS],
            "key {} is already registered as a renamed cfg key".format(old_name),
        )
        value = new_name
        if message:
            value = (new_name, message)
        self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value

    def key_is_deprecated(self, full_key: str):
        r"""Test if a key is deprecated. """
        if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:
            logger.warning("deprecated config key (ignoring): {}".format(full_key))
            return True
        return False

    def key_is_renamed(self, full_key: str):
        r"""Test if a key is renamed. """
        return full_key in self.__dict__[CfgNode.RENAMED_KEYS]

    def raise_key_rename_error(self, full_key: str):
        new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]
        if isinstance(new_key, tuple):
            msg = " Note: " + new_key[1]
            new_key = new_key[0]
        else:
            msg = ""
        raise KeyError(
            "Key {} was renamed to {}; please update your config.{}".format(
                full_key, new_key, msg
            )
        )

    def is_new_allowed(self):
        return self.__dict__[CfgNode.NEW_ALLOWED]

    @classmethod
    def load_cfg(cls, cfg_file_obj_or_str):
        r"""Load a configuration into the `CfgNode`.
        Args:
            cfg_file_obj_or_str (str or cfg compatible object): Supports loading from:
                - A file object backed by a YAML file.
                - A file object backed by a Python source file that exports an sttribute "cfg" (dict or `CfgNode`).
                - A string that can be parsed as valid YAML.
        """
        _assert_with_logging(
            isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),
            "Expected first argument to be of type {} or {}, but got {}".format(
                _FILE_TYPES, str, type(cfg_file_obj_or_str)
            ),
        )
        if isinstance(cfg_file_obj_or_str, str):
            return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str)
        elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):
            return cls._load_cfg_from_file(cfg_file_obj_or_str)
        else:
            raise NotImplementedError("Impossible to reach here (unless there's a bug)")

    @classmethod
    def _load_cfg_from_file(cls, file_obj):
        r"""Load a config from a YAML file or a Python source file. """
        _, file_ext = os.path.splitext(file_obj.name)
        if file_ext in _YAML_EXTS:
            return cls._load_cfg_from_yaml_str(file_obj.read())
        elif file_ext in _PY_EXTS:
            return cls._load_cfg_py_source(file_obj.name)
        else:
            raise Exception(
                "Attempt to load from an unsupported filetype {}; only {} supported".format(
                    _YAML_EXTS.union(_PY_EXTS)
                )
            )

    @classmethod
    def _load_cfg_from_yaml_str(cls, str_obj):
        r"""Load a config from a YAML string encoding. """
        cfg_as_dict = yaml.safe_load(str_obj)
        return cls(cfg_as_dict)

    @classmethod
    def _load_cfg_py_source(cls, filename):
        r"""Load a config from a Python source file. """
        module = _load_module_from_file("yacs.config.override", filename)
        _assert_with_logging(
            hasattr(module, "cfg"),
            "Python module from file {} must export a 'cfg' attribute".format(filename),
        )
        VALID_ATTR_TYPES = {dict, CfgNode}
        _assert_with_logging(
            type(module.cfg) in VALID_ATTR_TYPES,
            "Import module 'cfg' attribute must be in {} but is {}".format(
                VALID_ATTR_TYPES, type(module.cfg)
            ),
        )
        return cls(module.cfg)

    @classmethod
    def _decode_cfg_value(cls, value):
        r"""Decodes a raw config value (eg. from a yaml config file or commandline argument) into a Python object.
        If `value` is a dict, it will be interpreted as a new `CfgNode`.
        If `value` is a str, it will be evaluated as a literal.
        Otherwise, it is returned as is.
        """
        # Configs parsed from raw yaml will contain dictionary keys that need to be converted to `CfgNode` objects.
        if isinstance(value, dict):
            return cls(value)
        # All remaining processing is only applied to strings.
        if not isinstance(value, str):
            return value
        # Try to interpret `value` as a: string, number, tuple, list, dict, bool, or None
        try:
            value = literal_eval(value)
        # The following two excepts allow `value` to pass through it when it represents a string.
        # The type of `value` is always a string (before calling `literal_eval`), but sometimes it *represents* a
        # string and other times a data structure, like a list. In the case that `value` represents a str, what we
        # got back from the yaml parser is `foo` *without quotes* (so, not `"foo"`). `literal_eval` is ok with `"foo"`,
        # but will raise a `ValueError` if given `foo`. In other cases, like paths (`val = 'foo/bar'`) `literal_eval`
        # will raise a `SyntaxError`.
        except ValueError:
            pass
        except SyntaxError:
            pass
        return value


# Keep this function in global scope, for backward compataibility.
load_cfg = CfgNode.load_cfg


def _valid_type(value, allow_cfg_node: Optional[bool] = False):
    return (type(value) in _VALID_TYPES) or (
        allow_cfg_node and isinstance(value, CfgNode)
    )


def _merge_a_into_b(a: CfgNode, b: CfgNode, root: CfgNode, key_list: list):
    r"""Merge `CfgNode` `a` into `CfgNode` `b`, clobbering the options in `b` wherever they are also specified in `a`.
    """
    _assert_with_logging(
        isinstance(a, CfgNode),
        "`a` (cur type {}) must be an instance of {}".format(type(a), CfgNode),
    )
    _assert_with_logging(
        isinstance(b, CfgNode),
        "`b` (cur type {}) must be an instance of {}".format(type(b), CfgNode),
    )

    for k, v_ in a.items():
        full_key = ".".join(key_list + [k])
        v = copy.deepcopy(v_)
        v = b._decode_cfg_value(v)

        if k in b:
            v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)
            # Recursively merge dicts.
            if isinstance(v, CfgNode):
                try:
                    _merge_a_into_b(v, b[k], root, key_list + [k])
                except BaseException:
                    raise
            else:
                b[k] = v
        elif b.is_new_allowed():
            b[k] = v
        else:
            if root.key_is_deprecated(full_key):
                continue
            elif root.key_is_renamed(full_key):
                root.raise_key_rename_error(full_key)
            else:
                raise KeyError("Non-existent config key: {}".format(full_key))


def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
    r"""Checks that `replacement`, which is intended to replace `original` is of the right type. The type is correct if
    it matches exactly or is one of a few cases in which the type can easily be coerced.
    """

    original_type = type(original)
    replacement_type = type(replacement)
    if replacement_type == original_type:
        return replacement

    # If replacement and original types match, cast replacement from `from_type` to `to_type`.
    def _conditional_cast(from_type, to_type):
        if replacement_type == from_type and original_type == to_type:
            return True, to_type(replacement)
        else:
            return False, None

    # Conditional casts.
    # list <-> tuple
    casts = [(tuple, list), (list, tuple)]
    for (from_type, to_type) in casts:
        converted, converted_value = _conditional_cast(from_type, to_type)
        if converted:
            return converted_value

    raise ValueError(
        "Type mismatch ({} vs. {} with values ({} vs. {}) for config key: {}".format(
            original_type, replacement_type, original, replacement, full_key
        )
    )


def _assert_with_logging(cond, msg):
    if not cond:
        logger.debug(msg)
    assert cond, msg


def _load_module_from_file(name, filename):
    spec = importlib.util.spec_from_file_location(name, filename)
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)
    return module

================================================
FILE: utils/mesh_util.py
================================================
# adapted from https://github.com/zju3dv/manhattan_sdf
import numpy as np
import open3d as o3d
from sklearn.neighbors import KDTree
import trimesh
import os
os.environ['PYOPENGL_PLATFORM'] = 'egl'
import pyrender
from tqdm.contrib import tenumerate, tzip


def nn_correspondance(verts1, verts2):
    indices = []
    distances = []
    if len(verts1) == 0 or len(verts2) == 0:
        return indices, distances

    kdtree = KDTree(verts1)
    distances, indices = kdtree.query(verts2)
    distances = distances.reshape(-1)

    return distances


def evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02):
    pcd_trgt = o3d.geometry.PointCloud()
    pcd_pred = o3d.geometry.PointCloud()
    
    pcd_trgt.points = o3d.utility.Vector3dVector(mesh_trgt.vertices[:, :3])
    pcd_pred.points = o3d.utility.Vector3dVector(mesh_pred.vertices[:, :3])

    if down_sample:
        pcd_pred = pcd_pred.voxel_down_sample(down_sample)
        pcd_trgt = pcd_trgt.voxel_down_sample(down_sample)

    verts_pred = np.asarray(pcd_pred.points)
    verts_trgt = np.asarray(pcd_trgt.points)

    dist1 = nn_correspondance(verts_pred, verts_trgt)
    dist2 = nn_correspondance(verts_trgt, verts_pred)

    precision = np.mean((dist2 < threshold).astype('float'))
    recal = np.mean((dist1 < threshold).astype('float'))
    fscore = 2 * precision * recal / (precision + recal)
    metrics = {
        'Acc': np.mean(dist2),
        'Comp': np.mean(dist1),
        'Prec': precision,
        'Recal': recal,
        'F-score': fscore,
    }
    return metrics


class Renderer():
    def __init__(self, height=480, width=640):
        self.renderer = pyrender.OffscreenRenderer(width, height)
        self.scene = pyrender.Scene()
        # self.render_flags = pyrender.RenderFlags.SKIP_CULL_FACES

    def __call__(self, height, width, intrinsics, pose, mesh):
        self.renderer.viewport_height = height
        self.renderer.viewport_width = width
        self.scene.clear()
        self.scene.add(mesh)
        cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2],
                                        fx=intrinsics[0, 0], fy=intrinsics[1, 1])
        self.scene.add(cam, pose=self.fix_pose(pose))
        return self.renderer.render(self.scene)  # , self.render_flags)

    def fix_pose(self, pose):
        # 3D Rotation about the x-axis.
        t = np.pi
        c = np.cos(t)
        s = np.sin(t)
        R = np.array([[1, 0, 0],
                      [0, c, -s],
                      [0, s, c]])
        axis_transform = np.eye(4)
        axis_transform[:3, :3] = R
        return pose @ axis_transform

    def mesh_opengl(self, mesh):
        return pyrender.Mesh.from_trimesh(mesh)

    def delete(self):
        self.renderer.delete()
        

def refuse(mesh, poses, K, H, W, far_clip=5.0):
    renderer = Renderer()
    mesh_opengl = renderer.mesh_opengl(mesh)
    volume = o3d.pipelines.integration.ScalableTSDFVolume(
        voxel_length=0.01,
        sdf_trunc=3 * 0.01,
        color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
    )
    
    for i, pose in tenumerate(poses):
        intrinsic = K
        
        rgb = np.ones((H, W, 3))
        rgb = (rgb * 255).astype(np.uint8)
        rgb = o3d.geometry.Image(rgb)
        _, depth_pred = renderer(H, W, intrinsic, pose, mesh_opengl)
        depth_pred = o3d.geometry.Image(depth_pred)
        rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
            rgb, depth_pred, depth_scale=1.0, depth_trunc=far_clip, convert_rgb_to_intensity=False
        )
        fx, fy, cx, cy = intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2]
        intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx,  fy=fy, cx=cx, cy=cy)
        extrinsic = np.linalg.inv(pose)
        volume.integrate(rgbd, intrinsic, extrinsic)
    
    return volume.extract_triangle_mesh()

def depth2mesh(depths, poses, K, H, W):
    volume = o3d.pipelines.integration.ScalableTSDFVolume(
        voxel_length=0.01,
        sdf_trunc=3 * 0.01,
        color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8
    )
    for depth, pose in tzip(depths, poses):
        rgb = np.ones((H, W, 3))
        rgb = (rgb * 255).astype(np.uint8)
        rgb = o3d.geometry.Image(rgb)
        depth = o3d.geometry.Image(depth)
        rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
            rgb, depth, depth_scale=1.0, depth_trunc=5.0, convert_rgb_to_intensity=False
        )
        fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]
        intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx,  fy=fy, cx=cx, cy=cy)
        extrinsic = np.linalg.inv(pose)
        volume.integrate(rgbd, intrinsic, extrinsic)
    return volume.extract_triangle_mesh()

================================================
FILE: utils/plots.py
================================================
import plotly.graph_objs as go
import plotly.offline as offline
from plotly.subplots import make_subplots
import numpy as np
import torch
from skimage import measure
import torchvision.utils as vutils
import trimesh
from PIL import Image
import cv2
import mcubes
from utils import rend_util


def plot(implicit_network, indices, plot_data, path, epoch, img_res, plot_nimgs, resolution=None, grid_boundary=None, meshing=True, level=0):

    if plot_data is not None:
        if 'rgb_eval' in plot_data:
            plot_images(plot_data['rgb_eval'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res)
        if 'hdr_eval' in plot_data:
            plot_images(plot_data['hdr_eval'], plot_data['hdr_gt'], path, epoch, plot_nimgs, img_res, 'hdr', True)
        if 'rgb_surface' in plot_data:
            plot_images(plot_data['rgb_surface'], plot_data['rgb_gt'], path, '{0}s'.format(epoch), plot_nimgs, img_res)
        if 'rendered' in plot_data:
            plot_images(plot_data['rendered'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res, 'rendered')

        if 'normal_map' in plot_data:
            plot_imgs_wo_gt(plot_data['normal_map'], path, epoch, plot_nimgs, img_res)
        if 'depth_eval' in plot_data:
            plot_depths(plot_data['depth_eval'], path, epoch, plot_nimgs, img_res)
        if 'lmask_eval' in plot_data:
            plot_images(plot_data['lmask_eval'], plot_data['lmask_gt'], path, epoch, plot_nimgs, img_res, 'light_mask')

        if 'Kd' in plot_data:
            # plot_imgs_wo_gt(plot_data['albedo'], path, epoch, plot_nimgs, img_res, 'albedo')
            plot_images(plot_data['Ks'], plot_data['Kd'], path, epoch, plot_nimgs, img_res, 'albedo')
        if 'roughness' in plot_data:
            plot_colormap(plot_data['roughness'], path, epoch, plot_nimgs, img_res)
        if 'metallic' in plot_data:
            plot_colormap(plot_data['metallic'], path, epoch, plot_nimgs, img_res, 'metallic')
        if 'emission' in plot_data:
            plot_imgs_wo_gt(plot_data['emission'], path, '{0}'.format(epoch), plot_nimgs, img_res, 'emission', True)

    if meshing:
        path = f"{path}/mesh"
        cam_loc, cam_dir = rend_util.get_camera_for_plot(plot_data['pose'])
        data = []

        # plot surface
        surface_traces = get_surface_trace(path=path,
                                        epoch=epoch,
                                        sdf=lambda x: implicit_network(x)[:, 0],
                                        resolution=resolution,
                                        grid_boundary=grid_boundary,
                                        level=level
                                        )

        if surface_traces is not None:
            data.append(surface_traces[0])

        # plot cameras locations
        if plot_data is not None:
            for i, loc, dir in zip(indices, cam_loc, cam_dir):
                data.append(get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i)))

        fig = go.Figure(data=data)
        scene_dict = dict(xaxis=dict(range=[-6, 6], autorange=False),
                        yaxis=dict(range=[-6, 6], autorange=False),
                        zaxis=dict(range=[-6, 6], autorange=False),
                        aspectratio=dict(x=1, y=1, z=1))
        fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)
        filename = '{0}/surface_{1}.html'.format(path, epoch)
        offline.plot(fig, filename=filename, auto_open=False)


def visualize_pointcloud(points, filename):
    fig = go.Figure(
        data = [
            get_3D_scatter_trace(points, 'Pointcloud', 1)
        ]
    )
    scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),
                      yaxis=dict(range=[-3, 3], autorange=False),
                      zaxis=dict(range=[-3, 3], autorange=False),
                      aspectratio=dict(x=1, y=1, z=1))
    fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)
    offline.plot(fig, filename=filename, auto_open=False)


def visualize_clustered_pointcloud(points, labels, centroids, filename):
    fig = go.Figure()
    if centroids is not None:
        fig.add_trace(get_3D_scatter_trace(centroids, "Centroids", 10))
    for c in torch.unique(labels):
        cluster = points[labels == c, :]
        fig.add_trace(get_3D_scatter_trace(cluster, f"Emitter #{int(c)}"))
    scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),
                      yaxis=dict(range=[-3, 3], autorange=False),
                      zaxis=dict(range=[-3, 3], autorange=False),
                      aspectratio=dict(x=1, y=1, z=1))
    fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)
    offline.plot(fig, filename=filename, auto_open=False)


def visualize_marked_pointcloud(points, counts, path, epoch):
    fig = go.Figure(
        data = [
            get_3D_marked_scatter_trace(points, counts, 'Pointcloud samples')
        ]
    )
    scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),
                      yaxis=dict(range=[-3, 3], autorange=False),
                      zaxis=dict(range=[-3, 3], autorange=False),
                      aspectratio=dict(x=1, y=1, z=1))
    fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)
    filename = '{0}/pointcloud/{1}.html'.format(path, epoch)
    offline.plot(fig, filename=filename, auto_open=False)


def get_3D_scatter_trace(points, name='', size=3, caption=None):
    assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped "
    assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped "

    trace = go.Scatter3d(
        x=points[:, 0].cpu(),
        y=points[:, 1].cpu(),
        z=points[:, 2].cpu(),
        mode='markers',
        name=name,
        marker=dict(
            size=size,
            line=dict(
                width=2,
            ),
            opacity=1.0,
        ), text=caption)

    return trace


def get_3D_marked_scatter_trace(points, marks, name='', size=1, caption=None):
    assert points.shape[1] == 3, "3d scatter plot input points are not correctely shaped "
    assert len(points.shape) == 2, "3d scatter plot input points are not correctely shaped "

    trace = go.Scatter3d(
        x=points[:, 0].cpu(),
        y=points[:, 1].cpu(),
        z=points[:, 2].cpu(),
        mode='markers',
        name=name,
        marker=dict(
            size=size,
            li
Download .txt
gitextract_xvem6nqy/

├── .gitignore
├── DATA_CONVENTION.md
├── LICENSE
├── README.md
├── config/
│   ├── synthetic.yml
│   └── synthetic_light_mask.yml
├── data/
│   ├── normalize_cameras.py
│   └── npz_to_blender.py
├── dataset/
│   ├── __init__.py
│   ├── eval_dataset.py
│   └── train_dataset.py
├── environment.yml
├── i2-sdf-dataset-links.csv
├── main_recon.py
├── model/
│   ├── __init__.py
│   ├── eval/
│   │   ├── __init__.py
│   │   └── recon.py
│   ├── network/
│   │   ├── __init__.py
│   │   ├── density.py
│   │   ├── embedder.py
│   │   ├── mlp.py
│   │   └── ray_sampler.py
│   ├── rendering/
│   │   ├── __init__.py
│   │   └── brdf.py
│   └── trainer/
│       ├── __init__.py
│       └── recon.py
└── utils/
    ├── __init__.py
    ├── cfgnode.py
    ├── mesh_util.py
    ├── plots.py
    └── rend_util.py
Download .txt
SYMBOL INDEX (262 symbols across 18 files)

FILE: data/normalize_cameras.py
  function get_center_point (line 6) | def get_center_point(num_cams,cameras):
  function normalize_cameras (line 31) | def normalize_cameras(original_cameras_filename,output_cameras_filename,...

FILE: data/npz_to_blender.py
  function to16b (line 14) | def to16b(img):
  function opencv_to_gl (line 19) | def opencv_to_gl(pose):
  function get_offset (line 25) | def get_offset(poses):
  function scale_pose (line 37) | def scale_pose(pose, scale, offset):
  function load_K_Rt_from_P (line 43) | def load_K_Rt_from_P(filename, P=None):
  function main (line 67) | def main():

FILE: dataset/eval_dataset.py
  class GridDataset (line 15) | class GridDataset(Dataset):
    method __init__ (line 19) | def __init__(self, points, xyz) -> None:
    method __len__ (line 24) | def __len__(self):
    method __getitem__ (line 27) | def __getitem__(self, index):
  class PlotDataset (line 31) | class PlotDataset(torch.utils.data.Dataset):
    method __init__ (line 32) | def __init__(self,
    method shuffle_plot_index (line 137) | def shuffle_plot_index(self):
    method __len__ (line 141) | def __len__(self):
    method get_uv (line 144) | def get_uv(self):
    method __getitem__ (line 150) | def __getitem__(self, idx):
    method collate_fn (line 170) | def collate_fn(self, batch_list):
  class InterpolateDataset (line 188) | class InterpolateDataset(torch.utils.data.Dataset):
    method __init__ (line 192) | def __init__(self,
    method __len__ (line 244) | def __len__(self):
    method __getitem__ (line 247) | def __getitem__(self, idx):
    method collate_fn (line 258) | def collate_fn(self, batch_list):
  class RelightDataset (line 276) | class RelightDataset(PlotDataset):
    method __init__ (line 277) | def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwar...
    method loadattr (line 302) | def loadattr(self, edit_cfg, attr, mode=0):
    method __len__ (line 316) | def __len__(self):
    method __getitem__ (line 319) | def __getitem__(self, idx):
  class RelightVideoDataset (line 343) | class RelightVideoDataset(PlotDataset):
    method __init__ (line 344) | def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwar...
    method __len__ (line 358) | def __len__(self):
    method __getitem__ (line 361) | def __getitem__(self, idx):

FILE: dataset/train_dataset.py
  class ReconDataset (line 15) | class ReconDataset(torch.utils.data.Dataset):
    method __init__ (line 17) | def __init__(self,
    method __len__ (line 166) | def __len__(self):
    method __getitem__ (line 169) | def __getitem__(self, idx):
    method collate_fn (line 194) | def collate_fn(self, batch_list):
  class MaterialDataset (line 212) | class MaterialDataset(torch.utils.data.Dataset):
    method __init__ (line 214) | def __init__(self,
    method __len__ (line 315) | def __len__(self):
    method __getitem__ (line 318) | def __getitem__(self, idx):
    method collate_fn (line 336) | def collate_fn(self, batch_list):

FILE: model/eval/recon.py
  class SDFMeshSystem (line 21) | class SDFMeshSystem(pl.LightningModule):
    method __init__ (line 22) | def __init__(self, conf, exp_dir, resolution, score=False, far_clip=5....
    method initialize (line 46) | def initialize(self):
    method test_dataloader (line 84) | def test_dataloader(self):
    method test_step (line 89) | def test_step(self, batch, batch_idx):
    method test_epoch_end (line 92) | def test_epoch_end(self, outputs) -> None:
    method forward (line 131) | def forward(self):
  class VolumeRenderSystem (line 135) | class VolumeRenderSystem(pl.LightningModule):
    method __init__ (line 136) | def __init__(self, conf, exp_dir, indices=None, is_val=False, score_me...
    method test_dataloader (line 157) | def test_dataloader(self):
    method test_step (line 163) | def test_step(self, batch, batch_idx):
    method test_epoch_end (line 205) | def test_epoch_end(self, outputs):
    method forward (line 223) | def forward(self):
  class ViewInterpolateSystem (line 227) | class ViewInterpolateSystem(pl.LightningModule):
    method __init__ (line 228) | def __init__(self, conf, exp_dir, id0, id1, n_frames=60, frame_rate=24...
    method test_dataloader (line 253) | def test_dataloader(self):
    method test_step (line 259) | def test_step(self, batch, batch_idx):
    method test_epoch_end (line 284) | def test_epoch_end(self, outputs):
    method forward (line 302) | def forward(self):

FILE: model/network/__init__.py
  class I2SDFNetwork (line 19) | class I2SDFNetwork(nn.Module):
    method __init__ (line 20) | def __init__(self, conf):
    method init_emission_groups (line 49) | def init_emission_groups(self, n_emitters, pointcloud, init_emission=1...
    method get_param_groups (line 77) | def get_param_groups(self, lr):
    method forward (line 80) | def forward(self, input, predict_only=False):
    method volume_rendering (line 223) | def volume_rendering(self, z_vals, z_max, sdf):
    method bg_volume_rendering (line 242) | def bg_volume_rendering(self, z_vals_bg, bg_sdf):
    method depth2pts_outside (line 258) | def depth2pts_outside(self, ray_o, ray_d, depth):
  class I2SDFLoss (line 289) | class I2SDFLoss(nn.Module):
    method __init__ (line 290) | def __init__(self, eikonal_weight=0.1, smooth_weight=0.0, mask_weight=...
    method get_rgb_loss (line 308) | def get_rgb_loss(self, rgb_values, rgb_gt):
    method get_eikonal_loss (line 313) | def get_eikonal_loss(self, grad_theta):
    method get_mask_loss (line 317) | def get_mask_loss(self, mask_pred, mask_gt):
    method get_depth_loss (line 320) | def get_depth_loss(self, depth, depth_gt, depth_mask):
    method get_normal_l1_loss (line 326) | def get_normal_l1_loss(self, normal, normal_gt, normal_mask):
    method get_normal_angular_loss (line 331) | def get_normal_angular_loss(self, normal, normal_gt, normal_mask):
    method forward (line 338) | def forward(self, model_outputs, ground_truth, current_step):

FILE: model/network/density.py
  class Density (line 5) | class Density(nn.Module):
    method __init__ (line 6) | def __init__(self, params_init={}):
    method forward (line 12) | def forward(self, sdf, beta=None):
  class LaplaceDensity (line 16) | class LaplaceDensity(Density):  # alpha * Laplace(loc=0, scale=beta).cdf...
    method __init__ (line 17) | def __init__(self, params_init={}, beta_min=0.0001):
    method density_func (line 21) | def density_func(self, sdf, beta=None):
    method get_beta (line 28) | def get_beta(self):
  class AbsDensity (line 33) | class AbsDensity(Density):  # like NeRF++
    method density_func (line 34) | def density_func(self, sdf, beta=None):
  class SimpleDensity (line 38) | class SimpleDensity(Density):  # like NeRF
    method __init__ (line 39) | def __init__(self, params_init={}, noise_std=1.0):
    method density_func (line 43) | def density_func(self, sdf, beta=None):

FILE: model/network/embedder.py
  class Embedder (line 6) | class Embedder:
    method __init__ (line 8) | def __init__(self, **kwargs):
    method create_embedding_fn (line 12) | def create_embedding_fn(self):
    method embed (line 37) | def embed(self, inputs):
  class SHEncoder (line 41) | class SHEncoder(nn.Module):
    method __init__ (line 42) | def __init__(self, input_dims=3, degree=4):
    method forward (line 84) | def forward(self, input, **kwargs):
  class FourierFeature (line 125) | class FourierFeature(nn.Module):
    method __init__ (line 126) | def __init__(self, channels, sigma=1.0, input_dims=3, include_input=Tr...
    method forward (line 133) | def forward(self, x):
  function get_embedder (line 138) | def get_embedder(embed_type='positional', **kwargs):

FILE: model/network/mlp.py
  class ImplicitNetwork (line 10) | class ImplicitNetwork(nn.Module):
    method __init__ (line 11) | def __init__(
    method get_param_groups (line 81) | def get_param_groups(self, lr):
    method forward (line 84) | def forward(self, input):
    method gradient (line 107) | def gradient(self, x):
    method feature (line 120) | def feature(self, x):
    method get_outputs (line 123) | def get_outputs(self, x, returns_grad=True):
    method get_sdf_vals (line 145) | def get_sdf_vals(self, x):
  class RenderingNetwork (line 159) | class RenderingNetwork(nn.Module):
    method __init__ (line 160) | def __init__(
    method forward (line 208) | def forward(self, points, normals, view_dirs, feature_vectors):

FILE: model/network/ray_sampler.py
  class RaySampler (line 6) | class RaySampler(metaclass=abc.ABCMeta):
    method __init__ (line 7) | def __init__(self, near, far):
    method get_z_vals (line 12) | def get_z_vals(self, ray_dirs, cam_loc, model):
  class UniformSampler (line 15) | class UniformSampler(RaySampler):
    method __init__ (line 16) | def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere...
    method get_z_vals (line 22) | def get_z_vals(self, ray_dirs, cam_loc, model):
  class ErrorBoundSampler (line 46) | class ErrorBoundSampler(RaySampler):
    method __init__ (line 47) | def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_e...
    method get_z_vals (line 67) | def get_z_vals(self, ray_dirs, cam_loc, model):
    method get_error_bound (line 243) | def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):

FILE: model/rendering/__init__.py
  class RenderingLayer (line 10) | class RenderingLayer(nn.Module):
    method __init__ (line 11) | def __init__(self, spp, split_n_pixels, preserve_light=True) -> None:
    method forward (line 17) | def forward(

FILE: model/rendering/brdf.py
  function create_frame (line 5) | def create_frame(n: torch.Tensor, eps:float = 1e-6):
  function get_rendering_parameters (line 20) | def get_rendering_parameters(albedo_raw, rough_raw, use_metallic):
  function to_global (line 35) | def to_global(d, x, y, z):
  function sqrt_ (line 41) | def sqrt_(x: torch.Tensor, eps=1e-8) -> torch.Tensor:
  function reflect (line 47) | def reflect(v: torch.Tensor, h: torch.Tensor):
  function square_to_cosine_hemisphere (line 51) | def square_to_cosine_hemisphere(sample: torch.Tensor):
  function get_cos_theta (line 59) | def get_cos_theta(v: torch.Tensor):
  function get_phi (line 63) | def get_phi(v: torch.Tensor):
  function sample_disney_specular (line 72) | def sample_disney_specular(sample: torch.Tensor, roughness: torch.Tensor...
  function GTR2 (line 94) | def GTR2(ndh, a):
  function SchlickFresnel (line 99) | def SchlickFresnel(u):
  function smithG_GGX (line 103) | def smithG_GGX(ndv, a):
  function pdf_disney (line 109) | def pdf_disney(roughness: torch.Tensor, metallic: torch.Tensor, wi: torc...
  function eval_disney (line 130) | def eval_disney(albedo: torch.Tensor, roughness: torch.Tensor, metallic:...
  function F_Schlick (line 164) | def F_Schlick(SpecularColor, VoH):
  function GetSpecularEventProbability (line 168) | def GetSpecularEventProbability(SpecularColor, NoV) -> torch.Tensor:
  function baseColorToSpecularF0 (line 172) | def baseColorToSpecularF0(baseColor, metalness):
  function luminance (line 175) | def luminance(color):
  function probabilityToSampleSpecular (line 181) | def probabilityToSampleSpecular(difColor, specColor) -> torch.Tensor:
  function shadowedF90 (line 186) | def shadowedF90(F0):
  function evalFresnel (line 190) | def evalFresnel(f0, f90, NdotS):
  function Smith_G1_GGX (line 194) | def Smith_G1_GGX(alphaSquared, NdotSSquared):
  function Smith_G2_GGX (line 197) | def Smith_G2_GGX(alphaSquared, NdotL, NdotV):
  function GGX_D (line 202) | def GGX_D(alphaSquared, NdotH):
  function pdf_ggx (line 206) | def pdf_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor,...
  function eval_ggx (line 241) | def eval_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor...
  function sample_weight_ggx (line 268) | def sample_weight_ggx(alphaSquared, NdotL, NdotV):
  function sample_ggx (line 273) | def sample_ggx(sample: torch.Tensor, Kd: torch.Tensor, Ks: torch.Tensor,...
  function sample_ggx_specular (line 325) | def sample_ggx_specular(sample: torch.Tensor, roughness: torch.Tensor, w...

FILE: model/trainer/recon.py
  class ReconstructionTrainer (line 23) | class ReconstructionTrainer(pl.LightningModule):
    method __init__ (line 24) | def __init__(self, conf, prog_bar: RichProgressBar, exp_dir, model_onl...
    method forward (line 109) | def forward(self):
    method plot_hotmap (line 112) | def plot_hotmap(self, path):
    method plot_countmap (line 127) | def plot_countmap(self, path):
    method update_pdf (line 142) | def update_pdf(self, value, idx):
    method sample_bubble (line 155) | def sample_bubble(self, batch_size):
    method initialize_bubble_pdf (line 172) | def initialize_bubble_pdf(self, split_size):
    method configure_optimizers (line 201) | def configure_optimizers(self):
    method train_dataloader (line 209) | def train_dataloader(self):
    method val_dataloader (line 212) | def val_dataloader(self):
    method log_if_nonzero (line 215) | def log_if_nonzero(self, name, value, *args, **kwargs):
    method training_step (line 219) | def training_step(self, batch, batch_idx):
    method validation_step (line 290) | def validation_step(self, batch, batch_idx):
    method validation_epoch_end (line 358) | def validation_epoch_end(self, outputs) -> None:

FILE: utils/__init__.py
  class RichProgressBarWithScanId (line 12) | class RichProgressBarWithScanId(RichProgressBar):
    method __init__ (line 13) | def __init__(self, scan_id, *args, **kwargs) -> None:
    method configure_columns (line 17) | def configure_columns(self, trainer):
  function glob_imgs (line 21) | def glob_imgs(path):
  function glob_depths (line 27) | def glob_depths(path):
  function split_input (line 35) | def split_input(model_input, total_pixels, n_pixels=10000):
  function split_dict (line 50) | def split_dict(d, batch_size=10000):
  function detach_dict (line 66) | def detach_dict(d):
  function merge_output (line 70) | def merge_output(res, total_pixels, batch_size):
  function merge_dict (line 87) | def merge_dict(dicts):
  class _trunc_exp (line 96) | class _trunc_exp(Function):
    method forward (line 99) | def forward(ctx, x):
    method backward (line 105) | def backward(ctx, g):
  function kmeans_pp_centroid (line 111) | def kmeans_pp_centroid(points: torch.Tensor, k):

FILE: utils/cfgnode.py
  class CfgNode (line 34) | class CfgNode(dict):
    method __init__ (line 44) | def __init__(
    method _create_config_tree_from_dict (line 87) | def _create_config_tree_from_dict(cls, init_dict: dict, key_list: list):
    method __getattr__ (line 110) | def __getattr__(self, name: str):
    method __setattr__ (line 116) | def __setattr__(self, name: str, value):
    method __str__ (line 138) | def __str__(self):
    method __repr__ (line 159) | def __repr__(self):
    method dump (line 162) | def dump(self, **kwargs):
    method merge_from_file (line 184) | def merge_from_file(self, cfg_filename: str):
    method merge_from_other_cfg (line 193) | def merge_from_other_cfg(self, cfg_other):
    method merge_from_list (line 200) | def merge_from_list(self, cfg_list: list):
    method freeze (line 229) | def freeze(self):
    method defrost (line 233) | def defrost(self):
    method is_frozen (line 237) | def is_frozen(self):
    method _immutable (line 241) | def _immutable(self, is_immutable: bool):
    method clone (line 255) | def clone(self):
    method register_deprecated_key (line 259) | def register_deprecated_key(self, key: str):
    method register_renamed_key (line 270) | def register_renamed_key(
    method key_is_deprecated (line 286) | def key_is_deprecated(self, full_key: str):
    method key_is_renamed (line 293) | def key_is_renamed(self, full_key: str):
    method raise_key_rename_error (line 297) | def raise_key_rename_error(self, full_key: str):
    method is_new_allowed (line 310) | def is_new_allowed(self):
    method load_cfg (line 314) | def load_cfg(cls, cfg_file_obj_or_str):
    method _load_cfg_from_file (line 336) | def _load_cfg_from_file(cls, file_obj):
    method _load_cfg_from_yaml_str (line 351) | def _load_cfg_from_yaml_str(cls, str_obj):
    method _load_cfg_py_source (line 357) | def _load_cfg_py_source(cls, filename):
    method _decode_cfg_value (line 374) | def _decode_cfg_value(cls, value):
  function _valid_type (line 406) | def _valid_type(value, allow_cfg_node: Optional[bool] = False):
  function _merge_a_into_b (line 412) | def _merge_a_into_b(a: CfgNode, b: CfgNode, root: CfgNode, key_list: list):
  function _check_and_coerce_cfg_value_type (line 450) | def _check_and_coerce_cfg_value_type(replacement, original, key, full_key):
  function _assert_with_logging (line 482) | def _assert_with_logging(cond, msg):
  function _load_module_from_file (line 488) | def _load_module_from_file(name, filename):

FILE: utils/mesh_util.py
  function nn_correspondance (line 12) | def nn_correspondance(verts1, verts2):
  function evaluate (line 25) | def evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02):
  class Renderer (line 55) | class Renderer():
    method __init__ (line 56) | def __init__(self, height=480, width=640):
    method __call__ (line 61) | def __call__(self, height, width, intrinsics, pose, mesh):
    method fix_pose (line 71) | def fix_pose(self, pose):
    method mesh_opengl (line 83) | def mesh_opengl(self, mesh):
    method delete (line 86) | def delete(self):
  function refuse (line 90) | def refuse(mesh, poses, K, H, W, far_clip=5.0):
  function depth2mesh (line 117) | def depth2mesh(depths, poses, K, H, W):

FILE: utils/plots.py
  function plot (line 15) | def plot(implicit_network, indices, plot_data, path, epoch, img_res, plo...
  function visualize_pointcloud (line 76) | def visualize_pointcloud(points, filename):
  function visualize_clustered_pointcloud (line 90) | def visualize_clustered_pointcloud(points, labels, centroids, filename):
  function visualize_marked_pointcloud (line 105) | def visualize_marked_pointcloud(points, counts, path, epoch):
  function get_3D_scatter_trace (line 120) | def get_3D_scatter_trace(points, name='', size=3, caption=None):
  function get_3D_marked_scatter_trace (line 141) | def get_3D_marked_scatter_trace(points, marks, name='', size=1, caption=...
  function get_3D_quiver_trace (line 164) | def get_3D_quiver_trace(points, directions, color='#bd1540', name=''):
  function get_surface_trace (line 188) | def get_surface_trace(path, epoch, sdf, resolution=100, grid_boundary=[-...
  function extract_fields (line 228) | def extract_fields(bound_min, bound_max, resolution, query_func):
  function extract_geometry (line 246) | def extract_geometry(bound_min, bound_max, resolution, threshold, query_...
  function get_surface_high_res_mesh (line 258) | def get_surface_high_res_mesh(sdf, resolution=100, grid_boundary=[-2.0, ...
  function get_surface_by_grid (line 339) | def get_surface_by_grid(grid_params, sdf, resolution=100, level=0, highe...
  function get_grid_uniform (line 440) | def get_grid_uniform(resolution, grid_boundary=[-2.0, 2.0]):
  function get_grid (line 453) | def get_grid(points, resolution, input_min=None, input_max=None, eps=0.1):
  function plot_imgs_wo_gt (line 492) | def plot_imgs_wo_gt(normal_maps, path, epoch, plot_nrow, img_res, path_n...
  function plot_imgs_filter (line 508) | def plot_imgs_filter(rgb_points, ground_true, path, epoch, img_res, path...
  function plot_colormap (line 522) | def plot_colormap(mat_info, path, epoch, plot_nrow, img_res, colormap=cv...
  function plot_depths (line 538) | def plot_depths(depth_maps, path, epoch, plot_nrow, img_res, colormap=cv...
  function plot_images (line 560) | def plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res...
  function lin2img (line 581) | def lin2img(tensor, img_res):

FILE: utils/rend_util.py
  function linear_to_srgb (line 9) | def linear_to_srgb(data):
  function get_psnr (line 13) | def get_psnr(img1, img2, normalize_rgb=False):
  function load_rgb (line 25) | def load_rgb(path, normalize_rgb = False, is_hdr = False):
  function load_mask (line 38) | def load_mask(path):
  function load_depth (line 46) | def load_depth(path):
  function load_normal (line 52) | def load_normal(path):
  function load_K_Rt_from_P (line 57) | def load_K_Rt_from_P(filename, P=None):
  function depth_to_world (line 81) | def depth_to_world(uv, intrinsics, pose, depth, depth_mask=None):
  function get_camera_params (line 92) | def get_camera_params(uv, pose, intrinsics):
  function get_camera_for_plot (line 123) | def get_camera_for_plot(pose):
  function lift (line 134) | def lift(x, y, z, intrinsics):
  function quat_to_rot (line 150) | def quat_to_rot(q):
  function rot_to_quat (line 170) | def rot_to_quat(R):
  function get_general_sphere_intersections (line 191) | def get_general_sphere_intersections(cam_loc, ray_directions, center, r):
  function get_sphere_intersections (line 211) | def get_sphere_intersections(cam_loc, ray_directions, r = 1.0):
  function add_depth_noise (line 229) | def add_depth_noise(depth, depth_mask, scale=1):
Condensed preview — 31 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (230K chars).
[
  {
    "path": ".gitignore",
    "chars": 44,
    "preview": "data/synthetic\nexps\n__pycache__\narchive\ntmp*"
  },
  {
    "path": "DATA_CONVENTION.md",
    "chars": 3461,
    "preview": "# Data Convention\n\nThe format of our multi-view dataset is derived from [VolSDF](https://github.com/lioryariv/volsdf/blo"
  },
  {
    "path": "LICENSE",
    "chars": 1068,
    "preview": "MIT License\n\nCopyright (c) 2023 Jingsen Zhu\n\nPermission is hereby granted, free of charge, to any person obtaining a cop"
  },
  {
    "path": "README.md",
    "chars": 3711,
    "preview": "**News**\n\n- `04/04/2023` dataset preview release: 2 synthetic scenes available\n- `15/04/2023` code release: 3D reconstru"
  },
  {
    "path": "config/synthetic.yml",
    "chars": 1599,
    "preview": "train:\n    expname: synthetic\n    learning_rate: 5.0e-4\n    steps: 200000\n    checkpoint_freq: 10000\n    plot_freq: 500\n"
  },
  {
    "path": "config/synthetic_light_mask.yml",
    "chars": 1520,
    "preview": "train:\n    expname: synthetic_light\n    learning_rate: 5.0e-4\n    steps: 200000\n    checkpoint_freq: 10000\n    plot_freq"
  },
  {
    "path": "data/normalize_cameras.py",
    "chars": 4116,
    "preview": "import cv2\nimport numpy as np\nimport argparse\nfrom copy import deepcopy\n\ndef get_center_point(num_cams,cameras):\n    A ="
  },
  {
    "path": "data/npz_to_blender.py",
    "chars": 4142,
    "preview": "\"\"\"\n    Transform npz-formatted scenes to json-formatted scene (NeRF blender format)\n    Scale all poses to fit in a [-1"
  },
  {
    "path": "dataset/__init__.py",
    "chars": 56,
    "preview": "from .train_dataset import *\nfrom .eval_dataset import *"
  },
  {
    "path": "dataset/eval_dataset.py",
    "chars": 15229,
    "preview": "from copy import deepcopy\nimport os\nimport torch\nimport numpy as np\nfrom torch.utils.data import Dataset\nimport utils.pl"
  },
  {
    "path": "dataset/train_dataset.py",
    "chars": 16464,
    "preview": "import json\nimport os\nimport cv2\nimport torch\nimport numpy as np\nfrom torch.utils.data import Dataset\nimport torch.nn.fu"
  },
  {
    "path": "environment.yml",
    "chars": 782,
    "preview": "name: i2sdf\nchannels:\n  - pytorch\n  - conda-forge\n  - defaults\ndependencies:\n  - cudatoolkit=11.3.1=h9edb442_10\n  - ffmp"
  },
  {
    "path": "i2-sdf-dataset-links.csv",
    "chars": 2799,
    "preview": "file,url\ninteriorverse/i2-sdf/i2-sdf/bedroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedr"
  },
  {
    "path": "main_recon.py",
    "chars": 6355,
    "preview": "import torch\nimport yaml\nimport pytorch_lightning as pl\nimport argparse\nimport os\nimport utils\nimport model\nfrom pytorch"
  },
  {
    "path": "model/__init__.py",
    "chars": 129,
    "preview": "from .network import *\nfrom .trainer import *\n# from .material import *\nfrom .eval import *\nfrom .rendering import Rende"
  },
  {
    "path": "model/eval/__init__.py",
    "chars": 20,
    "preview": "from .recon import *"
  },
  {
    "path": "model/eval/recon.py",
    "chars": 15760,
    "preview": "import torch\nimport pytorch_lightning as pl\nimport numpy as np\nimport os\nfrom glob import glob\nfrom torch.utils.data imp"
  },
  {
    "path": "model/network/__init__.py",
    "chars": 19904,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nimport numpy as np\n\nimport utils\nfrom mod"
  },
  {
    "path": "model/network/density.py",
    "chars": 1422,
    "preview": "import torch.nn as nn\nimport torch\n\n\nclass Density(nn.Module):\n    def __init__(self, params_init={}):\n        super()._"
  },
  {
    "path": "model/network/embedder.py",
    "chars": 6089,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nclass Embedder:\n    \"\"\" Positiona"
  },
  {
    "path": "model/network/mlp.py",
    "chars": 7508,
    "preview": "import torch.nn as nn\nimport numpy as np\n\nimport utils\nfrom .embedder import *\nfrom .density import LaplaceDensity\nfrom "
  },
  {
    "path": "model/network/ray_sampler.py",
    "chars": 12331,
    "preview": "import abc\nimport torch\nfrom utils import rend_util\nimport utils\n\nclass RaySampler(metaclass=abc.ABCMeta):\n    def __ini"
  },
  {
    "path": "model/rendering/__init__.py",
    "chars": 3616,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport utils\nimport cv2\nfrom .brdf"
  },
  {
    "path": "model/rendering/brdf.py",
    "chars": 14422,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\n\ndef create_frame(n: torch.Tensor, eps:float = 1e-6):\n  "
  },
  {
    "path": "model/trainer/__init__.py",
    "chars": 40,
    "preview": "from .recon import ReconstructionTrainer"
  },
  {
    "path": "model/trainer/recon.py",
    "chars": 19845,
    "preview": "import math\nimport torch\nimport pytorch_lightning as pl\nimport numpy as np\nimport torch.optim as optim\nimport os\nfrom to"
  },
  {
    "path": "utils/__init__.py",
    "chars": 3849,
    "preview": "from .cfgnode import CfgNode\nfrom .rend_util import *\nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn"
  },
  {
    "path": "utils/cfgnode.py",
    "chars": 18578,
    "preview": "\"\"\"\nDefine a class to hold configurations.\nBorrows and merges stuff from YACS, fvcore, and detectron2\nhttps://github.com"
  },
  {
    "path": "utils/mesh_util.py",
    "chars": 4800,
    "preview": "# adapted from https://github.com/zju3dv/manhattan_sdf\nimport numpy as np\nimport open3d as o3d\nfrom sklearn.neighbors im"
  },
  {
    "path": "utils/plots.py",
    "chars": 25132,
    "preview": "import plotly.graph_objs as go\nimport plotly.offline as offline\nfrom plotly.subplots import make_subplots\nimport numpy a"
  },
  {
    "path": "utils/rend_util.py",
    "chars": 7476,
    "preview": "import numpy as np\nimport imageio\nimport skimage\nimport cv2\nimport torch\nfrom torch.nn import functional as F\n\n\ndef line"
  }
]

About this extraction

This page contains the full source code of the jingsenzhu/i2-sdf GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 31 files (217.1 KB), approximately 59.9k tokens, and a symbol index with 262 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!