[
  {
    "path": ".gitignore",
    "content": "data/synthetic\nexps\n__pycache__\narchive\ntmp*"
  },
  {
    "path": "DATA_CONVENTION.md",
    "content": "# Data Convention\n\nThe format of our multi-view dataset is derived from [VolSDF](https://github.com/lioryariv/volsdf/blob/main/DATA_CONVENTION.md).\n\n### Directory Structure\n\n```python\nscan<scan_id>/\n\tcameras.npz\n\timage/ -> {:04d}.png # tone-mapped LDR images\n\tdepth/ -> {:04d}.exr\n\tnormal/ -> {:04d}.exr\n\tmask/ -> {:04d}.png\n\tval/ -> {:04d}.png # validation images (LDR)\n\thdr/ -> {:04d}.exr # raw HDR images\n\t# followings are optional\n\tlight_mask/ -> {:04d}.png # emitter mask images\n\tmaterial/ -> {:04d}_kd.exr, {:04d}_ks.exr, {:04d}_rough.exr # diffuse, specular albedo and roughness\n```\n\nZeros areas in the depth maps and normal maps indicate invalid areas such as windows.\n\nNote that not all areas inside the scenes use the GGX material model. We'll provide a mask of the invalid areas.\n\n### Camera Information\n\nThe `cameras.npz` contains each image's associated camera projection matrix `'world_mat_{i}'` and a normalization matrix `'scale_mat_{i}'`, the same as VolSDF. Besides, we also provide a validation set of images for novel view synthesis, whose associated camera projection matrices are `'val_mat_{i}'`. Validation set and training set share the same normalization matrix.\n\nThe normalization matrices may not be readily available in `cameras.npz`. You can manually run `data/normalize_cameras.py` to generate `cameras_normalize.npz`. Since our method requires the entire scene to be within a radius-3 bounding sphere, we suggest normalizing cameras by radius 2.0 or 2.5.\n\nAn example of running `normalize_cameras.py`:\n\n```shell\npython normalize_cameras.py --id <scan_id> -n <synthetic/real/...> -r 2.0\n```\n\nNote that we follow **OpenCV camera coordinate system** (X right, Y downwards, Z into the image plane).\n\n### Dataset Format Conversion\n\nIf you want to convert the dataset format to NeRF blender format, run `npz_to_blender.py`:\n\n```sh\npython npz_to_blender.py --root /path/to/dataset\n```\n\nThe script will automatically scale all pose matrices to fit within a `[-1, 1]` bounding box.\n\n### About Real Dataset\n\nOur real dataset comes from [Inria](https://repo-sam.inria.fr/fungraph/deep-indoor-relight/) and [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/), with estimated depth from MVS tools and manually-labeled light masks (2 living room scenes). All depths has an absolute scale without needs of shifting like MonoSDF. All camera calibrations and depths are provided by the authors of [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/). We thank them for providing the datasets.\n\nNormal is not provided in the real dataset, and we find it sufficient for plausible reconstruction without a normal supervision in these scenes. Of course, you can estimate normal using any methods if you want to enable normal supervision.\n\n### About EXR format\n\nWe suggest using OpenCV to load an `.exr` format `float32` image:\n\n```python\nimport os\nos.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' # Enable OpenCV support for EXR\nimport cv2\n...\nim = cv2.imread(im_path, -1) # im will be an numpy.float32 array of shape (H, W, C)\nim = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) # cv2 reads image in BGR shape, convert into RGB\n```\n\nWe suggest using [tev](https://github.com/Tom94/tev) to preview HDR `.exr` images conveniently.\n\n### About Mesh\n\nDue to copyright issues, we could not release the original 3D mesh of our synthetic scenes. Instead, we'll provide a point cloud sampled from the GT mesh to enable 3D reconstruction evaluations."
  },
  {
    "path": "LICENSE",
    "content": "MIT License\n\nCopyright (c) 2023 Jingsen Zhu\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "**News**\n\n- `04/04/2023` dataset preview release: 2 synthetic scenes available\n- `15/04/2023` code release: 3D reconstruction and novel view synthesis part\n- `21/04/2023` dataset release: real data\n\n**TODO**\n\n- [ ] Full dataset release\n- [x] Code release for 3D reconstruction and novel view synthesis\n- [ ] Code release for intrinsic decomposition and scene editing\n\n**Dataset released**\n\n- Synthetic: `kitchen_0`, `bedroom_relight_0`, `bedroom_0`, `bedroom_1`, `bedroom_relight_1`, `diningroom_0`, `livingroom_0`, `livingroom_1`, more scenes to be released\n- Real: `inria_livingroom`, `nisr_livingroom`, `nisr_coffee_shop_0`, `nisr_coffee_shop_1`, release complete\n\n# I<sup>2</sup>-SDF: Intrinsic Indoor Scene Reconstruction and Editing via Raytracing in Neural SDFs (CVPR 2023)\n\n### [Project Page](https://jingsenzhu.github.io/i2-sdf/) | [Paper](https://arxiv.org/abs/2303.07634) | [Dataset](i2-sdf-dataset-links.csv)\n\n## Setup\n\n### Installation\n\n```\nconda env create -f environment.yml\nconda activate i2sdf\n```\n\n### Data preparation\n\nDownload our synthetic dataset and extract them into `data/synthetic`. If you want to run on your customized dataset, we provide a brief introduction to our data convention [here](DATA_CONVENTION.md).\n\n## Dataset\n\nWe provide a high-quality synthetic indoor scene multi-view dataset, with ground truth camera pose and geometry annotations. See [HERE](DATA_CONVENTION.md) for data conventions. Click [HERE](https://mega.nz/folder/jdhDnTqL#Ija678SU2Va_JJOiwqmdEg) to download.\n\n## 3D Reconstruction and Novel View Synthesis\n\n### Training\n\n```\npython main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version>\n```\n\nNote: `config/synthetic.yml` doesn't contain light mask network, while `config/synthetic_light_mask.yml` contains.\n\nIf you run out of GPU memory, try to reduce the `split_n_pixels` (i.e. validation batch size), `batch_size` in the config. The default parameters are evaluated under RTX A6000 (48GB). For RTX 3090 (24GB), try to set `split_n_pixels` 5000.\n\n### Evaluation\n\n#### Novel view synthesis\n\n```\npython main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version> --test [--is_val] [--full]\n```\n\nThe optional flag `--is_val` evaluates on the validation set instead of training set, `--full` produces full-resolution rendered images without downsampling.\n\n#### View Interpolation\n\n```\npython main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version> --test --test_mode interpolate --inter_id <view_id_0> <view_id_1> [--full]\n```\n\nGenerates a view interpolation video between 2 views. Requires `ffmpeg` being installed.\n\nThe number of frames and frame rate of the video can be specified by options.\n\n#### Mesh Extraction\n\n```\npython main_recon.py --conf config/<config_file>.yml --scan_id <scan_id> -d <gpu_id> -v <version> --test --test_mode mesh\n```\n\n## Intrinsic Decomposition and Scene Editing\n\n**Brewing🍺, code coming soon.**\n\n## Citation\n\nIf you find our work is useful, please consider cite:\n\n```\n@inproceedings{zhu2023i2sdf,\n    title = {I$^2$-SDF: Intrinsic Indoor Scene Reconstruction and Editing via Raytracing in Neural SDFs},\n    author = {Jingsen Zhu and Yuchi Huo and Qi Ye and Fujun Luan and Jifan Li and Dianbing Xi and Lisha Wang and Rui Tang and Wei Hua and Hujun Bao and Rui Wang},\n    booktitle = {CVPR},\n    year = {2023}\n}\n```\n\n## Acknowledgement\n\n- This repository is built upon [Pytorch lightning](https://lightning.ai/).\n- Thanks to Lior Yariv for her excellent work [VolSDF](https://lioryariv.github.io/volsdf/).\n- Thanks to [Scalable-NISR](https://xchaowu.github.io/papers/scalable-nisr/) team for providing their real-world dataset.\n"
  },
  {
    "path": "config/synthetic.yml",
    "content": "train:\n    expname: synthetic\n    learning_rate: 5.0e-4\n    steps: 200000\n    checkpoint_freq: 10000\n    plot_freq: 500\n    split_n_pixels: 12000\n    batch_size: 1600\n    pdf_criterion: DEPTH\n\nplot:\n    plot_nimgs: 1\n    grid_boundary: [-1.5, 1.5]\n\nloss:\n    eikonal_weight: 0.1\n    smooth_weight: 0.01\n    smooth_iter: 150000\n    depth_weight: 0.1\n    normal_weight: 0.05\n    bubble_weight: 0.5\n    min_bubble_iter: 50000\n    max_bubble_iter: 150000\n\ndataset:\n    data_dir: synthetic\n    img_res: [480, 640]\n    downsample: 2\n    pdf_prune: 0.05\n    pdf_max: 0.2\n\nmodel:\n    feature_vector_size: 256\n    scene_bounding_sphere: 3.0\n    implicit_network:\n        d_in: 3\n        d_out: 1\n        dims: [ 256, 256, 256, 256, 256, 256, 256, 256 ]\n        geometric_init: True\n        bias: 0.6\n        skip_in: [4]\n        weight_norm: True\n        embed_type: 'positional'\n        multires: 6\n    \n    rendering_network:\n        # mode: idr\n        # d_in: 9\n        # Don't find actual differences between 'nerf' and 'idr' mode\n        # Choose 'nerf' mode for a slight faster performance\n        mode: nerf\n        d_in: 3\n        d_out: 3\n        dims: [ 256, 256, 256, 256 ]\n        weight_norm: True\n        embed_type: 'positional'\n        multires: 4\n    \n    density:\n        params_init:\n            beta: 0.1\n        \n        beta_min: 0.0001\n    \n    ray_sampler:\n        near: 0.0\n        N_samples: 64\n        N_samples_eval: 128\n        N_samples_extra: 32\n        eps: 0.1\n        beta_iters: 10\n        max_total_iters: 5\n        N_samples_inverse_sphere: 32\n        add_tiny: 1.0e-6\n\n"
  },
  {
    "path": "config/synthetic_light_mask.yml",
    "content": "train:\n    expname: synthetic_light\n    learning_rate: 5.0e-4\n    steps: 200000\n    checkpoint_freq: 10000\n    plot_freq: 500\n    split_n_pixels: 12000\n    batch_size: 1600\n    pdf_criterion: DEPTH\n\nplot:\n    plot_nimgs: 1\n    grid_boundary: [-1.5, 1.5]\n\nloss:\n    eikonal_weight: 0.1\n    smooth_weight: 0.01\n    smooth_iter: 150000\n    depth_weight: 0.1\n    normal_weight: 0.05\n    bubble_weight: 0.5\n    light_mask_weight: 0.5\n    min_bubble_iter: 50000\n    max_bubble_iter: 150000\n\ndataset:\n    data_dir: synthetic\n    img_res: [480, 640]\n    downsample: 2\n    pdf_prune: 0.05\n    pdf_max: 0.2\n\nmodel:\n    feature_vector_size: 256\n    scene_bounding_sphere: 3.0\n    implicit_network:\n        d_in: 3\n        d_out: 1\n        dims: [ 256, 256, 256, 256, 256, 256 ]\n        geometric_init: True\n        bias: 0.6\n        skip_in: [3]\n        weight_norm: True\n        embed_type: 'positional'\n        multires: 6\n    \n    rendering_network:\n        mode: nerf\n        d_in: 3\n        d_out: 3\n        dims: [ 256, 256, 256 ]\n        weight_norm: True\n        embed_type: 'positional'\n        multires: 4\n    \n    light_network:\n        dims: [ 128 ]\n        weight_norm: True\n    \n    density:\n        params_init:\n            beta: 0.1\n        \n        beta_min: 0.0001\n    \n    ray_sampler:\n        near: 0.0\n        N_samples: 64\n        N_samples_eval: 128\n        N_samples_extra: 32\n        eps: 0.1\n        beta_iters: 10\n        max_total_iters: 5\n        N_samples_inverse_sphere: 32\n        add_tiny: 1.0e-6\n\n"
  },
  {
    "path": "data/normalize_cameras.py",
    "content": "import cv2\nimport numpy as np\nimport argparse\nfrom copy import deepcopy\n\ndef get_center_point(num_cams,cameras):\n    A = np.zeros((3 * num_cams, 3 + num_cams))\n    b = np.zeros((3 * num_cams, 1))\n    camera_centers=np.zeros((3,num_cams))\n    for i in range(num_cams):\n        P0 = cameras['world_mat_%d' % i][:3, :]\n\n        K = cv2.decomposeProjectionMatrix(P0)[0]\n        R = cv2.decomposeProjectionMatrix(P0)[1]\n        c = cv2.decomposeProjectionMatrix(P0)[2]\n        c = c / c[3]\n        camera_centers[:,i]=c[:3].flatten()\n\n        # v = np.linalg.inv(K) @ np.array([800, 600, 1])\n        # v = v / np.linalg.norm(v)\n\n        v=R[2,:]\n        A[3 * i:(3 * i + 3), :3] = np.eye(3)\n        A[3 * i:(3 * i + 3), 3 + i] = -v\n        b[3 * i:(3 * i + 3)] = c[:3]\n\n    soll= np.linalg.pinv(A) @ b\n\n    return soll,camera_centers\n\ndef normalize_cameras(original_cameras_filename,output_cameras_filename,num_of_cameras,radius,convert_coord):\n    cameras = np.load(original_cameras_filename)\n    if num_of_cameras==-1:\n        all_files=cameras.files\n        maximal_ind=0\n        for field in all_files:\n            if 'val' not in field:\n                maximal_ind=np.maximum(maximal_ind,int(field.split('_')[-1]))\n        num_of_cameras=maximal_ind+1\n    soll, camera_centers = get_center_point(num_of_cameras, cameras)\n\n    center = soll[:3].flatten()\n\n    max_radius = np.linalg.norm((center[:, np.newaxis] - camera_centers), axis=0).max() * 1.1\n\n    normalization = np.eye(4).astype(np.float32)\n\n    normalization[0, 3] = center[0]\n    normalization[1, 3] = center[1]\n    normalization[2, 3] = center[2]\n\n    normalization[0, 0] = max_radius / radius\n    normalization[1, 1] = max_radius / radius\n    normalization[2, 2] = max_radius / radius\n\n    cameras_new = {}\n    cameras_new = deepcopy(dict(cameras))\n    for i in range(num_of_cameras):\n        cameras_new['scale_mat_%d' % i] = normalization\n        # cameras_new['world_mat_%d' % i] = cameras['world_mat_%d' % i].copy()\n        # if ('val_mat_%d' % i) in cameras:\n        #     cameras_new['val_mat_%d' % i] = cameras['val_mat_%d' % i].copy()\n        \n        def opengl2opencv(P):\n            out = cv2.decomposeProjectionMatrix(P[:3,:])\n            K, R, t = out[0:3]\n            K = K/K[2,2]\n            intrinsics = np.eye(4, dtype=np.float32)\n            intrinsics[:3, :3] = K\n            t = (t[:3] / t[3]).squeeze()\n            w2c = np.eye(4, dtype=np.float32)\n            w2c[:3,:3] = R\n            w2c[:3,3] = -R @ t\n            T = np.diag([1, -1, -1, 1])\n            w2c = T @ w2c\n            return intrinsics @ w2c\n        if convert_coord:\n            cameras_new['world_mat_%d' % i] = opengl2opencv(cameras_new['world_mat_%d' % i])\n            if ('val_mat_%d' % i) in cameras_new:\n                cameras_new['val_mat_%d' % i] = opengl2opencv(cameras_new['val_mat_%d' % i])\n            # cameras_new['world_mat_%d' % i] = T @ cameras_new['world_mat_%d' % i]\n    np.savez(output_cameras_filename, **cameras_new)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(description='Normalizing cameras')\n    parser.add_argument('-i', '--input_cameras_file', type=str, default=\"cameras.npz\",\n                        help='the input cameras file')\n    parser.add_argument('-o', '--output_cameras_file', type=str, default=\"cameras_normalize.npz\",\n                        help='the output cameras file')\n    parser.add_argument('--id', type=int, nargs='?')\n    parser.add_argument('-n', '--name', type=str, default='synthetic')\n    parser.add_argument('--number_of_cams',type=int, default=-1,\n                        help='Number of cameras, if -1 use all')\n    parser.add_argument('-r', '--radius', type=float, default=2.0)\n    parser.add_argument('-c', '--convert_coord', action='store_true')\n\n    args = parser.parse_args()\n    if args.id:\n        args.input_cameras_file = f'{args.name}/scan{args.id}/cameras.npz'\n        args.output_cameras_file = f'{args.name}/scan{args.id}/cameras_normalize.npz'\n\n    normalize_cameras(args.input_cameras_file, args.output_cameras_file, args.number_of_cams, args.radius, args.convert_coord)\n"
  },
  {
    "path": "data/npz_to_blender.py",
    "content": "\"\"\"\n    Transform npz-formatted scenes to json-formatted scene (NeRF blender format)\n    Scale all poses to fit in a [-1, 1] box\n\"\"\"\n\nimport copy\nimport json\nimport os\nimport cv2\nimport numpy as np\nfrom tqdm import tqdm\nimport argparse\n\ndef to16b(img):\n    img = img.clip(0, 1) * 65535\n    return img.astype(np.uint16)\n\n\ndef opencv_to_gl(pose):\n    mat = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])\n    pose[:3, :3] = pose[:3, :3] @ mat\n    return pose\n\n\ndef get_offset(poses):\n    eyes = np.stack([pose[:3, 3] for pose in poses])\n\n    scale = eyes.max(axis=0) - eyes.min(axis=0)\n    print(f'scale : {scale}')\n\n    offset = -(eyes.max(axis=0) + eyes.min(axis=0)) / 2\n    print(f'offset : {offset}')\n\n    return scale / 2, offset\n\n\ndef scale_pose(pose, scale, offset):\n    pose[:3, 3] = (pose[:3, 3] + offset) / scale\n    # print(pose[:3, 3])\n    return pose.tolist()\n\n\ndef load_K_Rt_from_P(filename, P=None):\n    if P is None:\n        lines = open(filename).read().splitlines()\n        if len(lines) == 4:\n            lines = lines[1:]\n        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(\" \") for x in lines)]\n        P = np.asarray(lines).astype(np.float32).squeeze()\n\n    out = cv2.decomposeProjectionMatrix(P)\n    K = out[0]\n    R = out[1]\n    t = out[2]\n\n    K = K / K[2, 2]\n    intrinsics = np.eye(4)\n    intrinsics[:3, :3] = K\n\n    pose = np.eye(4, dtype=np.float32)\n    pose[:3, :3] = R.transpose()\n    pose[:3, 3] = (t[:3] / t[3])[:, 0]\n\n    return intrinsics, pose\n\n\ndef main():\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--root', required=True)\n    parser.add_argument('--scale', action='store_true')\n    args = parser.parse_args()\n    os.chdir(os.path.join(args.root))\n\n    image_dir = 'image'\n    n_images = len(os.listdir(image_dir))\n    val_dir = 'val'\n    n_val = len(os.listdir(val_dir))\n    os.makedirs('depths', exist_ok=True)\n\n    cam_file = 'cameras.npz'\n    camera_dict = np.load(cam_file)\n    world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]\n    val_mats = [camera_dict['val_mat_%d' % idx].astype(np.float32) for idx in range(n_val)]\n\n    intrinsics_all = []\n    pose_all = []\n    for mat in world_mats + val_mats:\n        P = mat\n        P = P[:3, :4]\n        intrinsics, pose = load_K_Rt_from_P(None, P)\n        intrinsics_all.append(intrinsics)\n        pose_all.append(opencv_to_gl(pose))\n\n    train_json = dict()\n\n    train_json['fl_y'] = intrinsics[1][1]\n    train_json['h'] = int(intrinsics[1, 2] * 2)\n    train_json['fl_x'] = intrinsics[0][0]\n    train_json['w'] = int(intrinsics[0, 2] * 2)\n\n    if args.scale:\n        scale, offset = get_offset(pose_all)\n\n    # train_json['enable_depth_loading'] = True\n    # train_json['integer_depth_scale'] = 1 / 65535\n\n    train_json['frames'] = []\n\n    test_json = copy.deepcopy(train_json)\n    test_json['enable_depth_loading'] = False\n\n    for i in tqdm(range(n_images)):\n        frames = train_json['frames']\n        if args.scale:\n            depth = cv2.imread(os.path.join('depth', '{:04d}.exr'.format(i)), -1)\n            cv2.imwrite(os.path.join('depths', '{:04d}.exr'.format(i)), depth / scale.max())\n\n        pose = pose_all[i].tolist() if not args.scale else scale_pose(pose_all[i], scale.max(), offset)\n        frame = {\n            'file_path': f'./image/{i:04d}',\n            'depth_path': f'./depths/{i:04d}.exr' if args.scale else f'./depth/{i:04d}.exr',\n            'transform_matrix': pose\n        }\n        frames.append(frame)\n\n    for i in tqdm(range(n_val)):\n        frames = test_json['frames']\n        pose = pose_all[i + n_images].tolist() if not args.scale else scale_pose(pose_all[i + n_images], scale.max(), offset)\n        frame = {\n            'file_path': f'./val/{i:04d}',\n            'transform_matrix': pose\n        }\n        frames.append(frame)\n\n    with open('transforms_train.json', 'w') as f:\n        json.dump(train_json, f, indent=4)\n    with open('transforms_test.json', 'w') as f:\n        json.dump(test_json, f, indent=4)\n    with open('transforms_val.json', 'w') as f:\n        json.dump(test_json, f, indent=4)\n\n\nif __name__ == '__main__':\n    main()\n"
  },
  {
    "path": "dataset/__init__.py",
    "content": "from .train_dataset import *\nfrom .eval_dataset import *"
  },
  {
    "path": "dataset/eval_dataset.py",
    "content": "from copy import deepcopy\nimport os\nimport torch\nimport numpy as np\nfrom torch.utils.data import Dataset\nimport utils.plots as plt\nimport torch.nn.functional as F\nimport utils\nfrom utils import rend_util\nfrom tqdm.contrib import tzip, tenumerate\nfrom scipy.spatial.transform import Rotation as Rot\nfrom scipy.spatial.transform import Slerp\nimport cv2\n\nclass GridDataset(Dataset):\n    \"\"\"\n    Used for mesh extraction\n    \"\"\"\n    def __init__(self, points, xyz) -> None:\n        super().__init__()\n        self.grid_points = points\n        self.xyz = xyz\n    \n    def __len__(self):\n        return self.grid_points.size(0)\n    \n    def __getitem__(self, index):\n        return self.grid_points[index]\n\n\nclass PlotDataset(torch.utils.data.Dataset):\n    def __init__(self,\n                 data_dir,\n                 plot_nimgs,\n                 scan_id=0,\n                 is_val=False,\n                 data=None,\n                 is_hdr=False,\n                 indices=None,\n                 use_lmask=False,\n                 **kwargs\n                 ):\n\n        self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))\n        val_dir = '{0}/val'.format(self.instance_dir)\n        is_val = is_val and os.path.exists(val_dir)\n        lmask_dir = '{0}/light_mask'.format(self.instance_dir)\n        self.use_lmask = use_lmask and os.path.exists(lmask_dir)\n        if is_val:\n            print(\"[INFO] Validation set detected\")\n        if is_val or data is None:\n            assert os.path.exists(self.instance_dir), \"Data directory is empty\"\n\n            if is_val:\n                image_dir = val_dir\n            elif is_hdr:\n                image_dir = '{0}/hdr'.format(self.instance_dir)\n            else:\n                image_dir = '{0}/image'.format(self.instance_dir)\n            image_paths = sorted(utils.glob_imgs(image_dir))\n            if indices is not None:\n                print(f\"[INFO] Selecting indices: {indices}\")\n                image_paths = [image_paths[i] for i in indices]\n            self.n_images = len(image_paths)\n            self.indices = indices if indices is not None else list(range(self.n_images))\n\n            self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)\n            camera_dict = np.load(self.cam_file)\n            scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.indices] if not is_val else [camera_dict['scale_mat_0'].astype(np.float32)] * len(self.indices)\n            world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.indices] if not is_val else [camera_dict['val_mat_%d' % idx].astype(np.float32) for idx in self.indices]\n\n            self.intrinsics_all = []\n            self.pose_all = []\n            for scale_mat, world_mat in zip(scale_mats, world_mats):\n                P = world_mat @ scale_mat\n                P = P[:3, :4]\n                intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)\n                self.intrinsics_all.append(torch.from_numpy(intrinsics).float())\n                self.pose_all.append(torch.from_numpy(pose).float())\n            self.intrinsics_all = torch.stack(self.intrinsics_all, 0)\n            self.pose_all = torch.stack(self.pose_all, 0)\n            self.rgb_images = []\n            for path in image_paths:\n                rgb = rend_util.load_rgb(path, is_hdr=is_hdr)\n                self.img_res = [rgb.shape[1], rgb.shape[2]]\n                rgb = rgb.reshape(3, -1).transpose(1, 0)\n                self.rgb_images.append(torch.from_numpy(rgb).float())\n            self.rgb_images = torch.stack(self.rgb_images, 0)\n            if self.use_lmask:\n                self.lightmask_images = []\n                lmask_paths = sorted(utils.glob_imgs(lmask_dir))\n                for path in lmask_paths:\n                    lmask = rend_util.load_mask(path)\n                    lmask = lmask.reshape(-1, 1)\n                    self.lightmask_images.append(torch.from_numpy(lmask).float())\n                self.lightmask_images = torch.stack(self.lightmask_images, 0)\n            self.total_pixels = self.rgb_images.size(1)\n        else:\n            self.intrinsics_all = data['intrinsics']\n            self.pose_all = data['pose']\n            self.rgb_images = data['rgb']\n            self.n_images = len(self.rgb_images)\n            self.img_res = [data['img_res'][0], data['img_res'][1]]\n            self.total_pixels = self.img_res[0] * self.img_res[1]\n            if 'light_mask' in data:\n                self.lightmask_images = data['light_mask']\n                self.use_lmask = True\n        \n        if (scale := kwargs.get('downsample', 1)) > 1:\n            old_img_res = deepcopy(self.img_res)\n            self.img_res[0] //= scale\n            self.img_res[1] //= scale\n            self.total_pixels = self.img_res[0] * self.img_res[1]\n            self.rgb_images = self.rgb_images.transpose(1, 2).reshape(-1, 3, old_img_res[0], old_img_res[1])\n            self.rgb_images = F.interpolate(self.rgb_images, self.img_res, mode='area')\n            self.rgb_images = self.rgb_images.reshape(-1, 3, self.total_pixels).transpose(1, 2)\n            if self.use_lmask:\n                self.lightmask_images = self.lightmask_images.transpose(1, 2).reshape(-1, 1, old_img_res[0], old_img_res[1])\n                self.lightmask_images = F.interpolate(self.lightmask_images, self.img_res, mode='area')\n                self.lightmask_images = self.lightmask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2)\n\n            self.intrinsics_all = self.intrinsics_all.clone()\n            self.intrinsics_all[:,0,0] /= scale\n            self.intrinsics_all[:,1,1] /= scale\n            self.intrinsics_all[:,0,2] /= scale\n            self.intrinsics_all[:,1,2] /= scale\n        \n        print(f\"[INFO] Plot image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total\")\n        if plot_nimgs == -1:\n            self.plot_nimgs = self.n_images\n        else:\n            self.plot_nimgs = min(plot_nimgs, self.n_images)\n        self.shuffle = kwargs.get('shuffle', True)\n        if self.shuffle:\n            self.shuffle_plot_index()\n\n    def shuffle_plot_index(self):\n        if self.shuffle:\n            self.plot_index = torch.randperm(self.n_images)[:self.plot_nimgs]\n\n    def __len__(self):\n        return self.plot_nimgs\n    \n    def get_uv(self):\n        uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)\n        uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()\n        uv = uv.reshape(2, -1).transpose(1, 0)\n        return uv\n\n    def __getitem__(self, idx):\n        if self.shuffle:\n            idx = self.plot_index[idx]\n        uv = self.get_uv()\n\n        sample = {\n            \"uv\": uv,\n            \"intrinsics\": self.intrinsics_all[idx],\n            \"pose\": self.pose_all[idx]\n        }\n\n        ground_truth = {\n            \"rgb\": self.rgb_images[idx]\n        }\n\n        if self.use_lmask:\n            ground_truth['light_mask'] = self.lightmask_images[idx]\n\n        return idx, sample, ground_truth\n\n    def collate_fn(self, batch_list):\n        # get list of dictionaries and returns input, ground_true as dictionary for all batch instances\n        batch_list = zip(*batch_list)\n\n        all_parsed = []\n        for entry in batch_list:\n            if type(entry[0]) is dict:\n                # make them all into a new dict\n                ret = {}\n                for k in entry[0].keys():\n                    ret[k] = torch.stack([obj[k] for obj in entry])\n                all_parsed.append(ret)\n            else:\n                all_parsed.append(torch.LongTensor(entry))\n\n        return tuple(all_parsed)\n\n\nclass InterpolateDataset(torch.utils.data.Dataset):\n    \"\"\"\n    View interpolation: specify 2 view ids from training set and generate a video moving between them\n    \"\"\"\n    def __init__(self,\n                 data_dir,\n                #  img_res,\n                 id0,\n                 id1,\n                 num_frames=60,\n                 scan_id=0,\n                 **kwargs\n                 ):\n\n        self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))\n        assert os.path.exists(self.instance_dir), \"Data directory is empty\"\n\n        image_dir = '{0}/image'.format(self.instance_dir)\n        im = cv2.imread(f\"{image_dir}/{id0:04d}.png\")\n        h, w, _ = im.shape\n        self.img_res = [h, w]\n        self.total_pixels = h * w\n\n        self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)\n        camera_dict = np.load(self.cam_file)\n        P0 = camera_dict['world_mat_%d' % id0].astype(np.float32) @ camera_dict['scale_mat_%d' % id0].astype(np.float32)\n        P1 = camera_dict['world_mat_%d' % id1].astype(np.float32) @ camera_dict['scale_mat_%d' % id1].astype(np.float32)\n        P0 = P0[:3,:]\n        P1 = P1[:3,:]\n        K, pose0 = rend_util.load_K_Rt_from_P(None, P0)\n        _, pose1 = rend_util.load_K_Rt_from_P(None, P1)\n        rots = Rot.from_matrix(np.stack([pose0[:3,:3].T, pose1[:3,:3].T]))\n        slerp = Slerp([0, 1], rots)\n\n        if (scale := kwargs.get('downsample', 1)) > 1:\n            self.img_res[0] = self.img_res[0] // scale\n            self.img_res[1] = self.img_res[1] // scale\n            self.total_pixels = self.img_res[0] * self.img_res[1]\n            K[0,0] /= scale\n            K[1,1] /= scale\n            K[0,2] /= scale\n            K[1,2] /= scale\n\n        self.intrinsics = torch.from_numpy(K).float()\n        self.pose_all = []\n        for i in range(num_frames):\n            ratio = np.sin(((i / num_frames) - 0.5) * np.pi) * 0.5 + 0.5\n            t = (1 - ratio) * pose0[:3,3] + ratio * pose1[:3,3]\n            R = slerp(ratio).as_matrix()\n            pose = np.eye(4, dtype=np.float32)\n            pose[:3,3] = t\n            pose[:3,:3] = R.T\n            self.pose_all.append(torch.from_numpy(pose).float())\n        self.pose_all = torch.stack(self.pose_all)\n        self.n_frames = num_frames\n\n    def __len__(self):\n        return self.n_frames\n\n    def __getitem__(self, idx):\n        uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)\n        uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float()\n        uv = uv.reshape(2, -1).transpose(1, 0)\n        sample = {\n            \"uv\": uv,\n            \"intrinsics\": self.intrinsics,\n            \"pose\": self.pose_all[idx]\n        }\n        return idx, sample\n\n    def collate_fn(self, batch_list):\n        # get list of dictionaries and returns input, ground_true as dictionary for all batch instances\n        batch_list = zip(*batch_list)\n\n        all_parsed = []\n        for entry in batch_list:\n            if type(entry[0]) is dict:\n                # make them all into a new dict\n                ret = {}\n                for k in entry[0].keys():\n                    ret[k] = torch.stack([obj[k] for obj in entry])\n                all_parsed.append(ret)\n            else:\n                all_parsed.append(torch.LongTensor(entry))\n\n        return tuple(all_parsed)\n\n\nclass RelightDataset(PlotDataset):\n    def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwargs):\n        super().__init__(data_dir, 1, scan_id, is_val, None, False, [edit_cfg['index']], True, **kwargs)\n        self.edit_mask = 'mask' in edit_cfg\n        if self.edit_mask:\n            self.mask = rend_util.load_mask(edit_cfg['mask']).astype(np.float32)\n            mh, mw = self.mask.shape\n            if mh != self.img_res[0] or mw != self.img_res[1]:\n                self.mask = cv2.resize(self.mask, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA)\n                self.mask = (self.mask > 0.5)\n            self.mask = torch.from_numpy(self.mask).float().flatten()\n            if 'normal' in edit_cfg:\n                self.loadattr(edit_cfg, 'normal', 0)\n                self.normal = self.normal.reshape(-1, 3)\n                self.normal = F.normalize(self.normal, dim=-1, eps=1e-6)\n            if 'rough' in edit_cfg:\n                self.loadattr(edit_cfg, 'rough', 1)\n                self.rough = self.rough.reshape(-1, 1)\n            if 'kd' in edit_cfg:\n                self.loadattr(edit_cfg, 'kd', 2)\n                self.kd = self.kd.reshape(-1, 3)\n            if 'ks' in edit_cfg:\n                self.loadattr(edit_cfg, 'ks', 2)\n                self.ks = self.ks.reshape(-1, 3)\n        self.uv = self.get_uv()\n    \n    def loadattr(self, edit_cfg, attr, mode=0):\n        if mode == 0:\n            im = rend_util.load_normal(edit_cfg[attr])\n        elif mode == 1:\n            im = cv2.imread(edit_cfg[attr], -1)\n            if len(im.shape) == 3:\n                im = im[:,:,-1]\n        else:\n            im = rend_util.load_rgb(edit_cfg[attr]).transpose(1, 2, 0)\n        h, w = im.shape[:2]\n        if h != self.img_res[0] or w != self.img_res[1]:\n            im = cv2.resize(im, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA)\n        setattr(self, attr, torch.from_numpy(im).float())\n\n    def __len__(self):\n        return self.total_pixels\n    \n    def __getitem__(self, idx):\n        sample = {\n            \"uv\": self.uv[idx].unsqueeze(0),\n            \"intrinsics\": self.intrinsics_all[0],\n            \"pose\": self.pose_all[0] \n        }\n        ground_truth = {\n            \"rgb\": self.rgb_images[0][idx],\n            'light_mask': self.lightmask_images[0][idx]\n            # 'edit_mask': self.edit_mask[idx]\n        }\n        if self.edit_mask:\n            ground_truth['mask'] = self.mask[idx]\n        if hasattr(self, 'normal'):\n            ground_truth['normal'] = self.normal[idx]\n        if hasattr(self, 'rough'):\n            ground_truth['rough'] = self.rough[idx]\n        if hasattr(self, 'kd'):\n            ground_truth['kd'] = self.kd[idx]\n        if hasattr(self, 'ks'):\n            ground_truth['ks'] = self.ks[idx]\n        return idx, sample, ground_truth\n\n    \nclass RelightVideoDataset(PlotDataset):\n    def __init__(self, data_dir, edit_cfg, scan_id=0, is_val=False, **kwargs):\n        self.n_frames = edit_cfg['n_frames']\n        self.img_idx = edit_cfg['index']\n        super().__init__(data_dir, 1, scan_id, is_val, None, False, [edit_cfg['index']] * self.n_frames, True, **kwargs)\n        self.edit_mask = 'mask' in edit_cfg\n        if self.edit_mask:\n            self.mask = rend_util.load_mask(edit_cfg['mask']).astype(np.float32)\n            mh, mw = self.mask.shape\n            if mh != self.img_res[0] or mw != self.img_res[1]:\n                self.mask = cv2.resize(self.mask, (self.img_res[1], self.img_res[0]), interpolation=cv2.INTER_AREA)\n                self.mask = (self.mask > 0.5)\n            self.mask = torch.from_numpy(self.mask).float().flatten()\n        self.uv = self.get_uv()\n    \n    def __len__(self):\n        return self.n_frames\n    \n    def __getitem__(self, idx):\n        sample = {\n            \"uv\": self.uv,\n            \"intrinsics\": self.intrinsics_all[idx],\n            \"pose\": self.pose_all[idx]\n        }\n        ground_truth = {\n            \"rgb\": self.rgb_images[idx],\n            'light_mask': self.lightmask_images[idx]\n            # 'edit_mask': self.edit_mask[idx]\n        }\n        if self.edit_mask:\n            ground_truth['mask'] = self.mask\n        return idx, sample, ground_truth\n"
  },
  {
    "path": "dataset/train_dataset.py",
    "content": "import json\nimport os\nimport cv2\nimport torch\nimport numpy as np\nfrom torch.utils.data import Dataset\nimport torch.nn.functional as F\nimport utils\nfrom utils import rend_util\nfrom tqdm.contrib import tzip, tenumerate\nfrom scipy.spatial.transform import Rotation as Rot\nfrom scipy.spatial.transform import Slerp\n\n\nclass ReconDataset(torch.utils.data.Dataset):\n\n    def __init__(self,\n                 data_dir,\n                 scan_id=0,\n                 use_mask=False,\n                 use_depth=False,\n                 use_normal=False,\n                 use_bubble=False,\n                 use_lightmask=False,\n                 is_hdr=False,\n                 **kwargs\n                 ):\n        self.sampling_idx = slice(None)\n\n        self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))\n        assert os.path.exists(self.instance_dir), \"Data directory is empty\"\n        print(f\"[INFO] Loading data from {self.instance_dir}\")\n\n        image_dir = '{0}/image'.format(self.instance_dir) if not is_hdr else '{0}/hdr'.format(self.instance_dir)\n        self.is_hdr = is_hdr\n        if is_hdr:\n            print(\"[INFO] Using HDR image\")\n        image_paths = sorted(utils.glob_imgs(image_dir))\n        self.n_images = len(image_paths)\n\n        self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)\n        camera_dict = np.load(self.cam_file)\n        scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]\n        world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]\n\n        self.intrinsics_all = []\n        self.pose_all = []\n        for scale_mat, world_mat in zip(scale_mats, world_mats):\n            P = world_mat @ scale_mat\n            P = P[:3, :4]\n            intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)\n            self.intrinsics_all.append(torch.from_numpy(intrinsics).float())\n            self.pose_all.append(torch.from_numpy(pose).float())\n        self.intrinsics_all = torch.stack(self.intrinsics_all, 0)\n        self.pose_all = torch.stack(self.pose_all, 0)\n\n        self.rgb_images = []\n        for path in image_paths:\n            rgb = rend_util.load_rgb(path, is_hdr=is_hdr)\n            self.img_res = [rgb.shape[1], rgb.shape[2]]\n            rgb = rgb.reshape(3, -1).transpose(1, 0)\n            self.rgb_images.append(torch.from_numpy(rgb).float())\n        self.rgb_images = torch.stack(self.rgb_images, 0)\n        self.total_pixels = self.rgb_images.size(1)\n        print(f\"[INFO] image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total\")\n        \n        uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)\n        uv = np.flip(uv, axis=0).copy()\n        self.uv = torch.from_numpy(uv).float()\n        self.uv = self.uv.reshape(2, -1).transpose(1, 0) # (h*w, 2)\n\n        mask_dir = '{0}/mask'.format(self.instance_dir)\n        self.use_mask = use_mask\n        if self.use_mask:\n            if os.path.exists(mask_dir):\n                mask_paths = sorted(utils.glob_imgs(mask_dir))\n                # assert len(mask_paths) == self.n_images\n                self.mask_images = []\n                for path in mask_paths:\n                    mask = rend_util.load_mask(path)\n                    mask = mask.reshape(-1, 1)\n                    self.mask_images.append(torch.from_numpy(mask).float())\n                self.mask_images = torch.stack(self.mask_images, 0)\n            else:\n                print(\"[INFO] No existing mask image, use one mask as default\")\n                self.mask_images = torch.ones(self.n_images, self.total_pixels, 1, dtype=torch.float)\n        \n        lmask_dir = '{0}/light_mask'.format(self.instance_dir)\n        self.use_lightmask = use_lightmask and os.path.exists(lmask_dir)\n        if self.use_lightmask:\n            lmask_paths = sorted(utils.glob_imgs(lmask_dir))\n            self.lightmask_images = []\n            for path in lmask_paths:\n                lmask = rend_util.load_mask(path)\n                lmask = lmask.reshape(-1, 1)\n                self.lightmask_images.append(torch.from_numpy(lmask).float())\n            self.lightmask_images = torch.stack(self.lightmask_images, 0)\n        \n        depth_dir = '{0}/depth'.format(self.instance_dir)\n        self.use_depth = use_depth and os.path.exists(depth_dir)\n        self.use_bubble = use_bubble and os.path.exists(depth_dir)\n        if self.use_depth or self.use_bubble:\n            self.depth_images = []\n            self.depth_masks = []\n            self.pointcloud = [] # pointcloud for bubble loss, unprojected from depth and poses\n            self.pointlinks = [] # link from pixel index to pointcloud index, value -1 when the pixel is invalid at pointcloud\n            self.pixlinks = [] # link from pointcloud index to pixel index\n            depth_paths = sorted(utils.glob_depths(depth_dir))\n            n_points = 0\n            if kwargs.get('noise_scale', 0.0) > 0:\n                print(f\"[INFO] Ablation study: using noise scale {kwargs.get('noise_scale')}\")\n            for scale_mat, depth_path, intrinsics, pose, i in tzip(scale_mats, depth_paths, self.intrinsics_all, self.pose_all, range(len(self.pose_all))):\n                depth = rend_util.load_depth(depth_path)\n                depth = torch.from_numpy(depth.reshape(-1)).float()\n                depth = depth / scale_mat[2,2]\n                valid_indices = torch.where((depth > 1e-3) & (depth < 6))[0]\n                if i == 0 and scale_mat[2,2] != 1:\n                    print(f\"[INFO] Depth scaled by {scale_mat[2,2]:.2f}\")\n                depth_mask = torch.zeros([self.total_pixels], dtype=torch.bool)\n                depth_mask[valid_indices] = True\n                # if self.use_depth:\n                if (noise_scale := kwargs.get('noise_scale', 0.0)) > 0:\n                    depth = rend_util.add_depth_noise(depth, depth_mask.float(), noise_scale)\n                self.depth_images.append(depth)\n                self.depth_masks.append(depth_mask)\n                if self.use_bubble:\n                    pointlink = -torch.ones([self.total_pixels], dtype=torch.long)\n                    pointlink[depth_mask] = torch.arange(0, len(valid_indices), dtype=torch.long) + n_points\n                    pixlink = torch.arange(i * self.total_pixels, (i + 1) * self.total_pixels, dtype=torch.long)[depth_mask]\n                    n_points += len(valid_indices)\n                    self.pointlinks.append(pointlink)\n                    self.pixlinks.append(pixlink)\n                    self.pointcloud.append(rend_util.depth_to_world(self.uv, intrinsics, pose, depth, depth_mask))\n\n            self.depth_images = torch.stack(self.depth_images, 0)\n            self.depth_masks = torch.stack(self.depth_masks, 0)\n            if self.use_bubble:\n                self.pointcloud = torch.cat(self.pointcloud, 0)\n                self.pointlinks = torch.cat(self.pointlinks, 0)\n                self.pixlinks = torch.cat(self.pixlinks, 0)\n                self.pointcloud = self.pointcloud[:,:3] / self.pointcloud[:,3:]\n                self.pdf_prune = kwargs.get('pdf_prune', 0)\n                self.pdf_max = kwargs.get('pdf_max', None)\n                print(f\"[INFO] PDF clamped to {self.pdf_prune}\")\n        \n        normal_dir = '{0}/normal'.format(self.instance_dir)\n        self.use_normal = use_normal and os.path.exists(normal_dir)\n        if self.use_normal:\n            self.normal_images = []\n            self.normal_masks = []\n            normal_paths = sorted(utils.glob_normal(normal_dir))\n            for pose, normal_path in tzip(self.pose_all, normal_paths):\n                normal = rend_util.load_normal(normal_path)\n                normal = torch.from_numpy(normal.reshape(-1, 3)).float()\n                valid_indices = torch.where(torch.linalg.vector_norm(normal, dim=1) > 1e-3)[0]\n                R = pose[:3,:3]\n                normal = (R @ normal.T).T # convert normal from view space to world space\n                normal = F.normalize(normal, dim=1, eps=1e-6)\n                self.normal_images.append(normal)\n                normal_mask = torch.zeros([self.total_pixels], dtype=torch.bool)\n                normal_mask[valid_indices] = True\n                self.normal_masks.append(normal_mask)\n            self.normal_images = torch.stack(self.normal_images, 0)\n            self.normal_masks = torch.stack(self.normal_masks, 0)\n\n    def __len__(self):\n        return self.n_images * self.total_pixels\n\n    def __getitem__(self, idx):\n        pidx = idx % self.total_pixels\n        tidx = idx\n        idx = idx // self.total_pixels\n        sample = {\n            \"uv\": self.uv[pidx].unsqueeze(0),\n            \"intrinsics\": self.intrinsics_all[idx],\n            \"pose\": self.pose_all[idx]\n        }\n        ground_truth = {\n            \"rgb\": self.rgb_images[idx][pidx]\n        }\n        if self.use_mask:\n            ground_truth['mask'] = self.mask_images[idx][pidx]\n        if self.use_lightmask:\n            ground_truth['light_mask'] = self.lightmask_images[idx][pidx]\n        if self.use_depth or self.use_bubble:\n            ground_truth['depth'] = self.depth_images[idx][pidx]\n            ground_truth['depth_mask'] = self.depth_masks[idx][pidx]\n        if self.use_normal:\n            ground_truth['normal'] = self.normal_images[idx][pidx]\n            ground_truth['normal_mask'] = self.normal_masks[idx][pidx]\n\n        return tidx, idx, sample, ground_truth\n\n    def collate_fn(self, batch_list):\n        # get list of dictionaries and returns input, ground_true as dictionary for all batch instances\n        batch_list = zip(*batch_list)\n\n        all_parsed = []\n        for entry in batch_list:\n            if type(entry[0]) is dict:\n                # make them all into a new dict\n                ret = {}\n                for k in entry[0].keys():\n                    ret[k] = torch.stack([obj[k] for obj in entry])\n                all_parsed.append(ret)\n            else:\n                all_parsed.append(torch.LongTensor(entry))\n\n        return tuple(all_parsed)\n\n\nclass MaterialDataset(torch.utils.data.Dataset):\n\n    def __init__(self,\n                 data_dir,\n                 scan_id=0,\n                 downsample_train=1,\n                 is_hdr=False,\n                 **kwargs\n                 ):\n        self.sampling_idx = slice(None)\n\n        self.instance_dir = os.path.join('data', data_dir, 'scan{0}'.format(scan_id))\n\n        assert os.path.exists(self.instance_dir), \"Data directory is empty\"\n\n        image_dir = '{0}/image'.format(self.instance_dir) if not is_hdr else '{0}/hdr'.format(self.instance_dir)\n        self.is_hdr = is_hdr\n        if is_hdr:\n            print(\"[INFO] Using HDR image\")\n        image_paths = sorted(utils.glob_imgs(image_dir))\n        self.n_images = len(image_paths)\n\n        self.cam_file = '{0}/cameras_normalize.npz'.format(self.instance_dir)\n        camera_dict = np.load(self.cam_file)\n        scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]\n        world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(self.n_images)]\n\n        self.intrinsics_all = []\n        self.pose_all = []\n\n        for scale_mat, world_mat in zip(scale_mats, world_mats):\n            P = world_mat @ scale_mat\n            P = P[:3, :4]\n            intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)\n            self.intrinsics_all.append(torch.from_numpy(intrinsics).float())\n            self.pose_all.append(torch.from_numpy(pose).float())\n\n        self.intrinsics_all = torch.stack(self.intrinsics_all, 0)\n        self.pose_all = torch.stack(self.pose_all, 0)\n\n        self.rgb_images = []\n        for path in image_paths:\n            rgb = rend_util.load_rgb(path, is_hdr=is_hdr)\n            self.img_res = [rgb.shape[1], rgb.shape[2]]\n            rgb = rgb.reshape(3, -1).transpose(1, 0)\n            self.rgb_images.append(torch.from_numpy(rgb).float())\n        self.rgb_images = torch.stack(self.rgb_images, 0)\n        self.total_pixels = self.rgb_images.size(1)\n\n        mask_dir = '{0}/mask'.format(self.instance_dir)\n        self.use_mask = os.path.exists(mask_dir)\n        if self.use_mask:\n            mask_paths = sorted(utils.glob_imgs(mask_dir))\n            # assert len(mask_paths) == self.n_images\n            self.mask_images = []\n            for path in mask_paths:\n                mask = rend_util.load_mask(path)\n                mask = mask.reshape(-1, 1)\n                self.mask_images.append(torch.from_numpy(mask).float())\n            self.mask_images = torch.stack(self.mask_images, 0)\n        \n        lmask_dir = '{0}/light_mask'.format(self.instance_dir)\n        self.use_lightmask = os.path.exists(lmask_dir)\n        if self.use_lightmask:\n            print(\"[INFO] Light mask detected\")\n            lmask_paths = sorted(utils.glob_imgs(lmask_dir))\n            self.lightmask_images = []\n            for path in lmask_paths:\n                lmask = rend_util.load_mask(path)\n                lmask = lmask.reshape(-1, 1)\n                self.lightmask_images.append(torch.from_numpy(lmask).float())\n            self.lightmask_images = torch.stack(self.lightmask_images, 0)\n        \n        if downsample_train > 1:\n            old_res = (self.img_res[0], self.img_res[1])\n            self.rgb_images = self.rgb_images.transpose(1, 2).reshape(-1, 3, *old_res)\n            self.img_res[0] //= downsample_train\n            self.img_res[1] //= downsample_train\n            self.total_pixels = self.img_res[0] * self.img_res[1]\n            self.rgb_images = F.interpolate(self.rgb_images, self.img_res, mode='area')\n            self.rgb_images = self.rgb_images.reshape(-1, 3, self.total_pixels).transpose(1, 2)\n            self.intrinsics_all = self.intrinsics_all.clone()\n            self.intrinsics_all[:,0,0] /= downsample_train\n            self.intrinsics_all[:,1,1] /= downsample_train\n            self.intrinsics_all[:,0,2] /= downsample_train\n            self.intrinsics_all[:,1,2] /= downsample_train\n            if self.use_mask:\n                self.mask_images = self.mask_images.transpose(1, 2).reshape(-1, 1, *old_res)\n                self.mask_images = F.interpolate(self.mask_images, self.img_res, mode='area')\n                self.mask_images = self.mask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2)\n            if self.use_lightmask:\n                self.lightmask_images = self.lightmask_images.transpose(1, 2).reshape(-1, 1, *old_res)\n                self.lightmask_images = F.interpolate(self.lightmask_images, self.img_res, mode='area')\n                self.lightmask_images[self.lightmask_images > 0] = 1\n                self.lightmask_images = self.lightmask_images.reshape(-1, 1, self.total_pixels).transpose(1, 2)\n\n        print(f\"[INFO] image size: {self.img_res[1]}x{self.img_res[0]}, {self.total_pixels} pixels in total\")\n        \n        uv = np.mgrid[0:self.img_res[0], 0:self.img_res[1]].astype(np.int32)\n        uv = np.flip(uv, axis=0).copy()\n        self.uv = torch.from_numpy(uv).float()\n        self.uv = self.uv.reshape(2, -1).transpose(1, 0) # (h*w, 2)\n\n    def __len__(self):\n        return self.n_images * self.total_pixels\n\n    def __getitem__(self, idx):\n        pidx = idx % self.total_pixels\n        tidx = idx\n        idx = idx // self.total_pixels\n        sample = {\n            \"uv\": self.uv[pidx].unsqueeze(0),\n            \"intrinsics\": self.intrinsics_all[idx],\n            \"pose\": self.pose_all[idx]\n        }\n        ground_truth = {\n            \"rgb\": self.rgb_images[idx][pidx]\n        }\n        if self.use_mask:\n            ground_truth['mask'] = self.mask_images[idx][pidx]\n        if self.use_lightmask:\n            ground_truth['light_mask'] = self.lightmask_images[idx][pidx]\n        return tidx, idx, sample, ground_truth\n\n    def collate_fn(self, batch_list):\n        # get list of dictionaries and returns input, ground_true as dictionary for all batch instances\n        batch_list = zip(*batch_list)\n\n        all_parsed = []\n        for entry in batch_list:\n            if type(entry[0]) is dict:\n                # make them all into a new dict\n                ret = {}\n                for k in entry[0].keys():\n                    ret[k] = torch.stack([obj[k] for obj in entry])\n                all_parsed.append(ret)\n            else:\n                all_parsed.append(torch.LongTensor(entry))\n\n        return tuple(all_parsed)"
  },
  {
    "path": "environment.yml",
    "content": "name: i2sdf\nchannels:\n  - pytorch\n  - conda-forge\n  - defaults\ndependencies:\n  - cudatoolkit=11.3.1=h9edb442_10\n  - ffmpeg=4.3=hf484d3e_0\n  - numpy=1.23.5=py39h14f4228_0\n  - pip=23.0.1=py39h06a4308_0\n  - python=3.9.16=h7a1cb2a_2\n  - pytorch=1.12.1=py3.9_cuda11.3_cudnn8.3.2_0\n  - torchaudio=0.12.1=py39_cu113\n  - torchvision=0.13.1=py39_cu113\n  - pip:\n    - fast-pytorch-kmeans==0.1.9\n    - ffmpeg-python==0.2.0\n    - gputil==1.4.0\n    - lpips==0.1.4\n    - open3d==0.17.0\n    - opencv-python==4.7.0.72\n    - pymcubes==0.1.4\n    - pytorch-lightning==1.9.0\n    - pyyaml==6.0\n    - rich==13.3.3\n    - scikit-image==0.20.0\n    - scikit-learn==1.2.2\n    - scipy==1.9.1\n    - tensorboard==2.12.0\n    - tensorboardx==2.6\n    - torchmetrics==0.11.4\n    - tqdm==4.65.0\n    - trimesh==3.21.4\n"
  },
  {
    "path": "i2-sdf-dataset-links.csv",
    "content": "file,url\ninteriorverse/i2-sdf/i2-sdf/bedroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_0.zip\ninteriorverse/i2-sdf/i2-sdf/bedroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_0_preview.png\ninteriorverse/i2-sdf/i2-sdf/bedroom_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_1.zip\ninteriorverse/i2-sdf/i2-sdf/bedroom_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_1_preview.png\ninteriorverse/i2-sdf/i2-sdf/bedroom_relight_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_0.zip\ninteriorverse/i2-sdf/i2-sdf/bedroom_relight_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_0_preview.png\ninteriorverse/i2-sdf/i2-sdf/bedroom_relight_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_1.zip\ninteriorverse/i2-sdf/i2-sdf/bedroom_relight_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/bedroom_relight_1_preview.png\ninteriorverse/i2-sdf/i2-sdf/diningroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/diningroom_0.zip\ninteriorverse/i2-sdf/i2-sdf/diningroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/diningroom_0_preview.png\ninteriorverse/i2-sdf/i2-sdf/kitchen_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/kitchen_0.zip\ninteriorverse/i2-sdf/i2-sdf/kitchen_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/kitchen_0_preview.png\ninteriorverse/i2-sdf/i2-sdf/livingroom_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_0.zip\ninteriorverse/i2-sdf/i2-sdf/livingroom_0_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_0_preview.png\ninteriorverse/i2-sdf/i2-sdf/livingroom_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_1.zip\ninteriorverse/i2-sdf/i2-sdf/livingroom_1_preview.png,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/livingroom_1_preview.png\ninteriorverse/i2-sdf/i2-sdf/real_data/inria_livingroom.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/inria_livingroom.zip\ninteriorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_0.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_0.zip\ninteriorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_1.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_coffee_shop_1.zip\ninteriorverse/i2-sdf/i2-sdf/real_data/nisr_livingroom.zip,https://kloudsim-usa-cos.kujiale.com/interiorverse/i2-sdf/i2-sdf/real_data/nisr_livingroom.zip\n"
  },
  {
    "path": "main_recon.py",
    "content": "import torch\nimport yaml\nimport pytorch_lightning as pl\nimport argparse\nimport os\nimport utils\nimport model\nfrom pytorch_lightning import loggers\nfrom pytorch_lightning.callbacks import ModelCheckpoint\nfrom rich.progress import TextColumn\nimport GPUtil\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--conf\", type=str, required=True, help=\"Path to (.yml) config file.\")\n    parser.add_argument('-d', \"--device_ids\", type=int, nargs='+', default=None, help=\"GPU devices to use\")\n    parser.add_argument(\"--exps_folder\", type=str, default=\"exps\")\n    parser.add_argument('--expname', type=str, default='')\n    parser.add_argument('--scan_id', type=int, default=-1, help='If set, taken to be the scan id.')\n    parser.add_argument('--test', action='store_true')\n    parser.add_argument('--test_mode', choices=['render', 'mesh', 'interpolate'], default='render')\n    parser.add_argument('-v', '--version', type=int, nargs='?')\n    parser.add_argument('--inter_id', type=int, nargs=2, required=False, help='2 view ids for interpolation video.')\n    parser.add_argument('-i', '--indices', nargs='*', type=int, help='If set, render only specified indices of the dataset instead of all images.')\n    parser.add_argument('--n_frames', type=int, default=60, help='Number of frames in the interpolation video.')\n    parser.add_argument('--frame_rate', type=int, default=24, help='Frame rate of the interpolation video.')\n    parser.add_argument('-f', '--full_res', action='store_true', help='If set, dataset downscaling will be ignored.')\n    parser.add_argument('--is_val', action='store_true', help='If set, render the validation set instead of training set.')\n    parser.add_argument('--val_mesh', action='store_true', help='If set, extract and save mesh every validation epoch.')\n    parser.add_argument('--score', action='store_true', help='If set, evaluate the meshing score (need to provide GT mesh).')\n    parser.add_argument('--far_clip', type=float, default=5.0)\n    parser.add_argument('--ckpt', type=str, default='last')\n    parser.add_argument('--resolution', type=int, default=512, help='Resolution for marching cube algorithm')\n    parser.add_argument('--spp', type=int, default=128)\n    parser.add_argument('--seed', type=int, default=42)\n    args = parser.parse_args()\n\n    with open(args.conf) as f:\n        cfg_dict = yaml.load(f, Loader=yaml.FullLoader)\n        cfg = utils.CfgNode(cfg_dict)\n    \n    expname = args.expname if args.expname else cfg.train.expname\n    scan_id = cfg.dataset.scan_id if args.scan_id == -1 else args.scan_id\n    cfg.dataset.scan_id = scan_id\n    expname = expname + '_' + str(scan_id)\n\n    if args.version is None and (v := args.conf.find(\"version_\")) != -1:\n        args.version = int(args.conf[v + 8:args.conf.find(\"/config\")])\n        print(f\"[INFO] Loaded version {args.version} from config file\")\n    \n    if args.version is not None:\n        logger = loggers.TensorBoardLogger(save_dir=args.exps_folder, name=expname, version=args.version)\n    else:\n        logger = loggers.TensorBoardLogger(save_dir=args.exps_folder, name=expname)\n    \n    if args.device_ids is None:\n        args.device_ids = GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.5, maxMemory=0.5, includeNan=False,\n                                              excludeID=[], excludeUUID=[])\n        print(\"Selected GPU {} automatically\".format(args.device_ids[0]))\n    torch.cuda.set_device(args.device_ids[0])\n    torch.set_float32_matmul_precision('medium')\n    progbar_callback = utils.RichProgressBarWithScanId(scan_id, leave=False)\n    pl.seed_everything(args.seed)\n    \n    if args.test:\n        version = args.version if args.version is not None else logger.version - 1\n        exp_dir = os.path.join(logger.root_dir, f\"version_{version}\")\n        del logger\n        if args.test_mode == 'render':\n            system = model.VolumeRenderSystem(cfg, exp_dir, indices=args.indices, is_val=args.is_val, full_res=args.full_res)\n            if not args.ckpt.endswith('.ckpt'):\n                args.ckpt += '.ckpt'\n            ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda')\n            system.load_state_dict(ckpt['state_dict'])\n            model.lpips.cuda()\n        elif args.test_mode == 'mesh':\n            system = model.SDFMeshSystem(cfg, exp_dir, args.resolution, args.score)\n            if not args.ckpt.endswith('.ckpt'):\n                args.ckpt += '.ckpt'\n            ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda')\n            system.load_state_dict(ckpt['state_dict'])\n            system.cuda()\n            system.eval()\n            system.initialize()\n        # elif args.test_mode == 'interpolate':\n        else:\n            system = model.ViewInterpolateSystem(cfg, exp_dir, *args.inter_id, n_frames=args.n_frames, frame_rate=args.frame_rate)\n            if not args.ckpt.endswith('.ckpt'):\n                args.ckpt += '.ckpt'\n            ckpt = torch.load(os.path.join(exp_dir, 'checkpoints', args.ckpt), map_location='cuda')\n            system.load_state_dict(ckpt['state_dict'])\n        trainer = pl.Trainer(\n            logger=False,\n            accelerator='gpu',\n            devices=args.device_ids,\n            callbacks=[progbar_callback]\n        )\n        trainer.test(system)\n    else:\n        max_steps = cfg.train.get('steps', 200000)\n        print(f\"Training for {max_steps} steps\")\n        exp_dir = logger.log_dir\n        checkpoint_callback = ModelCheckpoint(os.path.join(exp_dir, 'checkpoints'), save_last=True, every_n_train_steps=cfg.train.checkpoint_freq)\n        if hasattr(cfg.train, 'plot_freq'):\n            kwargs = {'val_check_interval': cfg.train.plot_freq}\n        else:\n            kwargs = {'check_val_every_n_epoch': cfg.train.plot_epochs}\n        trainer = pl.Trainer(\n            logger=logger,\n            accelerator='gpu',\n            devices=args.device_ids,\n            strategy=None,\n            callbacks=[checkpoint_callback, progbar_callback],\n            max_steps=max_steps,\n            **kwargs\n        )\n        system = model.ReconstructionTrainer(\n            cfg, progbar_callback,\n            exp_dir=exp_dir,\n            is_val=args.is_val,\n            val_mesh=args.val_mesh\n        )\n        trainer.fit(system)\n    torch.cuda.empty_cache()"
  },
  {
    "path": "model/__init__.py",
    "content": "from .network import *\nfrom .trainer import *\n# from .material import *\nfrom .eval import *\nfrom .rendering import RenderingLayer"
  },
  {
    "path": "model/eval/__init__.py",
    "content": "from .recon import *"
  },
  {
    "path": "model/eval/recon.py",
    "content": "import torch\nimport pytorch_lightning as pl\nimport numpy as np\nimport os\nfrom glob import glob\nfrom torch.utils.data import DataLoader\nimport utils\nfrom utils import rend_util\nimport utils.plots as plt\nimport dataset\nimport model\nfrom skimage import measure\nimport cv2\nimport trimesh\nfrom rich.progress import track\nfrom torchmetrics.functional import structural_similarity_index_measure as ssim\nfrom torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS\n\nlpips = LPIPS()\n\nclass SDFMeshSystem(pl.LightningModule):\n    def __init__(self, conf, exp_dir, resolution, score=False, far_clip=5.0) -> None:\n        super().__init__()\n        self.expdir = exp_dir\n        conf_model = conf.model\n        conf_model.use_normal = False\n        self.model = model.I2SDFNetwork(conf_model)\n        self.resolution = resolution\n        self.grid_boundary = conf.plot.grid_boundary\n        self.initialized = False\n        self.instance_dir = os.path.join('data', conf.dataset.data_dir, 'scan{0}'.format(conf.dataset.scan_id))\n        camera_dict = np.load(os.path.join(self.instance_dir, 'cameras_normalize.npz'))\n        self.scale_mat = camera_dict['scale_mat_0']\n        self.scan_id = conf.dataset.scan_id\n        self.score = score\n        if score:\n            self.n_imgs = len(os.listdir(os.path.join(self.instance_dir, 'image')))\n            self.poses = []\n            self.far_clip = far_clip\n            for i in range(self.n_imgs):\n                K, pose = rend_util.load_K_Rt_from_P(None, camera_dict[f'world_mat_{i}'][:3,:])\n                self.poses.append(pose)\n            self.K = K\n            self.H, self.W, _ = cv2.imread(os.path.join(self.instance_dir, 'image', '0000.png')).shape\n\n    def initialize(self):\n        grid = plt.get_grid_uniform(100, self.grid_boundary)\n        z = []\n        points = grid['grid_points']\n        for pnts in track(torch.split(points, 1000000, dim=0)):\n            z.append(self.model.implicit_network(pnts)[:,0].detach().cpu().numpy())\n        z = np.concatenate(z, axis=0).astype(np.float32)\n        verts, faces, normals, values = measure.marching_cubes(\n        volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],\n                         grid['xyz'][2].shape[0]).transpose([1, 0, 2]),\n                         level=0,\n                         spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],\n                         grid['xyz'][0][2] - grid['xyz'][0][1],\n                         grid['xyz'][0][2] - grid['xyz'][0][1]))\n        verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])\n        mesh_low_res = trimesh.Trimesh(verts, faces, normals)\n        recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]\n        recon_pc = torch.from_numpy(recon_pc).float().cuda()\n        s_mean = recon_pc.mean(dim=0)\n        s_cov = recon_pc - s_mean\n        s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)\n        self.vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0]\n        if torch.det(self.vecs) < 0:\n            self.vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), self.vecs)\n        helper = torch.bmm(self.vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),\n                        (recon_pc - s_mean).unsqueeze(-1)).squeeze().cpu()\n        grid_aligned = plt.get_grid(helper, self.resolution)\n        grid_points = grid_aligned['grid_points']\n        g = []\n        for pnts in track(torch.split(grid_points, 1000000, dim=0)):\n            g.append(torch.bmm(self.vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),\n                           pnts.unsqueeze(-1)).squeeze() + s_mean)\n        grid_points = torch.cat(g, dim=0)\n        points = grid_points.cpu()\n        self.test_dataset = dataset.GridDataset(points, grid_aligned['xyz'])\n        self.grid_points = grid_points\n        self.initialized = True\n\n    def test_dataloader(self):\n        assert self.initialized\n        print(len(self.test_dataset))\n        return DataLoader(self.test_dataset, batch_size=2000000, shuffle=False, num_workers=32)\n    \n    def test_step(self, batch, batch_idx):\n        return self.model.implicit_network(batch)[:,0].detach().cpu().numpy()\n    \n    def test_epoch_end(self, outputs) -> None:\n        # z = torch.cat(outputs, dim=0).cpu().numpy()\n        z = np.concatenate(outputs, axis=0).astype(np.float32)\n        if (not (np.min(z) > 0 or np.max(z) < 0)):\n            verts, faces, normals, values = measure.marching_cubes(\n                            volume=z.reshape(self.test_dataset.xyz[1].shape[0], self.test_dataset.xyz[0].shape[0], self.test_dataset.xyz[2].shape[0]).transpose([1, 0, 2]),\n                            level=0,\n                            spacing=(self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1],\n                                     self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1],\n                                     self.test_dataset.xyz[0][2] - self.test_dataset.xyz[0][1]))\n            verts = torch.from_numpy(verts).float().cuda()\n            verts = torch.bmm(self.vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),\n                   verts.unsqueeze(-1)).squeeze()\n            verts = (verts + self.grid_points[0]).cpu().numpy()\n            mesh = trimesh.Trimesh(verts, faces, normals)\n            mesh.apply_transform(self.scale_mat)\n            mesh_folder = os.path.join(self.expdir, 'eval/mesh')\n            os.makedirs(mesh_folder, exist_ok=True)\n            mesh.export(os.path.join(mesh_folder, 'scan{0}.ply'.format(self.scan_id)), 'ply')\n            if self.score:\n                from utils import mesh_util\n                import open3d as o3d\n                mesh = mesh_util.refuse(mesh, self.poses, self.K, self.H, self.W)\n                out_mesh_path = os.path.join(mesh_folder, 'scan{0}_refined.ply'.format(self.scan_id))\n                o3d.io.write_triangle_mesh(out_mesh_path, mesh)\n                mesh = trimesh.load(out_mesh_path)\n                print(\"[INFO] Pred mesh refined\")\n                gt_mesh = trimesh.load(os.path.join(self.instance_dir, 'mesh.ply'))\n                gt_mesh = mesh_util.refuse(gt_mesh, self.poses, self.K, self.H, self.W, self.far_clip)\n                out_mesh_path = os.path.join(mesh_folder, 'scan{0}_gt.ply'.format(self.scan_id))\n                o3d.io.write_triangle_mesh(out_mesh_path, gt_mesh)\n                gt_mesh = trimesh.load(out_mesh_path)\n                print(\"[INFO] GT mesh refined\")\n                metrics = mesh_util.evaluate(mesh, gt_mesh)\n                with open(f\"{mesh_folder}/metrics.txt\", 'w') as f:\n                    for k in metrics:\n                        f.write(f\"{k.upper()}: {metrics[k]}\\n\")\n                print(f\"[INFO] Metrics saved to {mesh_folder}/metrics.txt\\n\")\n\n    def forward(self):\n        raise NotImplementedError(\"forward not supported by trainer\")\n\n\nclass VolumeRenderSystem(pl.LightningModule):\n    def __init__(self, conf, exp_dir, indices=None, is_val=False, score_mesh=False, full_res=False) -> None:\n        super().__init__()\n        self.expdir = exp_dir\n        conf_model = conf.model\n        conf_model.use_normal = False\n        self.model = model.I2SDFNetwork(conf_model)\n        self.scan_id = conf.dataset.scan_id\n        dataset_conf = conf.dataset\n        if full_res:\n            dataset_conf.downsample = 1\n        self.test_dataset = dataset.PlotDataset(**dataset_conf, plot_nimgs=-1, shuffle=False, indices=indices, is_val=is_val)\n        self.total_pixels = self.test_dataset.total_pixels\n        self.img_res = self.test_dataset.img_res\n        self.split_n_pixels = conf.train.split_n_pixels\n        self.expdir = os.path.join(self.expdir, 'eval')\n        if is_val:\n            self.expdir = os.path.join(self.expdir, 'test')\n        os.makedirs(os.path.join(self.expdir, 'rendering'), exist_ok=True)\n        os.makedirs(os.path.join(self.expdir, 'depth'), exist_ok=True)\n        os.makedirs(os.path.join(self.expdir, 'normal'), exist_ok=True)\n\n    def test_dataloader(self):\n        print(len(self.test_dataset))\n        return DataLoader(self.test_dataset, batch_size=1, shuffle=False, collate_fn=self.test_dataset.collate_fn)\n\n    @torch.inference_mode(False)\n    @torch.no_grad()\n    def test_step(self, batch, batch_idx):\n        indices, model_input, ground_truth = batch\n        # idx = batch_idx\n        idx = self.test_dataset.indices[batch_idx]\n        split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels)\n        res = []\n        for s in split:\n            out = utils.detach_dict(self.model(s))\n            d = {\n                'rgb_values': out['rgb_values'].detach(),\n                'depth_values': out['depth_values'].detach()\n            }\n            d['normal_map'] = out['normal_map'].detach()\n            # d['surface_point'] = out['surface_point'].detach()\n            del out\n            res.append(d)\n        model_outputs = utils.merge_output(res, self.total_pixels, 1)\n        _, num_samples, _ = ground_truth['rgb'].shape\n        model_outputs['rgb_values'] = model_outputs['rgb_values'].reshape(1, num_samples, 3)\n        model_outputs['depth_values'] = model_outputs['depth_values'].reshape(1, num_samples, 1)\n        plt.plot_imgs_wo_gt(model_outputs['normal_map'].reshape(1, num_samples, 3), self.expdir, \"{:04d}w\".format(idx), 1, self.img_res, is_hdr=True)\n        normal_map = model_outputs['normal_map'].reshape(num_samples, 3).T # (3, h*w)\n        R = model_input['pose'].squeeze()[:3,:3].T\n        normal_map = R @ normal_map\n        model_outputs['normal_map'] = normal_map.T.reshape(1, num_samples, 3)\n        plt.plot_imgs_wo_gt(model_outputs['normal_map'], self.expdir, \"{:04d}\".format(idx), 1, self.img_res, is_hdr=True)\n        model_outputs['normal_map'] = (model_outputs['normal_map'] + 1.) / 2.\n        plt.plot_imgs_wo_gt(model_outputs['normal_map'], self.expdir, \"{:04d}\".format(idx), 1, self.img_res)\n        # plt.plot_imgs_wo_gt(model_outputs['surface_point'].reshape(1, num_samples, 3), self.expdir, \"{:04d}p\".format(idx), 1, self.img_res, is_hdr=True)\n\n        plt.plot_images(model_outputs['rgb_values'], ground_truth['rgb'], self.expdir, \"{:04d}\".format(idx), 1, self.img_res)\n        plt.plot_imgs_wo_gt(model_outputs['rgb_values'], self.expdir, \"{:04d}_pred\".format(idx), 1, self.img_res, 'rendering')\n        plt.plot_depths(model_outputs['depth_values'], self.expdir, \"{:04d}\".format(idx), 1, self.img_res)\n        plt.plot_depths(model_outputs['depth_values'], self.expdir, \"{:04d}\".format(idx), 1, self.img_res, None)\n        pred_img = model_outputs['rgb_values'].T.reshape(3, *self.img_res).unsqueeze(0)\n        gt_img = ground_truth['rgb'].T.reshape(3, *self.img_res).unsqueeze(0)\n        return {\n            'psnr': utils.get_psnr(model_outputs['rgb_values'], ground_truth['rgb']).item(),\n            'ssim': ssim(pred_img, gt_img).item(),\n            'lpips': lpips(pred_img.clamp(0, 1) * 2 - 1, gt_img.clamp(0, 1) * 2 - 1).item()\n        }\n    \n    def test_epoch_end(self, outputs):\n        with open(os.path.join(self.expdir, 'metrics.txt'), 'w') as f:\n            f.write(f\"# IMAGE RESOLUTION {self.img_res}\\n\")\n            psnr_sum = ssim_sum = lpips_sum = 0\n            psnrs = []\n            ssims = []\n            lpipss = []\n            for i, metrics in enumerate(outputs):\n                f.write(f\"[{i:04d}] [PSNR]{metrics['psnr']:.2f} [SSIM]{metrics['ssim']:.2f} [LPIPS]{metrics['lpips']:.2f}\\n\")\n                psnrs.append(metrics['psnr'])\n                ssims.append(metrics['ssim'])\n                lpipss.append(metrics['lpips'])\n                psnr_sum += metrics['psnr']\n                ssim_sum += metrics['ssim']\n                lpips_sum += metrics['lpips']\n            f.write(f\"[MEAN] [PSNR]{psnr_sum/len(outputs):.2f} [SSIM]{ssim_sum/len(outputs):.2f} [LPIPS]{lpips_sum/len(outputs):.2f}\\n\")\n            np.savez_compressed(os.path.join(self.expdir, 'metrics.npz'), psnr=np.array(psnrs), ssim=np.array(ssims), lpips=np.array(lpipss))\n\n    def forward(self):\n        raise NotImplementedError(\"forward not supported by trainer\")\n\n\nclass ViewInterpolateSystem(pl.LightningModule):\n    def __init__(self, conf, exp_dir, id0, id1, n_frames=60, frame_rate=24, use_normal=True) -> None:\n        super().__init__()\n        self.expdir = exp_dir\n        conf_model = conf.model\n        conf_model.use_normal = False\n        self.model = model.I2SDFNetwork(conf_model)\n        self.scan_id = conf.dataset.scan_id\n        dataset_conf = conf.dataset\n        self.test_dataset = dataset.InterpolateDataset(**dataset_conf, id0=id0, id1=id1, num_frames=n_frames)\n        self.total_pixels = self.test_dataset.total_pixels\n        self.img_res = self.test_dataset.img_res\n        self.split_n_pixels = conf.train.split_n_pixels\n        self.n_frames = n_frames\n        self.frame_rate = frame_rate\n        self.video_dir = os.path.join(self.expdir, 'eval/interpolate')\n        self.id0 = id0\n        self.id1 = id1\n        self.use_normal = use_normal\n        os.makedirs(self.video_dir, exist_ok=True)\n        self.frame_dir = os.path.join(self.video_dir, f\"{self.id0:04d}_{self.id1:04d}\")\n        os.makedirs(self.frame_dir, exist_ok=True)\n        if self.use_normal:\n            self.normal_fdir = os.path.join(self.video_dir, f\"{self.id0:04d}_{self.id1:04d}_normal\")\n            os.makedirs(self.normal_fdir, exist_ok=True)\n\n    def test_dataloader(self):\n        print(len(self.test_dataset))\n        return DataLoader(self.test_dataset, batch_size=1, shuffle=False, collate_fn=self.test_dataset.collate_fn)\n    \n    @torch.inference_mode(False)\n    @torch.no_grad()\n    def test_step(self, batch, batch_idx):\n        indices, model_input = batch\n        idx = batch_idx\n        split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels)\n        res = []\n        res_normal = []\n        for s in split:\n            out = utils.detach_dict(self.model(s, predict_only=not self.use_normal))\n            rgb = out['rgb_values'].detach()\n            res.append(rgb)\n            if self.use_normal:\n                res_normal.append(out['normal_map'])\n            del out\n        rendered = torch.cat(res, dim=0).reshape(self.img_res[0], self.img_res[1], 3).cpu().numpy()\n        rendered = (rendered * 255).clip(0, 255).astype(np.uint8)\n        cv2.imwrite(f\"{self.frame_dir}/{idx:04d}.png\", rendered[:,:,::-1])\n        if self.use_normal:\n            normal_map = torch.cat(res_normal, dim=0).reshape(-1, 3).T\n            R = model_input['pose'].squeeze()[:3,:3].T\n            normal_map = R @ normal_map\n            normal_map = normal_map.T.reshape(self.img_res[0], self.img_res[1], 3).cpu().numpy()\n            normal_map = (((normal_map + 1) * 0.5) * 255).clip(0, 255).astype(np.uint8)\n            cv2.imwrite(f\"{self.normal_fdir}/{idx:04d}.png\", normal_map[:,:,::-1])\n\n    \n    def test_epoch_end(self, outputs):\n        import ffmpeg\n        (\n            ffmpeg\n            .input(os.path.join(self.frame_dir, '*.png'), pattern_type='glob', framerate=self.frame_rate)\n            .output(os.path.join(self.video_dir, f\"scan{self.scan_id}_{self.id0:04d}_{self.id1:04d}.mp4\"), vcodec='h264')\n            .overwrite_output()\n            .run()\n        )\n        if self.use_normal:\n            (\n                ffmpeg\n                .input(os.path.join(self.normal_fdir, '*.png'), pattern_type='glob', framerate=self.frame_rate)\n                .output(os.path.join(self.video_dir, f\"scan{self.scan_id}_{self.id0:04d}_{self.id1:04d}_normal.mp4\"), vcodec='h264')\n                .overwrite_output()\n                .run()\n            )\n\n    def forward(self):\n        raise NotImplementedError(\"forward not supported by trainer\")\n\n"
  },
  {
    "path": "model/network/__init__.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nimport numpy as np\n\nimport utils\nfrom model.network.mlp import ImplicitNetwork, RenderingNetwork\nfrom model.network.density import LaplaceDensity, AbsDensity\nfrom model.network.ray_sampler import ErrorBoundSampler\nfrom fast_pytorch_kmeans import KMeans\nfrom sklearn.cluster import DBSCAN\n\n\n\"\"\"\nFor modeling more complex backgrounds, we follow the inverted sphere parametrization from NeRF++ \nhttps://github.com/Kai-46/nerfplusplus \n\"\"\"\nclass I2SDFNetwork(nn.Module):\n    def __init__(self, conf):\n        super().__init__()\n        self.feature_vector_size = conf.feature_vector_size\n        self.scene_bounding_sphere = getattr(conf, 'scene_bounding_sphere', 1.0)\n\n        # Foreground object's networks\n        self.implicit_network = ImplicitNetwork(self.feature_vector_size, 0.0, **conf.implicit_network)\n        self.rendering_network = RenderingNetwork(self.feature_vector_size, **conf.rendering_network)\n\n        self.use_light = hasattr(conf, 'light_network')\n        if self.use_light:\n            # self.light_network = RenderingNetwork(self.feature_vector_size, mode='nerf', output_activation='sigmoid', use_dir=False, **conf.light_network)\n            self.light_network = ImplicitNetwork(0, 0, d_in=self.feature_vector_size, d_out=1, geometric_init=False, embed_type=None, output_activation='sigmoid', **conf.light_network)\n\n        self.density = LaplaceDensity(**conf.density)\n\n        # Background's networks\n        self.use_bg = hasattr(conf, 'bg_network')\n        if self.use_bg:\n            bg_feature_vector_size = conf.bg_network.feature_vector_size\n            self.bg_implicit_network = ImplicitNetwork(bg_feature_vector_size, 0.0, **conf.bg_network.implicit_network)\n            self.bg_rendering_network = RenderingNetwork(bg_feature_vector_size, **conf.bg_network.rendering_network)\n            self.bg_density = AbsDensity(**getattr(conf.bg_network, 'density', {}))\n        else:\n            print(\"[INFO] BG Network Disabled\")\n        self.ray_sampler = ErrorBoundSampler(self.scene_bounding_sphere, inverse_sphere_bg=self.use_bg, **conf.ray_sampler)\n        self.use_normal = conf.get('use_normal', False)\n        self.detach_light_feature = conf.get('detach_light_feature', True)\n    \n    def init_emission_groups(self, n_emitters, pointcloud, init_emission=1.0, use_dbscan=False):\n        if use_dbscan:\n            \"\"\"\n            Use DBSCAN algorithm to initialize emitter cluster centroids for K-Means from a small random batch\n            Note that DBSCAN can automatically determine the number of clusters\n            \"\"\"\n            pt_samples = pointcloud[torch.randperm(len(pointcloud))[:10000]].cpu().numpy()\n            lab_samples = torch.from_numpy(DBSCAN(n_jobs=16).fit_predict(pt_samples))\n            if n_emitters != len(torch.unique(lab_samples)):\n                print(f\"[ERROR] Inconsistent emitter count: {n_emitters} / {len(torch.unique(lab_samples))}\")\n                # n_emitters = len(torch.unique(lab_samples))\n                exit()\n            init_centroids = torch.zeros(n_emitters, 3)\n            for i in range(n_emitters):\n                idx = (lab_samples == i).int().argmax()\n                init_centroids[i,:] = torch.from_numpy(pt_samples[idx, :])\n            init_centroids = init_centroids.to(pointcloud.device)\n        else:\n            \"\"\"\n            Use K-Means plus plus to initialize emitter cluster centroids for K-Means\n            \"\"\"\n            init_centroids = utils.kmeans_pp_centroid(pointcloud, n_emitters)\n        self.emitter_clusters = KMeans(n_emitters)\n        labels = self.emitter_clusters.fit_predict(pointcloud, init_centroids)\n        print(\"[INFO] emitters clustered\")\n        self.emissions = nn.Parameter(torch.empty(n_emitters, 3).fill_(init_emission), True)\n        return labels, self.emitter_clusters.centroids\n    \n    def get_param_groups(self, lr):\n        return [{'params': self.parameters(), 'lr': lr}]\n\n    def forward(self, input, predict_only=False):\n\n        intrinsics = input[\"intrinsics\"]\n        uv = input[\"uv\"]\n        pose = input[\"pose\"]\n\n        ray_dirs, cam_loc = utils.get_camera_params(uv, pose, intrinsics)\n\n        batch_size, num_pixels, _ = ray_dirs.shape\n\n        cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3)\n        ray_dirs = ray_dirs.reshape(-1, 3)\n        ray_dirs_norm = torch.linalg.vector_norm(ray_dirs, dim=1)\n        ray_dirs = F.normalize(ray_dirs, dim=1)\n\n        z_vals, z_samples_eik = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self)\n\n        if self.use_bg:\n            z_vals, z_vals_bg = z_vals\n        z_max = z_vals[:,-1]\n        z_vals = z_vals[:,:-1]\n        N_samples = z_vals.shape[1]\n\n        points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1)\n        points_flat = points.reshape(-1, 3)\n\n        dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1)\n        dirs_flat = dirs.reshape(-1, 3)\n\n        returns_grad = self.use_normal or (not self.training) or (self.rendering_network.mode == 'idr')\n        # with torch.enable_grad():\n        with torch.set_grad_enabled(returns_grad):\n        # with torch.inference_mode(not returns_grad):\n            sdf, feature_vectors, gradients = self.implicit_network.get_outputs(points_flat, returns_grad)\n\n        rgb_flat = self.rendering_network(points_flat, gradients, dirs_flat, feature_vectors)\n        rgb = rgb_flat.reshape(-1, N_samples, 3)\n\n        weights, bg_transmittance = self.volume_rendering(z_vals, z_max, sdf)\n\n        fg_rgb_values = torch.sum(weights.unsqueeze(-1) * rgb, 1)\n\n        weight_sum = torch.sum(weights, -1, keepdim=True)\n        # dist = torch.sum(weights / weight_sum.clamp(min=1e-6) * z_vals, 1)\n        dist = torch.sum(weights * z_vals, 1)\n        depth_values = dist / torch.clamp(ray_dirs_norm, min=1e-6)\n        # depth_values = torch.sum(weights * z_vals, 1) / torch.clamp(torch.sum(weights, 1), min=1e-6) # (bn,)\n\n        # Background rendering\n        if self.use_bg:\n            N_bg_samples = z_vals_bg.shape[1]\n            z_vals_bg = torch.flip(z_vals_bg, dims=[-1, ])  # 1--->0\n\n            bg_dirs = ray_dirs.unsqueeze(1).repeat(1,N_bg_samples,1)\n            bg_locs = cam_loc.unsqueeze(1).repeat(1,N_bg_samples,1)\n\n            bg_points = self.depth2pts_outside(bg_locs, bg_dirs, z_vals_bg)  # [..., N_samples, 4]\n            bg_points_flat = bg_points.reshape(-1, 4)\n            bg_dirs_flat = bg_dirs.reshape(-1, 3)\n\n            output = self.bg_implicit_network(bg_points_flat)\n            bg_sdf = output[:,:1]\n            bg_feature_vectors = output[:, 1:]\n            bg_rgb_flat = self.bg_rendering_network(None, None, bg_dirs_flat, bg_feature_vectors)\n            bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3)\n\n            bg_weights = self.bg_volume_rendering(z_vals_bg, bg_sdf)\n\n            bg_rgb_values = torch.sum(bg_weights.unsqueeze(-1) * bg_rgb, 1)\n\n            # Composite foreground and background\n            bg_rgb_values = bg_transmittance.unsqueeze(-1) * bg_rgb_values\n            rgb_values = fg_rgb_values + bg_rgb_values\n        else:\n            rgb_values = fg_rgb_values\n\n        output = {\n            'rgb_values': rgb_values,\n            'depth_values': depth_values,\n            'weight_sum': weight_sum\n        }\n        \n        if self.use_light:\n            light_features = F.relu(feature_vectors)\n            if self.detach_light_feature:\n                light_features = light_features.detach_()\n            # lmask_flat = self.light_network(None, None, None, light_features)\n            lmask_flat = self.light_network(light_features)\n            lmask = lmask_flat.reshape(-1, N_samples, 1)\n            lmask_values = torch.sum(weights.unsqueeze(-1).detach() * lmask, 1)\n            output['light_mask'] = lmask_values\n\n        if predict_only:\n            return output\n\n        if self.training:\n            # Sample points for the eikonal loss\n            n_eik_points = batch_size * num_pixels\n            eikonal_points = torch.empty(n_eik_points, 3, device=cam_loc.device).uniform_(-self.scene_bounding_sphere, self.scene_bounding_sphere)\n\n            # Add some of the near surface points\n            eik_near_points = (cam_loc.unsqueeze(1) + z_samples_eik.unsqueeze(2) * ray_dirs.unsqueeze(1)).reshape(-1, 3)\n            n_eik_near = eik_near_points.size(0)\n            eikonal_points = torch.cat([eikonal_points, eik_near_points], 0)\n\n            # Add neighbor points near surface for smoothness loss\n            eik_near_neighbors = eik_near_points + torch.empty_like(eik_near_points).uniform_(-0.005, 0.005)\n            eikonal_points = torch.cat([eikonal_points, eik_near_neighbors], 0)\n            grad_theta = self.implicit_network.gradient(eikonal_points)\n            output['grad_theta'] = grad_theta[:n_eik_points+n_eik_near,]\n            normals = grad_theta[n_eik_points:,]\n            normals = F.normalize(normals, dim=1, eps=1e-6)\n            diff_norm = torch.norm(normals[:n_eik_near,:] - normals[n_eik_near:,:], dim=1)\n            output['diff_norm'] = diff_norm\n\n            # Sample pointclouds for bubble loss\n            if 'pointcloud' in input:\n                surface_points = input['pointcloud']\n                cam_loc_selected = cam_loc[np.random.randint(0, len(cam_loc)),:]\n                surface_points = torch.cat([surface_points, cam_loc_selected.unsqueeze(0)], dim=0)\n                surface_sdf = self.implicit_network.get_sdf_vals(surface_points)\n                output['surface_sdf'] = surface_sdf[:-1,:]\n\n            # Accumulate gradients for normal loss\n            if self.use_normal:\n                normals = F.normalize(gradients, dim=-1)\n                normals = normals.reshape(-1, N_samples, 3)\n                normal_map = torch.sum(weights.unsqueeze(-1).detach() * normals, 1)\n                normal_map = F.normalize(normal_map, dim=-1)\n                output['normal_values'] = normal_map\n\n        # elif not self.training:\n        else:\n            # Accumulate gradients for normal visualization\n            gradients = gradients.detach()\n            normals = F.normalize(gradients, dim=-1)\n            normals = normals.reshape(-1, N_samples, 3)\n            normal_map = torch.sum(weights.unsqueeze(-1) * normals, 1)\n            normal_map = F.normalize(normal_map, dim=-1)\n            output['normal_map'] = normal_map\n\n        return output\n\n    def volume_rendering(self, z_vals, z_max, sdf):\n        density_flat = self.density(sdf)\n        density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples\n\n        # included also the dist from the sphere intersection\n        dists = z_vals[:, 1:] - z_vals[:, :-1]\n        dists = torch.cat([dists, z_max.unsqueeze(-1) - z_vals[:, -1:]], -1)\n\n        # LOG SPACE\n        free_energy = dists * density\n        shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=free_energy.device), free_energy], dim=-1)  # add 0 for transperancy 1 at t_0\n        alpha = 1 - torch.exp(-free_energy)  # probability of it is not empty here\n        transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))  # probability of everything is empty up to now\n        fg_transmittance = transmittance[:, :-1]\n        weights = alpha * fg_transmittance  # probability of the ray hits something here\n        bg_transmittance = transmittance[:, -1]  # factor to be multiplied with the bg volume rendering\n\n        return weights, bg_transmittance\n\n    def bg_volume_rendering(self, z_vals_bg, bg_sdf):\n        bg_density_flat = self.bg_density(bg_sdf)\n        bg_density = bg_density_flat.reshape(-1, z_vals_bg.shape[1]) # (batch_size * num_pixels) x N_samples\n\n        bg_dists = z_vals_bg[:, :-1] - z_vals_bg[:, 1:]\n        bg_dists = torch.cat([bg_dists, torch.tensor([1e10], device=bg_dists.device).unsqueeze(0).repeat(bg_dists.shape[0], 1)], -1)\n\n        # LOG SPACE\n        bg_free_energy = bg_dists * bg_density\n        bg_shifted_free_energy = torch.cat([torch.zeros(bg_dists.shape[0], 1, device=bg_free_energy.device), bg_free_energy[:, :-1]], dim=-1)  # shift one step\n        bg_alpha = 1 - torch.exp(-bg_free_energy)  # probability of it is not empty here\n        bg_transmittance = torch.exp(-torch.cumsum(bg_shifted_free_energy, dim=-1))  # probability of everything is empty up to now\n        bg_weights = bg_alpha * bg_transmittance # probability of the ray hits something here\n\n        return bg_weights\n\n    def depth2pts_outside(self, ray_o, ray_d, depth):\n\n        '''\n        ray_o, ray_d: [..., 3]\n        depth: [...]; inverse of distance to sphere origin\n        '''\n\n        o_dot_d = torch.sum(ray_d * ray_o, dim=-1)\n        under_sqrt = o_dot_d ** 2 - ((ray_o ** 2).sum(-1) - self.scene_bounding_sphere ** 2)\n        d_sphere = torch.sqrt(under_sqrt) - o_dot_d\n        p_sphere = ray_o + d_sphere.unsqueeze(-1) * ray_d\n        p_mid = ray_o - o_dot_d.unsqueeze(-1) * ray_d\n        p_mid_norm = torch.norm(p_mid, dim=-1)\n\n        rot_axis = torch.cross(ray_o, p_sphere, dim=-1)\n        rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)\n        phi = torch.asin(p_mid_norm / self.scene_bounding_sphere)\n        theta = torch.asin(p_mid_norm * depth)  # depth is inside [0, 1]\n        rot_angle = (phi - theta).unsqueeze(-1)  # [..., 1]\n\n        # now rotate p_sphere\n        # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula\n        p_sphere_new = p_sphere * torch.cos(rot_angle) + \\\n                       torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \\\n                       rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * (1. - torch.cos(rot_angle))\n        p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)\n        pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)\n\n        return pts\n\n\nclass I2SDFLoss(nn.Module):\n    def __init__(self, eikonal_weight=0.1, smooth_weight=0.0, mask_weight=0.0, depth_weight=0.1, normal_weight=0.05, angular_weight=0.05, bubble_weight=0.0, min_bubble_iter=0, max_bubble_iter=None, smooth_iter=None, light_mask_weight=0.0, eikonal_weight_bubble=0.0):\n        super().__init__()\n        self.eikonal_weight = eikonal_weight\n        self.rgb_loss = F.l1_loss\n        self.smooth_weight = smooth_weight\n        self.mask_weight = mask_weight\n        self.depth_weight = depth_weight\n        self.normal_weight = normal_weight\n        self.angular_weight = angular_weight\n        self.bubble_weight = bubble_weight\n        # self.eikonal_weight_bubble = eikonal_weight_bubble if eikonal_weight_bubble else self.eikonal_weight\n        self.min_bubble_iter = min_bubble_iter\n        self.max_bubble_iter = max_bubble_iter\n        self.smooth_iter = smooth_iter\n        if self.bubble_weight > 0 and self.max_bubble_iter is not None and self.smooth_iter < self.max_bubble_iter:\n            self.smooth_iter = self.max_bubble_iter # Disable smoothness loss during bubble steps\n        self.light_mask_weight = light_mask_weight\n\n    def get_rgb_loss(self, rgb_values, rgb_gt):\n        rgb_gt = rgb_gt.reshape(-1, 3)\n        rgb_loss = self.rgb_loss(rgb_values, rgb_gt)\n        return rgb_loss\n\n    def get_eikonal_loss(self, grad_theta):\n        eikonal_loss = ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()\n        return eikonal_loss\n\n    def get_mask_loss(self, mask_pred, mask_gt):\n        return F.binary_cross_entropy(mask_pred.clip(1e-3, 1.0 - 1e-3), mask_gt)\n    \n    def get_depth_loss(self, depth, depth_gt, depth_mask):\n        depth_gt = depth_gt.flatten()\n        depth_mask = depth_mask.flatten()\n        # TODO: Add support for scale invariant depth loss (like MonoSDF)\n        return F.mse_loss(depth[depth_mask], depth_gt[depth_mask])\n    \n    def get_normal_l1_loss(self, normal, normal_gt, normal_mask):\n        normal_gt = normal_gt.reshape(-1, 3)\n        normal_mask = normal_mask.flatten()\n        return torch.abs(1 - torch.sum(normal[normal_mask] * normal_gt[normal_mask], dim=-1)).mean()\n    \n    def get_normal_angular_loss(self, normal, normal_gt, normal_mask):\n        normal_gt = normal_gt.reshape(-1, 3)\n        normal_mask = normal_mask.flatten()\n        dot = torch.sum(normal[normal_mask] * normal_gt[normal_mask], dim=-1)\n        angle = torch.acos(torch.clamp(dot, -1.0+1e-6, 1.0-1e-6)) / math.tau\n        return angle.clamp_max(0.5).abs().mean()\n\n    def forward(self, model_outputs, ground_truth, current_step):\n        rgb_gt = ground_truth['rgb']\n\n        rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'], rgb_gt)\n        if 'grad_theta' in model_outputs:\n            eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])\n        else:\n            eikonal_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()\n        \n        smooth_activated = self.smooth_iter is None or current_step > self.smooth_iter\n        if smooth_activated and self.smooth_weight > 0 and 'diff_norm' in model_outputs:\n            smooth_loss = model_outputs['diff_norm'].mean()\n        else:\n            smooth_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()\n        \n        if 'mask' in ground_truth and self.mask_weight > 0:\n            mask_loss = self.get_mask_loss(model_outputs['weight_sum'], ground_truth['mask'])\n        else:\n            mask_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()\n        \n        if 'depth' in ground_truth and self.depth_weight > 0:\n            depth_loss = self.get_depth_loss(model_outputs['depth_values'], ground_truth['depth'], ground_truth['depth_mask'])\n        else:\n            depth_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()\n        \n        if 'normal' in ground_truth and self.normal_weight > 0:\n            normal_loss = self.get_normal_l1_loss(model_outputs['normal_values'], ground_truth['normal'], ground_truth['normal_mask'])\n        else:\n            normal_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()\n\n        if 'normal' in ground_truth and self.angular_weight > 0:\n            angular_loss = self.get_normal_l1_loss(model_outputs['normal_values'], ground_truth['normal'], ground_truth['normal_mask'])\n        else:\n            angular_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()\n        \n        if 'surface_sdf' in model_outputs and self.bubble_weight > 0:\n            bubble_loss = model_outputs['surface_sdf'].abs().mean()\n        else:\n            bubble_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()\n        \n        if 'light_mask' in model_outputs and self.light_mask_weight > 0:\n            light_mask_loss = self.get_mask_loss(model_outputs['light_mask'].reshape(-1, 1), ground_truth['light_mask'].reshape(-1, 1))\n        else:\n            light_mask_loss = torch.tensor(0.0, device=model_outputs['rgb_values'].device).float()\n\n        loss = rgb_loss + \\\n                self.eikonal_weight * eikonal_loss + \\\n                 self.smooth_weight * smooth_loss + \\\n                  self.mask_weight * mask_loss + \\\n                   self.depth_weight * depth_loss + \\\n                    self.normal_weight * normal_loss + \\\n                     self.angular_weight * angular_loss + \\\n                      self.bubble_weight * bubble_loss + \\\n                       self.light_mask_weight * light_mask_loss\n\n        output = {\n            'loss': loss,\n            'rgb_loss': rgb_loss,\n            'eikonal_loss': eikonal_loss,\n            'smooth_loss': smooth_loss,\n            'mask_loss': mask_loss,\n            'depth_loss': depth_loss,\n            'normal_loss': normal_loss,\n            'angular_loss': angular_loss,\n            'bubble_loss': bubble_loss,\n            'light_mask_loss': light_mask_loss\n        }\n\n        return output\n"
  },
  {
    "path": "model/network/density.py",
    "content": "import torch.nn as nn\nimport torch\n\n\nclass Density(nn.Module):\n    def __init__(self, params_init={}):\n        super().__init__()\n        for p in params_init:\n            param = nn.Parameter(torch.tensor(params_init[p]))\n            setattr(self, p, param)\n\n    def forward(self, sdf, beta=None):\n        return self.density_func(sdf, beta=beta)\n\n\nclass LaplaceDensity(Density):  # alpha * Laplace(loc=0, scale=beta).cdf(-sdf)\n    def __init__(self, params_init={}, beta_min=0.0001):\n        super().__init__(params_init=params_init)\n        self.beta_min = torch.tensor(beta_min)\n\n    def density_func(self, sdf, beta=None):\n        if beta is None:\n            beta = self.get_beta()\n\n        alpha = 1 / beta\n        return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta))\n\n    def get_beta(self):\n        beta = self.beta.abs() + self.beta_min\n        return beta\n\n\nclass AbsDensity(Density):  # like NeRF++\n    def density_func(self, sdf, beta=None):\n        return torch.abs(sdf)\n\n\nclass SimpleDensity(Density):  # like NeRF\n    def __init__(self, params_init={}, noise_std=1.0):\n        super().__init__(params_init=params_init)\n        self.noise_std = noise_std\n\n    def density_func(self, sdf, beta=None):\n        if self.training and self.noise_std > 0.0:\n            noise = torch.randn(sdf.shape).to(sdf.device) * self.noise_std\n            sdf = sdf + noise\n        return torch.relu(sdf)\n"
  },
  {
    "path": "model/network/embedder.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\n\nclass Embedder:\n    \"\"\" Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. \"\"\"\n    def __init__(self, **kwargs):\n        self.kwargs = kwargs\n        self.create_embedding_fn()\n\n    def create_embedding_fn(self):\n        embed_fns = []\n        d = self.kwargs['input_dims']\n        out_dim = 0\n        if self.kwargs['include_input']:\n            embed_fns.append(lambda x: x)\n            out_dim += d\n\n        max_freq = self.kwargs['max_freq_log2']\n        N_freqs = self.kwargs['num_freqs']\n\n        if self.kwargs['log_sampling']:\n            freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)\n        else:\n            freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)\n\n        for freq in freq_bands:\n            for p_fn in self.kwargs['periodic_fns']:\n                embed_fns.append(lambda x, p_fn=p_fn,\n                                 freq=freq: p_fn(x * freq))\n                out_dim += d\n\n        self.embed_fns = embed_fns\n        self.out_dim = out_dim\n\n    def embed(self, inputs):\n        return torch.cat([fn(inputs) for fn in self.embed_fns], -1)\n\n\nclass SHEncoder(nn.Module):\n    def __init__(self, input_dims=3, degree=4):\n    \n        super().__init__()\n\n        self.input_dims = input_dims\n        self.degree = degree\n\n        assert self.input_dims == 3\n        assert self.degree >= 1 and self.degree <= 5\n\n        self.out_dim = degree ** 2\n\n        self.C0 = 0.28209479177387814\n        self.C1 = 0.4886025119029199\n        self.C2 = [\n            1.0925484305920792,\n            -1.0925484305920792,\n            0.31539156525252005,\n            -1.0925484305920792,\n            0.5462742152960396\n        ]\n        self.C3 = [\n            -0.5900435899266435,\n            2.890611442640554,\n            -0.4570457994644658,\n            0.3731763325901154,\n            -0.4570457994644658,\n            1.445305721320277,\n            -0.5900435899266435\n        ]\n        self.C4 = [\n            2.5033429417967046,\n            -1.7701307697799304,\n            0.9461746957575601,\n            -0.6690465435572892,\n            0.10578554691520431,\n            -0.6690465435572892,\n            0.47308734787878004,\n            -1.7701307697799304,\n            0.6258357354491761\n        ]\n\n    def forward(self, input, **kwargs):\n\n        result = torch.empty((*input.shape[:-1], self.out_dim), dtype=input.dtype, device=input.device)\n        x, y, z = input.unbind(-1)\n\n        result[..., 0] = self.C0\n        if self.degree > 1:\n            result[..., 1] = -self.C1 * y\n            result[..., 2] = self.C1 * z\n            result[..., 3] = -self.C1 * x\n            if self.degree > 2:\n                xx, yy, zz = x * x, y * y, z * z\n                xy, yz, xz = x * y, y * z, x * z\n                result[..., 4] = self.C2[0] * xy\n                result[..., 5] = self.C2[1] * yz\n                result[..., 6] = self.C2[2] * (2.0 * zz - xx - yy)\n                #result[..., 6] = self.C2[2] * (3.0 * zz - 1) # xx + yy + zz == 1, but this will lead to different backward gradients, interesting...\n                result[..., 7] = self.C2[3] * xz\n                result[..., 8] = self.C2[4] * (xx - yy)\n                if self.degree > 3:\n                    result[..., 9] = self.C3[0] * y * (3 * xx - yy)\n                    result[..., 10] = self.C3[1] * xy * z\n                    result[..., 11] = self.C3[2] * y * (4 * zz - xx - yy)\n                    result[..., 12] = self.C3[3] * z * (2 * zz - 3 * xx - 3 * yy)\n                    result[..., 13] = self.C3[4] * x * (4 * zz - xx - yy)\n                    result[..., 14] = self.C3[5] * z * (xx - yy)\n                    result[..., 15] = self.C3[6] * x * (xx - 3 * yy)\n                    if self.degree > 4:\n                        result[..., 16] = self.C4[0] * xy * (xx - yy)\n                        result[..., 17] = self.C4[1] * yz * (3 * xx - yy)\n                        result[..., 18] = self.C4[2] * xy * (7 * zz - 1)\n                        result[..., 19] = self.C4[3] * yz * (7 * zz - 3)\n                        result[..., 20] = self.C4[4] * (zz * (35 * zz - 30) + 3)\n                        result[..., 21] = self.C4[5] * xz * (7 * zz - 3)\n                        result[..., 22] = self.C4[6] * (xx - yy) * (7 * zz - 1)\n                        result[..., 23] = self.C4[7] * xz * (xx - 3 * yy)\n                        result[..., 24] = self.C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))\n\n        return result\n\n\nclass FourierFeature(nn.Module):\n    def __init__(self, channels, sigma=1.0, input_dims=3, include_input=True) -> None:\n        super().__init__()\n        self.register_buffer('B', torch.randn(input_dims, channels) * sigma, True)\n        self.channels = channels\n        self.out_dim = 2 * self.channels + 3 if include_input else 2 * self.channels\n        self.include_input = include_input\n    \n    def forward(self, x):\n        xp = torch.matmul(2 * np.pi * x, self.B)\n        return torch.cat([x, torch.sin(xp), torch.cos(xp)], dim=-1) if self.include_input else torch.cat([torch.sin(xp), torch.cos(xp)], dim=-1)\n\n\ndef get_embedder(embed_type='positional', **kwargs):\n    if embed_type == 'positional':\n        input_dims = kwargs['input_dims']\n        multires = kwargs['multires']\n        embed_kwargs = {\n            'include_input': True,\n            'input_dims': input_dims,\n            'max_freq_log2': multires-1,\n            'num_freqs': multires,\n            'log_sampling': True,\n            'periodic_fns': [torch.sin, torch.cos],\n        }\n        embedder_obj = Embedder(**embed_kwargs)\n        def embed(x, eo=embedder_obj): return eo.embed(x)\n        return embed, embedder_obj.out_dim\n    elif embed_type == 'spherical_harmonics':\n        embedder = SHEncoder(**kwargs)\n        return embedder, embedder.out_dim\n    elif embed_type == 'fourier':\n        embedder = FourierFeature(**kwargs)\n        return embedder, embedder.out_dim\n    else:\n        raise ValueError('Unknown embedding type: {}'.format(embed_type))"
  },
  {
    "path": "model/network/mlp.py",
    "content": "import torch.nn as nn\nimport numpy as np\n\nimport utils\nfrom .embedder import *\nfrom .density import LaplaceDensity\nfrom .ray_sampler import ErrorBoundSampler\n\n\nclass ImplicitNetwork(nn.Module):\n    def __init__(\n            self,\n            feature_vector_size,\n            sdf_bounding_sphere,\n            d_in,\n            d_out,\n            dims,\n            geometric_init=True,\n            bias=1.0,\n            skip_in=(),\n            weight_norm=True,\n            embed_type=None,\n            sphere_scale=1.0,\n            output_activation=None,\n            **kwargs\n    ):\n        super().__init__()\n\n        self.sdf_bounding_sphere = sdf_bounding_sphere\n        self.sphere_scale = sphere_scale\n        dims = [d_in] + dims + [d_out + feature_vector_size]\n\n        self.embed_fn = None\n        if embed_type:\n            embed_fn, input_ch = get_embedder(embed_type, input_dims=d_in, **kwargs)\n            self.embed_fn = embed_fn\n            dims[0] = input_ch\n        \n        print(f\"[INFO] Implicit network dims: {dims}\")\n\n        self.num_layers = len(dims)\n        self.skip_in = skip_in\n        self.weight_norm = weight_norm\n\n        for l in range(0, self.num_layers - 1):\n            if l + 1 in self.skip_in:\n                out_dim = dims[l + 1] - dims[0]\n                if out_dim < 0:\n                    print(dims)\n            else:\n                out_dim = dims[l + 1]\n\n            lin = nn.Linear(dims[l], out_dim)\n\n            if geometric_init:\n                if l == self.num_layers - 2:\n                    torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001)\n                    torch.nn.init.constant_(lin.bias, -bias)\n                elif (embed_type or self.use_grid) and l == 0:\n                    torch.nn.init.constant_(lin.bias, 0.0)\n                    torch.nn.init.constant_(lin.weight[:, 3:], 0.0)\n                    torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim))\n                elif (embed_type or self.use_grid) and l in self.skip_in:\n                    torch.nn.init.constant_(lin.bias, 0.0)\n                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))\n                    torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0)\n                else:\n                    torch.nn.init.constant_(lin.bias, 0.0)\n                    torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim))\n\n            if weight_norm:\n                lin = nn.utils.weight_norm(lin)\n\n            setattr(self, \"lin\" + str(l), lin)\n\n        self.activation = nn.Softplus(beta=100)\n        self.output_activation = None\n        if output_activation is not None:\n            self.output_activation = activations[output_activation]\n\n    def get_param_groups(self, lr):\n        return [{'params': self.parameters(), 'lr': lr}]\n\n    def forward(self, input):\n\n        if self.embed_fn is not None:\n            input = self.embed_fn(input)\n\n        x = input\n\n        for l in range(0, self.num_layers - 1):\n            lin = getattr(self, \"lin\" + str(l))\n\n            if l in self.skip_in:\n                x = torch.cat([x, input], 1) / np.sqrt(2)\n\n            x = lin(x)\n\n            if l < self.num_layers - 2:\n                x = self.activation(x)\n\n        if self.output_activation is not None:\n            x = self.output_activation(x)\n        \n        return x\n\n    def gradient(self, x):\n        x.requires_grad_(True)\n        y = self.forward(x)[:,:1]\n        d_output = torch.ones_like(y, requires_grad=False, device=y.device)\n        gradients = torch.autograd.grad(\n            outputs=y,\n            inputs=x,\n            grad_outputs=d_output,\n            create_graph=True,\n            retain_graph=True,\n            only_inputs=True)[0]\n        return gradients\n\n    def feature(self, x):\n        return self.forward(x)[:,1:]\n\n    def get_outputs(self, x, returns_grad=True):\n        x.requires_grad_(returns_grad)\n        output = self.forward(x)\n        sdf = output[:,:1]\n        ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''\n        if self.sdf_bounding_sphere > 0.0:\n            sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))\n            sdf = torch.minimum(sdf, sphere_sdf)\n        feature_vectors = output[:, 1:]\n        if returns_grad:\n            d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)\n            gradients = torch.autograd.grad(\n                outputs=sdf,\n                inputs=x,\n                grad_outputs=d_output,\n                create_graph=True,\n                retain_graph=True,\n                only_inputs=True)[0]\n            return sdf, feature_vectors, gradients\n        else:\n            return sdf, feature_vectors, None\n\n    def get_sdf_vals(self, x):\n        sdf = self.forward(x)[:,:1]\n        ''' Clamping the SDF with the scene bounding sphere, so that all rays are eventually occluded '''\n        if self.sdf_bounding_sphere > 0.0:\n            sphere_sdf = self.sphere_scale * (self.sdf_bounding_sphere - x.norm(2,1, keepdim=True))\n            sdf = torch.minimum(sdf, sphere_sdf)\n        return sdf\n\nactivations = {\n    'sigmoid': nn.Sigmoid(),\n    'relu': nn.ReLU(),\n    'softplus': nn.Softplus()\n}\n\nclass RenderingNetwork(nn.Module):\n    def __init__(\n            self,\n            feature_vector_size,\n            mode,\n            d_in,\n            d_out,\n            dims,\n            weight_norm=True,\n            embed_type=None,\n            embed_point=None,\n            output_activation='sigmoid',\n            **kwargs\n    ):\n        super().__init__()\n\n        self.mode = mode\n        dims = [d_in + feature_vector_size] + dims + [d_out]\n        self.d_out = d_out\n\n        self.embedview_fn = None\n        if embed_type:\n            embedview_fn, input_ch = get_embedder(embed_type, input_dims=3, **kwargs)\n            self.embedview_fn = embedview_fn\n            dims[0] += (input_ch - 3)\n        \n        if mode == 'idr':\n            self.embedpoint_fn = None\n            if embed_point is not None:\n                embedpoint_fn, input_ch = get_embedder(input_dims=3, **embed_point)\n                self.embedpoint_fn = embedpoint_fn\n                dims[0] += (input_ch - 3)\n\n        print(f\"[INFO] Rendering network dims: {dims}\")\n        self.num_layers = len(dims)\n        self.weight_norm = weight_norm\n\n        for l in range(0, self.num_layers - 1):\n            out_dim = dims[l + 1]\n            lin = nn.Linear(dims[l], out_dim)\n\n            if weight_norm:\n                lin = nn.utils.weight_norm(lin)\n\n            setattr(self, \"lin\" + str(l), lin)\n\n        self.activation = nn.ReLU()\n        self.output_activation = activations[output_activation]\n\n    def forward(self, points, normals, view_dirs, feature_vectors):\n        if self.embedview_fn is not None:\n            view_dirs = self.embedview_fn(view_dirs)\n\n        if self.mode == 'idr':\n            rendering_input = torch.cat([points, view_dirs, normals, feature_vectors], dim=-1)\n        # elif self.mode == 'nerf':\n        else:\n            rendering_input = torch.cat([view_dirs, feature_vectors], dim=-1)\n        x = rendering_input\n\n        for l in range(0, self.num_layers - 1):\n            lin = getattr(self, \"lin\" + str(l))\n\n            x = lin(x)\n\n            if l < self.num_layers - 2:\n                x = self.activation(x)\n\n        x = self.output_activation(x)\n\n        return x\n\n"
  },
  {
    "path": "model/network/ray_sampler.py",
    "content": "import abc\nimport torch\nfrom utils import rend_util\nimport utils\n\nclass RaySampler(metaclass=abc.ABCMeta):\n    def __init__(self, near, far):\n        self.near = near\n        self.far = far\n\n    @abc.abstractmethod\n    def get_z_vals(self, ray_dirs, cam_loc, model):\n        pass\n\nclass UniformSampler(RaySampler):\n    def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1):\n        super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far)  # default far is 2*R\n        self.N_samples = N_samples\n        self.scene_bounding_sphere = scene_bounding_sphere\n        self.take_sphere_intersection = take_sphere_intersection\n\n    def get_z_vals(self, ray_dirs, cam_loc, model):\n        if not self.take_sphere_intersection:\n            near, far = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device), self.far * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device)\n        else:\n            sphere_intersections = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)\n            near = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device)\n            far = sphere_intersections[:,1:]\n\n        t_vals = torch.linspace(0., 1., steps=self.N_samples, device=ray_dirs.device)\n        z_vals = near * (1. - t_vals) + far * (t_vals)\n\n        if model.training:\n            # get intervals between samples\n            mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])\n            upper = torch.cat([mids, z_vals[..., -1:]], -1)\n            lower = torch.cat([z_vals[..., :1], mids], -1)\n            # stratified samples in those intervals\n            t_rand = torch.rand(z_vals.shape, device=z_vals.device)\n\n            z_vals = lower + (upper - lower) * t_rand\n\n        return z_vals\n\n\nclass ErrorBoundSampler(RaySampler):\n    def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra,\n                 eps, beta_iters, max_total_iters,\n                 inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0):\n        super().__init__(near, 2.0 * scene_bounding_sphere)\n        self.N_samples = N_samples\n        self.N_samples_eval = N_samples_eval\n        self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg)\n\n        self.N_samples_extra = N_samples_extra\n\n        self.eps = eps\n        self.beta_iters = beta_iters\n        self.max_total_iters = max_total_iters\n        self.scene_bounding_sphere = scene_bounding_sphere\n        self.add_tiny = add_tiny\n\n        self.inverse_sphere_bg = inverse_sphere_bg\n        if inverse_sphere_bg:\n            self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)\n\n    def get_z_vals(self, ray_dirs, cam_loc, model):\n        beta0 = model.density.get_beta().detach()\n\n        # Start with uniform sampling\n        z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model)\n        samples, samples_idx = z_vals, None\n\n        # Get maximum beta from the upper bound (Lemma 2)\n        dists = z_vals[:, 1:] - z_vals[:, :-1]\n        bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)\n        beta = torch.sqrt(bound)\n        # beta = torch.sqrt(bound).clone()\n\n        total_iters, not_converge = 0, True\n\n        # Algorithm 1\n        while not_converge and total_iters < self.max_total_iters:\n            points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1)\n            points_flat = points.reshape(-1, 3)\n\n            # Calculating the SDF only for the new sampled points\n            with torch.no_grad():\n                samples_sdf = model.implicit_network.get_sdf_vals(points_flat)\n            if samples_idx is not None:\n                sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),\n                                       samples_sdf.reshape(-1, samples.shape[1])], -1)\n                sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)\n            else:\n                sdf = samples_sdf\n\n\n            # Calculating the bound d* (Theorem 1)\n            d = sdf.reshape(z_vals.shape)\n            dists = z_vals[:, 1:] - z_vals[:, :-1]\n            a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()\n            first_cond = a.pow(2) + b.pow(2) <= c.pow(2)\n            second_cond = a.pow(2) + c.pow(2) <= b.pow(2)\n            # d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1, device=z_vals.device)\n            # d_star[first_cond] = b[first_cond]\n            # d_star[second_cond] = c[second_cond]\n            s = (a + b + c) / 2.0\n            area_before_sqrt = s * (s - a) * (s - b) * (s - c)\n            mask = ~first_cond & ~second_cond & (b + c - a > 0)\n            # d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])\n            # Optimization: multiplication is 5-20 times faster than indexing\n            first_cond = first_cond & ~second_cond\n            d_star = first_cond * b + second_cond * c + torch.nan_to_num((2.0 * torch.sqrt(area_before_sqrt)) / a) * mask\n            d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star  # Fixing the sign\n\n\n            # Updating beta using line search\n            curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)\n            # beta[curr_error <= self.eps] = beta0\n            # Optimization: multiplication is 5-20 times faster than indexing\n            beta0_mask = curr_error <= self.eps\n            beta = beta * ~beta0_mask + beta0 * beta0_mask\n            beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta\n            for j in range(self.beta_iters):\n                beta_mid = (beta_min + beta_max) / 2.\n                curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)\n                # beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]\n                # beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]\n                beta_mid_mask = curr_error <= self.eps\n                beta_max = beta_max * ~beta_mid_mask + beta_mid * beta_mid_mask\n                beta_min = beta_min * beta_mid_mask + beta_mid * ~beta_mid_mask\n            beta = beta_max\n\n\n            # Upsample more points\n            # tmp0 = beta.unsqueeze(-1).clone()\n            # tmp0 = beta.unsqueeze(-1)\n            # density = model.density(sdf.reshape(z_vals.shape), beta=tmp0)\n            density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))\n\n            # dists = torch.cat([dists, torch.tensor([1e10], device=dists.device).unsqueeze(0).repeat(dists.shape[0], 1)], -1)\n            dists = torch.cat([dists, torch.full([dists.shape[0], 1], 1e10, device=dists.device)], -1)\n            free_energy = dists * density\n            shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=free_energy.device), free_energy[:, :-1]], dim=-1)\n            alpha = 1 - torch.exp(-free_energy)\n            transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))\n            weights = alpha * transmittance  # probability of the ray hits something here\n\n            #  Check if we are done and this is the last sampling\n            total_iters += 1\n            not_converge = beta.max() > beta0\n\n            if not_converge and total_iters < self.max_total_iters:\n                ''' Sample more points proportional to the current error bound'''\n\n                N = self.N_samples_eval\n\n                bins = z_vals\n                error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)\n                # tmp0 = beta.unsqueeze(-1).clone()\n                # tmp1 = -d_star / tmp0\n                # tmp2 = dists[:,:-1] ** 2.\n                # tmp3 = 4 * tmp0 ** 2\n                # error_per_section = tmp1 * tmp2 / tmp3\n                error_integral = torch.cumsum(error_per_section, dim=-1)\n                bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]\n\n                pdf = bound_opacity + self.add_tiny\n                pdf = pdf / torch.sum(pdf, -1, keepdim=True)\n                cdf = torch.cumsum(pdf, -1)\n                cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)\n\n            else:\n                ''' Sample the final sample set to be used in the volume rendering integral '''\n\n                N = self.N_samples\n\n                bins = z_vals\n                pdf = weights[..., :-1]\n                pdf = pdf + 1e-5  # prevent nans\n                pdf = pdf / torch.sum(pdf, -1, keepdim=True)\n                cdf = torch.cumsum(pdf, -1)\n                cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)  # (batch, len(bins))\n\n\n            # Invert CDF\n            if (not_converge and total_iters < self.max_total_iters) or (not model.training):\n                u = torch.linspace(0., 1., steps=N, device=cdf.device).unsqueeze(0).repeat(cdf.shape[0], 1)\n            else:\n                u = torch.rand(list(cdf.shape[:-1]) + [N], device=cdf.device)\n            u = u.contiguous()\n\n            inds = torch.searchsorted(cdf, u, right=True)\n            below = torch.max(torch.zeros_like(inds - 1), inds - 1)\n            above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)\n            inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)\n\n            matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]\n            cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)\n            bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)\n\n            denom = (cdf_g[..., 1] - cdf_g[..., 0])\n            denom_mask = denom < 1e-5\n            denom = denom_mask + ~denom_mask * denom\n            # denom = torch.where(denom < 1e-5, 1.0, denom)\n            t = (u - cdf_g[..., 0]) / denom\n            samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])\n\n\n            # Adding samples if we not converged\n            if not_converge and total_iters < self.max_total_iters:\n                z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)\n\n\n        z_samples = samples\n\n        near, far = self.near * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device), self.far * torch.ones(ray_dirs.shape[0], 1, device=ray_dirs.device)\n        if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection\n            far = rend_util.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:]\n\n        if self.N_samples_extra > 0:\n            if model.training:\n                sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]\n            else:\n                sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()\n            z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)\n        else:\n            z_vals_extra = torch.cat([near, far], -1)\n\n        z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)\n\n        # add some of the near surface points\n        idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],), device=z_vals.device)\n        z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))\n\n        if self.inverse_sphere_bg:\n            z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model)\n            z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)\n            z_vals = (z_vals, z_vals_inverse_sphere)\n\n        return z_vals, z_samples_eik\n\n    def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):\n        density = model.density(sdf.reshape(z_vals.shape), beta=beta)\n        shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1, device=dists.device), dists * density[:, :-1]], dim=-1)\n        integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)\n        error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)\n        error_integral = torch.cumsum(error_per_section, dim=-1)\n        bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])\n\n        return bound_opacity.max(-1)[0]\n\n\n"
  },
  {
    "path": "model/rendering/__init__.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport numpy as np\nimport utils\nimport cv2\nfrom .brdf import *\n\n\nclass RenderingLayer(nn.Module):\n    def __init__(self, spp, split_n_pixels, preserve_light=True) -> None:\n        super().__init__()\n        self.spp = spp\n        self.split_n_pixels = split_n_pixels\n        self.preserve_light = preserve_light\n    \n    def forward(\n            self,\n            model,\n            surface_points,\n            view_direction,\n            Kd,\n            Ks,\n            normal,\n            rough,\n            radiance_scale=None,\n            intersect_func=None\n        ):\n        \"\"\"\n        Render according to material, normal and lighting conditions\n        Params:\n            model: NeRF model to predict radiance\n            surface_points, view_direction, albedo, normal, rough, metal: (bn, c)\n        \"\"\"\n        bn = normal.size(0)\n\n        cx, cy, cz = create_frame(normal)\n        wi_x = torch.sum(cx*view_direction, dim=1)\n        wi_y = torch.sum(cy*view_direction, dim=1)\n        wi_z = torch.sum(cz*view_direction, dim=1)\n        wi = torch.stack([wi_x, wi_y, wi_z], dim=1)\n        wi_mask = (wi[:,2] >= 0.00001)\n        wi[:,2,...] = torch.where(wi[:,2,...] < 0.00001, torch.ones_like(wi[:,2,...]) * 0.00001, wi[:,2,...])\n        wi = F.normalize(wi, dim=1, eps=1e-6)\n        # wi_mask = torch.where(wi[:,2:3,...] < 0, torch.zeros_like(wi[:,2:3,...]), torch.ones_like(wi[:,2:3,...]))\n        wi = wi.unsqueeze(1) # (bn, 1, 3)\n\n        # with torch.no_grad():\n        if True:\n            samples = torch.rand(bn, self.spp, 3, device=normal.device)\n            pS = probabilityToSampleSpecular(Kd, Ks)\n            clamp_value = 0.0\n            pS.clamp_min_(clamp_value)\n            sample_diffuse = samples[:,:,0] >= pS\n\n            ls_diffuse = square_to_cosine_hemisphere(samples[:,:,1:])\n            ls_specular = sample_ggx_specular(samples[:,:,1:], rough, wi)\n            wo = torch.where(sample_diffuse.unsqueeze(2).expand(bn, self.spp, 3), ls_diffuse, ls_specular) # (bn, spp, 3)\n        pdfs = pdf_ggx(Kd, Ks, rough, wi, wo, clamp_value).unsqueeze(2)\n        eval_diff, eval_spec, wo_mask = eval_ggx(Kd, Ks, rough, wi, wo)\n        # wo_mask = torch.all(wo_mask, dim=1)\n\n        direction = to_global(wo, cx.unsqueeze(1), cy.unsqueeze(1), cz.unsqueeze(1))\n\n        # surface_points = surface_points + 0.01 * view_direction # prevent self-intersection\n        surface_points = surface_points.unsqueeze(1).expand_as(direction).reshape(-1, 3)\n        direction = direction.reshape(-1, 3)\n        surface_points = surface_points + direction * 0.01 # prevent self-intersection\n        \n        pts_splits = torch.split(surface_points, self.split_n_pixels, dim=0)\n        dirs_splits = torch.split(direction, self.split_n_pixels, dim=0)\n        radiance = []\n        # with torch.no_grad():\n        for pts, dirs in zip(pts_splits, dirs_splits):\n            radiance.append(model.get_incident_radiance(pts, dirs, intersect_func))\n        radiance = torch.cat(radiance, dim=0)\n\n        radiance = radiance.view(bn, self.spp, 3)\n        if radiance_scale is not None:\n            radiance = radiance * radiance_scale[None,None,:]\n        pdfs = torch.clamp(pdfs, min=0.00001)\n        ndl = torch.clamp(wo[:,:,2:], min=0)\n\n        brdfDiffuse = eval_diff.expand(bn, self.spp, 3) * ndl / pdfs\n        colorDiffuse = torch.mean(brdfDiffuse * radiance, dim=1)\n        brdfSpec = eval_spec.expand(bn, self.spp, 3) * ndl / pdfs\n        colorSpec = torch.mean(brdfSpec * radiance, dim=1)\n\n        return colorDiffuse, colorSpec, wi_mask\n\n\n"
  },
  {
    "path": "model/rendering/brdf.py",
    "content": "import torch\nimport torch.nn.functional as F\nimport numpy as np\n\ndef create_frame(n: torch.Tensor, eps:float = 1e-6):\n    \"\"\"\n    Generate orthonormal coordinate system based on surface normal\n    [Duff et al. 17] Building An Orthonormal Basis, Revisited. JCGT. 2017.\n    :param: n (bn, 3, ...)\n    \"\"\"\n    z = F.normalize(n, dim=1, eps=eps)\n    sgn = torch.where(z[:,2,...] >= 0, 1.0, -1.0)\n    a = -1.0 / (sgn + z[:,2,...])\n    b = z[:,0,...] * z[:,1,...] * a\n    x = torch.stack([1.0 + sgn * z[:,0,...] * z[:,0,...] * a, sgn * b, -sgn * z[:,0,...]], dim=1)\n    y = torch.stack([b, sgn + z[:,1,...] * z[:,1,...] * a, -z[:,1,...]], dim=1)\n    return x, y, z\n\n\ndef get_rendering_parameters(albedo_raw, rough_raw, use_metallic):\n    if use_metallic:\n        assert albedo_raw.size(-1) == 3 and rough_raw.size(-1) == 2\n        metal = rough_raw[:,1:]\n        rough = rough_raw[:,:1].clamp_min(0.01)\n        Ks = baseColorToSpecularF0(albedo_raw, metal)\n        Kd = albedo_raw * (1 - metal)\n    else:\n        assert albedo_raw.size(-1) == 6 and rough_raw.size(-1) == 1\n        Kd = albedo_raw[:,:3]\n        Ks = albedo_raw[:,3:].clamp_min(0.04)\n        rough = rough_raw.clamp_min(0.01)\n    return Kd, Ks, rough\n\n\ndef to_global(d, x, y, z):\n    \"\"\"\n    d, x, y, z: (*, 3)\n    \"\"\"\n    return d[...,0:1] * x + d[...,1:2] * y + d[...,2:3] * z\n\ndef sqrt_(x: torch.Tensor, eps=1e-8) -> torch.Tensor:\n    \"\"\"\n    clamping 0 values of sqrt input to avoid NAN gradients\n    \"\"\"\n    return torch.sqrt(torch.clamp(x, min=eps))\n\ndef reflect(v: torch.Tensor, h: torch.Tensor):\n    dot = torch.sum(v*h, dim=2, keepdim=True)\n    return 2 * dot * h - v\n\ndef square_to_cosine_hemisphere(sample: torch.Tensor):\n    u, v = sample[:,:,0,...], sample[:,:,1,...]\n    phi = u * 2 * np.pi\n    r = sqrt_(v)\n    cos_theta = sqrt_(torch.clamp(1 - v, 0))\n    return torch.stack([torch.cos(phi) * r, torch.sin(phi) * r, cos_theta], dim=2)\n\n\ndef get_cos_theta(v: torch.Tensor):\n    return v[:,:,2,...]\n\n\ndef get_phi(v: torch.Tensor):\n    cos_theta = torch.clamp(v[:,:,2,...], min=0, max=1)\n    sin_theta = torch.clamp(sqrt_(1 - cos_theta*cos_theta), min=1e-8)\n    cos_phi = torch.clamp(v[:,:,0,...] / sin_theta, -1, 1)\n    sin_phi = v[:,:,1,...] / sin_theta\n    phi = torch.acos(cos_phi) # (0, pi)\n    return torch.where(sin_phi > 0, phi, 2*np.pi - phi)\n\n\ndef sample_disney_specular(sample: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor):\n    \"\"\"\n    :param: sample (bn, spp, 3, h, w)\n    :param: roughness (bn, 1, 1, h, w)\n    :param: wi (*, *, 3, h, w), supposed to be normalized\n    :return: wo (bn, spp, 3, h, w), phi (bn, spp, h, w), cos theta (bn, spp, h, w)\n    \"\"\"\n    # a = torch.clamp(roughness, 0.001)\n    a = roughness\n    u, v = sample[:,:,0,...], sample[:,:,1,...]\n    phi = u * 2 * np.pi\n    cos_theta = sqrt_((1 - v) / (1 + (a*a - 1) * v))\n    sin_theta = sqrt_(1 - cos_theta*cos_theta)\n    cos_phi = torch.cos(phi)\n    sin_phi = torch.sin(phi)\n    half = torch.stack([sin_theta*cos_phi, sin_theta*sin_phi, cos_theta], dim=2)\n    wo = F.normalize(reflect(wi.expand_as(half), half), dim=2, eps=1e-8)\n    return wo\n    #, phi.squeeze(2), cos_theta.squeeze(2)\n\n\n\ndef GTR2(ndh, a):\n    a2 = a*a\n    t = 1.0 + (a2 - 1.0) * ndh * ndh\n    return a2 / (np.pi * t * t)\n\ndef SchlickFresnel(u):\n    m = torch.clamp(1.0 - u, 0, 1)\n    return m**5\n\ndef smithG_GGX(ndv, a):\n    a = a*a\n    b = ndv*ndv\n    return 1.0 / (ndv + sqrt_(a + b - a * b))\n\n\ndef pdf_disney(roughness: torch.Tensor, metallic: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor):\n    \"\"\"\n    :param: roughness/metallic (bn, 1, h, w)\n    :param: wi (*, *, 3, h, w), supposed to be normalized\n    :param: wo (*, *, 3, h, w), supposed to be normalized\n    \"\"\"\n    # specularAlpha = torch.clamp(roughness, 0.001)\n    specularAlpha = roughness\n    diffuseRatio = 0.5 * (1 - metallic)\n    specularRatio = 1 - diffuseRatio\n    half = F.normalize(wi + wo, dim=2, eps=1e-8)\n    cosTheta = torch.abs(half[:,:,2,...])\n    pdfGTR2 = GTR2(cosTheta, specularAlpha) * cosTheta\n    pdfSpec = pdfGTR2 / torch.clamp(4.0 * torch.abs(torch.sum(wo*half, dim=2)), min=1e-8)\n    pdfDiff = torch.abs(wo[:,:,2,...]) / np.pi\n    pdf = diffuseRatio * pdfDiff + specularRatio * pdfSpec\n    pdf = torch.where(wi[:,:,2,...] < 0.0001, torch.ones_like(pdf) * 0.0001, pdf)\n    pdf = torch.where(wo[:,:,2,...] < 0.0001, torch.ones_like(pdf) * 0.0001, pdf)\n    return pdf\n\n\ndef eval_disney(albedo: torch.Tensor, roughness: torch.Tensor, metallic: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor):\n    \"\"\"\n    :param: albedo/roughness/metallic (bn, c, h, w)\n    :param: wi (*, *, 3, h, w), supposed to be normalized\n    :param: wo (*, *, 3, h, w), supposed to be normalized\n    \"\"\"\n    h = wi + wo;\n    h = F.normalize(h, dim=2, eps=1e-8)\n    \n    CSpec0 = torch.lerp(torch.ones_like(albedo)*0.04, albedo, metallic).unsqueeze(1)\n\n    ldh = torch.clamp(torch.sum( (wo * h), dim = 2), 0, 1).unsqueeze(2)\n    ndv = wi[:,:,2:3,...]\n    ndl = wo[:,:,2:3,...]\n    ndh = h[:,:,2:3,...]\n\n    FL, FV = SchlickFresnel(ndl), SchlickFresnel(ndv)\n    roughness = roughness.unsqueeze(1)\n    Fd90 = 0.5 + 2.0 * ldh * ldh * roughness\n    Fd = torch.lerp(torch.ones_like(Fd90), Fd90, FL) * torch.lerp(torch.ones_like(Fd90), Fd90, FV)\n\n    Ds = GTR2(ndh, roughness)\n    FH = SchlickFresnel(ldh)\n    Fs = torch.lerp(CSpec0, torch.ones_like(CSpec0), FH)\n    roughg = (roughness * 0.5 + 0.5) ** 2\n    Gs1, Gs2 = smithG_GGX(ndl, roughg), smithG_GGX(ndv, roughg)\n    Gs = Gs1 * Gs2\n\n    eval_diff = Fd * albedo.unsqueeze(1) * (1.0 - metallic.unsqueeze(1)) / np.pi\n    eval_spec = Gs * Fs * Ds\n    mask = torch.where(ndl < 0, torch.zeros_like(ndl), torch.ones(ndl))\n    return eval_diff, eval_spec, mask\n\n\ndef F_Schlick(SpecularColor, VoH):\n\tFc = (1 - VoH)**5\n\treturn torch.clamp(50.0 * SpecularColor[:,:,1:2,...], min=0, max=1) * Fc + (1 - Fc) * SpecularColor\n\ndef GetSpecularEventProbability(SpecularColor, NoV) -> torch.Tensor:\n\tf = F_Schlick(SpecularColor, NoV);\n\treturn (f[:,:,0,...] + f[:,:,1,...] + f[:,:,2,...]) / 3\n\ndef baseColorToSpecularF0(baseColor, metalness):\n    return torch.lerp(torch.empty_like(baseColor).fill_(0.04), baseColor, metalness)\n\ndef luminance(color):\n    if color.size(1) == 1:\n        return color\n    # return color.mean(dim=1, keepdim=True)\n    return color[:,0:1,...] * 0.212671 + color[:,1:2,...] * 0.715160 + color[:,2:3,...] * 0.072169\n\ndef probabilityToSampleSpecular(difColor, specColor) -> torch.Tensor:\n    lumDiffuse = torch.clamp(luminance(difColor), min=0.01)\n    lumSpecular = torch.clamp(luminance(specColor), min=0.01)\n    return lumSpecular / (lumDiffuse + lumSpecular)\n\ndef shadowedF90(F0):\n    t = 1 / 0.04\n    return torch.clamp(t * luminance(F0), max=1)\n\ndef evalFresnel(f0, f90, NdotS):\n    # print(f0.shape, f90.shape, NdotS.shape)\n    return f0 + (f90 - f0) * (1 - NdotS)**5\n\ndef Smith_G1_GGX(alphaSquared, NdotSSquared):\n    return 2 / (sqrt_(((alphaSquared * (1 - NdotSSquared)) + NdotSSquared) / NdotSSquared) + 1)\n\ndef Smith_G2_GGX(alphaSquared, NdotL, NdotV):\n\ta = NdotV * sqrt_(alphaSquared + NdotL * (NdotL - alphaSquared * NdotL))\n\tb = NdotL * sqrt_(alphaSquared + NdotV * (NdotV - alphaSquared * NdotV))\n\treturn 0.5 / (a + b)\n\ndef GGX_D(alphaSquared, NdotH):\n    b = ((alphaSquared - 1) * NdotH * NdotH + 1)\n    return alphaSquared / (np.pi * b * b)\n\ndef pdf_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor, ps_min=0.0):\n    \"\"\"\n    :param: color (bn, 3, h, w)\n    :param: roughness/metallic (bn, 1, h, w)\n    :param: wi (*, *, 3, h, w), supposed to be normalized\n    :param: wo (*, *, 3, h, w), supposed to be normalized\n    :return: pdf (*, *, h, w)\n    \"\"\"\n    alpha = roughness * roughness\n    alphaSquared = alpha * alpha\n    NdotV = wi[:,:,2,...]\n    h = F.normalize(wi + wo, dim=2, eps=1e-8)\n    NdotH = h[:,:,2,...]\n    # print(alphaSquared.min(), NdotH.min(), NdotV.min())\n    ggxd = GGX_D(torch.clamp(alphaSquared, min=0.00001), NdotH)\n    smith = Smith_G1_GGX(alphaSquared, NdotV * NdotV)\n    # pdf_spec = GGX_D(torch.clamp(alphaSquared, min=0.00001), NdotH) * Smith_G1_GGX(alphaSquared, NdotV * NdotV) / (4 * NdotV)\n    pdf_spec = ggxd * smith / (4 * NdotV)\n    # print(torch.any(torch.isnan(ggxd)), torch.any(torch.isnan(smith)), torch.any(torch.isnan(NdotV)))\n    # print(NdotV.min(), ggxd.min(), smith.min())\n    with torch.no_grad():\n        pS = probabilityToSampleSpecular(Kd, Ks).clamp_min(ps_min)\n    pdf_diff = wo[:,:,2,...] / np.pi\n    # print(\"#########################################\")\n    # print(\"#########################################\")\n    # print(\"#########################################\")\n    # print(torch.any(torch.isnan(kS)), torch.any(torch.isnan(pdf_spec)), torch.any(torch.isnan(pdf_diff)))\n    # print(\"#########################################\")\n    # print(\"#########################################\")\n    # print(\"#########################################\")\n    pdf = pS * pdf_spec + (1 - pS) * pdf_diff\n    pdf = torch.where(wi[:,:,2,...] <= 0.0001, torch.ones_like(pdf) * 0.0001, pdf)\n    pdf = torch.where(wo[:,:,2,...] <= 0.0001, torch.ones_like(pdf) * 0.0001, pdf)\n    return pdf\n\ndef eval_ggx(Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor, wo: torch.Tensor):\n    \"\"\"\n    :param: color (bn, c, h, w)\n    :param: roughness/metallic (bn, 1, h, w)\n    :param: wi (*, *, c, h, w), supposed to be normalized\n    :param: wo (*, *, c, h, w), supposed to be normalized\n    :return: fr(wi, wo) (*, *, c, h, w)\n    \"\"\"\n    NDotL = wo[:,:,2:3,...]\n    NDotV = wi[:,:,2:3,...]\n    H = F.normalize(wi + wo, dim=2, eps=1e-8)\n    NDotH = H[:,:,2:3,...]\n    LDotH = torch.sum(wo*H, dim=2, keepdim=True)\n    roughness = roughness.unsqueeze(1)\n    alpha = roughness * roughness\n    alpha2 = alpha * alpha\n    D = GGX_D(torch.clamp(alpha2, min=0.00001), NDotH)\n    G2 = Smith_G2_GGX(alpha2, NDotL, NDotV)\n    f = evalFresnel(Ks.unsqueeze(1), shadowedF90(Ks).unsqueeze(1), LDotH)\n    # spec = torch.where(NDotL <= 0, torch.zeros_like(NDotL), f * G2 * D)\n    # mask = torch.where(NDotL <= 0, torch.zeros_like(NDotL), torch.ones_like(NDotL))\n    spec = torch.where(NDotL < 0.0001, torch.ones_like(NDotL) * 0.0001, f * G2 * D)\n    # mask = torch.where(NDotL <= 0, torch.zeros_like(NDotL), torch.ones_like(NDotL))\n    mask = (NDotL >= 0.0001).squeeze(-1)\n    return Kd.unsqueeze(1) / np.pi, spec, mask\n\n\ndef sample_weight_ggx(alphaSquared, NdotL, NdotV):\n    G1V = Smith_G1_GGX(alphaSquared, NdotV*NdotV)\n    G1L = Smith_G1_GGX(alphaSquared, NdotL*NdotL)\n    return G1L / (G1V + G1L - G1V * G1L)\n\ndef sample_ggx(sample: torch.Tensor, Kd: torch.Tensor, Ks: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor):\n    \"\"\"\n    :param: sample (bn, spp, 3, h, w)\n    :param: roughness (bn, 1, h, w)\n    :param: wi (*, *, 3, h, w), supposed to be normalized\n    :return: wo (bn, spp, 3, h, w), weight (bn, spp, 3, h, w)\n    \"\"\"\n    with torch.no_grad():\n        pS = probabilityToSampleSpecular(Kd, Ks)\n    sample_diffuse = sample[:,:,2,...] >= pS\n\n    wo_diff = square_to_cosine_hemisphere(sample[:,:,1:,...])\n    weight_diff = Kd / (1 - pS)\n    weight_diff = weight_diff.unsqueeze(1)\n\n    roughness = roughness.unsqueeze(1)\n    alpha = roughness * roughness\n    # alpha = roughness\n    Vh = F.normalize(torch.cat([alpha * wi[:,:,0:1,...], alpha * wi[:,:,1:2,...], wi[:,:,2:3,...]], dim=2), dim=2, eps=1e-8)\n    lensq = Vh[:,:,0:1,...]**2 + Vh[:,:,1:2,...]**2\n    zero_ = torch.zeros_like(Vh[:,:,0,...])\n    one_ = torch.ones_like(Vh[:,:,0,...])\n    T1 = torch.where(\n        lensq > 0, \n        torch.stack([-Vh[:,:,1,...], Vh[:,:,0,...], zero_], dim=2) / sqrt_(lensq),\n        torch.stack([one_, zero_, zero_], dim=2)\n    )\n    T2 = torch.cross(Vh, T1, dim=2)\n    r = sqrt_(sample[:,:,0:1,...])\n    phi = 2 * np.pi * sample[:,:,1:2,...]\n    t1 = r * torch.cos(phi)\n    t2 = r * torch.sin(phi)\n    s = 0.5 * (1 + Vh[:,:,2:3,...])\n    t2 = torch.lerp(sqrt_(1 - t1**2), t2, s)\n    Nh = t1 * T1 + t2 * T2 + sqrt_(torch.clamp(1 - t1*t1 - t2*t2, min=0)) * Vh\n    h = F.normalize(torch.cat([alpha * Nh[:,:,0:1,...], alpha * Nh[:,:,1:2,...], torch.clamp(Nh[:,:,2:3,...], min=0)], dim=2), dim=2, eps=1e-8)\n    wo = reflect(wi, h)\n\n    HdotL = torch.clamp(torch.sum(h*wo, dim=2, keepdim=True), min=0.0001, max=1.0)\n    NdotL = torch.clamp(wo[:,:,2:3,...], min=0.0001, max=1.0)\n    NdotV = torch.clamp(wi[:,:,2:3,...], min=0.0001, max=1.0)\n    # NdotH = torch.clamp(h[:,:,2:3,...], min=0.00001, max=1.0)\n    # F = evalFresnel(specularF0, shadowedF90(specularF0), HdotL)\n    weight = evalFresnel(Ks, shadowedF90(Ks), HdotL) * sample_weight_ggx(alpha*alpha, NdotL, NdotV) / pS.unsqueeze(1)\n\n    wo = torch.where(sample_diffuse.unsqueeze(2), wo_diff, wo)\n    weight = torch.where(sample_diffuse.unsqueeze(2), weight_diff, weight)\n\n    return wo, weight\n\n\n\ndef sample_ggx_specular(sample: torch.Tensor, roughness: torch.Tensor, wi: torch.Tensor):\n    \"\"\"\n    :param: sample (bn, spp, 2, h, w)\n    :param: roughness (bn, 1, h, w)\n    :param: wi (*, *, 3, h, w), supposed to be normalized\n    :return: wo (bn, spp, 3, h, w), phi (bn, spp, h, w), cos theta (bn, spp, h, w)\n    \"\"\"\n    roughness = roughness.unsqueeze(1)\n    alpha = roughness * roughness\n    # alpha = roughness\n    Vh = F.normalize(torch.cat([alpha * wi[:,:,0:1,...], alpha * wi[:,:,1:2,...], wi[:,:,2:3,...]], dim=2), dim=2, eps=1e-8)\n    # bn, spp, _, row, col = Vh.shape\n    # Vh = Vh.view(-1, 3, row, col)\n    # T1, T2, Vh = utils.hughes_moeller(Vh)\n    # T1 = T1.view(bn, spp, 3, row, col)\n    # T2 = T2.view(bn, spp, 3, row, col)\n    # Vh = Vh.view(bn, spp, 3, row, col)\n    lensq = Vh[:,:,0:1,...]**2 + Vh[:,:,1:2,...]**2\n    zero_ = torch.zeros_like(Vh[:,:,0,...])\n    one_ = torch.ones_like(Vh[:,:,0,...])\n    T1 = torch.where(\n        lensq > 0, \n        torch.stack([-Vh[:,:,1,...], Vh[:,:,0,...], zero_], dim=2) / sqrt_(lensq),\n        torch.stack([one_, zero_, zero_], dim=2)\n    )\n    T2 = torch.cross(Vh, T1, dim=2)\n    r = sqrt_(sample[:,:,0:1,...])\n    phi = 2 * np.pi * sample[:,:,1:2,...]\n    t1 = r * torch.cos(phi)\n    t2 = r * torch.sin(phi)\n    s = 0.5 * (1 + Vh[:,:,2:3,...])\n    t2 = torch.lerp(sqrt_(1 - t1**2), t2, s)\n    Nh = t1 * T1 + t2 * T2 + sqrt_(torch.clamp(1 - t1*t1 - t2*t2, min=0)) * Vh\n    h = F.normalize(torch.cat([alpha * Nh[:,:,0:1,...], alpha * Nh[:,:,1:2,...], torch.clamp(Nh[:,:,2:3,...], min=0)], dim=2), dim=2, eps=1e-8)\n    wo = reflect(wi, h)\n    return wo"
  },
  {
    "path": "model/trainer/__init__.py",
    "content": "from .recon import ReconstructionTrainer"
  },
  {
    "path": "model/trainer/recon.py",
    "content": "import math\nimport torch\nimport pytorch_lightning as pl\nimport numpy as np\nimport torch.optim as optim\nimport os\nfrom torch.utils.data import DataLoader\nimport utils\nfrom utils import rend_util\nimport utils.plots as plt\nimport dataset\nimport model\nfrom tqdm import trange\nfrom pytorch_lightning.callbacks import RichProgressBar\nfrom torchmetrics.functional import structural_similarity_index_measure as ssim\nfrom torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as LPIPS\nos.environ[\"OPENCV_IO_ENABLE_OPENEXR\"]=\"1\"\nimport cv2\n\n\nlpips = LPIPS()\n\nclass ReconstructionTrainer(pl.LightningModule):\n    def __init__(self, conf, prog_bar: RichProgressBar, exp_dir, model_only=False, val_mesh=False, is_val=False, **kwargs) -> None:\n        super().__init__()\n        self.conf = conf\n        self.prog_bar = prog_bar\n        self.batch_size = conf.train.batch_size\n        self.bubble_batch_size = getattr(conf.train, 'bubble_batch_size', self.batch_size)\n        self.expdir = exp_dir\n        self.val_mesh = val_mesh\n\n        conf_model = conf.model\n        use_normal = (getattr(conf.loss, 'normal_weight', 0) > 0) or (getattr(conf.loss, 'angular_weight', 0) > 0)\n        conf_model.use_normal = use_normal\n        self.model = model.I2SDFNetwork(conf_model)\n        \n        if model_only:\n            return\n\n        print('[INFO] Loading data ...')\n        dataset_conf = conf.dataset\n        self.scan_id = dataset_conf.scan_id\n        self.train_dataset = dataset.ReconDataset(\n                **dataset_conf, \n                use_mask=getattr(conf.loss, 'mask_weight', 0) > 0, \n                use_depth=getattr(conf.loss, 'depth_weight', 0) > 0, \n                use_normal=use_normal, \n                use_bubble=getattr(conf.loss, 'bubble_weight', 0) > 0, \n                use_lightmask=getattr(conf.loss, 'light_mask_weight', 0) > 0\n            )\n        if self.train_dataset.use_bubble:\n            os.makedirs(os.path.join(self.expdir, 'hotmap'), exist_ok=True)\n            os.makedirs(os.path.join(self.expdir, 'countmap'), exist_ok=True)\n            self.pdf_criterion = getattr(conf.train, 'pdf_criterion', 'DEPTH')\n            assert self.pdf_criterion in ['RGB', 'DEPTH']\n\n        self.is_hdr = self.train_dataset.is_hdr\n        self.plots_dir = os.path.join(self.expdir, 'plots')\n        self.trace_bub_idx = self.conf.train.get('trace_bub_idx', -1)\n        if self.trace_bub_idx != -1:\n            os.makedirs(f\"{self.plots_dir}/bubble\", exist_ok=True)\n            print(f\"[INFO] Activate hotmap visualization for #{self.trace_bub_idx}\")\n            self.plot_dataset = dataset.PlotDataset(**dataset_conf, indices=[self.trace_bub_idx], plot_nimgs=1, is_val=is_val)\n        else:\n            data = {\n                'intrinsics': self.train_dataset.intrinsics_all,\n                'pose': self.train_dataset.pose_all,\n                'rgb': self.train_dataset.rgb_images,\n                'img_res': self.train_dataset.img_res\n            }\n            if self.train_dataset.use_lightmask:\n                data['light_mask'] = self.train_dataset.lightmask_images\n            self.plot_dataset = dataset.PlotDataset(**dataset_conf, data=data, plot_nimgs=conf.plot.plot_nimgs, is_val=is_val)\n\n        os.makedirs(self.plots_dir, exist_ok=True)\n        with open(f\"{self.expdir}/config.yml\", 'w') as f:\n            f.write(self.conf.dump())\n        if self.train_dataset.use_bubble:\n            points = self.train_dataset.pointcloud\n            index = torch.randperm(points.size(0))[:200000]\n            points = points[index,:]\n            plt.visualize_pointcloud(points, f\"{self.expdir}/pointcloud.html\")\n            print(f\"[INFO] Pointcloud visualization success: saved to {self.expdir}/pointcloud.html\")\n            self.pdf_prune = self.train_dataset.pdf_prune\n            self.pdf_max = self.train_dataset.pdf_max\n\n        # self.ds_len = len(self.train_dataset)\n        self.ds_len = self.train_dataset.n_images\n        print('[INFO] Finish loading data. Data-set size: {0}'.format(self.ds_len))\n        epoch_steps = len(self.train_dataset) / self.batch_size\n        self.nepochs = int(math.ceil(200000 / epoch_steps))\n\n        self.loss = model.I2SDFLoss(**conf.loss)\n        self.total_pixels = self.plot_dataset.total_pixels\n        self.img_res = self.plot_dataset.img_res\n        self.bubble_activated = False\n        self.uniform_bubble = getattr(self.conf.train, 'uniform_bubble', False)\n        if self.uniform_bubble:\n            print(\"[INFO] Ablation study: uniform sampling for bubble loss\")\n        self.checkpoint_freq = self.conf.train.checkpoint_freq\n        self.split_n_pixels = self.conf.train.split_n_pixels\n        self.plot_conf = self.conf.plot\n        self.progbar_task = None\n        if self.train_dataset.use_lightmask and getattr(self.conf.train, 'flip_light', False):\n            self.train_dataset.lightmask_images = 1.0 - self.train_dataset.lightmask_images\n            self.plot_dataset.lightmask_images = 1.0 - self.plot_dataset.lightmask_images\n\n    def forward(self):\n        raise NotImplementedError(\"forward not supported by trainer\")\n\n    def plot_hotmap(self, path):\n        assert self.bubble_activated\n        ds = self.train_dataset\n        hotmaps = torch.zeros(self.ds_len * ds.total_pixels)\n        hotmaps[ds.pixlinks] = self.pdf.cpu()\n        hotmaps = hotmaps.reshape(self.ds_len, *ds.img_res)\n        for i, hotmap in enumerate(hotmaps):\n            hotmap = hotmap.numpy()\n            # hotmap /= max(1e-4, hotmap.max())\n            hotmap = (hotmap * 255).astype(np.uint8)\n            hotmap = cv2.applyColorMap(hotmap, cv2.COLORMAP_MAGMA)\n            cv2.imwrite(os.path.join(path, \"{:04d}.png\".format(i)), hotmap)\n            if self.trace_bub_idx == i:\n                cv2.imwrite(os.path.join(f\"{self.plots_dir}/bubble\", f\"{self.global_step}_hot.png\"), hotmap)\n    \n    def plot_countmap(self, path):\n        assert self.bubble_activated\n        ds = self.train_dataset\n        countmaps = torch.zeros(self.ds_len * ds.total_pixels)\n        countmaps[ds.pixlinks] = self.sample_count.cpu().float()\n        countmaps = countmaps.reshape(self.ds_len, *ds.img_res)\n        countmaps = countmaps / max(1, countmaps.max())\n        for i, countmap in enumerate(countmaps):\n            countmap = countmap.numpy()\n            countmap = (countmap * 255).astype(np.uint8)\n            countmap = cv2.applyColorMap(countmap, cv2.COLORMAP_MAGMA)\n            cv2.imwrite(os.path.join(path, \"{:04d}.png\".format(i)), countmap)\n            if self.trace_bub_idx == i:\n                cv2.imwrite(os.path.join(f\"{self.plots_dir}/bubble\", f\"{self.global_step}_cnt.png\"), countmap)\n\n    def update_pdf(self, value, idx):\n        assert self.bubble_activated\n        ds = self.train_dataset\n        value = value.to(self.pdf.device)\n        if self.pdf_max is not None:\n            value = value.clamp(max=self.pdf_max)\n        value[value < self.pdf_prune] = 0 # PDF pruning\n        link = ds.pointlinks[idx]\n        mask = (link != -1)\n        link = link[mask]\n        value = value[mask]\n        self.pdf[link] = value\n    \n    def sample_bubble(self, batch_size):\n        assert self.bubble_activated\n        ds = self.train_dataset\n        if self.uniform_bubble:\n            sample_idx = torch.randperm(ds.pointcloud.size(0), device=ds.pointcloud.device)[:batch_size]\n            return ds.pointcloud[sample_idx,:]\n        sample_idx = torch.where(self.pdf > 0)[0]\n        pdf_samples = self.pdf[sample_idx]\n        pointcloud_samples = ds.pointcloud[sample_idx,:]\n        if sample_idx.size(0) >= (1 << 24):\n            # print(sample_idx.size(0), self.pdf.size(0), (1 << 24))\n            print(\"[ERROR] PDF capacity exceeds maximum limit of PyTorch\")\n            exit(1)\n        idx = torch.multinomial(pdf_samples, batch_size, replacement=False) # importance sampling\n        self.sample_count[sample_idx[idx]] += 1\n        return pointcloud_samples[idx,:]\n\n    def initialize_bubble_pdf(self, split_size):\n        ds = self.train_dataset\n        # ds.pdf = ds.pdf.cuda()\n        self.register_buffer('pdf', torch.zeros(len(ds.pointcloud)), False)\n        self.register_buffer('sample_count', torch.zeros(len(ds.pointcloud)), False)\n        self.pdf = self.pdf.cuda()\n        # self.sample_count = self.sample_count.cuda()\n        for i in trange(ds.n_images):\n            intrinsics = ds.intrinsics_all[i].cuda().unsqueeze(0)\n            pose = ds.pose_all[i].cuda().unsqueeze(0)\n            img = ds.rgb_images[i].cuda() if self.pdf_criterion != 'DEPTH' else ds.depth_images[i].cuda()\n            uv = ds.uv.cuda().unsqueeze(1) # (h*w, 1, 2)\n            img_splits = torch.split(img, split_size)\n            uv_splits = torch.split(uv, split_size)\n            indices = torch.arange(i * ds.total_pixels, (i + 1) * ds.total_pixels, dtype=torch.long, device='cuda')\n            index_splits = torch.split(indices, split_size)\n            for img_split, uv_split, index_split in zip(img_splits, uv_splits, index_splits):\n                data = {\n                    'uv': uv_split,\n                    'intrinsics': intrinsics.repeat(len(uv_split), 1, 1),\n                    'pose': pose.repeat(len(uv_split), 1, 1)\n                }\n                model_output = self.model.forward(data, True)\n                if self.pdf_criterion == 'RGB':\n                    self.update_pdf((model_output['rgb_values'].detach().clamp(0, 1) - img_split.clamp(0, 1)).abs().mean(dim=-1), index_split)\n                # elif self.pdf_criterion == 'DEPTH':\n                else:\n                    self.update_pdf((model_output['depth_values'].detach() - img_split).abs(), index_split)\n\n    def configure_optimizers(self):\n        lr = self.conf.train.learning_rate\n        optimizer = optim.Adam(self.model.get_param_groups(lr), eps=1e-15)\n        decay_rate = getattr(self.conf.train, 'sched_decay_rate', 0.1)\n        decay_steps = self.nepochs * self.ds_len\n        scheduler = optim.lr_scheduler.ExponentialLR(optimizer, decay_rate ** (1./decay_steps))\n        return [optimizer], [scheduler]\n\n    def train_dataloader(self):\n        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, collate_fn=self.train_dataset.collate_fn, num_workers=4)\n        \n    def val_dataloader(self):\n        return DataLoader(self.plot_dataset, batch_size=self.conf.plot.plot_nimgs, shuffle=False, collate_fn=self.train_dataset.collate_fn)\n\n    def log_if_nonzero(self, name, value, *args, **kwargs):\n        if value > 0:\n            self.log(name, value, *args, **kwargs)\n\n    def training_step(self, batch, batch_idx):\n        indices, img_indices, model_input, ground_truth = batch\n        if not self.bubble_activated and self.train_dataset.use_bubble and self.global_step >= self.loss.min_bubble_iter and self.global_step < self.loss.max_bubble_iter:\n            # Start bubble step\n            with torch.no_grad():\n                self.bubble_activated = True\n                self.train_dataset.pointcloud = self.train_dataset.pointcloud.cuda()\n                # self.loss.eikonal_weight = self.loss.eikonal_weight_pointcloud\n\n                # Disable normal loss, since it will discourage the growth of bubbles\n                self.loss.normal_weight_bak = self.loss.normal_weight\n                self.loss.normal_weight = 0.0\n                self.loss.angular_weight_bak = self.loss.angular_weight\n                self.loss.angular_weight = 0.0\n                if not self.uniform_bubble:\n                    print(f\"[INFO] Start to initializing pointcloud PDF, criterion: {self.pdf_criterion}\")\n                    self.initialize_bubble_pdf(self.split_n_pixels) # initialize PDF maps for each image by computing losses\n                    torch.save(self.pdf, os.path.join(self.expdir, 'checkpoints', \"pdf.pt\"))\n                    torch.cuda.empty_cache()\n                    self.plot_hotmap(os.path.join(self.expdir, 'hotmap'))\n                    print(\"[INFO] Finish to initializing pointcloud PDF\")\n                    print(f\"[INFO] {torch.count_nonzero(self.pdf).item()}/{self.pdf.size(0)} points to be sampled\")\n\n        if self.bubble_activated:\n            model_input['pointcloud'] = self.sample_bubble(self.bubble_batch_size)\n\n        model_outputs = self.model(model_input)\n        if self.bubble_activated and not self.uniform_bubble:\n            with torch.no_grad():\n                if self.pdf_criterion == 'RGB':\n                    self.update_pdf((model_outputs['rgb_values'].detach().clamp(0, 1) - ground_truth['rgb'].clamp(0, 1)).abs().mean(dim=-1), indices)\n                # elif self.pdf_criterion == 'DEPTH':\n                else:\n                    self.update_pdf((model_outputs['depth_values'].detach() - ground_truth['depth']).abs(), indices)\n\n        loss_output = self.loss(model_outputs, ground_truth, self.global_step)\n        if self.bubble_activated and self.loss.max_bubble_iter is not None and self.global_step >= self.loss.max_bubble_iter:\n            # End bubble step\n            self.train_dataset.use_bubble = False\n            self.bubble_activated = False\n            del self.train_dataset.pointcloud\n            del self.train_dataset.pointlinks\n            del self.train_dataset.pixlinks\n            if not self.uniform_bubble:\n                delattr(self, 'pdf')\n                delattr(self, 'sample_count')\n            torch.cuda.empty_cache()\n            # self.loss.eikonal_weight = self.conf.loss.eikonal_weight\n            # Restore normal loss\n            self.loss.normal_weight = self.loss.normal_weight_bak\n            self.loss.angular_weight = self.loss.angular_weight_bak\n        loss = loss_output['loss']\n\n        with torch.no_grad():\n            psnr = rend_util.get_psnr(model_outputs['rgb_values'].detach(), ground_truth['rgb'].view(-1, 3))\n            self.log('train/loss', loss.item())\n            self.log('train/psnr', psnr.item(), True)\n            self.log('train/rgb_loss', loss_output['rgb_loss'].item())\n            self.log_if_nonzero('train/eikonal_loss', loss_output['eikonal_loss'].item())\n            self.log_if_nonzero('train/smooth_loss', loss_output['smooth_loss'].item())\n            self.log_if_nonzero('train/mask_loss', loss_output['mask_loss'].item())\n            self.log_if_nonzero('train/depth_loss', loss_output['depth_loss'].item())\n            self.log_if_nonzero('train/normal_loss', loss_output['normal_loss'].item())\n            self.log_if_nonzero('train/angular_loss', loss_output['angular_loss'].item())\n            self.log_if_nonzero('train/bubble_loss', loss_output['bubble_loss'].item())\n            self.log_if_nonzero('train/light_mask_loss', loss_output['light_mask_loss'].item())\n            self.log('train/beta', self.model.density.beta.item())\n\n        return loss\n\n\n    def validation_step(self, batch, batch_idx):\n\n        indices, model_input, ground_truth = batch\n\n        split = utils.split_input(model_input, self.total_pixels, self.split_n_pixels)\n        res = []\n        if self.progbar_task is None and self.prog_bar.progress:\n            self.progbar_task = self.prog_bar.progress.add_task(\"[cyan]Validation split\", total=len(split))\n        elif self.progbar_task:\n            self.prog_bar.progress.reset(self.progbar_task, total=len(split), visible=True)\n\n        for s in split:\n            out = utils.detach_dict(self.model(s))\n            d = {\n                'rgb_values': out['rgb_values'].detach(),\n                'depth_values': out['depth_values'].detach()\n            }\n            if 'normal_map' in out:\n                d['normal_map'] = out['normal_map'].detach()\n            if 'light_mask' in out:\n                d['light_mask'] = out['light_mask'].detach()\n            del out\n            res.append(d)\n            if self.progbar_task:\n                self.prog_bar.progress.update(self.progbar_task, advance=1, refresh=True)\n        if self.progbar_task:\n            self.prog_bar.progress.update(self.progbar_task, visible=False)\n        batch_size = ground_truth['rgb'].shape[0]\n        model_outputs = utils.merge_output(res, self.total_pixels, batch_size)\n\n        def get_plot_data(model_outputs, pose, ground_truth):\n            rgb_gt = ground_truth['rgb']\n            batch_size, num_samples, _ = rgb_gt.shape\n            rgb_eval = model_outputs['rgb_values'].reshape(batch_size, num_samples, 3)\n            if self.is_hdr:\n                eval_hdr = rgb_eval\n                gt_hdr = rgb_gt\n                rgb_eval = rend_util.linear_to_srgb(rgb_eval.clamp(0, 1))\n                rgb_gt = rend_util.linear_to_srgb(rgb_gt.clamp(0, 1))\n            depth_eval = model_outputs['depth_values'].reshape(batch_size, num_samples, 1)\n            plot_data = {\n                'rgb_gt': rgb_gt,\n                'pose': pose,\n                'rgb_eval': rgb_eval,\n                'depth_eval': depth_eval\n            }\n            if self.is_hdr:\n                plot_data['hdr_gt'] = gt_hdr\n                plot_data['hdr_eval'] = eval_hdr\n            if 'normal_map' in model_outputs:\n                normal_map = model_outputs['normal_map'].reshape(batch_size, num_samples, 3)\n                normal_map = normal_map.transpose(1, 2) # (bn, 3, h*w)\n                R = pose[:,:3,:3].transpose(1, 2)\n                normal_map = torch.bmm(R, normal_map) # world to camera\n                normal_map = normal_map.transpose(1, 2)\n                normal_map = (normal_map + 1.) / 2.\n                plot_data['normal_map'] = normal_map\n            if 'light_mask' in model_outputs:\n                plot_data['lmask_eval'] = model_outputs['light_mask'].reshape(batch_size, num_samples, 1)\n                plot_data['lmask_gt'] = ground_truth['light_mask'].reshape(batch_size, num_samples, 1)\n            return plot_data\n\n        plot_data = get_plot_data(model_outputs, model_input['pose'], ground_truth)\n        return {\n            'indices': indices,\n            'plot_data': plot_data\n        }\n\n    def validation_epoch_end(self, outputs) -> None:\n        self.plot_dataset.shuffle_plot_index()\n        indices = torch.cat([x['indices'] for x in outputs], dim=0)\n        plot_data = utils.merge_dict([x['plot_data'] for x in outputs])\n\n        rgb_eval = plot_data['rgb_eval']\n        rgb_gt = plot_data['rgb_gt']\n        psnr = rend_util.get_psnr(rgb_eval, rgb_gt)\n        self.log('val/psnr', psnr.item())\n        rgb_gt = rgb_gt.transpose(1, 2).view(-1, 3, *self.img_res) # (bn, h*w, 3) => (bn, 3, h, w)\n        rgb_eval = rgb_eval.transpose(1, 2).view(-1, 3, *self.img_res)\n        self.log('val/ssim', ssim(rgb_eval, rgb_gt).item())\n        lpips.to(rgb_eval.device)\n        self.log('val/lpips', lpips(rgb_eval.clamp(0, 1) * 2 - 1, rgb_gt.clamp(0, 1) * 2 - 1).item())\n\n        os.makedirs(self.plots_dir, exist_ok=True)\n        os.makedirs('{0}/rendering'.format(self.plots_dir), exist_ok=True)\n        if self.is_hdr:\n            os.makedirs('{0}/hdr'.format(self.plots_dir), exist_ok=True)\n        os.makedirs('{0}/depth'.format(self.plots_dir), exist_ok=True)\n        if 'normal_map' in plot_data:\n            os.makedirs('{0}/normal'.format(self.plots_dir), exist_ok=True)\n        if 'lmask_eval' in plot_data:\n            os.makedirs('{0}/light_mask'.format(self.plots_dir), exist_ok=True)\n        if self.val_mesh:\n            os.makedirs('{0}/mesh'.format(self.plots_dir), exist_ok=True)\n        if self.bubble_activated and not self.uniform_bubble:\n            self.plot_hotmap(os.path.join(self.expdir, 'hotmap'))\n            self.plot_countmap(os.path.join(self.expdir, 'countmap'))\n        plt.plot(self.model.implicit_network,\n                indices,\n                plot_data,\n                self.plots_dir,\n                self.global_step,\n                self.img_res,\n                meshing=self.val_mesh,\n                **self.plot_conf\n                )\n\n\n"
  },
  {
    "path": "utils/__init__.py",
    "content": "from .cfgnode import CfgNode\nfrom .rend_util import *\nimport torch\nimport torch.nn.functional as F\nimport torch.nn as nn\nfrom glob import glob\nimport os\nimport numpy as np\nfrom pytorch_lightning.callbacks import RichProgressBar\nfrom rich.progress import TextColumn\n\nclass RichProgressBarWithScanId(RichProgressBar):\n    def __init__(self, scan_id, *args, **kwargs) -> None:\n        super().__init__(*args, **kwargs)\n        self.custom_column = TextColumn(f\"[progress.description]scan_id: {scan_id}\")\n    \n    def configure_columns(self, trainer):\n        return super().configure_columns(trainer) + [self.custom_column]\n\n\ndef glob_imgs(path):\n    imgs = []\n    for ext in ['*.png', '*.jpg', '*.JPEG', '*.JPG', '*.exr']:\n        imgs.extend(glob(os.path.join(path, ext)))\n    return imgs\n\ndef glob_depths(path):\n    imgs = []\n    for ext in ['*.exr']:\n        imgs.extend(glob(os.path.join(path, ext)))\n    return imgs\n\nglob_normal = glob_depths\n\ndef split_input(model_input, total_pixels, n_pixels=10000):\n    '''\n     Split the input to fit Cuda memory for large resolution.\n     Can decrease the value of n_pixels in case of cuda out of memory error.\n     '''\n    split = []\n    for i, indx in enumerate(torch.split(torch.arange(total_pixels, device=model_input['uv'].device), n_pixels, dim=0)):\n        data = model_input.copy()\n        data['uv'] = torch.index_select(model_input['uv'], 1, indx)\n        if 'object_mask' in data:\n            data['object_mask'] = torch.index_select(model_input['object_mask'], 1, indx)\n        split.append(data)\n    return split\n\n\ndef split_dict(d, batch_size=10000):\n    keys = d.keys()\n    splits = {}\n    for k in d:\n        splits[k] = torch.split(d[k], batch_size)\n        n_splits = len(splits[k])\n    split_inputs = []\n    for i in range(n_splits):\n        split = {}\n        for k in d:\n            split[k] = splits[k][i]\n        split_inputs.append(split)\n    return split_inputs\n\n\n\ndef detach_dict(d):\n    return {k: v.detach() for k, v in d.items() if torch.is_tensor(v)}\n\n\ndef merge_output(res, total_pixels, batch_size):\n    ''' Merge the split output. '''\n\n    model_outputs = {}\n    for entry in res[0]:\n        if res[0][entry] is None:\n            continue\n        if len(res[0][entry].shape) == 1:\n            model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res],\n                                             1).reshape(batch_size * total_pixels)\n        else:\n            model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res],\n                                             1).reshape(batch_size * total_pixels, -1)\n\n    return model_outputs\n\n\ndef merge_dict(dicts):\n    output = {}\n    for entry in dicts[0]:\n        output[entry] = torch.cat([r[entry] for r in dicts], dim=0)\n    return output\n\nfrom torch.autograd import Function\nfrom torch.cuda.amp import custom_bwd, custom_fwd \n\nclass _trunc_exp(Function):\n    @staticmethod\n    @custom_fwd(cast_inputs=torch.float32) # cast to float32\n    def forward(ctx, x):\n        ctx.save_for_backward(x)\n        return torch.exp(x)\n\n    @staticmethod\n    @custom_bwd\n    def backward(ctx, g):\n        x = ctx.saved_tensors[0]\n        return g * torch.exp(x.clamp(-15, 15))\n\ntrunc_exp = _trunc_exp.apply\n\ndef kmeans_pp_centroid(points: torch.Tensor, k):\n    n, c = points.shape\n    centroids = torch.zeros(k, c, device=points.device)\n    centroids[0, :] = points[np.random.randint(0, n), :].clone()\n    d = [0.0] * n\n    for i in range(1, k):\n        sum_all = 0\n        d = (points.unsqueeze(1) - centroids[:i,:].unsqueeze(0)).norm(p=2, dim=-1).min(dim=1).values\n        sum_all = d.sum() * np.random.random()\n        cumsum = torch.cumsum(d, dim=0)\n        j = ((cumsum - sum_all) > 0).int().argmax()\n        centroids[i,:] = points[j,:].clone()\n    return centroids"
  },
  {
    "path": "utils/cfgnode.py",
    "content": "\"\"\"\nDefine a class to hold configurations.\nBorrows and merges stuff from YACS, fvcore, and detectron2\nhttps://github.com/rbgirshick/yacs\nhttps://github.com/facebookresearch/fvcore/\nhttps://github.com/facebookresearch/detectron2/\n\"\"\"\n\nimport copy\nimport importlib.util\nimport io\nimport logging\nimport os\nfrom ast import literal_eval\nfrom typing import Optional\n\nimport yaml\n\n# File exts for yaml\n_YAML_EXTS = {\"\", \".yml\", \".yaml\"}\n# File exts for python\n_PY_EXTS = {\".py\"}\n\n# CfgNodes can only contain a limited set of valid types\n_VALID_TYPES = {tuple, list, str, int, float, bool}\n\n# Valid file object types\n_FILE_TYPES = (io.IOBase,)\n\n# Logger\nlogger = logging.getLogger(__name__)\n\n\nclass CfgNode(dict):\n    r\"\"\"CfgNode is a `node` in the configuration `tree`. It's a simple wrapper around a `dict` and supports access to\n    `attributes` via `keys`.\n    \"\"\"\n\n    IMMUTABLE = \"__immutable__\"\n    DEPRECATED_KEYS = \"__deprecated_keys__\"\n    RENAMED_KEYS = \"__renamed_keys__\"\n    NEW_ALLOWED = \"__new_allowed__\"\n\n    def __init__(\n        self,\n        init_dict: Optional[dict] = None,\n        key_list: Optional[list] = None,\n        new_allowed: Optional[bool] = False,\n    ):\n        r\"\"\"\n        Args:\n            init_dict (dict): A dictionary to initialize the `CfgNode`.\n            key_list (list[str]): A list of names that index this `CfgNode` from the root. Currently, only used for\n                logging.\n            new_allowed (bool): Whether adding a new key is allowed when merging with other `CfgNode` objects.\n        \"\"\"\n\n        # Recursively convert nested dictionaries in `init_dict` to config tree.\n        init_dict = {} if init_dict is None else init_dict\n        key_list = [] if key_list is None else key_list\n        init_dict = self._create_config_tree_from_dict(init_dict, key_list)\n        super(CfgNode, self).__init__(init_dict)\n\n        # Control the immutability of the `CfgNode`.\n        self.__dict__[CfgNode.IMMUTABLE] = False\n        # Support for deprecated options.\n        # If you choose to remove support for an option in code, but don't want to change all of the config files\n        # (to allow for deprecated config files to run), you can add the full config key as a string to this set.\n        self.__dict__[CfgNode.DEPRECATED_KEYS] = set()\n        # Support for renamed options.\n        # If you rename an option, record the mapping from the old name to the new name in this dictionary. Optionally,\n        # if the type also changed, you can make this value a tuple that specifies two things: the renamed key, and the\n        # instructions to edit the config file.\n        self.__dict__[CfgNode.RENAMED_KEYS] = {\n            # 'EXAMPLE.OLD.KEY': 'EXAMPLE.NEW.KEY',  # Dummy example\n            # 'EXAMPLE.OLD.KEY': (                   # A more complex example\n            #     'EXAMPLE.NEW.KEY',\n            #     \"Also convert to a tuple, eg. 'foo' -> ('foo', ) or \"\n            #     + \"'foo.bar' -> ('foo', 'bar')\"\n            # ),\n        }\n\n        # Allow new attributes after initialization.\n        self.__dict__[CfgNode.NEW_ALLOWED] = new_allowed\n\n    @classmethod\n    def _create_config_tree_from_dict(cls, init_dict: dict, key_list: list):\n        r\"\"\"Create a configuration tree using the input dict. Any dict-like objects inside `init_dict` will be treated\n        as new `CfgNode` objects.\n        Args:\n            init_dict (dict): Input dictionary, to create config tree from.\n            key_list (list): A list of names that index this `CfgNode` from the root. Currently only used for logging.\n        \"\"\"\n\n        d = copy.deepcopy(init_dict)\n        for k, v in d.items():\n            if isinstance(v, dict):\n                # Convert dictionary to CfgNode\n                d[k] = cls(v, key_list=key_list + [k])\n            else:\n                # Check for valid leaf type or nested CfgNode\n                _assert_with_logging(\n                    _valid_type(v, allow_cfg_node=False),\n                    \"Key {} with value {} is not a valid type; valid types: {}\".format(\n                        \".\".join(key_list + [k]), type(v), _VALID_TYPES\n                    ),\n                )\n        return d\n\n    def __getattr__(self, name: str):\n        if name in self:\n            return self[name]\n        else:\n            raise AttributeError(name)\n\n    def __setattr__(self, name: str, value):\n        if self.is_frozen():\n            raise AttributeError(\n                \"Attempted to set {} to {}, but CfgNode is immutable\".format(\n                    name, value\n                )\n            )\n\n        _assert_with_logging(\n            name not in self.__dict__,\n            \"Invalid attempt to modify internal CfgNode state: {}\".format(name),\n        )\n\n        _assert_with_logging(\n            _valid_type(value, allow_cfg_node=True),\n            \"Invalid type {} for key {}; valid types = {}\".format(\n                type(value), name, _VALID_TYPES\n            ),\n        )\n\n        self[name] = value\n\n    def __str__(self):\n        def _indent(s_, num_spaces):\n            s = s_.split(\"\\n\")\n            if len(s) == 1:\n                return s_\n            first = s.pop(0)\n            s = [(num_spaces * \" \") + line for line in s]\n            s = \"\\n\".join(s)\n            s = first + \"\\n\" + s\n            return s\n\n        r = \"\"\n        s = []\n        for k, v in sorted(self.items()):\n            separator = \"\\n\" if isinstance(v, CfgNode) else \" \"\n            attr_str = \"{}:{}{}\".format(str(k), separator, str(v))\n            attr_str = _indent(attr_str, 2)\n            s.append(attr_str)\n        r += \"\\n\".join(s)\n        return r\n\n    def __repr__(self):\n        return \"{}({})\".format(self.__class__.__name__, super(CfgNode, self).__repr__())\n\n    def dump(self, **kwargs):\n        r\"\"\"Dump CfgNode to a string.\n        \"\"\"\n\n        def _convert_to_dict(cfg_node, key_list):\n            if not isinstance(cfg_node, CfgNode):\n                _assert_with_logging(\n                    _valid_type(cfg_node),\n                    \"Key {} with value {} is not a valid type; valid types: {}\".format(\n                        \".\".join(key_list), type(cfg_node), _VALID_TYPES\n                    ),\n                )\n                return cfg_node\n            else:\n                cfg_dict = dict(cfg_node)\n                for k, v in cfg_dict.items():\n                    cfg_dict[k] = _convert_to_dict(v, key_list + [k])\n                return cfg_dict\n\n        self_as_dict = _convert_to_dict(self, [])\n        return yaml.safe_dump(self_as_dict, **kwargs)\n\n    def merge_from_file(self, cfg_filename: str):\n        r\"\"\"Load a yaml config file and merge it with this CfgNode.\n        Args:\n            cfg_filename (str): Config file path.\n        \"\"\"\n        with open(cfg_filename, \"r\") as f:\n            cfg = self.load_cfg(f)\n        self.merge_from_other_cfg(cfg)\n\n    def merge_from_other_cfg(self, cfg_other):\n        r\"\"\"Merge `cfg_other` into the current `CfgNode`.\n        Args:\n            cfg_other\n        \"\"\"\n        _merge_a_into_b(cfg_other, self, self, [])\n\n    def merge_from_list(self, cfg_list: list):\n        r\"\"\"Merge config (keys, values) in a list (eg. from commandline) into this `CfgNode`.\n        Eg. `cfg_list = ['FOO.BAR', 0.5]`.\n        \"\"\"\n        _assert_with_logging(\n            len(cfg_list) % 2 == 0,\n            \"Override list has odd lengths: {}; it must be a list of pairs\".format(\n                cfg_list\n            ),\n        )\n        root = self\n        for full_key, v in zip(cfg_list[0::2], cfg_list[1::2]):\n            if root.key_is_deprecated(full_key):\n                continue\n            if root.key_is_renamed(full_key):\n                root.raise_key_rename_error(full_key)\n            key_list = full_key.split(\".\")\n            d = self\n            for subkey in key_list[:-1]:\n                _assert_with_logging(\n                    subkey in d, \"Non-existent key: {}\".format(full_key)\n                )\n                d = d[subkey]\n            subkey = key_list[-1]\n            _assert_with_logging(subkey in d, \"Non-existent key: {}\".format(full_key))\n            value = self._decode_cfg_value(v)\n            value = _check_and_coerce_cfg_value_type(value, d[subkey], subkey, full_key)\n            d[subkey] = value\n\n    def freeze(self):\n        r\"\"\"Make this `CfgNode` and all of its children immutable. \"\"\"\n        self._immutable(True)\n\n    def defrost(self):\n        r\"\"\"Make this `CfgNode` and all of its children mutable. \"\"\"\n        self._immutable(False)\n\n    def is_frozen(self):\n        r\"\"\"Return mutability. \"\"\"\n        return self.__dict__[CfgNode.IMMUTABLE]\n\n    def _immutable(self, is_immutable: bool):\n        r\"\"\"Set mutability and recursively apply to all nested `CfgNode` objects.\n        Args:\n            is_immutable (bool): Whether or not the `CfgNode` and its children are immutable.\n        \"\"\"\n        self.__dict__[CfgNode.IMMUTABLE] = is_immutable\n        # Recursively propagate state to all children.\n        for v in self.__dict__.values():\n            if isinstance(v, CfgNode):\n                v._immutable(is_immutable)\n        for v in self.values():\n            if isinstance(v, CfgNode):\n                v._immutable(is_immutable)\n\n    def clone(self):\n        r\"\"\"Recursively copy this `CfgNode`. \"\"\"\n        return copy.deepcopy(self)\n\n    def register_deprecated_key(self, key: str):\n        r\"\"\"Register key (eg. `FOO.BAR`) a deprecated option. When merging deprecated keys, a warning is generated and\n        the key is ignored.\n        \"\"\"\n\n        _assert_with_logging(\n            key not in self.__dict__[CfgNode.DEPRECATED_KEYS],\n            \"key {} is already registered as a deprecated key\".format(key),\n        )\n        self.__dict__[CfgNode.DEPRECATED_KEYS].add(key)\n\n    def register_renamed_key(\n        self, old_name: str, new_name: str, message: Optional[str] = None\n    ):\n        r\"\"\"Register a key as having been renamed from `old_name` to `new_name`. When merging a renamed key, an\n        exception is thrown alerting the user to the fact that the key has been renamed.\n        \"\"\"\n\n        _assert_with_logging(\n            old_name not in self.__dict__[CfgNode.RENAMED_KEYS],\n            \"key {} is already registered as a renamed cfg key\".format(old_name),\n        )\n        value = new_name\n        if message:\n            value = (new_name, message)\n        self.__dict__[CfgNode.RENAMED_KEYS][old_name] = value\n\n    def key_is_deprecated(self, full_key: str):\n        r\"\"\"Test if a key is deprecated. \"\"\"\n        if full_key in self.__dict__[CfgNode.DEPRECATED_KEYS]:\n            logger.warning(\"deprecated config key (ignoring): {}\".format(full_key))\n            return True\n        return False\n\n    def key_is_renamed(self, full_key: str):\n        r\"\"\"Test if a key is renamed. \"\"\"\n        return full_key in self.__dict__[CfgNode.RENAMED_KEYS]\n\n    def raise_key_rename_error(self, full_key: str):\n        new_key = self.__dict__[CfgNode.RENAMED_KEYS][full_key]\n        if isinstance(new_key, tuple):\n            msg = \" Note: \" + new_key[1]\n            new_key = new_key[0]\n        else:\n            msg = \"\"\n        raise KeyError(\n            \"Key {} was renamed to {}; please update your config.{}\".format(\n                full_key, new_key, msg\n            )\n        )\n\n    def is_new_allowed(self):\n        return self.__dict__[CfgNode.NEW_ALLOWED]\n\n    @classmethod\n    def load_cfg(cls, cfg_file_obj_or_str):\n        r\"\"\"Load a configuration into the `CfgNode`.\n        Args:\n            cfg_file_obj_or_str (str or cfg compatible object): Supports loading from:\n                - A file object backed by a YAML file.\n                - A file object backed by a Python source file that exports an sttribute \"cfg\" (dict or `CfgNode`).\n                - A string that can be parsed as valid YAML.\n        \"\"\"\n        _assert_with_logging(\n            isinstance(cfg_file_obj_or_str, _FILE_TYPES + (str,)),\n            \"Expected first argument to be of type {} or {}, but got {}\".format(\n                _FILE_TYPES, str, type(cfg_file_obj_or_str)\n            ),\n        )\n        if isinstance(cfg_file_obj_or_str, str):\n            return cls._load_cfg_from_yaml_str(cfg_file_obj_or_str)\n        elif isinstance(cfg_file_obj_or_str, _FILE_TYPES):\n            return cls._load_cfg_from_file(cfg_file_obj_or_str)\n        else:\n            raise NotImplementedError(\"Impossible to reach here (unless there's a bug)\")\n\n    @classmethod\n    def _load_cfg_from_file(cls, file_obj):\n        r\"\"\"Load a config from a YAML file or a Python source file. \"\"\"\n        _, file_ext = os.path.splitext(file_obj.name)\n        if file_ext in _YAML_EXTS:\n            return cls._load_cfg_from_yaml_str(file_obj.read())\n        elif file_ext in _PY_EXTS:\n            return cls._load_cfg_py_source(file_obj.name)\n        else:\n            raise Exception(\n                \"Attempt to load from an unsupported filetype {}; only {} supported\".format(\n                    _YAML_EXTS.union(_PY_EXTS)\n                )\n            )\n\n    @classmethod\n    def _load_cfg_from_yaml_str(cls, str_obj):\n        r\"\"\"Load a config from a YAML string encoding. \"\"\"\n        cfg_as_dict = yaml.safe_load(str_obj)\n        return cls(cfg_as_dict)\n\n    @classmethod\n    def _load_cfg_py_source(cls, filename):\n        r\"\"\"Load a config from a Python source file. \"\"\"\n        module = _load_module_from_file(\"yacs.config.override\", filename)\n        _assert_with_logging(\n            hasattr(module, \"cfg\"),\n            \"Python module from file {} must export a 'cfg' attribute\".format(filename),\n        )\n        VALID_ATTR_TYPES = {dict, CfgNode}\n        _assert_with_logging(\n            type(module.cfg) in VALID_ATTR_TYPES,\n            \"Import module 'cfg' attribute must be in {} but is {}\".format(\n                VALID_ATTR_TYPES, type(module.cfg)\n            ),\n        )\n        return cls(module.cfg)\n\n    @classmethod\n    def _decode_cfg_value(cls, value):\n        r\"\"\"Decodes a raw config value (eg. from a yaml config file or commandline argument) into a Python object.\n        If `value` is a dict, it will be interpreted as a new `CfgNode`.\n        If `value` is a str, it will be evaluated as a literal.\n        Otherwise, it is returned as is.\n        \"\"\"\n        # Configs parsed from raw yaml will contain dictionary keys that need to be converted to `CfgNode` objects.\n        if isinstance(value, dict):\n            return cls(value)\n        # All remaining processing is only applied to strings.\n        if not isinstance(value, str):\n            return value\n        # Try to interpret `value` as a: string, number, tuple, list, dict, bool, or None\n        try:\n            value = literal_eval(value)\n        # The following two excepts allow `value` to pass through it when it represents a string.\n        # The type of `value` is always a string (before calling `literal_eval`), but sometimes it *represents* a\n        # string and other times a data structure, like a list. In the case that `value` represents a str, what we\n        # got back from the yaml parser is `foo` *without quotes* (so, not `\"foo\"`). `literal_eval` is ok with `\"foo\"`,\n        # but will raise a `ValueError` if given `foo`. In other cases, like paths (`val = 'foo/bar'`) `literal_eval`\n        # will raise a `SyntaxError`.\n        except ValueError:\n            pass\n        except SyntaxError:\n            pass\n        return value\n\n\n# Keep this function in global scope, for backward compataibility.\nload_cfg = CfgNode.load_cfg\n\n\ndef _valid_type(value, allow_cfg_node: Optional[bool] = False):\n    return (type(value) in _VALID_TYPES) or (\n        allow_cfg_node and isinstance(value, CfgNode)\n    )\n\n\ndef _merge_a_into_b(a: CfgNode, b: CfgNode, root: CfgNode, key_list: list):\n    r\"\"\"Merge `CfgNode` `a` into `CfgNode` `b`, clobbering the options in `b` wherever they are also specified in `a`.\n    \"\"\"\n    _assert_with_logging(\n        isinstance(a, CfgNode),\n        \"`a` (cur type {}) must be an instance of {}\".format(type(a), CfgNode),\n    )\n    _assert_with_logging(\n        isinstance(b, CfgNode),\n        \"`b` (cur type {}) must be an instance of {}\".format(type(b), CfgNode),\n    )\n\n    for k, v_ in a.items():\n        full_key = \".\".join(key_list + [k])\n        v = copy.deepcopy(v_)\n        v = b._decode_cfg_value(v)\n\n        if k in b:\n            v = _check_and_coerce_cfg_value_type(v, b[k], k, full_key)\n            # Recursively merge dicts.\n            if isinstance(v, CfgNode):\n                try:\n                    _merge_a_into_b(v, b[k], root, key_list + [k])\n                except BaseException:\n                    raise\n            else:\n                b[k] = v\n        elif b.is_new_allowed():\n            b[k] = v\n        else:\n            if root.key_is_deprecated(full_key):\n                continue\n            elif root.key_is_renamed(full_key):\n                root.raise_key_rename_error(full_key)\n            else:\n                raise KeyError(\"Non-existent config key: {}\".format(full_key))\n\n\ndef _check_and_coerce_cfg_value_type(replacement, original, key, full_key):\n    r\"\"\"Checks that `replacement`, which is intended to replace `original` is of the right type. The type is correct if\n    it matches exactly or is one of a few cases in which the type can easily be coerced.\n    \"\"\"\n\n    original_type = type(original)\n    replacement_type = type(replacement)\n    if replacement_type == original_type:\n        return replacement\n\n    # If replacement and original types match, cast replacement from `from_type` to `to_type`.\n    def _conditional_cast(from_type, to_type):\n        if replacement_type == from_type and original_type == to_type:\n            return True, to_type(replacement)\n        else:\n            return False, None\n\n    # Conditional casts.\n    # list <-> tuple\n    casts = [(tuple, list), (list, tuple)]\n    for (from_type, to_type) in casts:\n        converted, converted_value = _conditional_cast(from_type, to_type)\n        if converted:\n            return converted_value\n\n    raise ValueError(\n        \"Type mismatch ({} vs. {} with values ({} vs. {}) for config key: {}\".format(\n            original_type, replacement_type, original, replacement, full_key\n        )\n    )\n\n\ndef _assert_with_logging(cond, msg):\n    if not cond:\n        logger.debug(msg)\n    assert cond, msg\n\n\ndef _load_module_from_file(name, filename):\n    spec = importlib.util.spec_from_file_location(name, filename)\n    module = importlib.util.module_from_spec(spec)\n    spec.loader.exec_module(module)\n    return module"
  },
  {
    "path": "utils/mesh_util.py",
    "content": "# adapted from https://github.com/zju3dv/manhattan_sdf\nimport numpy as np\nimport open3d as o3d\nfrom sklearn.neighbors import KDTree\nimport trimesh\nimport os\nos.environ['PYOPENGL_PLATFORM'] = 'egl'\nimport pyrender\nfrom tqdm.contrib import tenumerate, tzip\n\n\ndef nn_correspondance(verts1, verts2):\n    indices = []\n    distances = []\n    if len(verts1) == 0 or len(verts2) == 0:\n        return indices, distances\n\n    kdtree = KDTree(verts1)\n    distances, indices = kdtree.query(verts2)\n    distances = distances.reshape(-1)\n\n    return distances\n\n\ndef evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02):\n    pcd_trgt = o3d.geometry.PointCloud()\n    pcd_pred = o3d.geometry.PointCloud()\n    \n    pcd_trgt.points = o3d.utility.Vector3dVector(mesh_trgt.vertices[:, :3])\n    pcd_pred.points = o3d.utility.Vector3dVector(mesh_pred.vertices[:, :3])\n\n    if down_sample:\n        pcd_pred = pcd_pred.voxel_down_sample(down_sample)\n        pcd_trgt = pcd_trgt.voxel_down_sample(down_sample)\n\n    verts_pred = np.asarray(pcd_pred.points)\n    verts_trgt = np.asarray(pcd_trgt.points)\n\n    dist1 = nn_correspondance(verts_pred, verts_trgt)\n    dist2 = nn_correspondance(verts_trgt, verts_pred)\n\n    precision = np.mean((dist2 < threshold).astype('float'))\n    recal = np.mean((dist1 < threshold).astype('float'))\n    fscore = 2 * precision * recal / (precision + recal)\n    metrics = {\n        'Acc': np.mean(dist2),\n        'Comp': np.mean(dist1),\n        'Prec': precision,\n        'Recal': recal,\n        'F-score': fscore,\n    }\n    return metrics\n\n\nclass Renderer():\n    def __init__(self, height=480, width=640):\n        self.renderer = pyrender.OffscreenRenderer(width, height)\n        self.scene = pyrender.Scene()\n        # self.render_flags = pyrender.RenderFlags.SKIP_CULL_FACES\n\n    def __call__(self, height, width, intrinsics, pose, mesh):\n        self.renderer.viewport_height = height\n        self.renderer.viewport_width = width\n        self.scene.clear()\n        self.scene.add(mesh)\n        cam = pyrender.IntrinsicsCamera(cx=intrinsics[0, 2], cy=intrinsics[1, 2],\n                                        fx=intrinsics[0, 0], fy=intrinsics[1, 1])\n        self.scene.add(cam, pose=self.fix_pose(pose))\n        return self.renderer.render(self.scene)  # , self.render_flags)\n\n    def fix_pose(self, pose):\n        # 3D Rotation about the x-axis.\n        t = np.pi\n        c = np.cos(t)\n        s = np.sin(t)\n        R = np.array([[1, 0, 0],\n                      [0, c, -s],\n                      [0, s, c]])\n        axis_transform = np.eye(4)\n        axis_transform[:3, :3] = R\n        return pose @ axis_transform\n\n    def mesh_opengl(self, mesh):\n        return pyrender.Mesh.from_trimesh(mesh)\n\n    def delete(self):\n        self.renderer.delete()\n        \n\ndef refuse(mesh, poses, K, H, W, far_clip=5.0):\n    renderer = Renderer()\n    mesh_opengl = renderer.mesh_opengl(mesh)\n    volume = o3d.pipelines.integration.ScalableTSDFVolume(\n        voxel_length=0.01,\n        sdf_trunc=3 * 0.01,\n        color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8\n    )\n    \n    for i, pose in tenumerate(poses):\n        intrinsic = K\n        \n        rgb = np.ones((H, W, 3))\n        rgb = (rgb * 255).astype(np.uint8)\n        rgb = o3d.geometry.Image(rgb)\n        _, depth_pred = renderer(H, W, intrinsic, pose, mesh_opengl)\n        depth_pred = o3d.geometry.Image(depth_pred)\n        rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(\n            rgb, depth_pred, depth_scale=1.0, depth_trunc=far_clip, convert_rgb_to_intensity=False\n        )\n        fx, fy, cx, cy = intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2]\n        intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx,  fy=fy, cx=cx, cy=cy)\n        extrinsic = np.linalg.inv(pose)\n        volume.integrate(rgbd, intrinsic, extrinsic)\n    \n    return volume.extract_triangle_mesh()\n\ndef depth2mesh(depths, poses, K, H, W):\n    volume = o3d.pipelines.integration.ScalableTSDFVolume(\n        voxel_length=0.01,\n        sdf_trunc=3 * 0.01,\n        color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8\n    )\n    for depth, pose in tzip(depths, poses):\n        rgb = np.ones((H, W, 3))\n        rgb = (rgb * 255).astype(np.uint8)\n        rgb = o3d.geometry.Image(rgb)\n        depth = o3d.geometry.Image(depth)\n        rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(\n            rgb, depth, depth_scale=1.0, depth_trunc=5.0, convert_rgb_to_intensity=False\n        )\n        fx, fy, cx, cy = K[0, 0], K[1, 1], K[0, 2], K[1, 2]\n        intrinsic = o3d.camera.PinholeCameraIntrinsic(width=W, height=H, fx=fx,  fy=fy, cx=cx, cy=cy)\n        extrinsic = np.linalg.inv(pose)\n        volume.integrate(rgbd, intrinsic, extrinsic)\n    return volume.extract_triangle_mesh()"
  },
  {
    "path": "utils/plots.py",
    "content": "import plotly.graph_objs as go\nimport plotly.offline as offline\nfrom plotly.subplots import make_subplots\nimport numpy as np\nimport torch\nfrom skimage import measure\nimport torchvision.utils as vutils\nimport trimesh\nfrom PIL import Image\nimport cv2\nimport mcubes\nfrom utils import rend_util\n\n\ndef plot(implicit_network, indices, plot_data, path, epoch, img_res, plot_nimgs, resolution=None, grid_boundary=None, meshing=True, level=0):\n\n    if plot_data is not None:\n        if 'rgb_eval' in plot_data:\n            plot_images(plot_data['rgb_eval'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res)\n        if 'hdr_eval' in plot_data:\n            plot_images(plot_data['hdr_eval'], plot_data['hdr_gt'], path, epoch, plot_nimgs, img_res, 'hdr', True)\n        if 'rgb_surface' in plot_data:\n            plot_images(plot_data['rgb_surface'], plot_data['rgb_gt'], path, '{0}s'.format(epoch), plot_nimgs, img_res)\n        if 'rendered' in plot_data:\n            plot_images(plot_data['rendered'], plot_data['rgb_gt'], path, epoch, plot_nimgs, img_res, 'rendered')\n\n        if 'normal_map' in plot_data:\n            plot_imgs_wo_gt(plot_data['normal_map'], path, epoch, plot_nimgs, img_res)\n        if 'depth_eval' in plot_data:\n            plot_depths(plot_data['depth_eval'], path, epoch, plot_nimgs, img_res)\n        if 'lmask_eval' in plot_data:\n            plot_images(plot_data['lmask_eval'], plot_data['lmask_gt'], path, epoch, plot_nimgs, img_res, 'light_mask')\n\n        if 'Kd' in plot_data:\n            # plot_imgs_wo_gt(plot_data['albedo'], path, epoch, plot_nimgs, img_res, 'albedo')\n            plot_images(plot_data['Ks'], plot_data['Kd'], path, epoch, plot_nimgs, img_res, 'albedo')\n        if 'roughness' in plot_data:\n            plot_colormap(plot_data['roughness'], path, epoch, plot_nimgs, img_res)\n        if 'metallic' in plot_data:\n            plot_colormap(plot_data['metallic'], path, epoch, plot_nimgs, img_res, 'metallic')\n        if 'emission' in plot_data:\n            plot_imgs_wo_gt(plot_data['emission'], path, '{0}'.format(epoch), plot_nimgs, img_res, 'emission', True)\n\n    if meshing:\n        path = f\"{path}/mesh\"\n        cam_loc, cam_dir = rend_util.get_camera_for_plot(plot_data['pose'])\n        data = []\n\n        # plot surface\n        surface_traces = get_surface_trace(path=path,\n                                        epoch=epoch,\n                                        sdf=lambda x: implicit_network(x)[:, 0],\n                                        resolution=resolution,\n                                        grid_boundary=grid_boundary,\n                                        level=level\n                                        )\n\n        if surface_traces is not None:\n            data.append(surface_traces[0])\n\n        # plot cameras locations\n        if plot_data is not None:\n            for i, loc, dir in zip(indices, cam_loc, cam_dir):\n                data.append(get_3D_quiver_trace(loc.unsqueeze(0), dir.unsqueeze(0), name='camera_{0}'.format(i)))\n\n        fig = go.Figure(data=data)\n        scene_dict = dict(xaxis=dict(range=[-6, 6], autorange=False),\n                        yaxis=dict(range=[-6, 6], autorange=False),\n                        zaxis=dict(range=[-6, 6], autorange=False),\n                        aspectratio=dict(x=1, y=1, z=1))\n        fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)\n        filename = '{0}/surface_{1}.html'.format(path, epoch)\n        offline.plot(fig, filename=filename, auto_open=False)\n\n\ndef visualize_pointcloud(points, filename):\n    fig = go.Figure(\n        data = [\n            get_3D_scatter_trace(points, 'Pointcloud', 1)\n        ]\n    )\n    scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),\n                      yaxis=dict(range=[-3, 3], autorange=False),\n                      zaxis=dict(range=[-3, 3], autorange=False),\n                      aspectratio=dict(x=1, y=1, z=1))\n    fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)\n    offline.plot(fig, filename=filename, auto_open=False)\n\n\ndef visualize_clustered_pointcloud(points, labels, centroids, filename):\n    fig = go.Figure()\n    if centroids is not None:\n        fig.add_trace(get_3D_scatter_trace(centroids, \"Centroids\", 10))\n    for c in torch.unique(labels):\n        cluster = points[labels == c, :]\n        fig.add_trace(get_3D_scatter_trace(cluster, f\"Emitter #{int(c)}\"))\n    scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),\n                      yaxis=dict(range=[-3, 3], autorange=False),\n                      zaxis=dict(range=[-3, 3], autorange=False),\n                      aspectratio=dict(x=1, y=1, z=1))\n    fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)\n    offline.plot(fig, filename=filename, auto_open=False)\n\n\ndef visualize_marked_pointcloud(points, counts, path, epoch):\n    fig = go.Figure(\n        data = [\n            get_3D_marked_scatter_trace(points, counts, 'Pointcloud samples')\n        ]\n    )\n    scene_dict = dict(xaxis=dict(range=[-3, 3], autorange=False),\n                      yaxis=dict(range=[-3, 3], autorange=False),\n                      zaxis=dict(range=[-3, 3], autorange=False),\n                      aspectratio=dict(x=1, y=1, z=1))\n    fig.update_layout(scene=scene_dict, width=1200, height=1200, showlegend=True)\n    filename = '{0}/pointcloud/{1}.html'.format(path, epoch)\n    offline.plot(fig, filename=filename, auto_open=False)\n\n\ndef get_3D_scatter_trace(points, name='', size=3, caption=None):\n    assert points.shape[1] == 3, \"3d scatter plot input points are not correctely shaped \"\n    assert len(points.shape) == 2, \"3d scatter plot input points are not correctely shaped \"\n\n    trace = go.Scatter3d(\n        x=points[:, 0].cpu(),\n        y=points[:, 1].cpu(),\n        z=points[:, 2].cpu(),\n        mode='markers',\n        name=name,\n        marker=dict(\n            size=size,\n            line=dict(\n                width=2,\n            ),\n            opacity=1.0,\n        ), text=caption)\n\n    return trace\n\n\ndef get_3D_marked_scatter_trace(points, marks, name='', size=1, caption=None):\n    assert points.shape[1] == 3, \"3d scatter plot input points are not correctely shaped \"\n    assert len(points.shape) == 2, \"3d scatter plot input points are not correctely shaped \"\n\n    trace = go.Scatter3d(\n        x=points[:, 0].cpu(),\n        y=points[:, 1].cpu(),\n        z=points[:, 2].cpu(),\n        mode='markers',\n        name=name,\n        marker=dict(\n            size=size,\n            line=dict(\n                width=2,\n            ),\n            color=marks.squeeze().cpu(),\n            colorscale='Viridis',\n            opacity=1.0,\n        ), text=caption)\n\n    return trace\n\n\ndef get_3D_quiver_trace(points, directions, color='#bd1540', name=''):\n    assert points.shape[1] == 3, \"3d cone plot input points are not correctely shaped \"\n    assert len(points.shape) == 2, \"3d cone plot input points are not correctely shaped \"\n    assert directions.shape[1] == 3, \"3d cone plot input directions are not correctely shaped \"\n    assert len(directions.shape) == 2, \"3d cone plot input directions are not correctely shaped \"\n\n    trace = go.Cone(\n        name=name,\n        x=points[:, 0].cpu(),\n        y=points[:, 1].cpu(),\n        z=points[:, 2].cpu(),\n        u=directions[:, 0].cpu(),\n        v=directions[:, 1].cpu(),\n        w=directions[:, 2].cpu(),\n        sizemode='absolute',\n        sizeref=0.125,\n        showscale=False,\n        colorscale=[[0, color], [1, color]],\n        anchor=\"tail\"\n    )\n\n    return trace\n\n\ndef get_surface_trace(path, epoch, sdf, resolution=100, grid_boundary=[-2.0, 2.0], return_mesh=False, level=0):\n    grid = get_grid_uniform(resolution, grid_boundary)\n    points = grid['grid_points']\n\n    z = []\n    for i, pnts in enumerate(torch.split(points, 100000, dim=0)):\n        z.append(sdf(pnts).detach().cpu().numpy())\n    z = np.concatenate(z, axis=0)\n\n    if (not (np.min(z) > level or np.max(z) < level)):\n\n        z = z.astype(np.float32)\n\n        verts, faces, normals, values = measure.marching_cubes(\n            volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],\n                             grid['xyz'][2].shape[0]).transpose([1, 0, 2]),\n            level=level,\n            spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],\n                     grid['xyz'][0][2] - grid['xyz'][0][1],\n                     grid['xyz'][0][2] - grid['xyz'][0][1]))\n\n        verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])\n\n        I, J, K = faces.transpose()\n\n        traces = [go.Mesh3d(x=verts[:, 0], y=verts[:, 1], z=verts[:, 2],\n                            i=I, j=J, k=K, name='implicit_surface',\n                            color='#ffffff', opacity=1.0, flatshading=False,\n                            lighting=dict(diffuse=1, ambient=0, specular=0),\n                            lightposition=dict(x=0, y=0, z=-1), showlegend=True)]\n\n        meshexport = trimesh.Trimesh(verts, faces, normals)\n        meshexport.export('{0}/surface_{1}.ply'.format(path, epoch), 'ply')\n\n        if return_mesh:\n            return meshexport\n        return traces\n    return None\n\n\ndef extract_fields(bound_min, bound_max, resolution, query_func):\n    N = 64\n    X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N)\n    Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N)\n    Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N)\n\n    u = np.zeros([resolution, resolution, resolution], dtype=np.float32)\n    with torch.no_grad():\n        for xi, xs in enumerate(X):\n            for yi, ys in enumerate(Y):\n                for zi, zs in enumerate(Z):\n                    xx, yy, zz = torch.meshgrid(xs, ys, zs)\n                    pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1)\n                    val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy()\n                    u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val\n    return u\n\n\ndef extract_geometry(bound_min, bound_max, resolution, threshold, query_func):\n    print('threshold: {}'.format(threshold))\n    u = extract_fields(bound_min, bound_max, resolution, query_func)\n    vertices, triangles = mcubes.marching_cubes(u, threshold)\n    b_max_np = bound_max.detach().cpu().numpy()\n    b_min_np = bound_min.detach().cpu().numpy()\n\n    vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :]\n    mesh = trimesh.Trimesh(vertices, triangles)\n    return mesh\n\n\ndef get_surface_high_res_mesh(sdf, resolution=100, grid_boundary=[-2.0, 2.0], level=0, take_components=True):\n    # get low res mesh to sample point cloud\n    grid = get_grid_uniform(100, grid_boundary)\n    z = []\n    points = grid['grid_points']\n\n    for i, pnts in enumerate(torch.split(points, 100000, dim=0)):\n        z.append(sdf(pnts).detach().cpu().numpy())\n    z = np.concatenate(z, axis=0)\n\n    z = z.astype(np.float32)\n\n    verts, faces, normals, values = measure.marching_cubes(\n        volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],\n                         grid['xyz'][2].shape[0]).transpose([1, 0, 2]),\n        level=level,\n        spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],\n                 grid['xyz'][0][2] - grid['xyz'][0][1],\n                 grid['xyz'][0][2] - grid['xyz'][0][1]))\n\n    verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])\n\n    mesh_low_res = trimesh.Trimesh(verts, faces, normals)\n    if take_components:\n        components = mesh_low_res.split(only_watertight=False)\n        areas = np.array([c.area for c in components], dtype=np.float)\n        mesh_low_res = components[areas.argmax()]\n\n    recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]\n    recon_pc = torch.from_numpy(recon_pc).float().cuda()\n\n    # Center and align the recon pc\n    s_mean = recon_pc.mean(dim=0)\n    s_cov = recon_pc - s_mean\n    s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)\n    vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0]\n    if torch.det(vecs) < 0:\n        vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs)\n    helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),\n                       (recon_pc - s_mean).unsqueeze(-1)).squeeze()\n\n    grid_aligned = get_grid(helper.cpu(), resolution)\n\n    grid_points = grid_aligned['grid_points']\n\n    g = []\n    for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)):\n        g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),\n                           pnts.unsqueeze(-1)).squeeze() + s_mean)\n    grid_points = torch.cat(g, dim=0)\n\n    # MC to new grid\n    points = grid_points\n    z = []\n    for i, pnts in enumerate(torch.split(points, 100000, dim=0)):\n        z.append(sdf(pnts).detach().cpu().numpy())\n    z = np.concatenate(z, axis=0)\n\n    meshexport = None\n    if (not (np.min(z) > level or np.max(z) < level)):\n\n        z = z.astype(np.float32)\n\n        verts, faces, normals, values = measure.marching_cubes(\n            volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0],\n                             grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]),\n            level=level,\n            spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],\n                     grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],\n                     grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1]))\n\n        verts = torch.from_numpy(verts).cuda().float()\n        verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),\n                   verts.unsqueeze(-1)).squeeze()\n        verts = (verts + grid_points[0]).cpu().numpy()\n\n        meshexport = trimesh.Trimesh(verts, faces, normals)\n\n    return meshexport\n\n\ndef get_surface_by_grid(grid_params, sdf, resolution=100, level=0, higher_res=False):\n    grid_params = grid_params * [[1.5], [1.0]]\n\n    # params = PLOT_DICT[scan_id]\n    input_min = torch.tensor(grid_params[0]).float()\n    input_max = torch.tensor(grid_params[1]).float()\n\n    if higher_res:\n        # get low res mesh to sample point cloud\n        grid = get_grid(None, 100, input_min=input_min, input_max=input_max, eps=0.0)\n        z = []\n        points = grid['grid_points']\n\n        for i, pnts in enumerate(torch.split(points, 100000, dim=0)):\n            z.append(sdf(pnts).detach().cpu().numpy())\n        z = np.concatenate(z, axis=0)\n\n        z = z.astype(np.float32)\n\n        verts, faces, normals, values = measure.marching_cubes(\n            volume=z.reshape(grid['xyz'][1].shape[0], grid['xyz'][0].shape[0],\n                             grid['xyz'][2].shape[0]).transpose([1, 0, 2]),\n            level=level,\n            spacing=(grid['xyz'][0][2] - grid['xyz'][0][1],\n                     grid['xyz'][0][2] - grid['xyz'][0][1],\n                     grid['xyz'][0][2] - grid['xyz'][0][1]))\n\n        verts = verts + np.array([grid['xyz'][0][0], grid['xyz'][1][0], grid['xyz'][2][0]])\n\n        mesh_low_res = trimesh.Trimesh(verts, faces, normals)\n        components = mesh_low_res.split(only_watertight=False)\n        areas = np.array([c.area for c in components], dtype=np.float)\n        mesh_low_res = components[areas.argmax()]\n\n        recon_pc = trimesh.sample.sample_surface(mesh_low_res, 10000)[0]\n        recon_pc = torch.from_numpy(recon_pc).float().cuda()\n\n        # Center and align the recon pc\n        s_mean = recon_pc.mean(dim=0)\n        s_cov = recon_pc - s_mean\n        s_cov = torch.mm(s_cov.transpose(0, 1), s_cov)\n        vecs = torch.view_as_real(torch.linalg.eig(s_cov)[1].transpose(0, 1))[:, :, 0]\n        if torch.det(vecs) < 0:\n            vecs = torch.mm(torch.tensor([[1, 0, 0], [0, 0, 1], [0, 1, 0]]).cuda().float(), vecs)\n        helper = torch.bmm(vecs.unsqueeze(0).repeat(recon_pc.shape[0], 1, 1),\n                           (recon_pc - s_mean).unsqueeze(-1)).squeeze()\n\n        grid_aligned = get_grid(helper.cpu(), resolution, eps=0.01)\n    else:\n        grid_aligned = get_grid(None, resolution, input_min=input_min, input_max=input_max, eps=0.0)\n\n    grid_points = grid_aligned['grid_points']\n\n    if higher_res:\n        g = []\n        for i, pnts in enumerate(torch.split(grid_points, 100000, dim=0)):\n            g.append(torch.bmm(vecs.unsqueeze(0).repeat(pnts.shape[0], 1, 1).transpose(1, 2),\n                               pnts.unsqueeze(-1)).squeeze() + s_mean)\n        grid_points = torch.cat(g, dim=0)\n\n    # MC to new grid\n    points = grid_points\n    z = []\n    for i, pnts in enumerate(torch.split(points, 100000, dim=0)):\n        z.append(sdf(pnts).detach().cpu().numpy())\n    z = np.concatenate(z, axis=0)\n\n    meshexport = None\n    if (not (np.min(z) > level or np.max(z) < level)):\n\n        z = z.astype(np.float32)\n\n        verts, faces, normals, values = measure.marching_cubes(\n            volume=z.reshape(grid_aligned['xyz'][1].shape[0], grid_aligned['xyz'][0].shape[0],\n                             grid_aligned['xyz'][2].shape[0]).transpose([1, 0, 2]),\n            level=level,\n            spacing=(grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],\n                     grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1],\n                     grid_aligned['xyz'][0][2] - grid_aligned['xyz'][0][1]))\n\n        if higher_res:\n            verts = torch.from_numpy(verts).cuda().float()\n            verts = torch.bmm(vecs.unsqueeze(0).repeat(verts.shape[0], 1, 1).transpose(1, 2),\n                       verts.unsqueeze(-1)).squeeze()\n            verts = (verts + grid_points[0]).cpu().numpy()\n        else:\n            verts = verts + np.array([grid_aligned['xyz'][0][0], grid_aligned['xyz'][1][0], grid_aligned['xyz'][2][0]])\n\n        meshexport = trimesh.Trimesh(verts, faces, normals)\n\n        # CUTTING MESH ACCORDING TO THE BOUNDING BOX\n        if higher_res:\n            bb = grid_params\n            transformation = np.eye(4)\n            transformation[:3, 3] = (bb[1,:] + bb[0,:])/2.\n            bounding_box = trimesh.creation.box(extents=bb[1,:] - bb[0,:], transform=transformation)\n\n            meshexport = meshexport.slice_plane(bounding_box.facets_origin, -bounding_box.facets_normal)\n\n    return meshexport\n\ndef get_grid_uniform(resolution, grid_boundary=[-2.0, 2.0]):\n    x = np.linspace(grid_boundary[0], grid_boundary[1], resolution)\n    y = x\n    z = x\n\n    xx, yy, zz = np.meshgrid(x, y, z)\n    grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float)\n\n    return {\"grid_points\": grid_points.cuda(),\n            \"shortest_axis_length\": 2.0,\n            \"xyz\": [x, y, z],\n            \"shortest_axis_index\": 0}\n\ndef get_grid(points, resolution, input_min=None, input_max=None, eps=0.1):\n    if input_min is None or input_max is None:\n        input_min = torch.min(points, dim=0)[0].squeeze().numpy()\n        input_max = torch.max(points, dim=0)[0].squeeze().numpy()\n\n    bounding_box = input_max - input_min\n    shortest_axis = np.argmin(bounding_box)\n    if (shortest_axis == 0):\n        x = np.linspace(input_min[shortest_axis] - eps,\n                        input_max[shortest_axis] + eps, resolution)\n        length = np.max(x) - np.min(x)\n        y = np.arange(input_min[1] - eps, input_max[1] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))\n        z = np.arange(input_min[2] - eps, input_max[2] + length / (x.shape[0] - 1) + eps, length / (x.shape[0] - 1))\n    elif (shortest_axis == 1):\n        y = np.linspace(input_min[shortest_axis] - eps,\n                        input_max[shortest_axis] + eps, resolution)\n        length = np.max(y) - np.min(y)\n        x = np.arange(input_min[0] - eps, input_max[0] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))\n        z = np.arange(input_min[2] - eps, input_max[2] + length / (y.shape[0] - 1) + eps, length / (y.shape[0] - 1))\n    elif (shortest_axis == 2):\n        z = np.linspace(input_min[shortest_axis] - eps,\n                        input_max[shortest_axis] + eps, resolution)\n        length = np.max(z) - np.min(z)\n        x = np.arange(input_min[0] - eps, input_max[0] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))\n        y = np.arange(input_min[1] - eps, input_max[1] + length / (z.shape[0] - 1) + eps, length / (z.shape[0] - 1))\n    print(x.shape, y.shape, z.shape)\n    xx, yy, zz = np.meshgrid(x, y, z)\n    # print(xx.shape, yy.shape, zz.shape)\n    # xx = torch.from_numpy(xx.flatten()).cuda().float()\n    # yy = torch.from_numpy(yy.flatten()).cuda().float()\n    # zz = torch.from_numpy(zz.flatten()).cuda().float()\n    # grid_points = torch.cat([xx, yy, zz], dim=1).T\n    grid_points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()\n    return {\"grid_points\": grid_points,\n            \"shortest_axis_length\": length,\n            \"xyz\": [x, y, z],\n            \"shortest_axis_index\": shortest_axis}\n\n\ndef plot_imgs_wo_gt(normal_maps, path, epoch, plot_nrow, img_res, path_name='normal', is_hdr=False):\n    normal_maps_plot = lin2img(normal_maps, img_res)\n    tensor = vutils.make_grid(normal_maps_plot,\n                                scale_each=False,\n                                normalize=False,\n                                nrow=plot_nrow).cpu().detach().numpy()\n    tensor = tensor.transpose(1, 2, 0)\n    if not is_hdr:\n        scale_factor = 255\n        tensor = (tensor * scale_factor).astype(np.uint8)\n        img = Image.fromarray(tensor)\n        img.save('{0}/{1}/{2}.png'.format(path, path_name, epoch))\n    else:\n        cv2.imwrite('{0}/{1}/{2}.exr'.format(path, path_name, epoch), tensor[:,:,::-1])\n\n\ndef plot_imgs_filter(rgb_points, ground_true, path, epoch, img_res, path_name='rendering'):\n    output = lin2img(rgb_points, img_res).squeeze(0) # (1, 3, h, w)\n    ground_true = lin2img(ground_true, img_res).squeeze(0)\n    output = output.permute(1, 2, 0).cpu().numpy()\n    scale_factor = 255\n    output = (output * scale_factor).astype(np.uint8)\n    ground_true = ground_true.permute(1, 2, 0).cpu().numpy()\n    ground_true = (ground_true * scale_factor).astype(np.uint8)\n    output = output[:,:,::-1]\n    ground_true = ground_true[:,:,::-1]\n    filtered = cv2.ximgproc.guidedFilter(ground_true, output, 10, 2, -1)\n    cv2.imwrite('{0}/{1}/{2}.png'.format(path, path_name, epoch), filtered)\n\n\ndef plot_colormap(mat_info, path, epoch, plot_nrow, img_res, colormap=cv2.COLORMAP_VIRIDIS, path_name='roughness'):\n    mat_info_plot = lin2img(mat_info, img_res)\n\n    tensor = vutils.make_grid(mat_info_plot,\n                                         scale_each=False,\n                                         normalize=False,\n                                         nrow=plot_nrow).cpu().detach().numpy()\n    tensor = tensor.transpose(1, 2, 0)\n    if colormap is None:\n        cv2.imwrite('{0}/{1}/{2}.exr'.format(path, path_name, epoch), tensor)\n    else:\n        tensor = (tensor * 255).astype(np.uint8)\n        img = cv2.applyColorMap(tensor, colormap)\n        cv2.imwrite('{0}/{1}/{2}.png'.format(path, path_name, epoch), img)\n\n\ndef plot_depths(depth_maps, path, epoch, plot_nrow, img_res, colormap=cv2.COLORMAP_VIRIDIS):\n    depth_maps_plot = lin2img(depth_maps, img_res)\n\n    tensor = vutils.make_grid(depth_maps_plot,\n                                         scale_each=False,\n                                         normalize=False,\n                                         nrow=plot_nrow).cpu().detach().numpy()\n    tensor = tensor.transpose(1, 2, 0)\n    # scale_factor = 255\n    # tensor = (tensor * scale_factor).astype(np.uint8)\n    if colormap is None:\n        cv2.imwrite('{0}/depth/{1}.exr'.format(path, epoch), tensor)\n    else:\n        tensor = tensor / (tensor.max() + 1e-6)\n        tensor = (tensor * 255).astype(np.uint8)\n        img = cv2.applyColorMap(tensor, colormap)\n        cv2.imwrite('{0}/depth/{1}.png'.format(path, epoch), img)\n\n    # img = Image.fromarray(tensor)\n    # img.save('{0}/normal_{1}.png'.format(path, epoch))\n\n\ndef plot_images(rgb_points, ground_true, path, epoch, plot_nrow, img_res, path_name='rendering', is_hdr=False):\n    ground_true = ground_true.cuda()\n\n    output_vs_gt = torch.cat((rgb_points, ground_true), dim=0)\n    output_vs_gt_plot = lin2img(output_vs_gt, img_res)\n\n    tensor = vutils.make_grid(output_vs_gt_plot,\n                                         scale_each=False,\n                                         normalize=False,\n                                         nrow=plot_nrow).cpu().detach().numpy()\n\n    tensor = tensor.transpose(1, 2, 0)\n    if not is_hdr:\n        scale_factor = 255\n        tensor = (tensor * scale_factor).astype(np.uint8)\n        img = Image.fromarray(tensor)\n        img.save('{0}/{1}/{2}.png'.format(path, path_name, epoch))\n    else:\n        cv2.imwrite('{0}/{1}/{2}.exr'.format(path, path_name, epoch), tensor[:,:,::-1])\n\n\ndef lin2img(tensor, img_res):\n    batch_size, num_samples, channels = tensor.shape\n    return tensor.permute(0, 2, 1).view(batch_size, channels, img_res[0], img_res[1])\n"
  },
  {
    "path": "utils/rend_util.py",
    "content": "import numpy as np\nimport imageio\nimport skimage\nimport cv2\nimport torch\nfrom torch.nn import functional as F\n\n\ndef linear_to_srgb(data):\n    return torch.where(data <= 0.0031308, data * 12.92, 1.055 * (data ** (1 / 2.4)) - 0.055)\n\n\ndef get_psnr(img1, img2, normalize_rgb=False):\n    if normalize_rgb: # [-1,1] --> [0,1]\n        img1 = (img1 + 1.) / 2.\n        img2 = (img2 + 1. ) / 2.\n\n    mse = torch.mean((img1 - img2) ** 2)\n    # psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda())\n    psnr = -10. * torch.log(mse) / np.log(10)\n\n    return psnr\n\n\ndef load_rgb(path, normalize_rgb = False, is_hdr = False):\n    if not is_hdr:\n        img = imageio.imread(path)\n        img = skimage.img_as_float32(img)\n    else:\n        img = cv2.imread(path, -1)[:,:,::-1].copy()\n\n    if normalize_rgb: # [-1,1] --> [0,1]\n        img -= 0.5\n        img *= 2.\n    img = img.transpose(2, 0, 1)\n    return img\n\ndef load_mask(path):\n    img = imageio.imread(path)\n    img = skimage.img_as_float32(img)\n    if len(img.shape) == 3:\n        img = img[:, :, 0]\n    return img # (h, w)\n\n\ndef load_depth(path):\n    img = cv2.imread(path, -1)\n    if len(img.shape) == 3:\n        img = img[:,:,-1]\n    return img\n\ndef load_normal(path):\n    img = cv2.imread(path, -1)[:,:,::-1]\n    return img.copy()\n\n\ndef load_K_Rt_from_P(filename, P=None):\n    if P is None:\n        lines = open(filename).read().splitlines()\n        if len(lines) == 4:\n            lines = lines[1:]\n        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(\" \") for x in lines)]\n        P = np.asarray(lines).astype(np.float32).squeeze()\n\n    out = cv2.decomposeProjectionMatrix(P)\n    K = out[0]\n    R = out[1]\n    t = out[2]\n\n    K = K/K[2,2]\n    intrinsics = np.eye(4, dtype=np.float32)\n    intrinsics[:3, :3] = K\n\n    pose = np.eye(4, dtype=np.float32)\n    pose[:3, :3] = R.transpose()\n    pose[:3,3] = (t[:3] / t[3])[:,0]\n\n    return intrinsics, pose\n\n\ndef depth_to_world(uv, intrinsics, pose, depth, depth_mask=None):\n    x_cam, y_cam = torch.unbind(uv, dim=1)\n    z_cam = torch.ones_like(x_cam)\n    xyz_view = lift(x_cam, y_cam, z_cam, intrinsics)\n    xyz_view[:,:-1] = xyz_view[:,:-1] * depth.unsqueeze(1)\n    if depth_mask is not None:\n        xyz_view = xyz_view[depth_mask,:]\n    xyz_world = pose @ xyz_view.T\n    return xyz_world.T\n\n\ndef get_camera_params(uv, pose, intrinsics):\n    if pose.shape[1] == 7: #In case of quaternion vector representation\n        cam_loc = pose[:, 4:]\n        R = quat_to_rot(pose[:,:4])\n        p = torch.eye(4, device=pose.device).repeat(pose.shape[0],1,1).float()\n        p[:, :3, :3] = R\n        p[:, :3, 3] = cam_loc\n    else: # In case of pose matrix representation\n        cam_loc = pose[:, :3, 3]\n        p = pose\n\n    batch_size, num_samples, _ = uv.shape\n\n    depth = torch.ones((batch_size, num_samples), device=pose.device)\n    x_cam = uv[:, :, 0].view(batch_size, -1)\n    y_cam = uv[:, :, 1].view(batch_size, -1)\n    # z_cam = -depth.view(batch_size, -1)\n    z_cam = depth.view(batch_size, -1)\n\n    pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)\n    # pixel_points_cam[:,:,0] = -pixel_points_cam[:,:,0]\n    # permute for batch matrix product\n    pixel_points_cam = pixel_points_cam.permute(0, 2, 1)\n\n    world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]\n    ray_dirs = world_coords - cam_loc[:, None, :]\n    # ray_dirs = F.normalize(ray_dirs, dim=2)\n\n    return ray_dirs, cam_loc\n\n\ndef get_camera_for_plot(pose):\n    if pose.shape[1] == 7: #In case of quaternion vector representation\n        cam_loc = pose[:, 4:].detach()\n        R = quat_to_rot(pose[:,:4].detach())\n    else: # In case of pose matrix representation\n        cam_loc = pose[:, :3, 3]\n        R = pose[:, :3, :3]\n    cam_dir = R[:, :3, 2]\n    return cam_loc, cam_dir\n\n\ndef lift(x, y, z, intrinsics):\n    # parse intrinsics\n    intrinsics = intrinsics\n    fx = intrinsics[..., 0, 0]\n    fy = intrinsics[..., 1, 1]\n    cx = intrinsics[..., 0, 2]\n    cy = intrinsics[..., 1, 2]\n    sk = intrinsics[..., 0, 1]\n\n    x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z\n    y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z\n\n    # homogeneous\n    return torch.stack((x_lift, y_lift, z, torch.ones_like(z)), dim=-1)\n\n\ndef quat_to_rot(q):\n    batch_size, _ = q.shape\n    q = F.normalize(q, dim=1)\n    R = torch.ones((batch_size, 3,3), device=q.device)\n    qr=q[:,0]\n    qi = q[:, 1]\n    qj = q[:, 2]\n    qk = q[:, 3]\n    R[:, 0, 0]=1-2 * (qj**2 + qk**2)\n    R[:, 0, 1] = 2 * (qj *qi -qk*qr)\n    R[:, 0, 2] = 2 * (qi * qk + qr * qj)\n    R[:, 1, 0] = 2 * (qj * qi + qk * qr)\n    R[:, 1, 1] = 1-2 * (qi**2 + qk**2)\n    R[:, 1, 2] = 2*(qj*qk - qi*qr)\n    R[:, 2, 0] = 2 * (qk * qi-qj * qr)\n    R[:, 2, 1] = 2 * (qj*qk + qi*qr)\n    R[:, 2, 2] = 1-2 * (qi**2 + qj**2)\n    return R\n\n\ndef rot_to_quat(R):\n    batch_size, _,_ = R.shape\n    q = torch.ones((batch_size, 4), device=R.device)\n\n    R00 = R[:, 0,0]\n    R01 = R[:, 0, 1]\n    R02 = R[:, 0, 2]\n    R10 = R[:, 1, 0]\n    R11 = R[:, 1, 1]\n    R12 = R[:, 1, 2]\n    R20 = R[:, 2, 0]\n    R21 = R[:, 2, 1]\n    R22 = R[:, 2, 2]\n\n    q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2\n    q[:, 1]=(R21-R12)/(4*q[:,0])\n    q[:, 2] = (R02 - R20) / (4 * q[:, 0])\n    q[:, 3] = (R10 - R01) / (4 * q[:, 0])\n    return q\n\n\ndef get_general_sphere_intersections(cam_loc, ray_directions, center, r):\n    n_rays = cam_loc.size(0)\n    # print(cam_loc.shape, ray_directions.shape)\n    cam_loc = cam_loc - center.unsqueeze(0)\n    ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3),\n                            cam_loc.view(-1, 3, 1)).squeeze(-1)\n    under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2)\n    intersect_mask = (under_sqrt >= 0).squeeze(-1) # (n_rays,)\n    under_sqrt = under_sqrt[intersect_mask,:]\n    ray_cam_dot = ray_cam_dot[intersect_mask,:]\n    sphere_intersections = torch.sqrt(under_sqrt) * torch.tensor([-1, 1], device=cam_loc.device).float() - ray_cam_dot\n    front_mask = (sphere_intersections > 0).all(dim=-1)\n    intersect_mask[intersect_mask.clone()] &= front_mask\n    sphere_intersections = sphere_intersections[front_mask,:]\n    intersection_normals = cam_loc[intersect_mask,:] + ray_directions[intersect_mask,:] * sphere_intersections[:,:1]\n    intersection_points = intersection_normals + center.unsqueeze(0)\n    intersection_normals = F.normalize(intersection_normals, dim=1, eps=1e-8)\n    return intersection_points, intersection_normals, intersect_mask\n\n\ndef get_sphere_intersections(cam_loc, ray_directions, r = 1.0):\n    # Input: n_rays x 3 ; n_rays x 3\n    # Output: n_rays x 1, n_rays x 1 (close and far)\n\n    ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3),\n                            cam_loc.view(-1, 3, 1)).squeeze(-1)\n    under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2)\n\n    # sanity check\n    if (under_sqrt <= 0).sum() > 0:\n        print('BOUNDING SPHERE PROBLEM!')\n        exit()\n\n    sphere_intersections = torch.sqrt(under_sqrt) * torch.tensor([-1, 1], device=cam_loc.device).float() - ray_cam_dot\n    sphere_intersections = sphere_intersections.clamp_min(0.0)\n\n    return sphere_intersections\n\ndef add_depth_noise(depth, depth_mask, scale=1):\n    mu = 0.0001125 * depth**2 + 0.0048875\n    sigma = 0.002925 * depth**2 + 0.003325\n    noise = torch.randn_like(depth) * sigma + mu\n    return (depth + noise * scale) * depth_mask"
  }
]