[
  {
    "path": ".gitmodules",
    "content": "[submodule \"submodules/simple-knn\"]\n\tpath = submodules/simple-knn\n\turl = https://gitlab.inria.fr/bkerbl/simple-knn.git\n[submodule \"submodules/diff-gaussian-rasterization\"]\n\tpath = submodules/diff-gaussian-rasterization\n\turl = https://github.com/HLinChen/diff-gaussian-rasterization\n[submodule \"SIBR_viewers\"]\n\tpath = SIBR_viewers\n\turl = https://gitlab.inria.fr/sibr/sibr_core.git\n[submodule \"submodules/colmap\"]\n\tpath = submodules/colmap\n\turl = https://github.com/colmap/colmap.git\n"
  },
  {
    "path": "LICENSE.md",
    "content": "Gaussian-Splatting License  \n===========================  \n\n**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.  \nThe *Software* is in the process of being registered with the Agence pour la Protection des  \nProgrammes (APP).  \n\nThe *Software* is still being developed by the *Licensor*.  \n\n*Licensor*'s goal is to allow the research community to use, test and evaluate  \nthe *Software*.  \n\n## 1.  Definitions  \n\n*Licensee* means any person or entity that uses the *Software* and distributes  \nits *Work*.  \n\n*Licensor* means the owners of the *Software*, i.e Inria and MPII  \n\n*Software* means the original work of authorship made available under this  \nLicense ie gaussian-splatting.  \n\n*Work* means the *Software* and any additions to or derivative works of the  \n*Software* that are made available under this License.  \n\n\n## 2.  Purpose  \nThis license is intended to define the rights granted to the *Licensee* by  \nLicensors under the *Software*.  \n\n## 3.  Rights granted  \n\nFor the above reasons Licensors have decided to distribute the *Software*.  \nLicensors grant non-exclusive rights to use the *Software* for research purposes  \nto research users (both academic and industrial), free of charge, without right  \nto sublicense.. The *Software* may be used \"non-commercially\", i.e., for research  \nand/or evaluation purposes only.  \n\nSubject to the terms and conditions of this License, you are granted a  \nnon-exclusive, royalty-free, license to reproduce, prepare derivative works of,  \npublicly display, publicly perform and distribute its *Work* and any resulting  \nderivative works in any form.  \n\n## 4.  Limitations  \n\n**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do  \nso under this License, (b) you include a complete copy of this License with  \nyour distribution, and (c) you retain without modification any copyright,  \npatent, trademark, or attribution notices that are present in the *Work*.  \n\n**4.2 Derivative Works.** You may specify that additional or different terms apply  \nto the use, reproduction, and distribution of your derivative works of the *Work*  \n(\"Your Terms\") only if (a) Your Terms provide that the use limitation in  \nSection 2 applies to your derivative works, and (b) you identify the specific  \nderivative works that are subject to Your Terms. Notwithstanding Your Terms,  \nthis License (including the redistribution requirements in Section 3.1) will  \ncontinue to apply to the *Work* itself.  \n\n**4.3** Any other use without of prior consent of Licensors is prohibited. Research  \nusers explicitly acknowledge having received from Licensors all information  \nallowing to appreciate the adequacy between of the *Software* and their needs and  \nto undertake all necessary precautions for its execution and use.  \n\n**4.4** The *Software* is provided both as a compiled library file and as source  \ncode. In case of using the *Software* for a publication or other results obtained  \nthrough the use of the *Software*, users are strongly encouraged to cite the  \ncorresponding publications as explained in the documentation of the *Software*.  \n\n## 5.  Disclaimer  \n\nTHE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES  \nWITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY  \nUNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL  \nCONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED \"AS IS\" WITHOUT ANY WARRANTIES  \nOF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL  \nUSE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR  \nADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE  \nAUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR  \nCONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE  \nGOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)  \nHOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT  \nLIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR  \nIN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.  \n"
  },
  {
    "path": "README.md",
    "content": "<p align=\"center\">\n\n  <h1 align=\"center\">VCR-GauS: View Consistent Depth-Normal Regularizer for Gaussian Surface Reconstruction</h1>\n  <p align=\"center\">\n    <a href=\"https://hlinchen.github.io/\">Hanlin Chen</a>,\n    <a href=\"https://weify627.github.io/\">Fangyin Wei</a>,\n    <a href=\"https://chaneyddtt.github.io/\">Chen Li</a>,\n    <a href=\"https://tianxinhuang.github.io/\">Tianxin Huang</a>,\n    <a href=\"https://scholar.google.com/citations?user=vv1uLeUAAAAJ&hl=en\">Yunsong Wang</a>,\n    <a href=\"https://www.comp.nus.edu.sg/~leegh/\">Gim Hee Lee</a>\n\n  </p>\n\n  <h2 align=\"center\">NeurIPS 2024</h2>\n\n  <h3 align=\"center\"><a href=\"https://arxiv.org/pdf/2406.05774\">arXiv</a> | <a href=\"https://hlinchen.github.io/projects/VCR-GauS/\">Project Page</a>  </h3>\n  <div align=\"center\"></div>\n</p>\n\n\n<p align=\"center\">\n  <a href=\"\">\n    <img src=\"./media/VCR-GauS.jpg\" alt=\"Logo\" width=\"95%\">\n  </a>\n</p>\n\n<p align=\"left\">\nVCR-GauS formulates a novel multi-view D-Normal regularizer that enables full optimization of the Gaussian geometric parameters to achieve better surface reconstruction. We further design a confidence term to weigh our D-Normal regularizer to mitigate inconsistencies of normal predictions across multiple views.</p>\n<br>\n\n# Updates\n\n* **[2024.10.31]**: We uploaded a new version to arXiv, adding theoretical proofs and visualization results for the D-Normal Regularizer.\n* **[2024.09.24]**: VCR-GauS is accepted to NeurIPS 2024.\n\n# Installation\nClone the repository and create an anaconda environment using\n```\ngit clone https://github.com/HLinChen/VCR-GauS.git --recursive\ncd VCR-GauS\ngit pull --recurse-submodules\n\nenv=vcr\nconda create -n $env -y python=3.10\nconda activate $env\npip install -e \".[train]\"\n# you can specify your own cuda path\nexport CUDA_HOME=/usr/local/cuda-11.8\npip install -r requirements.txt\n```\nWe also uploaded a built anaconda environment [here](https://huggingface.co/hanlin-chen/VCR-GauS/resolve/main/vcr.zip?download=true); you can download it and unzip and put it in your_anaconda_path/envs/ .\n\nFor eval TNT with the official scripts, you need to build a new environment with open3d==0.10:\n```\nenv=f1eval\nconda create -n $env -y python=3.8\nconda activate $env\npip install -e \".[f1eval]\"\n```\n\nFor extract normal maps based on [DSINE](https://baegwangbin.github.io/DSINE/), you need to build a new environment:\n```\nconda create --name dsine python=3.10\nconda activate dsine\n\nconda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia\npython -m pip install geffnet\n```\n\n\nSimilar to Gaussian Splatting, we also use colmap to process data and you can follow [COLMAP website](https://colmap.github.io/) to install it.\n\n\n# Dataset\n\n<!-- Please download the Mip-NeRF 360 dataset from the [official webiste](https://jonbarron.info/mipnerf360/), the preprocessed DTU dataset from [2DGS](https://surfsplatting.github.io/), the proprocessed Tanks and Temples dataset from [here](https://huggingface.co/datasets/ZehaoYu/gaussian-opacity-fields/tree/main). You need to download the ground truth point clouds from the [DTU dataset](https://roboimagedata.compute.dtu.dk/?page_id=36) and save to `dtu_eval/Offical_DTU_Dataset` to evaluate the geometry reconstruction. For the [Tanks and Temples](https://www.tanksandtemples.org/download/) dataset, you need to download the ground truth point clouds, alignments and cropfiles and save to `eval_tnt/TrainingSet`, such as `eval_tnt/TrainingSet/Caterpillar/Caterpillar.ply`. -->\n\n\n## Tanks and Temples dataset\nYou can download the proprocessed Tanks and Temples dataset from [here](https://huggingface.co/hanlin-chen/VCR-GauS/resolve/main/tnt.zip?download=true). Or proprocess it by your self:\nDownload the data from [Tanks and Temples](https://tanksandtemples.org/download/) website.\nYou will also need to download additional [COLMAP/camera/alignment](https://drive.google.com/file/d/1jAr3IDvhVmmYeDWi0D_JfgiHcl70rzVE/view?resourcekey=) and the images of each scene.  \nThe file structure should look like (you need to move the downloaded images to folder `images_raw`):\n```\ntanks_and_temples\n├─ Barn\n│  ├─ Barn_COLMAP_SfM.log   (camera poses)\n│  ├─ Barn.json             (cropfiles)\n│  ├─ Barn.ply              (ground-truth point cloud)\n│  ├─ Barn_trans.txt        (colmap-to-ground-truth transformation)\n│  └─ images_raw            (raw input images downloaded from Tanks and Temples website)\n│     ├─ 000001.png\n│     ├─ 000002.png\n│     ...\n├─ Caterpillar\n│  ├─ ...\n...\n```\n#### 1. Colmap and bounding box json\nRun the following command to generate json and colmap files:\n```bash\n# Modify --tnt_path to be the Tanks and Temples root directory.\nsh bash_scripts/1_preprocess_tnt.sh\n```\n\n#### 2. Normal maps\nYou need to download the [code](https://github.com/baegwangbin/DSINE) and [model weight](https://drive.google.com/drive/folders/1t3LMJIIrSnCGwOEf53Cyg0lkSXd3M4Hm) of DSINE first. Then, modify **CODE_PATH** to be the DSINE root directory, **CKPT** to be the DSINE model path, **DATADIR** to be the TNT root directory in the bash script.\nRun the following command to generate normal maps:\n\n```bash\nsh bash_scripts/2_extract_normal_dsine.sh\n```\n\n#### 3. Semantic masks (optional)\n\nIf you don't want to use the semantic masks, you can set **optim.loss_weight.semantic=0** and skip the mask generation.\n\nYou need to download the [code](https://github.com/IDEA-Research/Grounded-Segment-Anything) and model of Grounded-SAM first. Then, install the environment based on 'Install without Docker' in the [webside](https://github.com/IDEA-Research/Grounded-Segment-Anything). Next, modify **GSAM_PATH** to be the GSAM root directory, **DATADIR** to be the TNT root directory in the bash script. Run the following command to generate semantic masks:\n\n```bash\nsh bash_scripts/3_extract_mask.sh\n```\n\n## Other datasets\nPlease download the Mip-NeRF 360 dataset from the official [webiste](https://jonbarron.info/mipnerf360/), the preprocessed DTU dataset from [2DGS](https://drive.google.com/drive/folders/1SJFgt8qhQomHX55Q4xSvYE2C6-8tFll9). And extract normal maps with DSINE following the above scripts. You can also use [GeoWizard](https://github.com/fuxiao0719/GeoWizard) to extract normal maps by following the script: 'bash_scripts/4_extract_normal_geow.sh', and please install the corresponding environment and download the code as well as model weights first.\n\n# Training and Evaluation\n## From the scratch:\n```\n# you might need to update the data path in the script accordingly\n\n# Tanks and Temples dataset\npython python_scripts/run_tnt.py\n\n# Mip-NeRF 360 dataset\npython python_scripts/run_mipnerf360.py\n```\n\n## Only eval the metrics\nWe have uploaded the extracted meshes, you can download and eval them by yourselves ([TNT](https://huggingface.co/hanlin-chen/VCR-GauS/resolve/main/tnt_mesh.zip?download=true) and [DTU](https://huggingface.co/Chiller3/VCR-GauS/resolve/main/dtu_mesh.zip?download=true)). You might need to update the **mesh and data path** in the script accordingly. And set **do_train** and **do_extract_mesh** to be False.\n\n```\n# Tanks and Temples dataset\npython python_scripts/run_tnt.py\n\n# DTU dataset\npython python_scripts/run_dtu.py\n```\n\n## Additional regularizations:\nWe also incorporate some regularizations, like depth distortion loss and normal consistency loss, following [2DGS](https://surfsplatting.github.io/) and [GOF](https://niujinshuchong.github.io/gaussian-opacity-fields/). You can play with it by:\n- normal consistency loss: setting optim.loss_weight.consistent_normal > 0;\n- depth distortion loss:\n  1. set optim.loss_weight.depth_var > 0\n  2. set NUM_DIST = 1 in submodules/diff-gaussian-rasterization/cuda_rasterizer/config.h, and reinstall diff-gaussian-rasterization\n\n\n# Custom Dataset\nWe use the same data format from 3DGS, please follow [here](https://github.com/graphdeco-inria/gaussian-splatting?tab=readme-ov-file#processing-your-own-scenes) to prepare the your dataset. Then you can train your model and extract a mesh.\n```\n# Generate bounding box\npython process_data/convert_data_to_json.py \\\n        --scene_type outdoor \\\n        --data_dir /your/data/path\n\n# Extract normal maps\n# Use DSINE:\npython -W ignore process_data/extract_normal.py \\\n    --dsine_path /your/dsine/code/path \\\n    --ckpt /your/ckpt/path \\\n    --img_path /your/data/path/images \\\n    --intrins_path /your/data/path/ \\\n    --output_path /your/data/path/normals\n\n# Or use GeoWizard\npython process_data/extract_normal_geo.py \\\n  --code_path ${CODE_PATH} \\\n  --input_dir /your/data/path/images/ \\\n  --output_dir /your/data/path/ \\\n  --ensemble_size 3 \\\n  --denoise_steps 10 \\\n  --seed 0 \\\n  --domain ${DOMAIN_TYPE} # outdoor indoor object\n\n# training\n# --model.resolution=2 for using downsampled images with factor 2\n# --model.use_decoupled_appearance=True to enable decoupled appearance modeling if your images has changing lighting conditions\npython train.py \\\n  --config=configs/reconstruct.yaml \\\n  --logdir=/your/log/path/ \\\n  --model.source_path=/your/data/path/ \\\n  --model.data_device=cpu \\\n  --model.resolution=2 \\\n  --wandb \\\n  --wandb_name vcr-gaus\"\n\n# extract the mesh after training\npython tools/depth2mesh.py \\\n  --voxel_size 5e-3 \\\n  --max_depth 8 \\\n  --clean \\\n  --cfg_path /your/gaussian/path/config.yaml\"\n```\n\n# Acknowledgements\nThis project is built upon [3DGS](https://github.com/graphdeco-inria/gaussian-splatting). Evaluation scripts for DTU and Tanks and Temples dataset are taken from [DTUeval-python](https://github.com/jzhangbs/DTUeval-python) and [TanksAndTemples](https://github.com/isl-org/TanksAndTemples/tree/master/python_toolbox/evaluation) respectively. We also utilize the normal estimation [DSINE](https://github.com/baegwangbin/DSINE) as well as [GeoWizard](https://fuxiao0719.github.io/projects/geowizard/), and semantic segmentation [SAM](https://github.com/facebookresearch/segment-anything) and [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything?tab=readme-ov-file#install-without-docker). In addition, we use the pruning method in [LightGaussin](https://lightgaussian.github.io/). We thank all the authors for their great work and repos. \n\n\n# Citation\nIf you find our code or paper useful, please cite\n```bibtex\n@article{chen2024vcr,\n  author    = {Chen, Hanlin and Wei, Fangyin and Li, Chen and Huang, Tianxin and Wang, Yunsong and Lee, Gim Hee},\n  title     = {VCR-GauS: View Consistent Depth-Normal Regularizer for Gaussian Surface Reconstruction},\n  journal   = {arXiv preprint arXiv:2406.05774},\n  year      = {2024},\n}\n```\n\nIf you find the flatten 3D Gaussian useful, please kindly cite\n```bibtex\n@article{chen2023neusg,\n  title={Neusg: Neural implicit surface reconstruction with 3d gaussian splatting guidance},\n  author={Chen, Hanlin and Li, Chen and Lee, Gim Hee},\n  journal={arXiv preprint arXiv:2312.00846},\n  year={2023}\n}\n```\n"
  },
  {
    "path": "arguments/__init__.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom argparse import ArgumentParser, Namespace\nimport sys\nimport os\n\nclass GroupParams:\n    pass\n\nclass ParamGroup:\n    def __init__(self, parser: ArgumentParser, name : str, fill_none = False):\n        group = parser.add_argument_group(name)\n        for key, value in vars(self).items():\n            shorthand = False\n            if key.startswith(\"_\"):\n                shorthand = True\n                key = key[1:]\n            t = type(value)\n            value = value if not fill_none else None \n            if shorthand:\n                if t == bool:\n                    group.add_argument(\"--\" + key, (\"-\" + key[0:1]), default=value, action=\"store_true\")\n                else:\n                    group.add_argument(\"--\" + key, (\"-\" + key[0:1]), default=value, type=t)\n            else:\n                if t == bool:\n                    group.add_argument(\"--\" + key, default=value, action=\"store_true\")\n                else:\n                    group.add_argument(\"--\" + key, default=value, type=t)\n\n    def extract(self, args):\n        group = GroupParams()\n        for arg in vars(args).items():\n            if arg[0] in vars(self) or (\"_\" + arg[0]) in vars(self):\n                setattr(group, arg[0], arg[1])\n        return group\n\nclass ModelParams(ParamGroup): \n    def __init__(self, parser, sentinel=False):\n        self.sh_degree = 3\n        self._source_path = \"\"\n        self._model_path = \"\"\n        self._images = \"images\"\n        self._resolution = -1\n        self._white_background = False\n        self.data_device = \"cuda\"\n        self.eval = False\n        super().__init__(parser, \"Loading Parameters\", sentinel)\n\n    def extract(self, args):\n        g = super().extract(args)\n        g.source_path = os.path.abspath(g.source_path)\n        return g\n\nclass PipelineParams(ParamGroup):\n    def __init__(self, parser):\n        self.convert_SHs_python = False\n        self.compute_cov3D_python = False\n        self.debug = False\n        super().__init__(parser, \"Pipeline Parameters\")\n\nclass OptimizationParams(ParamGroup):\n    def __init__(self, parser):\n        self.iterations = 30_000\n        self.position_lr_init = 0.00016\n        self.position_lr_final = 0.0000016\n        self.position_lr_delay_mult = 0.01\n        self.position_lr_max_steps = 30_000\n        self.feature_lr = 0.0025\n        self.opacity_lr = 0.05\n        self.scaling_lr = 0.005\n        self.rotation_lr = 0.001\n        self.percent_dense = 0.01\n        self.lambda_dssim = 0.2\n        self.densification_interval = 100\n        self.opacity_reset_interval = 3000\n        self.densify_from_iter = 500\n        self.densify_until_iter = 15_000\n        self.densify_grad_threshold = 0.0002\n        self.random_background = False\n        super().__init__(parser, \"Optimization Parameters\")\n\ndef get_combined_args(parser : ArgumentParser):\n    cmdlne_string = sys.argv[1:]\n    cfgfile_string = \"Namespace()\"\n    args_cmdline = parser.parse_args(cmdlne_string)\n\n    try:\n        cfgfilepath = os.path.join(args_cmdline.model_path, \"cfg_args\")\n        print(\"Looking for config file in\", cfgfilepath)\n        with open(cfgfilepath) as cfg_file:\n            print(\"Config file found: {}\".format(cfgfilepath))\n            cfgfile_string = cfg_file.read()\n    except TypeError:\n        print(\"Config file not found at\")\n        pass\n    args_cfgfile = eval(cfgfile_string)\n\n    merged_dict = vars(args_cfgfile).copy()\n    for k,v in vars(args_cmdline).items():\n        if v != None:\n            merged_dict[k] = v\n    return Namespace(**merged_dict)\n"
  },
  {
    "path": "bash_scripts/0_train.sh",
    "content": "GPU=0\nexport CUDA_VISIBLE_DEVICES=${GPU}\nls\n\n\nDATASET=tnt\nSCENE=Barn\nNAME=${SCENE}\n\nPROJECT=vcr_gaus\n\nTRIAL_NAME=vcr_gaus\n\nCFG=configs/${DATASET}/${SCENE}.yaml\n\nDIR=/your/log/path/${PROJECT}/${DATASET}/${NAME}/${TRIAL_NAME}\n\npython train.py \\\n    --config=${CFG} \\\n    --port=-1 \\\n    --logdir=${DIR} \\\n    --model.source_path=/your/data/path/${DATASET}/${SCENE}/ \\\n    --model.resolution=1 \\\n    --model.data_device=cpu \\\n    --wandb \\\n    --wandb_name ${PROJECT}\n"
  },
  {
    "path": "bash_scripts/1_preprocess_tnt.sh",
    "content": "echo \"Compute intrinsics, undistort images and generate json files. This may take a while\"\npython process_data/convert_tnt_to_json.py \\\n        --tnt_path /your/data/path \\\n        --run_colmap \\\n        --export_json "
  },
  {
    "path": "bash_scripts/2_extract_normal_dsine.sh",
    "content": "export CUDA_VISIBLE_DEVICES=0\n\nDOMAIN_TYPE=indoor\nDATADIR=/your/data/path\n\nCODE_PATH=/your/dsine/code/path\nCKPT=/your/dsine/code/path/checkpoints/dsine.pt\n\nfor SCENE in Barn Caterpillar Courthouse Ignatius Meetingroom Truck;\ndo\n    SCENE_PATH=${DATADIR}/${SCENE}\n    # dsine\n    python -W ignore process_data/extract_normal.py \\\n            --dsine_path ${CODE_PATH} \\\n            --ckpt ${CKPT} \\\n            --img_path ${SCENE_PATH}/images \\\n            --intrins_path ${SCENE_PATH}/ \\\n            --output_path ${SCENE_PATH}/normals\ndone"
  },
  {
    "path": "bash_scripts/3_extract_mask.sh",
    "content": "export CUDA_VISIBLE_DEVICES=0\n\nDATADIR=/your/data/path\nGSAM_PATH=~/code/gsam\nCKPT_PATH=${GSAM_PATH}\n\nfor SCENE in Barn Caterpillar Courthouse Ignatius Meetingroom Truck;\ndo\n    SCENE_PATH=${DATADIR}/${SCENE}\n    # meething room scene_tye: indoor, others: outdoor\n        if [ ${SCENE} = \"Meetingroom\" ]; then\n            SCENE_TYPE=\"indoor\"\n        else\n            SCENE_TYPE=\"outdoor\"\n        fi\n    python -W ignore process_data/extract_mask.py \\\n            --config ${GSAM_PATH}/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \\\n            --grounded_checkpoint ${CKPT_PATH}/groundingdino_swint_ogc.pth \\\n            --sam_hq_checkpoint ${CKPT_PATH}/sam_hq_vit_h.pth \\\n            --gsam_path ${GSAM_PATH} \\\n            --use_sam_hq \\\n            --input_image ${SCENE_PATH}/images/ \\\n            --output_dir ${SCENE_PATH}/masks \\\n            --box_threshold 0.5 \\\n            --text_threshold 0.2 \\\n            --scene ${SCENE} \\\n            --scene_type ${SCENE_TYPE} \\\n            --device \"cuda\"\ndone\n"
  },
  {
    "path": "bash_scripts/4_extract_normal_geow.sh",
    "content": "export CUDA_VISIBLE_DEVICES=0\n\n# DOMAIN_TYPE=outdoor\n# DOMAIN_TYPE=indoor\nDOMAIN_TYPE=object\nDATADIR=/your/data/path/DTU_mask\n\nCODE_PATH=/your/geowizard/path\n\n\nfor SCENE in scan106  scan114  scan122  scan37  scan55  scan65  scan83 scan105    scan110  scan118  scan24   scan40  scan63  scan69  scan97;\ndo\n    SCENE_PATH=${DATADIR}/${SCENE}\n    python process_data/extract_normal_geo.py \\\n        --code_path ${CODE_PATH} \\\n        --input_dir ${SCENE_PATH}/images/ \\\n        --output_dir ${SCENE_PATH}/ \\\n        --ensemble_size 3 \\\n        --denoise_steps 10 \\\n        --seed 0 \\\n        --domain ${DOMAIN_TYPE}\ndone"
  },
  {
    "path": "bash_scripts/convert.sh",
    "content": "SCENE=Truck\nDATA_ROOT=/your/data/path/${SCENE}\n\npython convert.py -s $DATA_ROOT # [--resize] #If not resizing, ImageMagick is not needed\n\n\n"
  },
  {
    "path": "bash_scripts/install.sh",
    "content": "env=vcr\nconda create -n $env -y python=3.10\nconda activate $env\npip install -e \".[train]\"\nexport CUDA_HOME=/usr/local/cuda-11.2\npip install -r requirements.txt"
  },
  {
    "path": "configs/360_v2/base.yaml",
    "content": "_parent_: configs/reconstruct.yaml\n\nmodel:\n    eval: True\n    llffhold: 8\n    split: False\n\noptim:\n    mask_depth_thr: 1\n    densify_large:\n        percent_dense: 5e-2\n        sample_cams:\n            random: False\n            num: 100\n    loss_weight:\n        semantic: 0\n        l1_scale: 1"
  },
  {
    "path": "configs/config.py",
    "content": "'''\n-----------------------------------------------------------------------------\nCopyright (c) 2023, NVIDIA CORPORATION. All rights reserved.\n\nNVIDIA CORPORATION and its licensors retain all intellectual property\nand proprietary rights in and to this software, related documentation\nand any modifications thereto. Any use, reproduction, disclosure or\ndistribution of this software and related documentation without an express\nlicense agreement from NVIDIA CORPORATION is strictly prohibited.\n-----------------------------------------------------------------------------\n'''\n\nimport collections\nimport functools\nimport os\nimport re\n\nimport yaml\nfrom tools.distributed import master_only_print as print\nfrom tools.termcolor import cyan, green, yellow\n\nDEBUG = False\nUSE_JIT = False\n\n\nclass AttrDict(dict):\n    \"\"\"Dict as attribute trick.\"\"\"\n\n    def __init__(self, *args, **kwargs):\n        super(AttrDict, self).__init__(*args, **kwargs)\n        self.__dict__ = self\n        for key, value in self.__dict__.items():\n            if isinstance(value, dict):\n                self.__dict__[key] = AttrDict(value)\n            elif isinstance(value, (list, tuple)):\n                if value and isinstance(value[0], dict):\n                    self.__dict__[key] = [AttrDict(item) for item in value]\n                else:\n                    self.__dict__[key] = value\n\n    def yaml(self):\n        \"\"\"Convert object to yaml dict and return.\"\"\"\n        yaml_dict = {}\n        for key, value in self.__dict__.items():\n            if isinstance(value, AttrDict):\n                yaml_dict[key] = value.yaml()\n            elif isinstance(value, list):\n                if value and isinstance(value[0], AttrDict):\n                    new_l = []\n                    for item in value:\n                        new_l.append(item.yaml())\n                    yaml_dict[key] = new_l\n                else:\n                    yaml_dict[key] = value\n            else:\n                yaml_dict[key] = value\n        return yaml_dict\n\n    def __repr__(self):\n        \"\"\"Print all variables.\"\"\"\n        ret_str = []\n        for key, value in self.__dict__.items():\n            if isinstance(value, AttrDict):\n                ret_str.append('{}:'.format(key))\n                child_ret_str = value.__repr__().split('\\n')\n                for item in child_ret_str:\n                    ret_str.append('    ' + item)\n            elif isinstance(value, list):\n                if value and isinstance(value[0], AttrDict):\n                    ret_str.append('{}:'.format(key))\n                    for item in value:\n                        # Treat as AttrDict above.\n                        child_ret_str = item.__repr__().split('\\n')\n                        for item in child_ret_str:\n                            ret_str.append('    ' + item)\n                else:\n                    ret_str.append('{}: {}'.format(key, value))\n            else:\n                ret_str.append('{}: {}'.format(key, value))\n        return '\\n'.join(ret_str)\n\n\nclass Config(AttrDict):\n    r\"\"\"Configuration class. This should include every human specifiable\n    hyperparameter values for your training.\"\"\"\n\n    def __init__(self, filename=None, verbose=False):\n        super(Config, self).__init__()\n        self.source_filename = filename\n\n        # Load the base configuration file.\n        base_filename = os.path.join(\n            os.path.dirname(__file__), './config_base.yaml'\n        )\n        cfg_base = self.load_config(base_filename)\n        recursive_update(self, cfg_base)\n\n        # Update with given configurations.\n        cfg_dict = self.load_config(filename)\n        recursive_update(self, cfg_dict)\n\n        if verbose:\n            print(' imaginaire config '.center(80, '-'))\n            print(self.__repr__())\n            print(''.center(80, '-'))\n\n    def load_config(self, filename):\n        # Update with given configurations.\n        assert os.path.exists(filename), f'File {filename} not exist.'\n        yaml_loader = yaml.SafeLoader\n        yaml_loader.add_implicit_resolver(\n            u'tag:yaml.org,2002:float',\n            re.compile(u'''^(?:\n             [-+]?(?:[0-9][0-9_]*)\\\\.[0-9_]*(?:[eE][-+]?[0-9]+)?\n            |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)\n            |\\\\.[0-9_]+(?:[eE][-+][0-9]+)?\n            |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\\\.[0-9_]*\n            |[-+]?\\\\.(?:inf|Inf|INF)\n            |\\\\.(?:nan|NaN|NAN))$''', re.X),\n            list(u'-+0123456789.'))\n        try:\n            with open(filename) as file:\n                cfg_dict = yaml.load(file, Loader=yaml_loader)\n                cfg_dict = AttrDict(cfg_dict)\n        except EnvironmentError:\n            print(f'Please check the file with name of \"{filename}\"')\n        # Inherit configurations from parent\n        parent_key = \"_parent_\"\n        if parent_key in cfg_dict:\n            parent_filename = cfg_dict.pop(parent_key)\n            cfg_parent = self.load_config(parent_filename)\n            recursive_update(cfg_parent, cfg_dict)\n            cfg_dict = cfg_parent\n        return cfg_dict\n\n    def print_config(self, level=0):\n        \"\"\"Recursively print the configuration (with termcolor).\"\"\"\n        for key, value in sorted(self.items()):\n            if isinstance(value, (dict, Config)):\n                print(\"   \" * level + cyan(\"* \") + green(key) + \":\")\n                Config.print_config(value, level + 1)\n            else:\n                print(\"   \" * level + cyan(\"* \") + green(key) + \":\", yellow(value))\n\n    def save_config(self, logdir):\n        \"\"\"Save the final configuration to a yaml file.\"\"\"\n        cfg_fname = f\"{logdir}/config.yaml\"\n        with open(cfg_fname, \"w\") as file:\n            yaml.safe_dump(self.yaml(), file, default_flow_style=False, indent=4)\n\ndef rsetattr(obj, attr, val):\n    \"\"\"Recursively find object and set value\"\"\"\n    pre, _, post = attr.rpartition('.')\n    return setattr(rgetattr(obj, pre) if pre else obj, post, val)\n\n\ndef rgetattr(obj, attr, *args):\n    \"\"\"Recursively find object and return value\"\"\"\n\n    def _getattr(obj, attr):\n        r\"\"\"Get attribute.\"\"\"\n        return getattr(obj, attr, *args)\n\n    return functools.reduce(_getattr, [obj] + attr.split('.'))\n\n\ndef recursive_update(d, u):\n    \"\"\"Recursively update AttrDict d with AttrDict u\"\"\"\n    for key, value in u.items():\n        if isinstance(value, collections.abc.Mapping):\n            d.__dict__[key] = recursive_update(d.get(key, AttrDict({})), value)\n        elif isinstance(value, (list, tuple)):\n            if value and isinstance(value[0], dict):\n                d.__dict__[key] = [AttrDict(item) for item in value]\n            else:\n                d.__dict__[key] = value\n        else:\n            d.__dict__[key] = value\n    return d\n\n\ndef recursive_update_strict(d, u, stack=[]):\n    \"\"\"Recursively update AttrDict d with AttrDict u with strict matching\"\"\"\n    for key, value in u.items():\n        if key not in d:\n            key_full = \".\".join(stack + [key])\n            raise KeyError(f\"The input key '{key_full}; does not exist in the config files.\")\n        if isinstance(value, collections.abc.Mapping):\n            d.__dict__[key] = recursive_update_strict(d.get(key, AttrDict({})), value, stack + [key])\n        elif isinstance(value, (list, tuple)):\n            if value and isinstance(value[0], dict):\n                d.__dict__[key] = [AttrDict(item) for item in value]\n            else:\n                d.__dict__[key] = value\n        else:\n            d.__dict__[key] = value\n    return d\n\n\ndef parse_cmdline_arguments(args):\n    \"\"\"\n    Parse arguments from command line.\n    Syntax: --key1.key2.key3=value --> value\n            --key1.key2.key3=      --> None\n            --key1.key2.key3       --> True\n            --key1.key2.key3!      --> False\n    \"\"\"\n    cfg_cmd = {}\n    for arg in args:\n        assert arg.startswith(\"--\")\n        if \"=\" not in arg[2:]:\n            key_str, value = (arg[2:-1], \"false\") if arg[-1] == \"!\" else (arg[2:], \"true\")\n        else:\n            key_str, value = arg[2:].split(\"=\")\n        keys_sub = key_str.split(\".\")\n        cfg_sub = cfg_cmd\n        for k in keys_sub[:-1]:\n            cfg_sub.setdefault(k, {})\n            cfg_sub = cfg_sub[k]\n        assert keys_sub[-1] not in cfg_sub, keys_sub[-1]\n        cfg_sub[keys_sub[-1]] = yaml.safe_load(value)\n    return cfg_cmd\n"
  },
  {
    "path": "configs/config_base.yaml",
    "content": "logdir: \"/your/log/path/debug/\"\nip: 127.0.0.1\nport: -1\ndetect_anomaly: False\nsilent: 0\nseed: 0\n\nmodel:\n    sh_degree: 3\n    source_path: \"/your/data/path/tnt/Barn/\"\n    model_path: \"/your/log/path/\"\n    images: \"images\"\n    resolution: -1\n    white_background: False\n    data_device: \"cuda\"\n    eval: False\n    llffhold: 1\n    init_ply: \"sparse/points3D.ply\"\n    max_init_points:\n    split: False\n    sphere: False\n    load_depth: False\n    load_normal: False\n    load_mask: False\n    normal_folder: 'normals'\n    depth_folder: 'depths'\n    use_decoupled_appearance: False\n    ch_sem_feat: 0\n    num_cls: 0\n    max_mem: 22\n    load_mask: False\n    use_decoupled_appearance: False\n    use_decoupled_dnormal: False\n    ratio: 0\n    mesh:\n        voxel_size: 3e-3\n    depth_type: 'traditional'\n\noptim:\n    iterations: 30000\n    position_lr_init: 0.00016\n    position_lr_final: 0.0000016\n    position_lr_delay_mult: 0.01\n    position_lr_max_steps: 30000\n    feature_lr: 0.0025\n    sdf_lr: 0.001\n    weight_decay: 1e-2\n    opacity_lr: 0.05\n    scaling_lr: 0.005\n    rotation_lr: 0.001\n    appearance_embeddings_lr: 0.001\n    appearance_network_lr: 0.001\n    cls_lr: 5e-4\n    percent_dense: 0.01\n    densification_interval: 100\n    opacity_reset_interval: 3000\n    densify_from_iter: 500\n    densify_until_iter: 15000\n    densify_grad_threshold: 0.0005\n    random_background: False\n    rand_pts: 20000\n    edge_thr: 0\n    mask_depth_thr: 0\n    loss_weight:\n        l1: 0.8\n        ssim: 0.2\n        distortion: 0.\n        semantic: 0\n        mono_depth: 0\n        mono_normal: 0\n        depth_normal: 0\n    prune:\n        iterations: []\n        percent: 0.5\n        decay: 0.6\n        v_pow: 0.1\n\npipline:\n    convert_SHs_python: False\n    compute_cov3D_python: False\n    debug: False\n    \ndata:\n    name: dummy\n\ntrain:\n    test_iterations: [7000, 30000]\n    save_iterations: [7000, 30000]\n    checkpoint_iterations: [30000]\n    save_splat: False\n    start_checkpoint: \n    debug_from: -1\n\n"
  },
  {
    "path": "configs/dtu/base.yaml",
    "content": "_parent_: configs/reconstruct.yaml\n\nmodel:\n    use_decoupled_appearance: False\n    use_decoupled_dnormal: False\n    normal_folder: 'normal_npz_indoor'\n    eval: False\n\noptim:\n    exp_t: 0.01\n    mask_depth_thr: 0\n    loss_weight:\n        l1_scale: 0.5\n    consistent_normal_from_iter: 15000\n    close_depth_from_iter: 15000\n    densify_large:\n        percent_dense: 1e-2\n        sample_cams:\n            random: False\n            num: 30\n    loss_weight:\n        semantic: 0\n        depth_normal: 0\n        mono_normal: 0.01\n        consistent_normal: 0.05\n        distortion: 1000\n        depth_var: 0\n    random_background: False\n        "
  },
  {
    "path": "configs/dtu/dtu_scan24.yaml",
    "content": "_parent_: configs/dtu/base.yaml\n"
  },
  {
    "path": "configs/reconstruct.yaml",
    "content": "_parent_: configs/config_base.yaml\n\n\nmodel:\n    load_mask: False\n    use_decoupled_appearance: False\n    use_decoupled_dnormal: False\n    ch_sem_feat: 2\n    num_cls: 2\n    depth_type: 'intersection'\noptim:\n    mask_depth_thr: 0.8\n    edge_thr: 0\n    exp_t: 0.01\n    cos_thr: -1\n    close_depth_from_iter: 0\n    normal_from_iter: 0\n    dnormal_from_iter: 0\n    consistent_normal_from_iter: 0\n    curv_from_iter: 0\n    loss_weight:\n        l1: 0.8\n        ssim: 0.2\n        l1_scale: 1\n        entropy: 0\n        depth_var: 0.\n        mono_depth: 0\n        mono_normal: 0.01\n        depth_normal: 0.01\n        consistent_normal: 0\n    prune:\n        iterations: [15000, 25000]\n        percent: 0.5\n        decay: 0.6\n        v_pow: 0.1\n    densify_large:\n        percent_dense: 2e-3\n        interval: 1\n        sample_cams:\n            random: True\n            num: 200\n            up: True\n            around: True\n            look_mode: 'target'\n    random_background: True\n\n\ntrain:\n    checkpoint_iterations: []\n    save_mesh: False\n    save_iterations: [30000]"
  },
  {
    "path": "configs/scannetpp/base.yaml",
    "content": "_parent_: configs/reconstruct.yaml\n\nmodel:\n    split: True\n    eval: True\n    use_decoupled_appearance: False\n    use_decoupled_dnormal: False\n    mesh:\n        voxel_size: 1.5e-2\n\noptim:\n    mask_depth_thr: 0\n    curv_from_iter: 15000\n    densify_large:\n        percent_dense: 1e-2\n        sample_cams:\n            random: False\n    loss_weight:\n        semantic: 0\n        curv: 0.05"
  },
  {
    "path": "configs/tnt/Barn.yaml",
    "content": "_parent_: configs/tnt/base.yaml\n"
  },
  {
    "path": "configs/tnt/Caterpillar.yaml",
    "content": "_parent_: configs/tnt/base.yaml\n"
  },
  {
    "path": "configs/tnt/Courthouse.yaml",
    "content": "_parent_: configs/tnt/base.yaml\n"
  },
  {
    "path": "configs/tnt/Ignatius.yaml",
    "content": "_parent_: configs/tnt/base.yaml\n"
  },
  {
    "path": "configs/tnt/Meetingroom.yaml",
    "content": "_parent_: configs/tnt/base.yaml\n\noptim:\n    exp_t: 1e-3\n    mask_depth_thr: 0\n    densify_large:\n        percent_dense: 5e-3\n        sample_cams:\n            random: False\n    loss_weight:\n        semantic: 0\nmodel:\n    num_cls: 3\n    use_decoupled_appearance: False"
  },
  {
    "path": "configs/tnt/Truck.yaml",
    "content": "_parent_: configs/tnt/base.yaml\n"
  },
  {
    "path": "configs/tnt/base.yaml",
    "content": "_parent_: configs/reconstruct.yaml\n\nmodel:\n    use_decoupled_appearance: True\n    use_decoupled_dnormal: False\n    eval: False\n    llffhold: 5\n\noptim:\n    exp_t: 5e-3\n    loss_weight:\n        depth_normal: 0.015\n        semantic: 0.005\n        l1_scale: 1"
  },
  {
    "path": "environment.yml",
    "content": "name: fast_render\nchannels:\n  - pytorch\n  - nvidia\n  - conda-forge\n  - defaults\ndependencies:\n  - python=3.10\n  - pytorch==2.0.1\n  - torchvision==0.15.2\n  - torchaudio==2.0.2\n  - pytorch-cuda=11.8\n  - pip:\n    - open3d\n    - plyfile\n    - ninja\n    - GPUtil\n    - opencv-python\n    - lpips\n    - trimesh\n    - pymeshlab\n    - termcolor\n    - wandb\n    - imageio\n    - scikit-image\n    - torchmetrics\n    - mediapy\n    - \"git+https://github.com/facebookresearch/pytorch3d.git\"\n    - submodules/diff-gaussian-rasterization\n    - submodules/simple-knn"
  },
  {
    "path": "evaluation/crop_mesh.py",
    "content": "import os\nimport json\nimport plyfile\nimport argparse\n# import open3d as o3d\nimport numpy as np\n# from tqdm import tqdm\nimport trimesh\nfrom sklearn.cluster import DBSCAN\n\n\ndef align_gt_with_cam(pts, trans):\n    trans_inv = np.linalg.inv(trans)\n    pts_aligned = pts @ trans_inv[:3, :3].transpose(-1, -2) + trans_inv[:3, -1]\n    return pts_aligned\n\n\ndef main(args):\n    assert os.path.exists(args.ply_path), f\"PLY file {args.ply_path} does not exist.\"\n    gt_trans = np.loadtxt(args.align_path)\n    \n    mesh_rec = trimesh.load(args.ply_path, process=False)\n    mesh_gt = trimesh.load(args.gt_path, process=False)\n    \n    mesh_gt.vertices = align_gt_with_cam(mesh_gt.vertices, gt_trans)\n    \n    to_align, _ = trimesh.bounds.oriented_bounds(mesh_gt)\n    mesh_gt.vertices = (to_align[:3, :3] @ mesh_gt.vertices.T + to_align[:3, 3:]).T\n    mesh_rec.vertices = (to_align[:3, :3] @ mesh_rec.vertices.T + to_align[:3, 3:]).T\n    \n    min_points = mesh_gt.vertices.min(axis=0)\n    max_points = mesh_gt.vertices.max(axis=0)\n\n    mask_min = (mesh_rec.vertices - min_points[None]) > 0\n    mask_max = (mesh_rec.vertices - max_points[None]) < 0\n\n    mask = np.concatenate((mask_min, mask_max), axis=1).all(axis=1)\n    face_mask = mask[mesh_rec.faces].all(axis=1)\n\n    mesh_rec.update_vertices(mask)\n    mesh_rec.update_faces(face_mask)\n    \n    mesh_rec.vertices = (to_align[:3, :3].T @ mesh_rec.vertices.T - to_align[:3, :3].T @ to_align[:3, 3:]).T\n    mesh_gt.vertices = (to_align[:3, :3].T @ mesh_gt.vertices.T - to_align[:3, :3].T @ to_align[:3, 3:]).T\n    \n    # save mesh_rec and mesh_rec in args.out_path\n    mesh_rec.export(args.out_path)\n    \n    # downsample mesh_gt\n    \n    idx = np.random.choice(np.arange(len(mesh_gt.vertices)), 5000000)\n    mesh_gt.vertices = mesh_gt.vertices[idx]\n    mesh_gt.colors = mesh_gt.colors[idx]\n    \n    mesh_gt.export(args.gt_path.replace('.ply', '_trans.ply'))\n    \n    \n    return\n\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--gt_path\",\n        type=str,\n        default='/your/path//Barn_GT.ply',\n        help=\"path to a dataset/scene directory containing X.json, X.ply, ...\",\n    )\n    parser.add_argument(\n        \"--align_path\",\n        type=str,\n        default='/your/path//Barn_trans.txt',\n        help=\"path to a dataset/scene directory containing X.json, X.ply, ...\",\n    )\n    parser.add_argument(\n        \"--ply_path\",\n        type=str,\n        default='/your/path//Barn_lowres.ply',\n        help=\"path to reconstruction ply file\",\n    )\n    parser.add_argument(\n        \"--scene\",\n        type=str,\n        default='Barn',\n        help=\"path to reconstruction ply file\",\n    )\n    parser.add_argument(\n        \"--out_path\",\n        type=str,\n        default='/your/path//Barn_lowres_crop.ply',\n        help=\n        \"output directory, default: an evaluation directory is created in the directory of the ply file\",\n    )\n    args = parser.parse_args()\n    \n    main(args)"
  },
  {
    "path": "evaluation/eval_dtu/eval.py",
    "content": "# adapted from https://github.com/jzhangbs/DTUeval-python\nimport numpy as np\nimport open3d as o3d\nimport sklearn.neighbors as skln\nfrom tqdm import tqdm\nfrom scipy.io import loadmat\nimport multiprocessing as mp\nimport argparse\n\ndef sample_single_tri(input_):\n    n1, n2, v1, v2, tri_vert = input_\n    c = np.mgrid[:n1+1, :n2+1]\n    c += 0.5\n    c[0] /= max(n1, 1e-7)\n    c[1] /= max(n2, 1e-7)\n    c = np.transpose(c, (1,2,0))\n    k = c[c.sum(axis=-1) < 1]  # m2\n    q = v1 * k[:,:1] + v2 * k[:,1:] + tri_vert\n    return q\n\ndef write_vis_pcd(file, points, colors):\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(points)\n    pcd.colors = o3d.utility.Vector3dVector(colors)\n    o3d.io.write_point_cloud(file, pcd)\n\nif __name__ == '__main__':\n    mp.freeze_support()\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--data', type=str, default='data_in.ply')\n    parser.add_argument('--scan', type=int, default=1)\n    parser.add_argument('--mode', type=str, default='mesh', choices=['mesh', 'pcd'])\n    parser.add_argument('--dataset_dir', type=str, default='.')\n    parser.add_argument('--vis_out_dir', type=str, default='.')\n    parser.add_argument('--downsample_density', type=float, default=0.2)\n    parser.add_argument('--patch_size', type=float, default=60)\n    parser.add_argument('--max_dist', type=float, default=20)\n    parser.add_argument('--visualize_threshold', type=float, default=10)\n    args = parser.parse_args()\n\n    thresh = args.downsample_density\n    if args.mode == 'mesh':\n        pbar = tqdm(total=9)\n        pbar.set_description('read data mesh')\n        data_mesh = o3d.io.read_triangle_mesh(args.data)\n\n        vertices = np.asarray(data_mesh.vertices)\n        triangles = np.asarray(data_mesh.triangles)\n        tri_vert = vertices[triangles]\n\n        pbar.update(1)\n        pbar.set_description('sample pcd from mesh')\n        v1 = tri_vert[:,1] - tri_vert[:,0]\n        v2 = tri_vert[:,2] - tri_vert[:,0]\n        l1 = np.linalg.norm(v1, axis=-1, keepdims=True)\n        l2 = np.linalg.norm(v2, axis=-1, keepdims=True)\n        area2 = np.linalg.norm(np.cross(v1, v2), axis=-1, keepdims=True)\n        non_zero_area = (area2 > 0)[:,0]\n        l1, l2, area2, v1, v2, tri_vert = [\n            arr[non_zero_area] for arr in [l1, l2, area2, v1, v2, tri_vert]\n        ]\n        thr = thresh * np.sqrt(l1 * l2 / area2)\n        n1 = np.floor(l1 / thr)\n        n2 = np.floor(l2 / thr)\n\n        with mp.Pool() as mp_pool:\n            new_pts = mp_pool.map(sample_single_tri, ((n1[i,0], n2[i,0], v1[i:i+1], v2[i:i+1], tri_vert[i:i+1,0]) for i in range(len(n1))), chunksize=1024)\n\n        new_pts = np.concatenate(new_pts, axis=0)\n        data_pcd = np.concatenate([vertices, new_pts], axis=0)\n    \n    elif args.mode == 'pcd':\n        pbar = tqdm(total=8)\n        pbar.set_description('read data pcd')\n        data_pcd_o3d = o3d.io.read_point_cloud(args.data)\n        data_pcd = np.asarray(data_pcd_o3d.points)\n\n    pbar.update(1)\n    pbar.set_description('random shuffle pcd index')\n    shuffle_rng = np.random.default_rng()\n    shuffle_rng.shuffle(data_pcd, axis=0)\n\n    pbar.update(1)\n    pbar.set_description('downsample pcd')\n    nn_engine = skln.NearestNeighbors(n_neighbors=1, radius=thresh, algorithm='kd_tree', n_jobs=-1)\n    nn_engine.fit(data_pcd)\n    rnn_idxs = nn_engine.radius_neighbors(data_pcd, radius=thresh, return_distance=False)\n    mask = np.ones(data_pcd.shape[0], dtype=np.bool_)\n    for curr, idxs in enumerate(rnn_idxs):\n        if mask[curr]:\n            mask[idxs] = 0\n            mask[curr] = 1\n    data_down = data_pcd[mask]\n\n    pbar.update(1)\n    pbar.set_description('masking data pcd')\n    obs_mask_file = loadmat(f'{args.dataset_dir}/ObsMask/ObsMask{args.scan}_10.mat')\n    ObsMask, BB, Res = [obs_mask_file[attr] for attr in ['ObsMask', 'BB', 'Res']]\n    BB = BB.astype(np.float32)\n\n    patch = args.patch_size\n    inbound = ((data_down >= BB[:1]-patch) & (data_down < BB[1:]+patch*2)).sum(axis=-1) ==3\n    data_in = data_down[inbound]\n\n    data_grid = np.around((data_in - BB[:1]) / Res).astype(np.int32)\n    grid_inbound = ((data_grid >= 0) & (data_grid < np.expand_dims(ObsMask.shape, 0))).sum(axis=-1) ==3\n    data_grid_in = data_grid[grid_inbound]\n    in_obs = ObsMask[data_grid_in[:,0], data_grid_in[:,1], data_grid_in[:,2]].astype(np.bool_)\n    data_in_obs = data_in[grid_inbound][in_obs]\n\n    pbar.update(1)\n    pbar.set_description('read STL pcd')\n    stl_pcd = o3d.io.read_point_cloud(f'{args.dataset_dir}/Points/stl/stl{args.scan:03}_total.ply')\n    stl = np.asarray(stl_pcd.points)\n\n    pbar.update(1)\n    pbar.set_description('compute data2stl')\n    nn_engine.fit(stl)\n    dist_d2s, idx_d2s = nn_engine.kneighbors(data_in_obs, n_neighbors=1, return_distance=True)\n    max_dist = args.max_dist\n    mean_d2s = dist_d2s[dist_d2s < max_dist].mean()\n\n    pbar.update(1)\n    pbar.set_description('compute stl2data')\n    ground_plane = loadmat(f'{args.dataset_dir}/ObsMask/Plane{args.scan}.mat')['P']\n\n    stl_hom = np.concatenate([stl, np.ones_like(stl[:,:1])], -1)\n    above = (ground_plane.reshape((1,4)) * stl_hom).sum(-1) > 0\n    stl_above = stl[above]\n\n    nn_engine.fit(data_in)\n    dist_s2d, idx_s2d = nn_engine.kneighbors(stl_above, n_neighbors=1, return_distance=True)\n    mean_s2d = dist_s2d[dist_s2d < max_dist].mean()\n\n    pbar.update(1)\n    pbar.set_description('visualize error')\n    vis_dist = args.visualize_threshold\n    R = np.array([[1,0,0]], dtype=np.float64)\n    G = np.array([[0,1,0]], dtype=np.float64)\n    B = np.array([[0,0,1]], dtype=np.float64)\n    W = np.array([[1,1,1]], dtype=np.float64)\n    data_color = np.tile(B, (data_down.shape[0], 1))\n    data_alpha = dist_d2s.clip(max=vis_dist) / vis_dist\n    data_color[ np.where(inbound)[0][grid_inbound][in_obs] ] = R * data_alpha + W * (1-data_alpha)\n    data_color[ np.where(inbound)[0][grid_inbound][in_obs][dist_d2s[:,0] >= max_dist] ] = G\n    write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_d2s.ply', data_down, data_color)\n    stl_color = np.tile(B, (stl.shape[0], 1))\n    stl_alpha = dist_s2d.clip(max=vis_dist) / vis_dist\n    stl_color[ np.where(above)[0] ] = R * stl_alpha + W * (1-stl_alpha)\n    stl_color[ np.where(above)[0][dist_s2d[:,0] >= max_dist] ] = G\n    write_vis_pcd(f'{args.vis_out_dir}/vis_{args.scan:03}_s2d.ply', stl, stl_color)\n\n    pbar.update(1)\n    pbar.set_description('done')\n    pbar.close()\n    over_all = (mean_d2s + mean_s2d) / 2\n    print(mean_d2s, mean_s2d, over_all)\n    \n    import json\n    with open(f'{args.vis_out_dir}/results.json', 'w') as fp:\n        json.dump({\n            'mean_d2s': mean_d2s,\n            'mean_s2d': mean_s2d,\n            'overall': over_all,\n        }, fp, indent=True)\n\n\n"
  },
  {
    "path": "evaluation/eval_dtu/evaluate_single_scene.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport cv2\nimport numpy as np\nimport os\nimport glob\nfrom skimage.morphology import binary_dilation, disk\nimport argparse\n\nimport trimesh\nfrom pathlib import Path\nfrom tqdm import tqdm\n\nimport sys\n\nsys.path.append(os.getcwd())\n\nimport evaluation.eval_dtu.render_utils as rend_util\n\n\ndef cull_scan(scan, mesh_path, result_mesh_file, instance_dir):\n    \n    # load poses\n    image_dir = '{0}/images'.format(instance_dir)\n    image_paths = sorted(glob.glob(os.path.join(image_dir, \"*.png\")))\n    n_images = len(image_paths)\n    cam_file = '{0}/cameras.npz'.format(instance_dir)\n    camera_dict = np.load(cam_file)\n    scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]\n    world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]\n\n    intrinsics_all = []\n    pose_all = []\n    for scale_mat, world_mat in zip(scale_mats, world_mats):\n        P = world_mat @ scale_mat\n        P = P[:3, :4]\n        intrinsics, pose = rend_util.load_K_Rt_from_P(None, P)\n        intrinsics_all.append(torch.from_numpy(intrinsics).float())\n        pose_all.append(torch.from_numpy(pose).float())\n    \n    # load mask\n    mask_dir = '{0}/mask'.format(instance_dir)\n    mask_paths = sorted(glob.glob(os.path.join(mask_dir, \"*.png\")))\n    masks = []\n    for p in mask_paths:\n        mask = cv2.imread(p)\n        masks.append(mask)\n\n    # hard-coded image shape\n    W, H = 1600, 1200\n\n    # load mesh\n    mesh = trimesh.load(mesh_path)\n    \n    # load transformation matrix\n\n    vertices = mesh.vertices\n\n    # project and filter\n    vertices = torch.from_numpy(vertices).cuda()\n    vertices = torch.cat((vertices, torch.ones_like(vertices[:, :1])), dim=-1)\n    vertices = vertices.permute(1, 0)\n    vertices = vertices.float()\n\n    sampled_masks = []\n    for i in tqdm(range(n_images),  desc=\"Culling mesh given masks\"):\n        pose = pose_all[i]\n        w2c = torch.inverse(pose).cuda()\n        intrinsic = intrinsics_all[i].cuda()\n\n        with torch.no_grad():\n            # transform and project\n            cam_points = intrinsic @ w2c @ vertices\n            pix_coords = cam_points[:2, :] / (cam_points[2, :].unsqueeze(0) + 1e-6)\n            pix_coords = pix_coords.permute(1, 0)\n            pix_coords[..., 0] /= W - 1\n            pix_coords[..., 1] /= H - 1\n            pix_coords = (pix_coords - 0.5) * 2\n            valid = ((pix_coords > -1. ) & (pix_coords < 1.)).all(dim=-1).float()\n            \n            # dialate mask similar to unisurf\n            maski = masks[i][:, :, 0].astype(np.float32) / 256.\n            maski = torch.from_numpy(binary_dilation(maski, disk(24))).float()[None, None].cuda()\n\n            sampled_mask = F.grid_sample(maski, pix_coords[None, None], mode='nearest', padding_mode='zeros', align_corners=True)[0, -1, 0]\n\n            sampled_mask = sampled_mask + (1. - valid)\n            sampled_masks.append(sampled_mask)\n\n    sampled_masks = torch.stack(sampled_masks, -1)\n    # filter\n    \n    mask = (sampled_masks > 0.).all(dim=-1).cpu().numpy()\n    face_mask = mask[mesh.faces].all(axis=1)\n\n    mesh.update_vertices(mask)\n    mesh.update_faces(face_mask)\n    \n    # transform vertices to world \n    scale_mat = scale_mats[0]\n    mesh.vertices = mesh.vertices * scale_mat[0, 0] + scale_mat[:3, 3][None]\n    \n    # Taking the biggest connected component\n    print(\"Taking the biggest connected component\")\n    components = mesh.split(only_watertight=False)\n    areas = np.array([c.area for c in components], dtype=np.float32)\n    mesh = components[areas.argmax()]\n    \n    mesh.export(result_mesh_file)\n    del mesh\n    \n\nif __name__ == \"__main__\":\n\n    parser = argparse.ArgumentParser(\n        description='Arguments to evaluate the mesh.'\n    )\n\n    parser.add_argument('--input_mesh', type=str,  help='path to the mesh to be evaluated')\n    parser.add_argument('--scan_id', type=str,  help='scan id of the input mesh')\n    parser.add_argument('--output_dir', type=str, default='evaluation_results_single', help='path to the output folder')\n    parser.add_argument('--mask_dir', type=str,  default='mask', help='path to uncropped mask')\n    parser.add_argument('--DTU', type=str,  default='Offical_DTU_Dataset', help='path to the GT DTU point clouds')\n    args = parser.parse_args()\n\n    Offical_DTU_Dataset = args.DTU\n    out_dir = args.output_dir\n    Path(out_dir).mkdir(parents=True, exist_ok=True)\n\n    scan = args.scan_id\n    ply_file = args.input_mesh\n    print(\"cull mesh ....\")\n    result_mesh_file = os.path.join(out_dir, \"culled_mesh.ply\")\n    cull_scan(scan, ply_file, result_mesh_file, instance_dir=os.path.join(args.mask_dir, f'scan{args.scan_id}'))\n\n    script_dir = os.path.dirname(os.path.abspath(__file__))\n    cmd = f\"python {script_dir}/eval.py --data {result_mesh_file} --scan {scan} --mode mesh --dataset_dir {Offical_DTU_Dataset} --vis_out_dir {out_dir}\"\n    os.system(cmd)"
  },
  {
    "path": "evaluation/eval_dtu/render_utils.py",
    "content": "import numpy as np\nimport imageio\nimport skimage\nimport cv2\nimport torch\nfrom torch.nn import functional as F\n\n\ndef get_psnr(img1, img2, normalize_rgb=False):\n    if normalize_rgb: # [-1,1] --> [0,1]\n        img1 = (img1 + 1.) / 2.\n        img2 = (img2 + 1. ) / 2.\n\n    mse = torch.mean((img1 - img2) ** 2)\n    psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda())\n\n    return psnr\n\n\ndef load_rgb(path, normalize_rgb = False):\n    img = imageio.imread(path)\n    img = skimage.img_as_float32(img)\n\n    if normalize_rgb: # [-1,1] --> [0,1]\n        img -= 0.5\n        img *= 2.\n    img = img.transpose(2, 0, 1)\n    return img\n\n\ndef load_K_Rt_from_P(filename, P=None):\n    if P is None:\n        lines = open(filename).read().splitlines()\n        if len(lines) == 4:\n            lines = lines[1:]\n        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(\" \") for x in lines)]\n        P = np.asarray(lines).astype(np.float32).squeeze()\n\n    out = cv2.decomposeProjectionMatrix(P)\n    K = out[0]\n    R = out[1]\n    t = out[2]\n\n    K = K/K[2,2]\n    intrinsics = np.eye(4)\n    intrinsics[:3, :3] = K\n\n    pose = np.eye(4, dtype=np.float32)\n    pose[:3, :3] = R.transpose()\n    pose[:3,3] = (t[:3] / t[3])[:,0]\n\n    return intrinsics, pose\n\n\ndef get_camera_params(uv, pose, intrinsics):\n    if pose.shape[1] == 7: #In case of quaternion vector representation\n        cam_loc = pose[:, 4:]\n        R = quat_to_rot(pose[:,:4])\n        p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()\n        p[:, :3, :3] = R\n        p[:, :3, 3] = cam_loc\n    else: # In case of pose matrix representation\n        cam_loc = pose[:, :3, 3]\n        p = pose\n\n    batch_size, num_samples, _ = uv.shape\n\n    depth = torch.ones((batch_size, num_samples)).cuda()\n    x_cam = uv[:, :, 0].view(batch_size, -1)\n    y_cam = uv[:, :, 1].view(batch_size, -1)\n    z_cam = depth.view(batch_size, -1)\n\n    pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)\n\n    # permute for batch matrix product\n    pixel_points_cam = pixel_points_cam.permute(0, 2, 1)\n\n    world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]\n    ray_dirs = world_coords - cam_loc[:, None, :]\n    ray_dirs = F.normalize(ray_dirs, dim=2)\n\n    return ray_dirs, cam_loc\n\n\ndef get_camera_for_plot(pose):\n    if pose.shape[1] == 7: #In case of quaternion vector representation\n        cam_loc = pose[:, 4:].detach()\n        R = quat_to_rot(pose[:,:4].detach())\n    else: # In case of pose matrix representation\n        cam_loc = pose[:, :3, 3]\n        R = pose[:, :3, :3]\n    cam_dir = R[:, :3, 2]\n    return cam_loc, cam_dir\n\n\ndef lift(x, y, z, intrinsics):\n    # parse intrinsics\n    intrinsics = intrinsics.cuda()\n    fx = intrinsics[:, 0, 0]\n    fy = intrinsics[:, 1, 1]\n    cx = intrinsics[:, 0, 2]\n    cy = intrinsics[:, 1, 2]\n    sk = intrinsics[:, 0, 1]\n\n    x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z\n    y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z\n\n    # homogeneous\n    return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1)\n\n\ndef quat_to_rot(q):\n    batch_size, _ = q.shape\n    q = F.normalize(q, dim=1)\n    R = torch.ones((batch_size, 3,3)).cuda()\n    qr=q[:,0]\n    qi = q[:, 1]\n    qj = q[:, 2]\n    qk = q[:, 3]\n    R[:, 0, 0]=1-2 * (qj**2 + qk**2)\n    R[:, 0, 1] = 2 * (qj *qi -qk*qr)\n    R[:, 0, 2] = 2 * (qi * qk + qr * qj)\n    R[:, 1, 0] = 2 * (qj * qi + qk * qr)\n    R[:, 1, 1] = 1-2 * (qi**2 + qk**2)\n    R[:, 1, 2] = 2*(qj*qk - qi*qr)\n    R[:, 2, 0] = 2 * (qk * qi-qj * qr)\n    R[:, 2, 1] = 2 * (qj*qk + qi*qr)\n    R[:, 2, 2] = 1-2 * (qi**2 + qj**2)\n    return R\n\n\ndef rot_to_quat(R):\n    batch_size, _,_ = R.shape\n    q = torch.ones((batch_size, 4)).cuda()\n\n    R00 = R[:, 0,0]\n    R01 = R[:, 0, 1]\n    R02 = R[:, 0, 2]\n    R10 = R[:, 1, 0]\n    R11 = R[:, 1, 1]\n    R12 = R[:, 1, 2]\n    R20 = R[:, 2, 0]\n    R21 = R[:, 2, 1]\n    R22 = R[:, 2, 2]\n\n    q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2\n    q[:, 1]=(R21-R12)/(4*q[:,0])\n    q[:, 2] = (R02 - R20) / (4 * q[:, 0])\n    q[:, 3] = (R10 - R01) / (4 * q[:, 0])\n    return q\n\n\ndef get_sphere_intersections(cam_loc, ray_directions, r = 1.0):\n    # Input: n_rays x 3 ; n_rays x 3\n    # Output: n_rays x 1, n_rays x 1 (close and far)\n\n    ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3),\n                            cam_loc.view(-1, 3, 1)).squeeze(-1)\n    under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2)\n\n    # sanity check\n    if (under_sqrt <= 0).sum() > 0:\n        print('BOUNDING SPHERE PROBLEM!')\n        exit()\n\n    sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot\n    sphere_intersections = sphere_intersections.clamp_min(0.0)\n\n    return sphere_intersections"
  },
  {
    "path": "evaluation/eval_tnt.py",
    "content": "import os\nimport trimesh\nimport argparse\nimport numpy as np\nimport open3d as o3d\nfrom sklearn.neighbors import KDTree\n\n\ndef nn_correspondance(verts1, verts2):\n    indices = []\n    distances = []\n    if len(verts1) == 0 or len(verts2) == 0:\n        return indices, distances\n\n    kdtree = KDTree(verts1)\n    distances, indices = kdtree.query(verts2)\n    distances = distances.reshape(-1)\n\n    return distances\n\n\ndef evaluate(mesh_pred, mesh_trgt, threshold=.05, down_sample=.02):\n    pcd_trgt = o3d.geometry.PointCloud()\n    pcd_pred = o3d.geometry.PointCloud()\n    \n    pcd_trgt.points = o3d.utility.Vector3dVector(mesh_trgt.vertices[:, :3])\n    pcd_pred.points = o3d.utility.Vector3dVector(mesh_pred.vertices[:, :3])\n\n    if down_sample:\n        pcd_pred = pcd_pred.voxel_down_sample(down_sample)\n        pcd_trgt = pcd_trgt.voxel_down_sample(down_sample)\n\n    verts_pred = np.asarray(pcd_pred.points)\n    verts_trgt = np.asarray(pcd_trgt.points)\n\n    dist1 = nn_correspondance(verts_pred, verts_trgt)\n    dist2 = nn_correspondance(verts_trgt, verts_pred)\n\n    precision = np.mean((dist2 < threshold).astype('float'))\n    recal = np.mean((dist1 < threshold).astype('float'))\n    fscore = 2 * precision * recal / (precision + recal)\n    metrics = {\n        'Acc': np.mean(dist2),\n        'Comp': np.mean(dist1),\n        'Prec': precision,\n        'Recal': recal,\n        'F-score': fscore,\n    }\n    return metrics\n\n\ndef main(args):\n    assert os.path.exists(args.ply_path), f\"PLY file {args.ply_path} does not exist.\"\n    \n    mesh_rec = trimesh.load(args.ply_path, process=False)\n    mesh_gt = trimesh.load(args.gt_path, process=False)\n    \n    to_align, _ = trimesh.bounds.oriented_bounds(mesh_gt)\n    mesh_gt.vertices = (to_align[:3, :3] @ mesh_gt.vertices.T + to_align[:3, 3:]).T\n    mesh_rec.vertices = (to_align[:3, :3] @ mesh_rec.vertices.T + to_align[:3, 3:]).T\n    \n    min_points = mesh_gt.vertices.min(axis=0)\n    max_points = mesh_gt.vertices.max(axis=0)\n\n    mask_min = (mesh_rec.vertices - min_points[None]) > 0\n    mask_max = (mesh_rec.vertices - max_points[None]) < 0\n\n    mask = np.concatenate((mask_min, mask_max), axis=1).all(axis=1)\n    face_mask = mask[mesh_rec.faces].all(axis=1)\n\n    mesh_rec.update_vertices(mask)\n    mesh_rec.update_faces(face_mask)\n    \n    metrics = evaluate(mesh_rec, mesh_gt)\n    \n    metrics_path = os.path.join(os.path.dirname(args.ply_path), 'metrics.txt')\n    with open(metrics_path, 'w') as f:\n        for k, v in metrics.items():\n            f.write(f'{k}: {v}\\n')\n    \n    print('Scene: {} F-score: {}'.format(args.scene, metrics['F-score']))\n    \n    mesh_rec.vertices = (to_align[:3, :3].T @ mesh_rec.vertices.T - to_align[:3, :3].T @ to_align[:3, 3:]).T\n    mesh_rec.export(args.ply_path.replace('.ply', '_crop.ply'))\n    \n    return\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--gt_path\",\n        type=str,\n        default='/your/path//Barn_GT.ply',\n        help=\"path to a dataset/scene directory containing X.json, X.ply, ...\",\n    )\n    parser.add_argument(\n        \"--ply_path\",\n        type=str,\n        default='/your/path//Barn_lowres.ply',\n        help=\"path to reconstruction ply file\",\n    )\n    parser.add_argument(\n        \"--scene\",\n        type=str,\n        default='Barn',\n        help=\"path to reconstruction ply file\",\n    )\n    args = parser.parse_args()\n    \n    main(args)\n\n"
  },
  {
    "path": "evaluation/full_eval.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nfrom argparse import ArgumentParser\n\nmipnerf360_outdoor_scenes = [\"bicycle\", \"flowers\", \"garden\", \"stump\", \"treehill\"]\nmipnerf360_indoor_scenes = [\"room\", \"counter\", \"kitchen\", \"bonsai\"]\ntanks_and_temples_scenes = [\"truck\", \"train\"]\ndeep_blending_scenes = [\"drjohnson\", \"playroom\"]\n\nparser = ArgumentParser(description=\"Full evaluation script parameters\")\nparser.add_argument(\"--skip_training\", action=\"store_true\")\nparser.add_argument(\"--skip_rendering\", action=\"store_true\")\nparser.add_argument(\"--skip_metrics\", action=\"store_true\")\nparser.add_argument(\"--output_path\", default=\"./eval\")\nargs, _ = parser.parse_known_args()\n\nall_scenes = []\nall_scenes.extend(mipnerf360_outdoor_scenes)\nall_scenes.extend(mipnerf360_indoor_scenes)\nall_scenes.extend(tanks_and_temples_scenes)\nall_scenes.extend(deep_blending_scenes)\n\nif not args.skip_training or not args.skip_rendering:\n    parser.add_argument('--mipnerf360', \"-m360\", required=True, type=str)\n    parser.add_argument(\"--tanksandtemples\", \"-tat\", required=True, type=str)\n    parser.add_argument(\"--deepblending\", \"-db\", required=True, type=str)\n    args = parser.parse_args()\n\nif not args.skip_training:\n    common_args = \" --quiet --eval --test_iterations -1 \"\n    for scene in mipnerf360_outdoor_scenes:\n        source = args.mipnerf360 + \"/\" + scene\n        os.system(\"python train.py -s \" + source + \" -i images_4 -m \" + args.output_path + \"/\" + scene + common_args)\n    for scene in mipnerf360_indoor_scenes:\n        source = args.mipnerf360 + \"/\" + scene\n        os.system(\"python train.py -s \" + source + \" -i images_2 -m \" + args.output_path + \"/\" + scene + common_args)\n    for scene in tanks_and_temples_scenes:\n        source = args.tanksandtemples + \"/\" + scene\n        os.system(\"python train.py -s \" + source + \" -m \" + args.output_path + \"/\" + scene + common_args)\n    for scene in deep_blending_scenes:\n        source = args.deepblending + \"/\" + scene\n        os.system(\"python train.py -s \" + source + \" -m \" + args.output_path + \"/\" + scene + common_args)\n\nif not args.skip_rendering:\n    all_sources = []\n    for scene in mipnerf360_outdoor_scenes:\n        all_sources.append(args.mipnerf360 + \"/\" + scene)\n    for scene in mipnerf360_indoor_scenes:\n        all_sources.append(args.mipnerf360 + \"/\" + scene)\n    for scene in tanks_and_temples_scenes:\n        all_sources.append(args.tanksandtemples + \"/\" + scene)\n    for scene in deep_blending_scenes:\n        all_sources.append(args.deepblending + \"/\" + scene)\n\n    common_args = \" --quiet --eval --skip_train\"\n    for scene, source in zip(all_scenes, all_sources):\n        os.system(\"python render.py --iteration 7000 -s \" + source + \" -m \" + args.output_path + \"/\" + scene + common_args)\n        os.system(\"python render.py --iteration 30000 -s \" + source + \" -m \" + args.output_path + \"/\" + scene + common_args)\n\nif not args.skip_metrics:\n    scenes_string = \"\"\n    for scene in all_scenes:\n        scenes_string += \"\\\"\" + args.output_path + \"/\" + scene + \"\\\" \"\n\n    os.system(\"python metrics.py -m \" + scenes_string)"
  },
  {
    "path": "evaluation/lpipsPyTorch/__init__.py",
    "content": "import torch\n\nfrom .modules.lpips import LPIPS\n\n\ndef lpips(x: torch.Tensor,\n          y: torch.Tensor,\n          net_type: str = 'alex',\n          version: str = '0.1'):\n    r\"\"\"Function that measures\n    Learned Perceptual Image Patch Similarity (LPIPS).\n\n    Arguments:\n        x, y (torch.Tensor): the input tensors to compare.\n        net_type (str): the network type to compare the features: \n                        'alex' | 'squeeze' | 'vgg'. Default: 'alex'.\n        version (str): the version of LPIPS. Default: 0.1.\n    \"\"\"\n    device = x.device\n    criterion = LPIPS(net_type, version).to(device)\n    return criterion(x, y)\n"
  },
  {
    "path": "evaluation/lpipsPyTorch/modules/lpips.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom .networks import get_network, LinLayers\nfrom .utils import get_state_dict\n\n\nclass LPIPS(nn.Module):\n    r\"\"\"Creates a criterion that measures\n    Learned Perceptual Image Patch Similarity (LPIPS).\n\n    Arguments:\n        net_type (str): the network type to compare the features: \n                        'alex' | 'squeeze' | 'vgg'. Default: 'alex'.\n        version (str): the version of LPIPS. Default: 0.1.\n    \"\"\"\n    def __init__(self, net_type: str = 'alex', version: str = '0.1'):\n\n        assert version in ['0.1'], 'v0.1 is only supported now'\n\n        super(LPIPS, self).__init__()\n\n        # pretrained network\n        self.net = get_network(net_type)\n\n        # linear layers\n        self.lin = LinLayers(self.net.n_channels_list)\n        self.lin.load_state_dict(get_state_dict(net_type, version))\n\n    def forward(self, x: torch.Tensor, y: torch.Tensor):\n        feat_x, feat_y = self.net(x), self.net(y)\n\n        diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]\n        res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]\n\n        return torch.sum(torch.cat(res, 0), 0, True)\n"
  },
  {
    "path": "evaluation/lpipsPyTorch/modules/networks.py",
    "content": "from typing import Sequence\n\nfrom itertools import chain\n\nimport torch\nimport torch.nn as nn\nfrom torchvision import models\n\nfrom .utils import normalize_activation\n\n\ndef get_network(net_type: str):\n    if net_type == 'alex':\n        return AlexNet()\n    elif net_type == 'squeeze':\n        return SqueezeNet()\n    elif net_type == 'vgg':\n        return VGG16()\n    else:\n        raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')\n\n\nclass LinLayers(nn.ModuleList):\n    def __init__(self, n_channels_list: Sequence[int]):\n        super(LinLayers, self).__init__([\n            nn.Sequential(\n                nn.Identity(),\n                nn.Conv2d(nc, 1, 1, 1, 0, bias=False)\n            ) for nc in n_channels_list\n        ])\n\n        for param in self.parameters():\n            param.requires_grad = False\n\n\nclass BaseNet(nn.Module):\n    def __init__(self):\n        super(BaseNet, self).__init__()\n\n        # register buffer\n        self.register_buffer(\n            'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])\n        self.register_buffer(\n            'std', torch.Tensor([.458, .448, .450])[None, :, None, None])\n\n    def set_requires_grad(self, state: bool):\n        for param in chain(self.parameters(), self.buffers()):\n            param.requires_grad = state\n\n    def z_score(self, x: torch.Tensor):\n        return (x - self.mean) / self.std\n\n    def forward(self, x: torch.Tensor):\n        x = self.z_score(x)\n\n        output = []\n        for i, (_, layer) in enumerate(self.layers._modules.items(), 1):\n            x = layer(x)\n            if i in self.target_layers:\n                output.append(normalize_activation(x))\n            if len(output) == len(self.target_layers):\n                break\n        return output\n\n\nclass SqueezeNet(BaseNet):\n    def __init__(self):\n        super(SqueezeNet, self).__init__()\n\n        self.layers = models.squeezenet1_1(True).features\n        self.target_layers = [2, 5, 8, 10, 11, 12, 13]\n        self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]\n\n        self.set_requires_grad(False)\n\n\nclass AlexNet(BaseNet):\n    def __init__(self):\n        super(AlexNet, self).__init__()\n\n        self.layers = models.alexnet(True).features\n        self.target_layers = [2, 5, 8, 10, 12]\n        self.n_channels_list = [64, 192, 384, 256, 256]\n\n        self.set_requires_grad(False)\n\n\nclass VGG16(BaseNet):\n    def __init__(self):\n        super(VGG16, self).__init__()\n\n        self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features\n        self.target_layers = [4, 9, 16, 23, 30]\n        self.n_channels_list = [64, 128, 256, 512, 512]\n\n        self.set_requires_grad(False)\n"
  },
  {
    "path": "evaluation/lpipsPyTorch/modules/utils.py",
    "content": "from collections import OrderedDict\n\nimport torch\n\n\ndef normalize_activation(x, eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))\n    return x / (norm_factor + eps)\n\n\ndef get_state_dict(net_type: str = 'alex', version: str = '0.1'):\n    # build url\n    url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \\\n        + f'master/lpips/weights/v{version}/{net_type}.pth'\n\n    # download\n    old_state_dict = torch.hub.load_state_dict_from_url(\n        url, progress=True,\n        map_location=None if torch.cuda.is_available() else torch.device('cpu')\n    )\n\n    # rename keys\n    new_state_dict = OrderedDict()\n    for key, val in old_state_dict.items():\n        new_key = key\n        new_key = new_key.replace('lin', '')\n        new_key = new_key.replace('model.', '')\n        new_state_dict[new_key] = val\n\n    return new_state_dict\n"
  },
  {
    "path": "evaluation/metrics.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport sys\nimport json\nimport torch\nfrom PIL import Image\nfrom tqdm import tqdm\nfrom pathlib import Path\nimport torchvision.transforms.functional as tf\nsys.path.append(os.getcwd())\n\nfrom tools.loss_utils import ssim\nfrom lpipsPyTorch import lpips\nfrom tools.image_utils import psnr\nfrom argparse import ArgumentParser\nfrom configs.config import Config\nfrom tools.general_utils import set_random_seed\n\n\ndef readImages(renders_dir, gt_dir):\n    renders = []\n    gts = []\n    image_names = []\n    for fname in os.listdir(renders_dir):\n        render = Image.open(renders_dir / fname)\n        gt = Image.open(gt_dir / fname)\n        renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())\n        gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())\n        image_names.append(fname)\n    return renders, gts, image_names\n\ndef evaluate(model_paths):\n\n    full_dict = {}\n    per_view_dict = {}\n    full_dict_polytopeonly = {}\n    per_view_dict_polytopeonly = {}\n    print(\"\")\n\n    for scene_dir in model_paths:\n        try:\n            print(\"Scene:\", scene_dir)\n            full_dict[scene_dir] = {}\n            per_view_dict[scene_dir] = {}\n            full_dict_polytopeonly[scene_dir] = {}\n            per_view_dict_polytopeonly[scene_dir] = {}\n\n            test_dir = Path(scene_dir) / \"test\"\n\n            for method in os.listdir(test_dir):\n                print(\"Method:\", method)\n\n                full_dict[scene_dir][method] = {}\n                per_view_dict[scene_dir][method] = {}\n                full_dict_polytopeonly[scene_dir][method] = {}\n                per_view_dict_polytopeonly[scene_dir][method] = {}\n\n                method_dir = test_dir / method\n                gt_dir = method_dir/ \"gt\"\n                renders_dir = method_dir / \"renders\"\n                renders, gts, image_names = readImages(renders_dir, gt_dir)\n\n                ssims = []\n                psnrs = []\n                lpipss = []\n\n                for idx in tqdm(range(len(renders)), desc=\"Metric evaluation progress\"):\n                    ssims.append(ssim(renders[idx], gts[idx]))\n                    psnrs.append(psnr(renders[idx], gts[idx]))\n                    lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))\n\n\n                full_dict[scene_dir][method].update({\"SSIM\": torch.tensor(ssims).mean().item(),\n                                                        \"PSNR\": torch.tensor(psnrs).mean().item(),\n                                                        \"LPIPS\": torch.tensor(lpipss).mean().item()})\n                per_view_dict[scene_dir][method].update({\"SSIM\": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},\n                                                            \"PSNR\": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},\n                                                            \"LPIPS\": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})\n\n            with open(scene_dir + \"/results.json\", 'w') as fp:\n                json.dump(full_dict[scene_dir], fp, indent=True)\n            with open(scene_dir + \"/per_view.json\", 'w') as fp:\n                json.dump(per_view_dict[scene_dir], fp, indent=True)\n        except:\n            print(\"Unable to compute metrics for model\", scene_dir)\n\nif __name__ == \"__main__\":\n    device = torch.device(\"cuda:0\")\n    torch.cuda.set_device(device)\n\n    # Set up command line argument parser\n    parser = ArgumentParser(description=\"Training script parameters\")\n    parser.add_argument('--cfg_path', type=str, default='configs/config_base.yaml')\n    args = parser.parse_args()\n    \n    cfg = Config(args.cfg_path)\n    cfg.model.data_device = 'cpu'\n    cfg.model.load_normal = False\n    \n    set_random_seed(cfg.seed)\n    \n    evaluate([cfg.model.model_path])\n"
  },
  {
    "path": "evaluation/render.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport sys\nimport torch\nimport torchvision\nfrom tqdm import tqdm\nfrom argparse import ArgumentParser\nsys.path.append(os.getcwd())\n\nfrom scene import Scene\nfrom gaussian_renderer import render, render_fast\nfrom gaussian_renderer import GaussianModel\nfrom configs.config import Config\nfrom tools.general_utils import set_random_seed\nfrom tools.loss_utils import cos_weight\n\n\ndef render_set(model_path, name, iteration, views, gaussians, cfg, background):\n    render_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"renders\")\n    gts_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"gt\")\n\n    os.makedirs(render_path, exist_ok=True)\n    os.makedirs(gts_path, exist_ok=True)\n    alphas = []\n\n    for idx, view in enumerate(tqdm(views, desc=\"Rendering progress\")):\n        outs = render(view, gaussians, cfg, background)\n        # outs = render_fast(view, gaussians, cfg, background)\n        \n        rendering = outs[\"render\"]\n        gt = view.original_image[0:3, :, :]\n        torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + \".png\"))\n        torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + \".png\"))\n        alphas.append(outs[\"alpha\"].detach().clone().view(-1).cpu())\n        \n        if False:\n            normal_map = outs[\"normal\"].detach().clone()\n            normal_gt = view.normal.cuda()\n            cos = cos_weight(normal_gt, normal_map, cfg.optim.exp_t, cfg.optim.cos_thr)\n            torchvision.utils.save_image(cos, os.path.join(render_path, '{0:05d}_cosine'.format(idx) + \".png\"))\n    \n    # alphas = torch.cat(alphas, dim=0)\n    # print(\"Alpha min: {}, max: {}\".format(alphas.min(), alphas.max()))\n    # print(\"Alpha mean: {}, std: {}\".format(alphas.mean(), alphas.std()))\n    # print(\"Alpha median: {}\".format(alphas.median()))\n\n\ndef render_sets(cfg, iteration : int, skip_train : bool, skip_test : bool):\n    with torch.no_grad():\n        gaussians = GaussianModel(cfg.model)\n        scene = Scene(cfg.model, gaussians, load_iteration=iteration, shuffle=False)\n        # gaussians.extent = scene.cameras_extent\n\n        bg_color = [1,1,1] if cfg.model.white_background else [0, 0, 0]\n        background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n\n        if not skip_train:\n            render_set(cfg.model.model_path, \"train\", scene.loaded_iter, scene.getTrainCameras(), gaussians, cfg, background)\n\n        if not skip_test:\n            render_set(cfg.model.model_path, \"test\", scene.loaded_iter, scene.getTestCameras(), gaussians, cfg, background)\n            \n\nif __name__ == \"__main__\":\n    # Set up command line argument parser\n    parser = ArgumentParser()\n    parser.add_argument('--cfg_path', type=str, default='configs/config_base.yaml')\n    parser.add_argument(\"--iteration\", default=-1, type=int)\n    parser.add_argument(\"--skip_train\", action=\"store_true\")\n    parser.add_argument(\"--skip_test\", action=\"store_true\")\n    args = parser.parse_args()\n    \n    cfg = Config(args.cfg_path)\n    cfg.model.data_device = 'cuda'\n    cfg.model.load_normal = False\n    cfg.model.load_mask = False\n    \n    set_random_seed(cfg.seed)\n\n    # Initialize system state (RNG)\n    # safe_state(args.quiet)\n\n    render_sets(cfg, args.iteration, args.skip_train, args.skip_test)"
  },
  {
    "path": "evaluation/tnt_eval/README.md",
    "content": "# Python Toolbox for Evaluation\n\nThis Python script evaluates **training** dataset of TanksAndTemples benchmark.\nThe script requires ``Open3D`` and a few Python packages such as ``matplotlib``, ``json``, and ``numpy``.\n\n## How to use:\n**Step 0**. Reconstruct 3D models and recover camera poses from the training dataset.\nThe raw videos of the training dataset can be found from:\nhttps://tanksandtemples.org/download/\n\n**Step 1**. Download evaluation data (ground truth geometry + reference reconstruction) using\n[this link](https://drive.google.com/open?id=1UoKPiUUsKa0AVHFOrnMRhc5hFngjkE-t). In this example, we regard ``TanksAndTemples/evaluation/data/`` as a dataset folder.\n\n**Step 2**. Install Open3D. Follow instructions in http://open3d.org/docs/getting_started.html\n\n**Step 3**. Run the evaluation script and grab some coffee.\n```\npython run.py --dataset-dir path/to/TanksAndTemples/evaluation/data/Ignatius --traj-path path/to/TanksAndTemples/evaluation/data/Ignatius/Ignatius_COLMAP_SfM.log --ply-path path/to/TanksAndTemples/evaluation/data/Ignatius/Ignatius_COLMAP.ply\n```\nOutput (evaluation of Ignatius):\n```\n===========================\nEvaluating Ignatius\n===========================\npath/to/TanksAndTemples/evaluation/data/Ignatius/Ignatius_COLMAP.ply\nReading PLY: [========================================] 100%\nRead PointCloud: 6929586 vertices.\npath/to/TanksAndTemples/evaluation/data/Ignatius/Ignatius.ply\nReading PLY: [========================================] 100%\n:\nICP Iteration #0: Fitness 0.9980, RMSE 0.0044\nICP Iteration #1: Fitness 0.9980, RMSE 0.0043\nICP Iteration #2: Fitness 0.9980, RMSE 0.0043\nICP Iteration #3: Fitness 0.9980, RMSE 0.0043\nICP Iteration #4: Fitness 0.9980, RMSE 0.0042\nICP Iteration #5: Fitness 0.9980, RMSE 0.0042\nICP Iteration #6: Fitness 0.9979, RMSE 0.0042\nICP Iteration #7: Fitness 0.9979, RMSE 0.0042\nICP Iteration #8: Fitness 0.9979, RMSE 0.0042\nICP Iteration #9: Fitness 0.9979, RMSE 0.0042\nICP Iteration #10: Fitness 0.9979, RMSE 0.0042\n[EvaluateHisto]\nCropping geometry: [========================================] 100%\nPointcloud down sampled from 6929586 points to 1449840 points.\nPointcloud down sampled from 1449840 points to 1365628 points.\npath/to/TanksAndTemples/evaluation/data/Ignatius/evaluation//Ignatius.precision.ply\nCropping geometry: [========================================] 100%\nPointcloud down sampled from 5016769 points to 4957123 points.\nPointcloud down sampled from 4957123 points to 4181506 points.\n[compute_point_cloud_to_point_cloud_distance]\n[compute_point_cloud_to_point_cloud_distance]\n:\n[ViewDistances] Add color coding to visualize error\n[ViewDistances] Add color coding to visualize error\n:\n[get_f1_score_histo2]\n==============================\nevaluation result : Ignatius\n==============================\ndistance tau : 0.003\nprecision : 0.7679\nrecall : 0.7937\nf-score : 0.7806\n==============================\n```\n\n**Step 5**. Go to the evaluation folder. ``TanksAndTemples/evaluation/data/Ignatius/evaluation/`` will have the following outputs.\n\n<img src=\"images/f-score.jpg\" width=\"400\">\n\n``PR_Ignatius_@d_th_0_0030.pdf`` (Precision and recall curves with a F-score)\n\n| <img src=\"images/precision.jpg\" width=\"200\"> | <img src=\"images/recall.jpg\" width=\"200\"> |\n|--|--|\n| ``Ignatius.precision.ply``  | ``Ignatius.recall.ply`` |\n\n(3D visualization of precision and recall. Each mesh is color coded using hot colormap)\n\n# Requirements\n\n- Python 3\n- open3d v0.9.0\n- matplotlib\n"
  },
  {
    "path": "evaluation/tnt_eval/config.py",
    "content": "# ----------------------------------------------------------------------------\n# -                   TanksAndTemples Website Toolbox                        -\n# -                    http://www.tanksandtemples.org                        -\n# ----------------------------------------------------------------------------\n# The MIT License (MIT)\n#\n# Copyright (c) 2017\n# Arno Knapitsch <arno.knapitsch@gmail.com >\n# Jaesik Park <syncle@gmail.com>\n# Qian-Yi Zhou <Qianyi.Zhou@gmail.com>\n# Vladlen Koltun <vkoltun@gmail.com>\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n# THE SOFTWARE.\n# ----------------------------------------------------------------------------\n\n# some global parameters - do not modify\nscenes_tau_dict = {\n    \"Barn\": 0.01,\n    \"Caterpillar\": 0.005,\n    \"Church\": 0.025,\n    \"Courthouse\": 0.025,\n    \"Ignatius\": 0.003,\n    \"Meetingroom\": 0.01,\n    \"Truck\": 0.005,\n}\n"
  },
  {
    "path": "evaluation/tnt_eval/evaluation.py",
    "content": "# ----------------------------------------------------------------------------\n# -                   TanksAndTemples Website Toolbox                        -\n# -                    http://www.tanksandtemples.org                        -\n# ----------------------------------------------------------------------------\n# The MIT License (MIT)\n#\n# Copyright (c) 2017\n# Arno Knapitsch <arno.knapitsch@gmail.com >\n# Jaesik Park <syncle@gmail.com>\n# Qian-Yi Zhou <Qianyi.Zhou@gmail.com>\n# Vladlen Koltun <vkoltun@gmail.com>\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n# THE SOFTWARE.\n# ----------------------------------------------------------------------------\n#\n# This python script is for downloading dataset from www.tanksandtemples.org\n# The dataset has a different license, please refer to\n# https://tanksandtemples.org/license/\n\nimport json\nimport copy\nimport os\nimport numpy as np\nimport open3d as o3d\nimport matplotlib.pyplot as plt\n\n\ndef read_alignment_transformation(filename):\n    with open(filename) as data_file:\n        data = json.load(data_file)\n    return np.asarray(data[\"transformation\"]).reshape((4, 4)).transpose()\n\n\ndef write_color_distances(path, pcd, distances, max_distance):\n    o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Debug)\n    # cmap = plt.get_cmap(\"afmhot\")\n    cmap = plt.get_cmap(\"hot_r\")\n    distances = np.array(distances)\n    colors = cmap(np.minimum(distances, max_distance) / max_distance)[:, :3]\n    pcd.colors = o3d.utility.Vector3dVector(colors)\n    o3d.io.write_point_cloud(path, pcd)\n\n\ndef EvaluateHisto(\n    source,\n    target,\n    trans,\n    crop_volume,\n    voxel_size,\n    threshold,\n    filename_mvs,\n    plot_stretch,\n    scene_name,\n    verbose=True,\n):\n    print(\"[EvaluateHisto]\")\n    o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Debug)\n    s = copy.deepcopy(source)\n    s.transform(trans)\n    s = crop_volume.crop_point_cloud(s)\n    s = s.voxel_down_sample(voxel_size)\n    s.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20))\n    print(filename_mvs + \"/\" + scene_name + \".precision.ply\")\n\n    t = copy.deepcopy(target)\n    t = crop_volume.crop_point_cloud(t)\n    t = t.voxel_down_sample(voxel_size)\n    t.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamKNN(knn=20))\n    print(\"[compute_point_cloud_to_point_cloud_distance]\")\n    distance1 = s.compute_point_cloud_distance(t)\n    print(\"[compute_point_cloud_to_point_cloud_distance]\")\n    distance2 = t.compute_point_cloud_distance(s)\n\n    # write the distances to bin files\n    # np.array(distance1).astype(\"float64\").tofile(\n    #     filename_mvs + \"/\" + scene_name + \".precision.bin\"\n    # )\n    # np.array(distance2).astype(\"float64\").tofile(\n    #     filename_mvs + \"/\" + scene_name + \".recall.bin\"\n    # )\n\n    # Colorize the poincloud files prith the precision and recall values\n    # o3d.io.write_point_cloud(\n    #     filename_mvs + \"/\" + scene_name + \".precision.ply\", s\n    # )\n    # o3d.io.write_point_cloud(\n    #     filename_mvs + \"/\" + scene_name + \".precision.ncb.ply\", s\n    # )\n    # o3d.io.write_point_cloud(filename_mvs + \"/\" + scene_name + \".recall.ply\", t)\n\n    source_n_fn = filename_mvs + \"/\" + scene_name + \".precision.ply\"\n    target_n_fn = filename_mvs + \"/\" + scene_name + \".recall.ply\"\n\n    print(\"[ViewDistances] Add color coding to visualize error\")\n    # eval_str_viewDT = (\n    #     OPEN3D_EXPERIMENTAL_BIN_PATH\n    #     + \"ViewDistances \"\n    #     + source_n_fn\n    #     + \" --max_distance \"\n    #     + str(threshold * 3)\n    #     + \" --write_color_back --without_gui\"\n    # )\n    # os.system(eval_str_viewDT)\n    write_color_distances(source_n_fn, s, distance1, 3 * threshold)\n\n    print(\"[ViewDistances] Add color coding to visualize error\")\n    # eval_str_viewDT = (\n    #     OPEN3D_EXPERIMENTAL_BIN_PATH\n    #     + \"ViewDistances \"\n    #     + target_n_fn\n    #     + \" --max_distance \"\n    #     + str(threshold * 3)\n    #     + \" --write_color_back --without_gui\"\n    # )\n    # os.system(eval_str_viewDT)\n    write_color_distances(target_n_fn, t, distance2, 3 * threshold)\n\n    # get histogram and f-score\n    [\n        precision,\n        recall,\n        fscore,\n        edges_source,\n        cum_source,\n        edges_target,\n        cum_target,\n    ] = get_f1_score_histo2(threshold, filename_mvs, plot_stretch, distance1,\n                            distance2)\n    np.savetxt(filename_mvs + \"/\" + scene_name + \".recall.txt\", cum_target)\n    np.savetxt(filename_mvs + \"/\" + scene_name + \".precision.txt\", cum_source)\n    np.savetxt(\n        filename_mvs + \"/\" + scene_name + \".prf_tau_plotstr.txt\",\n        np.array([precision, recall, fscore, threshold, plot_stretch]),\n    )\n\n    return [\n        precision,\n        recall,\n        fscore,\n        edges_source,\n        cum_source,\n        edges_target,\n        cum_target,\n    ]\n\n\ndef get_f1_score_histo2(threshold,\n                        filename_mvs,\n                        plot_stretch,\n                        distance1,\n                        distance2,\n                        verbose=True):\n    print(\"[get_f1_score_histo2]\")\n    dist_threshold = threshold\n    if len(distance1) and len(distance2):\n\n        recall = float(sum(d < threshold for d in distance2)) / float(\n            len(distance2))\n        precision = float(sum(d < threshold for d in distance1)) / float(\n            len(distance1))\n        fscore = 2 * recall * precision / (recall + precision)\n        num = len(distance1)\n        bins = np.arange(0, dist_threshold * plot_stretch, dist_threshold / 100)\n        hist, edges_source = np.histogram(distance1, bins)\n        cum_source = np.cumsum(hist).astype(float) / num\n\n        num = len(distance2)\n        bins = np.arange(0, dist_threshold * plot_stretch, dist_threshold / 100)\n        hist, edges_target = np.histogram(distance2, bins)\n        cum_target = np.cumsum(hist).astype(float) / num\n\n    else:\n        precision = 0\n        recall = 0\n        fscore = 0\n        edges_source = np.array([0])\n        cum_source = np.array([0])\n        edges_target = np.array([0])\n        cum_target = np.array([0])\n\n    return [\n        precision,\n        recall,\n        fscore,\n        edges_source,\n        cum_source,\n        edges_target,\n        cum_target,\n    ]\n"
  },
  {
    "path": "evaluation/tnt_eval/plot.py",
    "content": "# ----------------------------------------------------------------------------\n# -                   TanksAndTemples Website Toolbox                        -\n# -                    http://www.tanksandtemples.org                        -\n# ----------------------------------------------------------------------------\n# The MIT License (MIT)\n#\n# Copyright (c) 2017\n# Arno Knapitsch <arno.knapitsch@gmail.com >\n# Jaesik Park <syncle@gmail.com>\n# Qian-Yi Zhou <Qianyi.Zhou@gmail.com>\n# Vladlen Koltun <vkoltun@gmail.com>\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n# THE SOFTWARE.\n# ----------------------------------------------------------------------------\n#\n# This python script is for downloading dataset from www.tanksandtemples.org\n# The dataset has a different license, please refer to\n# https://tanksandtemples.org/license/\n\nimport matplotlib.pyplot as plt\nfrom cycler import cycler\n\n\ndef plot_graph(\n    scene,\n    fscore,\n    dist_threshold,\n    edges_source,\n    cum_source,\n    edges_target,\n    cum_target,\n    plot_stretch,\n    mvs_outpath,\n    show_figure=False,\n):\n    f = plt.figure()\n    plt_size = [14, 7]\n    pfontsize = \"medium\"\n\n    ax = plt.subplot(111)\n    label_str = \"precision\"\n    ax.plot(\n        edges_source[1::],\n        cum_source * 100,\n        c=\"red\",\n        label=label_str,\n        linewidth=2.0,\n    )\n\n    label_str = \"recall\"\n    ax.plot(\n        edges_target[1::],\n        cum_target * 100,\n        c=\"blue\",\n        label=label_str,\n        linewidth=2.0,\n    )\n\n    ax.grid(True)\n    plt.rcParams[\"figure.figsize\"] = plt_size\n    plt.rc(\"axes\", prop_cycle=cycler(\"color\", [\"r\", \"g\", \"b\", \"y\"]))\n    plt.title(\"Precision and Recall: \" + scene + \", \" + \"%02.2f f-score\" %\n              (fscore * 100))\n    plt.axvline(x=dist_threshold, c=\"black\", ls=\"dashed\", linewidth=2.0)\n\n    plt.ylabel(\"# of points (%)\", fontsize=15)\n    plt.xlabel(\"Meters\", fontsize=15)\n    plt.axis([0, dist_threshold * plot_stretch, 0, 100])\n    ax.legend(shadow=True, fancybox=True, fontsize=pfontsize)\n    # plt.axis([0, dist_threshold*plot_stretch, 0, 100])\n\n    plt.setp(ax.get_legend().get_texts(), fontsize=pfontsize)\n\n    plt.legend(loc=2, borderaxespad=0.0, fontsize=pfontsize)\n    plt.legend(loc=4)\n    leg = plt.legend(loc=\"lower right\")\n\n    box = ax.get_position()\n    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])\n\n    # Put a legend to the right of the current axis\n    ax.legend(loc=\"center left\", bbox_to_anchor=(1, 0.5))\n    plt.setp(ax.get_legend().get_texts(), fontsize=pfontsize)\n    png_name = mvs_outpath + \"/PR_{0}_@d_th_0_{1}.png\".format(\n        scene, \"%04d\" % (dist_threshold * 10000))\n    pdf_name = mvs_outpath + \"/PR_{0}_@d_th_0_{1}.pdf\".format(\n        scene, \"%04d\" % (dist_threshold * 10000))\n\n    # save figure and display\n    f.savefig(png_name, format=\"png\", bbox_inches=\"tight\")\n    f.savefig(pdf_name, format=\"pdf\", bbox_inches=\"tight\")\n    if show_figure:\n        plt.show()\n"
  },
  {
    "path": "evaluation/tnt_eval/registration.py",
    "content": "# ----------------------------------------------------------------------------\n# -                   TanksAndTemples Website Toolbox                        -\n# -                    http://www.tanksandtemples.org                        -\n# ----------------------------------------------------------------------------\n# The MIT License (MIT)\n#\n# Copyright (c) 2017\n# Arno Knapitsch <arno.knapitsch@gmail.com >\n# Jaesik Park <syncle@gmail.com>\n# Qian-Yi Zhou <Qianyi.Zhou@gmail.com>\n# Vladlen Koltun <vkoltun@gmail.com>\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n# THE SOFTWARE.\n# ----------------------------------------------------------------------------\n#\n# This python script is for downloading dataset from www.tanksandtemples.org\n# The dataset has a different license, please refer to\n# https://tanksandtemples.org/license/\n\nfrom trajectory_io import read_trajectory, convert_trajectory_to_pointcloud\nimport copy\nimport numpy as np\nimport open3d as o3d\n\nMAX_POINT_NUMBER = 4e6\n\n\ndef read_mapping(filename):\n    mapping = []\n    with open(filename, \"r\") as f:\n        n_sampled_frames = int(f.readline())\n        n_total_frames = int(f.readline())\n        mapping = np.zeros(shape=(n_sampled_frames, 2))\n        metastr = f.readline()\n        for iter in range(n_sampled_frames):\n            metadata = list(map(int, metastr.split()))\n            mapping[iter, :] = metadata\n            metastr = f.readline()\n    return [n_sampled_frames, n_total_frames, mapping]\n\n\ndef gen_sparse_trajectory(mapping, f_trajectory):\n    sparse_traj = []\n    for m in mapping:\n        sparse_traj.append(f_trajectory[int(m[1] - 1)])\n    return sparse_traj\n\n\ndef trajectory_alignment(map_file, traj_to_register, gt_traj_col, gt_trans,\n                         scene):\n    traj_pcd_col = convert_trajectory_to_pointcloud(gt_traj_col)\n    traj_pcd_col.transform(gt_trans)\n    corres = o3d.utility.Vector2iVector(\n        np.asarray(list(map(lambda x: [x, x], range(len(gt_traj_col))))))\n    rr = o3d.registration.RANSACConvergenceCriteria()\n    rr.max_iteration = 100000\n    rr.max_validation = 100000\n\n    # in this case a log file was used which contains\n    # every movie frame (see tutorial for details)\n    if len(traj_to_register) > 1600:\n        n_sampled_frames, n_total_frames, mapping = read_mapping(map_file)\n        traj_col2 = gen_sparse_trajectory(mapping, traj_to_register)\n        traj_to_register_pcd = convert_trajectory_to_pointcloud(traj_col2)\n    else:\n        traj_to_register_pcd = convert_trajectory_to_pointcloud(\n            traj_to_register)\n    randomvar = 0.0\n    nr_of_cam_pos = len(traj_to_register_pcd.points)\n    rand_number_added = np.asanyarray(traj_to_register_pcd.points) * (\n        np.random.rand(nr_of_cam_pos, 3) * randomvar - randomvar / 2.0 + 1)\n    list_rand = list(rand_number_added)\n    traj_to_register_pcd_rand = o3d.geometry.PointCloud()\n    for elem in list_rand:\n        traj_to_register_pcd_rand.points.append(elem)\n\n    # Rough registration based on aligned colmap SfM data\n    reg = o3d.registration.registration_ransac_based_on_correspondence(\n        traj_to_register_pcd_rand,\n        traj_pcd_col,\n        corres,\n        0.2,\n        o3d.registration.TransformationEstimationPointToPoint(True),\n        6,\n        rr,\n    )\n    return reg.transformation\n\n\ndef crop_and_downsample(\n        pcd,\n        crop_volume,\n        down_sample_method=\"voxel\",\n        voxel_size=0.01,\n        trans=np.identity(4),\n):\n    pcd_copy = copy.deepcopy(pcd)\n    pcd_copy.transform(trans)\n    pcd_crop = crop_volume.crop_point_cloud(pcd_copy)\n    if down_sample_method == \"voxel\":\n        # return voxel_down_sample(pcd_crop, voxel_size)\n        return pcd_crop.voxel_down_sample(voxel_size)\n    elif down_sample_method == \"uniform\":\n        n_points = len(pcd_crop.points)\n        if n_points > MAX_POINT_NUMBER:\n            ds_rate = int(round(n_points / float(MAX_POINT_NUMBER)))\n            return pcd_crop.uniform_down_sample(ds_rate)\n    return pcd_crop\n\n\ndef registration_unif(\n    source,\n    gt_target,\n    init_trans,\n    crop_volume,\n    threshold,\n    max_itr,\n    max_size=4 * MAX_POINT_NUMBER,\n    verbose=True,\n):\n    if verbose:\n        print(\"[Registration] threshold: %f\" % threshold)\n        o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Debug)\n    s = crop_and_downsample(source,\n                            crop_volume,\n                            down_sample_method=\"uniform\",\n                            trans=init_trans)\n    t = crop_and_downsample(gt_target,\n                            crop_volume,\n                            down_sample_method=\"uniform\")\n    reg = o3d.registration.registration_icp(\n        s,\n        t,\n        threshold,\n        np.identity(4),\n        o3d.registration.TransformationEstimationPointToPoint(True),\n        o3d.registration.ICPConvergenceCriteria(1e-6, max_itr),\n    )\n    reg.transformation = np.matmul(reg.transformation, init_trans)\n    return reg\n\n\ndef registration_vol_ds(\n    source,\n    gt_target,\n    init_trans,\n    crop_volume,\n    voxel_size,\n    threshold,\n    max_itr,\n    verbose=True,\n):\n    if verbose:\n        print(\"[Registration] voxel_size: %f, threshold: %f\" %\n              (voxel_size, threshold))\n        o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Debug)\n    s = crop_and_downsample(\n        source,\n        crop_volume,\n        down_sample_method=\"voxel\",\n        voxel_size=voxel_size,\n        trans=init_trans,\n    )\n    t = crop_and_downsample(\n        gt_target,\n        crop_volume,\n        down_sample_method=\"voxel\",\n        voxel_size=voxel_size,\n    )\n    \n    s = crop_based_target(s, t)\n    \n    reg = o3d.registration.registration_icp(\n        s,\n        t,\n        threshold,\n        np.identity(4),\n        o3d.registration.TransformationEstimationPointToPoint(True),\n        o3d.registration.ICPConvergenceCriteria(1e-6, max_itr),\n    )\n    reg.transformation = np.matmul(reg.transformation, init_trans)\n    return reg\n\n\ndef crop_based_target(s, t):\n    bbox_t = t.get_axis_aligned_bounding_box()\n\n    min_bound = bbox_t.get_min_bound()\n    max_bound = bbox_t.get_max_bound()\n\n    s_filtered = o3d.geometry.PointCloud()\n    \n    valid = np.logical_and(np.all(s.points >= min_bound, axis=1), np.all(s.points <= max_bound, axis=1))\n    s_filtered.points = o3d.utility.Vector3dVector(np.asarray(s.points)[valid])\n    \n    return s_filtered"
  },
  {
    "path": "evaluation/tnt_eval/requirements.txt",
    "content": "matplotlib>=1.3\nopen3d==0.9\n"
  },
  {
    "path": "evaluation/tnt_eval/run.py",
    "content": "# ----------------------------------------------------------------------------\n# -                   TanksAndTemples Website Toolbox                        -\n# -                    http://www.tanksandtemples.org                        -\n# ----------------------------------------------------------------------------\n# The MIT License (MIT)\n#\n# Copyright (c) 2017\n# Arno Knapitsch <arno.knapitsch@gmail.com >\n# Jaesik Park <syncle@gmail.com>\n# Qian-Yi Zhou <Qianyi.Zhou@gmail.com>\n# Vladlen Koltun <vkoltun@gmail.com>\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in\n# all copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n# THE SOFTWARE.\n# ----------------------------------------------------------------------------\n#\n# This python script is for downloading dataset from www.tanksandtemples.org\n# The dataset has a different license, please refer to\n# https://tanksandtemples.org/license/\n\n# this script requires Open3D python binding\n# please follow the intructions in setup.py before running this script.\nimport numpy as np\nimport open3d as o3d\nimport os\nimport argparse\nimport sys\nsys.path.append(os.getcwd())\n\nfrom config import scenes_tau_dict\nfrom registration import (\n    trajectory_alignment,\n    registration_vol_ds,\n    registration_unif,\n    read_trajectory,\n)\nfrom evaluation import EvaluateHisto\nfrom util import make_dir\nfrom plot import plot_graph\n\n\n\ndef run_evaluation(dataset_dir, traj_path, ply_path, out_dir):\n    scene = os.path.basename(os.path.normpath(dataset_dir))\n\n    if scene not in scenes_tau_dict:\n        print(dataset_dir, scene)\n        raise Exception(\"invalid dataset-dir, not in scenes_tau_dict\")\n\n    print(\"\")\n    print(\"===========================\")\n    print(\"Evaluating %s\" % scene)\n    print(\"===========================\")\n\n    dTau = scenes_tau_dict[scene]\n    # put the crop-file, the GT file, the COLMAP SfM log file and\n    # the alignment of the according scene in a folder of\n    # the same scene name in the dataset_dir\n    colmap_ref_logfile = os.path.join(dataset_dir, scene + \"_COLMAP_SfM.log\")\n    alignment = os.path.join(dataset_dir, scene + \"_trans.txt\")\n    gt_filen = os.path.join(dataset_dir, scene + \".ply\")\n    # gt_filen = os.path.join(dataset_dir, scene + \"_GT.ply\")\n    cropfile = os.path.join(dataset_dir, scene + \".json\")\n    map_file = os.path.join(dataset_dir, scene + \"_mapping_reference.txt\")\n\n    make_dir(out_dir)\n    \n    assert os.path.exists(ply_path), f\"ply_path {ply_path} does not exist\"\n\n    # Load reconstruction and according GT\n    print(gt_filen)\n    gt_pcd = o3d.io.read_point_cloud(gt_filen)\n    print(ply_path)\n    # pcd = o3d.io.read_point_cloud(ply_path)\n    mesh = o3d.io.read_triangle_mesh(ply_path)\n    pcd = mesh.sample_points_uniformly(len(gt_pcd.points))\n\n    gt_trans = np.loadtxt(alignment)\n    traj_to_register = read_trajectory(traj_path)\n    gt_traj_col = read_trajectory(colmap_ref_logfile)\n\n    trajectory_transform = trajectory_alignment(map_file, traj_to_register,\n                                                gt_traj_col, gt_trans, scene)\n\n    # Refine alignment by using the actual GT and MVS pointclouds\n    vol = o3d.visualization.read_selection_polygon_volume(cropfile)\n    # big pointclouds will be downlsampled to this number to speed up alignment\n    dist_threshold = dTau\n\n    # Registration refinment in 3 iterations\n    r2 = registration_vol_ds(pcd, gt_pcd, trajectory_transform, vol, dTau,\n                             dTau * 80, 20)\n    r3 = registration_vol_ds(pcd, gt_pcd, r2.transformation, vol, dTau / 2.0,\n                             dTau * 20, 20)\n    r = registration_unif(pcd, gt_pcd, r3.transformation, vol, 2 * dTau, 20)\n\n    # Histogramms and P/R/F1\n    plot_stretch = 5\n    [\n        precision,\n        recall,\n        fscore,\n        edges_source,\n        cum_source,\n        edges_target,\n        cum_target,\n    ] = EvaluateHisto(\n        pcd,\n        gt_pcd,\n        r.transformation,\n        vol,\n        dTau / 2.0,\n        dTau,\n        out_dir,\n        plot_stretch,\n        scene,\n    )\n    eva = [precision, recall, fscore]\n    # eva = [i*100 for i in eva]\n    print(\"==============================\")\n    print(\"evaluation result : %s\" % scene)\n    print(\"==============================\")\n    print(\"distance tau : %.3f\" % dTau)\n    print(\"precision : %.4f\" % eva[0])\n    print(\"recall : %.4f\" % eva[1])\n    print(\"f-score : %.4f\" % eva[2])\n    print(\"==============================\")\n    \n    with open(os.path.join(out_dir, \"evaluation.txt\"), \"w\") as f:\n        f.write(\"evaluation result : %s\\n\" % scene)\n        f.write(\"distance tau : %.3f\\n\" % dTau)\n        f.write(\"precision : %.4f\\n\" % eva[0])\n        f.write(\"recall : %.4f\\n\" % eva[1])\n        f.write(\"f-score : %.4f\\n\" % eva[2])\n\n    # Plotting\n    plot_graph(\n        scene,\n        fscore,\n        dist_threshold,\n        edges_source,\n        cum_source,\n        edges_target,\n        cum_target,\n        plot_stretch,\n        out_dir,\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--dataset-dir\",\n        type=str,\n        required=True,\n        help=\"path to a dataset/scene directory containing X.json, X.ply, ...\",\n    )\n    parser.add_argument(\n        \"--traj-path\",\n        type=str,\n        required=True,\n        help=\n        \"path to trajectory file. See `convert_to_logfile.py` to create this file.\",\n    )\n    parser.add_argument(\n        \"--ply-path\",\n        type=str,\n        required=True,\n        help=\"path to reconstruction ply file\",\n    )\n    parser.add_argument(\n        \"--out-dir\",\n        type=str,\n        default=\"\",\n        help=\n        \"output directory, default: an evaluation directory is created in the directory of the ply file\",\n    )\n    args = parser.parse_args()\n\n    if args.out_dir.strip() == \"\":\n        args.out_dir = os.path.join(os.path.dirname(args.ply_path),\n                                    \"evaluation\")\n\n    run_evaluation(\n        dataset_dir=args.dataset_dir,\n        traj_path=args.traj_path,\n        ply_path=args.ply_path,\n        out_dir=args.out_dir,\n    )\n"
  },
  {
    "path": "evaluation/tnt_eval/trajectory_io.py",
    "content": "import numpy as np\nimport open3d as o3d\n\n\nclass CameraPose:\n\n    def __init__(self, meta, mat):\n        self.metadata = meta\n        self.pose = mat\n\n    def __str__(self):\n        return (\"Metadata : \" + \" \".join(map(str, self.metadata)) + \"\\n\" +\n                \"Pose : \" + \"\\n\" + np.array_str(self.pose))\n\n\ndef convert_trajectory_to_pointcloud(traj):\n    pcd = o3d.geometry.PointCloud()\n    for t in traj:\n        pcd.points.append(t.pose[:3, 3])\n    return pcd\n\n\ndef read_trajectory(filename):\n    traj = []\n    with open(filename, \"r\") as f:\n        metastr = f.readline()\n        while metastr:\n            metadata = map(int, metastr.split())\n            mat = np.zeros(shape=(4, 4))\n            for i in range(4):\n                matstr = f.readline()\n                mat[i, :] = np.fromstring(matstr, dtype=float, sep=\" \\t\")\n            traj.append(CameraPose(metadata, mat))\n            metastr = f.readline()\n    return traj\n\n\ndef write_trajectory(traj, filename):\n    with open(filename, \"w\") as f:\n        for x in traj:\n            p = x.pose.tolist()\n            f.write(\" \".join(map(str, x.metadata)) + \"\\n\")\n            f.write(\"\\n\".join(\n                \" \".join(map(\"{0:.12f}\".format, p[i])) for i in range(4)))\n            f.write(\"\\n\")\n"
  },
  {
    "path": "evaluation/tnt_eval/util.py",
    "content": "import os\n\n\ndef make_dir(path):\n    if not os.path.exists(path):\n        os.makedirs(path)\n"
  },
  {
    "path": "gaussian_renderer/__init__.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport math\nimport torch\nimport torch.nn.functional as F\n\nfrom diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\nfrom scene.gaussian_model import GaussianModel\nfrom tools.sh_utils import eval_sh\nfrom tools.normal_utils import compute_normals\n\n\ndef render(viewpoint_camera, pc : GaussianModel, cfg, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None, \n           return_normal = True, is_all = True, dirs=None, mask_depth_thr=0.8):\n    \"\"\"\n    Render the scene. \n    \n    Background tensor (bg_color) must be on GPU!\n    \"\"\"\n\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\") + 0\n    screenspace_points_densify = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\") + 0\n    try:\n        screenspace_points.retain_grad()\n        screenspace_points_densify.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=cfg.pipline.debug,\n        f_count=0,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    means2D_densify = screenspace_points_densify\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if cfg.pipline.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n    if override_color is None:\n        if cfg.pipline.convert_SHs_python:\n            shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)\n            dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))\n            dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)\n            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)\n            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)\n        else:\n            shs = pc.get_features\n    else:\n        colors_precomp = override_color\n\n    normals_precomp = None\n    # inside, _ = pc.get_inside_gaus_normalized()\n    if return_normal:\n        normal = pc.get_normal(is_all=is_all)\n        # convert normal direction to the camera; calculate the normal in the camera coordinate\n        view_dir = means3D - viewpoint_camera.camera_center\n        normal   = normal * ((((view_dir * normal).sum(dim=-1) > 0) * 1 - 0.5) * 2)[..., None]\n        R_w2c = torch.tensor(viewpoint_camera.R.T).cuda().to(torch.float32)\n        normals_precomp = normal @ R_w2c.transpose(0, 1)        # camera coordinate\n    \n    sem_feats = pc.get_objects.squeeze(1) if cfg.optim.loss_weight.semantic > 0 else None\n    inside = None\n\n    # Rasterize visible Gaussians to image, obtain their radii (on screen). \n    rendered_out, radii = rasterizer(\n        means3D = means3D,\n        means2D = means2D,\n        means2D_densify = means2D_densify,\n        shs = shs,\n        colors_precomp = colors_precomp,\n        normals_precomp = normals_precomp,\n        semantics_precomp = sem_feats,\n        opacities = opacity,\n        scales = scales,\n        rotations = rotations,\n        cov3D_precomp = cov3D_precomp,\n        dirs = dirs,\n        inside = inside)\n    \n    chs = [3, 1, 3, 1]\n    rendered_image, rendered_depth, rendered_normal, rendered_alpha = rendered_out[:sum(chs)].split(chs, dim=0)\n    \n    with torch.no_grad():\n        mask = viewpoint_camera.mask.bool() if hasattr(viewpoint_camera, 'mask') else \\\n                torch.ones_like(rendered_depth, dtype=torch.bool).squeeze(0)\n        if cfg.optim.mask_depth_thr > 0:\n            mask1 = rendered_depth < (pc.extent * cfg.optim.mask_depth_thr)\n            mask1 = mask1.squeeze(0)\n            mask = mask & mask1\n    \n    rendered_normal = rendered_normal.permute(1, 2, 0)\n    rendered_normal = F.normalize(rendered_normal, dim = -1)\n    \n    est_normal = compute_normals(rendered_depth, viewpoint_camera.intr)\n\n    out = {\"render\": rendered_image,\n            \"depth\": rendered_depth,\n            \"normal\": rendered_normal,\n            \"est_normal\": est_normal,\n            \"alpha\": rendered_alpha,\n            \"viewspace_points\": screenspace_points,\n            \"viewspace_points_densify\": screenspace_points_densify,\n            \"visibility_filter\" : radii > 0,\n            \"mask\": mask,\n            \"radii\": radii,}\n\n    if cfg.optim.loss_weight.semantic > 0:\n        rendered_sem = rendered_out[sum(chs):sum(chs)+cfg.model.ch_sem_feat]\n        rendered_sem = pc.classifier(rendered_sem[None])[0].permute(1, 2, 0)    # [H, W, cls]\n        out.update({\"render_sem\": rendered_sem})\n    \n    if hasattr(cfg.optim.loss_weight, 'depth_var') and cfg.optim.loss_weight.depth_var > 0:\n        d1 = rendered_out[-2:-1]\n        d2 = rendered_out[-1:]\n        depth_var = d2 / rendered_alpha - (d1 / rendered_alpha) ** 2\n        out.update({\"depth_var\": depth_var})\n    \n    if hasattr(cfg.optim.loss_weight, 'distortion') and cfg.optim.loss_weight.distortion > 0:\n        rendered_dist = rendered_out[-1:]\n        out.update({\"distortion\": rendered_dist})\n\n    return out\n\n\ndef render_fast(viewpoint_camera, pc : GaussianModel, cfg, bg_color : torch.Tensor, scaling_modifier = 1.0, override_color = None):\n    \"\"\"\n    use the original Gaussian Splatting cuda code!!!!\n    \"\"\"\n \n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\") + 0\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=cfg.pipline.debug\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if cfg.pipline.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n    if override_color is None:\n        if cfg.pipline.convert_SHs_python:\n            shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)\n            dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))\n            dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)\n            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)\n            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)\n        else:\n            shs = pc.get_features\n    else:\n        colors_precomp = override_color\n\n    # Rasterize visible Gaussians to image, obtain their radii (on screen). \n    rendered_image, radii = rasterizer(\n        means3D = means3D,\n        means2D = means2D,\n        shs = shs,\n        colors_precomp = colors_precomp,\n        opacities = opacity,\n        scales = scales,\n        rotations = rotations,\n        cov3D_precomp = cov3D_precomp)\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n    return {\"render\": rendered_image,\n            \"viewspace_points\": screenspace_points,\n            \"visibility_filter\" : radii > 0,\n            \"radii\": radii}\n\n\ndef count_render(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    bg_color: torch.Tensor,\n    scaling_modifier=1.0,\n    override_color=None,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n    \"\"\"\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n        f_count=1,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n    if override_color is None:\n        if pipe.convert_SHs_python:\n            shs_view = pc.get_features.transpose(1, 2).view(\n                -1, 3, (pc.max_sh_degree + 1) ** 2\n            )\n            dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(\n                pc.get_features.shape[0], 1\n            )\n            dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)\n            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)\n            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)\n        else:\n            shs = pc.get_features\n    else:\n        colors_precomp = override_color\n\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    gaussians_count, important_score, rendered_image, radii = rasterizer(\n        means3D=means3D,\n        means2D=means2D,\n        means2D_densify=None,\n        shs=shs,\n        colors_precomp=colors_precomp,\n        normals_precomp = None,\n        semantics_precomp = None,\n        opacities=opacity,\n        scales=scales,\n        rotations=rotations,\n        cov3D_precomp=cov3D_precomp,\n    )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n    return {\n        \"render\": rendered_image,\n        \"viewspace_points\": screenspace_points,\n        \"visibility_filter\": radii > 0,\n        \"radii\": radii,\n        \"gaussians_count\": gaussians_count,\n        \"important_score\": important_score,\n    }\n\n\ndef visi_render(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    bg_color: torch.Tensor,\n    scaling_modifier=1.0,\n    override_color=None,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n    \"\"\"\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n        f_count=2,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n    if override_color is None:\n        if pipe.convert_SHs_python:\n            shs_view = pc.get_features.transpose(1, 2).view(\n                -1, 3, (pc.max_sh_degree + 1) ** 2\n            )\n            dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(\n                pc.get_features.shape[0], 1\n            )\n            dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)\n            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)\n            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)\n        else:\n            shs = pc.get_features\n    else:\n        colors_precomp = override_color\n\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    \n    countlist, important_score, rendered_image, radii = rasterizer(\n            means3D=means3D,\n            means2D=means2D,\n            means2D_densify=None,\n            shs=shs,\n            colors_precomp=colors_precomp,\n            normals_precomp = None,\n            semantics_precomp = None,\n            opacities=opacity,\n            scales=scales,\n            rotations=rotations,\n            cov3D_precomp=cov3D_precomp,\n        )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n    return {\n        \"render\": rendered_image,\n        \"viewspace_points\": screenspace_points,\n        \"visibility_filter\": radii > 0,\n        \"radii\": radii,\n        \"countlist\": countlist,\n        \"important_score\": important_score,\n    }\n\n\ndef visi_acc_render(\n    viewpoint_camera,\n    pc: GaussianModel,\n    pipe,\n    bg_color: torch.Tensor,\n    scaling_modifier=1.0,\n    override_color=None,\n):\n    \"\"\"\n    Render the scene.\n\n    Background tensor (bg_color) must be on GPU!\n    \"\"\"\n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = (\n        torch.zeros_like(\n            pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\"\n        )\n        + 0\n    )\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug,\n        f_count=3,\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n    if override_color is None:\n        if pipe.convert_SHs_python:\n            shs_view = pc.get_features.transpose(1, 2).view(\n                -1, 3, (pc.max_sh_degree + 1) ** 2\n            )\n            dir_pp = pc.get_xyz - viewpoint_camera.camera_center.repeat(\n                pc.get_features.shape[0], 1\n            )\n            dir_pp_normalized = dir_pp / dir_pp.norm(dim=1, keepdim=True)\n            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)\n            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)\n        else:\n            shs = pc.get_features\n    else:\n        colors_precomp = override_color\n\n    # Rasterize visible Gaussians to image, obtain their radii (on screen).\n    \n    countlist, radii = rasterizer(\n            means3D=means3D,\n            means2D=means2D,\n            means2D_densify=None,\n            shs=shs,\n            colors_precomp=colors_precomp,\n            normals_precomp = None,\n            semantics_precomp = None,\n            opacities=opacity,\n            scales=scales,\n            rotations=rotations,\n            cov3D_precomp=cov3D_precomp,\n        )\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n    return {\n        \"viewspace_points\": screenspace_points,\n        \"visibility_filter\": radii > 0,\n        \"radii\": radii,\n        \"countlist\": countlist,\n    }\n\n"
  },
  {
    "path": "gaussian_renderer/network_gui.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport traceback\nimport socket\nimport json\nfrom scene.cameras import MiniCam\n\nhost = \"127.0.0.1\"\nport = 6009\n\nconn = None\naddr = None\n\nlistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n\ndef init(wish_host, wish_port):\n    global host, port, listener\n    host = wish_host\n    port = wish_port\n    listener.bind((host, port))\n    listener.listen()\n    listener.settimeout(0)\n\ndef try_connect():\n    global conn, addr, listener\n    try:\n        conn, addr = listener.accept()\n        print(f\"\\nConnected by {addr}\")\n        conn.settimeout(None)\n    except Exception as inst:\n        pass\n            \ndef read():\n    global conn\n    messageLength = conn.recv(4)\n    messageLength = int.from_bytes(messageLength, 'little')\n    message = conn.recv(messageLength)\n    return json.loads(message.decode(\"utf-8\"))\n\ndef send(message_bytes, verify):\n    global conn\n    if message_bytes != None:\n        conn.sendall(message_bytes)\n    conn.sendall(len(verify).to_bytes(4, 'little'))\n    conn.sendall(bytes(verify, 'ascii'))\n\ndef receive():\n    message = read()\n\n    width = message[\"resolution_x\"]\n    height = message[\"resolution_y\"]\n\n    if width != 0 and height != 0:\n        try:\n            do_training = bool(message[\"train\"])\n            fovy = message[\"fov_y\"]\n            fovx = message[\"fov_x\"]\n            znear = message[\"z_near\"]\n            zfar = message[\"z_far\"]\n            do_shs_python = bool(message[\"shs_python\"])\n            do_rot_scale_python = bool(message[\"rot_scale_python\"])\n            keep_alive = bool(message[\"keep_alive\"])\n            scaling_modifier = message[\"scaling_modifier\"]\n            world_view_transform = torch.reshape(torch.tensor(message[\"view_matrix\"]), (4, 4)).cuda()\n            world_view_transform[:,1] = -world_view_transform[:,1]\n            world_view_transform[:,2] = -world_view_transform[:,2]\n            full_proj_transform = torch.reshape(torch.tensor(message[\"view_projection_matrix\"]), (4, 4)).cuda()\n            full_proj_transform[:,1] = -full_proj_transform[:,1]\n            custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)\n        except Exception as e:\n            print(\"\")\n            traceback.print_exc()\n            raise e\n        return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier\n    else:\n        return None, None, None, None, None, None"
  },
  {
    "path": "process_data/convert.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport json\nimport logging\nfrom argparse import ArgumentParser\nimport shutil\nimport sys\nimport importlib\n\nsys.path.append(os.getcwd())\n\n\ndef create_init_files(pinhole_dict_file, db_file, out_dir):\n    # Partially adapted from https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/run_colmap_posed.py\n    # COLMAPDatabase = getattr(importlib.import_module(f'{args.colmap_path}.scripts.python.database'), 'COLMAPDatabase')\n    from submodules.colmap.scripts.python.database import COLMAPDatabase  # NOQA\n\n    if not os.path.exists(out_dir):\n        os.mkdir(out_dir)\n\n    # create template\n    with open(pinhole_dict_file) as fp:\n        pinhole_dict = json.load(fp)\n\n    template = {}\n    cameras_line_template = '{camera_id} RADIAL {width} {height} {f} {cx} {cy} {k1} {k2}\\n'\n    images_line_template = '{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {image_name}\\n\\n'\n\n    for img_name in pinhole_dict:\n        # w, h, fx, fy, cx, cy, qvec, t\n        params = pinhole_dict[img_name]\n        w = params[0]\n        h = params[1]\n        fx = params[2]\n        # fy = params[3]\n        cx = params[4]\n        cy = params[5]\n        qvec = params[6:10]\n        tvec = params[10:13]\n\n        cam_line = cameras_line_template.format(\n            camera_id=\"{camera_id}\", width=w, height=h, f=fx, cx=cx, cy=cy, k1=0, k2=0)\n        img_line = images_line_template.format(image_id=\"{image_id}\", qw=qvec[0], qx=qvec[1], qy=qvec[2], qz=qvec[3],\n                                               tx=tvec[0], ty=tvec[1], tz=tvec[2], camera_id=\"{camera_id}\",\n                                               image_name=img_name)\n        template[img_name] = (cam_line, img_line)\n\n    # read database\n    db = COLMAPDatabase.connect(db_file)\n    table_images = db.execute(\"SELECT * FROM images\")\n    img_name2id_dict = {}\n    for row in table_images:\n        img_name2id_dict[row[1]] = row[0]\n\n    cameras_txt_lines = [template[img_name][0].format(camera_id=1)]\n    images_txt_lines = []\n    for img_name, img_id in img_name2id_dict.items():\n        image_line = template[img_name][1].format(image_id=img_id, camera_id=1)\n        images_txt_lines.append(image_line)\n\n    with open(os.path.join(out_dir, 'cameras.txt'), 'w') as fp:\n        fp.writelines(cameras_txt_lines)\n\n    with open(os.path.join(out_dir, 'images.txt'), 'w') as fp:\n        fp.writelines(images_txt_lines)\n        fp.write('\\n')\n\n    # create an empty points3D.txt\n    fp = open(os.path.join(out_dir, 'points3D.txt'), 'w')\n    fp.close()\n\n\ndef main(args):\n    colmap_command = '\"{}\"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else \"colmap\"\n    magick_command = '\"{}\"'.format(args.magick_executable) if len(args.magick_executable) > 0 else \"magick\"\n    use_gpu = 1 if not args.no_gpu else 0\n\n    if not args.skip_matching:\n        os.makedirs(args.source_path + \"/distorted/sparse\", exist_ok=True)\n\n        ## Feature extraction\n        feat_extracton_cmd = colmap_command + \" feature_extractor \"\\\n            \"--database_path \" + args.source_path + \"/distorted/database.db \\\n            --image_path \" + args.source_path + \"/input \\\n            --ImageReader.single_camera 1 \\\n            --ImageReader.camera_model \" + args.camera + \" \\\n            --SiftExtraction.use_gpu \" + str(use_gpu)\n        exit_code = os.system(feat_extracton_cmd)\n        if exit_code != 0:\n            logging.error(f\"Feature extraction failed with code {exit_code}. Exiting.\")\n            exit(exit_code)\n\n        ## Feature matching\n        feat_matching_cmd = colmap_command + \" exhaustive_matcher \\\n            --database_path \" + args.source_path + \"/distorted/database.db \\\n            --SiftMatching.use_gpu \" + str(use_gpu)\n        exit_code = os.system(feat_matching_cmd)\n        if exit_code != 0:\n            logging.error(f\"Feature matching failed with code {exit_code}. Exiting.\")\n            exit(exit_code)\n\n        if args.existing_pose:\n            db_file = os.path.join(args.source_path, 'distorted/database.db')\n            sfm_dir = os.path.join(args.source_path, 'distorted/sparse/0')\n            pinhole_dict_file = os.path.join(args.source_path, 'pinhole_dict.json')\n            create_init_files(pinhole_dict_file, db_file, sfm_dir)\n            \n        ### Bundle adjustment\n        # The default Mapper tolerance is unnecessarily large,\n        # decreasing it speeds up bundle adjustment steps.\n        mapper_cmd = (colmap_command + \" mapper \\\n            --database_path \" + args.source_path + \"/distorted/database.db \\\n            --image_path \"  + args.source_path + \"/input \\\n            --output_path \"  + args.source_path + \"/distorted/sparse \\\n            --Mapper.ba_global_function_tolerance=0.000001\")\n        exit_code = os.system(mapper_cmd)\n        if exit_code != 0:\n            logging.error(f\"Mapper failed with code {exit_code}. Exiting.\")\n            exit(exit_code)\n\n    if not args.skip_distorting:\n        ### Image undistortion\n        ## We need to undistort our images into ideal pinhole intrinsics.\n        img_undist_cmd = (colmap_command + \" image_undistorter \\\n            --image_path \" + args.source_path + \"/input \\\n            --input_path \" + args.source_path + \"/distorted/sparse/0 \\\n            --output_path \" + args.source_path + \"\\\n            --output_type COLMAP\")\n        exit_code = os.system(img_undist_cmd)\n        if exit_code != 0:\n            logging.error(f\"Mapper failed with code {exit_code}. Exiting.\")\n            exit(exit_code)\n\n    files = os.listdir(args.source_path + \"/distorted/sparse/0\")\n    os.makedirs(args.source_path + \"/sparse/0\", exist_ok=True)\n    # Copy each file from the source directory to the destination directory\n    for file in files:\n        source_file = os.path.join(args.source_path, \"distorted/sparse/0\", file)\n        destination_file = os.path.join(args.source_path, \"sparse\", \"0\", file)\n        shutil.move(source_file, destination_file)\n\n    if(args.resize):\n        print(\"Copying and resizing...\")\n\n        # Resize images.\n        os.makedirs(args.source_path + \"/images_2\", exist_ok=True)\n        os.makedirs(args.source_path + \"/images_4\", exist_ok=True)\n        os.makedirs(args.source_path + \"/images_8\", exist_ok=True)\n        # Get the list of files in the source directory\n        files = os.listdir(args.source_path + \"/images\")\n        # Copy each file from the source directory to the destination directory\n        for file in files:\n            source_file = os.path.join(args.source_path, \"images\", file)\n\n            destination_file = os.path.join(args.source_path, \"images_2\", file)\n            shutil.copy2(source_file, destination_file)\n            exit_code = os.system(magick_command + \" mogrify -resize 50% \" + destination_file)\n            if exit_code != 0:\n                logging.error(f\"50% resize failed with code {exit_code}. Exiting.\")\n                exit(exit_code)\n\n            destination_file = os.path.join(args.source_path, \"images_4\", file)\n            shutil.copy2(source_file, destination_file)\n            exit_code = os.system(magick_command + \" mogrify -resize 25% \" + destination_file)\n            if exit_code != 0:\n                logging.error(f\"25% resize failed with code {exit_code}. Exiting.\")\n                exit(exit_code)\n\n            destination_file = os.path.join(args.source_path, \"images_8\", file)\n            shutil.copy2(source_file, destination_file)\n            exit_code = os.system(magick_command + \" mogrify -resize 12.5% \" + destination_file)\n            if exit_code != 0:\n                logging.error(f\"12.5% resize failed with code {exit_code}. Exiting.\")\n                exit(exit_code)\n\n    print(\"Done.\")\n\n\nif __name__ == '__main__':\n    # This Python script is based on the shell converter script provided in the MipNerF 360 repository.\n    parser = ArgumentParser(\"Colmap converter\")\n    parser.add_argument(\"--no_gpu\", action='store_true')\n    parser.add_argument(\"--skip_matching\", action='store_true')\n    parser.add_argument(\"--skip_distorting\", action='store_true')\n    parser.add_argument(\"--source_path\", \"-s\", required=True, type=str)\n    parser.add_argument(\"--camera\", default=\"OPENCV\", type=str)\n    parser.add_argument(\"--colmap_executable\", default=\"\", type=str)\n    parser.add_argument(\"--resize\", action=\"store_true\")\n    parser.add_argument(\"--magick_executable\", default=\"\", type=str)\n    parser.add_argument(\"--existing_pose\", action='store_true')\n    parser.add_argument(\"--colmap_path\", default=\"submodules.colmap\", type=str)\n    args = parser.parse_args()\n\n    main(args)"
  },
  {
    "path": "process_data/convert_360_to_json.py",
    "content": "import os\nimport numpy as np\nimport json\nimport sys\nfrom pathlib import Path\nfrom argparse import ArgumentParser\nimport trimesh\n\ndir_path = Path(os.path.dirname(os.path.realpath(__file__))).parents[0]\nsys.path.append(dir_path.__str__())\n\nfrom process_data.convert_data_to_json import export_to_json, get_split_dict, bound_by_pose  # NOQA\n\nfrom submodules.colmap.scripts.python.database import COLMAPDatabase  # NOQA\nfrom submodules.colmap.scripts.python.read_write_model import read_model, rotmat2qvec  # NOQA\n\n\ndef create_init_files(pinhole_dict_file, db_file, out_dir):\n    # Partially adapted from https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/run_colmap_posed.py\n\n    if not os.path.exists(out_dir):\n        os.mkdir(out_dir)\n\n    # create template\n    with open(pinhole_dict_file) as fp:\n        pinhole_dict = json.load(fp)\n\n    template = {}\n    cameras_line_template = '{camera_id} RADIAL {width} {height} {f} {cx} {cy} {k1} {k2}\\n'\n    images_line_template = '{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {image_name}\\n\\n'\n\n    for img_name in pinhole_dict:\n        # w, h, fx, fy, cx, cy, qvec, t\n        params = pinhole_dict[img_name]\n        w = params[0]\n        h = params[1]\n        fx = params[2]\n        # fy = params[3]\n        cx = params[4]\n        cy = params[5]\n        qvec = params[6:10]\n        tvec = params[10:13]\n\n        cam_line = cameras_line_template.format(\n            camera_id=\"{camera_id}\", width=w, height=h, f=fx, cx=cx, cy=cy, k1=0, k2=0)\n        img_line = images_line_template.format(image_id=\"{image_id}\", qw=qvec[0], qx=qvec[1], qy=qvec[2], qz=qvec[3],\n                                               tx=tvec[0], ty=tvec[1], tz=tvec[2], camera_id=\"{camera_id}\",\n                                               image_name=img_name)\n        template[img_name] = (cam_line, img_line)\n\n    # read database\n    db = COLMAPDatabase.connect(db_file)\n    table_images = db.execute(\"SELECT * FROM images\")\n    img_name2id_dict = {}\n    for row in table_images:\n        img_name2id_dict[row[1]] = row[0]\n\n    cameras_txt_lines = [template[img_name][0].format(camera_id=1)]\n    images_txt_lines = []\n    for img_name, img_id in img_name2id_dict.items():\n        image_line = template[img_name][1].format(image_id=img_id, camera_id=1)\n        images_txt_lines.append(image_line)\n\n    with open(os.path.join(out_dir, 'cameras.txt'), 'w') as fp:\n        fp.writelines(cameras_txt_lines)\n\n    with open(os.path.join(out_dir, 'images.txt'), 'w') as fp:\n        fp.writelines(images_txt_lines)\n        fp.write('\\n')\n\n    # create an empty points3D.txt\n    fp = open(os.path.join(out_dir, 'points3D.txt'), 'w')\n    fp.close()\n\n\ndef convert_cam_dict_to_pinhole_dict(cam_dict, pinhole_dict_file):\n    # Partially adapted from https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/run_colmap_posed.py\n\n    print('Writing pinhole_dict to: ', pinhole_dict_file)\n    h = 1080\n    w = 1920\n\n    pinhole_dict = {}\n    for img_name in cam_dict:\n        W2C = cam_dict[img_name]\n\n        # params\n        fx = 0.6 * w\n        fy = 0.6 * w\n        cx = w / 2.0\n        cy = h / 2.0\n\n        qvec = rotmat2qvec(W2C[:3, :3])\n        tvec = W2C[:3, 3]\n\n        params = [w, h, fx, fy, cx, cy,\n                  qvec[0], qvec[1], qvec[2], qvec[3],\n                  tvec[0], tvec[1], tvec[2]]\n        pinhole_dict[img_name] = params\n\n    with open(pinhole_dict_file, 'w') as fp:\n        json.dump(pinhole_dict, fp, indent=2, sort_keys=True)\n\n\ndef load_COLMAP_poses(cam_file, img_dir, tf='w2c'):\n    # load img_dir namges\n    names = sorted(os.listdir(img_dir))\n\n    with open(cam_file) as f:\n        lines = f.readlines()\n\n    # C2W\n    poses = {}\n    for idx, line in enumerate(lines):\n        if idx % 5 == 0:  # header\n            img_idx, valid, _ = line.split(' ')\n            if valid != '-1':\n                poses[int(img_idx)] = np.eye(4)\n                poses[int(img_idx)]\n        else:\n            if int(img_idx) in poses:\n                num = np.array([float(n) for n in line.split(' ')])\n                poses[int(img_idx)][idx % 5-1, :] = num\n\n    if tf == 'c2w':\n        return poses\n    else:\n        # convert to W2C (follow nerf convention)\n        poses_w2c = {}\n        for k, v in poses.items():\n            poses_w2c[names[k]] = np.linalg.inv(v)\n        return poses_w2c\n\n\ndef load_transformation(trans_file):\n    with open(trans_file) as f:\n        lines = f.readlines()\n\n    trans = np.eye(4)\n    for idx, line in enumerate(lines):\n        num = np.array([float(n) for n in line.split(' ')])\n        trans[idx, :] = num\n\n    return trans\n\n\ndef align_gt_with_cam(pts, trans):\n    trans_inv = np.linalg.inv(trans)\n    pts_aligned = pts @ trans_inv[:3, :3].transpose(-1, -2) + trans_inv[:3, -1]\n    return pts_aligned\n\n\ndef main(args):\n    assert args.data_path, \"Provide path to 360 dataset\"\n    scene_list = os.listdir(args.data_path)\n    scene_list = sorted(scene_list)\n\n    for scene in scene_list:\n        scene_path = os.path.join(args.data_path, scene)\n        if not os.path.isdir(scene_path): continue\n        \n        cameras, images, points3D = read_model(os.path.join(scene_path, \"sparse/0\"), ext=\".bin\")\n\n        trans, scale, bounding_box = bound_by_pose(images)\n        trans = trans.tolist()\n        \n        export_to_json(trans, scale, scene_path, 'meta.json')\n        print('Writing data to json file: ', os.path.join(scene_path, 'meta.json'))\n\n\nif __name__ == '__main__':\n    parser = ArgumentParser()\n    parser.add_argument('--data_path', type=str, default=None, help='Path to tanks and temples dataset')\n    parser.add_argument('--run_colmap', action='store_true', help='Run colmap')\n    parser.add_argument('--export_json', action='store_true', help='export json')\n\n    args = parser.parse_args()\n\n    main(args)\n"
  },
  {
    "path": "process_data/convert_data_to_json.py",
    "content": "'''\n-----------------------------------------------------------------------------\nCopyright (c) 2023, NVIDIA CORPORATION. All rights reserved.\n\nNVIDIA CORPORATION and its licensors retain all intellectual property\nand proprietary rights in and to this software, related documentation\nand any modifications thereto. Any use, reproduction, disclosure or\ndistribution of this software and related documentation without an express\nlicense agreement from NVIDIA CORPORATION is strictly prohibited.\n-----------------------------------------------------------------------------\n'''\n\nimport numpy as np\nfrom argparse import ArgumentParser\nimport os\nimport sys\nfrom pathlib import Path\nimport json\nimport trimesh\n\ndir_path = Path(os.path.dirname(os.path.realpath(__file__))).parents[0]\nsys.path.append(dir_path.__str__())\n\nfrom submodules.colmap.scripts.python.read_write_model import read_model, qvec2rotmat  # NOQA\n\n\ndef find_closest_point(p1, d1, p2, d2):\n    # Calculate the direction vectors of the lines\n    d1_norm = d1 / np.linalg.norm(d1)\n    d2_norm = d2 / np.linalg.norm(d2)\n\n    # Create the coefficient matrix A and the constant vector b\n    A = np.vstack((d1_norm, -d2_norm)).T\n    b = p2 - p1\n\n    # Solve the linear system to find the parameters t1 and t2\n    t1, t2 = np.linalg.lstsq(A, b, rcond=None)[0]\n\n    # Calculate the closest point on each line\n    closest_point1 = p1 + d1_norm * t1\n    closest_point2 = p2 + d2_norm * t2\n\n    # Calculate the average of the two closest points\n    closest_point = 0.5 * (closest_point1 + closest_point2)\n\n    return closest_point\n\n\ndef bound_by_pose(images):\n    poses = []\n    for img in images.values():\n        rotation = qvec2rotmat(img.qvec)\n        translation = img.tvec.reshape(3, 1)\n        w2c = np.concatenate([rotation, translation], 1)\n        w2c = np.concatenate([w2c, np.array([0, 0, 0, 1])[None]], 0)\n        c2w = np.linalg.inv(w2c)\n        poses.append(c2w)\n\n    center = np.array([0.0, 0.0, 0.0])\n    for f in poses:\n        src_frame = f[0:3, :]\n        for g in poses:\n            tgt_frame = g[0:3, :]\n            p = find_closest_point(src_frame[:, 3], src_frame[:, 2], tgt_frame[:, 3], tgt_frame[:, 2])\n            center += p\n    center /= len(poses) ** 2\n\n    radius = 0.0\n    for f in poses:\n        radius += np.linalg.norm(f[0:3, 3])\n    radius /= len(poses)\n    bounding_box = [\n        [center[0] - radius, center[0] + radius],\n        [center[1] - radius, center[1] + radius],\n        [center[2] - radius, center[2] + radius],\n    ]\n    return center, radius, bounding_box\n\n\ndef bound_by_points(points3D):\n    if not isinstance(points3D, np.ndarray):\n        xyzs = np.stack([point.xyz for point in points3D.values()])\n    else:\n        xyzs = points3D\n    center = xyzs.mean(axis=0)\n    std = xyzs.std(axis=0)\n    # radius = float(std.max() * 2)  # use 2*std to define the region, equivalent to 95% percentile\n    radius = np.abs(xyzs).max(0) * 1.1\n    bounding_box = [\n        [center[0] - std[0] * 3, center[0] + std[0] * 3],\n        [center[1] - std[1] * 3, center[1] + std[1] * 3],\n        [center[2] - std[2] * 3, center[2] + std[2] * 3],\n    ]\n    return center, radius, bounding_box\n\n\ndef compute_oriented_bound(pts):\n    to_align, _ = trimesh.bounds.oriented_bounds(pts)\n    \n    scale = (np.abs((to_align[:3, :3] @ pts.vertices.T + to_align[:3, 3:]).T).max(0) * 1.2).tolist()\n    \n    return to_align.tolist(), scale\n\n\ndef split_data(names, split=10):\n    split_dict = {'train': [], 'test': []}\n    names = sorted(names)\n    \n    for i, name in enumerate(names):\n        if i % split == 0:\n            split_dict['test'].append(name)\n        else:\n            split_dict['train'].append(name)\n    \n    split_dict['train'] = sorted(split_dict['train'])\n    split_dict['test'] = sorted(split_dict['test'])\n    return split_dict\n\n\ndef get_split_dict(scene_path):\n    split_dict = None\n    \n    if os.path.exists(os.path.join(scene_path, 'train_test_lists.json')):\n        image_names = os.listdir(os.path.join(scene_path, \"images\"))\n        image_names = sorted(['{:06}'.format(int(i.split(\".\")[0])) for i in image_names])\n        \n        with open(os.path.join(scene_path, 'train_test_lists.json'), 'r') as fp:\n            split_dict = json.load(fp)\n            \n        test_split = sorted([i.split(\".\")[0] for i in split_dict['test']])\n        train_split = [i for i in image_names if i not in test_split]\n        \n        assert len(train_split) + len(test_split) == len(image_names), \"train and test split do not cover all images\"\n        \n        split_dict = {\n            'train': train_split,\n            'test': test_split,\n        }\n    \n    return split_dict\n\n\ndef check_concentric(images, ang_tol=np.pi / 6.0, radii_tol=0.5, pose_tol=0.5):\n    look_at = []\n    cam_loc = []\n    for img in images.values():\n        rotation = qvec2rotmat(img.qvec)\n        translation = img.tvec.reshape(3, 1)\n        w2c = np.concatenate([rotation, translation], 1)\n        w2c = np.concatenate([w2c, np.array([0, 0, 0, 1])[None]], 0)\n        c2w = np.linalg.inv(w2c)\n        cam_loc.append(c2w[:3, -1])\n        look_at.append(c2w[:3, 2])\n    look_at = np.stack(look_at)\n    look_at = look_at / np.linalg.norm(look_at, axis=1, keepdims=True)\n    cam_loc = np.stack(cam_loc)\n    num_images = cam_loc.shape[0]\n\n    center = cam_loc.mean(axis=0)\n    vec = center - cam_loc\n    radii = np.linalg.norm(vec, axis=1, keepdims=True)\n    vec_unit = vec / radii\n    ang = np.arccos((look_at * vec_unit).sum(axis=-1, keepdims=True))\n    ang_valid = ang < ang_tol\n    print(f\"Fraction of images looking at the center: {ang_valid.sum()/num_images:.2f}.\")\n\n    radius_mean = radii.mean()\n    radii_valid = np.isclose(radius_mean, radii, rtol=radii_tol)\n    print(f\"Fraction of images positioned around the center: {radii_valid.sum()/num_images:.2f}.\")\n\n    valid = ang_valid * radii_valid\n    print(f\"Valid fraction of concentric images: {valid.sum()/num_images:.2f}.\")\n\n    return valid.sum() / num_images > pose_tol\n\n\ndef export_to_json(trans, scale, scene_path, file_name, split_dict=None, do_split=False):\n    out = {\n        \"trans\": trans,\n        \"scale\": scale,\n    }\n\n    if do_split:\n        if split_dict is None:\n            image_names = os.listdir(os.path.join(scene_path, \"images\"))\n            image_names = ['{:06}'.format(int(i.split(\".\")[0])) for i in image_names]\n            split_dict = split_data(image_names, split=10)\n        \n        out.update(split_dict)\n\n    with open(os.path.join(scene_path, file_name), \"w\") as outputfile:\n        json.dump(out, outputfile, indent=4)\n\n    return\n\n\ndef data_to_json(args):\n    cameras, images, points3D = read_model(os.path.join(args.data_dir, \"sparse\"), ext=\".bin\")\n\n    # define bounding regions based on scene type\n    if args.scene_type == \"outdoor\":\n        if check_concentric(images):\n            center, scale, bounding_box = bound_by_pose(images)\n        else:\n            center, scale, bounding_box = bound_by_points(points3D)\n    elif args.scene_type == \"indoor\":\n        # use sfm points as a proxy to define bounding regions\n        center, scale, bounding_box = bound_by_points(points3D)\n    elif args.scene_type == \"object\":\n        # use poses as a proxy to define bounding regions\n        center, scale, bounding_box = bound_by_pose(images)\n    else:\n        raise TypeError(\"Unknown scene type\")\n\n    # export json file\n    export_to_json(list(center), scale, args.data_dir, \"meta.json\")\n    print(\"Writing data to json file: \", os.path.join(args.data_dir, \"meta.json\"))\n    return\n\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser()\n    parser.add_argument(\"--data_dir\", type=str, default=None, help=\"Path to data\")\n    parser.add_argument(\n        \"--scene_type\",\n        type=str,\n        default=\"outdoor\",\n        choices=[\"outdoor\", \"indoor\", \"object\"],\n        help=\"Select scene type. Outdoor for building-scale reconstruction; \"\n        \"indoor for room-scale reconstruction; object for object-centric scene reconstruction.\",\n    )\n    args = parser.parse_args()\n    data_to_json(args)\n"
  },
  {
    "path": "process_data/convert_dtu_to_json.py",
    "content": "'''\n-----------------------------------------------------------------------------\nCopyright (c) 2023, NVIDIA CORPORATION. All rights reserved.\n\nNVIDIA CORPORATION and its licensors retain all intellectual property\nand proprietary rights in and to this software, related documentation\nand any modifications thereto. Any use, reproduction, disclosure or\ndistribution of this software and related documentation without an express\nlicense agreement from NVIDIA CORPORATION is strictly prohibited.\n-----------------------------------------------------------------------------\n'''\n\nimport numpy as np\nimport json\nfrom argparse import ArgumentParser\nimport os\nimport cv2\nfrom PIL import Image, ImageFile\nfrom glob import glob\nimport math\nimport sys\nfrom pathlib import Path\nfrom tqdm import tqdm\nimport trimesh\n\n\ndir_path = Path(os.path.dirname(os.path.realpath(__file__))).parents[0]\nsys.path.append(dir_path.__str__())\n# from process_data.convert_data_to_json import _cv_to_gl  # noqa: E402\nfrom process_data.convert_data_to_json import export_to_json, compute_oriented_bound  # NOQA\nfrom submodules.colmap.scripts.python.database import COLMAPDatabase  # NOQA\nfrom submodules.colmap.scripts.python.read_write_model import rotmat2qvec  # NOQA\n\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\n\ndef load_K_Rt_from_P(filename, P=None):\n    # This function is borrowed from IDR: https://github.com/lioryariv/idr\n    if P is None:\n        lines = open(filename).read().splitlines()\n        if len(lines) == 4:\n            lines = lines[1:]\n        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(\" \") for x in lines)]\n        P = np.asarray(lines).astype(np.float32).squeeze()\n\n    out = cv2.decomposeProjectionMatrix(P)\n    K = out[0]\n    R = out[1]\n    t = out[2]\n\n    K = K / K[2, 2]\n    intrinsics = np.eye(4)\n    intrinsics[:3, :3] = K\n\n    pose = np.eye(4, dtype=np.float32)\n    pose[:3, :3] = R.transpose()\n    pose[:3, 3] = (t[:3] / t[3])[:, 0]\n\n    return intrinsics, pose\n\n\ndef dtu_to_json(args):\n    assert args.dtu_path, \"Provide path to DTU dataset\"\n    scene_list = os.listdir(args.dtu_path)\n\n    test_indexes = [8, 13, 16, 21, 26, 31, 34, 56]\n    for scene in tqdm(scene_list):\n        scene_path = os.path.join(args.dtu_path, scene)\n        if not os.path.isdir(scene_path) or 'scan' not in scene:\n            continue\n\n        # trans = [0., 0., 0.]\n        # scale = 1.\n        id = int(scene[4:])\n        pts = trimesh.load(os.path.join(args.dtu_path, f'Points/stl/stl{id:03}_total.ply'))\n        trans, scale = compute_oriented_bound(pts)\n        \n        out = {\n            \"trans\": trans,\n            \"scale\": scale,\n        }\n\n        # split_dict = None\n        if args.split:\n            images_names = os.listdir(os.path.join(scene_path, 'images'))\n            images_names = sorted([i for i in images_names if 'png' in i])\n            \n            train_images = [i.split('.')[0] for i in images_names if int(i.split('.')[0]) not in test_indexes]\n            test_images = [i.split('.')[0] for i in images_names if int(i.split('.')[0]) in test_indexes]\n            \n            train_images = sorted(train_images)\n            test_images = sorted(test_images)\n            \n            out.update({\n                    'train': train_images,\n                    'test': test_images,\n                    })\n        \n            assert len(train_images) + len(test_images) == len(images_names)\n        \n        file_path = os.path.join(scene_path, 'meta.json')\n        with open(file_path, \"w\") as outputfile:\n            json.dump(out, outputfile, indent=4)\n        # print('Writing data to json file: ', file_path)\n\n\ndef load_poses(scene_path):\n    camera_param = dict(np.load(os.path.join(scene_path, 'cameras_sphere.npz')))\n    images_lis = sorted(glob(os.path.join(scene_path, 'image/*.png')))\n    c2ws = {}\n    for idx, image in enumerate(images_lis):\n        image = os.path.basename(image)\n\n        world_mat = camera_param['world_mat_%d' % idx]\n        scale_mat = camera_param['scale_mat_%d' % idx]\n\n        # scale and decompose\n        P = world_mat @ scale_mat\n        P = P[:3, :4]\n        intrinsic_param, c2w = load_K_Rt_from_P(None, P)\n        c2ws[image] = c2w\n    \n    w, h = Image.open(os.path.join(scene_path, 'image', image)).size\n    \n    return c2ws, intrinsic_param, w, h\n        \n\ndef convert_cam_dict_to_pinhole_dict(scene_path, pinhole_dict_file):\n    # Partially adapted from https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/run_colmap_posed.py\n    \n    c2ws, intrinsic_param, w, h = load_poses(scene_path)\n    \n    fx = intrinsic_param[0][0]\n    fy = intrinsic_param[1][1]\n    cx = intrinsic_param[0][2]\n    cy = intrinsic_param[1][2]\n    sk_x = intrinsic_param[0][1]\n    sk_y = intrinsic_param[1][0]\n\n    print('Writing pinhole_dict to: ', pinhole_dict_file)\n\n    pinhole_dict = {}\n    for img_name in c2ws:\n        c2w = c2ws[img_name]\n        W2C = np.linalg.inv(c2w)\n\n        # params\n        qvec = rotmat2qvec(W2C[:3, :3])\n        tvec = W2C[:3, 3]\n\n        params = [w, h, fx, fy, cx, cy, sk_x, sk_y,\n                  qvec[0], qvec[1], qvec[2], qvec[3],\n                  tvec[0], tvec[1], tvec[2]]\n        pinhole_dict[img_name] = params\n    \n    with open(pinhole_dict_file, 'w') as fp:\n        pinhole_dict = {k: [float(x) for x in v] for k, v in pinhole_dict.items()}\n        json.dump(pinhole_dict, fp, indent=2, sort_keys=True)\n\n\ndef create_init_files(pinhole_dict_file, db_file, out_dir):\n    # Partially adapted from https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/run_colmap_posed.py\n\n    if not os.path.exists(out_dir):\n        os.mkdir(out_dir)\n\n    # create template\n    with open(pinhole_dict_file) as fp:\n        pinhole_dict = json.load(fp)\n\n    template = {}\n    cameras_line_template = '{camera_id} RADIAL {width} {height} {fx} {fy} {cx} {cy} {k1} {k2}\\n'\n    images_line_template = '{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {image_name}\\n\\n'\n\n    for img_name in pinhole_dict:\n        # w, h, fx, fy, cx, cy, qvec, t\n        params = pinhole_dict[img_name]\n        w = params[0]\n        h = params[1]\n        fx = params[2]\n        fy = params[3]\n        cx = params[4]\n        cy = params[5]\n        sk_x = params[6]\n        sk_y = params[7]\n        qvec = params[8:12]\n        tvec = params[12:15]\n\n        cam_line = cameras_line_template.format(\n            camera_id=\"{camera_id}\", width=w, height=h, fx=fx, fy=fy, cx=cx, cy=cy, k1=sk_x, k2=sk_y)\n        img_line = images_line_template.format(image_id=\"{image_id}\", qw=qvec[0], qx=qvec[1], qy=qvec[2], qz=qvec[3],\n                                               tx=tvec[0], ty=tvec[1], tz=tvec[2], camera_id=\"{camera_id}\",\n                                               image_name=img_name)\n        template[img_name] = (cam_line, img_line)\n\n    # read database\n    db = COLMAPDatabase.connect(db_file)\n    table_images = db.execute(\"SELECT * FROM images\")\n    img_name2id_dict = {}\n    for row in table_images:\n        img_name2id_dict[row[1]] = row[0]\n\n    cameras_txt_lines = [template[img_name][0].format(camera_id=1)]\n    images_txt_lines = []\n    for img_name, img_id in img_name2id_dict.items():\n        image_line = template[img_name][1].format(image_id=img_id, camera_id=1)\n        images_txt_lines.append(image_line)\n\n    with open(os.path.join(out_dir, 'cameras.txt'), 'w') as fp:\n        fp.writelines(cameras_txt_lines)\n\n    with open(os.path.join(out_dir, 'images.txt'), 'w') as fp:\n        fp.writelines(images_txt_lines)\n        fp.write('\\n')\n\n    # create an empty points3D.txt\n    fp = open(os.path.join(out_dir, 'points3D.txt'), 'w')\n    fp.close()\n\n\ndef init_colmap(args):\n    assert args.dtu_path, \"Provide path to DTU dataset\"\n    scene_list = os.listdir(args.dtu_path)\n    scene_list = sorted([i for i in scene_list if 'scan' in i])\n\n    pbar = tqdm(total=len(scene_list))\n    for scene in scene_list:\n        pbar.set_description(desc=f'Scene: {scene}')\n        pbar.update(1)\n        scene_path = os.path.join(args.dtu_path, scene)\n\n        if not os.path.exists(f\"{scene_path}/image\"):\n            raise Exception(f\"'image` folder cannot be found in {scene_path}.\"\n                            \"Please check the expected folder structure in DATA_PREPROCESSING.md\")\n\n        # extract features\n        os.system(f\"colmap feature_extractor --database_path {scene_path}/database.db \\\n                --image_path {scene_path}/image \\\n                --ImageReader.camera_model=RADIAL \\\n                --SiftExtraction.use_gpu=true \\\n                --SiftExtraction.num_threads=32 \\\n                --ImageReader.single_camera=true\"\n                  )\n                # --ImageReader.camera_model=RADIAL \\\n\n        # match features\n        os.system(f\"colmap sequential_matcher \\\n                --database_path {scene_path}/database.db \\\n                --SiftMatching.use_gpu=true\"\n                  )\n\n        pinhole_dict_file = os.path.join(scene_path, 'pinhole_dict.json')\n        convert_cam_dict_to_pinhole_dict(scene_path, pinhole_dict_file)\n\n        db_file = os.path.join(scene_path, 'database.db')\n        sfm_dir = os.path.join(scene_path, 'sparse')\n        # sfm_dir = os.path.join(scene_path, 'colmap')\n        create_init_files(pinhole_dict_file, db_file, sfm_dir)\n\n        # bundle adjustment\n        os.system(f\"colmap point_triangulator \\\n                --database_path {scene_path}/database.db \\\n                --image_path {scene_path}/image \\\n                --input_path {scene_path}/sparse \\\n                --output_path {scene_path}/sparse \\\n                --clear_points 1 \\\n                --Mapper.tri_ignore_two_view_tracks=true\"\n                  )\n        os.system(f\"colmap bundle_adjuster \\\n                --input_path {scene_path}/sparse \\\n                --output_path {scene_path}/sparse \\\n                --BundleAdjustment.refine_extrinsics=false\"\n                  )\n        \n        # undistortion\n        os.system(f\"colmap image_undistorter \\\n            --image_path {scene_path}/image \\\n            --input_path {scene_path}/sparse \\\n            --output_path {scene_path} \\\n            --output_type COLMAP \\\n            --max_image_size 1600\"\n                )\n\nif __name__ == '__main__':\n    parser = ArgumentParser()\n    parser.add_argument('--dtu_path', type=str, default=None)\n    parser.add_argument('--export_json', action='store_true', help='export json')\n    parser.add_argument('--run_colmap', action='store_true', help='export json')\n    parser.add_argument('--split', action='store_true', help='export json')\n\n    args = parser.parse_args()\n\n    if args.run_colmap:\n        init_colmap(args)\n\n    if args.export_json:\n        dtu_to_json(args)"
  },
  {
    "path": "process_data/convert_tnt_to_json.py",
    "content": "import os\nimport numpy as np\nimport json\nimport sys\nfrom pathlib import Path\nfrom argparse import ArgumentParser\nimport trimesh\n\ndir_path = Path(os.path.dirname(os.path.realpath(__file__))).parents[0]\nsys.path.append(dir_path.__str__())\n\nfrom process_data.convert_data_to_json import export_to_json, get_split_dict, compute_oriented_bound  # NOQA\n\nfrom submodules.colmap.scripts.python.database import COLMAPDatabase  # NOQA\nfrom submodules.colmap.scripts.python.read_write_model import rotmat2qvec  # NOQA\n\n\ndef create_init_files(pinhole_dict_file, db_file, out_dir):\n    # Partially adapted from https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/run_colmap_posed.py\n\n    if not os.path.exists(out_dir):\n        os.mkdir(out_dir)\n\n    # create template\n    with open(pinhole_dict_file) as fp:\n        pinhole_dict = json.load(fp)\n\n    template = {}\n    cameras_line_template = '{camera_id} RADIAL {width} {height} {f} {cx} {cy} {k1} {k2}\\n'\n    images_line_template = '{image_id} {qw} {qx} {qy} {qz} {tx} {ty} {tz} {camera_id} {image_name}\\n\\n'\n\n    for img_name in pinhole_dict:\n        # w, h, fx, fy, cx, cy, qvec, t\n        params = pinhole_dict[img_name]\n        w = params[0]\n        h = params[1]\n        fx = params[2]\n        # fy = params[3]\n        cx = params[4]\n        cy = params[5]\n        qvec = params[6:10]\n        tvec = params[10:13]\n\n        cam_line = cameras_line_template.format(\n            camera_id=\"{camera_id}\", width=w, height=h, f=fx, cx=cx, cy=cy, k1=0, k2=0)\n        img_line = images_line_template.format(image_id=\"{image_id}\", qw=qvec[0], qx=qvec[1], qy=qvec[2], qz=qvec[3],\n                                               tx=tvec[0], ty=tvec[1], tz=tvec[2], camera_id=\"{camera_id}\",\n                                               image_name=img_name)\n        template[img_name] = (cam_line, img_line)\n\n    # read database\n    db = COLMAPDatabase.connect(db_file)\n    table_images = db.execute(\"SELECT * FROM images\")\n    img_name2id_dict = {}\n    for row in table_images:\n        img_name2id_dict[row[1]] = row[0]\n\n    cameras_txt_lines = [template[img_name][0].format(camera_id=1)]\n    images_txt_lines = []\n    for img_name, img_id in img_name2id_dict.items():\n        image_line = template[img_name][1].format(image_id=img_id, camera_id=1)\n        images_txt_lines.append(image_line)\n\n    with open(os.path.join(out_dir, 'cameras.txt'), 'w') as fp:\n        fp.writelines(cameras_txt_lines)\n\n    with open(os.path.join(out_dir, 'images.txt'), 'w') as fp:\n        fp.writelines(images_txt_lines)\n        fp.write('\\n')\n\n    # create an empty points3D.txt\n    fp = open(os.path.join(out_dir, 'points3D.txt'), 'w')\n    fp.close()\n\n\ndef convert_cam_dict_to_pinhole_dict(cam_dict, pinhole_dict_file):\n    # Partially adapted from https://github.com/Kai-46/nerfplusplus/blob/master/colmap_runner/run_colmap_posed.py\n\n    print('Writing pinhole_dict to: ', pinhole_dict_file)\n    h = 1080\n    w = 1920\n\n    pinhole_dict = {}\n    for img_name in cam_dict:\n        W2C = cam_dict[img_name]\n\n        # params\n        fx = 0.6 * w\n        fy = 0.6 * w\n        cx = w / 2.0\n        cy = h / 2.0\n\n        qvec = rotmat2qvec(W2C[:3, :3])\n        tvec = W2C[:3, 3]\n\n        params = [w, h, fx, fy, cx, cy,\n                  qvec[0], qvec[1], qvec[2], qvec[3],\n                  tvec[0], tvec[1], tvec[2]]\n        pinhole_dict[img_name] = params\n\n    with open(pinhole_dict_file, 'w') as fp:\n        json.dump(pinhole_dict, fp, indent=2, sort_keys=True)\n\n\ndef load_COLMAP_poses(cam_file, img_dir, tf='w2c'):\n    # load img_dir namges\n    names = sorted(os.listdir(img_dir))\n\n    with open(cam_file) as f:\n        lines = f.readlines()\n\n    # C2W\n    poses = {}\n    for idx, line in enumerate(lines):\n        if idx % 5 == 0:  # header\n            img_idx, valid, _ = line.split(' ')\n            if valid != '-1':\n                poses[int(img_idx)] = np.eye(4)\n                poses[int(img_idx)]\n        else:\n            if int(img_idx) in poses:\n                num = np.array([float(n) for n in line.split(' ')])\n                poses[int(img_idx)][idx % 5-1, :] = num\n\n    if tf == 'c2w':\n        return poses\n    else:\n        # convert to W2C (follow nerf convention)\n        poses_w2c = {}\n        for k, v in poses.items():\n            poses_w2c[names[k]] = np.linalg.inv(v)\n        return poses_w2c\n\n\ndef load_transformation(trans_file):\n    with open(trans_file) as f:\n        lines = f.readlines()\n\n    trans = np.eye(4)\n    for idx, line in enumerate(lines):\n        num = np.array([float(n) for n in line.split(' ')])\n        trans[idx, :] = num\n\n    return trans\n\n\ndef align_gt_with_cam(pts, trans):\n    trans_inv = np.linalg.inv(trans)\n    pts_aligned = pts @ trans_inv[:3, :3].transpose(-1, -2) + trans_inv[:3, -1]\n    return pts_aligned\n\n\ndef compute_bound(pts):\n    bounding_box = np.array([pts.min(axis=0), pts.max(axis=0)])\n    center = bounding_box.mean(axis=0)\n    # sphere radius\n    # scale = np.max(np.linalg.norm(pts - center, axis=-1)) * 1.01\n    # cube \n    # scale = (np.abs(pts - center).max(0) * 1.2).tolist() # cuboid for street\n    scale = (np.abs(pts - center).max(0) * 1.).tolist() # cuboid for street\n    return center, scale, bounding_box.T.tolist()\n\n\ndef init_colmap(args):\n    assert args.tnt_path, \"Provide path to Tanks and Temples dataset\"\n    scene_list = os.listdir(args.tnt_path)\n    if 'Church' in scene_list: scene_list.remove('Church')\n    scene_list = sorted(scene_list)\n\n    for scene in scene_list:\n        scene_path = os.path.join(args.tnt_path, scene)\n\n        if args.run_colmap:\n            if not os.path.exists(f\"{scene_path}/images_raw\"):\n                raise Exception(f\"'images_raw` folder cannot be found in {scene_path}.\"\n                                \"Please check the expected folder structure in DATA_PREPROCESSING.md\")\n\n\n            # extract features\n            os.system(f\"colmap feature_extractor --database_path {scene_path}/database.db \\\n                    --image_path {scene_path}/images_raw \\\n                    --ImageReader.camera_model=RADIAL \\\n                    --SiftExtraction.use_gpu=true \\\n                    --SiftExtraction.num_threads=32 \\\n                    --ImageReader.single_camera=true\"\n                    )\n\n            # match features\n            os.system(f\"colmap sequential_matcher \\\n                    --database_path {scene_path}/database.db \\\n                    --SiftMatching.use_gpu=true\"\n                  )\n\n            # read poses\n            poses = load_COLMAP_poses(os.path.join(scene_path, f'{scene}_COLMAP_SfM.log'),\n                                    os.path.join(scene_path, 'images_raw'))\n\n            # convert to colmap files\n            pinhole_dict_file = os.path.join(scene_path, 'pinhole_dict.json')\n            convert_cam_dict_to_pinhole_dict(poses, pinhole_dict_file)\n\n            db_file = os.path.join(scene_path, 'database.db')\n            sfm_dir = os.path.join(scene_path, 'sparse')\n            create_init_files(pinhole_dict_file, db_file, sfm_dir)\n\n            # bundle adjustment\n            os.system(f\"colmap point_triangulator \\\n                    --database_path {scene_path}/database.db \\\n                    --image_path {scene_path}/images_raw \\\n                    --input_path {scene_path}/sparse \\\n                    --output_path {scene_path}/sparse \\\n                    --Mapper.tri_ignore_two_view_tracks=true\"\n                    )\n            os.system(f\"colmap bundle_adjuster \\\n                    --input_path {scene_path}/sparse \\\n                    --output_path {scene_path}/sparse \\\n                    --BundleAdjustment.refine_extrinsics=false\"\n                    )\n\n            # undistortion\n            os.system(f\"colmap image_undistorter \\\n                --image_path {scene_path}/images_raw \\\n                --input_path {scene_path}/sparse \\\n                --output_path {scene_path} \\\n                --output_type COLMAP \\\n                --max_image_size 1500\"\n                    )\n\n        if args.export_json:\n            # read for bounding information\n            trans = load_transformation(os.path.join(scene_path, f'{scene}_trans.txt'))\n            pts = trimesh.load(os.path.join(scene_path, f'{scene}.ply'))\n            # pts = pts.vertices\n            # pts_aligned = align_gt_with_cam(pts, trans)\n            # center, scale, bounding_box = compute_bound(pts_aligned[::100])\n            pts.vertices = align_gt_with_cam(pts.vertices, trans)\n            # pts = pts.sample(20000)\n            pts.vertices = pts.vertices[::100]\n            trans, scale = compute_oriented_bound(pts)\n            \n            split_dict = get_split_dict(scene_path)\n            \n            export_to_json(trans, scale, scene_path, 'meta.json', split_dict=split_dict)\n            print('Writing data to json file: ', os.path.join(scene_path, 'meta.json'))\n\n\nif __name__ == '__main__':\n    parser = ArgumentParser()\n    parser.add_argument('--tnt_path', type=str, default=None, help='Path to tanks and temples dataset')\n    parser.add_argument('--run_colmap', action='store_true', help='Run colmap')\n    parser.add_argument('--export_json', action='store_true', help='export json')\n\n    args = parser.parse_args()\n\n    init_colmap(args)\n"
  },
  {
    "path": "process_data/extract_mask.py",
    "content": "import argparse\nimport os\nimport gc\nimport sys\n\nimport numpy as np\nimport json\nimport torch\nfrom PIL import Image\nfrom tqdm import tqdm\nimport torch.nn.functional as F\n\n# segment anything\nfrom segment_anything import (\n    sam_model_registry,\n    sam_hq_model_registry,\n    SamPredictor\n)\nimport cv2\nimport numpy as np\nimport matplotlib.pyplot as plt\n\nsys.path.append(os.getcwd())\nfrom tools.semantic_id import text_label_dict\n\n\ntext_prompt_dict = {\n    'indoor': 'window.floor.',\n    'outdoor': 'sky.',\n}\n\n\ndef load_image(image_path):\n    # load image\n    image_pil = Image.open(image_path).convert(\"RGB\")  # load image\n\n    transform = T.Compose(\n        [\n            T.RandomResize([800], max_size=1333),\n            T.ToTensor(),\n            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n        ]\n    )\n    image, _ = transform(image_pil, None)  # 3, h, w\n    return image_pil, image\n\n\ndef print_(a):\n    pass\n\n\ndef load_model(model_config_path, model_checkpoint_path, device):\n    args = SLConfig.fromfile(model_config_path)\n    args.device = device\n    model = build_model(args)\n    checkpoint = torch.load(model_checkpoint_path, map_location=\"cpu\")\n    load_res = model.load_state_dict(clean_state_dict(checkpoint[\"model\"]), strict=False)\n    print(load_res)\n    _ = model.eval()\n    return model\n\n\ndef get_grounding_output(model, image, caption, box_threshold, text_threshold, with_logits=True, device=\"cpu\"):\n    caption = caption.lower()\n    caption = caption.strip()\n    if not caption.endswith(\".\"):\n        caption = caption + \".\"\n    model = model.to(device)\n    image = image.to(device)\n    with torch.no_grad():\n        outputs = model(image[None], captions=[caption])\n    logits = outputs[\"pred_logits\"].cpu().sigmoid()[0]  # (nq, 256)\n    boxes = outputs[\"pred_boxes\"].cpu()[0]  # (nq, 4)\n    logits.shape[0]\n\n    # filter output\n    logits_filt = logits.clone()\n    boxes_filt = boxes.clone()\n    filt_mask = logits_filt.max(dim=1)[0] > box_threshold\n    logits_filt = logits_filt[filt_mask]  # num_filt, 256\n    boxes_filt = boxes_filt[filt_mask]  # num_filt, 4\n    logits_filt.shape[0]\n\n    # get phrase\n    tokenlizer = model.tokenizer\n    tokenized = tokenlizer(caption)\n    # build pred\n    pred_phrases = []\n    for logit, box in zip(logits_filt, boxes_filt):\n        pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)\n        if with_logits:\n            pred_phrases.append(pred_phrase + f\"({str(logit.max().item())[:4]})\")\n        else:\n            pred_phrases.append(pred_phrase)\n\n    return boxes_filt, pred_phrases\n\n\ndef show_mask(mask, ax, random_color=False):\n    if random_color:\n        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)\n    else:\n        color = np.array([30/255, 144/255, 255/255, 0.6])\n    h, w = mask.shape[-2:]\n    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)\n    ax.imshow(mask_image)\n\n\ndef show_box(box, ax, label):\n    x0, y0 = box[0], box[1]\n    w, h = box[2] - box[0], box[3] - box[1]\n    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))\n    ax.text(x0, y0, label)\n\n\ndef save_mask_data(output_dir, mask_list, box_list, label_list, name):\n    value = 1\n\n    mask_img = torch.ones(mask_list.shape[-2:]) * value\n    for idx, mask in enumerate(mask_list):\n        if len(label_list) == 0: break\n        sem = label_list[idx].split('(')[0]\n        try:\n            mask_img[mask.cpu().numpy()[0] == True] = text_label_dict.get(sem, value)\n        except KeyError:\n            import pdb; pdb.set_trace()\n    \n    mask_img = mask_img.numpy().astype(np.uint8)\n    cv2.imwrite(os.path.join(output_dir, f'{name}.png'), mask_img)\n\n\ndef morphology_open(x, k1=21, k2=21):\n    out = x.float()[None]\n    p1 = (k1 - 1) // 2\n    out = -F.max_pool2d(-out, kernel_size=k1, stride=1, padding=p1)\n    out = F.max_pool2d(out, kernel_size=k1, stride=1, padding=p1)\n    return out\n\n\ndef process_image(image_name):\n    name = image_name.split('.')[0]\n    image_path = os.path.join(image_dir, image_name)\n    # load image\n    image_pil, image = load_image(image_path)\n    # visualize raw image\n    # image_pil.save(os.path.join(output_dir, \"raw_image.jpg\"))\n\n    # run grounding dino model\n    boxes_filt, pred_phrases = get_grounding_output(\n        model, image, text_prompt, box_threshold, text_threshold, device=device\n    )\n\n    image = cv2.imread(image_path)\n    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n    predictor.set_image(image)\n\n    size = image_pil.size\n    H, W = size[1], size[0]\n    for i in range(boxes_filt.size(0)):\n        boxes_filt[i] = boxes_filt[i] * torch.Tensor([W, H, W, H])\n        boxes_filt[i][:2] -= boxes_filt[i][2:] / 2\n        boxes_filt[i][2:] += boxes_filt[i][:2]\n\n    boxes_filt = boxes_filt.cpu()\n    transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)\n\n    with torch.no_grad():\n        try:\n            masks, _, _ = predictor.predict_torch(\n                point_coords = None,\n                point_labels = None,\n                boxes = transformed_boxes.to(device),\n                multimask_output = False,\n            )\n        except RuntimeError:\n            print(f\"Error in {name}\")\n            masks = torch.zeros([1, 1, H, W]).to(device).bool()\n\n    masks = masks.cpu()\n\n    if args.vis:\n        # draw output image\n        plt.figure(figsize=(10, 10))\n        plt.imshow(image)\n        for mask in masks:\n            show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)\n        for box, label in zip(boxes_filt, pred_phrases):\n            show_box(box.numpy(), plt.gca(), label)\n\n        plt.axis('off')\n        plt.savefig(\n            os.path.join(output_dir, f\"{name}_output.png\"),\n            bbox_inches=\"tight\", dpi=100, pad_inches=0.0\n        )\n        plt.close()             # important!!! close the plot to release memory\n\n    save_mask_data(output_dir, masks, boxes_filt, pred_phrases, name)\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser(\"Grounded-Segment-Anything Demo\", add_help=True)\n    parser.add_argument(\"--config\", type=str, required=True, help=\"path to config file\")\n    parser.add_argument(\n        \"--grounded_checkpoint\", type=str, required=True, help=\"path to checkpoint file\"\n    )\n    parser.add_argument(\n        \"--sam_version\", type=str, default=\"vit_h\", required=False, help=\"SAM ViT version: vit_b / vit_l / vit_h\"\n    )\n    parser.add_argument(\n        \"--sam_checkpoint\", type=str, required=False, help=\"path to sam checkpoint file\"\n    )\n    parser.add_argument(\n        \"--sam_hq_checkpoint\", type=str, default=None, help=\"path to sam-hq checkpoint file\"\n    )\n    parser.add_argument(\n        \"--use_sam_hq\", action=\"store_true\", help=\"using sam-hq for prediction\"\n    )\n    parser.add_argument(\"--input_image\", type=str, required=True, help=\"path to image file\")\n    parser.add_argument(\"--text_prompt\", type=str, default=None, help=\"text prompt\")\n    parser.add_argument(\"--scene_type\", type=str, choices=['indoor', 'outdoor'], help=\"text prompt\")\n    parser.add_argument(\"--scene\", type=str, default=None, help=\"text prompt\")\n    parser.add_argument(\n        \"--output_dir\", \"-o\", type=str, default=\"outputs\", required=True, help=\"output directory\"\n    )\n\n    parser.add_argument(\"--box_threshold\", type=float, default=0.3, help=\"box threshold\")\n    parser.add_argument(\"--text_threshold\", type=float, default=0.25, help=\"text threshold\")\n    parser.add_argument(\"--gsam_path\", dest=\"gsam_path\", help=\"path to gsam\")\n    parser.add_argument('--vis', action='store_true', help='visualize the output')\n\n    parser.add_argument(\"--device\", type=str, default=\"cpu\", help=\"running on cpu only!, default=False\")\n    args = parser.parse_args()\n\n    gsam_path = args.gsam_path\n\n    sys.path.append(args.gsam_path)\n    sys.path.append(os.path.join(gsam_path, \"GroundingDINO\"))\n    sys.path.append(os.path.join(gsam_path, \"segment_anything\"))\n\n    # Grounding DINO\n    import GroundingDINO.groundingdino.datasets.transforms as T\n    from GroundingDINO.groundingdino.models import build_model\n    from GroundingDINO.groundingdino.util.slconfig import SLConfig\n    from GroundingDINO.groundingdino.util.utils import clean_state_dict, get_phrases_from_posmap\n\n    # print = print_\n    seed = 0\n    np.random.seed(seed)\n    torch.manual_seed(seed)         # sets seed on the current CPU & all GPUs\n    # cfg\n    config_file = args.config  # change the path of the model config file\n    grounded_checkpoint = args.grounded_checkpoint  # change the path of the model\n    sam_version = args.sam_version\n    sam_checkpoint = args.sam_checkpoint\n    sam_hq_checkpoint = args.sam_hq_checkpoint\n    use_sam_hq = args.use_sam_hq\n    image_dir = args.input_image\n    if args.text_prompt is not None:\n        text_prompt = args.text_prompt\n    else:\n        text_prompt = text_prompt_dict[args.scene_type]\n        if args.scene is not None:\n            text_prompt = text_prompt_dict.get(args.scene, text_prompt_dict[args.scene_type])\n            \n    output_dir = args.output_dir\n    box_threshold = args.box_threshold\n    text_threshold = args.text_threshold\n    device = args.device\n\n    # make dir\n    os.makedirs(output_dir, exist_ok=True)\n    # load model\n    model = load_model(config_file, grounded_checkpoint, device=device)\n\n    image_names = os.listdir(image_dir)\n    image_names = sorted([i for i in image_names if i.endswith(\".jpg\") or i.endswith(\".png\")])\n    # initialize SAM\n    if use_sam_hq:\n        predictor = SamPredictor(sam_hq_model_registry[sam_version](checkpoint=sam_hq_checkpoint).to(device))\n    else:\n        predictor = SamPredictor(sam_model_registry[sam_version](checkpoint=sam_checkpoint).to(device))\n\n    for image_name in tqdm(image_names):\n        process_image(image_name)"
  },
  {
    "path": "process_data/extract_normal.py",
    "content": "import os\nimport sys\nimport glob\nimport math\nimport struct\nimport argparse\nimport numpy as np\nimport collections\n\nimport torch\nimport torch.nn.functional as F\nfrom torchvision import transforms\nfrom PIL import Image, ImageFile\nfrom tqdm import tqdm\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\nsys.path.append(os.getcwd())\nfrom tools.general_utils import set_random_seed\n\n\nCamera = collections.namedtuple(\n    \"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"])\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12)\n}\nCAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)\n                         for camera_model in CAMERA_MODELS])\nCAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)\n                           for camera_model in CAMERA_MODELS])\n\n\n\ndef get_args(test=False):\n    parser = get_default_parser()\n\n    #↓↓↓↓\n    #NOTE: project-specific args\n    parser.add_argument('--NNET_architecture', type=str, default='v02')\n    parser.add_argument('--NNET_output_dim', type=int, default=3, help='{3, 4}')\n    parser.add_argument('--NNET_output_type', type=str, default='R', help='{R, G}')\n    parser.add_argument('--NNET_feature_dim', type=int, default=64)\n    parser.add_argument('--NNET_hidden_dim', type=int, default=64)\n\n    parser.add_argument('--NNET_encoder_B', type=int, default=5)\n\n    parser.add_argument('--NNET_decoder_NF', type=int, default=2048)\n    parser.add_argument('--NNET_decoder_BN', default=False, action=\"store_true\")\n    parser.add_argument('--NNET_decoder_down', type=int, default=8)\n    parser.add_argument('--NNET_learned_upsampling', default=False, action=\"store_true\")\n\n    parser.add_argument('--NRN_prop_ps', type=int, default=5)\n    parser.add_argument('--NRN_num_iter_train', type=int, default=5)\n    parser.add_argument('--NRN_num_iter_test', type=int, default=5)\n    parser.add_argument('--NRN_ray_relu', default=False, action=\"store_true\")\n\n    parser.add_argument('--loss_fn', type=str, default='AL')\n    parser.add_argument('--loss_gamma', type=float, default=0.8)\n    parser.add_argument('--outdir', type=str, default='/your/log/path/')\n    #↑↑↑↑\n\n    # read arguments from txt file\n    assert '.txt' in sys.argv[1]\n    arg_filename_with_prefix = '@' + sys.argv[1]\n    args = parser.parse_args([arg_filename_with_prefix] + sys.argv[2:])\n\n    #↓↓↓↓\n    #NOTE: update args\n    args.exp_root = os.path.join(args.outdir, 'dsine')\n    args.load_normal = True\n    args.load_intrins = True\n    #↑↑↑↑\n\n    # set working dir\n    exp_dir = os.path.join(args.exp_root, args.exp_name)\n\n    args.output_dir = os.path.join(exp_dir, args.exp_id)\n    return args\n\n\ndef focal2fov(focal, pixels):\n    return 2*math.atan(pixels/(2*focal))\n\n\ndef read_next_bytes(fid, num_bytes, format_char_sequence, endian_character=\"<\"):\n    \"\"\"Read and unpack the next bytes from a binary file.\n    :param fid:\n    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    :param endian_character: Any of {@, =, <, >, !}\n    :return: Tuple of read and unpacked values.\n    \"\"\"\n    data = fid.read(num_bytes)\n    return struct.unpack(endian_character + format_char_sequence, data)\n\n\ndef read_intrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_cameras = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_cameras):\n            camera_properties = read_next_bytes(\n                fid, num_bytes=24, format_char_sequence=\"iiQQ\")\n            camera_id = camera_properties[0]\n            model_id = camera_properties[1]\n            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name\n            width = camera_properties[2]\n            height = camera_properties[3]\n            num_params = CAMERA_MODEL_IDS[model_id].num_params\n            params = read_next_bytes(fid, num_bytes=8*num_params,\n                                     format_char_sequence=\"d\"*num_params)\n            cameras[camera_id] = Camera(id=camera_id,\n                                        model=model_name,\n                                        width=width,\n                                        height=height,\n                                        params=np.array(params))\n        assert len(cameras) == num_cameras\n    return cameras\n\n\ndef read_intrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    cameras = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                camera_id = int(elems[0])\n                model = elems[1]\n                assert model == \"PINHOLE\", \"While the loader support other types, the rest of the code assumes PINHOLE\"\n                width = int(elems[2])\n                height = int(elems[3])\n                params = np.array(tuple(map(float, elems[4:])))\n                cameras[camera_id] = Camera(id=camera_id, model=model,\n                                            width=width, height=height,\n                                            params=params)\n    return cameras\n\n\ndef load_intrinsic_colmap(path):\n    intr_dir = os.path.join(path, \"sparse\", \"0\")\n    if not os.path.exists(intr_dir):\n        intr_dir = os.path.join(path, \"sparse\")\n    # support only one camera for now\n    try:\n        cameras_intrinsic_file = os.path.join(intr_dir, \"cameras.bin\")\n        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)\n    except:\n        cameras_intrinsic_file = os.path.join(intr_dir, \"cameras.txt\")\n        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)\n\n    intrinsics = []\n    for idx, key in enumerate(cam_intrinsics):\n        intrinsic = np.eye(3)\n        intrinsic = torch.eye(3, dtype=torch.float32)\n\n        intr = cam_intrinsics[key]\n        height = intr.height\n        width = intr.width\n\n        if intr.model==\"SIMPLE_PINHOLE\":\n            focal_length_x = intr.params[0]\n            FovY = focal2fov(focal_length_x, height)\n            FovX = focal2fov(focal_length_x, width)\n        elif intr.model==\"PINHOLE\":\n            focal_length_x = intr.params[0]\n            focal_length_y = intr.params[1]\n            FovY = focal2fov(focal_length_y, height)\n            FovX = focal2fov(focal_length_x, width)\n        else:\n            assert False, \"Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!\"\n\n    \n        intrinsic[0, 0] = focal_length_x # FovX\n        intrinsic[1, 1] = focal_length_y # FovY\n        intrinsic[0, 2] = width / 2\n        intrinsic[1, 2] = height / 2\n        \n        intrinsics.append(intrinsic)\n    \n    intrinsics = torch.stack(intrinsics, axis=0)\n\n    return intrinsics\n\n\ndef test_samples(args, model, intrins=None, device='cpu'):\n    img_paths = glob.glob(f'{args.img_path}/*.png') + glob.glob(f'{args.img_path}/*.jpg') + glob.glob(f'{args.img_path}/*.JPG')\n    img_paths.sort()\n\n    # normalize\n    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n\n    intrin = load_intrinsic_colmap(args.intrins_path).to(device)\n    os.makedirs(args.output_path, exist_ok=True)\n    \n    with torch.no_grad():\n        for img_path in tqdm(img_paths):\n            ext = os.path.splitext(img_path)[1]\n            img = Image.open(img_path).convert('RGB')\n            img = np.array(img).astype(np.float32) / 255.0\n            img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).to(device)\n            _, _, orig_H, orig_W = img.shape\n\n            # zero-pad the input image so that both the width and height are multiples of 32\n            lrtb = utils.get_padding(orig_H, orig_W)\n            img = F.pad(img, lrtb, mode=\"constant\", value=0.0)\n            img = normalize(img)\n            intrins = intrin.clone()\n            intrins[:, 0, 2] += lrtb[0]\n            intrins[:, 1, 2] += lrtb[2]\n\n            pred_norm = model(img, intrins=intrins)[-1]\n            pred_norm = pred_norm[:, :, lrtb[2]:lrtb[2]+orig_H, lrtb[0]:lrtb[0]+orig_W]\n\n            # save to output folder\n            img_name = os.path.basename(img_path)\n            # NOTE: by saving the prediction as uint8 png format, you lose a lot of precision\n            # if you want to use the predicted normals for downstream tasks, we recommend saving them as float32 NPY files\n            pred_norm_np = pred_norm.cpu().detach().numpy()[0,:,:,:].transpose(1, 2, 0) # (H, W, 3) -1, 1\n            \n            if args.vis:\n                pred_norm_np = ((pred_norm_np + 1.0) / 2.0 * 255.0).astype(np.uint8)\n                target_path = os.path.join(args.output_path, img_name.replace(ext, '.png'))\n                im = Image.fromarray(pred_norm_np)\n                im.save(target_path)\n            else:\n                target_path = os.path.join(args.output_path, img_name.replace(ext, '.npz'))\n                np.savez_compressed(target_path, pred_norm_np.astype(np.float16))\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--ckpt', default='dsine', type=str, help='path to model checkpoint')\n    parser.add_argument('--mode', default='samples', type=str, help='{samples}')\n    parser.add_argument(\"--dsine_path\", dest=\"dsine_path\", help=\"path to rgb image\")\n    parser.add_argument(\"--img_path\", dest=\"img_path\", help=\"path to rgb image\")\n    parser.add_argument(\"--intrins_path\", dest=\"intrins_path\", help=\"path to rgb image\")\n    parser.add_argument(\"--output_path\", dest=\"output_path\", help=\"path to where output image should be stored\")\n    parser.add_argument('--vis', action='store_true', help='visualize the output')\n    args = parser.parse_args()\n    \n    dsine_path = args.dsine_path\n    dsine_path = os.path.abspath(dsine_path)\n\n    sys.path.append(dsine_path)\n\n    # define model\n    device = torch.device('cuda')\n    set_random_seed(0)\n\n    import utils.utils as utils\n    from projects import get_default_parser\n    from models.dsine.v02 import DSINE_v02 as DSINE\n    \n    cfg_path = f'{args.dsine_path}/projects/dsine/experiments/exp001_cvpr2024/dsine.txt'\n    sys.argv = [sys.argv[0], cfg_path]\n    cfg = get_args(test=True)\n    \n    model = DSINE(cfg).to(device)\n    model.pixel_coords = model.pixel_coords.to(device)\n    model = utils.load_checkpoint(args.ckpt, model)\n    model.eval()\n    \n    # # # Load the normal predictor model from torch hub\n    # model = torch.hub.load(\"hugoycj/DSINE-hub\", \"DSINE\", trust_repo=True)\n    \n    if args.mode == 'samples':\n        test_samples(args, model, intrins=None, device=device)\n"
  },
  {
    "path": "process_data/extract_normal_geo.py",
    "content": "# A reimplemented version in public environments by Xiao Fu and Mu Hu\n\nimport os\nimport sys\nimport logging\nimport argparse\n\nimport numpy as np\nimport torch\nfrom PIL import Image, ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES = True\nfrom tqdm.auto import tqdm\n\n\nif __name__==\"__main__\":\n    \n    logging.basicConfig(level=logging.INFO)\n    \n    '''Set the Args'''\n    parser = argparse.ArgumentParser(\n        description=\"Run MonoDepthNormal Estimation using Stable Diffusion.\"\n    )\n    parser.add_argument(\"--code_path\", help=\"path to code directory\", type=str,\n                        default=\"~/code/geowizard/geowizard\")\n    parser.add_argument(\n        \"--pretrained_model_path\",\n        type=str,\n        default='lemonaddie/geowizard',\n        help=\"pretrained model path from hugging face or local dir\",\n    )    \n    parser.add_argument(\n        \"--input_dir\", type=str, required=True, help=\"Input directory.\"\n    )\n\n    parser.add_argument(\n        \"--output_dir\", type=str, required=True, help=\"Output directory.\"\n    )\n    parser.add_argument(\n        \"--domain\",\n        type=str,\n        default='indoor',\n        required=True,\n        help=\"domain prediction\",\n    )   \n\n    # inference setting\n    parser.add_argument(\n        \"--denoise_steps\",\n        type=int,\n        default=10,\n        help=\"Diffusion denoising steps, more steps results in higher accuracy but slower inference speed.\",\n    )\n    parser.add_argument(\n        \"--ensemble_size\",\n        type=int,\n        default=10,\n        help=\"Number of predictions to be ensembled, more inference gives better results but runs slower.\",\n    )\n    parser.add_argument(\n        \"--half_precision\",\n        action=\"store_true\",\n        help=\"Run with half-precision (16-bit float), might lead to suboptimal result.\",\n    )\n\n    # resolution setting\n    parser.add_argument(\n        \"--processing_res\",\n        type=int,\n        default=768,\n        help=\"Maximum resolution of processing. 0 for using input image resolution. Default: 768.\",\n    )\n    parser.add_argument(\n        \"--output_processing_res\",\n        action=\"store_true\",\n        help=\"When input is resized, out put depth at resized operating resolution. Default: False.\",\n    )\n\n    # depth map colormap\n    parser.add_argument(\n        \"--color_map\",\n        type=str,\n        default=\"Spectral\",\n        help=\"Colormap used to render depth predictions.\",\n    )\n    # other settings\n    parser.add_argument(\"--seed\", type=int, default=None, help=\"Random seed.\")\n    parser.add_argument(\n        \"--batch_size\",\n        type=int,\n        default=0,\n        help=\"Inference batch size. Default: 0 (will be set automatically).\",\n    )\n    \n    args = parser.parse_args()\n    sys.path.append(args.code_path)\n    \n    from models.geowizard_pipeline import DepthNormalEstimationPipeline\n    from utils.seed_all import seed_all\n    from utils.depth2normal import *\n    \n    checkpoint_path = args.pretrained_model_path\n    output_dir = args.output_dir\n    denoise_steps = args.denoise_steps\n    ensemble_size = args.ensemble_size\n    \n    if ensemble_size>15:\n        logging.warning(\"long ensemble steps, low speed..\")\n    \n    half_precision = args.half_precision\n\n    processing_res = args.processing_res\n    match_input_res = not args.output_processing_res\n    domain = args.domain\n\n    color_map = args.color_map\n    seed = args.seed\n    batch_size = args.batch_size\n    \n    if batch_size==0:\n        batch_size = 1  # set default batchsize\n    \n    # -------------------- Preparation --------------------\n    # Random seed\n    if seed is None:\n        import time\n        seed = int(time.time())\n    seed_all(seed)\n\n    # Output directories\n    output_dir_color = os.path.join(output_dir, f\"depth_colored_{domain}\")\n    # output_dir_npy = os.path.join(output_dir, \"depth_npy\")\n    # output_dir_normal_npy = os.path.join(output_dir, \"normal_npy\")\n    output_dir_npy = os.path.join(output_dir, f\"depth_npz_{domain}\")\n    output_dir_normal_npy = os.path.join(output_dir, f\"normal_npz_{domain}\")\n    output_dir_normal_color = os.path.join(output_dir, f\"normal_colored_{domain}\")\n    os.makedirs(output_dir, exist_ok=True)\n    os.makedirs(output_dir_color, exist_ok=True)\n    os.makedirs(output_dir_npy, exist_ok=True)\n    os.makedirs(output_dir_normal_npy, exist_ok=True)\n    os.makedirs(output_dir_normal_color, exist_ok=True)\n    logging.info(f\"output dir = {output_dir}\")\n\n    # -------------------- Device --------------------\n    if torch.cuda.is_available():\n        device = torch.device(\"cuda\")\n    else:\n        device = torch.device(\"cpu\")\n        logging.warning(\"CUDA is not available. Running on CPU will be slow.\")\n    logging.info(f\"device = {device}\")\n\n    # -------------------- Data --------------------\n    input_dir = args.input_dir\n    test_files = sorted(os.listdir(input_dir))\n    n_images = len(test_files)\n    if n_images > 0:\n        logging.info(f\"Found {n_images} images\")\n    else:\n        logging.error(f\"No image found\")\n        exit(1)\n\n    # -------------------- Model --------------------\n    if half_precision:\n        dtype = torch.float16\n        logging.info(f\"Running with half precision ({dtype}).\")\n    else:\n        dtype = torch.float32\n\n    # declare a pipeline\n    pipe = DepthNormalEstimationPipeline.from_pretrained(checkpoint_path, torch_dtype=dtype)\n\n    logging.info(\"loading pipeline whole successfully.\")\n    \n    try:\n        pipe.enable_xformers_memory_efficient_attention()\n    except:\n        pass  # run without xformers\n\n    pipe = pipe.to(device)\n\n    # -------------------- Inference and saving --------------------\n    with torch.no_grad():\n        os.makedirs(output_dir, exist_ok=True)\n\n        for test_file in tqdm(test_files, desc=\"Estimating Depth & Normal\", leave=True):\n            rgb_path = os.path.join(input_dir, test_file)\n            rgb_name_base = os.path.splitext(os.path.basename(rgb_path))[0]\n            pred_name_base = rgb_name_base # + \"_pred\"\n\n            normal_npz_save_path = os.path.join(output_dir_normal_npy, f\"{pred_name_base}.npz\")\n            if os.path.exists(normal_npz_save_path):\n                continue\n                # logging.warning(f\"Existing file: '{normal_npz_save_path}' will be overwritten\")\n                \n            # Read input image\n            input_image = Image.open(rgb_path)\n\n            # predict the depth here\n            pipe_out = pipe(input_image,\n                denoising_steps = denoise_steps,\n                ensemble_size= ensemble_size,\n                processing_res = processing_res,\n                match_input_res = match_input_res,\n                domain = domain,\n                color_map = color_map,\n                show_progress_bar = False,\n            )\n\n            depth_pred: np.ndarray = pipe_out.depth_np\n            depth_colored: Image.Image = pipe_out.depth_colored\n            normal_pred: np.ndarray = pipe_out.normal_np\n            normal_colored: Image.Image = pipe_out.normal_colored\n\n            # Save as npy\n            # npy_save_path = os.path.join(output_dir_npy, f\"{pred_name_base}.npy\")\n            npy_save_path = os.path.join(output_dir_npy, f\"{pred_name_base}.npz\")\n            if os.path.exists(npy_save_path):\n                logging.warning(f\"Existing file: '{npy_save_path}' will be overwritten\")\n            # np.save(npy_save_path, depth_pred)\n            np.savez_compressed(npy_save_path, depth_pred)\n\n            # normal_npy_save_path = os.path.join(output_dir_normal_npy, f\"{pred_name_base}.npy\")\n            normal_npz_save_path = os.path.join(output_dir_normal_npy, f\"{pred_name_base}.npz\")\n            if os.path.exists(normal_npz_save_path):\n                logging.warning(f\"Existing file: '{normal_npz_save_path}' will be overwritten\")\n            # np.save(normal_npy_save_path, normal_pred)\n            np.savez_compressed(normal_npz_save_path, normal_pred)\n\n            # Colorize\n            # depth_colored_save_path = os.path.join(output_dir_color, f\"{pred_name_base}_colored.png\")\n            depth_colored_save_path = os.path.join(output_dir_color, f\"{pred_name_base}.png\")\n            if os.path.exists(depth_colored_save_path):\n                logging.warning(\n                    f\"Existing file: '{depth_colored_save_path}' will be overwritten\"\n                )\n            depth_colored.save(depth_colored_save_path)\n\n            normal_colored_save_path = os.path.join(output_dir_normal_color, f\"{pred_name_base}_colored.png\")\n            if os.path.exists(normal_colored_save_path):\n                logging.warning(\n                    f\"Existing file: '{normal_colored_save_path}' will be overwritten\"\n                )\n            normal_colored.save(normal_colored_save_path)\n"
  },
  {
    "path": "process_data/visualize_colmap.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8b8d7b17-af50-42cd-b531-ef61c49c9e61\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Set the work directory to the imaginaire root.\\n\",\n    \"import os, sys, time\\n\",\n    \"import pathlib\\n\",\n    \"\\n\",\n    \"root_dir = pathlib.Path().absolute().parents[0]\\n\",\n    \"os.chdir(root_dir)\\n\",\n    \"print(f\\\"Root Directory Path: {root_dir}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"id\": \"2b5b9e2f-841c-4815-92e0-0c76ed46da62\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Import Python libraries.\\n\",\n    \"import numpy as np\\n\",\n    \"import torch\\n\",\n    \"import k3d\\n\",\n    \"import json\\n\",\n    \"import trimesh\\n\",\n    \"import plotly.graph_objs as go\\n\",\n    \"from collections import OrderedDict\\n\",\n    \"# Import imaginaire modules.\\n\",\n    \"from submodules.colmap.scripts.python.read_write_model import read_model\\n\",\n    \"# from tools import camera, visualize\\n\",\n    \"from tools.camera import quaternion\\n\",\n    \"from tools.visualize import k3d_visualize_pose, plotly_visualize_pose\\n\",\n    \"from process_data.convert_tnt_to_json import load_transformation, align_gt_with_cam\\n\",\n    \"from tools.camera_utils import cubic_camera, grid_camera, around_camera, up_camera, bb_camera\\n\",\n    \"from tools.math_utils import inv_normalize_pts\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"76033016-2d92-4a5d-9e50-3978553e8df4\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Read the COLMAP data.\\n\",\n    \"# colmap_path = \\\"datasets/lego_ds2\\\"\\n\",\n    \"scene = 'Barn'\\n\",\n    \"colmap_path = f\\\"/your/path/tnt/{scene}\\\"\\n\",\n    \"# read piont clouds from lidar # point cloud\\n\",\n    \"pcd = trimesh.load(os.path.join(colmap_path, '{}.ply'.format(colmap_path.split('/')[-1])))\\n\",\n    \"# scene = 'c49a8c6cff'\\n\",\n    \"# colmap_path = f\\\"/your/path/ScanNet++/{scene}/dslr\\\"\\n\",\n    \"# pcd = trimesh.load(os.path.join(colmap_path, '../scans/mesh_aligned_0.05.ply'))\\n\",\n    \"view_sample_camera = False\\n\",\n    \"cameras, images, points_3D = read_model(path=f\\\"{colmap_path}/sparse\\\", ext=\\\".bin\\\") # w2c extrinsics\\n\",\n    \"# Convert camera poses.\\n\",\n    \"images = OrderedDict(sorted(images.items()))\\n\",\n    \"qvecs = torch.from_numpy(np.stack([image.qvec for image in images.values()]))\\n\",\n    \"tvecs = torch.from_numpy(np.stack([image.tvec for image in images.values()]))\\n\",\n    \"# Rs = camera.quaternion.q_to_R(qvecs)\\n\",\n    \"Rs = quaternion.q_to_R(qvecs)\\n\",\n    \"poses = torch.cat([Rs, tvecs[..., None]], dim=-1)  # [N,3,4]  w2c\\n\",\n    \"print(f\\\"# images: {len(poses)}\\\")\\n\",\n    \"print(\\\"camera height: {}\\\".format(poses[:, 1, 3].mean()))\\n\",\n    \"\\n\",\n    \"# # Get the sparse 3D points and the colors. colmap\\n\",\n    \"# xyzs = torch.from_numpy(np.stack([point.xyz for point in points_3D.values()]))\\n\",\n    \"# rgbs = np.stack([point.rgb for point in points_3D.values()])\\n\",\n    \"# rgbs_int32 = (rgbs[:, 0] * 2**16 + rgbs[:, 1] * 2**8 + rgbs[:, 2]).astype(np.uint32)\\n\",\n    \"# print(f\\\"# points: {len(xyzs)}\\\")\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"if os.path.exists(os.path.join(colmap_path, f'{scene}_trans.txt')):\\n\",\n    \"    trans = load_transformation(os.path.join(colmap_path, f'{scene}_trans.txt'))\\n\",\n    \"    pcd.vertices = align_gt_with_cam(pcd.vertices, trans)\\n\",\n    \"    \\n\",\n    \"xyzs = pcd.vertices[::500]\\n\",\n    \"# xyzs = pcd.vertices\\n\",\n    \"rgbs = np.random.randint(0, 255, xyzs.shape)\\n\",\n    \"rgbs_int32 = (rgbs[:, 0] * 2**16 + rgbs[:, 1] * 2**8 + rgbs[:, 2]).astype(np.uint32)\\n\",\n    \"print(f\\\"# points: {len(xyzs)}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 4,\n   \"id\": \"47862ee1-286c-4877-a181-4b33b7733719\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis_depth = 0.2\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 5,\n   \"id\": \"b6cf60ec-fe6a-43ba-9aaf-e3c7afd88208\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Visualize the bounding sphere.\\n\",\n    \"json_fname = f\\\"{colmap_path}/meta.json\\\"\\n\",\n    \"with open(json_fname) as file:\\n\",\n    \"    meta = json.load(file)\\n\",\n    \"trans = np.array(meta[\\\"trans\\\"])\\n\",\n    \"scale = np.array(meta[\\\"scale\\\"])\\n\",\n    \"# ------------------------------------------------------------------------------------\\n\",\n    \"# These variables can be adjusted to make the bounding sphere fit the region of interest.\\n\",\n    \"# The adjusted values can then be set in the config as data.readjust.center and data.readjust.scale\\n\",\n    \"readjust_center = np.array([0., 0., 0.])\\n\",\n    \"readjust_scale = np.array([1., 1., 1.]) # * 1.1\\n\",\n    \"# save adjusted values\\n\",\n    \"readjust = {\\n\",\n    \"    'scale': readjust_scale.tolist(),\\n\",\n    \"    'trans': readjust_center.tolist()\\n\",\n    \"}\\n\",\n    \"redjust_fname = f'{colmap_path}/readjust.json'\\n\",\n    \"with open(redjust_fname, \\\"w\\\") as outputfile:\\n\",\n    \"    json.dump(readjust, outputfile, indent=2)\\n\",\n    \"# ------------------------------------------------------------------------------------\\n\",\n    \"if trans.ndim == 1:\\n\",\n    \"    trans += readjust_center\\n\",\n    \"scale *= readjust_scale\\n\",\n    \"# Make some points to hallucinate a bounding sphere.\\n\",\n    \"# sphere_points = np.random.randn(100000, 3)\\n\",\n    \"sphere_points = np.random.rand(100000, 3) * 2 - 1\\n\",\n    \"# sphere_points = sphere_points / np.linalg.norm(sphere_points, axis=-1, keepdims=True) # Unit sphere\\n\",\n    \"# sphere_points[:, 0] = -1 # up\\n\",\n    \"for i in range(3): sphere_points[i::3, i] = sphere_points[i::3, i] / np.abs(sphere_points[i::3, i]) # Unit cube\\n\",\n    \"sphere_points = np.concatenate([sphere_points, np.zeros([1, 3])], axis=0) # center point\\n\",\n    \"# sphere_points[-1, 0] = 5\\n\",\n    \"\\n\",\n    \"sphere_points = inv_normalize_pts(sphere_points, trans, scale)\\n\",\n    \"\\n\",\n    \"# sphere_points[:, 1] = -1.1\\n\",\n    \"\\n\",\n    \"# sample up cameras\\n\",\n    \"if view_sample_camera:\\n\",\n    \"    height = poses[:, 1, 3].mean()\\n\",\n    \"    # height = -1\\n\",\n    \"    # sample_poses = cubic_camera(200, trans, scale)\\n\",\n    \"    # sample_poses = around_camera(500, trans, scale, height)\\n\",\n    \"    # sample_poses = bb_camera(500, trans, scale, height, up=False, around=True)\\n\",\n    \"    sample_poses = bb_camera(200, trans, scale, height=height, up=True, around=True, bidirect=True) # , look_mode='direction'\\n\",\n    \"    # sample_poses = up_camera(500, trans, scale)\\n\",\n    \"    # sample_poses = grid_camera(trans, scale)\\n\",\n    \"\\n\",\n    \"    # sample_poses = torch.from_numpy(poses[:, :3])\\n\",\n    \"    sample_poses = sample_poses[:, :3]\\n\",\n    \"\\n\",\n    \"    # poses = torch.cat([poses, sample_poses], dim=0)\\n\",\n    \"    poses = sample_poses # [::6]\\n\",\n    \"    # print(f\\\"# poses: {len(poses)}\\\")\\n\",\n    \"\\n\",\n    \"    # print(f\\\"center: {trans[:3, 3:].T}\\\")\\n\",\n    \"    # print(f\\\"scale: {scale}\\\")\\n\",\n    \"    # print(\\\"up: {}\\\".format(trans[1, 3] - scale[1] * 0.5))\\n\",\n    \"    # print(f\\\"max: {sphere_points.max(0)}\\\")\\n\",\n    \"    # print(f\\\"min: {sphere_points.min(0)}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"e986aed0-1aaf-4772-937c-136db7f2eaec\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# You can choose to visualize with Plotly...\\n\",\n    \"x, y, z = *xyzs.T,\\n\",\n    \"colors = rgbs / 255.0\\n\",\n    \"sphere_x, sphere_y, sphere_z = *sphere_points.T,\\n\",\n    \"sphere_colors = [\\\"#4488ff\\\"] * len(sphere_points)\\n\",\n    \"sphere_size = [0.5] * len(sphere_points)\\n\",\n    \"sphere_colors[-1] = \\\"#ff0000\\\" # #ff4444 center point\\n\",\n    \"# sphere_size[-1] = 5\\n\",\n    \"# traces_poses = visualize.plotly_visualize_pose(poses, vis_depth=vis_depth, xyz_length=0.02, center_size=0.01, xyz_width=0.005, mesh_opacity=0.05)\\n\",\n    \"traces_poses = plotly_visualize_pose(poses, vis_depth=vis_depth, xyz_length=0.02, center_size=0.01, xyz_width=0.005, mesh_opacity=0.05)\\n\",\n    \"trace_points = go.Scatter3d(x=x, y=y, z=z, mode=\\\"markers\\\", marker=dict(size=0.4, color=colors, opacity=0.7), hoverinfo=\\\"skip\\\")\\n\",\n    \"trace_sphere = go.Scatter3d(x=sphere_x, y=sphere_y, z=sphere_z, mode=\\\"markers\\\", marker=dict(size=sphere_size, color=sphere_colors, opacity=0.7), hoverinfo=\\\"skip\\\")\\n\",\n    \"traces_all = traces_poses + [trace_points, trace_sphere]\\n\",\n    \"layout = go.Layout(scene=dict(xaxis=dict(showspikes=False, backgroundcolor=\\\"rgba(0,0,0,0)\\\", gridcolor=\\\"rgba(0,0,0,0.1)\\\"),\\n\",\n    \"                              yaxis=dict(showspikes=False, backgroundcolor=\\\"rgba(0,0,0,0)\\\", gridcolor=\\\"rgba(0,0,0,0.1)\\\"),\\n\",\n    \"                              zaxis=dict(showspikes=False, backgroundcolor=\\\"rgba(0,0,0,0)\\\", gridcolor=\\\"rgba(0,0,0,0.1)\\\"),\\n\",\n    \"                              xaxis_title=\\\"X\\\", yaxis_title=\\\"Y\\\", zaxis_title=\\\"Z\\\", dragmode=\\\"orbit\\\",\\n\",\n    \"                              aspectratio=dict(x=1, y=1, z=1), aspectmode=\\\"data\\\"), height=800)\\n\",\n    \"fig = go.Figure(data=traces_all, layout=layout)\\n\",\n    \"fig.show()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"fdde170b-4546-4617-9162-a9fcb936347d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# ... or visualize with K3D.\\n\",\n    \"plot = k3d.plot(name=\\\"poses\\\", height=800, camera_rotate_speed=5.0, camera_zoom_speed=3.0, camera_pan_speed=1.0)\\n\",\n    \"# k3d_objects = visualize.k3d_visualize_pose(poses, vis_depth=vis_depth, xyz_length=0.02, center_size=0.01, xyz_width=0.005, mesh_opacity=0.05)\\n\",\n    \"k3d_objects = k3d_visualize_pose(poses, vis_depth=vis_depth, xyz_length=0.02, center_size=0.01, xyz_width=0.005, mesh_opacity=0.05)\\n\",\n    \"for k3d_object in k3d_objects:\\n\",\n    \"    plot += k3d_object\\n\",\n    \"plot += k3d.points(xyzs, colors=rgbs_int32, point_size=0.02, shader=\\\"flat\\\")\\n\",\n    \"plot += k3d.points(sphere_points, color=0x4488ff, point_size=0.01, shader=\\\"flat\\\")\\n\",\n    \"plot.display()\\n\",\n    \"plot.camera_fov = 30.0\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.13\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "process_data/visualize_transforms.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"8b8d7b17-af50-42cd-b531-ef61c49c9e61\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Set the work directory to the imaginaire root.\\n\",\n    \"import os, sys, time\\n\",\n    \"import pathlib\\n\",\n    \"root_dir = pathlib.Path().absolute().parents[2]\\n\",\n    \"os.chdir(root_dir)\\n\",\n    \"print(f\\\"Root Directory Path: {root_dir}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"2b5b9e2f-841c-4815-92e0-0c76ed46da62\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Import Python libraries.\\n\",\n    \"import numpy as np\\n\",\n    \"import torch\\n\",\n    \"import k3d\\n\",\n    \"import json\\n\",\n    \"from collections import OrderedDict\\n\",\n    \"# Import imaginaire modules.\\n\",\n    \"from projects.nerf.utils import camera, visualize\\n\",\n    \"from third_party.colmap.scripts.python.read_write_model import read_model\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"97bedecf-da68-44b1-96cf-580ef7e7f3f0\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Read the COLMAP data.\\n\",\n    \"colmap_path = \\\"datasets/lego_ds2\\\"\\n\",\n    \"json_fname = f\\\"{colmap_path}/transforms.json\\\"\\n\",\n    \"with open(json_fname) as file:\\n\",\n    \"    meta = json.load(file)\\n\",\n    \"center = meta[\\\"sphere_center\\\"]\\n\",\n    \"radius = meta[\\\"sphere_radius\\\"]\\n\",\n    \"# Convert camera poses.\\n\",\n    \"poses = []\\n\",\n    \"for frame in meta[\\\"frames\\\"]:\\n\",\n    \"    c2w = torch.tensor(frame[\\\"transform_matrix\\\"])\\n\",\n    \"    c2w[:, 1:3] *= -1\\n\",\n    \"    w2c = c2w.inverse()\\n\",\n    \"    pose = w2c[:3]  # [3,4]\\n\",\n    \"    poses.append(pose)\\n\",\n    \"poses = torch.stack(poses, dim=0)\\n\",\n    \"print(f\\\"# images: {len(poses)}\\\")\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"2016d20c-1e58-407f-9810-cbe76dc5ccec\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"vis_depth = 0.2\\n\",\n    \"k3d_textures = []\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"d7168a09-6654-4660-b140-66b9dfd6f1e8\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# (optional) visualize the images.\\n\",\n    \"# This block can be skipped if we don't want to visualize the image observations.\\n\",\n    \"for i, frame in enumerate(meta[\\\"frames\\\"]):\\n\",\n    \"    image_fname = frame[\\\"file_path\\\"]\\n\",\n    \"    image_path = f\\\"{colmap_path}/{image_fname}\\\"\\n\",\n    \"    with open(image_path, \\\"rb\\\") as file:\\n\",\n    \"        binary = file.read()\\n\",\n    \"    # Compute the corresponding image corners in 3D.\\n\",\n    \"    pose = poses[i]\\n\",\n    \"    corners = torch.tensor([[-0.5, 0.5, 1], [0.5, 0.5, 1], [-0.5, -0.5, 1]])\\n\",\n    \"    corners *= vis_depth\\n\",\n    \"    corners = camera.cam2world(corners, pose)\\n\",\n    \"    puv = [corners[0].tolist(), (corners[1]-corners[0]).tolist(), (corners[2]-corners[0]).tolist()]\\n\",\n    \"    k3d_texture = k3d.texture(binary, file_format=\\\"jpg\\\", puv=puv)\\n\",\n    \"    k3d_textures.append(k3d_texture)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"b6cf60ec-fe6a-43ba-9aaf-e3c7afd88208\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Visualize the bounding sphere.\\n\",\n    \"json_fname = f\\\"{colmap_path}/transforms.json\\\"\\n\",\n    \"with open(json_fname) as file:\\n\",\n    \"    meta = json.load(file)\\n\",\n    \"center = meta[\\\"sphere_center\\\"]\\n\",\n    \"radius = meta[\\\"sphere_radius\\\"]\\n\",\n    \"# ------------------------------------------------------------------------------------\\n\",\n    \"# These variables can be adjusted to make the bounding sphere fit the region of interest.\\n\",\n    \"# The adjusted values can then be set in the config as data.readjust.center and data.readjust.scale\\n\",\n    \"readjust_center = np.array([0., 0., 0.])\\n\",\n    \"readjust_scale = 1.\\n\",\n    \"# ------------------------------------------------------------------------------------\\n\",\n    \"center += readjust_center\\n\",\n    \"radius *= readjust_scale\\n\",\n    \"# Make some points to hallucinate a bounding sphere.\\n\",\n    \"sphere_points = np.random.randn(100000, 3)\\n\",\n    \"sphere_points = sphere_points / np.linalg.norm(sphere_points, axis=-1, keepdims=True)\\n\",\n    \"sphere_points = sphere_points * radius + center\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"id\": \"fdde170b-4546-4617-9162-a9fcb936347d\",\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"# Visualize with K3D.\\n\",\n    \"plot = k3d.plot(name=\\\"poses\\\", height=800, camera_rotate_speed=5.0, camera_zoom_speed=3.0, camera_pan_speed=1.0)\\n\",\n    \"k3d_objects = visualize.k3d_visualize_pose(poses, vis_depth=vis_depth, xyz_length=0.02, center_size=0.01, xyz_width=0.005, mesh_opacity=0.)\\n\",\n    \"for k3d_object in k3d_objects:\\n\",\n    \"    plot += k3d_object\\n\",\n    \"for k3d_texture in k3d_textures:\\n\",\n    \"    plot += k3d_texture\\n\",\n    \"plot += k3d.points(sphere_points, color=0x4488ff, point_size=0.01, shader=\\\"flat\\\")\\n\",\n    \"plot.display()\\n\",\n    \"plot.camera_fov = 30.0\"\n   ]\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"Python 3 (ipykernel)\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.13\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 5\n}\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.black]\nline-length = 240\n\n[build-system]\nrequires = [\"setuptools>=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"vcr-gaus\"\nversion = \"0.0.0.dev0\"\ndescription = \"VCR-GauS: View Consistent Depth-Normal Regularizer for Gaussian Surface Reconstruction\"\nreadme = \"README.md\"\nrequires-python = \">=3.8\"\nclassifiers = [\n    \"Programming Language :: Python :: 3\",\n    \"License :: OSI Approved :: Apache Software License\",\n]\n\n[project.optional-dependencies]\n\nf1eval = [\n    \"open3d==0.10.0\",\n    \"numpy\"\n]\n\ntrain = [\n    \"torch==2.0.1\",\n    \"torchvision==0.15.2\",\n    \"torchaudio==2.0.2\",\n    \"numpy==1.26.1\",\n    \"open3d\",\n    \"plyfile\",\n    \"ninja\",\n    \"GPUtil\",\n    \"opencv-python\",\n    \"lpips\",\n    \"trimesh\",\n    \"pymeshlab\",\n    \"termcolor\",\n    \"wandb\",\n    \"imageio\",\n    \"scikit-image\",\n    \"torchmetrics\",\n    \"mediapy\",\n]\n\n[project.urls]\n\"Homepage\" = \"https://hlinchen.github.io/projects/VCR-GauS/\"\n\"Bug Tracker\" = \"https://github.com/HLinChen/VCR-GauS/issues\"\n\n[tool.setuptools.packages.find]\ninclude = [\"vcr*\", \"trl*\"]\nexclude = [\n    \"assets*\",\n    \"benchmark*\",\n    \"docs\",\n    \"dist*\",\n    \"playground*\",\n    \"scripts*\",\n    \"tests*\",\n    \"checkpoints*\",\n    \"project_checkpoints*\",\n    \"debug_checkpoints*\",\n    \"mlx_configs*\",\n    \"wandb*\",\n    \"notebooks*\",\n]\n\n[tool.wheel]\nexclude = [\n    \"assets*\",\n    \"benchmark*\",\n    \"docs\",\n    \"dist*\",\n    \"playground*\",\n    \"scripts*\",\n    \"tests*\",\n    \"checkpoints*\",\n    \"project_checkpoints*\",\n    \"debug_checkpoints*\",\n    \"mlx_configs*\",\n    \"wandb*\",\n    \"notebooks*\",\n]\n"
  },
  {
    "path": "python_scripts/run_base.py",
    "content": "import os\nimport time\nimport GPUtil\n\n\ndef worker(gpu, scene, factor, fn):\n    print(f\"Starting job on GPU {gpu} with scene {scene}\\n\")\n    fn(gpu, scene, factor)\n    print(f\"Finished job on GPU {gpu} with scene {scene}\\n\")\n    # This worker function starts a job and returns when it's done.\n\n\ndef dispatch_jobs(jobs, executor, excluded_gpus, fn):\n    future_to_job = {}\n    reserved_gpus = set()  # GPUs that are slated for work but may not be active yet\n\n    while jobs or future_to_job:\n        # Get the list of available GPUs, not including those that are reserved.\n        all_available_gpus = set(GPUtil.getAvailable(order=\"first\", limit=10, maxMemory=0.1, maxLoad=0.1))\n        available_gpus = list(all_available_gpus - reserved_gpus - excluded_gpus)\n        \n        # Launch new jobs on available GPUs\n        while available_gpus and jobs:\n            gpu = available_gpus.pop(0)\n            job = jobs.pop(0)\n            future = executor.submit(worker, gpu, *job, fn)  # Unpacking job as arguments to worker\n            future_to_job[future] = (gpu, job)\n\n            reserved_gpus.add(gpu)  # Reserve this GPU until the job starts processing\n\n        # Check for completed jobs and remove them from the list of running jobs.\n        # Also, release the GPUs they were using.\n        done_futures = [future for future in future_to_job if future.done()]\n        for future in done_futures:\n            job = future_to_job.pop(future)  # Remove the job associated with the completed future\n            gpu = job[0]  # The GPU is the first element in each job tuple\n            reserved_gpus.discard(gpu)  # Release this GPU\n            print(f\"Job {job} has finished., rellasing GPU {gpu}\")\n        # (Optional) You might want to introduce a small delay here to prevent this loop from spinning very fast\n        # when there are no GPUs available.\n        time.sleep(5)\n        \n    print(\"All jobs have been processed.\")\n\n\ndef check_finish(scene, path, type='mesh'):\n    if not os.path.exists(path):\n        print(f\"Scene \\033[1;31m{scene}\\033[0m failed in \\033[1;31m{type}\\033[0m\")\n        return False\n    return True\n\n\ntrain_cmd = \"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} \\\n        python train.py \\\n            --config=configs/{dataset}/{cfg}.yaml \\\n            --logdir={log_dir} \\\n            --model.source_path={data_dir}/{scene}/ \\\n            --train.debug_from={debug_from} \\\n            --model.data_device={data_device} \\\n            --model.resolution={resolution} \\\n            --wandb \\\n            --wandb_name {project}\"\n\n\ntrain_cmd_new = \"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} \\\n        python train.py \\\n            --config={cfg} \\\n            --logdir={log_dir} \\\n            --model.source_path={data_dir}/{scene}/ \\\n            --train.debug_from={debug_from} \\\n            --model.data_device={data_device} \\\n            --model.resolution={resolution} \\\n            --wandb \\\n            --wandb_name {project}\"\n\n\nextract_mesh_cmd = \"OMP_NUM_THREADS=4 CUDA_VISIBLE_DEVICES={gpu} \\\n        python tools/depth2mesh.py \\\n                --mesh_name {ply} \\\n                --split {step} \\\n                --method {fuse_method} \\\n                --voxel_size {voxel_size} \\\n                --num_cluster {num_cluster} \\\n                --max_depth {max_depth} \\\n                --clean \\\n                --prob_thres {prob_thr} \\\n                --cfg_path {log_dir}/config.yaml\"\n\n\neval_tnt_cmd = \"OMP_NUM_THREADS={num_threads} CUDA_VISIBLE_DEVICES={gpu} \\\n                conda run -n {eval_env} \\\n                python evaluation/tnt_eval/run.py \\\n                    --dataset-dir {data_dir}/{scene}/ \\\n                    --traj-path {data_dir}/{scene}/{scene}_COLMAP_SfM.log \\\n                    --ply-path {log_dir}/{ply} > {log_dir}/fscore.txt\"\n\n\neval_cd_cmd = \"OMP_NUM_THREADS={num_threads}  CUDA_VISIBLE_DEVICES={gpu} \\\n                python evaluation/eval_dtu/evaluate_single_scene.py \\\n                    --input_mesh {tri_mesh_path} \\\n                    --scan_id {scan_id} --output_dir {output_dir} \\\n                    --mask_dir {data_dir} \\\n                    --DTU {data_dir}\"\n\n\nrender_cmd = \"CUDA_VISIBLE_DEVICES={gpu} \\\n                python evaluation/render.py \\\n                    --cfg_path {log_dir}/config.yaml \\\n                    --iteration 30000 \\\n                    --skip_train\"\n\neval_psnr_cmd = \"CUDA_VISIBLE_DEVICES={gpu} \\\n                    python evaluation/metrics.py \\\n                    --cfg_path {log_dir}/config.yaml\"\n\neval_replica_cmd = \"OMP_NUM_THREADS={num_threads} CUDA_VISIBLE_DEVICES={gpu} \\\n                    python evaluation/replica_eval/evaluate_single_scene.py \\\n                        --input_mesh {tri_mesh_path} \\\n                        --scene {scene} \\\n                        --output_dir {output_dir} \\\n                        --data_dir {data_dir}\""
  },
  {
    "path": "python_scripts/run_dtu.py",
    "content": "# training scripts for the TNT datasets\nimport os\nimport sys\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\n\nsys.path.append(os.getcwd())\nfrom python_scripts.run_base import dispatch_jobs, train_cmd, extract_mesh_cmd, eval_cd_cmd, check_finish\nfrom python_scripts.show_dtu import show_matrix\n\nTRIAL_NAME = 'vcr_gaus'\nPROJECT = 'vcr_gaus'\nPROJECT_wandb = 'vcr_gaus_dtu'\nDATASET = 'dtu'\nbase_dir = \"/your/path\"\noutput_dir = f\"{base_dir}/output/{PROJECT}/{DATASET}\"\ndata_dir = f\"{base_dir}/data/DTU_mask\"\n\ndo_train = False\ndo_extract_mesh = False\ndo_cd = True\ndry_run = False\n\nnode = 0\nmax_workers = 15\nbe = node*max_workers\nexcluded_gpus = set([])\n\ntotal_list = [\n        'scan24', 'scan37', 'scan40', 'scan55', 'scan63', 'scan65', 'scan69', \n        'scan83', 'scan97', 'scan105', 'scan106', 'scan110', 'scan114', 'scan118', 'scan122'\n]\ntraining_list = [\n        'scan24', 'scan37', 'scan40', 'scan55', 'scan63', 'scan65', 'scan69', \n        'scan83', 'scan97', 'scan105', 'scan106', 'scan110', 'scan114', 'scan118', 'scan122'\n]\n\ntraining_list = training_list[be: be + max_workers]\nscenes = training_list\n\nfactors = [-1] * len(scenes)\ndebug_from = -1\n\neval_env = 'pt'\ndata_device = 'cuda'\nvoxel_size = 0.004\nstep = 1\nPLY = f\"ours.ply\"\nTOTAL_THREADS = 64\nNUM_THREADS = TOTAL_THREADS // max_workers\nprob_thr = 0.15\nnum_cluster = 1\nmax_depth = 3\nfuse_method = 'tsdf_cpu'\n\njobs = list(zip(scenes, factors))\n\ndef train_scene(gpu, scene, factor):\n    time.sleep(2*gpu)\n    os.system('ulimit -n 9000')\n    log_dir = f\"{output_dir}/{scene}/{TRIAL_NAME}\"\n    \n    fail = 0\n    \n    if not dry_run:\n        if do_train:\n            cmd = train_cmd.format(gpu=gpu, dataset=DATASET, cfg='base',\n                                scene=scene, log_dir=log_dir, \n                                data_dir=data_dir, debug_from=debug_from, \n                                data_device=data_device, resolution=factor, project=PROJECT_wandb)\n            print(cmd)\n            fail = os.system(cmd)\n\n        if fail == 0:\n            if not dry_run:\n                # fusion\n                if do_extract_mesh:\n                    if not check_finish(scene, f\"{log_dir}/point_cloud\", 'train'): return False\n                    cmd = extract_mesh_cmd.format(gpu=gpu, ply=PLY, step=step,  fuse_method=fuse_method, voxel_size=voxel_size, num_cluster=num_cluster, max_depth=max_depth, log_dir=log_dir, prob_thr=prob_thr)\n                    fail = os.system(cmd)\n                    print(cmd)\n        \n                # evaluation\n                # evaluate the mesh\n                scan_id = scene[4:]\n                cmd = eval_cd_cmd.format(num_threads=NUM_THREADS, gpu=gpu, tri_mesh_path=f'{log_dir}/{PLY}', scan_id=scan_id, output_dir=log_dir, data_dir=data_dir)\n                if fail == 0:\n                    if not dry_run:\n                        if do_cd: \n                            if not check_finish(scene, f\"{log_dir}/{PLY}\", 'mesh'): return False\n                            print(cmd)\n                            fail = os.system(cmd)\n                            if not check_finish(scene, f\"{log_dir}/results.json\", 'cd'): return False\n    return fail == 0\n\n\n# Using ThreadPoolExecutor to manage the thread pool\nwith ThreadPoolExecutor(max_workers) as executor:\n    dispatch_jobs(jobs, executor, excluded_gpus, train_scene)\n\nshow_matrix(total_list, [output_dir], TRIAL_NAME)\nprint(TRIAL_NAME, \" done\")"
  },
  {
    "path": "python_scripts/run_mipnerf360.py",
    "content": "# Training script for the Mip-NeRF 360 dataset\nimport os\nimport sys\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\n\nsys.path.append(os.getcwd())\nfrom python_scripts.run_base import dispatch_jobs, train_cmd, extract_mesh_cmd, check_finish, render_cmd, eval_psnr_cmd\nfrom python_scripts.show_360 import show_matrix\n\n\nTRIAL_NAME = 'vcr_gaus'\nPROJECT = 'vcr_gaus'\nPROJECT_wandb = 'vcr_gaus_360'\n\ndo_train = True\ndo_render = True\ndo_eval = True\ndo_extract_mesh = True\ndry_run = False\n\nnode = 0\nmax_workers = 9\nbe = node*max_workers\nexcluded_gpus = set([])\n\ntotal_list = [\n        \"bicycle\", \"bonsai\", \"counter\", \"flowers\", \"garden\", \"stump\", \"treehill\", \"kitchen\", \"room\"\n]\ntraining_list = [\n        \"bicycle\", \"bonsai\", \"counter\", \"flowers\", \"garden\", \"stump\", \"treehill\", \"kitchen\", \"room\"\n]\n\ntraining_list = training_list[be: be + max_workers]\nscenes = training_list\n\nfactors = [-1] * len(scenes)\n\ndebug_from = -1\n\nDATASET = '360_v2'\neval_env = 'pt'\ndata_device = 'cpu'\nstep = 1\nmax_depth = 6.0\nvoxel_size = 8e-3\nPLY = f\"fused_mesh_split{step}.ply\"\nTOTAL_THREADS = 64\nNUM_THREADS = TOTAL_THREADS // max_workers\nprob_thr = 0.15\nnum_cluster = 1000\nfuse_method = 'tsdf'\n\nbase_dir = \"/your/path\"\noutput_dir = f\"{base_dir}/output/{PROJECT}/{DATASET}\"\ndata_dir = f\"{base_dir}/data/{DATASET}\"\n\njobs = list(zip(scenes, factors))\n\n\ndef train_scene(gpu, scene, factor):\n    time.sleep(2*gpu)\n    os.system('ulimit -n 9000')\n    log_dir = f\"{output_dir}/{scene}/{TRIAL_NAME}\"\n    \n    fail = 0\n    \n    if not dry_run:\n        if do_train:\n            cmd = train_cmd.format(gpu=gpu, dataset=DATASET, cfg='base',\n                                scene=scene, log_dir=log_dir, \n                                data_dir=data_dir, debug_from=debug_from, \n                                data_device=data_device, resolution=factor, project=PROJECT_wandb)\n            print(cmd)\n            fail = os.system(cmd)\n            \n        if fail == 0:\n            if not dry_run:\n                # render\n                cmd = render_cmd.format(gpu=gpu, log_dir=log_dir)\n                if fail == 0:\n                    if not dry_run:\n                        if do_render:\n                            print(cmd)\n                            fail = os.system(cmd)\n                            if not check_finish(scene, f\"{log_dir}/test/ours_30000/renders\", 'render'): return False\n                \n                # eval\n                cmd = eval_psnr_cmd.format(gpu=gpu, log_dir=log_dir)\n                if fail == 0:\n                    if not dry_run:\n                        if do_eval:\n                            print(cmd)\n                            fail = os.system(cmd)\n                            if not check_finish(scene, f\"{log_dir}/results.json\", 'eval'): return False\n                \n                # fusion\n                if do_extract_mesh:\n                    if not check_finish(scene, f\"{log_dir}/point_cloud\", 'train'): return False\n                    cmd = extract_mesh_cmd.format(gpu=gpu, ply=PLY, step=step,  fuse_method=fuse_method, voxel_size=voxel_size, num_cluster=num_cluster, max_depth=max_depth, log_dir=log_dir, prob_thr=prob_thr)\n                    fail = os.system(cmd)\n                    print(cmd)\n        \n    return fail == 0\n\n\n# Using ThreadPoolExecutor to manage the thread pool\nwith ThreadPoolExecutor(max_workers) as executor:\n    dispatch_jobs(jobs, executor, excluded_gpus, train_scene)\n\nshow_matrix(total_list, [output_dir], TRIAL_NAME)\nprint(TRIAL_NAME, \" done\")\n"
  },
  {
    "path": "python_scripts/run_tnt.py",
    "content": "# training scripts for the TNT datasets\nimport os\nimport sys\nimport time\nfrom concurrent.futures import ThreadPoolExecutor\n\nsys.path.append(os.getcwd())\nfrom python_scripts.run_base import dispatch_jobs, train_cmd, extract_mesh_cmd, eval_tnt_cmd, check_finish\nfrom python_scripts.show_tnt import show_matrix\n\n\nTRIAL_NAME = 'vcr_gaus'\nPROJECT = 'vcr_gaus'\nDATASET = 'tnt'\nbase_dir = \"/your/path\"\noutput_dir = f\"{base_dir}/output/{PROJECT}/{DATASET}\"\ndata_dir = f\"{base_dir}/data/{DATASET}\"\n\ndo_train = True\ndo_extract_mesh = True\ndo_f1 = True\ndry_run = False\n\nnode = 0\nmax_workers = 4\nbe = node*max_workers\nexcluded_gpus = set([])\n\ntotal_list = [\n        'Barn', 'Caterpillar', 'Courthouse', 'Ignatius',\n        'Meetingroom', 'Truck'\n]\ntraining_list = [\n        'Barn', 'Caterpillar', 'Courthouse', 'Ignatius',\n        'Meetingroom', 'Truck'\n]\n\ntraining_list = training_list[be: be + max_workers]\nscenes = training_list\n\nfactors = [1] * len(scenes)\ndebug_from = -1 # enable wandb\n\neval_env = 'f1eval'\ndata_device = 'cpu'\nstep = 3\nvoxel_size = [0.02, 0.015, 0.01] + [x / 1000.0 for x in range(2, 10, 1)][::-1]\nvoxel_size = sorted(voxel_size)\nPLY = f\"ours.ply\"\nTOTAL_THREADS = 128\nNUM_THREADS = TOTAL_THREADS // max_workers\nprob_thr = 0.3\nnum_cluster = 1000\nfuse_method = 'tsdf'\nmax_depth = 8\n\n\njobs = list(zip(scenes, factors))\n\n\ndef train_scene(gpu, scene, factor):\n    time.sleep(2*gpu)\n    os.system('ulimit -n 9000')\n    log_dir = f\"{output_dir}/{scene}/{TRIAL_NAME}\"\n    \n    fail = 0\n    \n    if not dry_run:\n        if do_train:\n            cmd = train_cmd.format(gpu=gpu, dataset=DATASET, cfg=scene,\n                                scene=scene, log_dir=log_dir, \n                                data_dir=data_dir, debug_from=debug_from, \n                                data_device=data_device, resolution=factor, project=PROJECT)\n            print(cmd)\n            fail = os.system(cmd)\n\n        if fail == 0:\n            if not dry_run:\n                # fusion\n                if do_extract_mesh:\n                    if not check_finish(scene, f\"{log_dir}/point_cloud\", 'train'): return False\n                    for vs in voxel_size:\n                        cmd = extract_mesh_cmd.format(gpu=gpu, ply=PLY, step=step,  fuse_method=fuse_method, voxel_size=vs, num_cluster=num_cluster, max_depth=max_depth, log_dir=log_dir, prob_thr=prob_thr)\n                        fail = os.system(cmd)\n                        if fail == 0: break\n                    print(cmd)\n        \n                # evaluation\n                # You need to install open3d==0.9 for evaluation\n                # evaluate the mesh\n                cmd = eval_tnt_cmd.format(num_threads=NUM_THREADS, gpu=gpu, eval_env=eval_env, data_dir=data_dir, scene=scene, log_dir=log_dir, ply=PLY)\n                if fail == 0:\n                    if not dry_run:\n                        if do_f1: \n                            if not check_finish(scene, f\"{log_dir}/{PLY}\", 'mesh'): return False\n                            print(cmd)\n                            fail = os.system(cmd)\n                            if not check_finish(scene, f\"{log_dir}/evaluation/evaluation.txt\", 'f1'): return False\n    # return True\n    return fail == 0\n\n\n# Using ThreadPoolExecutor to manage the thread pool\nwith ThreadPoolExecutor(max_workers) as executor:\n    dispatch_jobs(jobs, executor, excluded_gpus, train_scene)\n\nshow_matrix(total_list, [output_dir], TRIAL_NAME)\nprint(TRIAL_NAME, \" done\")"
  },
  {
    "path": "python_scripts/show_360.py",
    "content": "import json\nimport numpy as np\n\nscenes = ['bicycle', 'flowers', 'garden', 'stump', 'treehill', 'room', 'counter', 'kitchen', 'bonsai']\n\noutput_dirs = [\"exp_360/release\"]\n\noutdoor_scenes = [\"bicycle\", \"flowers\", \"garden\", \"stump\", \"treehill\"]\nindoor_scenes = [\"room\", \"counter\", \"kitchen\", \"bonsai\"]\n\nall_metrics = {\"PSNR\": [], \"SSIM\": [], \"LPIPS\": [], 'scene': []}\nindoor_metrics = {\"PSNR\": [], \"SSIM\": [], \"LPIPS\": [], 'scene': []}\noutdoor_metrics = {\"PSNR\": [], \"SSIM\": [], \"LPIPS\": [], 'scene': []}\nTRIAL_NAME = 'vcr_gaus'\n\ndef show_matrix(scenes, output_dirs, TRIAL_NAME):\n\n    for scene in scenes:\n        for output in output_dirs:\n            json_file = f\"{output}/{scene}/{TRIAL_NAME}/results.json\"\n            data = json.load(open(json_file))\n            data = data['ours_30000']\n            \n            for k in [\"PSNR\", \"SSIM\", \"LPIPS\"]:\n                all_metrics[k].append(data[k])\n                if scene in indoor_scenes:\n                    indoor_metrics[k].append(data[k])\n                else:\n                    outdoor_metrics[k].append(data[k])\n            all_metrics['scene'].append(scene)\n            if scene in indoor_scenes:\n                indoor_metrics['scene'].append(scene)\n            else:\n                outdoor_metrics['scene'].append(scene)\n\n    latex = []\n    for k in [\"PSNR\", \"SSIM\", \"LPIPS\"]:\n        numbers = np.asarray(all_metrics[k]).mean(axis=0).tolist()\n        numbers = [numbers]\n        if k == \"PSNR\":\n            numbers = [f\"{x:.2f}\" for x in numbers]\n        else:\n            numbers = [f\"{x:.3f}\" for x in numbers]\n        latex.extend([k+': ', numbers[-1]+' '])\n\n    indoor_latex = []\n    for k in [\"PSNR\", \"SSIM\", \"LPIPS\"]:\n        numbers = np.asarray(indoor_metrics[k]).mean(axis=0).tolist()\n        numbers = [numbers]\n        if k == \"PSNR\":\n            numbers = [f\"{x:.2f}\" for x in numbers]\n        else:\n            numbers = [f\"{x:.3f}\" for x in numbers]\n        indoor_latex.extend([k+': ', numbers[-1]+' '])\n        \n    outdoor_latex = []\n    for k in [\"PSNR\", \"SSIM\", \"LPIPS\"]:\n        numbers = np.asarray(outdoor_metrics[k]).mean(axis=0).tolist()\n        numbers = [numbers]\n        if k == \"PSNR\":\n            numbers = [f\"{x:.2f}\" for x in numbers]\n        else:\n            numbers = [f\"{x:.3f}\" for x in numbers]\n        outdoor_latex.extend([k+': ', numbers[-1]+' '])\n        \n    print('Outdoor scenes')\n    for i in range(len(outdoor_metrics['scene'])):\n        print('PSNR: {:.3f}, SSIM: {:.3f}, LPIPS: {:.3f}, scene: {}'.format(outdoor_metrics['PSNR'][i], outdoor_metrics['SSIM'][i], outdoor_metrics['LPIPS'][i], outdoor_metrics['scene'][i]))\n    \n    print('Indoor scenes')\n    for i in range(len(indoor_metrics['scene'])):\n        print('PSNR: {:.3f}, SSIM: {:.3f}, LPIPS: {:.3f}, scene: {}'.format(indoor_metrics['PSNR'][i], indoor_metrics['SSIM'][i], indoor_metrics['LPIPS'][i], indoor_metrics['scene'][i]))\n        \n    print('Outdoor:')\n    print(\"\".join(outdoor_latex))\n    print('Indoor:')\n    print(\"\".join(indoor_latex))\n\nif __name__ == \"__main__\":\n    show_matrix(scenes, output_dirs, TRIAL_NAME)\n"
  },
  {
    "path": "python_scripts/show_dtu.py",
    "content": "import os\nimport json\nimport numpy as np\n\nscenes = [24, 37, 40, 55, 63, 65, 69, 83, 97, 105, 106, 110, 114, 118, 122]\noutput_dirs = [\"exp_dtu/release\"]\nTRIAL_NAME = 'vcr_gaus'\n\n\ndef show_matrix_old(scenes, output_dirs, TRIAL_NAME):\n    all_metrics = {\"mean_d2s\": [], \"mean_s2d\": [], \"overall\": []}\n    print(output_dirs)\n\n    for scene in scenes:\n        print(scene,end=\" \")\n        for output in output_dirs:\n            json_file = f\"{output}/scan{scene}/test/ours_30000/tsdf/results.json\"\n            data = json.load(open(json_file))\n            \n            for k in [\"mean_d2s\", \"mean_s2d\", \"overall\"]:\n                all_metrics[k].append(data[k])\n                print(f\"{data[k]:.3f}\", end=\" \")\n            print()\n\n    latex = []\n    for k in [\"mean_d2s\", \"mean_s2d\", \"overall\"]:\n        numbers = np.asarray(all_metrics[k]).mean(axis=0).tolist()\n        \n        numbers = all_metrics[k] + [numbers]\n        \n        numbers = [f\"{x:.2f}\" for x in numbers]\n        if k == \"overall\":\n            latex.extend(numbers)\n        \n    print(\" & \".join(latex))\n    \n\ndef show_matrix(scenes, output_dirs, TRIAL_NAME):\n    all_metrics = {\"mean_d2s\": [], \"mean_s2d\": [], \"overall\": [], 'scene': []}\n\n    for scene in scenes:\n        for output in output_dirs:\n            json_file = f\"{output}/{scene}/{TRIAL_NAME}/results.json\"\n            if not os.path.exists(json_file):\n                print(f\"Scene \\033[1;31m{scene}\\033[0m was not evaluated.\")\n                continue\n            data = json.load(open(json_file))\n            \n            for k in [\"mean_d2s\", \"mean_s2d\", \"overall\"]:\n                all_metrics[k].append(data[k])\n            all_metrics['scene'].append(scene)\n\n    latex = []\n    for k in [\"mean_d2s\", \"mean_s2d\", \"overall\"]:\n        numbers = np.asarray(all_metrics[k]).mean(axis=0).tolist()\n        \n        numbers = all_metrics[k] + [numbers]\n        \n        numbers = [f\"{x:.2f}\" for x in numbers]\n        latex.extend([k+': ', numbers[-1]+' '])\n        \n    for i in range(len(all_metrics['scene'])):\n        print('d2s: {:.3f}, s2d: {:.3f}, overall: {:.3f}, scene: {}'.format(all_metrics['mean_d2s'][i], all_metrics['mean_s2d'][i], all_metrics['overall'][i], all_metrics['scene'][i]))\n    \n    print(\"\".join(latex))\n\n\nif __name__ == \"__main__\":\n    show_matrix(scenes, output_dirs, TRIAL_NAME)"
  },
  {
    "path": "python_scripts/show_tnt.py",
    "content": "import os\nimport numpy as np\n\ntraining_list = [\n    'Barn', 'Caterpillar', 'Courthouse', 'Ignatius', 'Meetingroom', 'Truck'\n]\n\nscenes = training_list\n\nDATASET = 'tnt'\nbase_dir = \"/your/log/path/\"\nTRIAL_NAME = 'vcr_gaus'\nPROJECT = 'sq_gs'\noutput_dirs = [f\"{base_dir}/{PROJECT}/{DATASET}\"]\n\n\ndef show_matrix(scenes, output_dirs, TRIAL_NAME):\n    all_metrics = {\"precision\": [], \"recall\": [], \"f-score\": [], 'scene': []}\n    for scene in scenes:\n        for output in output_dirs:\n            # precision\n            eval_file = os.path.join(output, scene, f\"{TRIAL_NAME}/evaluation/evaluation.txt\")\n            \n            if not os.path.exists(eval_file):\n                print(f\"Scene \\033[1;31m{scene}\\033[0m was not evaluated.\")\n                continue\n            with open(eval_file, 'r') as f:\n                matrix = f.readlines()\n            \n            precision = float(matrix[2].split(\" \")[-1])\n            recall = float(matrix[3].split(\" \")[-1])\n            f_score = float(matrix[4].split(\" \")[-1])\n            \n            all_metrics[\"precision\"].append(precision)\n            all_metrics[\"recall\"].append(recall)\n            all_metrics[\"f-score\"].append(f_score)\n            all_metrics['scene'].append(scene)\n\n\n    latex = []\n    for k in [\"precision\",\"recall\", \"f-score\"]:\n        numbers = all_metrics[k]\n        mean = np.mean(numbers)\n        numbers = numbers + [mean]\n        \n        numbers = [f\"{x:.3f}\" for x in numbers]\n        latex.extend([k+': ', numbers[-1]+' '])\n        \n    for i in range(len(all_metrics['scene'])):\n        print('precision: {:.3f}, recall: {:.3f}, f-score: {:.3f}, scene: {}'.format(all_metrics['precision'][i], all_metrics['recall'][i], all_metrics['f-score'][i], all_metrics['scene'][i]))\n    \n    print(\"\".join(latex))\n    \n    return\n\n\nif __name__ == \"__main__\":\n    show_matrix(scenes, output_dirs, TRIAL_NAME)\n"
  },
  {
    "path": "requirements.txt",
    "content": "submodules/diff-gaussian-rasterization\nsubmodules/simple-knn/\ngit+https://github.com/facebookresearch/pytorch3d.git@stable"
  },
  {
    "path": "scene/__init__.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport random\nimport json\nimport torch\n\nfrom arguments import ModelParams\nfrom scene.gaussian_model import GaussianModel\nfrom tools.system_utils import searchForMaxIteration\nfrom scene.dataset_readers import sceneLoadTypeCallbacks\nfrom tools.camera_utils import cameraList_from_camInfos, camera_to_JSON\nfrom tools.graphics_utils import get_all_px_dir\n\nclass Scene:\n\n    gaussians : GaussianModel\n\n    def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):\n        \"\"\"b\n        :param path: Path to colmap scene main folder.\n        \"\"\"\n        self.model_path = args.model_path\n        self.loaded_iter = None\n        self.gaussians = gaussians\n        self.split = args.split\n        load_depth = args.load_depth\n        load_normal = args.load_normal\n        load_mask = args.load_mask\n\n        if load_iteration:\n            if load_iteration == -1:\n                self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, \"point_cloud\"))\n            else:\n                self.loaded_iter = load_iteration\n            print(\"Loading trained model at iteration {}\".format(self.loaded_iter))\n\n        self.train_cameras = {}\n        self.test_cameras = {}\n        \n        if os.path.exists(os.path.join(args.source_path, \"sparse\")):\n            scene_info = sceneLoadTypeCallbacks[\"Colmap\"](args.source_path, args.images, args.eval, args.llffhold, args.ratio, split=self.split, load_depth=load_depth, load_normal=load_normal, load_mask=load_mask, normal_folder=args.normal_folder, depth_folder=args.depth_folder)\n        elif os.path.exists(os.path.join(args.source_path, \"transforms_train.json\")):\n            print(\"Found transforms_train.json file, assuming Blender data set!\")\n            scene_info = sceneLoadTypeCallbacks[\"Blender\"](args.source_path, args.white_background, args.eval)\n        else:\n            assert False, \"Could not recognize scene type!\"\n        \n        self.trans = scene_info.trans\n        self.scale = scene_info.scale\n\n        if not self.loaded_iter:\n            with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, \"input.ply\") , 'wb') as dest_file:\n                dest_file.write(src_file.read())\n            json_cams = []\n            camlist = []\n            if scene_info.test_cameras:\n                camlist.extend(scene_info.test_cameras)\n            if scene_info.train_cameras:\n                camlist.extend(scene_info.train_cameras)\n            for id, cam in enumerate(camlist):\n                json_cams.append(camera_to_JSON(id, cam))\n            with open(os.path.join(self.model_path, \"cameras.json\"), 'w') as file:\n                json.dump(json_cams, file)\n\n        if shuffle:\n            random.shuffle(scene_info.train_cameras)  # Multi-res consistent random shuffling\n            # random.shuffle(scene_info.test_cameras)  # Multi-res consistent random shuffling\n\n        self.cameras_extent = scene_info.nerf_normalization[\"radius\"]\n        gaussians.extent = self.cameras_extent\n        \n        for resolution_scale in resolution_scales:\n            print(\"Loading Training Cameras\")\n            self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)\n            print(\"Loading Test Cameras\")\n            self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)\n        \n            for idx, camera in enumerate(self.train_cameras[resolution_scale] + self.test_cameras[resolution_scale]):\n                camera.idx = idx\n\n        if self.loaded_iter:\n            self.gaussians.load_ply(os.path.join(self.model_path,\n                                                           \"point_cloud\",\n                                                           \"iteration_\" + str(self.loaded_iter),\n                                                           \"point_cloud.ply\"))\n        else:\n            self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)\n        \n        if args.depth_type == \"traditional\":\n            self.dirs = None\n        elif args.depth_type == \"intersection\":\n            self.dirs = get_all_px_dir(self.getTrainCameras()[0].intr, self.getTrainCameras()[0].image_height, self.getTrainCameras()[0].image_width).cuda()\n        self.first_name = scene_info.first_name\n\n    def save(self, iteration, visi=None, surf=None, save_splat=False):\n        point_cloud_path = os.path.join(self.model_path, \"point_cloud/iteration_{}\".format(iteration))\n        self.gaussians.save_ply(os.path.join(point_cloud_path, \"point_cloud.ply\"))\n        self.gaussians.save_inside_ply(os.path.join(point_cloud_path, \"point_cloud_inside.ply\"))\n        \n        if visi is not None:\n            self.gaussians.save_visi_ply(os.path.join(point_cloud_path, \"visi.ply\"), visi)\n        \n        if surf is not None:\n            self.gaussians.save_visi_ply(os.path.join(point_cloud_path, \"surf.ply\"), surf)\n        \n        if save_splat:\n            self.gaussians.save_splat(os.path.join(point_cloud_path, \"pcd.splat\"))\n\n    def getTrainCameras(self, scale=1.0):\n        return self.train_cameras[scale]\n\n    def getTestCameras(self, scale=1.0):\n        return self.test_cameras[scale]\n    \n    def getFullCameras(self, scale=1.0):\n        if self.split:\n            return self.train_cameras[scale] + self.test_cameras[scale]\n        else:\n            return self.train_cameras[scale]\n    \n    def getUpCameras(self):\n        return self.random_cameras_up\n    \n    def getAroundCameras(self):\n        return self.random_cameras_around\n    \n    def getRandCameras(self, n, up=False, around=True, sample_mode='uniform'):\n        if up and around:\n            n = n // 2\n        \n        cameras = []\n        if up:\n            up_cameras = self.getUpCameras().copy()\n            idx = torch.randperm(len(up_cameras))[: n]\n            if n == 1:\n                cameras.append(up_cameras[idx])\n            else:\n                cameras.extend(up_cameras[idx])\n        if around:\n            around_cameras = self.getAroundCameras()\n            \n            if sample_mode == 'random':\n                idx = torch.randperm(len(around_cameras))[: n]\n            elif sample_mode == 'uniform':\n                idx = torch.arange(len(around_cameras))[::len(around_cameras)//n]\n            else:\n                assert False, f\"Unknown sample_mode: {sample_mode}\"\n            \n            if n == 1:\n                cameras.append(around_cameras[idx])\n            else:\n                cameras.extend(around_cameras[idx])\n        return cameras"
  },
  {
    "path": "scene/appearance_network.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nclass UpsampleBlock(nn.Module):\n    def __init__(self, num_input_channels, num_output_channels):\n        super(UpsampleBlock, self).__init__()\n        self.pixel_shuffle = nn.PixelShuffle(2)\n        self.conv = nn.Conv2d(num_input_channels // (2 * 2), num_output_channels, 3, stride=1, padding=1)\n        self.relu = nn.ReLU()\n        \n    def forward(self, x):\n        x = self.pixel_shuffle(x)\n        x = self.conv(x)\n        x = self.relu(x)\n        return x\n\n\nclass AppearanceNetwork(nn.Module):\n    def __init__(self, num_input_channels, num_output_channels):\n        super(AppearanceNetwork, self).__init__()\n        \n        self.conv1 = nn.Conv2d(num_input_channels, 256, 3, stride=1, padding=1)\n        self.up1 = UpsampleBlock(256, 128)\n        self.up2 = UpsampleBlock(128, 64)\n        self.up3 = UpsampleBlock(64, 32)\n        self.up4 = UpsampleBlock(32, 16)\n        \n        self.conv2 = nn.Conv2d(16, 16, 3, stride=1, padding=1)\n        self.conv3 = nn.Conv2d(16, num_output_channels, 3, stride=1, padding=1)\n        self.relu = nn.ReLU()\n        self.sigmoid = nn.Sigmoid()\n        \n    def forward(self, x):\n        x = self.conv1(x)\n        x = self.relu(x)\n        x = self.up1(x)\n        x = self.up2(x)\n        x = self.up3(x)\n        x = self.up4(x)\n        # bilinear interpolation\n        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)\n        x = self.conv2(x)\n        x = self.relu(x)\n        x = self.conv3(x)\n        x = self.sigmoid(x)\n        return x\n\n\nif __name__ == \"__main__\":\n    H, W = 1200//32, 1600//32\n    input_channels = 3 + 64\n    output_channels = 3\n    input = torch.randn(1, input_channels, H, W).cuda()\n    model = AppearanceNetwork(input_channels, output_channels).cuda()\n    \n    output = model(input)\n    print(output.shape)"
  },
  {
    "path": "scene/cameras.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nfrom torch import nn\nimport numpy as np\n\nfrom tools.graphics_utils import getWorld2View2, getProjectionMatrix, getIntrinsic\n\n\nclass Camera(nn.Module):\n    def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask,\n                 image_name, uid, depth=None, normal=None, mask=None,\n                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = \"cuda\"\n                 ):\n        super(Camera, self).__init__()\n\n        self.uid = uid\n        self.colmap_id = colmap_id\n        self.R = R\n        self.T = T\n        self.FoVx = FoVx\n        self.FoVy = FoVy\n        self.image_name = image_name\n\n        try:\n            self.data_device = torch.device(data_device)\n        except Exception as e:\n            print(e)\n            print(f\"[Warning] Custom device {data_device} failed, fallback to default cuda device\" )\n            self.data_device = torch.device(\"cuda\")\n\n        self.original_image = image.clamp(0.0, 1.0).to(self.data_device)\n        self.image_width = self.original_image.shape[2]\n        self.image_height = self.original_image.shape[1]\n\n        if gt_alpha_mask is not None:\n            self.gt_alpha_mask = gt_alpha_mask\n            if mask is not None:\n                mask = mask.squeeze(-1).cuda()\n                mask[self.gt_alpha_mask[0] == 0] = 0\n            else:\n                mask = self.gt_alpha_mask.bool().squeeze(0).cuda()\n        else:\n            self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)\n            self.gt_alpha_mask = None\n\n        self.depth = depth.to(data_device) if depth is not None else None\n        self.normal = normal.to(data_device) if normal is not None else None\n        \n        if mask is not None:\n            self.mask = mask.squeeze(-1).cuda()\n        \n        self.zfar = 100.0\n        self.znear = 0.01\n\n        self.trans = trans\n        self.scale = scale\n\n        self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()         # w2c\n        self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()\n        self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)     # w2c2image\n        self.camera_center = self.world_view_transform.inverse()[3, :3]\n        intr = getIntrinsic(self.FoVx, self.FoVy, self.image_height, self.image_width).cuda()\n        self.intr = intr\n        \n        \nclass MiniCam:\n    def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):\n        self.image_width = width\n        self.image_height = height    \n        self.FoVy = fovy\n        self.FoVx = fovx\n        self.znear = znear\n        self.zfar = zfar\n        self.world_view_transform = world_view_transform\n        self.full_proj_transform = full_proj_transform\n        view_inv = torch.inverse(self.world_view_transform)\n        self.camera_center = view_inv[3][:3]\n\n\nclass SampleCam(nn.Module):\n    def __init__(self, w2c, width, height, FoVx, FoVy, device='cuda'):\n        super(SampleCam, self).__init__()\n\n        self.FoVx = FoVx\n        self.FoVy = FoVy\n        self.image_width = width\n        self.image_height = height\n\n        self.zfar = 100.0\n        self.znear = 0.01\n\n        try:\n            self.data_device = torch.device(device)\n        except Exception as e:\n            print(e)\n            print(f\"[Warning] Custom device {device} failed, fallback to default cuda device\" )\n            self.data_device = torch.device(\"cuda\")\n        \n        w2c = w2c.to(self.data_device)\n        self.world_view_transform = w2c.transpose(0, 1)\n        self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).to(w2c.device)\n        self.full_proj_transform = self.world_view_transform @ self.projection_matrix\n        self.camera_center = self.world_view_transform.inverse()[3, :3]\n        \nclass MiniCam2:\n    def __init__(self, c2w, width, height, fovy, fovx, znear, zfar):\n        # c2w (pose) should be in NeRF convention.\n\n        self.image_width = width\n        self.image_height = height\n        self.FoVy = fovy\n        self.FoVx = fovx\n        self.znear = znear\n        self.zfar = zfar\n\n        w2c = np.linalg.inv(c2w)\n\n        # rectify...\n        w2c[1:3, :3] *= -1\n        w2c[:3, 3] *= -1\n\n        self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()\n        self.projection_matrix = (\n            getProjectionMatrix(\n                znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy\n            )\n            .transpose(0, 1)\n            .cuda()\n        )\n        self.full_proj_transform = self.world_view_transform @ self.projection_matrix\n        self.camera_center = -torch.tensor(c2w[:3, 3]).cuda()"
  },
  {
    "path": "scene/colmap_loader.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport numpy as np\nimport collections\nimport struct\n\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"])\nCamera = collections.namedtuple(\n    \"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nBaseImage = collections.namedtuple(\n    \"Image\", [\"id\", \"qvec\", \"tvec\", \"camera_id\", \"name\", \"xys\", \"point3D_ids\"])\nPoint3D = collections.namedtuple(\n    \"Point3D\", [\"id\", \"xyz\", \"rgb\", \"error\", \"image_ids\", \"point2D_idxs\"])\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12)\n}\nCAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)\n                         for camera_model in CAMERA_MODELS])\nCAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)\n                           for camera_model in CAMERA_MODELS])\n\n\ndef qvec2rotmat(qvec):\n    return np.array([\n        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,\n         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],\n         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],\n        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],\n         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,\n         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],\n        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],\n         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],\n         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])\n\ndef rotmat2qvec(R):\n    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat\n    K = np.array([\n        [Rxx - Ryy - Rzz, 0, 0, 0],\n        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],\n        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],\n        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0\n    eigvals, eigvecs = np.linalg.eigh(K)\n    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]\n    if qvec[0] < 0:\n        qvec *= -1\n    return qvec\n\nclass Image(BaseImage):\n    def qvec2rotmat(self):\n        return qvec2rotmat(self.qvec)\n\ndef read_next_bytes(fid, num_bytes, format_char_sequence, endian_character=\"<\"):\n    \"\"\"Read and unpack the next bytes from a binary file.\n    :param fid:\n    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    :param endian_character: Any of {@, =, <, >, !}\n    :return: Tuple of read and unpacked values.\n    \"\"\"\n    data = fid.read(num_bytes)\n    return struct.unpack(endian_character + format_char_sequence, data)\n\ndef read_points3D_text(path):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DText(const std::string& path)\n        void Reconstruction::WritePoints3DText(const std::string& path)\n    \"\"\"\n    xyzs = None\n    rgbs = None\n    errors = None\n    num_points = 0\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                num_points += 1\n\n\n    xyzs = np.empty((num_points, 3))\n    rgbs = np.empty((num_points, 3))\n    errors = np.empty((num_points, 1))\n    count = 0\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                xyz = np.array(tuple(map(float, elems[1:4])))\n                rgb = np.array(tuple(map(int, elems[4:7])))\n                error = np.array(float(elems[7]))\n                xyzs[count] = xyz\n                rgbs[count] = rgb\n                errors[count] = error\n                count += 1\n\n    return xyzs, rgbs, errors\n\ndef read_points3D_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n\n\n    with open(path_to_model_file, \"rb\") as fid:\n        num_points = read_next_bytes(fid, 8, \"Q\")[0]\n\n        xyzs = np.empty((num_points, 3))\n        rgbs = np.empty((num_points, 3))\n        errors = np.empty((num_points, 1))\n\n        for p_id in range(num_points):\n            binary_point_line_properties = read_next_bytes(\n                fid, num_bytes=43, format_char_sequence=\"QdddBBBd\")\n            xyz = np.array(binary_point_line_properties[1:4])\n            rgb = np.array(binary_point_line_properties[4:7])\n            error = np.array(binary_point_line_properties[7])\n            track_length = read_next_bytes(\n                fid, num_bytes=8, format_char_sequence=\"Q\")[0]\n            track_elems = read_next_bytes(\n                fid, num_bytes=8*track_length,\n                format_char_sequence=\"ii\"*track_length)\n            xyzs[p_id] = xyz\n            rgbs[p_id] = rgb\n            errors[p_id] = error\n    return xyzs, rgbs, errors\n\ndef read_intrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    cameras = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                camera_id = int(elems[0])\n                model = elems[1]\n                assert model == \"PINHOLE\", \"While the loader support other types, the rest of the code assumes PINHOLE\"\n                width = int(elems[2])\n                height = int(elems[3])\n                params = np.array(tuple(map(float, elems[4:])))\n                cameras[camera_id] = Camera(id=camera_id, model=model,\n                                            width=width, height=height,\n                                            params=params)\n    return cameras\n\ndef read_extrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_reg_images = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_reg_images):\n            binary_image_properties = read_next_bytes(\n                fid, num_bytes=64, format_char_sequence=\"idddddddi\")\n            image_id = binary_image_properties[0]\n            qvec = np.array(binary_image_properties[1:5])\n            tvec = np.array(binary_image_properties[5:8])\n            camera_id = binary_image_properties[8]\n            image_name = \"\"\n            current_char = read_next_bytes(fid, 1, \"c\")[0]\n            while current_char != b\"\\x00\":   # look for the ASCII 0 entry\n                image_name += current_char.decode(\"utf-8\")\n                current_char = read_next_bytes(fid, 1, \"c\")[0]\n            num_points2D = read_next_bytes(fid, num_bytes=8,\n                                           format_char_sequence=\"Q\")[0]\n            x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,\n                                       format_char_sequence=\"ddq\"*num_points2D)\n            xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),\n                                   tuple(map(float, x_y_id_s[1::3]))])\n            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))\n            images[image_id] = Image(\n                id=image_id, qvec=qvec, tvec=tvec,\n                camera_id=camera_id, name=image_name,\n                xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_intrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_cameras = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_cameras):\n            camera_properties = read_next_bytes(\n                fid, num_bytes=24, format_char_sequence=\"iiQQ\")\n            camera_id = camera_properties[0]\n            model_id = camera_properties[1]\n            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name\n            width = camera_properties[2]\n            height = camera_properties[3]\n            num_params = CAMERA_MODEL_IDS[model_id].num_params\n            params = read_next_bytes(fid, num_bytes=8*num_params,\n                                     format_char_sequence=\"d\"*num_params)\n            cameras[camera_id] = Camera(id=camera_id,\n                                        model=model_name,\n                                        width=width,\n                                        height=height,\n                                        params=np.array(params))\n        assert len(cameras) == num_cameras\n    return cameras\n\n\ndef read_extrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    images = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                image_id = int(elems[0])\n                qvec = np.array(tuple(map(float, elems[1:5])))\n                tvec = np.array(tuple(map(float, elems[5:8])))\n                camera_id = int(elems[8])\n                image_name = elems[9]\n                elems = fid.readline().split()\n                xys = np.column_stack([tuple(map(float, elems[0::3])),\n                                       tuple(map(float, elems[1::3]))])\n                point3D_ids = np.array(tuple(map(int, elems[2::3])))\n                images[image_id] = Image(\n                    id=image_id, qvec=qvec, tvec=tvec,\n                    camera_id=camera_id, name=image_name,\n                    xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_colmap_bin_array(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py\n\n    :param path: path to the colmap binary file.\n    :return: nd array with the floating point values in the value\n    \"\"\"\n    with open(path, \"rb\") as fid:\n        width, height, channels = np.genfromtxt(fid, delimiter=\"&\", max_rows=1,\n                                                usecols=(0, 1, 2), dtype=int)\n        fid.seek(0)\n        num_delimiter = 0\n        byte = fid.read(1)\n        while True:\n            if byte == b\"&\":\n                num_delimiter += 1\n                if num_delimiter >= 3:\n                    break\n            byte = fid.read(1)\n        array = np.fromfile(fid, np.float32)\n    array = array.reshape((width, height, channels), order=\"F\")\n    return np.transpose(array, (1, 0, 2)).squeeze()\n"
  },
  {
    "path": "scene/dataset_readers.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport sys\nimport cv2\nimport json\nimport numpy as np\nimport open3d as o3d\nfrom PIL import Image, ImageFile\nfrom pathlib import Path\nfrom typing import NamedTuple\nfrom plyfile import PlyData, PlyElement\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\nfrom scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \\\n    read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text\nfrom tools.graphics_utils import getWorld2View2, focal2fov, fov2focal\nfrom tools.sh_utils import SH2RGB\nfrom scene.gaussian_model import BasicPointCloud\nfrom tools.math_utils import normalize_pts\nfrom process_data.convert_data_to_json import bound_by_points\n\nclass CameraInfo(NamedTuple):\n    uid: int\n    R: np.array\n    T: np.array\n    FovY: np.array\n    FovX: np.array\n    image: np.array\n    image_path: str\n    image_name: str\n    width: int\n    height: int\n    depth: None\n    normal: None\n    mask: None\n\nclass SceneInfo(NamedTuple):\n    point_cloud: BasicPointCloud\n    train_cameras: list\n    test_cameras: list\n    nerf_normalization: dict\n    ply_path: str\n    trans: np.array\n    scale: np.array\n    first_name: str\n\ndef getNerfppNorm(cam_info):\n    def get_center_and_diag(cam_centers):\n        cam_centers = np.hstack(cam_centers)\n        avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)\n        center = avg_cam_center\n        dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)\n        diagonal = np.max(dist)\n        return center.flatten(), diagonal\n\n    cam_centers = []\n\n    for cam in cam_info:\n        W2C = getWorld2View2(cam.R, cam.T)\n        C2W = np.linalg.inv(W2C)\n        cam_centers.append(C2W[:3, 3:4])\n\n    center, diagonal = get_center_and_diag(cam_centers)\n    radius = diagonal * 1.1\n\n    translate = -center\n\n    return {\"translate\": translate, \"radius\": radius}\n\ndef readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder, load_depth=False, load_normal=False, load_mask=False, normal_folder='normals', depth_folder='depths'):\n    if load_depth:\n        depths_folder = images_folder.replace('images', depth_folder)\n    \n    if load_normal:\n        normals_folder = images_folder.replace('images', normal_folder)\n    if load_mask:\n        mask_folder = images_folder.replace('images', 'masks')\n    cam_infos = []\n    for idx, key in enumerate(cam_extrinsics):\n        sys.stdout.write('\\r')\n        # the exact output you're looking for:\n        sys.stdout.write(\"Reading camera {}/{}\".format(idx+1, len(cam_extrinsics)))\n        sys.stdout.flush()\n\n        extr = cam_extrinsics[key]\n        intr = cam_intrinsics[extr.camera_id]\n        height = intr.height\n        width = intr.width\n\n        uid = intr.id\n        R = np.transpose(qvec2rotmat(extr.qvec))\n        T = np.array(extr.tvec)\n        \n        if intr.model==\"SIMPLE_PINHOLE\":\n            focal_length_x = intr.params[0]\n            FovY = focal2fov(focal_length_x, height)\n            FovX = focal2fov(focal_length_x, width)\n        elif intr.model==\"PINHOLE\":\n            focal_length_x = intr.params[0]\n            focal_length_y = intr.params[1]\n            FovY = focal2fov(focal_length_y, height)\n            FovX = focal2fov(focal_length_x, width)\n        else:\n            assert False, \"Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!\"\n\n        image_path = os.path.join(images_folder, os.path.basename(extr.name))\n        image_name = os.path.basename(image_path).split(\".\")[0]\n        image = Image.open(image_path)\n        \n        depth = None\n        if load_depth:\n            depth_path = os.path.join(depths_folder, os.path.basename(extr.name).replace('jpg', 'npz').replace('png', 'npz'))\n            if os.path.exists(depth_path):\n                depth = np.load(depth_path)['arr_0']\n            else:\n                depth_path = os.path.join(depths_folder, os.path.basename(extr.name).replace('jpg', 'png'))\n                depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)\n            \n            if depth.ndim == 2: depth = depth[..., None]\n        \n        normal = None\n        if load_normal:\n            normal_path = os.path.join(normals_folder, os.path.basename(extr.name).replace('png', 'npz').replace('jpg', 'npz').replace('JPG', 'npz'))\n            normal = np.load(normal_path)['arr_0'] # -1, 1\n\n        mask = None\n        if load_mask:\n            mask_path = os.path.join(mask_folder, os.path.basename(extr.name).replace('jpg', 'png'))\n            mask_path = mask_path if os.path.exists(mask_path) else \\\n                        os.path.join(mask_folder, os.path.basename(extr.name)[1:])\n            mask = Image.open(mask_path)\n\n        cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, image=image,\n                              image_path=image_path, image_name=image_name, width=width, height=height, depth=depth, normal=normal, mask=mask)\n        cam_infos.append(cam_info)\n    sys.stdout.write('\\n')\n    return cam_infos\n\ndef fetchPly(path):\n    plydata = PlyData.read(path)\n    vertices = plydata['vertex']\n    positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T\n    colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0\n    normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T\n    return BasicPointCloud(points=positions, colors=colors, normals=normals)\n\ndef storePly(path, xyz, rgb, normals=None):\n    # Define the dtype for the structured array\n    dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),\n            ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),\n            ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]\n    \n    normals = np.zeros_like(xyz) if normals is None else normals\n\n    elements = np.empty(xyz.shape[0], dtype=dtype)\n    attributes = np.concatenate((xyz, normals, rgb), axis=1)\n    elements[:] = list(map(tuple, attributes))\n\n    # Create the PlyData object and write to file\n    vertex_element = PlyElement.describe(elements, 'vertex')\n    ply_data = PlyData([vertex_element])\n    ply_data.write(path)\n\ndef get_inside_mask(pts, trans, scale):\n    pts = normalize_pts(pts, trans, scale)\n    \n    inside = np.all(np.abs(pts) < 1.5, axis=-1)\n    return inside\n\ndef filter_point_cloud(trans, scale, xyz, rgb, nb_points=5, radius=0.1):\n    inside = get_inside_mask(xyz, trans, scale)\n    xyz_inside = xyz[inside]\n    rgb_inside = rgb[inside]\n    xyz_outside = xyz[~inside]\n    rgb_outside = rgb[~inside]\n    \n    pcd_inside = o3d.geometry.PointCloud()\n    pcd_inside.points = o3d.utility.Vector3dVector(xyz_inside)\n    pcd_inside.colors = o3d.utility.Vector3dVector(rgb_inside)\n    \n    pcd_inside_filter, ind = pcd_inside.remove_radius_outlier(nb_points, radius)\n    \n    xyz_inside = np.asarray(pcd_inside_filter.points)\n    rgb_inside = np.asarray(pcd_inside_filter.colors)\n    \n    xyz = np.concatenate((xyz_inside, xyz_outside), axis=0)\n    rgb = np.concatenate((rgb_inside, rgb_outside), axis=0)\n    \n    return xyz, rgb\n\ndef readColmapSceneInfo(path, images, eval, llffhold=8, ratio=0, split=False, load_depth=False, load_normal=False, load_mask=False, normal_folder='normals', depth_folder='depths'):\n    colmap_dir = os.path.join(path, \"sparse/0\")\n    if not os.path.exists(colmap_dir):\n        colmap_dir = os.path.join(path, \"sparse\")\n    try:\n        cameras_extrinsic_file = os.path.join(colmap_dir, \"images.bin\")\n        cameras_intrinsic_file = os.path.join(colmap_dir, \"cameras.bin\")\n        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)\n        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)\n    except:\n        cameras_extrinsic_file = os.path.join(colmap_dir, \"images.txt\")\n        cameras_intrinsic_file = os.path.join(colmap_dir, \"cameras.txt\")\n        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)\n        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)\n    \n    ply_path = os.path.join(colmap_dir, \"points3D.ply\")\n    bin_path = os.path.join(colmap_dir, \"points3D.bin\")\n    txt_path = os.path.join(colmap_dir, \"points3D.txt\")\n    \n    reading_dir = \"images\" if images == None else images\n    cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir), load_depth=load_depth, load_normal=load_normal, load_mask=load_mask, normal_folder=normal_folder, depth_folder=depth_folder)\n    cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)\n    \n    meta_fname = f\"{path}/meta.json\"\n    if os.path.exists(meta_fname):\n        with open(meta_fname) as file:\n            meta = json.load(file)\n        trans = np.array(meta[\"trans\"], dtype=np.float32)\n        scale = np.array(meta[\"scale\"], dtype=np.float32)\n    else:\n        print(\"No meta.json file found, using default values.\")\n        \n        if not os.path.exists(ply_path):\n            print(\"Converting point3d.bin to .ply, will happen only the first time you open the scene.\")\n            try:\n                xyz, rgb, _ = read_points3D_binary(bin_path)\n            except:\n                xyz, rgb, _ = read_points3D_text(txt_path)\n        #     xyz, rgb = filter_point_cloud(trans, scale, xyz, rgb)\n        #     storePly(ply_path, xyz, rgb)\n        # try:\n        #     pcd = fetchPly(ply_path)\n        # except:\n        #     pcd = None\n        \n        trans, scale, bounding_box = bound_by_points(xyz)\n        meta = {\n            'trans': trans.tolist(),\n            'scale': scale.tolist()\n        }\n        with open(meta_fname, \"w\") as file:\n            json.dump(meta, file, indent=4)\n\n    if ratio > 0:\n        len_train = int(len(cam_infos) * ratio)\n        llffhold = len(cam_infos) // len_train\n        train_idx = set([int(i * llffhold) for i in range(len_train)])\n        test_idx = set(range(len(cam_infos))) - train_idx\n        train_cam_infos = [cam_infos[i] for i in train_idx]\n        test_cam_infos = [cam_infos[i] for i in test_idx]\n    elif eval:\n        if split and \"test\" in meta:\n            train_cam_infos = [c for c in cam_infos if c.image_name in meta[\"train\"]]\n            test_cam_infos = [c for c in cam_infos if c.image_name in meta[\"test\"]]\n        else:\n            train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]\n            test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]\n    else:\n        train_cam_infos = cam_infos\n        test_cam_infos = []\n        \n    print(f\"Train: {len(train_cam_infos)}, Test: {len(test_cam_infos)}\")\n\n    first_name = test_cam_infos[0].image_name if eval else cam_infos[0].image_name\n    nerf_normalization = getNerfppNorm(train_cam_infos)\n\n    if not os.path.exists(ply_path):\n        print(\"Converting point3d.bin to .ply, will happen only the first time you open the scene.\")\n        try:\n            xyz, rgb, _ = read_points3D_binary(bin_path)\n        except:\n            xyz, rgb, _ = read_points3D_text(txt_path)\n        xyz, rgb = filter_point_cloud(trans, scale, xyz, rgb)\n        storePly(ply_path, xyz, rgb)\n    try:\n        pcd = fetchPly(ply_path)\n    except:\n        pcd = None\n\n    scene_info = SceneInfo(point_cloud=pcd,\n                           train_cameras=train_cam_infos,\n                           test_cameras=test_cam_infos,\n                           nerf_normalization=nerf_normalization,\n                           ply_path=ply_path,\n                           trans=trans,\n                           scale=scale,\n                           first_name=first_name)\n    return scene_info\n\ndef readCamerasFromTransforms(path, transformsfile, white_background, extension=\".png\"):\n    cam_infos = []\n\n    with open(os.path.join(path, transformsfile)) as json_file:\n        contents = json.load(json_file)\n        fovx = contents[\"camera_angle_x\"]\n\n        frames = contents[\"frames\"]\n        for idx, frame in enumerate(frames):\n            cam_name = os.path.join(path, frame[\"file_path\"] + extension)\n\n            # NeRF 'transform_matrix' is a camera-to-world transform\n            c2w = np.array(frame[\"transform_matrix\"])\n            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)\n            c2w[:3, 1:3] *= -1\n\n            # get the world-to-camera transform and set R, T\n            w2c = np.linalg.inv(c2w)\n            R = np.transpose(w2c[:3,:3])  # R is stored transposed due to 'glm' in CUDA code\n            T = w2c[:3, 3]\n\n            image_path = os.path.join(path, cam_name)\n            image_name = Path(cam_name).stem\n            image = Image.open(image_path)\n\n            im_data = np.array(image.convert(\"RGBA\"))\n\n            bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])\n\n            norm_data = im_data / 255.0\n            arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])\n            image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), \"RGB\")\n\n            fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])\n            FovY = fovy \n            FovX = fovx\n\n            cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, image=image,\n                            image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))\n            \n    return cam_infos\n\ndef readNerfSyntheticInfo(path, white_background, eval, extension=\".png\"):\n    print(\"Reading Training Transforms\")\n    train_cam_infos = readCamerasFromTransforms(path, \"transforms_train.json\", white_background, extension)\n    print(\"Reading Test Transforms\")\n    test_cam_infos = readCamerasFromTransforms(path, \"transforms_test.json\", white_background, extension)\n    \n    if not eval:\n        train_cam_infos.extend(test_cam_infos)\n        test_cam_infos = []\n\n    nerf_normalization = getNerfppNorm(train_cam_infos)\n\n    ply_path = os.path.join(path, \"points3d.ply\")\n    if not os.path.exists(ply_path):\n        # Since this data set has no colmap data, we start with random points\n        num_pts = 100_000\n        print(f\"Generating random point cloud ({num_pts})...\")\n        \n        # We create random points inside the bounds of the synthetic Blender scenes\n        xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3\n        shs = np.random.random((num_pts, 3)) / 255.0\n        pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))\n\n        storePly(ply_path, xyz, SH2RGB(shs) * 255)\n    try:\n        pcd = fetchPly(ply_path)\n    except:\n        pcd = None\n\n    scene_info = SceneInfo(point_cloud=pcd,\n                           train_cameras=train_cam_infos,\n                           test_cameras=test_cam_infos,\n                           nerf_normalization=nerf_normalization,\n                           ply_path=ply_path)\n    return scene_info\n\nsceneLoadTypeCallbacks = {\n    \"Colmap\": readColmapSceneInfo,\n    \"Blender\" : readNerfSyntheticInfo\n}"
  },
  {
    "path": "scene/gaussian_model.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport torch\nimport numpy as np\nfrom torch import nn\nfrom copy import deepcopy\ntry:\n    from simple_knn._C import distCUDA2\nexcept ModuleNotFoundError:\n    pass\nfrom plyfile import PlyData, PlyElement\nfrom io import BytesIO\nfrom tqdm import trange\n\nfrom tools.sh_utils import RGB2SH\nfrom tools.system_utils import mkdir_p\nfrom tools.graphics_utils import BasicPointCloud\nfrom tools.math_utils import normalize_pts, get_inside_normalized\nfrom tools.general_utils import strip_symmetric, build_scaling_rotation\nfrom tools.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation\nfrom tools.denoise_pcd import remove_radius_outlier\nfrom scene.appearance_network import AppearanceNetwork\nfrom tools.semantic_id import BACKGROUND\n\n\nclass GaussianModel:\n    def setup_functions(self):\n        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):\n            L = build_scaling_rotation(scaling_modifier * scaling, rotation)\n            actual_covariance = L @ L.transpose(1, 2)\n            symm = strip_symmetric(actual_covariance)\n            return symm\n        \n        self.scaling_activation = torch.exp\n        self.scaling_inverse_activation = torch.log\n\n        self.covariance_activation = build_covariance_from_scaling_rotation\n\n        self.opacity_activation = torch.sigmoid\n        self.inverse_opacity_activation = inverse_sigmoid\n\n        self.rotation_activation = torch.nn.functional.normalize\n        \n    def __init__(self, cfg):\n        self.active_sh_degree = 0\n        self.max_sh_degree = cfg.sh_degree  \n        self._xyz = torch.empty(0)\n        self._features_dc = torch.empty(0)\n        self._features_rest = torch.empty(0)\n        self._scaling = torch.empty(0)\n        self._rotation = torch.empty(0)\n        self._opacity = torch.empty(0)\n        self.max_radii2D = torch.empty(0)\n        self.xyz_gradient_accum = torch.empty(0)\n        self.denom = torch.empty(0)\n        self.optimizer = None\n        self.percent_dense = 0\n        self.spatial_lr_scale = 0\n        self.setup_functions()\n        self.max_mem = cfg.max_mem\n        \n        self.use_decoupled_appearance = cfg.use_decoupled_appearance\n        if self.use_decoupled_appearance:\n            # appearance network and appearance embedding\n            self.appearance_network = AppearanceNetwork(3+64, 3).cuda()\n            std = 1e-4\n            num_embedding = len(os.listdir(os.path.join(cfg.source_path, 'images')))\n            self._appearance_embeddings = nn.Parameter(torch.empty(num_embedding, 64).cuda())\n            self._appearance_embeddings.data.normal_(0, std)\n        \n        self.enable_semantic = cfg.enable_semantic\n        self._objects_dc = torch.empty(0)\n        if self.enable_semantic:\n            self.ch_sem_feat = cfg.ch_sem_feat\n            self.num_cls = cfg.num_cls\n            self.classifier = torch.nn.Conv2d(self.ch_sem_feat, self.num_cls, kernel_size=1).cuda()\n\n    def capture(self):\n        return (\n            self.active_sh_degree,\n            self._xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self._objects_dc,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            self.optimizer.state_dict(),\n            self.spatial_lr_scale,\n        )\n    \n    def restore(self, model_args, training_args):\n        (self.active_sh_degree, \n        self._xyz, \n        self._features_dc, \n        self._features_rest,\n        self._scaling, \n        self._rotation, \n        self._opacity,\n        self._objects_dc,\n        self.max_radii2D, \n        xyz_gradient_accum, \n        denom,\n        opt_dict, \n        self.spatial_lr_scale,\n        ) = model_args\n        self.training_setup(training_args)\n        self.xyz_gradient_accum = xyz_gradient_accum\n        self.denom = denom\n        self.optimizer.load_state_dict(opt_dict)\n\n    @property\n    def get_scaling(self):\n        scaling = self._scaling\n        return self.scaling_activation(scaling)\n    \n    @property\n    def get_rotation(self):\n        return self.rotation_activation(self._rotation)\n    \n    @property\n    def get_xyz(self):\n        return self._xyz\n    \n    @property\n    def get_features(self):\n        features_dc = self._features_dc\n        features_rest = self._features_rest\n        return torch.cat((features_dc, features_rest), dim=1)\n    \n    @property\n    def get_objects(self):\n        return self._objects_dc\n    \n    def get_cls(self, idx=None):\n        assert self.enable_semantic, \"Semantic feature is not enabled\"\n        feats = self.get_objects.permute(0, 2, 1)[..., None]\n        if idx is not None: feats = feats[idx]\n        return self.classifier(feats).view(-1, self.num_cls).argmax(-1)\n    \n    def logits_2_label(self, logits):\n        return torch.argmax(self.logits2prob(logits), dim=-1)\n    \n    def logits2prob(self, logits):\n        return torch.nn.functional.softmax(logits, dim=-1)\n    \n    @property\n    def get_opacity(self):\n        return self.opacity_activation(self._opacity)\n    \n    def get_apperance_embedding(self, idx):\n        return self._appearance_embeddings[idx]\n    \n    # @property\n    def get_normal(self, valid=None, idx=None, refine_sign=True, is_all=False):\n        '''\n            rots: N, 3, 3\n        '''\n        normal = None\n        if valid is None:\n            if is_all:\n                valid = torch.ones(self.get_xyz.shape[0], device='cuda', dtype=torch.bool)\n            else:\n                valid = self.get_inside_gaus_normalized()[0]\n                normal = torch.zeros_like(self.get_xyz)\n        \n        _rot = self.get_rotation[valid]\n        if idx is not None: _rot = _rot[idx]\n        \n        rots = build_rotation(_rot)\n        scaling = self.get_scaling[valid]\n        if idx is not None: scaling = scaling[idx]\n        axis = torch.argmin(scaling, dim=-1)\n        normals = rots.gather(2, axis[:, None, None].expand(-1, 3, -1)).squeeze(-1)\n        \n        if normal is not None:\n            normal[valid] = normals\n            normals = normal\n        return normals\n    \n    def get_covariance(self, scaling_modifier = 1):\n        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)\n\n    def oneupSHdegree(self):\n        if self.active_sh_degree < self.max_sh_degree:\n            self.active_sh_degree += 1\n\n    def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):\n        self.spatial_lr_scale = spatial_lr_scale\n        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()\n        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())\n        features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()\n        features[:, :3, 0 ] = fused_color\n        features[:, 3:, 1:] = 0.0\n\n        print(\"Number of points at initialisation : \", fused_point_cloud.shape[0])\n\n        dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)\n        scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)\n        rots = torch.zeros((fused_point_cloud.shape[0], 4), device=\"cuda\")\n        rots[:, 0] = 1\n\n        opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device=\"cuda\"))\n\n        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))\n        self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))\n        self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))\n        self._scaling = nn.Parameter(scales.requires_grad_(True))\n        self._rotation = nn.Parameter(rots.requires_grad_(True))\n        self._opacity = nn.Parameter(opacities.requires_grad_(True))\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n        \n        if self.enable_semantic:\n            # random init obj_id now\n            fused_objects = RGB2SH(torch.rand((fused_point_cloud.shape[0], self.ch_sem_feat), device=\"cuda\"))\n            fused_objects = fused_objects[:,:,None]\n            self._objects_dc = nn.Parameter(fused_objects.transpose(1, 2).contiguous().requires_grad_(True))\n\n    def training_setup(self, training_args, neural_sdf_params=None):\n        self.percent_dense = training_args.percent_dense\n        self.large_percent_dense = None\n        if hasattr(training_args, 'densify_large'):\n            self.large_percent_dense = training_args.densify_large.percent_dense if \\\n                getattr(training_args.densify_large, 'percent_dense', 0) > 0 else None\n        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n\n        l = [\n            {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, \"name\": \"xyz\"},\n            {'params': [self._features_dc], 'lr': training_args.feature_lr, \"name\": \"f_dc\"},\n            {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, \"name\": \"f_rest\"},\n            {'params': [self._opacity], 'lr': training_args.opacity_lr, \"name\": \"opacity\"},\n            {'params': [self._scaling], 'lr': training_args.scaling_lr, \"name\": \"scaling\"},\n            {'params': [self._rotation], 'lr': training_args.rotation_lr, \"name\": \"rotation\"},\n        ]\n        if self.use_decoupled_appearance:\n            l.append({'params': [self._appearance_embeddings], 'lr': training_args.appearance_embeddings_lr, \"name\": \"appearance_embeddings\"})\n            l.append({'params': self.appearance_network.parameters(), 'lr': training_args.appearance_network_lr, \"name\": \"appearance_network\"})\n        if self.enable_semantic:\n            l.append({'params': [self._objects_dc], 'lr': training_args.feature_lr, \"name\": \"obj_dc\"})\n            l.append({'params': self.classifier.parameters(), 'lr': training_args.cls_lr, \"name\": \"classifier\"})\n        if neural_sdf_params is not None:\n            l.append({'params': neural_sdf_params.parameters(), 'lr': training_args.sdf_lr, \"name\": \"neural_sdf\"})\n\n        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)\n        self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,\n                                                    lr_final=training_args.position_lr_final*self.spatial_lr_scale,\n                                                    lr_delay_mult=training_args.position_lr_delay_mult,\n                                                    max_steps=training_args.position_lr_max_steps)\n\n    def update_learning_rate(self, iteration):\n        ''' Learning rate scheduling per step '''\n        for param_group in self.optimizer.param_groups:\n            if param_group[\"name\"] == \"xyz\":\n                lr = self.xyz_scheduler_args(iteration)\n                param_group['lr'] = lr\n                return lr\n\n    def construct_list_of_attributes(self):\n        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']\n        # All channels except the 3 DC\n        for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):\n            l.append('f_dc_{}'.format(i))\n        for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):\n            l.append('f_rest_{}'.format(i))\n        l.append('opacity')\n        for i in range(self._scaling.shape[1]):\n            l.append('scale_{}'.format(i))\n        for i in range(self._rotation.shape[1]):\n            l.append('rot_{}'.format(i))\n        if self.enable_semantic:\n            for i in range(self._objects_dc.shape[1]*self._objects_dc.shape[2]):\n                l.append('obj_dc_{}'.format(i))\n        return l\n    \n    def save_ply(self, path):\n        mkdir_p(os.path.dirname(path))\n\n        xyz = self._xyz.detach().cpu().numpy()\n        normals = np.zeros_like(xyz)\n        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n        f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n        opacities = self._opacity.detach().cpu().numpy()\n        scale = self._scaling.detach().cpu().numpy()\n        rotation = self._rotation.detach().cpu().numpy()\n        if self.enable_semantic:\n            obj_dc = self._objects_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n\n        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]\n\n        elements = np.empty(xyz.shape[0], dtype=dtype_full)\n        attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)\n        if self.enable_semantic:\n            attributes = np.concatenate((attributes, obj_dc), axis=1)\n        \n        elements[:] = list(map(tuple, attributes))\n        el = PlyElement.describe(elements, 'vertex')\n        PlyData([el]).write(path)\n        \n        state_dict = {}\n        if self.use_decoupled_appearance:\n            state_dict[\"appearance_network\"] = self.appearance_network.state_dict()\n            state_dict[\"appearance_embeddings\"] = self._appearance_embeddings\n        if self.enable_semantic:\n            state_dict[\"classifier\"] = self.classifier.state_dict()\n        if len(state_dict) > 0:\n            torch.save(state_dict, os.path.join(os.path.dirname(path), 'model.pth'))\n\n    @torch.no_grad()\n    def save_inside_ply(self, path, inside=None):\n        mkdir_p(os.path.dirname(path))\n        \n        if inside is None:\n            inside = self.get_inside_gaus_normalized()[0]\n        \n        xyz = self._xyz[inside].detach()\n        _normals = self.get_normal(inside, refine_sign=True).detach()\n        normals = _normals\n        \n        inside = inside.cpu().numpy()\n        xyz = xyz.cpu().numpy()\n        normals = normals.cpu().numpy()\n        \n        f_dc = self._features_dc[inside].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n        f_rest = self._features_rest[inside].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n        opacities = self._opacity[inside].detach().cpu().numpy()\n        scale = self._scaling[inside].detach().cpu().numpy()\n        rotation = self._rotation[inside].detach().cpu().numpy()\n        if self.enable_semantic:\n            obj_dc = self._objects_dc[inside].detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n\n        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]\n\n        elements = np.empty(xyz.shape[0], dtype=dtype_full)\n        attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)\n        if self.enable_semantic:\n            attributes = np.concatenate((attributes, obj_dc), axis=1)\n        elements[:] = list(map(tuple, attributes))\n        el = PlyElement.describe(elements, 'vertex')\n        PlyData([el]).write(path)\n\n    def save_visi_ply(self, path, visi):\n        inside = self.get_inside_gaus_normalized()[0]\n        inside = inside & visi\n        \n        self.save_inside_ply(path, inside)\n\n    def reset_opacity(self):\n        opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))\n        optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, \"opacity\")\n        self._opacity = optimizable_tensors[\"opacity\"]\n\n    def load_ply(self, path):\n        plydata = PlyData.read(path)\n\n        xyz = np.stack((np.asarray(plydata.elements[0][\"x\"]),\n                        np.asarray(plydata.elements[0][\"y\"]),\n                        np.asarray(plydata.elements[0][\"z\"])),  axis=1)\n        opacities = np.asarray(plydata.elements[0][\"opacity\"])[..., np.newaxis]\n\n        features_dc = np.zeros((xyz.shape[0], 3, 1))\n        features_dc[:, 0, 0] = np.asarray(plydata.elements[0][\"f_dc_0\"])\n        features_dc[:, 1, 0] = np.asarray(plydata.elements[0][\"f_dc_1\"])\n        features_dc[:, 2, 0] = np.asarray(plydata.elements[0][\"f_dc_2\"])\n\n        extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"f_rest_\")]\n        extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))\n        assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3\n        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))\n        for idx, attr_name in enumerate(extra_f_names):\n            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])\n        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)\n        features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))\n\n        scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"scale_\")]\n        scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))\n        scales = np.zeros((xyz.shape[0], len(scale_names)))\n        for idx, attr_name in enumerate(scale_names):\n            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])\n\n        rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"rot\")]\n        rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))\n        rots = np.zeros((xyz.shape[0], len(rot_names)))\n        for idx, attr_name in enumerate(rot_names):\n            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])\n        \n        if self.enable_semantic:\n            objects_dc = np.zeros((xyz.shape[0], self.ch_sem_feat, 1))\n            for idx in range(self.ch_sem_feat):\n                objects_dc[:,idx,0] = np.asarray(plydata.elements[0][\"obj_dc_\"+str(idx)])\n            \n        self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device=\"cuda\").transpose(1, 2).contiguous().requires_grad_(True))\n        self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device=\"cuda\").transpose(1, 2).contiguous().requires_grad_(True))\n        self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        if self.enable_semantic:\n            self._objects_dc = nn.Parameter(torch.tensor(objects_dc, dtype=torch.float, device=\"cuda\").transpose(1, 2).contiguous().requires_grad_(True))\n        \n        self.active_sh_degree = self.max_sh_degree\n        \n        ckpt_path = os.path.join(os.path.dirname(path), 'model.pth')\n        if os.path.exists(ckpt_path):\n            state_dict = torch.load(ckpt_path)\n            if self.enable_semantic:\n                self.classifier.load_state_dict(state_dict[\"classifier\"])\n            if self.use_decoupled_appearance:\n                self.appearance_network.load_state_dict(state_dict[\"appearance_network\"])\n                self._appearance_embeddings = nn.Parameter(state_dict[\"appearance_embeddings\"].cuda())\n\n    def replace_tensor_to_optimizer(self, tensor, name):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            if group[\"name\"] in [\"appearance_embeddings\", \"appearance_network\", \"classifier\"]:\n                continue\n            if group[\"name\"] == name:\n                stored_state = self.optimizer.state.get(group['params'][0], None)\n                stored_state[\"exp_avg\"] = torch.zeros_like(tensor)\n                stored_state[\"exp_avg_sq\"] = torch.zeros_like(tensor)\n\n                del self.optimizer.state[group['params'][0]]\n                group[\"params\"][0] = nn.Parameter(tensor.requires_grad_(True))\n                self.optimizer.state[group['params'][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n        return optimizable_tensors\n\n    def _prune_optimizer(self, mask):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            if group[\"name\"] in [\"appearance_embeddings\", \"appearance_network\", \"classifier\"]:\n                continue\n            stored_state = self.optimizer.state.get(group['params'][0], None)\n            if stored_state is not None:\n                stored_state[\"exp_avg\"] = stored_state[\"exp_avg\"][mask]\n                stored_state[\"exp_avg_sq\"] = stored_state[\"exp_avg_sq\"][mask]\n\n                del self.optimizer.state[group['params'][0]]\n                group[\"params\"][0] = nn.Parameter((group[\"params\"][0][mask].requires_grad_(True)))\n                self.optimizer.state[group['params'][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n            else:\n                group[\"params\"][0] = nn.Parameter(group[\"params\"][0][mask].requires_grad_(True))\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n        return optimizable_tensors\n\n    def prune_points(self, mask):\n        valid_points_mask = ~mask\n        optimizable_tensors = self._prune_optimizer(valid_points_mask)\n\n        self._xyz = optimizable_tensors[\"xyz\"]\n        self._features_dc = optimizable_tensors[\"f_dc\"]\n        self._features_rest = optimizable_tensors[\"f_rest\"]\n        self._opacity = optimizable_tensors[\"opacity\"]\n        self._scaling = optimizable_tensors[\"scaling\"]\n        self._rotation = optimizable_tensors[\"rotation\"]\n        if self.enable_semantic:\n            self._objects_dc = optimizable_tensors[\"obj_dc\"]\n\n        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]\n\n        self.denom = self.denom[valid_points_mask]\n        self.max_radii2D = self.max_radii2D[valid_points_mask]\n\n    def cat_tensors_to_optimizer(self, tensors_dict):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            if group[\"name\"] in [\"appearance_embeddings\", \"appearance_network\", \"classifier\"]:\n                continue\n            assert len(group[\"params\"]) == 1\n            extension_tensor = tensors_dict[group[\"name\"]]\n            stored_state = self.optimizer.state.get(group['params'][0], None)\n            if stored_state is not None:\n\n                stored_state[\"exp_avg\"] = torch.cat((stored_state[\"exp_avg\"], torch.zeros_like(extension_tensor)), dim=0)\n                stored_state[\"exp_avg_sq\"] = torch.cat((stored_state[\"exp_avg_sq\"], torch.zeros_like(extension_tensor)), dim=0)\n\n                del self.optimizer.state[group['params'][0]]\n                group[\"params\"][0] = nn.Parameter(torch.cat((group[\"params\"][0], extension_tensor), dim=0).requires_grad_(True))\n                self.optimizer.state[group['params'][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n            else:\n                group[\"params\"][0] = nn.Parameter(torch.cat((group[\"params\"][0], extension_tensor), dim=0).requires_grad_(True))\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n\n        return optimizable_tensors\n\n    def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_objects_dc=None, reset=True):\n        d = {\"xyz\": new_xyz,\n        \"f_dc\": new_features_dc,\n        \"f_rest\": new_features_rest,\n        \"opacity\": new_opacities,\n        \"scaling\" : new_scaling,\n        \"rotation\" : new_rotation}\n        if self.enable_semantic:\n            d[\"obj_dc\"] = new_objects_dc\n\n        optimizable_tensors = self.cat_tensors_to_optimizer(d)\n        self._xyz = optimizable_tensors[\"xyz\"]\n        self._features_dc = optimizable_tensors[\"f_dc\"]\n        self._features_rest = optimizable_tensors[\"f_rest\"]\n        self._opacity = optimizable_tensors[\"opacity\"]\n        self._scaling = optimizable_tensors[\"scaling\"]\n        self._rotation = optimizable_tensors[\"rotation\"]\n        if self.enable_semantic:\n            self._objects_dc = optimizable_tensors[\"obj_dc\"]\n\n        if reset:\n            self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n            self.denom = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n            self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n        else:\n            self.xyz_gradient_accum = torch.cat((self.xyz_gradient_accum, torch.zeros((new_xyz.shape[0], 1), device=\"cuda\")), dim=0)\n            self.denom = torch.cat((self.denom, torch.zeros((new_xyz.shape[0], 1), device=\"cuda\")), dim=0)\n            self.max_radii2D = torch.cat((self.max_radii2D, torch.zeros((new_xyz.shape[0]), device=\"cuda\")), dim=0)\n\n    def densify_and_split(self, grads, grad_threshold, scene_extent, visi=None, N=2):\n        n_init_points = self.get_xyz.shape[0]\n        # Extract points that satisfy the gradient condition\n        padded_grad = torch.zeros((n_init_points), device=\"cuda\")\n        padded_grad[:grads.shape[0]] = grads.squeeze()\n        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)\n        selected_pts_mask = torch.logical_and(selected_pts_mask,\n                                              torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)\n\n        if self.large_percent_dense is not None:\n            densify_pts_mask = torch.max(self.get_scaling, dim=1).values > self.large_percent_dense * scene_extent\n            inside, _ = self.get_inside_gaus_normalized()\n            densify_pts_mask = torch.logical_and(densify_pts_mask, inside)\n            if visi is not None:\n                padded_vis = torch.zeros((n_init_points), device=\"cuda\").bool()\n                padded_vis[:visi.shape[0]] = visi\n                densify_pts_mask = torch.logical_and(densify_pts_mask, padded_vis)\n            selected_pts_mask = torch.logical_or(selected_pts_mask, densify_pts_mask)\n            \n        stds = self.get_scaling[selected_pts_mask].repeat(N,1)\n        means =torch.zeros((stds.size(0), 3),device=\"cuda\")\n        samples = torch.normal(mean=means, std=stds)\n        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)\n        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)\n        new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))\n        new_rotation = self._rotation[selected_pts_mask].repeat(N,1)\n        new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)\n        new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)\n        new_opacity = self._opacity[selected_pts_mask].repeat(N,1)\n        new_objects_dc = self._objects_dc[selected_pts_mask].repeat(N,1,1) if self.enable_semantic else None\n\n        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_objects_dc)\n\n        prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device=\"cuda\", dtype=bool)))\n        self.prune_points(prune_filter)\n    \n    def get_dir_max_scaling(self, scaling, rots):\n        '''\n            rots: N, 3, 3\n        '''\n        axis = torch.argmax(scaling, dim=-1)\n        max_scaling = scaling[torch.arange(scaling.shape[0]), axis]\n        dirs = rots.gather(2, axis[:, None, None].expand(-1, 3, -1)).squeeze(-1)\n        \n        return dirs, max_scaling, axis\n    \n    def densify_and_split_along_maxscaling(self, grads, grad_threshold, scene_extent, visi=None, N=2, n_std=2):\n        n_init_points = self.get_xyz.shape[0]\n        # Extract points that satisfy the gradient condition\n        padded_grad = torch.zeros((n_init_points), device=\"cuda\")\n        padded_grad[:grads.shape[0]] = grads.squeeze()\n        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)\n        selected_pts_mask = torch.logical_and(selected_pts_mask,\n                                              torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)\n\n        if self.large_percent_dense is not None and (torch.cuda.memory_allocated(0) / 1024**3 < self.max_mem):\n            densify_pts_mask = torch.max(self.get_scaling, dim=1).values > self.large_percent_dense * scene_extent\n            inside, _ = self.get_inside_gaus_normalized()\n            densify_pts_mask = torch.logical_and(densify_pts_mask, inside)\n            if visi is not None:\n                padded_vis = torch.zeros((n_init_points), device=\"cuda\").bool()\n                padded_vis[:visi.shape[0]] = visi\n                densify_pts_mask = torch.logical_and(densify_pts_mask, padded_vis)\n            selected_pts_mask = torch.logical_or(selected_pts_mask, densify_pts_mask)\n            \n        scaling = self.get_scaling[selected_pts_mask]\n        rots = build_rotation(self._rotation[selected_pts_mask])\n        dirs, max_scaling, axis = self.get_dir_max_scaling(scaling, rots)\n        radii = (n_std * max_scaling / 3.)[..., None] # 3 std\n        new_xyz1 = self.get_xyz[selected_pts_mask] + dirs * radii\n        new_xyz2 = self.get_xyz[selected_pts_mask] - dirs * radii\n        new_xyz = torch.cat((new_xyz1, new_xyz2), dim=0)\n        \n        new_scaling = scaling.detach().clone()\n        new_scaling[torch.arange(new_scaling.shape[0]), axis] = max_scaling / (0.8*N)\n        new_scaling = self.scaling_inverse_activation(new_scaling)\n        new_scaling = torch.cat((new_scaling, new_scaling), dim=0)\n        \n        new_rotation = self._rotation[selected_pts_mask]\n        new_rotation = torch.cat((new_rotation, new_rotation), dim=0)\n        \n        new_features_dc = self._features_dc[selected_pts_mask]\n        new_features_dc = torch.cat((new_features_dc, new_features_dc), dim=0)\n        new_features_rest = self._features_rest[selected_pts_mask]\n        new_features_rest = torch.cat((new_features_rest, new_features_rest), dim=0)\n        \n        new_opacity = self._opacity[selected_pts_mask]\n        new_opacity = torch.cat((new_opacity, new_opacity), dim=0)\n        new_opacity = self._opacity[selected_pts_mask].repeat(N,1)\n        new_objects_dc = self._objects_dc[selected_pts_mask].repeat(N,1,1) if self.enable_semantic else None\n        \n        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_objects_dc)\n\n        prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device=\"cuda\", dtype=bool)))\n        self.prune_points(prune_filter)\n    \n    def densify_and_clone(self, grads, grad_threshold, scene_extent):\n        # Extract points that satisfy the gradient condition\n        selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)\n        selected_pts_mask = torch.logical_and(selected_pts_mask,\n                                              torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)\n        \n        new_xyz = self._xyz[selected_pts_mask]\n        new_features_dc = self._features_dc[selected_pts_mask]\n        new_features_rest = self._features_rest[selected_pts_mask]\n        new_opacities = self._opacity[selected_pts_mask]\n        new_scaling = self._scaling[selected_pts_mask]\n        new_rotation = self._rotation[selected_pts_mask]\n        new_objects_dc = self._objects_dc[selected_pts_mask] if self.enable_semantic else None\n\n        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_objects_dc)\n\n    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size, visi=None):\n        grads = self.xyz_gradient_accum / self.denom\n        grads[grads.isnan()] = 0.0\n\n        self.densify_and_clone(grads, max_grad, extent)\n        self.densify_and_split_along_maxscaling(grads, max_grad, extent, visi=visi)\n        \n        prune_mask = (self.get_opacity < min_opacity).squeeze()\n        if max_screen_size:\n            big_points_vs = self.max_radii2D > max_screen_size\n            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent\n            prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)\n        self.prune_points(prune_mask)\n\n        torch.cuda.empty_cache()\n\n    def prune_gaussians(self, percent, import_score: list):\n        sorted_tensor, _ = torch.sort(import_score, dim=0)\n        index_nth_percentile = int(percent * (sorted_tensor.shape[0] - 1))\n        value_nth_percentile = sorted_tensor[index_nth_percentile]\n        prune_mask = (import_score <= value_nth_percentile).squeeze()\n        # TODO(Kevin) Emergent, change it back. This is just for testing\n        self.prune_points(prune_mask)\n\n    def add_densification_stats(self, viewspace_point_tensor, update_filter):\n        self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)\n        self.denom[update_filter] += 1\n    \n    def get_inside_gaus_normalized(self):\n        inside, pts = get_inside_normalized(self.get_xyz, self.trans, self.scale)\n        return inside, pts\n    \n    def normalize_pts(self, pts):\n        pts = normalize_pts(pts, self.trans, self.scale)\n        return pts\n    \n    def filter_points(self, nb_points=5, radius=0.01, std_ratio=0.01):\n        inside, _ = self.get_inside_gaus_normalized()\n        \n        xyz = self.get_xyz[inside]\n        filte_valid = remove_radius_outlier(xyz, nb_points, radius*self.extent)\n        inside[inside.clone()] = filte_valid\n        return inside\n    \n    def prune_outside(self):\n        inside, _ = self.get_inside_gaus_normalized()\n        self.prune_points(~inside)\n\n    def prune_outliers(self):\n        mask = torch.ones(self.get_xyz.shape[0], dtype=torch.bool, device=\"cuda\")\n        valid = self.filter_points()\n        mask[valid] = False\n        self.prune_points(mask)\n\n    def prune_semantics(self, cls=BACKGROUND):\n        mask = torch.ones(self.get_xyz.shape[0], dtype=torch.bool, device=\"cuda\")\n        mask[self.get_cls() != cls] = False\n        self.prune_points(mask)\n    \n\nif __name__ == '__main__':\n    model = GaussianModel(2)\n    m2 = deepcopy(model)\n    \n"
  },
  {
    "path": "tools/__init__.py",
    "content": ""
  },
  {
    "path": "tools/camera.py",
    "content": "'''\n-----------------------------------------------------------------------------\nCopyright (c) 2023, NVIDIA CORPORATION. All rights reserved.\n\nNVIDIA CORPORATION and its licensors retain all intellectual property\nand proprietary rights in and to this software, related documentation\nand any modifications thereto. Any use, reproduction, disclosure or\ndistribution of this software and related documentation without an express\nlicense agreement from NVIDIA CORPORATION is strictly prohibited.\n-----------------------------------------------------------------------------\n'''\n\nimport numpy as np\nimport torch\n\n\nclass Pose():\n    \"\"\"\n    A class of operations on camera poses (PyTorch tensors with shape [...,3,4]).\n    Each [3,4] camera pose takes the form of [R|t].\n    \"\"\"\n\n    def __call__(self, R=None, t=None):\n        # Construct a camera pose from the given R and/or t.\n        assert R is not None or t is not None\n        if R is None:\n            if not isinstance(t, torch.Tensor):\n                t = torch.tensor(t)\n            R = torch.eye(3, device=t.device).repeat(*t.shape[:-1], 1, 1)\n        elif t is None:\n            if not isinstance(R, torch.Tensor):\n                R = torch.tensor(R)\n            t = torch.zeros(R.shape[:-1], device=R.device)\n        else:\n            if not isinstance(R, torch.Tensor):\n                R = torch.tensor(R)\n            if not isinstance(t, torch.Tensor):\n                t = torch.tensor(t)\n        assert R.shape[:-1] == t.shape and R.shape[-2:] == (3, 3)\n        R = R.float()\n        t = t.float()\n        pose = torch.cat([R, t[..., None]], dim=-1)\n        assert pose.shape[-2:] == (3, 4)\n        return pose\n\n    def invert(self, pose, use_inverse=False):\n        # Invert a camera pose.\n        R, t = pose[..., :3], pose[..., 3:]\n        R_inv = R.inverse() if use_inverse else R.transpose(-1, -2)\n        t_inv = (-R_inv @ t)[..., 0]\n        pose_inv = self(R=R_inv, t=t_inv)\n        return pose_inv\n\n    def compose(self, pose_list):\n        # Compose a sequence of poses together.\n        # pose_new(x) = poseN o ... o pose2 o pose1(x)\n        pose_new = pose_list[0]\n        for pose in pose_list[1:]:\n            pose_new = self.compose_pair(pose_new, pose)\n        return pose_new\n\n    def compose_pair(self, pose_a, pose_b):\n        R_a, t_a = pose_a[..., :3], pose_a[..., 3:]\n        R_b, t_b = pose_b[..., :3], pose_b[..., 3:]\n        R_new = R_b @ R_a\n        t_new = (R_b @ t_a + t_b)[..., 0]\n        pose_new = self(R=R_new, t=t_new)\n        return pose_new\n\n    def scale_center(self, pose, scale):\n        \"\"\"Scale the camera center from the origin.\n        0 = R@c+t --> c = -R^T@t (camera center in world coordinates)\n        0 = R@(sc)+t' --> t' = -R@(sc) = -R@(-R^T@st) = st\n        \"\"\"\n        R, t = pose[..., :3], pose[..., 3:]\n        pose_new = torch.cat([R, t * scale], dim=-1)\n        return pose_new\n\n    def interpolate(self, pose_a, pose_b, alpha):\n        \"\"\"Interpolate between two poses with Slerp.\n        Args:\n            pose_a (tensor [...,3,4]): Pose at time t=0.\n            pose_b (tensor [...,3,4]): Pose at time t=1.\n            alpha (tensor [...,1]): Interpolation parameter.\n        Returns:\n            pose (tensor [...,3,4]): Pose at time t.\n        \"\"\"\n        R_a, t_a = pose_a[..., :3], pose_a[..., 3:]\n        R_b, t_b = pose_b[..., :3], pose_b[..., 3:]\n        q_a = quaternion.R_to_q(R_a)  # [...,4]\n        q_b = quaternion.R_to_q(R_b)  # [...,4]\n        q_intp = quaternion.interpolate(q_a, q_b, alpha)  # [...,4]\n        R_intp = quaternion.q_to_R(q_intp)  # [...,3,3]\n        t_intp = (1 - alpha) * t_a + alpha * t_b  # [...,3]\n        pose_intp = torch.cat([R_intp, t_intp], dim=-1)  # [...,3,4]\n        return pose_intp\n\n\nclass Lie():\n    \"\"\"\n    Lie algebra for SO(3) and SE(3) operations in PyTorch.\n    \"\"\"\n\n    def so3_to_SO3(self, w):  # [..., 3]\n        wx = self.skew_symmetric(w)\n        theta = w.norm(dim=-1)[..., None, None]\n        eye = torch.eye(3, device=w.device, dtype=torch.float32)\n        A = self.taylor_A(theta)\n        B = self.taylor_B(theta)\n        R = eye + A * wx + B * wx @ wx\n        return R\n\n    def SO3_to_so3(self, R, eps=1e-7):  # [..., 3, 3]\n        trace = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]\n        theta = ((trace - 1) / 2).clamp(-1 + eps, 1 - eps).acos_()[\n                    ..., None, None] % np.pi  # ln(R) will explode if theta==pi\n        lnR = 1 / (2 * self.taylor_A(theta) + 1e-8) * (R - R.transpose(-2, -1))  # FIXME: wei-chiu finds it weird\n        w0, w1, w2 = lnR[..., 2, 1], lnR[..., 0, 2], lnR[..., 1, 0]\n        w = torch.stack([w0, w1, w2], dim=-1)\n        return w\n\n    def se3_to_SE3(self, wu):  # [...,3]\n        w, u = wu.split([3, 3], dim=-1)\n        wx = self.skew_symmetric(w)\n        theta = w.norm(dim=-1)[..., None, None]\n        eye = torch.eye(3, device=w.device, dtype=torch.float32)\n        A = self.taylor_A(theta)\n        B = self.taylor_B(theta)\n        C = self.taylor_C(theta)\n        R = eye + A * wx + B * wx @ wx\n        V = eye + B * wx + C * wx @ wx\n        Rt = torch.cat([R, (V @ u[..., None])], dim=-1)\n        return Rt\n\n    def SE3_to_se3(self, Rt, eps=1e-8):  # [...,3,4]\n        R, t = Rt.split([3, 1], dim=-1)\n        w = self.SO3_to_so3(R)\n        wx = self.skew_symmetric(w)\n        theta = w.norm(dim=-1)[..., None, None]\n        eye = torch.eye(3, device=w.device, dtype=torch.float32)\n        A = self.taylor_A(theta)\n        B = self.taylor_B(theta)\n        invV = eye - 0.5 * wx + (1 - A / (2 * B)) / (theta ** 2 + eps) * wx @ wx\n        u = (invV @ t)[..., 0]\n        wu = torch.cat([w, u], dim=-1)\n        return wu\n\n    def skew_symmetric(self, w):\n        w0, w1, w2 = w.unbind(dim=-1)\n        zero = torch.zeros_like(w0)\n        wx = torch.stack([torch.stack([zero, -w2, w1], dim=-1),\n                          torch.stack([w2, zero, -w0], dim=-1),\n                          torch.stack([-w1, w0, zero], dim=-1)], dim=-2)\n        return wx\n\n    def taylor_A(self, x, nth=10):\n        # Taylor expansion of sin(x)/x.\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth + 1):\n            if i > 0:\n                denom *= (2 * i) * (2 * i + 1)\n            ans = ans + (-1) ** i * x ** (2 * i) / denom\n        return ans\n\n    def taylor_B(self, x, nth=10):\n        # Taylor expansion of (1-cos(x))/x**2.\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth + 1):\n            denom *= (2 * i + 1) * (2 * i + 2)\n            ans = ans + (-1) ** i * x ** (2 * i) / denom\n        return ans\n\n    def taylor_C(self, x, nth=10):\n        # Taylor expansion of (x-sin(x))/x**3.\n        ans = torch.zeros_like(x)\n        denom = 1.\n        for i in range(nth + 1):\n            denom *= (2 * i + 2) * (2 * i + 3)\n            ans = ans + (-1) ** i * x ** (2 * i) / denom\n        return ans\n\n\nclass Quaternion():\n\n    def q_to_R(self, q):  # [...,4]\n        # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion\n        qa, qb, qc, qd = q.unbind(dim=-1)\n        R = torch.stack(\n            [torch.stack([1 - 2 * (qc ** 2 + qd ** 2), 2 * (qb * qc - qa * qd), 2 * (qa * qc + qb * qd)], dim=-1),\n             torch.stack([2 * (qb * qc + qa * qd), 1 - 2 * (qb ** 2 + qd ** 2), 2 * (qc * qd - qa * qb)], dim=-1),\n             torch.stack([2 * (qb * qd - qa * qc), 2 * (qa * qb + qc * qd), 1 - 2 * (qb ** 2 + qc ** 2)], dim=-1)],\n            dim=-2)\n        return R\n\n    def R_to_q(self, R, eps=1e-6):  # [...,3,3]\n        # https://en.wikipedia.org/wiki/Rotation_matrix#Quaternion\n        row0, row1, row2 = R.unbind(dim=-2)\n        R00, R01, R02 = row0.unbind(dim=-1)\n        R10, R11, R12 = row1.unbind(dim=-1)\n        R20, R21, R22 = row2.unbind(dim=-1)\n        t = R[..., 0, 0] + R[..., 1, 1] + R[..., 2, 2]\n        r = (1 + t + eps).sqrt()\n        qa = 0.5 * r\n        qb = (R21 - R12).sign() * 0.5 * (1 + R00 - R11 - R22 + eps).sqrt()\n        qc = (R02 - R20).sign() * 0.5 * (1 - R00 + R11 - R22 + eps).sqrt()\n        qd = (R10 - R01).sign() * 0.5 * (1 - R00 - R11 + R22 + eps).sqrt()\n        q = torch.stack([qa, qb, qc, qd], dim=-1)\n        return q\n\n    def invert(self, q):  # [...,4]\n        qa, qb, qc, qd = q.unbind(dim=-1)\n        norm = q.norm(dim=-1, keepdim=True)\n        q_inv = torch.stack([qa, -qb, -qc, -qd], dim=-1) / norm ** 2\n        return q_inv\n\n    def product(self, q1, q2):  # [...,4]\n        q1a, q1b, q1c, q1d = q1.unbind(dim=-1)\n        q2a, q2b, q2c, q2d = q2.unbind(dim=-1)\n        hamil_prod = torch.stack([q1a * q2a - q1b * q2b - q1c * q2c - q1d * q2d,\n                                  q1a * q2b + q1b * q2a + q1c * q2d - q1d * q2c,\n                                  q1a * q2c - q1b * q2d + q1c * q2a + q1d * q2b,\n                                  q1a * q2d + q1b * q2c - q1c * q2b + q1d * q2a], dim=-1)\n        return hamil_prod\n\n    def interpolate(self, q1, q2, alpha):  # [...,4],[...,4],[...,1]\n        # https://en.wikipedia.org/wiki/Slerp\n        cos_angle = (q1 * q2).sum(dim=-1, keepdim=True)  # [...,1]\n        flip = cos_angle < 0\n        q1 = q1 * (~flip) - q1 * flip  # [...,4]\n        theta = cos_angle.abs().acos()  # [...,1]\n        slerp = (((1 - alpha) * theta).sin() * q1 + (alpha * theta).sin() * q2) / theta.sin()  # [...,4]\n        return slerp\n\n\npose = Pose()\nlie = Lie()\nquaternion = Quaternion()\n\n\ndef to_hom(X):\n    # Get homogeneous coordinates of the input.\n    X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1)\n    return X_hom\n\n\n# Basic operations of transforming 3D points between world/camera/image coordinates.\ndef world2cam(X, pose):  # [B,N,3]\n    X_hom = to_hom(X)\n    return X_hom @ pose.transpose(-1, -2)\n\n\ndef cam2img(X, cam_intr):\n    return X @ cam_intr.transpose(-1, -2)\n\n\ndef img2cam(X, cam_intr):\n    return X @ cam_intr.inverse().transpose(-1, -2)\n\n\ndef cam2world(X, pose):\n    X_hom = to_hom(X)\n    pose_inv = Pose().invert(pose)\n    return X_hom @ pose_inv.transpose(-1, -2)\n\n\ndef angle_to_rotation_matrix(a, axis):\n    # Get the rotation matrix from Euler angle around specific axis.\n    roll = dict(X=1, Y=2, Z=0)[axis]\n    if isinstance(a, float):\n        a = torch.tensor(a)\n    zero = torch.zeros_like(a)\n    eye = torch.ones_like(a)\n    M = torch.stack([torch.stack([a.cos(), -a.sin(), zero], dim=-1),\n                     torch.stack([a.sin(), a.cos(), zero], dim=-1),\n                     torch.stack([zero, zero, eye], dim=-1)], dim=-2)\n    M = M.roll((roll, roll), dims=(-2, -1))\n    return M\n\n\ndef get_center_and_ray(pose, intr, image_size):\n    \"\"\"\n    Args:\n        pose (tensor [3,4]/[B,3,4]): Camera pose.\n        intr (tensor [3,3]/[B,3,3]): Camera intrinsics.\n        image_size (list of int): Image size.\n    Returns:\n        center_3D (tensor [HW,3]/[B,HW,3]): Center of the camera.\n        ray (tensor [HW,3]/[B,HW,3]): Ray of the camera with depth=1 (note: not unit ray).\n    \"\"\"\n    H, W = image_size\n    # Given the intrinsic/extrinsic matrices, get the camera center and ray directions.\n    with torch.no_grad():\n        # Compute image coordinate grid.\n        y_range = torch.arange(H, dtype=torch.float32, device=pose.device).add_(0.5)\n        x_range = torch.arange(W, dtype=torch.float32, device=pose.device).add_(0.5)\n        Y, X = torch.meshgrid(y_range, x_range, indexing=\"ij\")  # [H,W]\n        xy_grid = torch.stack([X, Y], dim=-1).view(-1, 2)  # [HW,2]\n    # Compute center and ray.\n    if len(pose.shape) == 3:\n        batch_size = len(pose)\n        xy_grid = xy_grid.repeat(batch_size, 1, 1)  # [B,HW,2]\n    grid_3D = img2cam(to_hom(xy_grid), intr)  # [HW,3]/[B,HW,3]\n    center_3D = torch.zeros_like(grid_3D)  # [HW,3]/[B,HW,3]\n    # Transform from camera to world coordinates.\n    grid_3D = cam2world(grid_3D, pose)  # [HW,3]/[B,HW,3]\n    center_3D = cam2world(center_3D, pose)  # [HW,3]/[B,HW,3]\n    ray = grid_3D - center_3D  # [B,HW,3]\n    return center_3D, ray\n\n\ndef get_3D_points_from_dist(center, ray_unit, dist, multi=True):\n    # Two possible use cases: (1) center + ray_unit * dist, or (2) center + ray * depth\n    if multi:\n        center, ray_unit = center[..., None, :], ray_unit[..., None, :]  # [...,1,3]\n    # x = c+dv\n    points_3D = center + ray_unit * dist  # [...,3]/[...,N,3]\n    return points_3D\n\n\ndef convert_NDC(center, ray, intr, near=1):\n    # Shift camera center (ray origins) to near plane (z=1).\n    # (Unlike conventional NDC, we assume the cameras are facing towards the +z direction.)\n    center = center + (near - center[..., 2:]) / ray[..., 2:] * ray\n    # Projection.\n    cx, cy, cz = center.unbind(dim=-1)  # [...,R]\n    rx, ry, rz = ray.unbind(dim=-1)  # [...,R]\n    scale_x = intr[..., 0, 0] / intr[..., 0, 2]  # [...]\n    scale_y = intr[..., 1, 1] / intr[..., 1, 2]  # [...]\n    cnx = scale_x[..., None] * (cx / cz)\n    cny = scale_y[..., None] * (cy / cz)\n    cnz = 1 - 2 * near / cz\n    rnx = scale_x[..., None] * (rx / rz - cx / cz)\n    rny = scale_y[..., None] * (ry / rz - cy / cz)\n    rnz = 2 * near / cz\n    center_ndc = torch.stack([cnx, cny, cnz], dim=-1)  # [...,R,3]\n    ray_ndc = torch.stack([rnx, rny, rnz], dim=-1)  # [...,R,3]\n    return center_ndc, ray_ndc\n\n\ndef convert_NDC2(center, ray, intr):\n    # Similar to convert_NDC() but shift the ray origins to its own image plane instead of the global near plane.\n    # Also this version is much more interpretable.\n    scale_x = intr[..., 0, 0] / intr[..., 0, 2]  # [...]\n    scale_y = intr[..., 1, 1] / intr[..., 1, 2]  # [...]\n    # Get the metric image plane (i.e. new \"center\"): (sx*cx/cz, sy*cy/cz, 1-2/cz).\n    center = center + ray  # This is the key difference.\n    cx, cy, cz = center.unbind(dim=-1)  # [...,R]\n    image_plane = torch.stack([scale_x[..., None] * cx / cz,\n                               scale_x[..., None] * cy / cz,\n                               1 - 2 / cz], dim=-1)\n    # Get the infinity plane: (sx*rx/rz, sy*ry/rz, 1).\n    rx, ry, rz = ray.unbind(dim=-1)  # [...,R]\n    inf_plane = torch.stack([scale_x[..., None] * rx / rz,\n                             scale_y[..., None] * ry / rz,\n                             torch.ones_like(rz)], dim=-1)\n    # The NDC ray is the difference between the two planes, assuming t \\in [0,1].\n    ndc_ray = inf_plane - image_plane\n    return image_plane, ndc_ray\n\n\ndef rotation_distance(R1, R2, eps=1e-7):\n    # http://www.boris-belousov.net/2016/12/01/quat-dist/\n    R_diff = R1 @ R2.transpose(-2, -1)\n    trace = R_diff[..., 0, 0] + R_diff[..., 1, 1] + R_diff[..., 2, 2]\n    angle = ((trace - 1) / 2).clamp(-1 + eps, 1 - eps).acos_()  # numerical stability near -1/+1\n    return angle\n\n\ndef get_oscil_novel_view_poses(N=60, angle=0.05, dist=5):\n    # Create circular viewpoints (small oscillations).\n    theta = torch.arange(N) / N * 2 * np.pi\n    R_x = angle_to_rotation_matrix((theta.sin() * angle).asin(), \"X\")\n    R_y = angle_to_rotation_matrix((theta.cos() * angle).asin(), \"Y\")\n    pose_rot = pose(R=R_y @ R_x)\n    pose_shift = pose(t=[0, 0, dist])\n    pose_oscil = pose.compose([pose.invert(pose_shift), pose_rot, pose_shift])\n    return pose_oscil\n\n\ndef cross_product_matrix(x):\n    \"\"\"Matrix form of cross product opertaion.\n\n    param x: [3,] tensor.\n    return: [3, 3] tensor representing the matrix form of cross product.\n    \"\"\"\n    return torch.tensor(\n        [[0, -x[2], x[1]],\n         [x[2], 0, -x[0]],\n         [-x[1], x[0], 0, ]]\n    )\n\n\ndef essential_matrix(poses):\n    \"\"\"Compute Essential Matrix from a relative pose.\n\n    param poses: [views, 3, 4] tensor representing relative poses.\n    return: [views, 3, 3] tensor representing Essential Matrix.\n    \"\"\"\n    r = poses[..., 0:3]\n    t = poses[..., 3]\n    tx = torch.stack([cross_product_matrix(tt) for tt in t], axis=0)\n    return tx @ r\n\n\ndef fundamental_matrix(poses, intr1, intr2):\n    \"\"\"Compute Fundamental Matrix from a relative pose and intrinsics.\n\n    param poses: [views, 3, 4] tensor representing relative poses.\n          intr1: [3, 3] tensor. Camera intrinsic of reference image.\n          intr2: [views, 3, 3] tensor. Camera Intrinsic of target image.\n    return: [views, 3, 3] tensor representing Fundamental Matrix.\n    \"\"\"\n    return intr2.inverse().transpose(-1, -2) @ essential_matrix(poses) @ intr1.inverse()\n\n\ndef get_ray_depth_plane_intersection(center, ray, depths):\n    \"\"\"Compute the intersection of a ray with a depth plane.\n    Args:\n        center (tensor [B,HW,3]): Camera center of the target pose.\n        ray (tensor [B,HW,3]): Ray direction of the target pose.\n        depth (tensor [L]): The depth values from the source view (e.g. for MPI planes).\n    Returns:\n        intsc_points (tensor [B,HW,L,3]): Intersecting 3D points with the MPI.\n    \"\"\"\n    # Each 3D point x along the ray v from center c can be written as x = c+t*v.\n    # Plane equation: n@x = d, where normal n = (0,0,1), d = depth.\n    # --> t = (d-n@c)/(n@v).\n    # --> x = c+t*v = c+(d-n@c)/(n@v)*v.\n    center, ray = center[:, :, None], ray[:, :, None]  # [B,HW,L,3], [B,HW,1,3]\n    depths = depths[None, None, :, None]  # [1,1,L,1]\n    intsc_points = center + (depths - center[..., 2:]) / ray[..., 2:] * ray  # [B,HW,L,3]\n    return intsc_points\n\n\ndef unit_view_vector_to_rotation_matrix(v, axes=\"ZYZ\"):\n    \"\"\"\n    Args:\n        v (tensor [...,3]): Unit vectors on the view sphere.\n        axes: rotation axis order.\n\n    Returns:\n        rotation_matrix (tensor [...,3,3]): rotation matrix R @ v + [0, 0, 1] = 0.\n    \"\"\"\n    alpha = torch.arctan2(v[..., 1], v[..., 0])  # [...]\n    beta = np.pi - v[..., 2].arccos()  # [...]\n    euler_angles = torch.stack([torch.ones_like(alpha) * np.pi / 2, -beta, alpha], dim=-1)  # [...,3]\n    rot2 = angle_to_rotation_matrix(euler_angles[..., 2], axes[2])  # [...,3,3]\n    rot1 = angle_to_rotation_matrix(euler_angles[..., 1], axes[1])  # [...,3,3]\n    rot0 = angle_to_rotation_matrix(euler_angles[..., 0], axes[0])  # [...,3,3]\n    rot = rot2 @ rot1 @ rot0  # [...,3,3]\n    return rot.transpose(-2, -1)\n\n\ndef sample_on_spherical_cap(anchor, N, max_angle):\n    \"\"\"Sample n points on the view hemisphere within the angle to x.\n    Args:\n        anchor (tensor [...,3]): Reference 3-D unit vector on the view hemisphere.\n        N (int): Number of sampled points.\n        max_angle (float): Sampled points should have max angle to x.\n    Returns:\n        sampled_points (tensor [...,N,3]): Sampled points on the spherical caps.\n    \"\"\"\n    batch_shape = anchor.shape[:-1]\n    # First, sample uniformly on a unit 2D disk.\n    radius = torch.rand(*batch_shape, N, device=anchor.device)  # [...,N]\n    theta = torch.rand(*batch_shape, N, device=anchor.device) * 2 * np.pi  # [...,N]\n    x = radius.sqrt() * theta.cos()  # [...,N]\n    y = radius.sqrt() * theta.sin()  # [...,N]\n    # Reparametrize to a unit spherical cap with height h.\n    # http://marc-b-reynolds.github.io/distribution/2016/11/28/Uniform.html\n    h = 1 - np.cos(max_angle)  # spherical cap height\n    k = h * radius  # [...,N]\n    s = (h * (2 - k)).sqrt()  # [...,N]\n    points = torch.stack([s * x, s * y, 1 - k], dim=-1)  # [...,N,3]\n    # Transform to center around the anchor.\n    ref_z = torch.tensor([0., 0., 1.], device=anchor.device)\n    v = -anchor.cross(ref_z)  # [...,3]\n    ss_v = lie.skew_symmetric(v)  # [...,3,3]\n    R = torch.eye(3, device=anchor.device) + ss_v + ss_v @ ss_v / (1 + anchor @ ref_z)[..., None, None]  # [...,3,3]\n    points = points @ R.transpose(-2, -1)  # [...,N,3]\n    return points\n\n\ndef sample_on_spherical_cap_northern(anchor, N, max_angle, away_from=None, max_reject_count=None):\n    \"\"\"Sample n points only the northern view hemisphere within the angle to x.\"\"\"\n\n    def find_invalid_points(points):\n        southern = points[..., 2] < 0  # [...,N]\n        if away_from is not None:\n            cosine_ab = (away_from * anchor).sum(dim=-1, keepdim=True)  # [...,1]\n            cosine_ac = (away_from[..., None, :] * points).sum(dim=-1)  # [...,N]\n            not_outwards = cosine_ab < cosine_ac  # [...,N]\n            invalid = southern | not_outwards\n        else:\n            invalid = southern\n        return invalid\n\n    assert (anchor[..., 2] > 0).all()\n    assert anchor.norm(dim=-1).allclose(torch.ones_like(anchor[..., 0]))\n    points = sample_on_spherical_cap(anchor, N, max_angle)  # [...,N,3]\n    invalid = find_invalid_points(points)\n    count = 0\n    while invalid.any():\n        # Reject and resample.\n        points_resample = sample_on_spherical_cap(anchor, N, max_angle)\n        points[invalid] = points_resample[invalid]\n        invalid = find_invalid_points(points)\n        count += 1\n        if max_reject_count and count > max_reject_count:\n            points = anchor.repeat(N, 1)\n    return points\n"
  },
  {
    "path": "tools/camera_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport math\nimport torch\nimport numpy as np\nfrom tqdm import tqdm\nfrom scipy.spatial.transform import Rotation as R\n\ntry:\n    from scene.cameras import Camera\nexcept ImportError:\n    pass\nfrom tools.general_utils import PILtoTorch, NumpytoTorch\nfrom tools.graphics_utils import fov2focal\nfrom tools.math_utils import inv_normalize_pts\nfrom scene.cameras import SampleCam\n\nWARNED = False\n\n\ndef loadCam(args, id, cam_info, resolution_scale):\n    orig_w, orig_h = cam_info.image.size\n\n    if args.resolution in [1, 2, 4, 8]:\n        resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))\n    else:  # should be a type that converts to float\n        if args.resolution == -1:\n            if orig_w > 1600:\n                global WARNED\n                if not WARNED:\n                    print(\"[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\\n \"\n                        \"If this is not desired, please explicitly specify '--resolution/-r' as 1\")\n                    WARNED = True\n                global_down = orig_w / 1600\n            else:\n                global_down = 1\n        else:\n            global_down = orig_w / args.resolution\n\n        scale = float(global_down) * float(resolution_scale)\n        resolution = (int(orig_w / scale), int(orig_h / scale))\n\n    resized_image_rgb = PILtoTorch(cam_info.image, resolution) / 255.\n\n    gt_image = resized_image_rgb[:3, ...]\n    loaded_mask = None\n\n    if resized_image_rgb.shape[0] == 4:\n        loaded_mask = resized_image_rgb[3:4, ...]\n    \n    depth = None\n    if cam_info.depth is not None:\n        size = list(resolution)[::-1]\n        depth = NumpytoTorch(cam_info.depth, size)\n    normal = None\n    if cam_info.normal is not None:\n        size = list(resolution)[::-1]\n        normal = NumpytoTorch(cam_info.normal, size).permute(1, 2, 0)   # H, W, 3\n    mask = None\n    if cam_info.mask is not None:\n        mask = PILtoTorch(cam_info.mask, resolution).squeeze(0)\n        if mask.dim() == 3: mask = mask[0]\n\n    return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, \n                  FoVx=cam_info.FovX, FoVy=cam_info.FovY, \n                  image=gt_image, gt_alpha_mask=loaded_mask,\n                  image_name=cam_info.image_name, uid=id, data_device=args.data_device, depth=depth, normal=normal, mask=mask)\n\n\ndef cameraList_from_camInfos(cam_infos, resolution_scale, args):\n    camera_list = []\n\n    for id, c in tqdm(enumerate(cam_infos), total=len(cam_infos), desc=\"Processing data\", leave=False):\n        camera_list.append(loadCam(args, id, c, resolution_scale))\n\n    return camera_list\n\n\ndef camera_to_JSON(id, camera):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = camera.R.transpose()\n    Rt[:3, 3] = camera.T\n    Rt[3, 3] = 1.0\n\n    W2C = np.linalg.inv(Rt)\n    pos = W2C[:3, 3]\n    rot = W2C[:3, :3]\n    serializable_array_2d = [x.tolist() for x in rot]\n    camera_entry = {\n        'id' : id,\n        'img_name' : camera.image_name,\n        'width' : camera.width,\n        'height' : camera.height,\n        'position': pos.tolist(),\n        'rotation': serializable_array_2d,\n        'fy' : fov2focal(camera.FovY, camera.height),\n        'fx' : fov2focal(camera.FovX, camera.width)\n    }\n    return camera_entry\n\n\ndef find_up_axis(R):\n    '''\n    R: world to bounding box coordinate system\n    '''\n    \n    up_vector = torch.tensor([0, -1, 0], dtype=torch.float32, device=R.device) # world colmap\n    up_vector = R @ up_vector                                      # bounding box coordinate system\n    up_axis = torch.argmax(torch.abs(up_vector))\n    up_sign = torch.sign(up_vector[up_axis])\n    \n    return up_axis, up_sign\n\n\ndef find_axis(R, axis_name='up'):\n    '''\n    colmap coordinate system\n    R: world to bounding box coordinate system\n    '''\n    if axis_name == 'up':\n        axis_w=[0, -1, 0]\n    elif axis_name == 'front':\n        axis_w=[0, 0, 1]\n    elif axis_name == 'right':\n        axis_w=[1, 0, 0]\n    else:\n        raise ValueError(f'axis_name: \"{axis_name}\" should be one of [up, front, right]')\n    axis_w = torch.tensor(axis_w, dtype=torch.float32, device=R.device) # world colmap\n    axis_c = R @ axis_w                                      # bounding box coordinate system\n    axis = torch.argmax(torch.abs(axis_c))\n    sign = torch.sign(axis_c[axis])\n    \n    return axis, sign\n\ndef dot(x, y):\n    if isinstance(x, np.ndarray):\n        return np.sum(x * y, -1, keepdims=True)\n    else:\n        return torch.sum(x * y, -1, keepdim=True)\n\n\ndef length(x, eps=1e-20):\n    if isinstance(x, np.ndarray):\n        return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))\n    else:\n        return torch.sqrt(torch.clamp(dot(x, x), min=eps))\n\n\ndef safe_normalize(x, eps=1e-20):\n    return x / length(x, eps)\n\n\ndef look_at_np(campos, target, opengl=True):\n    # campos: [N, 3], camera/eye position\n    # target: [N, 3], object to look at\n    # return: [N, 3, 3], rotation matrix\n    if not opengl:\n        # camera forward aligns with -z colmap\n        forward_vector = safe_normalize(target - campos)\n        up_vector = np.array([0, 1, 0], dtype=np.float32)\n        right_vector = safe_normalize(np.cross(forward_vector, up_vector)) # z x up\n        up_vector = safe_normalize(np.cross(right_vector, forward_vector))\n    else:\n        # camera forward aligns with +z\n        forward_vector = safe_normalize(campos - target)\n        up_vector = np.array([0, 1, 0], dtype=np.float32)\n        right_vector = safe_normalize(np.cross(up_vector, forward_vector)) # up x z\n        up_vector = safe_normalize(np.cross(forward_vector, right_vector))\n    R = np.stack([right_vector, up_vector, forward_vector], axis=1) # axis=1 !!!!! 把行拼起来了 w2c\n    return R\n\n\ndef look_at(campos, target, opengl=True):\n    # campos: [N, 3], camera/eye position\n    # target: [N, 3], object to look at\n    # return: [N, 3, 3], rotation matrix\n    up_vector = torch.tensor([0, 1, 0], dtype=torch.float32, device=campos.device)\n    if campos.dim() == 2: up_vector = up_vector[None, :]\n    if not opengl:\n        # camera forward aligns with -z colmap\n        forward_vector = safe_normalize(target - campos)\n        right_vector = safe_normalize(torch.cross(forward_vector, up_vector)) # z x up\n        up_vector = safe_normalize(torch.cross(right_vector, forward_vector))\n    else:\n        # camera forward aligns with +z\n        forward_vector = safe_normalize(campos - target)\n        right_vector = safe_normalize(torch.cross(up_vector, forward_vector)) # up x z\n        up_vector = safe_normalize(torch.cross(forward_vector, right_vector))\n    R = torch.stack([right_vector, up_vector, forward_vector], dim=1) # axis=1 !!!!! 把行拼起来了 w2c\n    return R\n\n\n# elevation & azimuth to pose (cam2world) matrix\ndef orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):\n    # radius: scalar\n    # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)\n    # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)\n    # return: [4, 4], camera pose matrix\n    if is_degree:\n        elevation = np.deg2rad(elevation)\n        azimuth = np.deg2rad(azimuth)\n    x = radius * np.cos(elevation) * np.sin(azimuth)\n    y = - radius * np.sin(elevation)\n    z = radius * np.cos(elevation) * np.cos(azimuth)\n    if target is None:\n        target = np.zeros([3], dtype=np.float32)\n    campos = np.array([x, y, z]) + target  # [3]\n    T = np.eye(4, dtype=np.float32)\n    T[:3, :3] = look_at_np(campos, target, opengl)                             # ??? should be look_at(campos, target, opengl).transpose(0, 2, 1)\n    T[:3, 3] = campos\n    return T\n\n\ndef cubic_camera(n, trans, scale, target=None, opengl=False):\n    xyz = np.random.rand(n, 3) * 2 - 1\n    for i in range(3): xyz[i::3, i] = xyz[i::3, i] / np.abs(xyz[i::3, i]) # Unit cube\n    \n    if target is None: target = np.zeros([1, 3], dtype=np.float32)\n    \n    xyz = inv_normalize_pts(xyz, trans, scale)\n    target = inv_normalize_pts(target, trans, scale)\n    \n    T = np.zeros((n, 4, 4))\n    up_vector = [1, 0, 0]\n    T[:, :3, :3] = look_at(xyz, target, opengl, up_vector) # c2w\n    T[:, :3, 3] = xyz\n    T[:, 3, 3] = 1\n    \n    T = np.linalg.inv(T) # w2c\n    \n    return T\n\n\ndef check_tensor(x):\n    if isinstance(x, np.ndarray):\n        return torch.from_numpy(x).to(torch.float32)\n    else: return x\n\n\ndef up_camera(n, trans, scale, target=None, opengl=False): # colmap\n    trans = check_tensor(trans)\n    scale = check_tensor(scale)\n    device = trans.device\n    \n    up_axis, up_sign = find_up_axis(trans[:3, :3])\n    v_axis = [i for i in [0, 1, 2] if i != up_axis]\n    \n    xyz = torch.rand(n, 3).to(device) * 2 - 1\n    xyz[:, up_axis] = up_sign # up\n    \n    if target is None:\n        target = check_tensor(target)\n        target = torch.zeros([1, 3], dtype=torch.float32, device=device)\n    \n    target[:, up_axis] = 1 * -up_sign # 5\n    \n    xyz = inv_normalize_pts(xyz, trans, scale)\n    target = inv_normalize_pts(target, trans, scale)\n    \n    T = torch.zeros((xyz.shape[0], 4, 4), device=device)      # w2c\n    R = look_at(xyz, target, opengl) # w2c\n    T[:, :3, :3] = R\n    T[:, :3, 3] = - (R @ xyz[..., None]).squeeze(-1) # w2c\n    T[:, 3, 3] = 1\n    \n    return T\n\n\ndef around_camera(n, trans, scale, height=None, target=None, opengl=False):\n    trans = check_tensor(trans)\n    scale = check_tensor(scale)\n    \n    device = trans.device\n    grid_points = torch.Tensor([\n        [-1, -1, -1],\n        [1, 1, 1],\n    ]).to(device)\n    \n    up_axis, up_sign = find_up_axis(trans[:3, :3])\n    v_axis = [i for i in [0, 1, 2] if i != up_axis]\n    \n    xyz = torch.rand(n, 3).to(device) * 2 - 1\n    for i in v_axis: xyz[i-1::2, i] = xyz[i-1::2, i] / torch.abs(xyz[i-1::2, i])\n    \n    if target is None:\n        target = check_tensor(target)\n        target = torch.zeros([1, 3], dtype=torch.float32, device=device)\n    \n    xyz = inv_normalize_pts(xyz, trans, scale)\n    target = inv_normalize_pts(target, trans, scale)\n    grid_points = inv_normalize_pts(grid_points, trans, scale)\n    \n    if height is None: height = target[0, 1]\n    \n    xyz[:, 1] = height\n    \n    T = torch.zeros((xyz.shape[0], 4, 4), device=device)      # w2c\n    R = look_at(xyz, target, opengl) # w2c\n    T[:, :3, :3] = R\n    T[:, :3, 3] = - (R @ xyz[..., None]).squeeze(-1) # w2c\n    T[:, 3, 3] = 1\n    \n    return T\n\n\ndef bb_camera(n, trans, scale, height=None, target=None, opengl=False, up=True, around=True, look_mode='target', sample_mode='grid', boundary=0.9, bidirect=False): # colmap 0.8\n    trans = check_tensor(trans)\n    scale = check_tensor(scale)\n    device = trans.device\n    \n    if scale.ndim == 0: scale = torch.ones(3, dtype=torch.float32, device=device) * scale\n    rot = trans[:3, :3] if trans.ndim == 2 else torch.eye(3, device=device)\n    up_axis, up_sign = find_axis(rot, axis_name='up')\n    if sample_mode == 'grid' or (up and around):\n        right_axis, right_sign = find_axis(rot, axis_name='right')\n        front_axis, front_sign = find_axis(rot, axis_name='front')\n    v_axis = [i for i in [0, 1, 2] if i != up_axis]\n    \n    up_n = around_n = n\n    if up and around:\n        h = scale[up_axis]\n        l = scale[right_axis]\n        w = scale[front_axis]\n        \n        around_area = 2 * (l * h + h * w)\n        up_area = l * w\n        total_area = around_area + up_area\n        \n        up_n = int((n * up_area / total_area) * 1)\n    \n    xyz = []\n    if target is None:\n        if look_mode == 'target':\n            target = torch.zeros([1, 3], dtype=torch.float32, device=device)\n            target[:, up_axis] = 1 * -up_sign # 5\n        else:\n            target = []\n    else:\n        target = check_tensor(target)\n    \n    if up:\n        if sample_mode == 'random':\n            xyz_up = torch.rand(up_n, 3).to(device) * 2 - 1\n        elif sample_mode == 'grid':\n            xyz_up = up_grid_posi(up_n, scale, right_axis, up_axis, front_axis).to(device)\n            around_n = n - up_n\n        xyz_up[:, up_axis] = up_sign\n        xyz.append(xyz_up)\n        \n        if look_mode == 'direction':\n            tgt_up = xyz_up.clone()\n            tgt_up[:, up_axis] *= -1\n            target.append(tgt_up)\n    \n    if around:\n        if sample_mode == 'random':\n            xyz_around = torch.rand(around_n, 3).to(device) * 2 - 1\n        elif sample_mode == 'grid':\n            if not bidirect:\n                xyz_around = around_grid_posi(around_n, scale, right_axis, up_axis, front_axis, up_sign=up_sign).to(device)\n            else:\n                n1 = around_n // 2\n                xyz1 = around_grid_posi(n1, scale, right_axis, up_axis, front_axis, sign=1, up_sign=up_sign).to(device)\n                n2 = around_n - xyz1.shape[0]\n                xyz2 = around_grid_posi(n2, scale, right_axis, up_axis, front_axis, sign=-1, up_sign=up_sign).to(device)\n                xyz_around = torch.cat([xyz1, xyz2], 0)\n                n_trg = xyz_up.shape[0] + xyz_around.shape[0] if up else xyz_around.shape[0]\n                target = target.repeat(n_trg, 1)\n                target[-xyz2.shape[0]:, up_axis] *= -1\n                \n        xyz_around[:, up_axis] = xyz_around[:, up_axis] * boundary + (1 - boundary) * up_sign\n        xyz.append(xyz_around)\n        \n        if look_mode == 'direction':\n            trg_around = xyz_around.clone()\n            for i in v_axis: trg_around[i-1::2, i] *= -1\n            target.append(trg_around)\n    \n    xyz = torch.cat(xyz, 0)\n    if look_mode == 'direction':\n        target = torch.cat(target, 0)\n    \n    xyz = inv_normalize_pts(xyz, trans, scale)\n    target = inv_normalize_pts(target, trans, scale)\n    \n    T = torch.zeros((xyz.shape[0], 4, 4), device=device)      # w2c\n    R = look_at(xyz, target, opengl) # w2c\n    T[:, :3, :3] = R\n    T[:, :3, 3] = - (R @ xyz[..., None]).squeeze(-1) # w2c\n    T[:, 3, 3] = 1\n    \n    return T\n\n\ndef around_grid_posi(num_points, scale, right_axis, up_axis, front_axis, sign=1, up_sign=1):\n    device = scale.device\n    indexing = 'xy'\n    h = scale[up_axis]\n    l = scale[right_axis]\n    w = scale[front_axis]\n    \n    total_area = 2 * (l * h + h * w)\n    ratio = (num_points / total_area).sqrt()\n    h_points = torch.round(h * ratio).int()\n    l_points = torch.round(l * ratio).int()\n    w_points = torch.round(w * ratio).int()\n    \n    total_points = []\n    h_coord = torch.arange(start=-1, end=1, step=2 / h_points, device=device) * up_sign\n    \n    step = 2 / l_points\n    st = -1 if sign == 1 else -1 + step\n    l_coord = torch.arange(start=st, end=1, step=step, device=device) # * sign\n    grid_l, grid_h = torch.meshgrid([l_coord, h_coord], indexing=indexing)\n    lh = torch.stack([grid_l.flatten(), grid_h.flatten()], dim=1)\n    points = torch.ones([lh.shape[0], 3], dtype=torch.float32, device=device) * 1\n    points[:, [right_axis, up_axis]] = lh\n    total_points.append(points)\n    \n    # back\n    step = - 2 / l_points\n    st = 1 if sign == 1 else 1 + step\n    l_coord = torch.arange(start=st, end=-1, step=step, device=device) # * sign\n    grid_l, grid_h = torch.meshgrid([l_coord, h_coord], indexing=indexing)\n    lh = torch.stack([grid_l.flatten(), grid_h.flatten()], dim=1)\n    points = torch.ones([lh.shape[0], 3], dtype=torch.float32, device=device) * -1\n    points[:, [right_axis, up_axis]] = lh\n    total_points.append(points)\n    \n    # right\n    step = - 2 / w_points\n    st = 1 if sign == 1 else 1 + step\n    w_coord = torch.arange(start=st, end=-1, step=step, device=device)\n    grid_h, grid_w = torch.meshgrid([h_coord, w_coord], indexing=indexing)\n    hw = torch.stack([grid_h.flatten(), grid_w.flatten()], dim=1)\n    points = torch.ones([hw.shape[0], 3], dtype=torch.float32, device=device) * 1\n    points[:, [up_axis, front_axis]] = hw\n    total_points.append(points)\n    \n    # left\n    step = 2 / w_points\n    st = -1 if sign == 1 else -1 + step\n    w_coord = torch.arange(start=st, end=1, step = step, device=device)\n    grid_h, grid_w = torch.meshgrid([h_coord, w_coord], indexing=indexing)\n    hw = torch.stack([grid_h.flatten(), grid_w.flatten()], dim=1)\n    points = torch.ones([hw.shape[0], 3], dtype=torch.float32, device=device) * -1\n    points[:, [up_axis, front_axis]] = hw\n    total_points.append(points)\n    \n    points = torch.cat(total_points, 0)\n    return points\n\n\ndef up_grid_posi(num_points, scale, right_axis, up_axis, front_axis):\n    h = scale[up_axis]\n    l = scale[right_axis]\n    w = scale[front_axis]\n    \n    total_area = l * w\n    ratio = math.sqrt(num_points / total_area)\n    l_points = torch.round(l * ratio).int()\n    w_points = torch.round(w * ratio).int()\n    \n    # up\n    l_coord = torch.linspace(start=-1, end=1, steps=l_points) # * 0.9\n    w_coord = torch.linspace(start=-1, end=1, steps=w_points) # * 0.9\n    grid_l, grid_w = torch.meshgrid([l_coord, w_coord], indexing='xy')\n    lw = torch.stack([grid_l.flatten(), grid_w.flatten()], dim=1)\n    points = torch.ones([lw.shape[0], 3], dtype=torch.float32) * 1\n    points[:, [right_axis, front_axis]] = lw\n    \n    return points\n\n\ndef grid_camera(trans, scale, opengl=False):\n    trans = check_tensor(trans)\n    scale = check_tensor(scale)\n    device = trans.device\n    \n    xyz = torch.tensor(\n        [\n            [-1, -1, -1],\n            [1, 1, 1],\n            [-1, 1, 1],\n            [1, -1, -1],\n            [-1, 1, -1],\n            [1, -1, 1],\n            [1, 1, -1],\n            [-1, -1, 1],\n        ], dtype=torch.float32, device=device\n    )\n    \n    \n    if target is None:\n        target = check_tensor(target)\n        target = torch.zeros([1, 3], dtype=torch.float32, device=device)\n    \n    xyz = inv_normalize_pts(xyz, trans, scale)\n    target = inv_normalize_pts(target, trans, scale)\n    \n    T = torch.zeros((xyz.shape[0], 4, 4), device=device)      # w2c\n    R = look_at(xyz, target, opengl) # w2c\n    T[:, :3, :3] = R\n    T[:, :3, 3] = - (R @ xyz[..., None]).squeeze(-1) # w2c\n    T[:, 3, 3] = 1\n    \n    return T\n\n\ndef sample_cameras(model, n, up=False, around=True, look_mode='target', sample_mode='grid', bidirect=True):\n    cam_height = None\n    w2cs = bb_camera(n, model.trans, model.scale, cam_height, up=up, around=around, \\\n        look_mode=look_mode, sample_mode=sample_mode, bidirect=bidirect)\n    # traincam = self.scene.getTrainCameras()[0]\n    # FoVx = traincam.FoVx        # 1.3990553440909452\n    # FoVy = traincam.FoVy        # 0.8764846384037163\n    # width = traincam.image_width    # 1500\n    # height = traincam.image_height  # 835\n    FoVx = FoVy = 2.5 # 3.14 / 2\n    width = height = 1500\n    cams = []\n    \n    for i in range(w2cs.shape[0]):\n        w2c = w2cs[i]\n        cam = SampleCam(w2c, width, height, FoVx, FoVy)\n        cams.append(cam)\n    \n    return cams\n\n\nclass OrbitCamera:\n    def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):\n        self.W = W\n        self.H = H\n        self.radius = r  # camera distance from center\n        self.fovy = np.deg2rad(fovy)  # deg 2 rad\n        self.near = near\n        self.far = far\n        self.center = np.array([0, 0, 0], dtype=np.float32)  # look at this point\n        self.rot = R.from_matrix(np.eye(3))\n        self.up = np.array([0, 1, 0], dtype=np.float32)  # need to be normalized!\n\n    @property\n    def fovx(self):\n        return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)\n\n    @property\n    def campos(self):\n        return self.pose[:3, 3]\n\n    # pose (c2w)\n    @property\n    def pose(self):\n        # first move camera to radius\n        res = np.eye(4, dtype=np.float32)\n        res[2, 3] = self.radius  # opengl convention...\n        # rotate\n        rot = np.eye(4, dtype=np.float32)\n        rot[:3, :3] = self.rot.as_matrix()\n        res = rot @ res\n        # translate\n        res[:3, 3] -= self.center\n        return res\n\n    # view (w2c)\n    @property\n    def view(self):\n        return np.linalg.inv(self.pose)\n\n    # projection (perspective)\n    @property\n    def perspective(self):\n        y = np.tan(self.fovy / 2)\n        aspect = self.W / self.H\n        return np.array(\n            [\n                [1 / (y * aspect), 0, 0, 0],\n                [0, -1 / y, 0, 0],\n                [\n                    0,\n                    0,\n                    -(self.far + self.near) / (self.far - self.near),\n                    -(2 * self.far * self.near) / (self.far - self.near),\n                ],\n                [0, 0, -1, 0],\n            ],\n            dtype=np.float32,\n        )\n\n    # intrinsics\n    @property\n    def intrinsics(self):\n        focal = self.H / (2 * np.tan(self.fovy / 2))\n        return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)\n\n    @property\n    def mvp(self):\n        return self.perspective @ np.linalg.inv(self.pose)  # [4, 4]\n\n    def orbit(self, dx, dy):\n        # rotate along camera up/side axis!\n        side = self.rot.as_matrix()[:3, 0]\n        rotvec_x = self.up * np.radians(-0.05 * dx)\n        rotvec_y = side * np.radians(-0.05 * dy)\n        self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot\n\n    def scale(self, delta):\n        self.radius *= 1.1 ** (-delta)\n\n    def pan(self, dx, dy, dz=0):\n        # pan in camera coordinate system (careful on the sensitivity!)\n        self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])\n\n\n"
  },
  {
    "path": "tools/crop_mesh.py",
    "content": "import os\nimport argparse\nimport numpy as np\nimport trimesh\n\n\ndef align_gt_with_cam(pts, trans):\n    trans_inv = np.linalg.inv(trans)\n    pts_aligned = pts @ trans_inv[:3, :3].transpose(-1, -2) + trans_inv[:3, -1]\n    return pts_aligned\n\n\ndef filter_largest_cc(mesh):\n    components = mesh.split(only_watertight=False)\n    areas = np.array([c.area for c in components], dtype=float)\n    if len(areas) > 0 and mesh.vertices.shape[0] > 0:\n        new_mesh = components[areas.argmax()]\n    else:\n        new_mesh = trimesh.Trimesh()\n    return new_mesh\n\n\ndef main(args):\n    assert os.path.exists(args.ply_path), f\"PLY file {args.ply_path} does not exist.\"\n    gt_trans = np.loadtxt(args.align_path)\n    \n    mesh_rec = trimesh.load(args.ply_path, process=False)\n    mesh_gt = trimesh.load(args.gt_path, process=False)\n    \n    mesh_gt.vertices = align_gt_with_cam(mesh_gt.vertices, gt_trans)\n    \n    to_align, _ = trimesh.bounds.oriented_bounds(mesh_gt)\n    mesh_gt.vertices = (to_align[:3, :3] @ mesh_gt.vertices.T + to_align[:3, 3:]).T\n    mesh_rec.vertices = (to_align[:3, :3] @ mesh_rec.vertices.T + to_align[:3, 3:]).T\n    \n    min_points = mesh_gt.vertices.min(axis=0)\n    max_points = mesh_gt.vertices.max(axis=0)\n\n    mask_min = (mesh_rec.vertices - min_points[None]) > 0\n    mask_max = (mesh_rec.vertices - max_points[None]) < 0\n\n    mask = np.concatenate((mask_min, mask_max), axis=1).all(axis=1)\n    face_mask = mask[mesh_rec.faces].all(axis=1)\n\n    mesh_rec.update_vertices(mask)\n    mesh_rec.update_faces(face_mask)\n    \n    mesh_rec.vertices = (to_align[:3, :3].T @ mesh_rec.vertices.T - to_align[:3, :3].T @ to_align[:3, 3:]).T\n    mesh_gt.vertices = (to_align[:3, :3].T @ mesh_gt.vertices.T - to_align[:3, :3].T @ to_align[:3, 3:]).T\n    \n    \n    # save mesh_rec and mesh_rec in args.out_path\n    mesh_rec.export(args.out_path)\n    \n    # downsample mesh_gt\n    \n    idx = np.random.choice(np.arange(len(mesh_gt.vertices)), 5000000)\n    mesh_gt.vertices = mesh_gt.vertices[idx]\n    mesh_gt.colors = mesh_gt.colors[idx]\n    \n    mesh_gt.export(args.gt_path.replace('.ply', '_trans.ply'))\n    \n    \n    return\n\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--gt_path\",\n        type=str,\n        default='/your/path//Barn_GT.ply',\n        help=\"path to a dataset/scene directory containing X.json, X.ply, ...\",\n    )\n    parser.add_argument(\n        \"--align_path\",\n        type=str,\n        default='/your/path//Barn_trans.txt',\n        help=\"path to a dataset/scene directory containing X.json, X.ply, ...\",\n    )\n    parser.add_argument(\n        \"--ply_path\",\n        type=str,\n        default='/your/path//Barn_lowres.ply',\n        help=\"path to reconstruction ply file\",\n    )\n    parser.add_argument(\n        \"--scene\",\n        type=str,\n        default='Barn',\n        help=\"path to reconstruction ply file\",\n    )\n    parser.add_argument(\n        \"--out_path\",\n        type=str,\n        default='/your/path//Barn_lowres_crop.ply',\n        help=\n        \"output directory, default: an evaluation directory is created in the directory of the ply file\",\n    )\n    args = parser.parse_args()\n    \n    main(args)"
  },
  {
    "path": "tools/denoise_pcd.py",
    "content": "from pytorch3d.ops import ball_query, knn_points\n\n\ndef remove_radius_outlier(xyz, nb_points=5, radius=0.1):\n    if xyz.dim() == 2: xyz = xyz[None]\n    nn_dists, nn_idx, nn = ball_query(xyz, xyz, K=nb_points+1, radius=radius)\n    valid = ~(nn_idx[0]==-1).any(-1)\n    \n    return valid\n\n\ndef remove_statistical_outlier(xyz, nb_points=20, std_ratio=20.):\n    if xyz.dim() == 2: xyz = xyz[None]\n    nn_dists, nn_idx, nn = knn_points(xyz, xyz, K=nb_points, return_sorted=False)\n    \n    # Compute distances to neighbors\n    distances = nn_dists.squeeze(0)  # Shape: (N, nb_neighbors)\n\n    # Compute mean and standard deviation of distances\n    mean_distances = distances.mean(dim=-1)\n    std_distances = distances.std(dim=-1)\n\n    # Identify points that are not outliers\n    threshold = mean_distances + std_ratio * std_distances\n    valid = (distances <= threshold.unsqueeze(1)).any(dim=1)\n    \n    return valid\n\n\nif __name__ == '__main__':\n    import torch\n    import time\n    \n    gpu = 0\n    device = torch.device('cuda:{:d}'.format(gpu) if torch.cuda.is_available() else 'cpu')\n    t1 = time.time()\n    xyz = torch.rand(int(1e7), 3).to(device)\n    remove_statistical_outlier(xyz)\n    print('time:', time.time()-t1, 's')\n\n"
  },
  {
    "path": "tools/depth2mesh.py",
    "content": "import os\nimport sys\nimport math\nimport torch\nimport argparse\nimport numpy as np\nimport open3d as o3d\nimport open3d.core as o3c\n\nsys.path.append(os.getcwd())\nfrom configs.config import Config\nfrom gaussian_renderer import render\nfrom scene import Scene, GaussianModel\nfrom tools.semantic_id import BACKGROUND\nfrom tools.graphics_utils import depth2point\nfrom tools.general_utils import set_random_seed\nfrom tools.math_utils import get_inside_normalized\nfrom tools.mesh_utils import GaussianExtractor, post_process_mesh\n\n\n@torch.no_grad()\ndef tsdf_fusion(args, cfg, model, cameras, dirs, bg, outdir, mesh_name='fused_mesh.ply', max_depth=5.0):\n    o3d_device = o3d.core.Device(\"CUDA:0\")\n    \n    vbg = o3d.t.geometry.VoxelBlockGrid(\n            attr_names=('tsdf', 'weight', 'color'),\n            attr_dtypes=(o3c.float32, o3c.float32, o3c.float32),\n            attr_channels=((1), (1), (3)),\n            voxel_size=args.voxel_size,\n            block_resolution=16,\n            block_count=60000,\n            device=o3d_device)\n    \n    with torch.no_grad():\n        for _, view in enumerate(cameras):\n            \n            render_pkg = render(view, model, cfg, bg, dirs=dirs)\n            if args.depth_mode == 'mean':\n                depth = render_pkg[\"depth\"]\n            elif args.depth_mode == 'median':\n                depth = render_pkg[\"median_depth\"]\n            rgb = render_pkg[\"render\"]\n            alpha = render_pkg[\"alpha\"]\n            \n            if view.gt_alpha_mask is not None:\n                depth[(view.gt_alpha_mask < 0.5)] = 0\n            \n            depth[(alpha < args.alpha_thres)] = 0\n            \n            rendered_pcd_world = depth2point(depth[0], view.intr, view.world_view_transform.transpose(0, 1))[1]\n            inside = get_inside_normalized(rendered_pcd_world.view(-1, 3), model.trans, model.scale)[0]\n            depth.view(-1)[~inside] = 0\n            \n            if 'render_sem' in render_pkg:\n                semantic = render_pkg[\"render_sem\"]\n                prob = model.logits2prob(semantic)\n                mask = (prob[..., BACKGROUND] > args.prob_thres)[None]\n                depth[mask] = 0\n            \n            intrinsic=o3d.camera.PinholeCameraIntrinsic(width=view.image_width, \n                    height=view.image_height, \n                    cx = view.image_width/2,\n                    cy = view.image_height/2,\n                    fx = view.image_width / (2 * math.tan(view.FoVx / 2.)),\n                    fy = view.image_height / (2 * math.tan(view.FoVy / 2.)))\n            extrinsic = np.asarray((view.world_view_transform.T).cpu().numpy())\n            \n            rgb = rgb.clamp(0, 1)\n            o3d_color = o3d.t.geometry.Image(np.asarray(rgb.permute(1,2,0).cpu().numpy(), order=\"C\"))\n            o3d_depth = o3d.t.geometry.Image(np.asarray(depth.permute(1,2,0).cpu().numpy(), order=\"C\"))\n            o3d_color = o3d_color.to(o3d_device)\n            o3d_depth = o3d_depth.to(o3d_device)\n\n            intrinsic = o3d.core.Tensor(intrinsic.intrinsic_matrix, o3d.core.Dtype.Float64)#.to(o3d_device)\n            extrinsic = o3d.core.Tensor(extrinsic, o3d.core.Dtype.Float64)#.to(o3d_device)\n            \n            frustum_block_coords = vbg.compute_unique_block_coordinates(\n                o3d_depth, intrinsic, extrinsic, 1.0, max_depth)\n\n            vbg.integrate(frustum_block_coords, o3d_depth, o3d_color, intrinsic,\n                          intrinsic, extrinsic, 1.0, max_depth)\n        \n        mesh = vbg.extract_triangle_mesh().to_legacy()\n        \n        # write mesh\n        o3d.io.write_triangle_mesh(os.path.join(outdir, mesh_name), mesh)\n\n        # Clean Mesh\n        if args.clean:\n            import pymeshlab\n            ms = pymeshlab.MeshSet()\n            ms.load_new_mesh(os.path.join(outdir, mesh_name))\n            ms.meshing_remove_unreferenced_vertices()\n            ms.meshing_remove_duplicate_faces()\n            ms.meshing_remove_null_faces()\n            ms.meshing_remove_connected_component_by_face_number(mincomponentsize=20000)\n            ms.save_current_mesh(os.path.join(outdir, mesh_name))\n        \n        with open(os.path.join(outdir, 'voxel_size.txt'), 'w') as f:\n            f.write(f'voxel_size: {args.voxel_size}')\n\n\ndef tsdf_cpu(args, cfg, model, cameras, dirs, bg, outdir, mesh_name='fused_mesh.ply', max_depth=5.0):\n    gaussExtractor = GaussianExtractor(model, render, cfg, bg_color=bg, dirs=dirs, prob_thres=args.prob_thres, alpha_thres=args.alpha_thres)\n    gaussExtractor.gaussians.active_sh_degree = 0\n    gaussExtractor.reconstruction(cameras)\n    # extract the mesh and save\n    if args.unbounded:\n        mesh = gaussExtractor.extract_mesh_unbounded(resolution=args.mesh_res)\n    else:\n        mesh = gaussExtractor.extract_mesh_bounded(voxel_size=args.voxel_size, sdf_trunc=5*args.voxel_size, depth_trunc=max_depth)\n    \n    o3d.io.write_triangle_mesh(os.path.join(outdir, mesh_name), mesh)\n    print(\"mesh saved at {}\".format(os.path.join(outdir, mesh_name)))\n    # post-process the mesh and save, saving the largest N clusters\n    mesh_post = post_process_mesh(mesh, cluster_to_keep=args.num_cluster)\n    o3d.io.write_triangle_mesh(os.path.join(outdir, mesh_name), mesh_post)\n    \n    \n    return\n\n\ndef main(args):\n    cfg = Config(args.cfg_path)\n    cfg.model.data_device = 'cpu'\n    cfg.model.load_normal = False\n    cfg.model.load_mask = False\n    args.voxel_size = cfg.model.mesh.voxel_size if args.voxel_size == 0 else args.voxel_size\n    \n    set_random_seed(cfg.seed)\n    \n    model = GaussianModel(cfg.model)\n    \n    scene = Scene(cfg.model, model, load_iteration=-1, shuffle=False)\n    model.trans = torch.from_numpy(scene.trans).cuda()\n    model.scale = torch.from_numpy(scene.scale).cuda() * 1.1\n    model.extent = scene.cameras_extent\n    cameras = scene.getTrainCameras().copy()[::args.split]\n    \n    model.training_setup(cfg.optim)\n    model.max_radii2D = torch.zeros((model.get_xyz.shape[0]), device=\"cuda\")\n    model.scale = torch.from_numpy(scene.scale).cuda()\n    \n    model.prune_outliers()\n    \n    bg_color = [1, 1, 1] if cfg.model.white_background else [0, 0, 0]\n    background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n    \n    print(f'Fusing into {args.mesh_name} vs: {args.voxel_size}...')\n    if args.method == 'tsdf':\n        dirs = scene.dirs\n        max_depth = (model.scale ** 2).sum().sqrt().item()\n        max_depth = args.max_depth\n        tsdf_fusion(args, cfg, model, cameras, dirs, background, cfg.logdir, args.mesh_name, max_depth)\n    elif args.method == 'tsdf_cpu':\n        dirs = scene.dirs\n        max_depth = args.max_depth\n        tsdf_cpu(args, cfg, model, cameras, dirs, background, cfg.logdir, args.mesh_name, max_depth)\n    \n    return\n\n\nif __name__ == '__main__':\n    parser = argparse.ArgumentParser()\n    parser.add_argument('--input', type=str, default='Barn')\n    parser.add_argument('--outdir', type=str, default=None)\n    parser.add_argument('--mesh_name', type=str, default='vcr_gaus.ply')\n    parser.add_argument('--scene', type=str, default='Barn')\n    parser.add_argument('--data_path', type=str, default='Barn')\n    parser.add_argument('--method', type=str, default='tsdf', choices=['tsdf', 'point2mesh', 'tsdf_cpu'])\n    parser.add_argument('--depth_mode', type=str, default='mean', choices=['mean', 'median'])\n    parser.add_argument('--rec_method', type=str, default='poisson', choices=['nksr', 'poisson'])\n    parser.add_argument('--split', type=int, default=3)\n    parser.add_argument('--resolution', type=float, default=1.0)\n    parser.add_argument('--detail_level', type=float, default=1.0)\n    parser.add_argument('--voxel_size', type=float, default=5e-3)\n    parser.add_argument('--sdf_trunc', type=float, default=0.08)\n    parser.add_argument('--alpha_thres', type=float, default=0.5)\n    parser.add_argument('--prob_thres', type=float, default=0.15)\n    parser.add_argument('--mise_iter', type=int, default=1)\n    parser.add_argument('--depth', type=int, default=9)\n    parser.add_argument('--max_depth', type=float, default=6.0)\n    parser.add_argument('--est_normal', action='store_true')\n    parser.add_argument('--cfg_path', type=str, default='configs/config_base.yaml')\n    parser.add_argument('--clean', action='store_true', help='perform a clean operation')\n    parser.add_argument(\"--unbounded\", action=\"store_true\", help='Mesh: using unbounded mode for meshing')\n    parser.add_argument(\"--num_cluster\", default=1000, type=int, help='Mesh: number of connected clusters to export')\n    args = parser.parse_args()\n    \n    main(args)"
  },
  {
    "path": "tools/distributed.py",
    "content": "'''\n-----------------------------------------------------------------------------\nCopyright (c) 2023, NVIDIA CORPORATION. All rights reserved.\n\nNVIDIA CORPORATION and its licensors retain all intellectual property\nand proprietary rights in and to this software, related documentation\nand any modifications thereto. Any use, reproduction, disclosure or\ndistribution of this software and related documentation without an express\nlicense agreement from NVIDIA CORPORATION is strictly prohibited.\n-----------------------------------------------------------------------------\n'''\n\nimport functools\nimport ctypes\n\nimport torch\nimport torch.distributed as dist\nfrom contextlib import contextmanager\n\n\ndef init_dist(local_rank, backend='nccl', **kwargs):\n    r\"\"\"Initialize distributed training\"\"\"\n    if dist.is_available():\n        if dist.is_initialized():\n            return torch.cuda.current_device()\n        torch.cuda.set_device(local_rank)\n        dist.init_process_group(backend=backend, init_method='env://', **kwargs)\n\n    # Increase the L2 fetch granularity for faster speed.\n    _libcudart = ctypes.CDLL('libcudart.so')\n    # Set device limit on the current device\n    # cudaLimitMaxL2FetchGranularity = 0x05\n    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))\n    _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))\n    _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))\n\n\ndef get_rank():\n    r\"\"\"Get rank of the thread.\"\"\"\n    rank = 0\n    if dist.is_available():\n        if dist.is_initialized():\n            rank = dist.get_rank()\n    return rank\n\n\ndef get_world_size():\n    r\"\"\"Get world size. How many GPUs are available in this job.\"\"\"\n    world_size = 1\n    if dist.is_available():\n        if dist.is_initialized():\n            world_size = dist.get_world_size()\n    return world_size\n\n\ndef broadcast_object_list(message, src=0):\n    r\"\"\"Broadcast object list from the master to the others\"\"\"\n    # Send logdir from master to all workers.\n    if dist.is_available():\n        if dist.is_initialized():\n            torch.distributed.broadcast_object_list(message, src=src)\n    return message\n\n\ndef master_only(func):\n    r\"\"\"Apply this function only to the master GPU.\"\"\"\n    @functools.wraps(func)\n    def wrapper(*args, **kwargs):\n        r\"\"\"Simple function wrapper for the master function\"\"\"\n        if get_rank() == 0:\n            return func(*args, **kwargs)\n        else:\n            return None\n    return wrapper\n\n\ndef is_master():\n    r\"\"\"check if current process is the master\"\"\"\n    return get_rank() == 0\n\n\ndef is_dist():\n    return dist.is_initialized()\n\n\ndef barrier():\n    if is_dist():\n        dist.barrier()\n\n\n@contextmanager\ndef master_first():\n    if not is_master():\n        barrier()\n    yield\n    if dist.is_initialized() and is_master():\n        barrier()\n\n\ndef is_local_master():\n    return torch.cuda.current_device() == 0\n\n\n@master_only\ndef master_only_print(*args):\n    r\"\"\"master-only print\"\"\"\n    print(*args)\n\n\ndef dist_reduce_tensor(tensor, rank=0, reduce='mean'):\n    r\"\"\" Reduce to rank 0 \"\"\"\n    world_size = get_world_size()\n    if world_size < 2:\n        return tensor\n    with torch.no_grad():\n        dist.reduce(tensor, dst=rank)\n        if get_rank() == rank:\n            if reduce == 'mean':\n                tensor /= world_size\n            elif reduce == 'sum':\n                pass\n            else:\n                raise NotImplementedError\n    return tensor\n\n\ndef dist_all_reduce_tensor(tensor, reduce='mean'):\n    r\"\"\" Reduce to all ranks \"\"\"\n    world_size = get_world_size()\n    if world_size < 2:\n        return tensor\n    with torch.no_grad():\n        dist.all_reduce(tensor)\n        if reduce == 'mean':\n            tensor /= world_size\n        elif reduce == 'sum':\n            pass\n        else:\n            raise NotImplementedError\n    return tensor\n\n\ndef dist_all_gather_tensor(tensor):\n    r\"\"\" gather to all ranks \"\"\"\n    world_size = get_world_size()\n    if world_size < 2:\n        return [tensor]\n    tensor_list = [\n        torch.ones_like(tensor) for _ in range(dist.get_world_size())]\n    with torch.no_grad():\n        dist.all_gather(tensor_list, tensor)\n    return tensor_list\n"
  },
  {
    "path": "tools/general_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport sys\nfrom datetime import datetime\nimport numpy as np\nimport random\nimport torchvision.transforms.functional as torchvision_F\nfrom PIL import ImageFile\nImageFile.LOAD_TRUNCATED_IMAGES = True\n\n\ndef inverse_sigmoid(x):\n    return torch.log(x/(1-x))\n\n\ndef PILtoTorch(pil_image, resolution):\n    resized_image_PIL = pil_image.resize(resolution)\n    resized_image = torch.from_numpy(np.array(resized_image_PIL))\n    if len(resized_image.shape) == 3:\n        return resized_image.permute(2, 0, 1)\n    else:\n        return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)\n\n\ndef NumpytoTorch(image, resolution):\n    image = torch.from_numpy(image)\n    if image.ndim == 4: image = image.squeeze(0)\n    if image.shape[-1] == 3 or image.shape[-1] == 1:\n        image = image.permute(2, 0, 1)\n    _, orig_h, orig_w = image.shape\n    if resolution == [orig_h, orig_w]:\n        resized_image = image\n    else:\n        resized_image = torchvision_F.resize(image, resolution, antialias=True)\n\n    return resized_image\n    \n\ndef get_expon_lr_func(\n    lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000\n):\n    \"\"\"\n    Copied from Plenoxels\n\n    Continuous learning rate decay function. Adapted from JaxNeRF\n    The returned rate is lr_init when step=0 and lr_final when step=max_steps, and\n    is log-linearly interpolated elsewhere (equivalent to exponential decay).\n    If lr_delay_steps>0 then the learning rate will be scaled by some smooth\n    function of lr_delay_mult, such that the initial learning rate is\n    lr_init*lr_delay_mult at the beginning of optimization but will be eased back\n    to the normal learning rate when steps>lr_delay_steps.\n    :param conf: config subtree 'lr' or similar\n    :param max_steps: int, the number of steps during optimization.\n    :return HoF which takes step as input\n    \"\"\"\n\n    def helper(step):\n        if step < 0 or (lr_init == 0.0 and lr_final == 0.0):\n            # Disable this parameter\n            return 0.0\n        if lr_delay_steps > 0:\n            # A kind of reverse cosine decay.\n            delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(\n                0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)\n            )\n        else:\n            delay_rate = 1.0\n        t = np.clip(step / max_steps, 0, 1)\n        log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)\n        return delay_rate * log_lerp\n\n    return helper\n\ndef strip_lowerdiag(L):\n    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=\"cuda\")\n\n    uncertainty[:, 0] = L[:, 0, 0]\n    uncertainty[:, 1] = L[:, 0, 1]\n    uncertainty[:, 2] = L[:, 0, 2]\n    uncertainty[:, 3] = L[:, 1, 1]\n    uncertainty[:, 4] = L[:, 1, 2]\n    uncertainty[:, 5] = L[:, 2, 2]\n    return uncertainty\n\ndef strip_symmetric(sym):\n    return strip_lowerdiag(sym)\n\ndef build_rotation(r):\n    norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])\n\n    q = r / norm[:, None]\n\n    R = torch.zeros((q.size(0), 3, 3), device='cuda')\n\n    r = q[:, 0]\n    x = q[:, 1]\n    y = q[:, 2]\n    z = q[:, 3]\n\n    R[:, 0, 0] = 1 - 2 * (y*y + z*z)\n    R[:, 0, 1] = 2 * (x*y - r*z)\n    R[:, 0, 2] = 2 * (x*z + r*y)\n    R[:, 1, 0] = 2 * (x*y + r*z)\n    R[:, 1, 1] = 1 - 2 * (x*x + z*z)\n    R[:, 1, 2] = 2 * (y*z - r*x)\n    R[:, 2, 0] = 2 * (x*z - r*y)\n    R[:, 2, 1] = 2 * (y*z + r*x)\n    R[:, 2, 2] = 1 - 2 * (x*x + y*y)\n    return R\n\ndef build_scaling_rotation(s, r):\n    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=\"cuda\")\n    R = build_rotation(r)\n\n    L[:,0,0] = s[:,0]\n    L[:,1,1] = s[:,1]\n    L[:,2,2] = s[:,2]\n\n    L = R @ L\n    return L\n\ndef safe_state(silent):\n    old_f = sys.stdout\n    class F:\n        def __init__(self, silent):\n            self.silent = silent\n\n        def write(self, x):\n            if not self.silent:\n                if x.endswith(\"\\n\"):\n                    old_f.write(x.replace(\"\\n\", \" [{}]\\n\".format(str(datetime.now().strftime(\"%d/%m %H:%M:%S\")))))\n                else:\n                    old_f.write(x)\n\n        def flush(self):\n            old_f.flush()\n\n    sys.stdout = F(silent)\n\n\ndef set_random_seed(seed):\n    r\"\"\"Set random seeds for everything, including random, numpy, torch.manual_seed, torch.cuda_manual_seed.\n    torch.cuda.manual_seed_all is not necessary (included in torch.manual_seed)\n\n    Args:\n        seed (int): Random seed.\n    \"\"\"\n    print(f\"Using random seed {seed}\")\n    random.seed(seed)\n    np.random.seed(seed)\n    torch.manual_seed(seed)         # sets seed on the current CPU & all GPUs\n    torch.cuda.manual_seed(seed)    # sets seed on current GPU\n    # torch.cuda.manual_seed_all(seed)  # included in torch.manual_seed\n    torch.cuda.set_device(torch.device(\"cuda:0\"))\n"
  },
  {
    "path": "tools/graphics_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport math\nimport numpy as np\nfrom typing import NamedTuple\n\nclass BasicPointCloud(NamedTuple):\n    points : np.array\n    colors : np.array\n    normals : np.array\n\ndef geom_transform_points(points, transf_matrix):\n    P, _ = points.shape\n    ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)\n    points_hom = torch.cat([points, ones], dim=1)\n    points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))\n\n    denom = points_out[..., 3:] + 0.0000001\n    return (points_out[..., :3] / denom).squeeze(dim=0)\n\ndef getWorld2View(R, t):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = R.transpose()\n    Rt[:3, 3] = t\n    Rt[3, 3] = 1.0\n    return np.float32(Rt)\n\ndef getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):\n    Rt = np.zeros((4, 4))   # w2c\n    Rt[:3, :3] = R.transpose()      # w2c\n    Rt[:3, 3] = t                   # w2c\n    Rt[3, 3] = 1.0\n\n    C2W = np.linalg.inv(Rt)         # c2w\n    cam_center = C2W[:3, 3]\n    cam_center = (cam_center + translate) * scale\n    C2W[:3, 3] = cam_center\n    Rt = np.linalg.inv(C2W)         # w2c\n    return np.float32(Rt)\n\ndef getView2World(R, t):\n    '''\n    R: w2c\n    t: w2c\n    '''\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = R.transpose()      # c2w\n    Rt[:3, 3] = -R.transpose() @ t  # c2w\n    Rt[3, 3] = 1.0\n\n    return Rt\n\ndef getProjectionMatrix(znear, zfar, fovX, fovY):\n    '''\n    normalized intrinsics\n    '''\n    tanHalfFovY = math.tan((fovY / 2))\n    tanHalfFovX = math.tan((fovX / 2))\n\n    top = tanHalfFovY * znear\n    bottom = -top\n    right = tanHalfFovX * znear\n    left = -right\n\n    P = torch.zeros(4, 4)\n\n    z_sign = 1.0\n\n    P[0, 0] = 2.0 * znear / (right - left)\n    P[1, 1] = 2.0 * znear / (top - bottom)\n    P[0, 2] = (right + left) / (right - left)\n    P[1, 2] = (top + bottom) / (top - bottom)\n    P[3, 2] = z_sign\n    P[2, 2] = z_sign * zfar / (zfar - znear)\n    P[2, 3] = -(zfar * znear) / (zfar - znear)\n    return P\n\n\ndef getIntrinsic(fovX, fovY, h, w):\n    focal_length_y = fov2focal(fovY, h)\n    focal_length_x = fov2focal(fovX, w)\n    \n    intrinsic = np.eye(3)\n    intrinsic = torch.eye(3, dtype=torch.float32)\n    \n    intrinsic[0, 0] = focal_length_x # FovX\n    intrinsic[1, 1] = focal_length_y # FovY\n    intrinsic[0, 2] = w / 2\n    intrinsic[1, 2] = h / 2\n    \n    return intrinsic\n\n\ndef fov2focal(fov, pixels):\n    return pixels / (2 * math.tan(fov / 2))\n\ndef focal2fov(focal, pixels):\n    return 2*math.atan(pixels/(2*focal))\n\n\ndef ndc_2_cam(ndc_xyz, intrinsic, W, H):\n    inv_scale = torch.tensor([[W - 1, H - 1]], device=ndc_xyz.device)\n    cam_z = ndc_xyz[..., 2:3]\n    cam_xy = ndc_xyz[..., :2] * inv_scale * cam_z\n    cam_xyz = torch.cat([cam_xy, cam_z], dim=-1)\n    cam_xyz = cam_xyz @ torch.inverse(intrinsic[0, ...].t())\n    return cam_xyz\n\n\ndef depth2point_cam(sampled_depth, ref_intrinsic):\n    B, N, C, H, W = sampled_depth.shape\n    valid_z = sampled_depth\n    valid_x = torch.arange(W, dtype=torch.float32, device=sampled_depth.device).add_(0.5) / (W - 1)\n    valid_y = torch.arange(H, dtype=torch.float32, device=sampled_depth.device).add_(0.5) / (H - 1)\n    valid_y, valid_x = torch.meshgrid(valid_y, valid_x, indexing='ij')\n    # B,N,H,W\n    valid_x = valid_x[None, None, None, ...].expand(B, N, C, -1, -1)\n    valid_y = valid_y[None, None, None, ...].expand(B, N, C, -1, -1)\n    ndc_xyz = torch.stack([valid_x, valid_y, valid_z], dim=-1).view(B, N, C, H, W, 3)  # 1, 1, 5, 512, 640, 3\n    cam_xyz = ndc_2_cam(ndc_xyz, ref_intrinsic, W, H) # 1, 1, 5, 512, 640, 3\n    return ndc_xyz, cam_xyz\n\n\ndef depth2point(depth_image, intrinsic_matrix, extrinsic_matrix):\n    _, xyz_cam = depth2point_cam(depth_image[None,None,None,...], intrinsic_matrix[None,...])\n    xyz_cam = xyz_cam.reshape(-1,3)\n    xyz_world = torch.cat([xyz_cam, torch.ones_like(xyz_cam[...,0:1])], axis=-1) @ torch.inverse(extrinsic_matrix).transpose(0,1)\n    xyz_world = xyz_world[...,:3]\n\n    return xyz_cam.reshape(*depth_image.shape, 3), xyz_world.reshape(*depth_image.shape, 3)\n\n\n@torch.no_grad()\ndef get_all_px_dir(intrinsics, height, width):\n    \"\"\"\n    # Calculate the view direction for all pixels/rays in the image.\n    # This is used for intersection calculation between ray and voxel textures.\n    # \"\"\"\n\n    a, ray_dir = depth2point_cam(torch.ones(1, 1, 1, height, width).cuda(), intrinsics[None])\n    a, ray_dir = a.squeeze(), ray_dir.squeeze()\n    ray_dir = torch.nn.functional.normalize(ray_dir, dim=-1)\n    \n    ray_dir = ray_dir.permute(2, 0, 1) # 3, H, W\n    return ray_dir"
  },
  {
    "path": "tools/image_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\n\ndef mse(img1, img2):\n    return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n\ndef psnr(img1, img2):\n    mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n    return 20 * torch.log10(1.0 / torch.sqrt(mse))\n"
  },
  {
    "path": "tools/loss_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\"\"\"\n[1] Feature Preserving Point Set Surfaces based on Non-Linear Kernel Regression\nCengiz Oztireli, Gaël Guennebaud, Markus Gross\n[2] Consolidation of Unorganized Point Clouds for Surface Reconstruction\nHui Huang, Dan Li, Hao Zhang, Uri Ascher Daniel Cohen-Or\n[3] Differentiable Surface Splatting for Point-based Geometry Processing\nWang Yifan, Felice Serena, Shihao Wu, Cengiz Oeztireli, Olga Sorkine-Hornung\n[4] 3D Gaussian Splatting for Real-Time Radiance Field Rendering\nBernhard Kerbl, Georgios Kopanas, Thomas Leimkühler, George Drettakis\n\"\"\"\n\nfrom typing import Optional\nfrom math import exp\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\n\n\ndef entropy_loss(opacity):\n    loss = (- opacity * torch.log(opacity + 1e-6) - \\\n        (1 - opacity) * torch.log(1 - opacity + 1e-6)).mean()\n    return loss\n\n\ndef l1_loss(network_output, gt):\n    return torch.abs((network_output - gt)).mean()\n\n\ndef log_l1_loss(network_output, gt):\n    loss = torch.log(1 + torch.abs((network_output - gt))).mean()\n    return loss\n\n\ndef l2_loss(network_output, gt):\n    return ((network_output - gt) ** 2).mean()\n\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])\n    return gauss / gauss.sum()\n\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())\n    return window\n\n\ndef ssim(img1, img2, window_size=11, size_average=True):\n    channel = img1.size(-3)\n    window = create_window(window_size, channel)\n\n    if img1.is_cuda:\n        window = window.cuda(img1.get_device())\n    window = window.type_as(img1)\n\n    return _ssim(img1, img2, window, window_size, channel, size_average)\n\n\ndef _ssim(img1, img2, window, window_size, channel, size_average=True):\n    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq\n    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq\n    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2\n\n    C1 = 0.01 ** 2\n    C2 = 0.03 ** 2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1).mean(1).mean(1)\n\n\ndef eikonal_loss(gradients):\n    gradient_error = (gradients.norm(dim=-1) - 1.0) ** 2  # [B,R,N]\n    gradient_error = gradient_error.nan_to_num(nan=0.0, posinf=0.0, neginf=0.0)  # [B,R,N]\n    \n    return gradient_error.mean()\n\n\ndef curvature_loss(hessian):\n    laplacian = hessian.sum(dim=-1).abs()  # [B,R,N]\n    laplacian = laplacian.nan_to_num(nan=0.0, posinf=0.0, neginf=0.0)  # [B,R,N]\n    \n    return laplacian.mean()\n\n\ndef compute_normal_loss(normal_pred, normal_gt, weight=None):\n    if weight is not None:\n        weight = weight.view(-1, 1)\n    else:\n        weight = 1.0\n    normal_pred = normal_pred.view(-1, 3)\n    normal_gt = normal_gt.view(-1, 3)\n    \n    cos = (1.0 - torch.sum(normal_pred * normal_gt * weight, dim=-1).abs()).mean()\n    \n    return cos\n\n\ndef monosdf_normal_loss(normal_pred: torch.Tensor, normal_gt: torch.Tensor, weight: Optional[torch.Tensor] = None):\n    \"\"\"normal consistency loss as monosdf\n\n    Args:\n        normal_pred (torch.Tensor): volume rendered normal\n        normal_gt (torch.Tensor): monocular normal\n    \"\"\"\n    if weight is None: weight = 1.0\n    l1 = (weight * torch.abs(normal_pred - normal_gt).sum(dim=-1)).mean()\n    cos = (weight * (1.0 - torch.sum(normal_pred * normal_gt, dim=-1))).mean()\n    return l1 + cos\n\n\ndef cos_weight(render_normal, gt_normal, exp_t=1.0):\n    cos = torch.sum(render_normal * gt_normal, dim=-1)\n    if exp_t > 0:\n        cos = torch.exp((cos - 1) / exp_t)\n    else:\n        cos = torch.ones_like(cos)\n\n    return cos.detach()\n\n\n# copy from MiDaS\ndef compute_scale_and_shift(prediction, target, mask):\n    # system matrix: A = [[a_00, a_01], [a_10, a_11]]\n    a_00 = torch.sum(mask * prediction * prediction, (1, 2))\n    a_01 = torch.sum(mask * prediction, (1, 2))\n    a_11 = torch.sum(mask, (1, 2))\n\n    # right hand side: b = [b_0, b_1]\n    b_0 = torch.sum(mask * prediction * target, (1, 2))\n    b_1 = torch.sum(mask * target, (1, 2))\n\n    # solution: x = A^-1 . b = [[a_11, -a_01], [-a_10, a_00]] / (a_00 * a_11 - a_01 * a_10) . b\n    x_0 = torch.zeros_like(b_0)\n    x_1 = torch.zeros_like(b_1)\n\n    det = a_00 * a_11 - a_01 * a_01\n    valid = det.nonzero()\n\n    x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid]\n    x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid]\n\n    return x_0, x_1\n\n\ndef reduction_batch_based(image_loss, M):\n    # average of all valid pixels of the batch\n\n    # avoid division by 0 (if sum(M) = sum(sum(mask)) = 0: sum(image_loss) = 0)\n    divisor = torch.sum(M)\n\n    if divisor == 0:\n        return 0\n    else:\n        return torch.sum(image_loss) / divisor\n\n\ndef reduction_image_based(image_loss, M):\n    # mean of average of valid pixels of an image\n\n    # avoid division by 0 (if M = sum(mask) = 0: image_loss = 0)\n    valid = M.nonzero()\n\n    image_loss[valid] = image_loss[valid] / M[valid]\n\n    return torch.mean(image_loss)\n\n\ndef mse_loss(prediction, target, mask, reduction=reduction_batch_based):\n\n    M = torch.sum(mask, (1, 2))\n    res = prediction - target\n    image_loss = torch.sum(mask * res * res, (1, 2))\n\n    return reduction(image_loss, 2 * M)\n\n\ndef gradient_loss(prediction, target, mask, reduction=reduction_batch_based):\n\n    M = torch.sum(mask, (1, 2))\n\n    diff = prediction - target\n    diff = torch.mul(mask, diff)\n\n    grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])\n    mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])\n    grad_x = torch.mul(mask_x, grad_x)\n\n    grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])\n    mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])\n    grad_y = torch.mul(mask_y, grad_y)\n\n    image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2))\n\n    return reduction(image_loss, M)\n\n\nclass MSELoss(nn.Module):\n    def __init__(self, reduction='batch-based'):\n        super().__init__()\n\n        if reduction == 'batch-based':\n            self.__reduction = reduction_batch_based\n        else:\n            self.__reduction = reduction_image_based\n\n    def forward(self, prediction, target, mask):\n        return mse_loss(prediction, target, mask, reduction=self.__reduction)\n\n\nclass GradientLoss(nn.Module):\n    def __init__(self, scales=4, reduction='batch-based'):\n        super().__init__()\n\n        if reduction == 'batch-based':\n            self.__reduction = reduction_batch_based\n        else:\n            self.__reduction = reduction_image_based\n\n        self.__scales = scales\n\n    def forward(self, prediction, target, mask):\n        total = 0\n\n        for scale in range(self.__scales):\n            step = pow(2, scale)\n\n            total += gradient_loss(prediction[:, ::step, ::step], target[:, ::step, ::step],\n                                   mask[:, ::step, ::step], reduction=self.__reduction)\n\n        return total\n\n\nclass ScaleAndShiftInvariantLoss(nn.Module):\n    def __init__(self, alpha=0.5, scales=1, reduction='batch-based'):\n        super().__init__()\n\n        self.__data_loss = MSELoss(reduction=reduction)\n        self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction)\n        self.__alpha = alpha\n\n        self.__prediction_ssi = None\n\n    def forward(self, prediction, target, mask=None):\n        target = target * 50 + 0.5\n        if mask is None: mask = torch.ones_like(target)\n\n        scale, shift = compute_scale_and_shift(prediction, target, mask)\n        self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1)\n\n        total = self.__data_loss(self.__prediction_ssi, target, mask)\n        if self.__alpha > 0:\n            total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask)\n\n        return total\n\n    def __get_prediction_ssi(self):\n        return self.__prediction_ssi\n\n    prediction_ssi = property(__get_prediction_ssi)\n# end copy\n    \n\ndef normal2curv(normal, mask = None):\n    n = normal\n    m = mask\n    n = torch.nn.functional.pad(n[None], [0, 0, 1, 1, 1, 1], mode='replicate')\n    m = torch.nn.functional.pad(m[None].to(torch.float32), [0, 0, 1, 1, 1, 1], mode='replicate').to(torch.bool)\n    n_c = (n[:, 1:-1, 1:-1, :]      ) * m[:, 1:-1, 1:-1, :]\n    n_u = (n[:,  :-2, 1:-1, :] - n_c) * m[:,  :-2, 1:-1, :]\n    n_l = (n[:, 1:-1,  :-2, :] - n_c) * m[:, 1:-1,  :-2, :]\n    n_b = (n[:, 2:  , 1:-1, :] - n_c) * m[:, 2:  , 1:-1, :]\n    n_r = (n[:, 1:-1, 2:  , :] - n_c) * m[:, 1:-1, 2:  , :]\n    curv = (n_u + n_l + n_b + n_r)[0]\n    curv = curv * mask\n    curv = curv.norm(1, -1, True)\n    return curv\n\n\ndef L1_loss_appearance(image, gt_image, gaussians, view_idx, return_transformed_image=False):\n    appearance_embedding = gaussians.get_apperance_embedding(view_idx)\n    # center crop the image\n    origH, origW = image.shape[1:]\n    H = origH // 32 * 32\n    W = origW // 32 * 32\n    left = origW // 2 - W // 2\n    top = origH // 2 - H // 2\n    crop_image = image[:, top:top+H, left:left+W]\n    crop_gt_image = gt_image[:, top:top+H, left:left+W]\n    \n    # down sample the image\n    crop_image_down = torch.nn.functional.interpolate(crop_image[None], size=(H//32, W//32), mode=\"bilinear\", align_corners=True)[0]\n    \n    crop_image_down = torch.cat([crop_image_down, appearance_embedding[None].repeat(H//32, W//32, 1).permute(2, 0, 1)], dim=0)[None]\n    mapping_image = gaussians.appearance_network(crop_image_down)\n    transformed_image = mapping_image * crop_image\n    if not return_transformed_image:\n        return l1_loss(transformed_image, crop_gt_image)\n    else:\n        transformed_image = torch.nn.functional.interpolate(transformed_image, size=(origH, origW), mode=\"bilinear\", align_corners=True)[0]\n        return transformed_image\n"
  },
  {
    "path": "tools/math_utils.py",
    "content": "import torch\n\n\ndef eps_sqrt(squared, eps=1e-17):\n    \"\"\"\n    Prepare for the input for sqrt, make sure the input positive and\n    larger than eps\n    \"\"\"\n    return torch.clamp(squared.abs(), eps)\n\n\ndef ndc_to_pix(p, resolution):\n    \"\"\"\n    Reverse of pytorch3d pix_to_ndc function\n    Args:\n        p (float tensor): (..., 3)\n        resolution (scalar): image resolution (for now, supports only aspectratio = 1)\n    Returns:\n        pix (long tensor): (..., 2)\n    \"\"\"\n    pix = resolution - ((p[..., :2] + 1.0) * resolution - 1.0) / 2\n    return pix\n\n\ndef decompose_to_R_and_t(transform_mat, row_major=True):\n    \"\"\" decompose a 4x4 transform matrix to R (3,3) and t (1,3)\"\"\"\n    assert(transform_mat.shape[-2:] == (4, 4)), \\\n        \"Expecting batches of 4x4 matrice\"\n    # ... 3x3\n    if not row_major:\n        transform_mat = transform_mat.transpose(-2, -1)\n\n    R = transform_mat[..., :3, :3]\n    t = transform_mat[..., -1, :3]\n\n    return R, t\n\n\ndef to_homogen(x, dim=-1):\n    \"\"\" append one to the specified dimension \"\"\"\n    if dim < 0:\n        dim = x.ndim + dim\n    shp = x.shape\n    new_shp = shp[:dim] + (1, ) + shp[dim + 1:]\n    x_homogen = x.new_ones(new_shp)\n    x_homogen = torch.cat([x, x_homogen], dim=dim)\n    return x_homogen\n\n\ndef normalize_pts(pts, trans, scale):\n    '''\n    trans: (4, 4), world to \n    '''\n    if trans.ndim == 1:\n        pts = (pts - trans) / scale\n    else:\n        pts = ((trans[:3, :3] @ pts.T + trans[:3, 3:]).T) / scale\n    return pts\n\n\ndef inv_normalize_pts(pts, trans, scale):\n    if trans.ndim == 1:\n        pts = pts * scale + trans\n    else:\n        pts = (pts * scale[None] - trans[:3, 3:].T) @ trans[:3, :3]\n    \n    return pts\n\n\ndef get_inside_normalized(xyz, trans, scale):\n    pts = normalize_pts(xyz, trans, scale)\n    with torch.no_grad():\n        inside = torch.all(torch.abs(pts) < 1, dim=-1)\n    return inside, pts"
  },
  {
    "path": "tools/mcube_utils.py",
    "content": "#\n# Copyright (C) 2024, ShanghaiTech\n# SVIP research group, https://github.com/svip-lab\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  huangbb@shanghaitech.edu.cn\n#\n\nimport numpy as np\nimport torch\nimport trimesh\nfrom skimage import measure\n# modified from here https://github.com/autonomousvision/sdfstudio/blob/370902a10dbef08cb3fe4391bd3ed1e227b5c165/nerfstudio/utils/marching_cubes.py#L201\ndef marching_cubes_with_contraction(\n    sdf,\n    resolution=512,\n    bounding_box_min=(-1.0, -1.0, -1.0),\n    bounding_box_max=(1.0, 1.0, 1.0),\n    return_mesh=False,\n    level=0,\n    simplify_mesh=True,\n    inv_contraction=None,\n    max_range=32.0,\n):\n    assert resolution % 512 == 0\n\n    resN = resolution\n    cropN = 512\n    level = 0\n    N = resN // cropN\n\n    grid_min = bounding_box_min\n    grid_max = bounding_box_max\n    xs = np.linspace(grid_min[0], grid_max[0], N + 1)\n    ys = np.linspace(grid_min[1], grid_max[1], N + 1)\n    zs = np.linspace(grid_min[2], grid_max[2], N + 1)\n\n    meshes = []\n    for i in range(N):\n        for j in range(N):\n            for k in range(N):\n                print(i, j, k)\n                x_min, x_max = xs[i], xs[i + 1]\n                y_min, y_max = ys[j], ys[j + 1]\n                z_min, z_max = zs[k], zs[k + 1]\n\n                x = np.linspace(x_min, x_max, cropN)\n                y = np.linspace(y_min, y_max, cropN)\n                z = np.linspace(z_min, z_max, cropN)\n\n                xx, yy, zz = np.meshgrid(x, y, z, indexing=\"ij\")\n                points = torch.tensor(np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T, dtype=torch.float).cuda()\n\n                @torch.no_grad()\n                def evaluate(points):\n                    z = []\n                    for _, pnts in enumerate(torch.split(points, 256**3, dim=0)):\n                        z.append(sdf(pnts))\n                    z = torch.cat(z, axis=0)\n                    return z\n\n                # construct point pyramids\n                points = points.reshape(cropN, cropN, cropN, 3)\n                points = points.reshape(-1, 3)\n                pts_sdf = evaluate(points.contiguous())\n                z = pts_sdf.detach().cpu().numpy()\n                if not (np.min(z) > level or np.max(z) < level):\n                    z = z.astype(np.float32)\n                    verts, faces, normals, _ = measure.marching_cubes(\n                        volume=z.reshape(cropN, cropN, cropN),\n                        level=level,\n                        spacing=(\n                            (x_max - x_min) / (cropN - 1),\n                            (y_max - y_min) / (cropN - 1),\n                            (z_max - z_min) / (cropN - 1),\n                        ),\n                    )\n                    verts = verts + np.array([x_min, y_min, z_min])\n                    meshcrop = trimesh.Trimesh(verts, faces, normals)\n                    meshes.append(meshcrop)\n                \n                print(\"finished one block\")\n\n    combined = trimesh.util.concatenate(meshes)\n    combined.merge_vertices(digits_vertex=6)\n\n    # inverse contraction and clipping the points range\n    if inv_contraction is not None:\n        combined.vertices = inv_contraction(torch.from_numpy(combined.vertices).float().cuda()).cpu().numpy()\n        combined.vertices = np.clip(combined.vertices, -max_range, max_range)\n    \n    return combined"
  },
  {
    "path": "tools/mesh_utils.py",
    "content": "import torch\nimport numpy as np\nimport os\nimport math\nfrom tqdm import tqdm\nfrom functools import partial\nimport open3d as o3d\n\nfrom tools.render_utils import save_img_f32, save_img_u8\nfrom tools.semantic_id import BACKGROUND\nfrom tools.graphics_utils import depth2point\nfrom tools.math_utils import get_inside_normalized\n\n\ndef post_process_mesh(mesh, cluster_to_keep=1000):\n    \"\"\"\n    Post-process a mesh to filter out floaters and disconnected parts\n    \"\"\"\n    import copy\n    print(\"post processing the mesh to have {} clusterscluster_to_kep\".format(cluster_to_keep))\n    mesh_0 = copy.deepcopy(mesh)\n    with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:\n            triangle_clusters, cluster_n_triangles, cluster_area = (mesh_0.cluster_connected_triangles())\n\n    triangle_clusters = np.asarray(triangle_clusters)\n    cluster_n_triangles = np.asarray(cluster_n_triangles)\n    cluster_area = np.asarray(cluster_area)\n    n_cluster = np.sort(cluster_n_triangles.copy())[-cluster_to_keep]\n    n_cluster = max(n_cluster, 50) # filter meshes smaller than 50\n    triangles_to_remove = cluster_n_triangles[triangle_clusters] < n_cluster\n    mesh_0.remove_triangles_by_mask(triangles_to_remove)\n    mesh_0.remove_unreferenced_vertices()\n    mesh_0.remove_degenerate_triangles()\n    print(\"num vertices raw {}\".format(len(mesh.vertices)))\n    print(\"num vertices post {}\".format(len(mesh_0.vertices)))\n    return mesh_0\n\ndef to_cam_open3d(viewpoint_stack):\n    camera_traj = []\n    for i, viewpoint_cam in enumerate(viewpoint_stack):\n        intrinsic=o3d.camera.PinholeCameraIntrinsic(width=viewpoint_cam.image_width, \n                    height=viewpoint_cam.image_height, \n                    cx = viewpoint_cam.image_width/2,\n                    cy = viewpoint_cam.image_height/2,\n                    fx = viewpoint_cam.image_width / (2 * math.tan(viewpoint_cam.FoVx / 2.)),\n                    fy = viewpoint_cam.image_height / (2 * math.tan(viewpoint_cam.FoVy / 2.)))\n\n        extrinsic=np.asarray((viewpoint_cam.world_view_transform.T).cpu().numpy())\n        camera = o3d.camera.PinholeCameraParameters()\n        camera.extrinsic = extrinsic\n        camera.intrinsic = intrinsic\n        camera_traj.append(camera)\n\n    return camera_traj\n\n\nclass GaussianExtractor(object):\n    def __init__(self, gaussians, render, cfg, bg_color=None, dirs=None, prob_thres=0.2, alpha_thres=0.5):\n        \"\"\"\n        a class that extracts attributes a scene presented by 2DGS\n\n        Usage example:\n        >>> gaussExtrator = GaussianExtractor(gaussians, render, pipe)\n        >>> gaussExtrator.reconstruction(view_points)\n        >>> mesh = gaussExtractor.export_mesh_bounded(...)\n        \"\"\"\n        if bg_color is None:\n            bg_color = [0, 0, 0]\n        if isinstance(bg_color, torch.Tensor): background = bg_color.clone().detach()\n        else: background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n        self.gaussians = gaussians\n        self.render = partial(render, cfg=cfg, bg_color=background, dirs=dirs)\n        self.prob_thres = prob_thres\n        self.alpha_thres = alpha_thres\n        self.clean()\n\n    @torch.no_grad()\n    def clean(self):\n        self.depthmaps = []\n        self.alphamaps = []\n        self.rgbmaps = []\n        self.normals = []\n        self.depth_normals = []\n        self.viewpoint_stack = []\n\n    @torch.no_grad()\n    def reconstruction(self, viewpoint_stack):\n        \"\"\"\n        reconstruct radiance field given cameras\n        \"\"\"\n        self.clean()\n        self.viewpoint_stack = viewpoint_stack\n        for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc=\"reconstruct radiance fields\", total=len(self.viewpoint_stack)):\n            render_pkg = self.render(viewpoint_cam, self.gaussians)\n            rgb = render_pkg['render']\n            alpha = render_pkg['alpha']\n            normal = torch.nn.functional.normalize(render_pkg['normal'], dim=0)\n            normal = render_pkg['normal'].permute(1, 2, 0)\n            depth = render_pkg['depth']\n            \n            if 'render_sem' in render_pkg:\n                semantic = render_pkg[\"render_sem\"]\n                prob = self.gaussians.logits2prob(semantic)\n                mask = (prob[..., BACKGROUND] > self.prob_thres)[None]\n                depth[mask] = 0\n                \n            rendered_pcd_world = depth2point(depth[0], viewpoint_cam.intr, viewpoint_cam.world_view_transform.transpose(0, 1))[1]\n            inside = get_inside_normalized(rendered_pcd_world.view(-1, 3), self.gaussians.trans, self.gaussians.scale)[0]\n            depth.view(-1)[~inside] = 0\n            \n                \n            depth_normal = render_pkg['est_normal'].permute(1, 2, 0)\n            self.rgbmaps.append(rgb.cpu())\n            self.depthmaps.append(depth.cpu())\n            self.alphamaps.append(alpha.cpu())\n            self.normals.append(normal.cpu())\n            self.depth_normals.append(depth_normal.cpu())\n        \n        self.rgbmaps = torch.stack(self.rgbmaps, dim=0)\n        self.depthmaps = torch.stack(self.depthmaps, dim=0)\n        self.alphamaps = torch.stack(self.alphamaps, dim=0)\n        self.depth_normals = torch.stack(self.depth_normals, dim=0)\n\n    @torch.no_grad()\n    def extract_mesh_bounded(self, voxel_size=0.004, sdf_trunc=0.02, depth_trunc=3, mask_backgrond=True):\n        \"\"\"\n        Perform TSDF fusion given a fixed depth range, used in the paper.\n        \n        voxel_size: the voxel size of the volume\n        sdf_trunc: truncation value\n        depth_trunc: maximum depth range, should depended on the scene's scales\n        mask_backgrond: whether to mask backgroud, only works when the dataset have masks\n\n        return o3d.mesh\n        \"\"\"\n        print(\"Running tsdf volume integration ...\")\n        print(f'voxel_size: {voxel_size}')\n        print(f'sdf_trunc: {sdf_trunc}')\n        print(f'depth_truc: {depth_trunc}')\n\n        volume = o3d.pipelines.integration.ScalableTSDFVolume(\n            voxel_length= voxel_size,\n            sdf_trunc=sdf_trunc,\n            color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8\n        )\n\n        for i, cam_o3d in tqdm(enumerate(to_cam_open3d(self.viewpoint_stack)), desc=\"TSDF integration progress\", total=len(self.viewpoint_stack)):\n            rgb = self.rgbmaps[i]\n            depth = self.depthmaps[i]\n            \n            # if we have mask provided, use it\n            if mask_backgrond and (self.viewpoint_stack[i].gt_alpha_mask is not None):\n                depth[(self.viewpoint_stack[i].gt_alpha_mask < 0.5)] = 0\n\n            # make open3d rgbd\n            rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(\n                o3d.geometry.Image(np.asarray(rgb.permute(1,2,0).cpu().numpy() * 255, order=\"C\", dtype=np.uint8)),\n                o3d.geometry.Image(np.asarray(depth.permute(1,2,0).cpu().numpy(), order=\"C\")),\n                depth_trunc = depth_trunc, convert_rgb_to_intensity=False,\n                depth_scale = 1.0\n            )\n\n            volume.integrate(rgbd, intrinsic=cam_o3d.intrinsic, extrinsic=cam_o3d.extrinsic)\n\n        mesh = volume.extract_triangle_mesh()\n        return mesh\n\n    @torch.no_grad()\n    def extract_mesh_unbounded(self, resolution=1024):\n        \"\"\"\n        Experimental features, extracting meshes from unbounded scenes, not fully test across datasets. \n        #TODO: support color mesh exporting\n\n        sdf_trunc: truncation value\n        return o3d.mesh\n        \"\"\"\n        def contract(x):\n            mag = torch.linalg.norm(x, ord=2, dim=-1)[..., None]\n            return torch.where(mag < 1, x, (2 - (1 / mag)) * (x / mag))\n        \n        def uncontract(y):\n            mag = torch.linalg.norm(y, ord=2, dim=-1)[..., None]\n            return torch.where(mag < 1, y, (1 / (2-mag) * (y/mag)))\n\n        def compute_sdf_perframe(i, points, depthmap, rgbmap, normalmap, viewpoint_cam):\n            \"\"\"\n                compute per frame sdf\n            \"\"\"\n            new_points = torch.cat([points, torch.ones_like(points[...,:1])], dim=-1) @ viewpoint_cam.full_proj_transform\n            z = new_points[..., -1:]\n            pix_coords = (new_points[..., :2] / new_points[..., -1:])\n            mask_proj = ((pix_coords > -1. ) & (pix_coords < 1.) & (z > 0)).all(dim=-1)\n            sampled_depth = torch.nn.functional.grid_sample(depthmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(-1, 1)\n            sampled_rgb = torch.nn.functional.grid_sample(rgbmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(3,-1).T\n            sampled_normal = torch.nn.functional.grid_sample(normalmap.cuda()[None], pix_coords[None, None], mode='bilinear', padding_mode='border', align_corners=True).reshape(3,-1).T\n            sdf = (sampled_depth-z)\n            return sdf, sampled_rgb, sampled_normal, mask_proj\n\n        def compute_unbounded_tsdf(samples, inv_contraction, voxel_size, return_rgb=False):\n            \"\"\"\n                Fusion all frames, perform adaptive sdf_funcation on the contract spaces.\n            \"\"\"\n            if inv_contraction is not None:\n                samples = inv_contraction(samples)\n                mask = torch.linalg.norm(samples, dim=-1) > 1\n                # adaptive sdf_truncation\n                sdf_trunc = 5 * voxel_size * torch.ones_like(samples[:, 0])\n                sdf_trunc[mask] *= 1/(2-torch.linalg.norm(samples, dim=-1)[mask].clamp(max=1.9))\n            else:\n                sdf_trunc = 5 * voxel_size\n\n            tsdfs = torch.ones_like(samples[:,0]) * 1\n            rgbs = torch.zeros((samples.shape[0], 3)).cuda()\n\n            weights = torch.ones_like(samples[:,0])\n            for i, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc=\"TSDF integration progress\"):\n                sdf, rgb, normal, mask_proj = compute_sdf_perframe(i, samples,\n                    depthmap = self.depthmaps[i],\n                    rgbmap = self.rgbmaps[i],\n                    normalmap = self.depth_normals[i],\n                    viewpoint_cam=self.viewpoint_stack[i],\n                )\n\n                # volume integration\n                sdf = sdf.flatten()\n                mask_proj = mask_proj & (sdf > -sdf_trunc)\n                sdf = torch.clamp(sdf / sdf_trunc, min=-1.0, max=1.0)[mask_proj]\n                w = weights[mask_proj]\n                wp = w + 1\n                tsdfs[mask_proj] = (tsdfs[mask_proj] * w + sdf) / wp\n                rgbs[mask_proj] = (rgbs[mask_proj] * w[:,None] + rgb[mask_proj]) / wp[:,None]\n                # update weight\n                weights[mask_proj] = wp\n            \n            if return_rgb:\n                return tsdfs, rgbs\n\n            return tsdfs\n\n        from tools.render_utils import focus_point_fn\n        torch.cuda.empty_cache()\n        c2ws = np.array([np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in self.viewpoint_stack])\n        poses = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1])\n        center = (focus_point_fn(poses))\n        radius = np.linalg.norm(c2ws[:,:3,3] - center, axis=-1).min()\n        center = torch.from_numpy(center).float().cuda()\n        normalize = lambda x: (x - center) / radius\n        unnormalize = lambda x: (x * radius) + center\n        inv_contraction = lambda x: unnormalize(uncontract(x))\n\n        N = resolution\n        voxel_size = (radius * 2 / N)\n        print(f\"Computing sdf gird resolution {N} x {N} x {N}\")\n        print(f\"Define the voxel_size as {voxel_size}\")\n        sdf_function = lambda x: compute_unbounded_tsdf(x, inv_contraction, voxel_size)\n        from tools.mcube_utils import marching_cubes_with_contraction\n        R = contract(normalize(self.gaussians.get_xyz)).norm(dim=-1).cpu().numpy()\n        R = np.quantile(R, q=0.95)\n        R = min(R+0.01, 1.9)\n\n        mesh = marching_cubes_with_contraction(\n            sdf=sdf_function,\n            bounding_box_min=(-R, -R, -R),\n            bounding_box_max=(R, R, R),\n            level=0,\n            resolution=N,\n            inv_contraction=inv_contraction,\n        )\n        \n        # coloring the mesh\n        torch.cuda.empty_cache()\n        mesh = mesh.as_open3d\n        print(\"texturing mesh ... \")\n        _, rgbs = compute_unbounded_tsdf(torch.tensor(np.asarray(mesh.vertices)).float().cuda(), inv_contraction=None, voxel_size=voxel_size, return_rgb=True)\n        mesh.vertex_colors = o3d.utility.Vector3dVector(rgbs.cpu().numpy())\n        return mesh\n\n    @torch.no_grad()\n    def export_image(self, path):\n        render_path = os.path.join(path, \"renders\")\n        gts_path = os.path.join(path, \"gt\")\n        vis_path = os.path.join(path, \"vis\")\n        os.makedirs(render_path, exist_ok=True)\n        os.makedirs(vis_path, exist_ok=True)\n        os.makedirs(gts_path, exist_ok=True)\n        for idx, viewpoint_cam in tqdm(enumerate(self.viewpoint_stack), desc=\"export images\"):\n            gt = viewpoint_cam.original_image[0:3, :, :]\n            save_img_u8(gt.permute(1,2,0).cpu().numpy(), os.path.join(gts_path, '{0:05d}'.format(idx) + \".png\"))\n            save_img_u8(self.rgbmaps[idx].permute(1,2,0).cpu().numpy(), os.path.join(render_path, '{0:05d}'.format(idx) + \".png\"))\n            save_img_f32(self.depthmaps[idx][0].cpu().numpy(), os.path.join(vis_path, 'depth_{0:05d}'.format(idx) + \".tiff\"))\n            save_img_u8(self.normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'normal_{0:05d}'.format(idx) + \".png\"))\n            save_img_u8(self.depth_normals[idx].permute(1,2,0).cpu().numpy() * 0.5 + 0.5, os.path.join(vis_path, 'depth_normal_{0:05d}'.format(idx) + \".png\"))"
  },
  {
    "path": "tools/normal_utils.py",
    "content": "import torch\nimport torch.nn.functional as F\n\nfrom tools.graphics_utils import depth2point_cam\n\n\ndef get_normal_sign(normals, begin=None, end=None, trans=None, mode='origin', vec=None):\n    if mode == 'origin':\n        if vec is None:\n            if begin is None:\n                # center\n                if trans is not None:\n                    begin = - trans[:3, :3].T @ trans[:3, 3] \\\n                        if trans.ndim != 1 else trans\n                else:\n                    begin = end.mean(0)\n                begin[1] += 1\n            vec = end - begin\n        cos = (normals * vec).sum(-1, keepdim=True)\n    \n    return cos\n\n\ndef compute_gradient(img):\n    dy = torch.gradient(img, dim=0)[0]\n    dx = torch.gradient(img, dim=1)[0]\n    return dx, dy\n\n\ndef compute_normals(depth_map, K):\n    # Assuming depth_map is a PyTorch tensor of shape [H, W]\n    # K_inv is the inverse of the intrinsic matrix\n    \n    _, cam_coords = depth2point_cam(depth_map[None, None], K[None])\n    cam_coords = cam_coords.squeeze(0).squeeze(0).squeeze(0)        # [H, W, 3]\n    \n    dx, dy = compute_gradient(cam_coords)\n    # Cross product of gradients gives normal\n    normals = torch.cross(dx, dy, dim=-1)\n    normals = F.normalize(normals, p=2, dim=-1)\n    return normals\n    \n\ndef compute_edge(image, k=11, thr=0.01):\n    dx, dy = compute_gradient(image)\n    \n    edge = torch.sqrt(dx**2 + dy**2)\n    edge = edge / edge.max()\n    \n    p = (k - 1) // 2\n    edge = F.max_pool2d(edge[None], kernel_size=k, stride=1, padding=p)[0]\n        \n    edge[edge>thr] = 1\n    return edge\n\n\ndef get_edge_aware_distortion_map(gt_image, distortion_map):\n    grad_img_left = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 1:-1, :-2]), 0)\n    grad_img_right = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 1:-1, 2:]), 0)\n    grad_img_top = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, :-2, 1:-1]), 0)\n    grad_img_bottom = torch.mean(torch.abs(gt_image[:, 1:-1, 1:-1] - gt_image[:, 2:, 1:-1]), 0)\n    max_grad = torch.max(torch.stack([grad_img_left, grad_img_right, grad_img_top, grad_img_bottom], dim=-1), dim=-1)[0]\n    # pad\n    max_grad = torch.exp(-max_grad)\n    max_grad = torch.nn.functional.pad(max_grad, (1, 1, 1, 1), mode=\"constant\", value=0)\n    return distortion_map * max_grad"
  },
  {
    "path": "tools/prune.py",
    "content": "import torch\n\nfrom gaussian_renderer import count_render, visi_acc_render\n\n\ndef calculate_v_imp_score(gaussians, imp_list, v_pow):\n    \"\"\"\n    :param gaussians: A data structure containing Gaussian components with a get_scaling method.\n    :param imp_list: The importance scores for each Gaussian component.\n    :param v_pow: The power to which the volume ratios are raised.\n    :return: A list of adjusted values (v_list) used for pruning.\n    \"\"\"\n    # Calculate the volume of each Gaussian component\n    volume = torch.prod(gaussians.get_scaling, dim=1)\n    # Determine the kth_percent_largest value\n    index = int(len(volume) * 0.9)\n    sorted_volume, _ = torch.sort(volume, descending=True)\n    kth_percent_largest = sorted_volume[index]\n    # Calculate v_list\n    v_list = torch.pow(volume / kth_percent_largest, v_pow)\n    v_list = v_list * imp_list\n    return v_list\n\n\ndef prune_list(gaussians, viewpoint_stack, pipe, background):\n    gaussian_list, imp_list = None, None\n    viewpoint_cam = viewpoint_stack.pop()\n    render_pkg = count_render(viewpoint_cam, gaussians, pipe, background)\n    gaussian_list, imp_list = (\n        render_pkg[\"gaussians_count\"],\n        render_pkg[\"important_score\"],\n    )\n\n        \n    for iteration in range(len(viewpoint_stack)):\n        # Pick a random Camera\n        # prunning\n        viewpoint_cam = viewpoint_stack.pop()\n        render_pkg = count_render(viewpoint_cam, gaussians, pipe, background)\n        gaussians_count, important_score = (\n            render_pkg[\"gaussians_count\"].detach(),\n            render_pkg[\"important_score\"].detach(),\n        )\n        gaussian_list += gaussians_count\n        imp_list += important_score\n        \n    return gaussian_list, imp_list\n\n\nv_render = visi_acc_render\ndef get_visi_list(gaussians, viewpoint_stack, pipe, background):\n    out = {}\n    gaussian_list = None\n    viewpoint_cam = viewpoint_stack.pop()\n    render_pkg = v_render(viewpoint_cam, gaussians, pipe, background)\n    gaussian_list = render_pkg[\"countlist\"]\n    \n    for i in range(len(viewpoint_stack)):\n        # Pick a random Camera\n        # prunning\n        viewpoint_cam = viewpoint_stack.pop()\n        render_pkg = v_render(viewpoint_cam, gaussians, pipe, background)\n        gaussians_count = render_pkg[\"countlist\"].detach()\n        gaussian_list += gaussians_count\n        \n    visi = gaussian_list > 0\n        \n    out[\"visi\"] = visi\n    return out\n\n"
  },
  {
    "path": "tools/render_utils.py",
    "content": "# Copyright 2022 Google LLC\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     https://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport numpy as np\nimport os\nfrom typing import Tuple\nimport copy\nfrom PIL import Image\nimport mediapy as media\nfrom matplotlib import cm\nfrom tqdm import tqdm\n\nimport torch\n\ndef normalize(x: np.ndarray) -> np.ndarray:\n  \"\"\"Normalization helper function.\"\"\"\n  return x / np.linalg.norm(x)\n\ndef pad_poses(p: np.ndarray) -> np.ndarray:\n  \"\"\"Pad [..., 3, 4] pose matrices with a homogeneous bottom row [0,0,0,1].\"\"\"\n  bottom = np.broadcast_to([0, 0, 0, 1.], p[..., :1, :4].shape)\n  return np.concatenate([p[..., :3, :4], bottom], axis=-2)\n\n\ndef unpad_poses(p: np.ndarray) -> np.ndarray:\n  \"\"\"Remove the homogeneous bottom row from [..., 4, 4] pose matrices.\"\"\"\n  return p[..., :3, :4]\n\n\ndef recenter_poses(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n  \"\"\"Recenter poses around the origin.\"\"\"\n  cam2world = average_pose(poses)\n  transform = np.linalg.inv(pad_poses(cam2world))\n  poses = transform @ pad_poses(poses)\n  return unpad_poses(poses), transform\n\n\ndef average_pose(poses: np.ndarray) -> np.ndarray:\n  \"\"\"New pose using average position, z-axis, and up vector of input poses.\"\"\"\n  position = poses[:, :3, 3].mean(0)\n  z_axis = poses[:, :3, 2].mean(0)\n  up = poses[:, :3, 1].mean(0)\n  cam2world = viewmatrix(z_axis, up, position)\n  return cam2world\n\ndef viewmatrix(lookdir: np.ndarray, up: np.ndarray,\n               position: np.ndarray) -> np.ndarray:\n  \"\"\"Construct lookat view matrix.\"\"\"\n  vec2 = normalize(lookdir)\n  vec0 = normalize(np.cross(up, vec2))\n  vec1 = normalize(np.cross(vec2, vec0))\n  m = np.stack([vec0, vec1, vec2, position], axis=1)\n  return m\n\ndef focus_point_fn(poses: np.ndarray) -> np.ndarray:\n  \"\"\"Calculate nearest point to all focal axes in poses.\"\"\"\n  directions, origins = poses[:, :3, 2:3], poses[:, :3, 3:4]\n  m = np.eye(3) - directions * np.transpose(directions, [0, 2, 1])\n  mt_m = np.transpose(m, [0, 2, 1]) @ m\n  focus_pt = np.linalg.inv(mt_m.mean(0)) @ (mt_m @ origins).mean(0)[:, 0]\n  return focus_pt\n\ndef transform_poses_pca(poses: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:\n  \"\"\"Transforms poses so principal components lie on XYZ axes.\n\n  Args:\n    poses: a (N, 3, 4) array containing the cameras' camera to world transforms.\n\n  Returns:\n    A tuple (poses, transform), with the transformed poses and the applied\n    camera_to_world transforms.\n  \"\"\"\n  t = poses[:, :3, 3]\n  t_mean = t.mean(axis=0)\n  t = t - t_mean\n\n  eigval, eigvec = np.linalg.eig(t.T @ t)\n  # Sort eigenvectors in order of largest to smallest eigenvalue.\n  inds = np.argsort(eigval)[::-1]\n  eigvec = eigvec[:, inds]\n  rot = eigvec.T\n  if np.linalg.det(rot) < 0:\n    rot = np.diag(np.array([1, 1, -1])) @ rot\n\n  transform = np.concatenate([rot, rot @ -t_mean[:, None]], -1)\n  poses_recentered = unpad_poses(transform @ pad_poses(poses))\n  transform = np.concatenate([transform, np.eye(4)[3:]], axis=0)\n\n  # Flip coordinate system if z component of y-axis is negative\n  if poses_recentered.mean(axis=0)[2, 1] < 0:\n    poses_recentered = np.diag(np.array([1, -1, -1])) @ poses_recentered\n    transform = np.diag(np.array([1, -1, -1, 1])) @ transform\n\n  return poses_recentered, transform\n  \n\ndef generate_ellipse_path(poses: np.ndarray,\n                          n_frames: int = 120,\n                          const_speed: bool = True,\n                          z_variation: float = 0.,\n                          z_phase: float = 0.) -> np.ndarray:\n  \"\"\"Generate an elliptical render path based on the given poses.\"\"\"\n  # Calculate the focal point for the path (cameras point toward this).\n  center = focus_point_fn(poses)\n  # Path height sits at z=0 (in middle of zero-mean capture pattern).\n  offset = np.array([center[0], center[1], 0])\n\n  # Calculate scaling for ellipse axes based on input camera positions.\n  sc = np.percentile(np.abs(poses[:, :3, 3] - offset), 90, axis=0)\n  # Use ellipse that is symmetric about the focal point in xy.\n  low = -sc + offset\n  high = sc + offset\n  # Optional height variation need not be symmetric\n  z_low = np.percentile((poses[:, :3, 3]), 10, axis=0)\n  z_high = np.percentile((poses[:, :3, 3]), 90, axis=0)\n\n  def get_positions(theta):\n    # Interpolate between bounds with trig functions to get ellipse in x-y.\n    # Optionally also interpolate in z to change camera height along path.\n    return np.stack([\n        low[0] + (high - low)[0] * (np.cos(theta) * .5 + .5),\n        low[1] + (high - low)[1] * (np.sin(theta) * .5 + .5),\n        z_variation * (z_low[2] + (z_high - z_low)[2] *\n                       (np.cos(theta + 2 * np.pi * z_phase) * .5 + .5)),\n    ], -1)\n\n  theta = np.linspace(0, 2. * np.pi, n_frames + 1, endpoint=True)\n  positions = get_positions(theta)\n\n  # Throw away duplicated last position.\n  positions = positions[:-1]\n\n  # Set path's up vector to axis closest to average of input pose up vectors.\n  avg_up = poses[:, :3, 1].mean(0)\n  avg_up = avg_up / np.linalg.norm(avg_up)\n  ind_up = np.argmax(np.abs(avg_up))\n  up = np.eye(3)[ind_up] * np.sign(avg_up[ind_up])\n\n  return np.stack([viewmatrix(p - center, up, p) for p in positions])\n\n\ndef generate_path(viewpoint_cameras, n_frames=480):\n  c2ws = np.array([np.linalg.inv(np.asarray((cam.world_view_transform.T).cpu().numpy())) for cam in viewpoint_cameras])\n  pose = c2ws[:,:3,:] @ np.diag([1, -1, -1, 1])\n  pose_recenter, colmap_to_world_transform = transform_poses_pca(pose)\n\n  # generate new poses\n  new_poses = generate_ellipse_path(poses=pose_recenter, n_frames=n_frames)\n  # warp back to orignal scale\n  new_poses = np.linalg.inv(colmap_to_world_transform) @ pad_poses(new_poses)\n\n  traj = []\n  for c2w in new_poses:\n      c2w = c2w @ np.diag([1, -1, -1, 1])\n      cam = copy.deepcopy(viewpoint_cameras[0])\n      cam.image_height = int(cam.image_height / 2) * 2\n      cam.image_width = int(cam.image_width / 2) * 2\n      cam.world_view_transform = torch.from_numpy(np.linalg.inv(c2w).T).float().cuda()\n      cam.full_proj_transform = (cam.world_view_transform.unsqueeze(0).bmm(cam.projection_matrix.unsqueeze(0))).squeeze(0)\n      cam.camera_center = cam.world_view_transform.inverse()[3, :3]\n      traj.append(cam)\n\n  return traj\n\ndef load_img(pth: str) -> np.ndarray:\n  \"\"\"Load an image and cast to float32.\"\"\"\n  with open(pth, 'rb') as f:\n    image = np.array(Image.open(f), dtype=np.float32)\n  return image\n\n\ndef create_videos(base_dir, input_dir, out_name, num_frames=480):\n  \"\"\"Creates videos out of the images saved to disk.\"\"\"\n  # Last two parts of checkpoint path are experiment name and scene name.\n  video_prefix = f'{out_name}'\n\n  zpad = max(5, len(str(num_frames - 1)))\n  idx_to_str = lambda idx: str(idx).zfill(zpad)\n\n  os.makedirs(base_dir, exist_ok=True)\n  render_dist_curve_fn = np.log\n  \n  # Load one example frame to get image shape and depth range.\n  depth_file = os.path.join(input_dir, 'vis', f'depth_{idx_to_str(0)}.tiff')\n  depth_frame = load_img(depth_file)\n  shape = depth_frame.shape\n  p = 3\n  distance_limits = np.percentile(depth_frame.flatten(), [p, 100 - p])\n  lo, hi = [render_dist_curve_fn(x) for x in distance_limits]\n  print(f'Video shape is {shape[:2]}')\n\n  video_kwargs = {\n      'shape': shape[:2],\n      'codec': 'h264',\n      'fps': 60,\n      'crf': 18,\n  }\n  \n  for k in ['depth', 'normal', 'color']:\n    video_file = os.path.join(base_dir, f'{video_prefix}_{k}.mp4')\n    input_format = 'gray' if k == 'alpha' else 'rgb'\n    \n\n    file_ext = 'png' if k in ['color', 'normal'] else 'tiff'\n    idx = 0\n\n    if k == 'color':\n      file0 = os.path.join(input_dir, 'renders', f'{idx_to_str(0)}.{file_ext}')\n    else:\n      file0 = os.path.join(input_dir, 'vis', f'{k}_{idx_to_str(0)}.{file_ext}')\n\n    if not os.path.exists(file0):\n      print(f'Images missing for tag {k}')\n      continue\n    print(f'Making video {video_file}...')\n    with media.VideoWriter(\n        video_file, **video_kwargs, input_format=input_format) as writer:\n      for idx in tqdm(range(num_frames)):\n        if k == 'color':\n          img_file = os.path.join(input_dir, 'renders', f'{idx_to_str(idx)}.{file_ext}')\n        else:\n          img_file = os.path.join(input_dir, 'vis', f'{k}_{idx_to_str(idx)}.{file_ext}')\n\n        if not os.path.exists(img_file):\n          ValueError(f'Image file {img_file} does not exist.')\n        img = load_img(img_file)\n        if k in ['color', 'normal']:\n          img = img / 255.\n        elif k.startswith('depth'):\n          img = render_dist_curve_fn(img)\n          img = np.clip((img - np.minimum(lo, hi)) / np.abs(hi - lo), 0, 1)\n          img = cm.get_cmap('turbo')(img)[..., :3]\n\n        frame = (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)\n        writer.add_image(frame)\n        idx += 1\n\ndef save_img_u8(img, pth):\n  \"\"\"Save an image (probably RGB) in [0, 1] to disk as a uint8 PNG.\"\"\"\n  with open(pth, 'wb') as f:\n    Image.fromarray(\n        (np.clip(np.nan_to_num(img), 0., 1.) * 255.).astype(np.uint8)).save(\n            f, 'PNG')\n\n\ndef save_img_f32(depthmap, pth):\n  \"\"\"Save an image (probably a depthmap) to disk as a float32 TIFF.\"\"\"\n  with open(pth, 'wb') as f:\n    Image.fromarray(np.nan_to_num(depthmap).astype(np.float32)).save(f, 'TIFF')"
  },
  {
    "path": "tools/semantic_id.py",
    "content": "\nBACKGROUND = 0\ntext_label_dict = {\n    'window': BACKGROUND,\n    'sky': BACKGROUND,\n    'sky window': BACKGROUND,\n    'window sky': BACKGROUND,\n    'floor': 2,\n}\n"
  },
  {
    "path": "tools/sh_utils.py",
    "content": "#  Copyright 2021 The PlenOctree Authors.\n#  Redistribution and use in source and binary forms, with or without\n#  modification, are permitted provided that the following conditions are met:\n#\n#  1. Redistributions of source code must retain the above copyright notice,\n#  this list of conditions and the following disclaimer.\n#\n#  2. Redistributions in binary form must reproduce the above copyright notice,\n#  this list of conditions and the following disclaimer in the documentation\n#  and/or other materials provided with the distribution.\n#\n#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE\n#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n#  POSSIBILITY OF SUCH DAMAGE.\n\nimport torch\n\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n    1.0925484305920792,\n    -1.0925484305920792,\n    0.31539156525252005,\n    -1.0925484305920792,\n    0.5462742152960396\n]\nC3 = [\n    -0.5900435899266435,\n    2.890611442640554,\n    -0.4570457994644658,\n    0.3731763325901154,\n    -0.4570457994644658,\n    1.445305721320277,\n    -0.5900435899266435\n]\nC4 = [\n    2.5033429417967046,\n    -1.7701307697799304,\n    0.9461746957575601,\n    -0.6690465435572892,\n    0.10578554691520431,\n    -0.6690465435572892,\n    0.47308734787878004,\n    -1.7701307697799304,\n    0.6258357354491761,\n]   \n\n\ndef eval_sh(deg, sh, dirs):\n    \"\"\"\n    Evaluate spherical harmonics at unit directions\n    using hardcoded SH polynomials.\n    Works with torch/np/jnp.\n    ... Can be 0 or more batch dimensions.\n    Args:\n        deg: int SH deg. Currently, 0-3 supported\n        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]\n        dirs: jnp.ndarray unit directions [..., 3]\n    Returns:\n        [..., C]\n    \"\"\"\n    assert deg <= 4 and deg >= 0\n    coeff = (deg + 1) ** 2\n    assert sh.shape[-1] >= coeff\n\n    result = C0 * sh[..., 0]\n    if deg > 0:\n        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]\n        result = (result -\n                C1 * y * sh[..., 1] +\n                C1 * z * sh[..., 2] -\n                C1 * x * sh[..., 3])\n\n        if deg > 1:\n            xx, yy, zz = x * x, y * y, z * z\n            xy, yz, xz = x * y, y * z, x * z\n            result = (result +\n                    C2[0] * xy * sh[..., 4] +\n                    C2[1] * yz * sh[..., 5] +\n                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +\n                    C2[3] * xz * sh[..., 7] +\n                    C2[4] * (xx - yy) * sh[..., 8])\n\n            if deg > 2:\n                result = (result +\n                C3[0] * y * (3 * xx - yy) * sh[..., 9] +\n                C3[1] * xy * z * sh[..., 10] +\n                C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +\n                C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +\n                C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +\n                C3[5] * z * (xx - yy) * sh[..., 14] +\n                C3[6] * x * (xx - 3 * yy) * sh[..., 15])\n\n                if deg > 3:\n                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +\n                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +\n                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +\n                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +\n                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +\n                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +\n                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +\n                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +\n                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])\n    return result\n\ndef RGB2SH(rgb):\n    return (rgb - 0.5) / C0\n\ndef SH2RGB(sh):\n    return sh * C0 + 0.5"
  },
  {
    "path": "tools/system_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom errno import EEXIST\nfrom os import makedirs, path\nimport os\n\ndef mkdir_p(folder_path):\n    # Creates a directory. equivalent to using mkdir -p on the command line\n    try:\n        makedirs(folder_path)\n    except OSError as exc: # Python >2.5\n        if exc.errno == EEXIST and path.isdir(folder_path):\n            pass\n        else:\n            raise\n\ndef searchForMaxIteration(folder):\n    saved_iters = [int(fname.split(\"_\")[-1]) for fname in os.listdir(folder)]\n    return max(saved_iters)\n"
  },
  {
    "path": "tools/termcolor.py",
    "content": "'''\n-----------------------------------------------------------------------------\nCopyright (c) 2023, NVIDIA CORPORATION. All rights reserved.\n\nNVIDIA CORPORATION and its licensors retain all intellectual property\nand proprietary rights in and to this software, related documentation\nand any modifications thereto. Any use, reproduction, disclosure or\ndistribution of this software and related documentation without an express\nlicense agreement from NVIDIA CORPORATION is strictly prohibited.\n-----------------------------------------------------------------------------\n'''\n\nimport pprint\n\nimport termcolor\n\n\ndef red(x): return termcolor.colored(str(x), color=\"red\")\ndef green(x): return termcolor.colored(str(x), color=\"green\")\ndef blue(x): return termcolor.colored(str(x), color=\"blue\")\ndef cyan(x): return termcolor.colored(str(x), color=\"cyan\")\ndef yellow(x): return termcolor.colored(str(x), color=\"yellow\")\ndef magenta(x): return termcolor.colored(str(x), color=\"magenta\")\ndef grey(x): return termcolor.colored(str(x), color=\"grey\")\n\n\nCOLORS = {\n    'red': red, 'green': green, 'blue': blue, 'cyan': cyan, 'yellow': yellow, 'magenta': magenta, 'grey': grey\n}\n\n\ndef PP(x):\n    string = pprint.pformat(x, indent=2)\n    if isinstance(x, dict):\n        string = '{\\n ' + string[1:-1] + '\\n}'\n    return string\n\n\ndef alert(x, color='red'):\n    color = COLORS[color]\n    print(color('-' * 32))\n    print(color(f'* {x}'))\n    print(color('-' * 32))\n"
  },
  {
    "path": "tools/visualization.py",
    "content": "import wandb\nimport imageio\nimport torch\nimport torchvision\n\nfrom matplotlib import pyplot as plt\nfrom torchvision.transforms import functional as torchvision_F\n\n\nPALETTE = [\n            (0, 0, 0),\n            (174, 199, 232), (152, 223, 138), (31, 119, 180), (255, 187, 120), (188, 189, 34),\n            (140, 86, 75), (255, 152, 150), (214, 39, 40), (197, 176, 213), (148, 103, 189),\n            (196, 156, 148), (23, 190, 207), (247, 182, 210), (219, 219, 141), (255, 127, 14),\n            (158, 218, 229), (44, 160, 44), (112, 128, 144), (227, 119, 194), (82, 84, 163),\n        ]\nPALETTE = torch.tensor(PALETTE, dtype=torch.uint8)\n\n\ndef wandb_image(images, from_range=(0, 1)):\n    images = preprocess_image(images, from_range=from_range)\n    wandb_image = wandb.Image(images)\n    return wandb_image\n\n\ndef preprocess_image(images, from_range=(0, 1), cmap=\"viridis\"):\n    min, max = from_range\n    images = (images - min) / (max - min)\n    images = images.detach().cpu().float().clamp_(min=0, max=1)\n    if images.shape[0] == 1:\n        images = get_heatmap(images, cmap=cmap)\n    images = tensor2pil(images)\n    return images\n\n\ndef wandb_sem(image, palette=PALETTE):\n    image = image.detach().long().cpu()\n    image = PALETTE[image].float().permute(2, 0, 1)[None]\n    image = tensor2pil(image)\n    wandb_image = wandb.Image(image)\n    return wandb_image\n\n\ndef tensor2pil(images):\n    image_grid = torchvision.utils.make_grid(images, nrow=1, pad_value=1)\n    image_grid = torchvision_F.to_pil_image(image_grid)\n    return image_grid\n\n\ndef get_heatmap(gray, cmap):  # [N,H,W]\n    color = plt.get_cmap(cmap)(gray.numpy())\n    color = torch.from_numpy(color[..., :3]).permute(0, 3, 1, 2).float()  # [N,3,H,W]\n    return color\n\n\ndef save_render(render, path):\n    image = torch.clamp(render, 0.0, 1.0).detach().cpu()\n    image = (image.permute(1, 2, 0).numpy() * 255).astype('uint8') # [..., ::-1]\n    imageio.imsave(path, image)\n\n\n"
  },
  {
    "path": "tools/visualize.py",
    "content": "'''\n-----------------------------------------------------------------------------\nCopyright (c) 2023, NVIDIA CORPORATION. All rights reserved.\n\nNVIDIA CORPORATION and its licensors retain all intellectual property\nand proprietary rights in and to this software, related documentation\nand any modifications thereto. Any use, reproduction, disclosure or\ndistribution of this software and related documentation without an express\nlicense agreement from NVIDIA CORPORATION is strictly prohibited.\n-----------------------------------------------------------------------------\n'''\n\nimport numpy as np\nimport torch\nimport matplotlib.pyplot as plt\nimport plotly.graph_objs as go\nimport k3d\n\nfrom tools import camera\n\n\ndef get_camera_mesh(pose, depth=1):\n    vertices = torch.tensor([[-0.5, -0.5, 1],\n                             [0.5, -0.5, 1],\n                             [0.5, 0.5, 1],\n                             [-0.5, 0.5, 1],\n                             [0, 0, 0]]) * depth  # [6,3]\n    faces = torch.tensor([[0, 1, 2],\n                          [0, 2, 3],\n                          [0, 1, 4],\n                          [1, 2, 4],\n                          [2, 3, 4],\n                          [3, 0, 4]])  # [6,3]\n    vertices = camera.cam2world(vertices[None], pose)  # [N,6,3]\n    wireframe = vertices[:, [0, 1, 2, 3, 0, 4, 1, 2, 4, 3]]  # [N,10,3]\n    return vertices, faces, wireframe\n\n\ndef merge_meshes(vertices, faces):\n    mesh_N, vertex_N = vertices.shape[:2]\n    faces_merged = torch.cat([faces + i * vertex_N for i in range(mesh_N)], dim=0)\n    vertices_merged = vertices.view(-1, vertices.shape[-1])\n    return vertices_merged, faces_merged\n\n\ndef merge_wireframes_k3d(wireframe):\n    wf_first, wf_last, wf_dummy = wireframe[:, :1], wireframe[:, -1:], wireframe[:, :1] * np.nan\n    wireframe_merged = torch.cat([wf_first, wireframe, wf_last, wf_dummy], dim=1)\n    return wireframe_merged\n\n\ndef merge_wireframes_plotly(wireframe):\n    wf_dummy = wireframe[:, :1] * np.nan\n    wireframe_merged = torch.cat([wireframe, wf_dummy], dim=1).view(-1, 3)\n    return wireframe_merged\n\n\ndef get_xyz_indicators(pose, length=0.1):\n    xyz = torch.eye(4, 3)[None] * length\n    xyz = camera.cam2world(xyz, pose)\n    return xyz\n\n\ndef merge_xyz_indicators_k3d(xyz):  # [N,4,3]\n    xyz = xyz[:, [[-1, 0], [-1, 1], [-1, 2]]]  # [N,3,2,3]\n    xyz_0, xyz_1 = xyz.unbind(dim=2)  # [N,3,3]\n    xyz_dummy = xyz_0 * np.nan\n    xyz_merged = torch.stack([xyz_0, xyz_0, xyz_1, xyz_1, xyz_dummy], dim=2)  # [N,3,5,3]\n    return xyz_merged\n\n\ndef merge_xyz_indicators_plotly(xyz):  # [N,4,3]\n    xyz = xyz[:, [[-1, 0], [-1, 1], [-1, 2]]]  # [N,3,2,3]\n    xyz_0, xyz_1 = xyz.unbind(dim=2)  # [N,3,3]\n    xyz_dummy = xyz_0 * np.nan\n    xyz_merged = torch.stack([xyz_0, xyz_1, xyz_dummy], dim=2)  # [N,3,3,3]\n    xyz_merged = xyz_merged.view(-1, 3)\n    return xyz_merged\n\n\ndef k3d_visualize_pose(poses, vis_depth=0.5, xyz_length=0.1, center_size=0.1, xyz_width=0.02, mesh_opacity=0.05):\n    # poses has shape [N,3,4] potentially in sequential order\n    N = len(poses)\n    centers_cam = torch.zeros(N, 1, 3)\n    centers_world = camera.cam2world(centers_cam, poses)\n    centers_world = centers_world[:, 0]\n    # Get the camera wireframes.\n    vertices, faces, wireframe = get_camera_mesh(poses, depth=vis_depth)\n    xyz = get_xyz_indicators(poses, length=xyz_length)\n    vertices_merged, faces_merged = merge_meshes(vertices, faces)\n    wireframe_merged = merge_wireframes_k3d(wireframe)\n    xyz_merged = merge_xyz_indicators_k3d(xyz)\n    # Set the color map for the camera trajectory and the xyz indicators.\n    color_map = plt.get_cmap(\"gist_rainbow\")\n    center_color = []\n    vertices_merged_color = []\n    wireframe_color = []\n    xyz_color = []\n    x_hex, y_hex, z_hex = int(255) << 16, int(255) << 8, int(255)\n    for i in range(N):\n        # Set the camera pose colors (with a smooth gradient color map).\n        r, g, b, _ = color_map(i / (N - 1))\n        r, g, b = r * 0.8, g * 0.8, b * 0.8\n        pose_rgb_hex = (int(r * 255) << 16) + (int(g * 255) << 8) + int(b * 255)\n        center_color += [pose_rgb_hex]\n        vertices_merged_color += [pose_rgb_hex] * 5\n        wireframe_color += [pose_rgb_hex] * 13\n        # Set the xyz indicator colors.\n        xyz_color += [x_hex] * 5 + [y_hex] * 5 + [z_hex] * 5\n    # Plot in K3D.\n    k3d_objects = [\n        k3d.points(centers_world, colors=center_color, point_size=center_size, shader=\"3d\"),\n        k3d.mesh(vertices_merged, faces_merged, colors=vertices_merged_color, side=\"double\", opacity=mesh_opacity),\n        k3d.line(wireframe_merged, colors=wireframe_color, shader=\"simple\"),\n        k3d.line(xyz_merged, colors=xyz_color, shader=\"thick\", width=xyz_width),\n    ]\n    return k3d_objects\n\n\ndef plotly_visualize_pose(poses, vis_depth=0.5, xyz_length=0.5, center_size=2, xyz_width=5, mesh_opacity=0.05):\n    # poses has shape [N,3,4] potentially in sequential order\n    N = len(poses)\n    centers_cam = torch.zeros(N, 1, 3)\n    centers_world = camera.cam2world(centers_cam, poses)\n    centers_world = centers_world[:, 0]\n    # Get the camera wireframes.\n    vertices, faces, wireframe = get_camera_mesh(poses, depth=vis_depth)\n    xyz = get_xyz_indicators(poses, length=xyz_length)\n    vertices_merged, faces_merged = merge_meshes(vertices, faces)\n    wireframe_merged = merge_wireframes_plotly(wireframe)\n    xyz_merged = merge_xyz_indicators_plotly(xyz)\n    # Break up (x,y,z) coordinates.\n    wireframe_x, wireframe_y, wireframe_z = wireframe_merged.unbind(dim=-1)\n    xyz_x, xyz_y, xyz_z = xyz_merged.unbind(dim=-1)\n    centers_x, centers_y, centers_z = centers_world.unbind(dim=-1)\n    vertices_x, vertices_y, vertices_z = vertices_merged.unbind(dim=-1)\n    # Set the color map for the camera trajectory and the xyz indicators.\n    color_map = plt.get_cmap(\"gist_rainbow\")\n    center_color = []\n    faces_merged_color = []\n    wireframe_color = []\n    xyz_color = []\n    x_color, y_color, z_color = *np.eye(3).T,\n    for i in range(N):\n        # Set the camera pose colors (with a smooth gradient color map).\n        r, g, b, _ = color_map(i / (N - 1))\n        rgb = np.array([r, g, b]) * 0.8\n        wireframe_color += [rgb] * 11\n        center_color += [rgb]\n        faces_merged_color += [rgb] * 6\n        xyz_color += [x_color] * 3 + [y_color] * 3 + [z_color] * 3\n    # Plot in plotly.\n    plotly_traces = [\n        go.Scatter3d(x=wireframe_x, y=wireframe_y, z=wireframe_z, mode=\"lines\",\n                     line=dict(color=wireframe_color, width=1)),\n        go.Scatter3d(x=xyz_x, y=xyz_y, z=xyz_z, mode=\"lines\", line=dict(color=xyz_color, width=xyz_width)),\n        go.Scatter3d(x=centers_x, y=centers_y, z=centers_z, mode=\"markers\",\n                     marker=dict(color=center_color, size=center_size, opacity=1)),\n        go.Mesh3d(x=vertices_x, y=vertices_y, z=vertices_z,\n                  i=[f[0] for f in faces_merged], j=[f[1] for f in faces_merged], k=[f[2] for f in faces_merged],\n                  facecolor=faces_merged_color, opacity=mesh_opacity),\n    ]\n    return plotly_traces\n"
  },
  {
    "path": "train.py",
    "content": "import os\nimport sys\nimport argparse\nsys.path.append(os.getcwd())\n\nfrom configs.config import Config, recursive_update_strict, parse_cmdline_arguments\nfrom trainer import Trainer\n\n\ndef parse_args():\n    parser = argparse.ArgumentParser(description='Training')\n    parser.add_argument('--config', help='Path to the training config file.', required=True)\n    parser.add_argument('--wandb', action='store_true', help=\"Enable using Weights & Biases as the logger\")\n    parser.add_argument('--wandb_name', default='default', type=str)\n    args, cfg_cmd = parser.parse_known_args()\n    return args, cfg_cmd\n\n\ndef main():\n    args, cfg_cmd = parse_args()\n    cfg = Config(args.config)\n\n    cfg_cmd = parse_cmdline_arguments(cfg_cmd)\n    recursive_update_strict(cfg, cfg_cmd)\n    \n    trainer = Trainer(cfg)\n    cfg.save_config(cfg.logdir)\n    \n    trainer.init_wandb(cfg,\n                       project=args.wandb_name,\n                       mode=\"disabled\" if cfg.train.debug_from > -1 or not args.wandb else \"online\",\n                       use_group=True)\n    \n    trainer.train()\n    trainer.finalize()\n    \n    return\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "trainer.py",
    "content": "import os\nimport json\nimport uuid\nimport math\nimport wandb\nimport imageio\nimport numpy as np\nfrom torch import nn\nfrom tqdm import tqdm\nimport math\nfrom random import randint\nimport torch.nn.functional as F\nfrom argparse import Namespace\nfrom pytorch3d.ops import knn_points\nfrom torchmetrics import JaccardIndex\n\nimport torch\nimport matplotlib.pyplot as plt\nfrom copy import deepcopy\n\nfrom tools.loss_utils import l1_loss, ssim, cos_weight, entropy_loss, monosdf_normal_loss, ScaleAndShiftInvariantLoss\nfrom gaussian_renderer import render, network_gui\nfrom scene import Scene, GaussianModel\nfrom tools.image_utils import psnr\nfrom configs.config import Config\nfrom tools.visualization import wandb_image, preprocess_image, wandb_sem\nfrom tools.prune import prune_list, calculate_v_imp_score, get_visi_list\nfrom tools.loss_utils import compute_normal_loss, L1_loss_appearance, normal2curv\nfrom tools.camera_utils import bb_camera\nfrom tools.general_utils import safe_state, set_random_seed\nfrom scene.cameras import SampleCam\nfrom tools.normal_utils import get_normal_sign, get_edge_aware_distortion_map\n# from process_data.extract_mask import text_label_dict\n\ntry:\n    from torch.utils.tensorboard import SummaryWriter\n    TENSORBOARD_FOUND = True\nexcept ImportError:\n    TENSORBOARD_FOUND = False\n\n\nclass Trainer(object):\n    def __init__(self, cfg):\n        self.cfg = cfg\n        set_random_seed(cfg.seed)\n        cfg.model.model_path = cfg.logdir\n        self.sphere = getattr(cfg.model, 'sphere', False)\n        cfg.model.load_normal = cfg.optim.loss_weight.mono_normal > 0 \\\n                                    or cfg.optim.loss_weight.depth_normal > 0\n        cfg.model.load_depth = cfg.optim.loss_weight.mono_depth > 0\n        self.enable_semantic = getattr(cfg.optim.loss_weight, 'semantic', 0) > 0\n        cfg.model.enable_semantic = self.enable_semantic\n        cfg.model.load_mask = self.enable_semantic or cfg.model.load_mask\n        cfg.print_config()\n        safe_state(cfg.silent)\n        \n        self.setup_model(cfg.model)\n        self.setup_dataset(cfg.model)\n        self.setup_optimizer(cfg.optim)\n        self.init_attributes()\n        self.init_losses()\n        \n        # Start GUI server, configure and run training\n        if cfg.port > 0:\n            network_gui.init(cfg.ip, cfg.port)\n        torch.autograd.set_detect_anomaly(cfg.detect_anomaly)\n\n    def setup_model(self, cfg):\n        self.model = GaussianModel(cfg)\n\n    def setup_dataset(self, cfg):\n        os.makedirs(cfg.model_path, exist_ok = True)\n        self.scene = Scene(cfg, self.model)\n        self.model.trans = torch.from_numpy(self.scene.trans).cuda()\n        self.model.scale = torch.from_numpy(self.scene.scale).cuda()\n        self.model.extent = self.scene.cameras_extent\n    \n    def init_writer(self, cfg):\n        if not cfg.model.model_path:\n            if os.getenv('OAR_JOB_ID'):\n                unique_str=os.getenv('OAR_JOB_ID')\n            else:\n                unique_str = str(uuid.uuid4())\n            cfg.model.model_path = os.path.join(\"./output/\", unique_str[0:10])\n            \n        # Set up output folder\n        print(\"Output folder: {}\".format(cfg.model.model_path))\n        os.makedirs(cfg.model.model_path, exist_ok = True)\n        with open(os.path.join(cfg.model.model_path, \"cfg_args\"), 'w') as cfg_log_f:\n            cfg_log_f.write(str(Namespace(**vars(cfg))))\n\n        # Create Tensorboard writer\n        if TENSORBOARD_FOUND:\n            self.writer = SummaryWriter(cfg.model.model_path)\n        else:\n            print(\"Tensorboard not available: not logging progress\")\n    \n    def init_wandb(self, cfg, wandb_id=None, project=\"\", run_name=None, mode=\"online\", resume=\"allow\", use_group=False):\n        r\"\"\"Initialize Weights & Biases (wandb) logger.\n\n        Args:\n            cfg (obj): Global configuration.\n            wandb_id (str): A unique ID for this run, used for resuming.\n            project (str): The name of the project where you're sending the new run.\n                If the project is not specified, the run is put in an \"Uncategorized\" project.\n            run_name (str): name for each wandb run (useful for logging changes)\n            mode (str): online/offline/disabled\n        \"\"\"\n        print('Initialize wandb')\n        if not wandb_id:\n            wandb_path = os.path.join(cfg.logdir, \"wandb_id.txt\")\n            if os.path.exists(wandb_path):\n                with open(wandb_path, \"r\") as f:\n                    wandb_id = f.read()\n            else:\n                wandb_id = wandb.util.generate_id()\n                with open(wandb_path, \"w\") as f:\n                    f.write(wandb_id)\n        if use_group:\n            group, name = cfg.logdir.split(\"/\")[-2:]\n        else:\n            group, name = None, os.path.basename(cfg.logdir)\n\n        if run_name is not None:\n            name = run_name\n\n        wandb.init(id=wandb_id,\n                    project=project,\n                    config=cfg,\n                    group=group,\n                    name=name,\n                    dir=cfg.logdir,\n                    resume=resume,\n                    settings=wandb.Settings(start_method=\"fork\"),\n                    mode=mode)\n        wandb.config.update({'dataset': cfg.data.name})\n\n    def init_losses(self):\n        r\"\"\"Initialize loss functions. All loss names have weights. Some have criterion modules.\"\"\"\n        self.losses = dict()\n        \n        self.weights = {key: value for key, value in self.cfg.optim.loss_weight.items() if value}\n        \n        if 'mono_depth' in self.weights:\n            self.depth_loss = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1)\n    \n    def setup_optimizer(self, cfg):\n        self.model.training_setup(cfg)\n    \n    def init_attributes(self):\n        self.iter_start = torch.cuda.Event(enable_timing = True)\n        self.iter_end = torch.cuda.Event(enable_timing = True)\n\n        self.viewpoint_stack = None\n        self.ema_loss_for_log = 0.0\n        \n        self.current_iteration = 0\n        self.max_iters = self.cfg.optim.iterations\n        self.saving_iterations = self.cfg.train.save_iterations\n        self.testing_iterations = self.cfg.train.test_iterations\n        self.checkpoint_iterations = self.cfg.train.checkpoint_iterations\n        \n        self.debug_from = self.cfg.train.debug_from\n        self.checkpoint = self.cfg.train.start_checkpoint\n        self.star_ft_iter = None\n        \n        self.visi_list = None\n        \n        self.first_iter = 0\n        if self.checkpoint:\n            (model_params, self.first_iter) = torch.load(self.checkpoint)\n            self.model.restore(model_params, self.cfg.optim)\n            \n        bg_color = [1, 1, 1] if self.cfg.model.white_background else [0, 0, 0]\n        self.background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n        self.writer = None\n        \n        with open(os.path.join(self.cfg.model.model_path, \"cfg_args\"), 'w') as cfg_log_f:\n            cfg_log_f.write(str(Namespace(**vars(self.cfg))))\n        \n        self.vis_path = os.path.join(self.cfg.logdir, \"vis\")\n        self.vis_color_path = os.path.join(self.vis_path, \"color\")\n        self.vis_depth_path = os.path.join(self.vis_path, \"depth\")\n        self.vis_normal_path = os.path.join(self.vis_path, \"normal\")\n        self.vis_dnormal_path = os.path.join(self.vis_path, \"dnormal\")\n        self.vis_cos_path = os.path.join(self.vis_path, \"cos\")\n        \n        for mode in ['train', 'test']:\n            os.makedirs(os.path.join(self.vis_color_path, mode), exist_ok=True)\n            os.makedirs(os.path.join(self.vis_depth_path, mode), exist_ok=True)\n            os.makedirs(os.path.join(self.vis_normal_path, mode), exist_ok=True)\n            os.makedirs(os.path.join(self.vis_dnormal_path, mode), exist_ok=True)\n            os.makedirs(os.path.join(self.vis_cos_path, mode), exist_ok=True)\n        \n        \n        if self.enable_semantic:\n            self.calc_miou = JaccardIndex(num_classes=self.model.num_cls, task='multiclass').cuda()\n    \n    def train(self):\n        progress_bar = tqdm(range(self.first_iter, self.max_iters), desc=\"Training progress\")\n        self.current_iteration += self.first_iter\n        self.first_iter += 1\n        for iteration in range(self.first_iter, self.max_iters  + 1):\n            self.current_iteration += 1\n            \n            self.start_of_iteration()\n            \n            output = self.train_step(mode='train')\n            \n            self.end_of_iteration(output, render, progress_bar)\n    \n    def get_center_scale(self):\n        meta_fname = f\"{self.cfg.model.source_path}/meta.json\"\n        with open(meta_fname) as file:\n            meta = json.load(file)\n        # center scene\n        trans = np.array(meta[\"trans\"], dtype=np.float32)\n        trans = torch.from_numpy(trans.astype(np.float32)).to(\"cuda\")\n        self.model.trans = torch.nn.parameter.Parameter(trans, requires_grad=False)\n        # scale scene\n        scale = np.array(meta[\"scale\"], dtype=np.float32)\n        scale = torch.from_numpy(scale.astype(np.float32)).to(\"cuda\")\n        self.model.scale = torch.nn.parameter.Parameter(scale, requires_grad=False)\n\n    def model_forward(self, data, mode):\n        render_pkg = render(data['viewpoint_cam'], self.model, self.cfg, data.pop('bg'), dirs=self.scene.dirs)\n        data.update(render_pkg)\n        self._compute_loss(data, mode)\n        loss = self._get_total_loss()\n        \n        return loss\n    \n    def _compute_loss(self, data, mode=None):\n        if mode == 'train':\n            gt_image = data['viewpoint_cam'].original_image.cuda()\n            self.losses['l1'] = l1_loss(data['render'], gt_image) if not self.cfg.model.use_decoupled_appearance \\\n                else L1_loss_appearance(data['render'], gt_image, self.model, data['viewpoint_cam'].idx)\n            self.losses['ssim'] = 1.0 - ssim(data['render'], gt_image)\n            \n            if 'l1_scale' in self.weights or 'entropy' in self.weights or 'proj' in self.weights or 'repul' in self.weights:\n                mask, _ = self.model.get_inside_gaus_normalized()\n            \n            if 'l1_scale' in self.weights and not self.sphere:\n                scaling = self.model.get_scaling[mask].min(-1)[0]\n                self.losses['l1_scale'] = l1_loss(scaling, torch.zeros_like(scaling))\n            \n            if 'entropy' in self.weights:\n                opacity = self.model.get_opacity[mask]\n                self.losses['entropy'] = entropy_loss(opacity)\n            \n            if 'mono_depth' in self.weights:\n                render_depth = data['depth']\n                gt_depth = data['viewpoint_cam'].depth.cuda().float()\n                mask = None\n                if self.cfg.model.load_mask:\n                    mask = data['viewpoint_cam'].mask\n                \n                mask = render_depth > 0\n                self.losses['mono_depth'] = self.depth_loss(render_depth, gt_depth, mask)\n            \n            if 'mono_normal' in self.weights and self.current_iteration > self.cfg.optim.normal_from_iter:\n                render_normal = data['normal']\n                gt_normal = data['viewpoint_cam'].normal.cuda()\n                self.losses['mono_normal'] = monosdf_normal_loss(render_normal, gt_normal)\n            \n            if 'depth_normal' in self.weights and self.current_iteration > self.cfg.optim.dnormal_from_iter:\n                est_normal = data['est_normal']\n                gt_normal = data['viewpoint_cam'].normal.cuda()\n                render_normal = data['normal'].detach()\n                mask = data['mask']\n                \n                with torch.no_grad():\n                    weights = cos_weight(render_normal, gt_normal, self.cfg.optim.exp_t)\n                \n                if mask.sum() != 0:\n                    est_normal, gt_normal = est_normal[mask], gt_normal[mask]\n                    render_normal = render_normal[mask]\n                    weights = weights[mask]\n                    self.losses['depth_normal'] = monosdf_normal_loss(est_normal, gt_normal, weights)\n                else: self.losses['depth_normal'] = 0\n                \n                if 'curv' in self.weights and self.current_iteration > self.cfg.optim.curv_from_iter:\n                    est_normal = data['est_normal']         # h, w, 3\n                    mask = data['mask'][..., None].clone()          # h, w, 1\n                    mask = mask.float()\n                    curv = normal2curv(est_normal, mask)\n                    self.losses['curv'] = l1_loss(curv, 0)\n                \n            if 'consistent_normal' in self.weights and self.current_iteration > self.cfg.optim.consistent_normal_from_iter:\n                est_normal = data['est_normal']\n                render_normal = data['normal']\n                mask = data['mask']\n                self.losses['consistent_normal'] = monosdf_normal_loss(est_normal, render_normal)\n                \n            if 'distortion' in self.weights and self.current_iteration > self.cfg.optim.close_depth_from_iter:\n                distortion_map = data['distortion']\n                distortion_map = get_edge_aware_distortion_map(gt_image, distortion_map)\n                self.losses['distortion'] = distortion_map.mean()\n            \n            if 'depth_var' in self.weights and self.current_iteration > self.cfg.optim.close_depth_from_iter:\n                depth_var = data['depth_var']\n                depth_var = get_edge_aware_distortion_map(gt_image, depth_var)\n                self.losses['depth_var'] = depth_var.mean()\n                \n            if 'semantic' in self.weights:\n                sem_logits = data['render_sem']\n                sem_trg = data['viewpoint_cam'].mask.view(-1)\n                self.losses['semantic'] = F.cross_entropy(sem_logits.view(-1, self.model.num_cls), sem_trg) / torch.log(torch.tensor(self.model.num_cls)) # normalize to (0,1)\n    \n    def _get_total_loss(self):\n        r\"\"\"Return the total loss to be backpropagated.\n        \"\"\"\n        total_loss = torch.tensor(0., device=torch.device('cuda'))\n        \n        # Iterates over all possible losses.\n        for loss_name in self.weights:\n            if loss_name in self.losses:\n                # Multiply it with the corresponding weight and add it to the total loss.\n                total_loss += self.losses[loss_name] * self.weights[loss_name]\n        self.losses['total'] = total_loss  # logging purpose\n        return total_loss\n    \n    def train_step(self, mode='train'):\n        data = dict()\n        # Pick a random Camera\n        if not self.viewpoint_stack:\n            self.viewpoint_stack = self.scene.getTrainCameras().copy()\n        data['viewpoint_cam'] = self.viewpoint_stack.pop(randint(0, len(self.viewpoint_stack)-1))\n\n        # Render\n        if (self.current_iteration - 1) == self.debug_from:\n            self.cfg.pipline.debug = True\n\n        data['bg'] = torch.rand((3), device=\"cuda\") if self.cfg.optim.random_background else self.background\n\n        loss = self.model_forward(data, mode)\n        \n        loss.backward()\n        viewspace_point_tensor, visibility_filter, radii = data.pop(\"viewspace_points\"), data.pop(\"visibility_filter\"), data.pop(\"radii\")\n\n        with torch.no_grad():\n            # Densification\n            if self.current_iteration < self.cfg.optim.densify_until_iter:\n                # Keep track of max radii in image-space for pruning\n                self.model.max_radii2D[visibility_filter] = torch.max(self.model.max_radii2D[visibility_filter], radii[visibility_filter])\n                viewspace_point_tensor_densify = data[\"viewspace_points_densify\"]\n                self.model.add_densification_stats(viewspace_point_tensor_densify, visibility_filter)\n                # self.model.add_densification_stats(viewspace_point_tensor, visibility_filter)\n\n                if self.current_iteration > self.cfg.optim.densify_from_iter \\\n                    and hasattr(self.cfg.optim, 'densify_large'):\n                        \n                    if 'countlist' in data:\n                        visi_list_each = data['countlist']\n                        self.visi_list = visi_list_each if self.visi_list is None else self.visi_list + visi_list_each\n\n                if self.current_iteration > self.cfg.optim.densify_from_iter and self.current_iteration % self.cfg.optim.densification_interval == 0:\n                    size_threshold = 20 if self.current_iteration > self.cfg.optim.opacity_reset_interval else None\n                    visi = None\n                    \n                    if getattr(self.cfg.optim, 'densify_large', False) and self.cfg.optim.densify_large.sample_cams.num > 0 \\\n                        and getattr(self.cfg.optim.densify_large, 'percent_dense', 0):\n                        visi = self.get_visi_mask_acc(self.cfg.optim.densify_large.sample_cams.num,\n                                                self.cfg.optim.densify_large.sample_cams.up,\n                                                self.cfg.optim.densify_large.sample_cams.around,\n                                                sample_mode='random')\n                        if self.visi_list is not None:\n                            visi = visi & self.visi_list > 0\n                    self.model.densify_and_prune(self.cfg.optim.densify_grad_threshold, 0.005, self.scene.cameras_extent, size_threshold, visi)\n                    self.visi_list = None\n                    \n                if self.current_iteration % self.cfg.optim.opacity_reset_interval == 0 or \\\n                    (self.cfg.model.white_background and self.current_iteration == self.cfg.optim.densify_from_iter):\n                    self.model.reset_opacity()\n\n            if self.current_iteration in self.cfg.optim.prune.iterations:\n                # TODO Add prunning types\n                n = int(len(self.scene.getFullCameras()) * 1.2)\n                viewpoint_stack = self.scene.getFullCameras().copy()\n                gaussian_list, imp_list = prune_list(self.model, viewpoint_stack, self.cfg.pipline, self.background)\n                i = self.cfg.optim.prune.iterations.index(self.current_iteration)\n                v_list = calculate_v_imp_score(self.model, imp_list, self.cfg.optim.prune.v_pow)\n                self.model.prune_gaussians(\n                    (self.cfg.optim.prune.decay**i) * self.cfg.optim.prune.percent, v_list\n                )\n            \n            \n            # Optimizer step\n            self.model.optimizer.step()\n            self.model.optimizer.zero_grad(set_to_none = True)\n        \n        return data\n\n    def start_of_iteration(self):\n        self.iter_start.record()\n        \n        # train or fine-tune\n        iter = self.current_iteration if self.star_ft_iter is None \\\n            else self.current_iteration - self.star_ft_iter\n        self.model.update_learning_rate(iter)\n\n        # Every 1000 its we increase the levels of SH up to a maximum degree\n        if self.current_iteration % 1000 == 0:\n            self.model.oneupSHdegree()\n        \n    def end_of_iteration(self, output, render, progress_bar):\n        self.iter_end.record()\n        \n        with torch.no_grad():\n            # Progress bar\n            self.ema_loss_for_log = 0.4 * self.losses['total'].item() + 0.6 * self.ema_loss_for_log\n            if self.current_iteration % 10 == 0:\n                progress_bar.set_postfix({\"Loss\": f\"{self.ema_loss_for_log:.{7}f}\"})\n                progress_bar.update(10)\n            if self.current_iteration == self.max_iters :\n                progress_bar.close()\n\n            # Log and save\n            if self.writer:\n                self.log_writer(output, mode=\"train\")\n            else:\n                output.update(self.test(render))\n                self.log_wandb_scalars(output, mode=\"train\")\n            \n            if (self.current_iteration in self.saving_iterations) or (self.current_iteration == self.max_iters):\n                self.save_gaussians()\n\n            if (self.current_iteration in self.checkpoint_iterations) or (self.current_iteration == self.max_iters):\n                print(\"\\n[ITER {}] Saving Checkpoint\".format(self.current_iteration))\n                torch.save((self.model.capture(), self.current_iteration), self.scene.model_path + \"/chkpnt\" + str(self.current_iteration) + \".pth\")\n                \n                if len(self.cfg.optim.prune.iterations) > 0 and self.current_iteration == self.max_iters:\n                    viewpoint_stack = self.scene.getFullCameras().copy()\n                    gaussian_list, imp_list = prune_list(self.model, viewpoint_stack, self.cfg.pipline, self.background)\n                    v_list = calculate_v_imp_score(self.model, imp_list, self.cfg.optim.prune.v_pow)\n                    np.savez(os.path.join(self.scene.model_path, \"imp_score\"), v_list.cpu().detach().numpy())\n    \n    def log_wandb_scalars(self, output, mode=None):\n        scalars = dict()\n        if mode == \"train\":\n            for param_group in self.model.optimizer.param_groups:\n                scalars.update({\"optim/lr_{}\".format(param_group[\"name\"]): param_group['lr']})\n                \n        scalars.update({\"time/iteration\": self.iter_start.elapsed_time(self.iter_end)})\n        scalars.update({f\"loss/{mode}_{key}\": value for key, value in self.losses.items()})\n        scalars.update(iteration=self.current_iteration)\n        \n        scalars.update({k: v for k, v in output.items() if isinstance(v, (int, float))})\n        \n        wandb.log(scalars, step=self.current_iteration)\n\n    def log_wandb_images(self, data, mode=None):\n        image = torch.cat([data[\"rgb_map\"], data[\"image\"]], dim=1)\n        depth = data[\"depth_map\"]\n        inv_depth = depth.max() - depth\n        images = {f'vis/{mode}': wandb_image(image),\n                  f'vis/{mode}_depth': wandb_image(depth, from_range=(depth.min(), depth.max())),\n                  f'vis/{mode}_inv_depth': wandb_image(inv_depth, from_range=(inv_depth.min(), inv_depth.max()))}\n        if 'depth_var' in data:\n            depth_var = data['depth_var']\n            images.update({f'vis/{mode}_depth_var': wandb_image(depth_var, from_range=(depth_var.min(), depth_var.max()))})\n        if 'depth' in data:\n            depth = data[\"depth\"].detach().clone()\n            images.update({f'vis/{mode}_depth_gt': wandb_image(depth, from_range=(depth.min(), depth.max()))})\n        if 'mask' in data:\n            mask = data['mask'].detach().clone().float()\n            images.update({f'vis/{mode}_mask': wandb_image(mask)})\n        if 'normal_map' in data:\n            normal_map = data[\"normal_map\"]\n            images.update({f'vis/{mode}_normal': wandb_image(normal_map.permute(2, 0, 1), from_range=(-1, 1))})\n            if 'normal' in data:\n                normal = data[\"normal\"].detach().clone()\n                images.update({f'vis/{mode}_normal_gt': wandb_image(normal.permute(2, 0, 1), from_range=(-1, 1))})\n                cos = cos_weight(normal.cuda(), normal_map, self.cfg.optim.exp_t)\n                images.update({f'vis/{mode}_normal_cos': wandb_image(cos, from_range=(0, 1))})\n            if 'est_normal' in data:\n                est_normal = data[\"est_normal\"].permute(2, 0, 1).detach().clone()\n                images.update({f'vis/{mode}_est_normal': wandb_image(est_normal, from_range=(-1, 1))})\n            if 'transformed_est_normal' in data:\n                transformed_est_normal = data[\"transformed_est_normal\"].permute(2, 0, 1).detach().clone()\n                images.update({f'vis/{mode}_trans_est_normal': wandb_image(transformed_est_normal, from_range=(-1, 1))})\n        if 'sem' in data:\n            sem = data['sem']\n            images.update({f'vis/{mode}_sem': wandb_sem(sem)})\n        if 'distortion' in data:\n            distortion = data['distortion']\n            images.update({f'vis/{mode}_distortion': wandb_image(distortion, from_range=(distortion.min(), distortion.max()))})\n        if 'depth_var' in data:\n            depth_var = data['depth_var']\n            images.update({f'vis/{mode}_depth_var': wandb_image(depth_var, from_range=(depth_var.min(), depth_var.max()))})\n        if 'trans_image' in data:\n            trans_image = data['trans_image']\n            images.update({f'vis/{mode}_trans': wandb_image(trans_image)})\n        wandb.log(images, step=self.current_iteration)\n    \n    def log_hist(self, tensor, name, num_bin=10):\n        counts, bins = np.histogram(tensor, bins=num_bin)\n        density = counts / counts.sum()\n        plt.stairs(density, bins)\n        plt.title('Histogram {}'.format(name))\n        wandb.log({f'statistic/{name}': wandb.Image(plt)}, step=self.current_iteration)\n        plt.close()\n    \n    @torch.no_grad()\n    def test(self, renderFunc):\n        output = dict()\n        # Report test and samples of training set\n        if (self.current_iteration in self.testing_iterations) or (self.current_iteration == self.max_iters):\n            torch.cuda.empty_cache()\n            validation_configs = ({'name': 'test', 'cameras' : self.scene.getTestCameras()}, \n                                {'name': 'train', 'cameras' : self.scene.getTrainCameras()})\n\n            for config in validation_configs:\n                if config['cameras'] and len(config['cameras']) > 0:\n                    l1_test = 0.0\n                    psnr_test = 0.0\n                    for idx, viewpoint in enumerate(config['cameras']):\n                        out = renderFunc(viewpoint, self.model, self.cfg, self.background, dirs=self.scene.dirs)\n                        image = torch.clamp(out[\"render\"], 0.0, 1.0)\n                        gt_image = torch.clamp(viewpoint.original_image.to(\"cuda\"), 0.0, 1.0)\n                        if config['name'] == 'train' and self.cfg.model.use_decoupled_appearance:\n                            trans_image = L1_loss_appearance(image, gt_image, self.model, viewpoint.idx, return_transformed_image=True)\n                            \n                        depth = out[\"depth\"]\n                        normal = out[\"normal\"] if \"normal\" in out else None\n                        est_normal = out[\"est_normal\"] if \"est_normal\" in out else None\n                        if 'render_sem' in out:\n                            pred = self.model.logits_2_label(out['render_sem'])\n                            sem_mask = viewpoint.mask.cuda()\n                            self.calc_miou.update(pred, sem_mask)\n                        if viewpoint.image_name == self.scene.first_name:\n                            data = {\"image\": gt_image, \"rgb_map\": image, \"depth_map\": depth}\n                            if config['name'] == 'train' and self.cfg.model.use_decoupled_appearance:\n                                data['trans_image'] = trans_image\n                            if 'mask' in out: data['mask'] = out['mask']\n                            if viewpoint.depth is not None: data['depth'] = viewpoint.depth\n                            if 'depth_var' in out: data['depth_var'] = out['depth_var']\n                            if 'distortion' in out: data['distortion'] = out['distortion']\n                            if normal is not None:\n                                data[\"normal_map\"] = normal\n                                if viewpoint.normal is not None: data['normal'] = viewpoint.normal\n                                if est_normal is not None:\n                                    data['est_normal'] = est_normal\n                            if 'render_sem' in out:\n                                pred = self.model.logits_2_label(out['render_sem']).to(torch.uint8)\n                                data['sem'] = torch.cat([pred, sem_mask], dim=0)\n                                \n                            self.log_wandb_images(data, mode=config['name'])\n                        \n                        if False:\n                            data = {\"image\": gt_image, \"rgb_map\": image, \"depth_map\": depth}\n                            if 'mask' in out: data['mask'] = out['mask']\n                            if viewpoint.depth is not None: data['depth'] = viewpoint.depth\n                            if 'depth_var' in out: data['depth_var'] = out['depth_var']\n                            if normal is not None:\n                                data[\"normal_map\"] = normal\n                                if viewpoint.normal is not None: data['normal'] = viewpoint.normal\n                                if est_normal is not None: data['est_normal'] = est_normal\n                            cos = cos_weight(normal.cuda(), normal, self.cfg.optim.exp_t)\n                            data['normal_cos'] = cos\n                            self.save_vis(data, viewpoint.image_name, mode=config['name'])\n                        \n                        l1_test += l1_loss(image, gt_image).mean().double()\n                        psnr_test += psnr(image, gt_image).mean().double()\n                    psnr_test /= len(config['cameras'])\n                    l1_test /= len(config['cameras'])\n                    \n                    if self.enable_semantic:\n                        miou = self.calc_miou.compute()\n                        self.calc_miou.reset()\n                    \n                    output.update({\n                        f'statistic/{config[\"name\"]}_PSNR': psnr_test.item(),\n                        f'loss/{config[\"name\"]}_l1': l1_test.item(),\n                    })\n                    if self.enable_semantic:\n                        output[f'statistic/{config[\"name\"]}_mIoU'] = miou.item()\n            \n            output.update({\n                'statistic/total_points': self.scene.gaussians.get_xyz.shape[0],\n            })\n            \n            self.log_hist(self.model.get_opacity.cpu().numpy(), \"opacity\")\n            \n            torch.cuda.empty_cache()\n        \n        return output\n\n    def finalize(self):\n        # Finish the W&B logger.\n        wandb.finish()\n\n    def log_writer(self, mode=None):\n        if self.writer:\n            for key, value in self.losses.items():\n                self.writer.add_scalar(f\"loss/{mode}_{key}\", value, global_step=self.current_iteration)\n    \n    def save_vis(self, data, name, mode='train'):\n        image = torch.clamp(data[\"rgb_map\"], 0.0, 1.0).detach().cpu()\n        image = (image.permute(1, 2, 0).numpy() * 255).astype('uint8')\n        imageio.imsave(os.path.join(self.vis_color_path, mode, f\"{name}.png\"), image)\n        \n        normal = preprocess_image(data[\"normal_map\"].permute(2, 0, 1), from_range=(-1, 1))\n        normal.save(os.path.join(self.vis_normal_path, mode, f\"{name}.png\"))\n        \n        if False:\n            normal_gt = preprocess_image(data[\"normal\"].permute(2, 0, 1), from_range=(-1, 1))\n            gt_normal_path = os.path.join(self.vis_normal_path+'_gt', mode)\n            if not os.path.exists(gt_normal_path):\n                os.makedirs(gt_normal_path, exist_ok=True)\n            normal_gt.save(os.path.join(gt_normal_path, f\"{name}.png\"))\n        \n        dnormal = preprocess_image(data[\"est_normal\"].permute(2, 0, 1), from_range=(-1, 1))\n        dnormal.save(os.path.join(self.vis_dnormal_path, mode, f\"{name}.png\"))\n        \n        cos = preprocess_image(data[\"normal_cos\"], from_range=(0, 1))\n        cos.save(os.path.join(self.vis_cos_path, mode, f\"{name}.png\"))\n        \n        return\n\n    def sample_cameras(self, n, up=False, around=True, look_mode='target', sample_mode='grid', bidirect=True): # direction target\n        cam_height = None\n        w2cs = bb_camera(n, self.model.trans, self.model.scale, cam_height, up=up, around=around, \\\n            look_mode=look_mode, sample_mode=sample_mode, bidirect=bidirect)\n        FoVx = FoVy = 2.5\n        width = height = 1500\n        cams = []\n        \n        for i in range(w2cs.shape[0]):\n            w2c = w2cs[i]\n            cam = SampleCam(w2c, width, height, FoVx, FoVy)\n            cams.append(cam)\n        \n        return cams\n    \n    @torch.no_grad()\n    def get_visi_mask(self, n=500, up=False, around=True, denoise_after=False, \\\n        denoise_before=True, nb_points=10, viewpoint_stack=None, sample_mode='grid', cat_cams=False): # direction target\n        if viewpoint_stack is None:\n            if self.cfg.optim.densify_large.sample_cams.random:\n                viewpoint_stack = self.sample_cameras(n, up, around, sample_mode=sample_mode)\n                if cat_cams:\n                    viewpoint_stack += self.scene.getTrainCameras().copy()\n            else:\n                viewpoint_stack = self.scene.getTrainCameras().copy()\n        \n        model = deepcopy(self.model)\n        \n        if denoise_before:\n            mask = torch.ones(model.get_xyz.shape[0], dtype=torch.bool, device=\"cuda\")\n            valid = model.filter_points()\n            mask[valid] = False\n        \n            model.prune_points(mask)\n        else:\n            mask = torch.zeros(model.get_xyz.shape[0], dtype=torch.bool, device=\"cuda\")\n        \n        xyz = model.get_xyz[None]\n        dist2 = knn_points(xyz, xyz, K=nb_points+1, return_sorted=True).dists # 1, N, K\n        dist2 = dist2[0, :, 1:]\n        dist2 = torch.clamp_min(dist2, 0.0000001)\n        dist = (torch.sqrt(dist2)).mean(-1)\n        scaling = dist\n        \n        scales = torch.log(scaling)[...,None].repeat(1, 3)\n        \n        idx = torch.argmin(model.get_scaling, dim=-1)\n        scales[torch.arange(scales.shape[0]), idx] = math.log(1e-7)\n        model._scaling = nn.Parameter(scales.requires_grad_(True))\n        \n        out = get_visi_list(model, viewpoint_stack, self.cfg.pipline, self.background)\n        \n        visi = out['visi']\n        \n        valid = ~mask\n        if denoise_after:\n            model.prune_points(~visi)\n            filted = model.filter_points()\n            visi[visi.clone()] = filted\n            \n        valid[~mask] = visi\n        \n        del model\n        \n        return valid\n\n    @torch.no_grad()\n    def get_visi_mask_acc(self, n=500, up=False, around=True, sample_mode='grid', viewpoint_stack=None):\n        if viewpoint_stack is None:\n            if self.cfg.optim.densify_large.sample_cams.random:\n                viewpoint_stack = self.sample_cameras(n, up, around, sample_mode=sample_mode)\n            else:\n                fullcam = self.scene.getTrainCameras().copy()\n                idx = torch.randint(0, len(fullcam), (n,))\n                viewpoint_stack = [fullcam[i] for i in idx]\n            \n        out = get_visi_list(self.model, viewpoint_stack, self.cfg.pipline, self.background)\n        visi = out['visi']\n        inside = self.model.get_inside_gaus_normalized()[0]\n        valid = visi & inside\n        \n        return valid\n\n    @torch.no_grad()\n    def save_gaussians(self):\n        print(\"\\n[ITER {}] Saving Gaussians\".format(self.current_iteration))\n        \n        surfmask = None\n        visi = None\n        self.scene.save(self.current_iteration, visi=visi, surf=surfmask, save_splat=self.cfg.train.save_splat)\n    \n\nif __name__ == \"__main__\":\n    from configs.config import Config\n    import sys\n    sys.path.append(os.getcwd())\n    \n    cfg_path = 'projects/gaussain_splatting/configs/base.yaml'\n    \n    cfg = Config(cfg_path)\n    \n    trainer = Trainer(cfg)\n    \n    trainer.get_center_scale()\n    \n    for thr in np.linspace(0.9, 1., 11):\n        trainer.save_pts_thr(thr)\n    \n"
  }
]