[
  {
    "path": ".gitignore",
    "content": "*.pyc\n# .vscode\n.git---\noutput\nbuild\ndiff_rasterization/diff_rast.egg-info\ndiff_rasterization/dist\ntensorboard_3d\nscreenshots\n*.ipynb_checkpoints\n# submodules/\n# assets/\n*.npz\n*.bundle\noutput*\n*.log\nlog"
  },
  {
    "path": "LICENSE.md",
    "content": "Gaussian-Splatting License  \n===========================  \n\n**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.  \nThe *Software* is in the process of being registered with the Agence pour la Protection des  \nProgrammes (APP).  \n\nThe *Software* is still being developed by the *Licensor*.  \n\n*Licensor*'s goal is to allow the research community to use, test and evaluate  \nthe *Software*.  \n\n## 1.  Definitions  \n\n*Licensee* means any person or entity that uses the *Software* and distributes  \nits *Work*.  \n\n*Licensor* means the owners of the *Software*, i.e Inria and MPII  \n\n*Software* means the original work of authorship made available under this  \nLicense ie gaussian-splatting.  \n\n*Work* means the *Software* and any additions to or derivative works of the  \n*Software* that are made available under this License.  \n\n\n## 2.  Purpose  \nThis license is intended to define the rights granted to the *Licensee* by  \nLicensors under the *Software*.  \n\n## 3.  Rights granted  \n\nFor the above reasons Licensors have decided to distribute the *Software*.  \nLicensors grant non-exclusive rights to use the *Software* for research purposes  \nto research users (both academic and industrial), free of charge, without right  \nto sublicense.. The *Software* may be used \"non-commercially\", i.e., for research  \nand/or evaluation purposes only.  \n\nSubject to the terms and conditions of this License, you are granted a  \nnon-exclusive, royalty-free, license to reproduce, prepare derivative works of,  \npublicly display, publicly perform and distribute its *Work* and any resulting  \nderivative works in any form.  \n\n## 4.  Limitations  \n\n**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do  \nso under this License, (b) you include a complete copy of this License with  \nyour distribution, and (c) you retain without modification any copyright,  \npatent, trademark, or attribution notices that are present in the *Work*.  \n\n**4.2 Derivative Works.** You may specify that additional or different terms apply  \nto the use, reproduction, and distribution of your derivative works of the *Work*  \n(\"Your Terms\") only if (a) Your Terms provide that the use limitation in  \nSection 2 applies to your derivative works, and (b) you identify the specific  \nderivative works that are subject to Your Terms. Notwithstanding Your Terms,  \nthis License (including the redistribution requirements in Section 3.1) will  \ncontinue to apply to the *Work* itself.  \n\n**4.3** Any other use without of prior consent of Licensors is prohibited. Research  \nusers explicitly acknowledge having received from Licensors all information  \nallowing to appreciate the adequacy between of the *Software* and their needs and  \nto undertake all necessary precautions for its execution and use.  \n\n**4.4** The *Software* is provided both as a compiled library file and as source  \ncode. In case of using the *Software* for a publication or other results obtained  \nthrough the use of the *Software*, users are strongly encouraged to cite the  \ncorresponding publications as explained in the documentation of the *Software*.  \n\n## 5.  Disclaimer  \n\nTHE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES  \nWITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY  \nUNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL  \nCONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED \"AS IS\" WITHOUT ANY WARRANTIES  \nOF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL  \nUSE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR  \nADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE  \nAUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR  \nCONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE  \nGOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)  \nHOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT  \nLIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR  \nIN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.  \n\n## 6.  Files subject to permissive licenses\nThe contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. \n\nTitle: pytorch-ssim\\\nProject code: https://github.com/Po-Hsun-Su/pytorch-ssim\\\nCopyright Evan Su, 2017\\\nLicense: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT)"
  },
  {
    "path": "README.md",
    "content": "<div align=\"center\">\n\n# [NeurIPS2024🔥] OpenGaussian: Towards Point-Level 3D Gaussian-based Open Vocabulary Understanding\n\n<!-- <a href=\"https://arxiv.org/abs/2406.02058\"><strong>Paper</strong></a> |  -->\n\n<h3>\n  <strong>Paper(<a href=\"https://arxiv.org/abs/2406.02058\">arXiv</a> / <a href=\"https://proceedings.neurips.cc/paper_files/paper/2024/hash/21f7b745f73ce0d1f9bcea7f40b1388e-Abstract-Conference.html\">Conference</a>)</strong> | \n  <a href=\"https://3d-aigc.github.io/OpenGaussian/\"><strong>Project Page</strong></a>\n</h3>\n\n<!-- [**Paper**](https://arxiv.org/abs/2406.02058) | [**Project Page**](https://3d-aigc.github.io/OpenGaussian/) -->\n<!-- [![arXiv](https://img.shields.io/badge/arXiv-<Paper>-<COLOR>.svg)](https://arxiv.org/abs/2406.02058)\n[![Project Page](https://img.shields.io/badge/Project_Page-<Website>-blue.svg)](https://3d-aigc.github.io/OpenGaussian/) -->\n\n[Yanmin Wu](https://yanmin-wu.github.io/)<sup>1</sup>, [Jiarui Meng](https://scholar.google.com/citations?user=N_pRAVAAAAAJ&hl=en&oi=ao)<sup>1</sup>, [Haijie Li](https://villa.jianzhang.tech/people/haijie-li-%E6%9D%8E%E6%B5%B7%E6%9D%B0/)<sup>1</sup>, [Chenming Wu](https://chenming-wu.github.io/)<sup>2*</sup>, [Yahao Shi](https://scholar.google.com/citations?user=-VJZrUkAAAAJ&hl=en)<sup>3</sup>, [Xinhua Cheng](https://cxh0519.github.io/)<sup>1</sup>, \n[Chen Zhao](https://openreview.net/profile?id=~Chen_Zhao9)<sup>2</sup>, [Haocheng Feng](https://openreview.net/profile?id=~Haocheng_Feng1)<sup>2</sup>, [Errui Ding](https://scholar.google.com/citations?user=1wzEtxcAAAAJ&hl=zh-CN)<sup>2</sup>, [Jingdong Wang](https://jingdongwang2017.github.io/)<sup>2</sup>, [Jian Zhang](https://jianzhang.tech/)<sup>1*</sup>\n\n<sup>1</sup> Peking University, <sup>2</sup> Baidu VIS, <sup>3</sup> Beihang University\n\n</div>\n\n## 0. Installation\n\nThe installation of OpenGaussian is similar to [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting).\n```\ngit clone https://github.com/yanmin-wu/OpenGaussian.git\n```\nThen install the dependencies:\n```shell\nconda env create --file environment.yml\nconda activate gaussian_splatting\n\n# the rasterization lib comes from DreamGaussian\ncd OpenGaussian/submodules\nunzip ashawkey-diff-gaussian-rasterization.zip\npip install ./ashawkey-diff-gaussian-rasterization\n```\n+ other additional dependencies: bitarray, scipy, [pytorch3d](https://anaconda.org/pytorch3d/pytorch3d/files)\n    ```shell\n    pip install bitarray scipy\n    \n    # install a pytorch3d version compatible with your PyTorch, Python, and CUDA.\n    ```\n+ `simple-knn` is not required\n\n---\n\n## 1. ToDo list\n\n+ [x] Point feature visualization\n+ [x] Data preprocessing\n+ ~~[ ] Improved SAM mask extraction (extracting only one layer)~~\n+ [x] Click to Select 3D Object\n\n---\n\n## 2. Data preparation\nThe files are as follows:\n```\n[DATA_ROOT]\n├── [1] scannet/\n│   │   ├── scene0000_00/\n|   |   |   |── color/\n|   |   |   |── language_features/\n|   |   |   |── points3d.ply\n|   |   |   |── transforms_train/test.json\n|   |   |   |── *_vh_clean_2.labels.ply\n│   │   ├── scene0062_00/\n│   │   └── ...\n├── [2] lerf_ovs/\n│   │   ├── figurines/ & ramen/ & teatime/ & waldo_kitchen/\n|   |   |   |── images/\n|   |   |   |── language_features/\n|   |   |   |── sparse/\n│   │   ├── label/\n```\n+ **[1] Prepare ScanNet Data**\n    + You can directly download our pre-processed data: [**OneDrive**](https://onedrive.live.com/?authkey=%21AIgsXZy3gl%5FuKmM&id=744D3E86422BE3C9%2139813&cid=744D3E86422BE3C9) / [Baidu](https://pan.baidu.com/s/1B_tGYla5dWyJRu3jTNTMvA?pwd=u5iy). Please unzip the `color.zip` and `language_features.zip` files.\n    + The ScanNet dataset requires permission for use, following the [ScanNet instructions](https://github.com/ScanNet/ScanNet) to apply for dataset permission.\n    + **If you want to process more scenes from the ScanNet dataset, you can follow these steps:**\n\t    + First, use the official `download-scannet.py` script provided by ScanNet to download the `.sens` archive of the specified scenes;\n\t    + Then, refer to the [`preprocess_2d_scannet.py`](https://github.com/pengsongyou/openscene/blob/main/scripts/preprocess/preprocess_2d_scannet.py) script to extract the `color` and `pose` information;\n\t    + Finally, convert the data into Blender format using the [`scripts/scannet2blender.py`](https://github.com/yanmin-wu/OpenGaussian/blob/main/scripts/scannet2blender.py) script. Please check the `TODO` comments in the script to specify the paths.\n+ **[2] Prepare lerf_ovs Data**\n    + You can directly download our pre-processed data: [**OneDrive**](https://onedrive.live.com/?authkey=%21AIgsXZy3gl%5FuKmM&id=744D3E86422BE3C9%2139815&cid=744D3E86422BE3C9) / [Baidu](https://pan.baidu.com/s/1B_tGYla5dWyJRu3jTNTMvA?pwd=u5iy) (re-annotated by LangSplat). Please unzip the `images.zip` and `language_features.zip` files.\n+ **Mask and Language Feature Extraction Details**\n    + We use the tools provided by LangSplat to extract the SAM mask and CLIP features, but we only use the large-level mask.\n\n---\n\n## 3. Training\n### 3.1 ScanNet\n```shell\nchmod +x scripts/train_scannet.sh\n./scripts/train_scannet.sh\n```\n+ Please ***check*** the script for more details and ***modify*** the dataset path.\n+ you will see the following processes during training:\n    ```shell\n    [Stage 0] Start 3dgs pre-train ... (step 0-30k)\n    [Stage 1] Start continuous instance feature learning ... (step 30-50k)\n    [Stage 2.1] Start coarse-level codebook discretization ... (step 50-70k)\n    [Stage 2.2] Start fine-level codebook discretization ... (step 70-90k)\n    [Stage 3] Start 2D language feature - 3D cluster association ... (1 min)\n    ```\n+ Intermediate results from different stages can be found in subfolders `***/train_process/stage*`. (The intermediate results of stage 3 are recommended to be observed in the LeRF dataset.)\n\n### 3.2 LeRF_ovs\n```shell\nchmod +x scripts/train_lerf.sh\n./scripts/train_lerf.sh\n```\n+ Please ***check*** the script for more details and ***modify*** the dataset path.\n+ you will see the following processes during training:\n    ```shell\n    [Stage 0] Start 3dgs pre-train ... (step 0-30k)\n    [Stage 1] Start continuous instance feature learning ... (step 30-40k)\n    [Stage 2.1] Start coarse-level codebook discretization ... (step 40-50k)\n    [Stage 2.2] Start fine-level codebook discretization ... (step 50-70k)\n    [Stage 3] Start 2D language feature - 3D cluster association ... (1 min)\n    ```\n+ Intermediate results from different stages can be found in subfolders `***/train_process/stage*`.\n\n### 3.3 Custom data\n+ Without any special processing, videos are first captured, approximately 200 frames are sampled, and COLMAP is then used to initialize the point cloud and camera poses.\n\n---\n\n## 4. Render & Eval & Downstream Tasks\n\n### 4.1 3D Instance Feature Visualization\n+ Please install `open3d` first, and then execute the following command on a system with UI support:\n    ```python\n    python scripts/vis_opengs_pts_feat.py\n    ```\n    + Please specify `ply_path` in the script as the PLY file `output/xxxxxxxx-x/point_cloud/iteration_x0000/point_cloud.ply` saved at different stages.\n    + During the training process, we have saved the first three dimensions of the 6D features as colors for visualization; see [here](https://github.com/yanmin-wu/OpenGaussian/blob/2845b9c744c1b06ac6930ffa2d2a6f9167f1b843/scene/gaussian_model.py#L272).\n\n### 4.2 Render 2D Feature Map\n+ The same rendering method as the 3DGS rendering colors.\n    ```shell\n    python render.py -m \"output/xxxxxxxx-x\"\n    ```\n    You can find the rendered feature maps in subfolders `renders_ins_feat1` and `renders_ins_feat2`.\n\n### 4.3 ScanNet Evalution (Open-Vocabulary Point Cloud Understanding)\n> Due to code optimization and the use of more suitable hyperparameters, the latest evaluation metrics may be higher than those reported in the paper. \n+ Evaluate text-guided segmentation performance on ScanNet for 19, 15, and 10 categories.\n    ```shell\n    # unzip the pre-extracted text features\n    cd assets\n    unzip text_features.zip\n\n    # 1. please check the `gt_file_path` and `model_path` are correct\n    # 2. specify `target_id` as 19, 15, or 10 categories.\n    python scripts/eval_scannet.py\n    ```\n\n### 4.4 LeRF Evalution (Open-Vocabulary Object Selection in 3D Space)\n+ (1) First, render text-selected 3D Gaussians into multi-view images.\n    ```shell\n    # unzip the pre-extracted text features\n    cd assets\n    unzip text_features.zip\n\n    # 1. specify the model path using -m\n    # 2. specify the scene name: figurines, teatime, ramen, waldo_kitchen\n    python render_lerf_by_text.py -m \"output/xxxxxxxx-x\" --scene_name \"figurines\"\n    ```\n    The object selection results are saved in `output/xxxxxxxx-x/text2obj/ours_70000/renders_cluster`.\n\n+ (2) Then, compute evaluation metrics.\n    > Due to code optimization and the use of more suitable hyperparameters, the latest evaluation metrics may be higher than those reported in the paper. \n    > The metrics may be unstable due to the limited evaluation samples of LeRF.\n    ```shell\n    # 1. change path_gt and path_pred in the script\n    # 2. specify the scene name: figurines, teatime, ramen, waldo_kitchen\n    python scripts/compute_lerf_iou.py --scene_name \"figurines\"\n    ```\n\n### 4.5 Click to Select 3D Object\n\n+ (1) First, you need to render the feature maps (refer to Step 4.3; in practice, only two feature maps from a single view are required).\n+ (2) Then, check the [`scripts/render_by_click.py`](https://github.com/yanmin-wu/OpenGaussian/blob/main/scripts/render_by_click.py) script for `TODO` comments, including specifying the frame filename, clicked pixel coordinates, and file paths.\n+ (3) Finally, run the [`scripts/render_by_click.py`](https://github.com/yanmin-wu/OpenGaussian/blob/main/scripts/render_by_click.py) script. *Note that this script has not been tested with the current version of the code and may require debugging*.\n\n---\n\n## 5. Acknowledgements\nWe are quite grateful for [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), [LangSplat](https://github.com/minghanqin/LangSplat), [CompGS](https://github.com/UCDvision/compact3d), [LEGaussians](https://github.com/buaavrcg/LEGaussians), [SAGA](https://github.com/Jumpat/SegAnyGAussians), and [SAM](https://segment-anything.com/).\n\n---\n\n## 6. Citation\n\n```\n@inproceedings{wu2024opengaussian,\n    title={OpenGaussian: Towards Point-Level 3D Gaussian-based Open Vocabulary Understanding},\n    author={Wu, Yanmin and Meng, Jiarui and Li, Haijie and Wu, Chenming and Shi, Yahao and Cheng, Xinhua and Zhao, Chen and Feng, Haocheng and Ding, Errui and Wang, Jingdong and Zhang, Jian},\n    booktitle={Proceedings of the Advances in Neural Information Processing Systems (NeurIPS)},\n    pages={19114--19138},\n    year={2024}\n}\n```\n\n---\n\n## 7. Contact\nIf you have any questions about this project, please feel free to contact [Yanmin Wu](https://yanmin-wu.github.io/): wuyanminmax[AT]gmail.com\n"
  },
  {
    "path": "arguments/__init__.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom argparse import ArgumentParser, Namespace\nimport sys\nimport os\n\nclass GroupParams:\n    pass\n\nclass ParamGroup:\n    def __init__(self, parser: ArgumentParser, name : str, fill_none = False):\n        group = parser.add_argument_group(name)\n        for key, value in vars(self).items():\n            shorthand = False\n            if key.startswith(\"_\"):\n                shorthand = True\n                key = key[1:]\n            t = type(value)\n            value = value if not fill_none else None \n            if shorthand:\n                if t == bool:\n                    group.add_argument(\"--\" + key, (\"-\" + key[0:1]), default=value, action=\"store_true\")\n                else:\n                    group.add_argument(\"--\" + key, (\"-\" + key[0:1]), default=value, type=t)\n            else:\n                if t == bool:\n                    group.add_argument(\"--\" + key, default=value, action=\"store_true\")\n                else:\n                    group.add_argument(\"--\" + key, default=value, type=t)\n\n    def extract(self, args):\n        group = GroupParams()\n        for arg in vars(args).items():\n            if arg[0] in vars(self) or (\"_\" + arg[0]) in vars(self):\n                setattr(group, arg[0], arg[1])\n        return group\n\nclass ModelParams(ParamGroup): \n    def __init__(self, parser, sentinel=False):\n        self.sh_degree = 3\n        self._source_path = \"\"\n        self._model_path = \"\"\n        self._images = \"images\"\n        self._resolution = -1\n        self._white_background = False\n        self.data_device = \"cuda\"\n        self.eval = False\n        super().__init__(parser, \"Loading Parameters\", sentinel)\n\n    def extract(self, args):\n        g = super().extract(args)\n        g.source_path = os.path.abspath(g.source_path)\n        return g\n\nclass PipelineParams(ParamGroup):\n    def __init__(self, parser):\n        self.convert_SHs_python = False\n        self.compute_cov3D_python = False\n        self.debug = False\n        super().__init__(parser, \"Pipeline Parameters\")\n\nclass OptimizationParams(ParamGroup):\n    def __init__(self, parser):\n        self.leaf_update_fr = 300           # coarse-level codebook update frequency\n        self.ins_feat_dim = 6\n        self.position_lr_init = 0.00016\n        self.position_lr_final = 0.0000016\n        self.position_lr_delay_mult = 0.01\n        self.position_lr_max_steps = 30_000\n        self.feature_lr = 0.0025\n        self.ins_feat_lr = 0.001\n        self.opacity_lr = 0.05\n        self.scaling_lr = 0.005\n        self.rotation_lr = 0.001\n        self.percent_dense = 0.01\n        self.lambda_dssim = 0.2\n        self.densification_interval = 100\n        self.opacity_reset_interval = 3000\n        self.densify_from_iter = 500\n        self.densify_until_iter = 15_000\n        self.densify_grad_threshold = 0.0002\n        self.random_background = False\n\n        parser.add_argument('--root_node_num', type=int, default=64)    # k1=64\n        parser.add_argument('--leaf_node_num', type=int, default=5)     # k2=5/10\n\n        parser.add_argument('--pos_weight', type=float, default=1.0)    # position weight for coarse codebook\n        parser.add_argument('--loss_weight', type=float, default=0.1)   # loss_cohesion weight\n\n        parser.add_argument('--iterations', type=int, default=70_000)   # default 7w, scannet 9w\n        parser.add_argument('--start_ins_feat_iter', type=int, default=30_000)  # default 3w\n        parser.add_argument('--start_root_cb_iter', type=int, default=40_000)   # default 4w, scannet 5w\n        parser.add_argument('--start_leaf_cb_iter', type=int, default=50_000)   # default 5w, scannet 7w\n\n        # note: Freeze the position of the initial point, do not densify. for ScanNet\n        parser.add_argument('--frozen_init_pts', action='store_true', default=False)\n        parser.add_argument('--sam_level', type=int, default=3)\n\n        parser.add_argument('--save_memory', action='store_true', default=False)\n        super().__init__(parser, \"Optimization Parameters\")\n    \n    def extract(self, args):\n        g = super().extract(args)\n        g.root_node_num = args.root_node_num\n        g.leaf_node_num = args.leaf_node_num\n        g.pos_weight = args.pos_weight\n        g.loss_weight = args.loss_weight\n        g.frozen_init_pts = args.frozen_init_pts\n        g.sam_level = args.sam_level\n        g.iterations = args.iterations\n        g.start_ins_feat_iter = args.start_ins_feat_iter\n        g.start_root_cb_iter = args.start_root_cb_iter\n        g.start_leaf_cb_iter = args.start_leaf_cb_iter\n        g.save_memory = args.save_memory\n\n        return g\n\ndef get_combined_args(parser : ArgumentParser):\n    cmdlne_string = sys.argv[1:]\n    cfgfile_string = \"Namespace()\"\n    args_cmdline = parser.parse_args(cmdlne_string)\n\n    try:\n        cfgfilepath = os.path.join(args_cmdline.model_path, \"cfg_args\")\n        print(\"Looking for config file in\", cfgfilepath)\n        with open(cfgfilepath) as cfg_file:\n            print(\"Config file found: {}\".format(cfgfilepath))\n            cfgfile_string = cfg_file.read()\n    except TypeError:\n        print(\"Config file not found at\")\n        pass\n    args_cfgfile = eval(cfgfile_string)\n\n    merged_dict = vars(args_cfgfile).copy()\n    for k,v in vars(args_cmdline).items():\n        if v != None:\n            merged_dict[k] = v\n    return Namespace(**merged_dict)\n"
  },
  {
    "path": "convert.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use\n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport logging\nfrom argparse import ArgumentParser\nimport shutil\n\n# This Python script is based on the shell converter script provided in the MipNerF 360 repository.\nparser = ArgumentParser(\"Colmap converter\")\nparser.add_argument(\"--no_gpu\", action='store_true')\nparser.add_argument(\"--skip_matching\", action='store_true')\nparser.add_argument(\"--source_path\", \"-s\", required=True, type=str)\nparser.add_argument(\"--camera\", default=\"OPENCV\", type=str)\nparser.add_argument(\"--colmap_executable\", default=\"\", type=str)\nparser.add_argument(\"--resize\", action=\"store_true\")\nparser.add_argument(\"--magick_executable\", default=\"\", type=str)\nargs = parser.parse_args()\ncolmap_command = '\"{}\"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else \"colmap\"\nmagick_command = '\"{}\"'.format(args.magick_executable) if len(args.magick_executable) > 0 else \"magick\"\nuse_gpu = 1 if not args.no_gpu else 0\n\nif not args.skip_matching:\n    os.makedirs(args.source_path + \"/distorted/sparse\", exist_ok=True)\n\n    ## Feature extraction\n    feat_extracton_cmd = colmap_command + \" feature_extractor \"\\\n        \"--database_path \" + args.source_path + \"/distorted/database.db \\\n        --image_path \" + args.source_path + \"/input \\\n        --ImageReader.single_camera 1 \\\n        --ImageReader.camera_model \" + args.camera + \" \\\n        --SiftExtraction.use_gpu \" + str(use_gpu)\n    exit_code = os.system(feat_extracton_cmd)\n    if exit_code != 0:\n        logging.error(f\"Feature extraction failed with code {exit_code}. Exiting.\")\n        exit(exit_code)\n\n    ## Feature matching\n    feat_matching_cmd = colmap_command + \" exhaustive_matcher \\\n        --database_path \" + args.source_path + \"/distorted/database.db \\\n        --SiftMatching.use_gpu \" + str(use_gpu)\n    exit_code = os.system(feat_matching_cmd)\n    if exit_code != 0:\n        logging.error(f\"Feature matching failed with code {exit_code}. Exiting.\")\n        exit(exit_code)\n\n    ### Bundle adjustment\n    # The default Mapper tolerance is unnecessarily large,\n    # decreasing it speeds up bundle adjustment steps.\n    mapper_cmd = (colmap_command + \" mapper \\\n        --database_path \" + args.source_path + \"/distorted/database.db \\\n        --image_path \"  + args.source_path + \"/input \\\n        --output_path \"  + args.source_path + \"/distorted/sparse \\\n        --Mapper.ba_global_function_tolerance=0.000001\")\n    exit_code = os.system(mapper_cmd)\n    if exit_code != 0:\n        logging.error(f\"Mapper failed with code {exit_code}. Exiting.\")\n        exit(exit_code)\n\n### Image undistortion\n## We need to undistort our images into ideal pinhole intrinsics.\nimg_undist_cmd = (colmap_command + \" image_undistorter \\\n    --image_path \" + args.source_path + \"/input \\\n    --input_path \" + args.source_path + \"/distorted/sparse/0 \\\n    --output_path \" + args.source_path + \"\\\n    --output_type COLMAP\")\nexit_code = os.system(img_undist_cmd)\nif exit_code != 0:\n    logging.error(f\"Mapper failed with code {exit_code}. Exiting.\")\n    exit(exit_code)\n\nfiles = os.listdir(args.source_path + \"/sparse\")\nos.makedirs(args.source_path + \"/sparse/0\", exist_ok=True)\n# Copy each file from the source directory to the destination directory\nfor file in files:\n    if file == '0':\n        continue\n    source_file = os.path.join(args.source_path, \"sparse\", file)\n    destination_file = os.path.join(args.source_path, \"sparse\", \"0\", file)\n    shutil.move(source_file, destination_file)\n\nif(args.resize):\n    print(\"Copying and resizing...\")\n\n    # Resize images.\n    os.makedirs(args.source_path + \"/images_2\", exist_ok=True)\n    os.makedirs(args.source_path + \"/images_4\", exist_ok=True)\n    os.makedirs(args.source_path + \"/images_8\", exist_ok=True)\n    # Get the list of files in the source directory\n    files = os.listdir(args.source_path + \"/images\")\n    # Copy each file from the source directory to the destination directory\n    for file in files:\n        source_file = os.path.join(args.source_path, \"images\", file)\n\n        destination_file = os.path.join(args.source_path, \"images_2\", file)\n        shutil.copy2(source_file, destination_file)\n        exit_code = os.system(magick_command + \" mogrify -resize 50% \" + destination_file)\n        if exit_code != 0:\n            logging.error(f\"50% resize failed with code {exit_code}. Exiting.\")\n            exit(exit_code)\n\n        destination_file = os.path.join(args.source_path, \"images_4\", file)\n        shutil.copy2(source_file, destination_file)\n        exit_code = os.system(magick_command + \" mogrify -resize 25% \" + destination_file)\n        if exit_code != 0:\n            logging.error(f\"25% resize failed with code {exit_code}. Exiting.\")\n            exit(exit_code)\n\n        destination_file = os.path.join(args.source_path, \"images_8\", file)\n        shutil.copy2(source_file, destination_file)\n        exit_code = os.system(magick_command + \" mogrify -resize 12.5% \" + destination_file)\n        if exit_code != 0:\n            logging.error(f\"12.5% resize failed with code {exit_code}. Exiting.\")\n            exit(exit_code)\n\nprint(\"Done.\")\n"
  },
  {
    "path": "environment.yml",
    "content": "name: gaussian_splatting\nchannels:\n  - pytorch\n  - conda-forge\n  - defaults\ndependencies:\n  - cudatoolkit=11.6\n  - plyfile=0.8.1\n  - python=3.7.13\n  - pip=22.3.1\n  - pytorch=1.12.1\n  - torchaudio=0.12.1\n  - torchvision=0.13.1\n  - tqdm\n  - pip:\n    - bitarray\n    - scipy\n    - submodules/ashawkey-diff-gaussian-rasterization"
  },
  {
    "path": "full_eval.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nfrom argparse import ArgumentParser\n\nmipnerf360_outdoor_scenes = [\"bicycle\", \"flowers\", \"garden\", \"stump\", \"treehill\"]\nmipnerf360_indoor_scenes = [\"room\", \"counter\", \"kitchen\", \"bonsai\"]\ntanks_and_temples_scenes = [\"truck\", \"train\"]\ndeep_blending_scenes = [\"drjohnson\", \"playroom\"]\n\nparser = ArgumentParser(description=\"Full evaluation script parameters\")\nparser.add_argument(\"--skip_training\", action=\"store_true\")\nparser.add_argument(\"--skip_rendering\", action=\"store_true\")\nparser.add_argument(\"--skip_metrics\", action=\"store_true\")\nparser.add_argument(\"--output_path\", default=\"./eval\")\nargs, _ = parser.parse_known_args()\n\nall_scenes = []\nall_scenes.extend(mipnerf360_outdoor_scenes)\nall_scenes.extend(mipnerf360_indoor_scenes)\nall_scenes.extend(tanks_and_temples_scenes)\nall_scenes.extend(deep_blending_scenes)\n\nif not args.skip_training or not args.skip_rendering:\n    parser.add_argument('--mipnerf360', \"-m360\", required=True, type=str)\n    parser.add_argument(\"--tanksandtemples\", \"-tat\", required=True, type=str)\n    parser.add_argument(\"--deepblending\", \"-db\", required=True, type=str)\n    args = parser.parse_args()\n\nif not args.skip_training:\n    common_args = \" --quiet --eval --test_iterations -1 \"\n    for scene in mipnerf360_outdoor_scenes:\n        source = args.mipnerf360 + \"/\" + scene\n        os.system(\"python train.py -s \" + source + \" -i images_4 -m \" + args.output_path + \"/\" + scene + common_args)\n    for scene in mipnerf360_indoor_scenes:\n        source = args.mipnerf360 + \"/\" + scene\n        os.system(\"python train.py -s \" + source + \" -i images_2 -m \" + args.output_path + \"/\" + scene + common_args)\n    for scene in tanks_and_temples_scenes:\n        source = args.tanksandtemples + \"/\" + scene\n        os.system(\"python train.py -s \" + source + \" -m \" + args.output_path + \"/\" + scene + common_args)\n    for scene in deep_blending_scenes:\n        source = args.deepblending + \"/\" + scene\n        os.system(\"python train.py -s \" + source + \" -m \" + args.output_path + \"/\" + scene + common_args)\n\nif not args.skip_rendering:\n    all_sources = []\n    for scene in mipnerf360_outdoor_scenes:\n        all_sources.append(args.mipnerf360 + \"/\" + scene)\n    for scene in mipnerf360_indoor_scenes:\n        all_sources.append(args.mipnerf360 + \"/\" + scene)\n    for scene in tanks_and_temples_scenes:\n        all_sources.append(args.tanksandtemples + \"/\" + scene)\n    for scene in deep_blending_scenes:\n        all_sources.append(args.deepblending + \"/\" + scene)\n\n    common_args = \" --quiet --eval --skip_train\"\n    for scene, source in zip(all_scenes, all_sources):\n        os.system(\"python render.py --iteration 7000 -s \" + source + \" -m \" + args.output_path + \"/\" + scene + common_args)\n        os.system(\"python render.py --iteration 30000 -s \" + source + \" -m \" + args.output_path + \"/\" + scene + common_args)\n\nif not args.skip_metrics:\n    scenes_string = \"\"\n    for scene in all_scenes:\n        scenes_string += \"\\\"\" + args.output_path + \"/\" + scene + \"\\\" \"\n\n    os.system(\"python metrics.py -m \" + scenes_string)"
  },
  {
    "path": "gaussian_renderer/__init__.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport math\n# from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\nfrom ashawkey_diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer\nfrom scene.gaussian_model import GaussianModel\nfrom utils.sh_utils import eval_sh\nfrom utils.opengs_utlis import *\n# from sklearn.neighbors import NearestNeighbors\nimport pytorch3d.ops\n\ndef render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, iteration,\n            scaling_modifier = 1.0, override_color = None, visible_mask = None, mask_num=0,\n            cluster_idx=None,       # per-point cluster id (coarse-level)\n            leaf_cluster_idx=None,  # per-point cluster id (fine-level)\n            rescale=True,           # re-scale (for enhance ins_feat)\n            origin_feat=False,      # origin ins_feat (not quantized)\n            render_feat_map=True,   # render image-level feat map\n            render_color=True,      # render rgb image\n            render_cluster=False,   # render cluster, stage 2.2\n            better_vis=False,       # filter some points\n            selected_root_id=None,  # coarse-level cluster id\n            selected_leaf_id=None,  # fine-level cluster id (possibly more than one)\n            pre_mask=None,\n            seg_rgb=False,          # render cluster rgb, not feat\n            post_process=False,     # post\n            root_num=64, leaf_num=10):  # k1, k2 \n    \"\"\"\n    Render the scene. \n    \n    Background tensor (bg_color) must be on GPU!\n    \"\"\"\n \n    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n    screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device=\"cuda\") + 0\n    try:\n        screenspace_points.retain_grad()\n    except:\n        pass\n\n    # Set up rasterization configuration\n    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n    raster_settings = GaussianRasterizationSettings(\n        image_height=int(viewpoint_camera.image_height),\n        image_width=int(viewpoint_camera.image_width),\n        tanfovx=tanfovx,\n        tanfovy=tanfovy,\n        bg=bg_color,\n        scale_modifier=scaling_modifier,\n        viewmatrix=viewpoint_camera.world_view_transform,\n        projmatrix=viewpoint_camera.full_proj_transform,\n        sh_degree=pc.active_sh_degree,\n        campos=viewpoint_camera.camera_center,\n        prefiltered=False,\n        debug=pipe.debug\n    )\n\n    rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n    means3D = pc.get_xyz\n    means2D = screenspace_points\n    opacity = pc.get_opacity\n\n    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n    # scaling / rotation by the rasterizer.\n    scales = None\n    rotations = None\n    cov3D_precomp = None\n    if pipe.compute_cov3D_python:\n        cov3D_precomp = pc.get_covariance(scaling_modifier)\n    else:\n        scales = pc.get_scaling\n        rotations = pc.get_rotation\n\n    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n    shs = None\n    colors_precomp = None\n    if override_color is None:\n        if pipe.convert_SHs_python:\n            shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)\n            dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))\n            dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)\n            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)\n            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)\n        else:\n            shs = pc.get_features\n    else:\n        colors_precomp = override_color\n\n    if render_color:\n        rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(\n            means3D = means3D,\n            means2D = means2D,\n            shs = shs,\n            colors_precomp = colors_precomp,\n            opacities = opacity,\n            scales = scales,\n            rotations = rotations,\n            cov3D_precomp = cov3D_precomp)\n    else:\n        rendered_image, radii, rendered_depth, rendered_alpha = None, None, None, None\n\n    # ################################################################\n    # [Stage 1, Stage 2.1] Render image-level instance feature map   #\n    #   - rendered_ins_feat: image-level feat map                    #\n    # ################################################################\n    # probabilistically rescale\n    prob = torch.rand(1)\n    rescale_factor = torch.tensor(1.0, dtype=torch.float32).cuda()\n    if prob > 0.5 and rescale:\n        rescale_factor = torch.rand(1).cuda()\n    if render_feat_map:\n        # get feature\n        ins_feat = (pc.get_ins_feat(origin=origin_feat) + 1) / 2   # pseudo -> norm, else -> raw\n        # first three channels\n        rendered_ins_feat, _, _, _ = rasterizer(\n            means3D = means3D,\n            means2D = means2D,\n            shs = None,\n            colors_precomp = ins_feat[:, :3],   # render features as pre-computed colors\n            opacities = opacity,\n            scales = scales * rescale_factor,\n\n            rotations = rotations,\n            cov3D_precomp = cov3D_precomp)\n        # last three channels\n        if ins_feat.shape[-1] > 3:\n            rendered_ins_feat2, _, _, _ = rasterizer(\n                means3D = means3D,\n                means2D = means2D,\n                shs = None,\n                colors_precomp = ins_feat[:, 3:6],  # render features as pre-computed colors\n                opacities = opacity,\n                scales = scales * rescale_factor,\n\n                rotations = rotations,\n                cov3D_precomp = cov3D_precomp)\n            rendered_ins_feat = torch.cat((rendered_ins_feat, rendered_ins_feat2), dim=0)\n        # mask\n        _, _, _, silhouette = rasterizer(\n            means3D = means3D,\n            means2D = means2D,\n            shs = shs,\n            colors_precomp = colors_precomp,\n            opacities = opacity,\n            scales = scales * rescale_factor,\n            # opacities = opacity*0+1.0,    # \n            # scales = scales*0+0.001,   # *0.1\n            rotations = rotations,\n            cov3D_precomp = cov3D_precomp)\n    else:\n        rendered_ins_feat, silhouette = None, None\n\n\n    # ########################################################################\n    # [Preprocessing for Stage 2.2]: render (coarse) cluster-level feat map  #\n    #   - rendered_clusters: feat map of the coarse clusters                 #\n    #   - rendered_cluster_silhouettes: cluster mask                         #\n    # ########################################################################\n    viewed_pts = radii > 0      # ignore the invisible points\n    if cluster_idx is not None:\n        num_cluster = cluster_idx.max() + 1\n        cluster_occur = torch.zeros(num_cluster).to(torch.bool) # [num_cluster], bool\n    else:\n        cluster_occur = None\n    if render_cluster and cluster_idx is not None and viewed_pts.sum() != 0:\n        ins_feat = (pc.get_ins_feat(origin=origin_feat) + 1) / 2   # pseudo -> norm, else -> raw\n        rendered_clusters = []\n        rendered_cluster_silhouettes = []\n        scale_filter = (scales < 0.5).all(dim=1)    #   filter\n        for idx in range(num_cluster):\n            if not better_vis and idx != selected_root_id:\n                continue\n\n            # ignore the invisible coarse-level cluster\n            if viewpoint_camera.bClusterOccur is not None and viewpoint_camera.bClusterOccur[idx] == False:\n                continue\n            \n            # NOTE: Render only the idx-th coarse cluster\n            filter_idx = cluster_idx == idx\n            \n            filter_idx = filter_idx & viewed_pts\n            # todo: filter\n            if better_vis:\n                filter_idx = filter_idx & scale_filter\n                if filter_idx.sum() < 100:\n                    continue\n                    \n            # render cluster-level feat map\n            rendered_cluster, _, _, cluster_silhouette = rasterizer(\n                means3D = means3D[filter_idx],\n                means2D = means2D[filter_idx],\n                shs = None,  # feat\n                colors_precomp = ins_feat[:, :3][filter_idx],  # feat\n                # shs = shs[filter_idx],  # rgb\n                # colors_precomp = None,  # rgb\n                opacities = opacity[filter_idx],\n                scales = scales[filter_idx] * rescale_factor,\n                rotations = rotations[filter_idx],\n                cov3D_precomp = cov3D_precomp)\n            if ins_feat.shape[-1] > 3:\n                rendered_cluster2, _, _, cluster_silhouette = rasterizer(\n                    means3D = means3D[filter_idx],\n                    means2D = means2D[filter_idx],\n                    shs = None,           # feat\n                    colors_precomp = ins_feat[:, 3:][filter_idx],  # feat\n                    # shs = shs[filter_idx],  # rgb\n                    # colors_precomp = None,  # rgb\n                    opacities = opacity[filter_idx],\n                    scales = scales[filter_idx] * rescale_factor,\n                    rotations = rotations[filter_idx],\n                    cov3D_precomp = cov3D_precomp)\n                rendered_cluster = torch.cat((rendered_cluster, rendered_cluster2), dim=0)\n\n            # alpha --> mask\n            if cluster_silhouette.max() > 0.8:\n                cluster_occur[idx] = True\n                rendered_clusters.append(rendered_cluster)\n                rendered_cluster_silhouettes.append(cluster_silhouette)\n        if len(rendered_cluster_silhouettes) != 0:\n            rendered_cluster_silhouettes = torch.vstack(rendered_cluster_silhouettes)\n    else:\n        rendered_clusters, rendered_cluster_silhouettes = None, None\n\n\n    # ###############################################################\n    # [Stage 2.2 & Stage 3] render (fine) cluster-level feat map    #\n    #   - rendered_leaf_clusters: feat map of the fine clusters     #\n    #   - rendered_leaf_cluster_silhouettes: fine cluster mask      #\n    #   - occured_leaf_id: visible fine cluster id                  #\n    # ###############################################################\n    if leaf_cluster_idx is not None and leaf_cluster_idx.numel() > 0:\n        ins_feat = (pc.get_ins_feat(origin=origin_feat) + 1) / 2   # pseudo -> norm, else -> raw\n        # todo: rescale\n        scale_filter = (scales < 0.1).all(dim=1)\n        # scale_filter = (scales < 0.1).all(dim=1) & (opacity > 0.1).squeeze(-1)\n        re_scale_factor = torch.ones_like(opacity)  # not used\n\n        # determine the fine cluster ID range (lerf_range) based on the coarse cluster ID (selected_leaf_id).\n        # root_num = 64   # todo modify\n        # leaf_num = 5    # todo modify\n        rendered_leaf_clusters = []\n        rendered_leaf_cluster_silhouettes = []\n        occured_leaf_id = []\n        if selected_leaf_id is None:\n            if selected_root_id is not None:\n                start_leaf = selected_root_id * leaf_num   # todo 10\n                end_leaf = start_leaf + leaf_num   # todo 10\n            else:\n                start_leaf = 0\n                end_leaf = root_num * leaf_num  # todo 64 * 10\n            lerf_range = range(start_leaf, end_leaf)\n        else:\n            lerf_range = selected_leaf_id.tolist()\n        for _, leaf_idx in enumerate(lerf_range):   # render each fine cluster\n            # ignore the invisible clusters\n            if viewpoint_camera.bClusterOccur is not None and viewpoint_camera.bClusterOccur[selected_root_id] == False:\n                continue\n\n            if selected_leaf_id is None:\n                filter_idx = leaf_cluster_idx == leaf_idx     # Render only the idx-th fine cluster\n                # filter_idx = labels != value      # remove the idx-th fine cluster\n            else:\n                filter_idx = (leaf_cluster_idx.unsqueeze(1) == selected_leaf_id).any(dim=1)\n\n            # pre-mask\n            if pre_mask is not None:\n                filter_idx = filter_idx & pre_mask\n\n            filter_idx = filter_idx & viewed_pts\n            # filter\n            if better_vis:\n                filter_idx = filter_idx & scale_filter\n                if filter_idx.sum() < 100:\n                    continue\n            \n            # TODO post process (for 3D object selection)\n            # pre_count = filter_idx.sum()\n            max_time = 5\n            if post_process and max_time > 0:\n                nearest_k_distance = pytorch3d.ops.knn_points(\n                    means3D[filter_idx].unsqueeze(0),\n                    means3D[filter_idx].unsqueeze(0),\n                    # K=int(filter_idx.sum()**0.5),\n                    K=int(filter_idx.sum()**0.5),\n                ).dists\n                mean_nearest_k_distance, std_nearest_k_distance = nearest_k_distance.mean(), nearest_k_distance.std()\n                # print(std_nearest_k_distance, \"std_nearest_k_distance\")\n\n                mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + std_nearest_k_distance\n                # mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + 0.1 * std_nearest_k_distance\n\n                mask = mask.squeeze()\n                if filter_idx is not None:\n                    filter_idx[filter_idx != 0] = mask\n                max_time -= 1\n            \n            if filter_idx.sum() < 10:\n                continue\n\n            # record the fine cluster id appears in the current view.\n            occured_leaf_id.append(leaf_idx)\n\n            # note: render cluster rgb or feat.\n            if seg_rgb:\n                _shs = shs[filter_idx]\n                _colors_precomp1 = None\n                _colors_precomp2 = None\n            else:\n                _shs = None\n                _colors_precomp1 = ins_feat[:, :3][filter_idx]\n                _colors_precomp2 = ins_feat[:, 3:][filter_idx]\n            \n            rendered_leaf_cluster, _, _, leaf_cluster_silhouette = rasterizer(\n                means3D = means3D[filter_idx],\n                means2D = means2D[filter_idx],\n                shs = _shs,                          # rgb or feat\n                colors_precomp = _colors_precomp1,   # rgb or feat\n                opacities = opacity[filter_idx],\n                scales = (scales * re_scale_factor)[filter_idx],\n                rotations = rotations[filter_idx],\n                cov3D_precomp = cov3D_precomp)\n            if ins_feat.shape[-1] > 3:\n                rendered_leaf_cluster2, _, _, _ = rasterizer(\n                    means3D = means3D[filter_idx],\n                    means2D = means2D[filter_idx],\n                    shs = _shs,                          # rgb or feat\n                    colors_precomp = _colors_precomp2,   # rgb or feat\n                    opacities = opacity[filter_idx],\n                    scales = (scales * re_scale_factor)[filter_idx],\n                    rotations = rotations[filter_idx],\n                    cov3D_precomp = cov3D_precomp)\n                rendered_leaf_cluster = torch.cat((rendered_leaf_cluster, rendered_leaf_cluster2), dim=0)\n            rendered_leaf_clusters.append(rendered_leaf_cluster)\n            rendered_leaf_cluster_silhouettes.append(leaf_cluster_silhouette)\n            if selected_leaf_id is not None and len(rendered_leaf_clusters) > 0:\n                break\n        if len(rendered_leaf_cluster_silhouettes) != 0:\n            rendered_leaf_cluster_silhouettes = torch.vstack(rendered_leaf_cluster_silhouettes)\n    else:\n        rendered_leaf_clusters = None\n        rendered_leaf_cluster_silhouettes =  None\n        occured_leaf_id = None\n\n    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.\n    # They will be excluded from value updates used in the splitting criteria.\n    return {\"render\": rendered_image,\n            \"alpha\": rendered_alpha,\n            \"depth\": rendered_depth,    # not used\n            \"silhouette\": silhouette,\n            \"ins_feat\": rendered_ins_feat,          # image-level feat map\n            \"cluster_imgs\": rendered_clusters,      # coarse cluster feat map/image\n            \"cluster_silhouettes\": rendered_cluster_silhouettes,    # coarse cluster mask\n            \"leaf_clusters_imgs\": rendered_leaf_clusters,           # fine cluster feat map/image\n            \"leaf_cluster_silhouettes\": rendered_leaf_cluster_silhouettes,  # fine cluster mask\n            \"occured_leaf_id\": occured_leaf_id,     # fine cluster\n            \"cluster_occur\": cluster_occur,         # coarse cluster\n            \"viewspace_points\": screenspace_points,\n            \"visibility_filter\" : radii > 0,\n            \"radii\": radii}"
  },
  {
    "path": "gaussian_renderer/network_gui.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport traceback\nimport socket\nimport json\nfrom scene.cameras import MiniCam\n\nhost = \"127.0.0.1\"\nport = 6009\n\nconn = None\naddr = None\n\nlistener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)\n\ndef init(wish_host, wish_port):\n    global host, port, listener\n    host = wish_host\n    port = wish_port\n    listener.bind((host, port))\n    listener.listen()\n    listener.settimeout(0)\n\ndef try_connect():\n    global conn, addr, listener\n    try:\n        conn, addr = listener.accept()\n        print(f\"\\nConnected by {addr}\")\n        conn.settimeout(None)\n    except Exception as inst:\n        pass\n            \ndef read():\n    global conn\n    messageLength = conn.recv(4)\n    messageLength = int.from_bytes(messageLength, 'little')\n    message = conn.recv(messageLength)\n    return json.loads(message.decode(\"utf-8\"))\n\ndef send(message_bytes, verify):\n    global conn\n    if message_bytes != None:\n        conn.sendall(message_bytes)\n    conn.sendall(len(verify).to_bytes(4, 'little'))\n    conn.sendall(bytes(verify, 'ascii'))\n\ndef receive():\n    message = read()\n\n    width = message[\"resolution_x\"]\n    height = message[\"resolution_y\"]\n\n    if width != 0 and height != 0:\n        try:\n            do_training = bool(message[\"train\"])\n            fovy = message[\"fov_y\"]\n            fovx = message[\"fov_x\"]\n            znear = message[\"z_near\"]\n            zfar = message[\"z_far\"]\n            do_shs_python = bool(message[\"shs_python\"])\n            do_rot_scale_python = bool(message[\"rot_scale_python\"])\n            keep_alive = bool(message[\"keep_alive\"])\n            scaling_modifier = message[\"scaling_modifier\"]\n            world_view_transform = torch.reshape(torch.tensor(message[\"view_matrix\"]), (4, 4)).cuda()\n            world_view_transform[:,1] = -world_view_transform[:,1]\n            world_view_transform[:,2] = -world_view_transform[:,2]\n            full_proj_transform = torch.reshape(torch.tensor(message[\"view_projection_matrix\"]), (4, 4)).cuda()\n            full_proj_transform[:,1] = -full_proj_transform[:,1]\n            custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)\n        except Exception as e:\n            print(\"\")\n            traceback.print_exc()\n            raise e\n        return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier\n    else:\n        return None, None, None, None, None, None"
  },
  {
    "path": "lpipsPyTorch/__init__.py",
    "content": "import torch\n\nfrom .modules.lpips import LPIPS\n\n\ndef lpips(x: torch.Tensor,\n          y: torch.Tensor,\n          net_type: str = 'alex',\n          version: str = '0.1'):\n    r\"\"\"Function that measures\n    Learned Perceptual Image Patch Similarity (LPIPS).\n\n    Arguments:\n        x, y (torch.Tensor): the input tensors to compare.\n        net_type (str): the network type to compare the features: \n                        'alex' | 'squeeze' | 'vgg'. Default: 'alex'.\n        version (str): the version of LPIPS. Default: 0.1.\n    \"\"\"\n    device = x.device\n    criterion = LPIPS(net_type, version).to(device)\n    return criterion(x, y)\n"
  },
  {
    "path": "lpipsPyTorch/modules/lpips.py",
    "content": "import torch\nimport torch.nn as nn\n\nfrom .networks import get_network, LinLayers\nfrom .utils import get_state_dict\n\n\nclass LPIPS(nn.Module):\n    r\"\"\"Creates a criterion that measures\n    Learned Perceptual Image Patch Similarity (LPIPS).\n\n    Arguments:\n        net_type (str): the network type to compare the features: \n                        'alex' | 'squeeze' | 'vgg'. Default: 'alex'.\n        version (str): the version of LPIPS. Default: 0.1.\n    \"\"\"\n    def __init__(self, net_type: str = 'alex', version: str = '0.1'):\n\n        assert version in ['0.1'], 'v0.1 is only supported now'\n\n        super(LPIPS, self).__init__()\n\n        # pretrained network\n        self.net = get_network(net_type)\n\n        # linear layers\n        self.lin = LinLayers(self.net.n_channels_list)\n        self.lin.load_state_dict(get_state_dict(net_type, version))\n\n    def forward(self, x: torch.Tensor, y: torch.Tensor):\n        feat_x, feat_y = self.net(x), self.net(y)\n\n        diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]\n        res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]\n\n        return torch.sum(torch.cat(res, 0), 0, True)\n"
  },
  {
    "path": "lpipsPyTorch/modules/networks.py",
    "content": "from typing import Sequence\n\nfrom itertools import chain\n\nimport torch\nimport torch.nn as nn\nfrom torchvision import models\n\nfrom .utils import normalize_activation\n\n\ndef get_network(net_type: str):\n    if net_type == 'alex':\n        return AlexNet()\n    elif net_type == 'squeeze':\n        return SqueezeNet()\n    elif net_type == 'vgg':\n        return VGG16()\n    else:\n        raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')\n\n\nclass LinLayers(nn.ModuleList):\n    def __init__(self, n_channels_list: Sequence[int]):\n        super(LinLayers, self).__init__([\n            nn.Sequential(\n                nn.Identity(),\n                nn.Conv2d(nc, 1, 1, 1, 0, bias=False)\n            ) for nc in n_channels_list\n        ])\n\n        for param in self.parameters():\n            param.requires_grad = False\n\n\nclass BaseNet(nn.Module):\n    def __init__(self):\n        super(BaseNet, self).__init__()\n\n        # register buffer\n        self.register_buffer(\n            'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])\n        self.register_buffer(\n            'std', torch.Tensor([.458, .448, .450])[None, :, None, None])\n\n    def set_requires_grad(self, state: bool):\n        for param in chain(self.parameters(), self.buffers()):\n            param.requires_grad = state\n\n    def z_score(self, x: torch.Tensor):\n        return (x - self.mean) / self.std\n\n    def forward(self, x: torch.Tensor):\n        x = self.z_score(x)\n\n        output = []\n        for i, (_, layer) in enumerate(self.layers._modules.items(), 1):\n            x = layer(x)\n            if i in self.target_layers:\n                output.append(normalize_activation(x))\n            if len(output) == len(self.target_layers):\n                break\n        return output\n\n\nclass SqueezeNet(BaseNet):\n    def __init__(self):\n        super(SqueezeNet, self).__init__()\n\n        self.layers = models.squeezenet1_1(True).features\n        self.target_layers = [2, 5, 8, 10, 11, 12, 13]\n        self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]\n\n        self.set_requires_grad(False)\n\n\nclass AlexNet(BaseNet):\n    def __init__(self):\n        super(AlexNet, self).__init__()\n\n        self.layers = models.alexnet(True).features\n        self.target_layers = [2, 5, 8, 10, 12]\n        self.n_channels_list = [64, 192, 384, 256, 256]\n\n        self.set_requires_grad(False)\n\n\nclass VGG16(BaseNet):\n    def __init__(self):\n        super(VGG16, self).__init__()\n\n        self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features\n        self.target_layers = [4, 9, 16, 23, 30]\n        self.n_channels_list = [64, 128, 256, 512, 512]\n\n        self.set_requires_grad(False)\n"
  },
  {
    "path": "lpipsPyTorch/modules/utils.py",
    "content": "from collections import OrderedDict\n\nimport torch\n\n\ndef normalize_activation(x, eps=1e-10):\n    norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))\n    return x / (norm_factor + eps)\n\n\ndef get_state_dict(net_type: str = 'alex', version: str = '0.1'):\n    # build url\n    url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \\\n        + f'master/lpips/weights/v{version}/{net_type}.pth'\n\n    # download\n    old_state_dict = torch.hub.load_state_dict_from_url(\n        url, progress=True,\n        map_location=None if torch.cuda.is_available() else torch.device('cpu')\n    )\n\n    # rename keys\n    new_state_dict = OrderedDict()\n    for key, val in old_state_dict.items():\n        new_key = key\n        new_key = new_key.replace('lin', '')\n        new_key = new_key.replace('model.', '')\n        new_state_dict[new_key] = val\n\n    return new_state_dict\n"
  },
  {
    "path": "metrics.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom pathlib import Path\nimport os\nfrom PIL import Image\nimport torch\nimport torchvision.transforms.functional as tf\nfrom utils.loss_utils import ssim\nfrom lpipsPyTorch import lpips\nimport json\nfrom tqdm import tqdm\nfrom utils.image_utils import psnr\nfrom argparse import ArgumentParser\n\ndef readImages(renders_dir, gt_dir):\n    renders = []\n    gts = []\n    image_names = []\n    for fname in os.listdir(renders_dir):\n        render = Image.open(renders_dir / fname)\n        gt = Image.open(gt_dir / fname)\n        renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())\n        gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())\n        image_names.append(fname)\n    return renders, gts, image_names\n\ndef evaluate(model_paths):\n\n    full_dict = {}\n    per_view_dict = {}\n    full_dict_polytopeonly = {}\n    per_view_dict_polytopeonly = {}\n    print(\"\")\n\n    for scene_dir in model_paths:\n        try:\n            print(\"Scene:\", scene_dir)\n            full_dict[scene_dir] = {}\n            per_view_dict[scene_dir] = {}\n            full_dict_polytopeonly[scene_dir] = {}\n            per_view_dict_polytopeonly[scene_dir] = {}\n\n            test_dir = Path(scene_dir) / \"test\"\n\n            for method in os.listdir(test_dir):\n                print(\"Method:\", method)\n\n                full_dict[scene_dir][method] = {}\n                per_view_dict[scene_dir][method] = {}\n                full_dict_polytopeonly[scene_dir][method] = {}\n                per_view_dict_polytopeonly[scene_dir][method] = {}\n\n                method_dir = test_dir / method\n                gt_dir = method_dir/ \"gt\"\n                renders_dir = method_dir / \"renders\"\n                renders, gts, image_names = readImages(renders_dir, gt_dir)\n\n                ssims = []\n                psnrs = []\n                lpipss = []\n\n                for idx in tqdm(range(len(renders)), desc=\"Metric evaluation progress\"):\n                    ssims.append(ssim(renders[idx], gts[idx]))\n                    psnrs.append(psnr(renders[idx], gts[idx]))\n                    lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))\n\n                print(\"  SSIM : {:>12.7f}\".format(torch.tensor(ssims).mean(), \".5\"))\n                print(\"  PSNR : {:>12.7f}\".format(torch.tensor(psnrs).mean(), \".5\"))\n                print(\"  LPIPS: {:>12.7f}\".format(torch.tensor(lpipss).mean(), \".5\"))\n                print(\"\")\n\n                full_dict[scene_dir][method].update({\"SSIM\": torch.tensor(ssims).mean().item(),\n                                                        \"PSNR\": torch.tensor(psnrs).mean().item(),\n                                                        \"LPIPS\": torch.tensor(lpipss).mean().item()})\n                per_view_dict[scene_dir][method].update({\"SSIM\": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},\n                                                            \"PSNR\": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},\n                                                            \"LPIPS\": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})\n\n            with open(scene_dir + \"/results.json\", 'w') as fp:\n                json.dump(full_dict[scene_dir], fp, indent=True)\n            with open(scene_dir + \"/per_view.json\", 'w') as fp:\n                json.dump(per_view_dict[scene_dir], fp, indent=True)\n        except:\n            print(\"Unable to compute metrics for model\", scene_dir)\n\nif __name__ == \"__main__\":\n    device = torch.device(\"cuda:0\")\n    torch.cuda.set_device(device)\n\n    # Set up command line argument parser\n    parser = ArgumentParser(description=\"Training script parameters\")\n    parser.add_argument('--model_paths', '-m', required=True, nargs=\"+\", type=str, default=[])\n    args = parser.parse_args()\n    evaluate(args.model_paths)\n"
  },
  {
    "path": "render.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport torch.nn.functional as F\nfrom scene import Scene\nimport os\nfrom tqdm import tqdm\nfrom os import makedirs\nfrom gaussian_renderer import render\nimport torchvision\nfrom utils.general_utils import safe_state\nfrom argparse import ArgumentParser\nfrom arguments import ModelParams, PipelineParams, get_combined_args\nfrom gaussian_renderer import GaussianModel\nimport numpy as np\nfrom utils.opengs_utlis import get_SAM_mask_and_feat, load_code_book\n\n# Randomly initialize 300 colors for visualizing the SAM mask. [OpenGaussian]\nnp.random.seed(42)\ncolors_defined = np.random.randint(100, 256, size=(300, 3))\ncolors_defined[0] = np.array([0, 0, 0]) # Ignore the mask ID of -1 and set it to black.\ncolors_defined = torch.from_numpy(colors_defined)\n\ndef render_set(model_path, name, iteration, views, gaussians, pipeline, background):\n    render_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"renders\")\n    gts_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"gt\")\n\n    render_ins_feat_path1 = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"renders_ins_feat1\")\n    render_ins_feat_path2 = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"renders_ins_feat2\")\n    gt_sam_mask_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"gt_sam_mask\")\n\n    makedirs(render_path, exist_ok=True)\n    makedirs(gts_path, exist_ok=True)\n    makedirs(render_ins_feat_path1, exist_ok=True)\n    makedirs(render_ins_feat_path2, exist_ok=True)\n    makedirs(gt_sam_mask_path, exist_ok=True)\n\n    # load codebook\n    root_code_book_path = os.path.join(model_path, \"point_cloud\", f'iteration_{iteration}', \"root_code_book\")\n    leaf_code_book_path = os.path.join(model_path, \"point_cloud\", f'iteration_{iteration}', \"leaf_code_book\")\n    if os.path.exists(os.path.join(root_code_book_path, 'kmeans_inds.bin')):\n        root_code_book, root_cluster_indices = load_code_book(root_code_book_path)\n        root_cluster_indices = torch.from_numpy(root_cluster_indices).cuda()\n    if os.path.exists(os.path.join(leaf_code_book_path, 'kmeans_inds.bin')):\n        leaf_code_book, leaf_cluster_indices = load_code_book(leaf_code_book_path)\n        leaf_cluster_indices = torch.from_numpy(leaf_cluster_indices).cuda()\n    else:\n        leaf_cluster_indices = None\n\n    # render\n    for idx, view in enumerate(tqdm(views, desc=\"Rendering progress\")):\n        render_pkg = render(view, gaussians, pipeline, background, iteration, rescale=False)\n\n        # RGB\n        rendering = render_pkg[\"render\"]\n        gt = view.original_image[0:3, :, :]\n\n        # ins_feat\n        rendered_ins_feat = render_pkg[\"ins_feat\"]\n        gt_sam_mask = view.original_sam_mask.cuda()    # [4, H, W]\n\n        # Rendered RGB\n        torchvision.utils.save_image(rendering, os.path.join(render_path, view.image_name + \".png\"))\n        # GT RGB\n        torchvision.utils.save_image(gt, os.path.join(gts_path, view.image_name + \".png\"))\n\n        # ins_feat\n        torchvision.utils.save_image(rendered_ins_feat[:3,:,:], os.path.join(render_ins_feat_path1, view.image_name + \"_1.png\"))\n        torchvision.utils.save_image(rendered_ins_feat[3:6,:,:], os.path.join(render_ins_feat_path2, view.image_name + \"_2.png\"))\n\n        # NOTE get SAM id, mask bool, mask_feat, invalid pix\n        mask_id, _, _, _ = \\\n            get_SAM_mask_and_feat(gt_sam_mask, level=0, original_mask_feat=view.original_mask_feat)\n        # mask visualization\n        mask_color_rand = colors_defined[mask_id.detach().cpu().type(torch.int64)].type(torch.float64)\n        mask_color_rand = mask_color_rand.permute(2, 0, 1)\n        torchvision.utils.save_image(mask_color_rand/255.0, os.path.join(gt_sam_mask_path, view.image_name + \".png\"))\n\ndef render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):\n    with torch.no_grad():\n        gaussians = GaussianModel(dataset.sh_degree)\n        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)\n\n        bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]\n        background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n\n        if not skip_train:\n             render_set(dataset.model_path, \"train\", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)\n\n        if not skip_test:\n             render_set(dataset.model_path, \"test\", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)\n\nif __name__ == \"__main__\":\n    # Set up command line argument parser\n    parser = ArgumentParser(description=\"Testing script parameters\")\n    model = ModelParams(parser, sentinel=True)\n    pipeline = PipelineParams(parser)\n    parser.add_argument(\"--iteration\", default=-1, type=int)\n    parser.add_argument(\"--skip_train\", action=\"store_true\")\n    parser.add_argument(\"--skip_test\", action=\"store_true\")\n    parser.add_argument(\"--quiet\", action=\"store_true\")\n    args = get_combined_args(parser)\n    print(\"Rendering \" + args.model_path)\n\n    # Initialize system state (RNG)\n    safe_state(args.quiet)\n\n    render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)"
  },
  {
    "path": "render_lerf_by_text.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport torch.nn.functional as F\nfrom scene import Scene\nimport os\nfrom tqdm import tqdm\nfrom os import makedirs\nfrom gaussian_renderer import render\nimport torchvision\nfrom utils.general_utils import safe_state\nfrom argparse import ArgumentParser\nfrom arguments import ModelParams, PipelineParams, get_combined_args\nfrom gaussian_renderer import GaussianModel\nimport numpy as np\nimport json\nfrom utils.opengs_utlis import mask_feature_mean, get_SAM_mask_and_feat, load_code_book\n\nnp.random.seed(42)\ncolors_defined = np.random.randint(100, 256, size=(300, 3))\ncolors_defined[0] = np.array([0, 0, 0]) # Ignore the mask ID of -1 and set it to black.\ncolors_defined = torch.from_numpy(colors_defined)\n\ndef render_set(model_path, name, iteration, views, gaussians, pipeline, background, scene_name):\n    render_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"renders\")\n    gts_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"gt\")\n\n    render_ins_feat_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"renders_ins_feat\")\n    gt_sam_mask_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"gt_sam_mask\")\n\n    makedirs(render_path, exist_ok=True)\n    makedirs(gts_path, exist_ok=True)\n    makedirs(render_ins_feat_path, exist_ok=True)\n    makedirs(gt_sam_mask_path, exist_ok=True)\n\n    # load codebook\n    root_code_book, root_cluster_indices = load_code_book(os.path.join(model_path, \"point_cloud\", \\\n        f'iteration_{iteration}', \"root_code_book\"))\n    leaf_code_book, leaf_cluster_indices = load_code_book(os.path.join(model_path, \"point_cloud\", \\\n        f'iteration_{iteration}', \"leaf_code_book\"))\n    root_cluster_indices = torch.from_numpy(root_cluster_indices).cuda()\n    leaf_cluster_indices = torch.from_numpy(leaf_cluster_indices).cuda()\n    # counts = torch.bincount(torch.from_numpy(cluster_indices), minlength=64)\n\n    # load the saved codebook(leaf id) and instance-level language feature\n    # 'leaf_feat', 'leaf_acore', 'occu_count', 'leaf_ind'\n    mapping_file = os.path.join(model_path, \"cluster_lang.npz\")\n    saved_data = np.load(mapping_file)\n    leaf_lang_feat = torch.from_numpy(saved_data[\"leaf_feat.npy\"]).cuda()    # [num_leaf=k1*k2, 512] cluster lang feat\n    leaf_score = torch.from_numpy(saved_data[\"leaf_score.npy\"]).cuda()       # [num_leaf=k1*k2] cluster score\n    leaf_occu_count = torch.from_numpy(saved_data[\"occu_count.npy\"]).cuda()  # [num_leaf=k1*k2] \n    leaf_ind = torch.from_numpy(saved_data[\"leaf_ind.npy\"]).cuda()           # [num_pts] fine id\n    leaf_lang_feat[leaf_occu_count < 5] *= 0.0      # Filter out clusters that occur too infrequently.\n    leaf_cluster_indices = leaf_ind\n    \n    root_num = root_cluster_indices.max() + 1\n    leaf_num = leaf_lang_feat.shape[0] / root_num\n\n    # text feature\n    with open('assets/text_features.json', 'r') as f:\n        data_loaded = json.load(f)\n    all_texts = list(data_loaded.keys())\n    text_features = torch.from_numpy(np.array(list(data_loaded.values()))).to(torch.float32)  # [num_text, 512]\n\n    scene_texts = {\n        \"waldo_kitchen\": ['Stainless steel pots', 'dark cup', 'refrigerator', 'frog cup', 'pot', 'spatula', 'plate', \\\n                'spoon', 'toaster', 'ottolenghi', 'plastic ladle', 'sink', 'ketchup', 'cabinet', 'red cup', \\\n                'pour-over vessel', 'knife', 'yellow desk'],\n        \"ramen\": ['nori', 'sake cup', 'kamaboko', 'corn', 'spoon', 'egg', 'onion segments', 'plate', \\\n                'napkin', 'bowl', 'glass of water', 'hand', 'chopsticks', 'wavy noodles'],\n        \"figurines\": ['jake', 'pirate hat', 'pikachu', 'rubber duck with hat', 'porcelain hand', \\\n                    'red apple', 'tesla door handle', 'waldo', 'bag', 'toy cat statue', 'miffy', \\\n                    'green apple', 'pumpkin', 'rubics cube', 'old camera', 'rubber duck with buoy', \\\n                    'red toy chair', 'pink ice cream', 'spatula', 'green toy chair', 'toy elephant'],\n        \"teatime\": ['sheep', 'yellow pouf', 'stuffed bear', 'coffee mug', 'tea in a glass', 'apple', \n                'coffee', 'hooves', 'bear nose', 'dall-e brand', 'plate', 'paper napkin', 'three cookies', \\\n                'bag of cookies']\n    }\n    # note: query text\n    target_text = scene_texts[scene_name]\n\n    query_text_feats = torch.zeros(len(target_text), 512).cuda()\n    for i, text in enumerate(target_text):\n        feat = text_features[all_texts.index(text)].unsqueeze(0)\n        query_text_feats[i] = feat\n\n    for t_i, text_feat in enumerate(query_text_feats):\n        # if target_text[t_i] != \"old camera\":\n        #     continue\n\n        print(f\"rendering the {t_i+1}-th query of {len(target_text)} texts: {target_text[t_i]}\")\n        # compute cosine similarity\n        text_feat = F.normalize(text_feat.unsqueeze(0), dim=1, p=2)  \n        leaf_lang_feat = F.normalize(leaf_lang_feat, dim=1, p=2)  \n        cosine_similarity = torch.matmul(text_feat, leaf_lang_feat.transpose(0, 1))\n        max_id = torch.argmax(cosine_similarity, dim=-1) # [cluster_num]\n        text_leaf_indices = max_id\n\n        top_values, top_indices = torch.topk(cosine_similarity, 10)\n        for candidate_id in top_indices[0][1:]:\n            if candidate_id - max_id < leaf_num:  # TODO !!!!!!\n                max_feat = leaf_code_book['ins_feat'][max_id]\n                candi_feat = leaf_code_book['ins_feat'][candidate_id]\n                distances = torch.norm(max_feat - candi_feat, dim=1)\n                if distances < 0.9:\n                    text_leaf_indices = torch.cat([text_leaf_indices, candidate_id.unsqueeze(0)])\n\n        # render\n        for idx, view in enumerate(tqdm(views, desc=\"Rendering progress\")):\n            # note: evaluation frame\n            scene_gt_frames = {\n                \"waldo_kitchen\": [\"frame_00053\", \"frame_00066\", \"frame_00089\", \"frame_00140\", \"frame_00154\"],\n                \"ramen\": [\"frame_00006\", \"frame_00024\", \"frame_00060\", \"frame_00065\", \"frame_00081\", \"frame_00119\", \"frame_00128\"],\n                \"figurines\": [\"frame_00041\", \"frame_00105\", \"frame_00152\", \"frame_00195\"],\n                \"teatime\": [\"frame_00002\", \"frame_00025\", \"frame_00043\", \"frame_00107\", \"frame_00129\", \"frame_00140\"]\n            }\n            candidate_frames = scene_gt_frames[scene_name]\n            \n            if  view.image_name not in candidate_frames:\n                continue\n\n            render_pkg = render(view, gaussians, pipeline, background, iteration, rescale=False)\n            # RGB\n            rendering = render_pkg[\"render\"]\n            gt = view.original_image[0:3, :, :]\n\n            # ins_feat\n            rendered_ins_feat = render_pkg[\"ins_feat\"]\n            gt_sam_mask = view.original_sam_mask.cuda()    # [4, H, W]\n\n            # RGB\n            torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + \".png\"))\n            torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + \".png\"))\n\n            # ins_feat\n            torchvision.utils.save_image(rendered_ins_feat[:3,:,:], os.path.join(render_ins_feat_path, '{0:05d}'.format(idx) + \"_1.png\"))\n            torchvision.utils.save_image(rendered_ins_feat[3:6,:,:], os.path.join(render_ins_feat_path, '{0:05d}'.format(idx) + \"_2.png\"))\n\n            # NOTE get SAM id, mask bool, mask_feat, invalid pix\n            mask_id, mask_bool, mask_feat, invalid_pix = \\\n                get_SAM_mask_and_feat(gt_sam_mask, level=3, original_mask_feat=view.original_mask_feat)\n            \n            # sam mask\n            mask_color_rand = colors_defined[mask_id.detach().cpu().type(torch.int64)].type(torch.float64)\n            mask_color_rand = mask_color_rand.permute(2, 0, 1)\n            torchvision.utils.save_image(mask_color_rand/255.0, os.path.join(gt_sam_mask_path, '{0:05d}'.format(idx) + \".png\"))\n            \n            # render target object\n            render_pkg = render(view, gaussians, pipeline, background, iteration,\n                                rescale=False,                #)  # wherther to re-scale the gaussian scale\n                                # cluster_idx=leaf_cluster_indices,     # root id\n                                leaf_cluster_idx=leaf_cluster_indices,  # leaf id\n                                selected_leaf_id=text_leaf_indices.cuda(),  # text query 所选择的 leaf id\n                                render_feat_map=False, \n                                render_cluster=False,\n                                better_vis=True,\n                                seg_rgb=True,\n                                post_process=True,\n                                root_num=root_num, leaf_num=leaf_num)\n            rendered_cluster_imgs = render_pkg[\"leaf_clusters_imgs\"]\n            occured_leaf_id = render_pkg[\"occured_leaf_id\"]\n            rendered_leaf_cluster_silhouettes = render_pkg[\"leaf_cluster_silhouettes\"]\n\n            render_cluster_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"renders_cluster\")\n            render_cluster_silhouette_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"renders_cluster_silhouette\")\n            makedirs(render_cluster_path, exist_ok=True)\n            makedirs(render_cluster_silhouette_path, exist_ok=True)\n            for i, img in enumerate(rendered_cluster_imgs):\n                # save object RGB\n                torchvision.utils.save_image(img[:3,:,:], os.path.join(render_cluster_path, \\\n                    view.image_name + f\"_{target_text[t_i]}.png\"))\n                # save object mask\n                cluster_silhouette = rendered_leaf_cluster_silhouettes[i] > 0.7\n                torchvision.utils.save_image(cluster_silhouette.to(torch.float32), os.path.join(render_cluster_silhouette_path, \\\n                    view.image_name + f\"_{target_text[t_i]}.png\"))\n        \ndef render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool,\n                scene_name: str):\n    with torch.no_grad():\n        gaussians = GaussianModel(dataset.sh_degree)\n        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)\n\n        # bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]\n        bg_color = [1,1,1]\n        background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n\n        if not skip_train:\n             render_set(dataset.model_path, \"text2obj\", scene.loaded_iter, scene.getTrainCameras(), \n                        gaussians, pipeline, background, scene_name)\n        if not skip_test:\n             render_set(dataset.model_path, \"text2obj\", scene.loaded_iter, scene.getTestCameras(), \n                        gaussians, pipeline, background, scene_name)\n\nif __name__ == \"__main__\":\n    # Set up command line argument parser\n    parser = ArgumentParser(description=\"Testing script parameters\")\n    model = ModelParams(parser, sentinel=True)\n    pipeline = PipelineParams(parser)\n    parser.add_argument(\"--iteration\", default=-1, type=int)\n    parser.add_argument(\"--skip_train\", action=\"store_true\")\n    parser.add_argument(\"--skip_test\", action=\"store_true\")\n    parser.add_argument(\"--quiet\", action=\"store_true\")\n    parser.add_argument(\"--scene_name\", type=str, choices=[\"waldo_kitchen\", \"ramen\", \"figurines\", \"teatime\"],\n                        help=\"Specify the scene_name from: figurines, teatime, ramen, waldo_kitchen\")\n    args = get_combined_args(parser)\n    print(\"Rendering \" + args.model_path)\n\n    if not args.scene_name:\n        parser.error(\"The --scene_name argument is required and must be one of: waldo_kitchen, ramen, figurines, teatime\")\n\n    # Initialize system state (RNG)\n    safe_state(args.quiet)\n\n    render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.scene_name)"
  },
  {
    "path": "scene/__init__.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport random\nimport json\nfrom utils.system_utils import searchForMaxIteration\nfrom scene.dataset_readers import sceneLoadTypeCallbacks\nfrom scene.gaussian_model import GaussianModel\nfrom arguments import ModelParams\nfrom utils.camera_utils import cameraList_from_camInfos, camera_to_JSON\n\nclass Scene:\n\n    gaussians : GaussianModel\n\n    def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):\n        \"\"\"b\n        :param path: Path to colmap scene main folder.\n        \"\"\"\n        self.model_path = args.model_path\n        self.loaded_iter = None\n        self.gaussians = gaussians\n\n        if load_iteration:\n            if load_iteration == -1:\n                self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, \"point_cloud\"))\n            else:\n                self.loaded_iter = load_iteration\n            print(\"Loading trained model at iteration {}\".format(self.loaded_iter))\n\n        self.train_cameras = {}\n        self.test_cameras = {}\n\n        if os.path.exists(os.path.join(args.source_path, \"sparse\")):\n            scene_info = sceneLoadTypeCallbacks[\"Colmap\"](args.source_path, args.images, args.eval)\n        elif os.path.exists(os.path.join(args.source_path, \"transforms_train.json\")):\n            print(\"Found transforms_train.json file, assuming Blender data set!\")\n            scene_info = sceneLoadTypeCallbacks[\"Blender\"](args.source_path, args.white_background, args.eval)\n        else:\n            assert False, \"Could not recognize scene type!\"\n\n        if not self.loaded_iter:\n            with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, \"input.ply\") , 'wb') as dest_file:\n                dest_file.write(src_file.read())\n            json_cams = []\n            camlist = []\n            if scene_info.test_cameras:\n                camlist.extend(scene_info.test_cameras)\n            if scene_info.train_cameras:\n                camlist.extend(scene_info.train_cameras)\n            for id, cam in enumerate(camlist):\n                json_cams.append(camera_to_JSON(id, cam))\n            with open(os.path.join(self.model_path, \"cameras.json\"), 'w') as file:\n                json.dump(json_cams, file)\n\n        if shuffle:\n            random.shuffle(scene_info.train_cameras)  # Multi-res consistent random shuffling\n            random.shuffle(scene_info.test_cameras)  # Multi-res consistent random shuffling\n\n        self.cameras_extent = scene_info.nerf_normalization[\"radius\"]\n\n        for resolution_scale in resolution_scales:\n            print(\"Resolution: \", resolution_scale)\n            print(\"Loading Training Cameras\")\n            self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)\n            print(\"Loading Test Cameras\")\n            self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)\n\n        if self.loaded_iter:\n            self.gaussians.load_ply(os.path.join(self.model_path,\n                                                           \"point_cloud\",\n                                                           \"iteration_\" + str(self.loaded_iter),\n                                                           \"point_cloud.ply\"))\n        else:\n            self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)\n\n    def save(self, iteration, save_q=[]):\n        point_cloud_path = os.path.join(self.model_path, \"point_cloud/iteration_{}\".format(iteration))\n        self.gaussians.save_ply(os.path.join(point_cloud_path, \"point_cloud.ply\"), save_q)\n\n    def getTrainCameras(self, scale=1.0):\n        return self.train_cameras[scale]\n\n    def getTestCameras(self, scale=1.0):\n        return self.test_cameras[scale]"
  },
  {
    "path": "scene/cameras.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nfrom torch import nn\nimport numpy as np\nfrom utils.graphics_utils import getWorld2View2, getProjectionMatrix\n\nclass Camera(nn.Module):\n    def __init__(self, colmap_id, R, T, FoVx, FoVy, cx, cy, image, depth, gt_alpha_mask,\n                 gt_sam_mask, gt_mask_feat,\n                 image_name, uid,\n                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = \"cuda\"\n                 ):\n        super(Camera, self).__init__()\n\n        self.uid = uid\n        self.colmap_id = colmap_id\n        self.R = R\n        self.T = T\n        self.FoVx = FoVx\n        self.FoVy = FoVy\n        # modify -----\n        self.cx = cx\n        self.cy = cy\n        # modify -----\n        self.image_name = image_name\n\n        try:\n            self.data_device = torch.device(data_device)\n        except Exception as e:\n            print(e)\n            print(f\"[Warning] Custom device {data_device} failed, fallback to default cuda device\" )\n            self.data_device = torch.device(\"cuda\")\n\n        self.data_on_gpu = True     # note\n        self.original_image = image.clamp(0.0, 1.0).to(self.data_device)\n        # modify -----\n        self.original_mask = gt_alpha_mask.to(self.data_device) if gt_alpha_mask is not None else None\n        \n        # modify -----\n        self.original_sam_mask = gt_sam_mask.to(self.data_device) if gt_sam_mask is not None else None\n        self.original_mask_feat = gt_mask_feat.to(self.data_device) if gt_mask_feat is not None else None\n        self.pesudo_ins_feat = None\n        self.pesudo_mask_bool = None\n        self.cluster_masks = None\n        self.bClusterOccur = None\n\n        self.image_width = self.original_image.shape[2]\n        self.image_height = self.original_image.shape[1]\n\n        if gt_alpha_mask is not None:\n            self.original_image *= gt_alpha_mask.to(self.data_device)\n        else:\n            self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)\n\n        self.zfar = 100.0\n        self.znear = 0.01\n\n        self.trans = trans\n        self.scale = scale\n\n        self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()\n        self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()\n        self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)\n        self.camera_center = self.world_view_transform.inverse()[3, :3]\n    \n    # modify -----\n    def to_gpu(self):\n        for attr_name in dir(self):\n            attr = getattr(self, attr_name)\n            if isinstance(attr, torch.Tensor) and not attr.is_cuda:\n                setattr(self, attr_name, attr.to('cuda'))\n        self.data_on_gpu = True\n\n    # modify -----\n    def to_cpu(self):\n        for attr_name in dir(self):\n            attr = getattr(self, attr_name)\n            if isinstance(attr, torch.Tensor) and attr.is_cuda:\n                setattr(self, attr_name, attr.to('cpu'))\n        self.data_on_gpu = False\n\nclass MiniCam:\n    def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):\n        self.image_width = width\n        self.image_height = height    \n        self.FoVy = fovy\n        self.FoVx = fovx\n        self.znear = znear\n        self.zfar = zfar\n        self.world_view_transform = world_view_transform\n        self.full_proj_transform = full_proj_transform\n        view_inv = torch.inverse(self.world_view_transform)\n        self.camera_center = view_inv[3][:3]\n\n"
  },
  {
    "path": "scene/colmap_loader.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport numpy as np\nimport collections\nimport struct\n\nCameraModel = collections.namedtuple(\n    \"CameraModel\", [\"model_id\", \"model_name\", \"num_params\"])\nCamera = collections.namedtuple(\n    \"Camera\", [\"id\", \"model\", \"width\", \"height\", \"params\"])\nBaseImage = collections.namedtuple(\n    \"Image\", [\"id\", \"qvec\", \"tvec\", \"camera_id\", \"name\", \"xys\", \"point3D_ids\"])\nPoint3D = collections.namedtuple(\n    \"Point3D\", [\"id\", \"xyz\", \"rgb\", \"error\", \"image_ids\", \"point2D_idxs\"])\nCAMERA_MODELS = {\n    CameraModel(model_id=0, model_name=\"SIMPLE_PINHOLE\", num_params=3),\n    CameraModel(model_id=1, model_name=\"PINHOLE\", num_params=4),\n    CameraModel(model_id=2, model_name=\"SIMPLE_RADIAL\", num_params=4),\n    CameraModel(model_id=3, model_name=\"RADIAL\", num_params=5),\n    CameraModel(model_id=4, model_name=\"OPENCV\", num_params=8),\n    CameraModel(model_id=5, model_name=\"OPENCV_FISHEYE\", num_params=8),\n    CameraModel(model_id=6, model_name=\"FULL_OPENCV\", num_params=12),\n    CameraModel(model_id=7, model_name=\"FOV\", num_params=5),\n    CameraModel(model_id=8, model_name=\"SIMPLE_RADIAL_FISHEYE\", num_params=4),\n    CameraModel(model_id=9, model_name=\"RADIAL_FISHEYE\", num_params=5),\n    CameraModel(model_id=10, model_name=\"THIN_PRISM_FISHEYE\", num_params=12)\n}\nCAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)\n                         for camera_model in CAMERA_MODELS])\nCAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)\n                           for camera_model in CAMERA_MODELS])\n\n\ndef qvec2rotmat(qvec):\n    return np.array([\n        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,\n         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],\n         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],\n        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],\n         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,\n         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],\n        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],\n         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],\n         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])\n\ndef rotmat2qvec(R):\n    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat\n    K = np.array([\n        [Rxx - Ryy - Rzz, 0, 0, 0],\n        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],\n        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],\n        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0\n    eigvals, eigvecs = np.linalg.eigh(K)\n    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]\n    if qvec[0] < 0:\n        qvec *= -1\n    return qvec\n\nclass Image(BaseImage):\n    def qvec2rotmat(self):\n        return qvec2rotmat(self.qvec)\n\ndef read_next_bytes(fid, num_bytes, format_char_sequence, endian_character=\"<\"):\n    \"\"\"Read and unpack the next bytes from a binary file.\n    :param fid:\n    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.\n    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.\n    :param endian_character: Any of {@, =, <, >, !}\n    :return: Tuple of read and unpacked values.\n    \"\"\"\n    data = fid.read(num_bytes)\n    return struct.unpack(endian_character + format_char_sequence, data)\n\ndef read_points3D_text(path):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DText(const std::string& path)\n        void Reconstruction::WritePoints3DText(const std::string& path)\n    \"\"\"\n    xyzs = None\n    rgbs = None\n    errors = None\n    num_points = 0\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                num_points += 1\n\n\n    xyzs = np.empty((num_points, 3))\n    rgbs = np.empty((num_points, 3))\n    errors = np.empty((num_points, 1))\n    count = 0\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                xyz = np.array(tuple(map(float, elems[1:4])))\n                rgb = np.array(tuple(map(int, elems[4:7])))\n                error = np.array(float(elems[7]))\n                xyzs[count] = xyz\n                rgbs[count] = rgb\n                errors[count] = error\n                count += 1\n\n    return xyzs, rgbs, errors\n\ndef read_points3D_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadPoints3DBinary(const std::string& path)\n        void Reconstruction::WritePoints3DBinary(const std::string& path)\n    \"\"\"\n\n\n    with open(path_to_model_file, \"rb\") as fid:\n        num_points = read_next_bytes(fid, 8, \"Q\")[0]\n\n        xyzs = np.empty((num_points, 3))\n        rgbs = np.empty((num_points, 3))\n        errors = np.empty((num_points, 1))\n\n        for p_id in range(num_points):\n            binary_point_line_properties = read_next_bytes(\n                fid, num_bytes=43, format_char_sequence=\"QdddBBBd\")\n            xyz = np.array(binary_point_line_properties[1:4])\n            rgb = np.array(binary_point_line_properties[4:7])\n            error = np.array(binary_point_line_properties[7])\n            track_length = read_next_bytes(\n                fid, num_bytes=8, format_char_sequence=\"Q\")[0]\n            track_elems = read_next_bytes(\n                fid, num_bytes=8*track_length,\n                format_char_sequence=\"ii\"*track_length)\n            xyzs[p_id] = xyz\n            rgbs[p_id] = rgb\n            errors[p_id] = error\n    return xyzs, rgbs, errors\n\ndef read_intrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    cameras = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                camera_id = int(elems[0])\n                model = elems[1]\n                assert model == \"PINHOLE\", \"While the loader support other types, the rest of the code assumes PINHOLE\"\n                width = int(elems[2])\n                height = int(elems[3])\n                params = np.array(tuple(map(float, elems[4:])))\n                cameras[camera_id] = Camera(id=camera_id, model=model,\n                                            width=width, height=height,\n                                            params=params)\n    return cameras\n\ndef read_extrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::ReadImagesBinary(const std::string& path)\n        void Reconstruction::WriteImagesBinary(const std::string& path)\n    \"\"\"\n    images = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_reg_images = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_reg_images):\n            binary_image_properties = read_next_bytes(\n                fid, num_bytes=64, format_char_sequence=\"idddddddi\")\n            image_id = binary_image_properties[0]\n            qvec = np.array(binary_image_properties[1:5])\n            tvec = np.array(binary_image_properties[5:8])\n            camera_id = binary_image_properties[8]\n            image_name = \"\"\n            current_char = read_next_bytes(fid, 1, \"c\")[0]\n            while current_char != b\"\\x00\":   # look for the ASCII 0 entry\n                image_name += current_char.decode(\"utf-8\")\n                current_char = read_next_bytes(fid, 1, \"c\")[0]\n            num_points2D = read_next_bytes(fid, num_bytes=8,\n                                           format_char_sequence=\"Q\")[0]\n            x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,\n                                       format_char_sequence=\"ddq\"*num_points2D)\n            xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),\n                                   tuple(map(float, x_y_id_s[1::3]))])\n            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))\n            images[image_id] = Image(\n                id=image_id, qvec=qvec, tvec=tvec,\n                camera_id=camera_id, name=image_name,\n                xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_intrinsics_binary(path_to_model_file):\n    \"\"\"\n    see: src/base/reconstruction.cc\n        void Reconstruction::WriteCamerasBinary(const std::string& path)\n        void Reconstruction::ReadCamerasBinary(const std::string& path)\n    \"\"\"\n    cameras = {}\n    with open(path_to_model_file, \"rb\") as fid:\n        num_cameras = read_next_bytes(fid, 8, \"Q\")[0]\n        for _ in range(num_cameras):\n            camera_properties = read_next_bytes(\n                fid, num_bytes=24, format_char_sequence=\"iiQQ\")\n            camera_id = camera_properties[0]\n            model_id = camera_properties[1]\n            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name\n            width = camera_properties[2]\n            height = camera_properties[3]\n            num_params = CAMERA_MODEL_IDS[model_id].num_params\n            params = read_next_bytes(fid, num_bytes=8*num_params,\n                                     format_char_sequence=\"d\"*num_params)\n            cameras[camera_id] = Camera(id=camera_id,\n                                        model=model_name,\n                                        width=width,\n                                        height=height,\n                                        params=np.array(params))\n        assert len(cameras) == num_cameras\n    return cameras\n\n\ndef read_extrinsics_text(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py\n    \"\"\"\n    images = {}\n    with open(path, \"r\") as fid:\n        while True:\n            line = fid.readline()\n            if not line:\n                break\n            line = line.strip()\n            if len(line) > 0 and line[0] != \"#\":\n                elems = line.split()\n                image_id = int(elems[0])\n                qvec = np.array(tuple(map(float, elems[1:5])))\n                tvec = np.array(tuple(map(float, elems[5:8])))\n                camera_id = int(elems[8])\n                image_name = elems[9]\n                elems = fid.readline().split()\n                xys = np.column_stack([tuple(map(float, elems[0::3])),\n                                       tuple(map(float, elems[1::3]))])\n                point3D_ids = np.array(tuple(map(int, elems[2::3])))\n                images[image_id] = Image(\n                    id=image_id, qvec=qvec, tvec=tvec,\n                    camera_id=camera_id, name=image_name,\n                    xys=xys, point3D_ids=point3D_ids)\n    return images\n\n\ndef read_colmap_bin_array(path):\n    \"\"\"\n    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py\n\n    :param path: path to the colmap binary file.\n    :return: nd array with the floating point values in the value\n    \"\"\"\n    with open(path, \"rb\") as fid:\n        width, height, channels = np.genfromtxt(fid, delimiter=\"&\", max_rows=1,\n                                                usecols=(0, 1, 2), dtype=int)\n        fid.seek(0)\n        num_delimiter = 0\n        byte = fid.read(1)\n        while True:\n            if byte == b\"&\":\n                num_delimiter += 1\n                if num_delimiter >= 3:\n                    break\n            byte = fid.read(1)\n        array = np.fromfile(fid, np.float32)\n    array = array.reshape((width, height, channels), order=\"F\")\n    return np.transpose(array, (1, 0, 2)).squeeze()\n"
  },
  {
    "path": "scene/dataset_readers.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport sys\nfrom PIL import Image\nfrom typing import NamedTuple\nfrom scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \\\n    read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text\nfrom utils.graphics_utils import getWorld2View2, focal2fov, fov2focal\nimport numpy as np\nimport json\nimport random\nfrom tqdm import tqdm\nfrom pathlib import Path\nfrom plyfile import PlyData, PlyElement\nfrom utils.sh_utils import SH2RGB\nfrom scene.gaussian_model import BasicPointCloud\n\nclass CameraInfo(NamedTuple):\n    uid: int\n    R: np.array\n    T: np.array\n    FovY: np.array\n    FovX: np.array\n    cx: np.array\n    cy: np.array\n    image: np.array\n    depth: np.array     # not used\n    sam_mask: np.array  # modify -----\n    mask_feat: np.array # modify -----\n    image_path: str\n    image_name: str\n    width: int\n    height: int\n\nclass SceneInfo(NamedTuple):\n    point_cloud: BasicPointCloud\n    train_cameras: list\n    test_cameras: list\n    nerf_normalization: dict\n    ply_path: str\n\ndef getNerfppNorm(cam_info):\n    def get_center_and_diag(cam_centers):\n        cam_centers = np.hstack(cam_centers)\n        avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)\n        center = avg_cam_center\n        dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)\n        diagonal = np.max(dist)\n        return center.flatten(), diagonal\n\n    cam_centers = []\n\n    for cam in cam_info:\n        W2C = getWorld2View2(cam.R, cam.T)\n        C2W = np.linalg.inv(W2C)\n        cam_centers.append(C2W[:3, 3:4])\n\n    center, diagonal = get_center_and_diag(cam_centers)\n    radius = diagonal * 1.1\n\n    translate = -center\n\n    return {\"translate\": translate, \"radius\": radius}\n\ndef readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):\n    cam_infos = []\n\n    for idx, key in enumerate(cam_extrinsics):\n        sys.stdout.write('\\r')\n        # the exact output you're looking for:\n        sys.stdout.write(\"Reading camera {}/{}\".format(idx+1, len(cam_extrinsics)))\n        sys.stdout.flush()\n\n        extr = cam_extrinsics[key]\n        intr = cam_intrinsics[extr.camera_id]\n        height = intr.height\n        width = intr.width\n\n        uid = intr.id\n        R = np.transpose(qvec2rotmat(extr.qvec))\n        T = np.array(extr.tvec)\n\n        if intr.model==\"SIMPLE_PINHOLE\":\n            focal_length_x = intr.params[0]\n            FovY = focal2fov(focal_length_x, height)\n            FovX = focal2fov(focal_length_x, width)\n        elif intr.model==\"PINHOLE\":\n            focal_length_x = intr.params[0]\n            focal_length_y = intr.params[1]\n            FovY = focal2fov(focal_length_y, height)\n            FovX = focal2fov(focal_length_x, width)\n        else:\n            assert False, \"Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!\"\n\n        image_path = os.path.join(images_folder, os.path.basename(extr.name))\n        if not os.path.exists(image_path):\n            # modify -----\n            base, ext = os.path.splitext(image_path)\n            if ext.lower() == \".jpg\":\n                image_path = base + \".png\"\n            elif ext.lower() == \".png\":\n                image_path = base + \".jpg\"\n            if not os.path.exists(image_path):\n                continue\n            # modify ----\n\n        image_name = os.path.basename(image_path).split(\".\")[0]\n        image = Image.open(image_path)\n\n        # NOTE: load SAM mask and CLIP feat. [OpenGaussian]\n        mask_seg_path = os.path.join(images_folder[:-6], \"language_features/\" + extr.name.split('/')[-1][:-4] + \"_s.npy\")\n        mask_feat_path = os.path.join(images_folder[:-6], \"language_features/\" + extr.name.split('/')[-1][:-4] + \"_f.npy\")\n        if os.path.exists(mask_seg_path):\n            sam_mask = np.load(mask_seg_path)    # [level=4, H, W]\n        else:\n            sam_mask = None\n        if mask_feat_path is not None and os.path.exists(mask_feat_path):\n            mask_feat = np.load(mask_feat_path)    # [level=4, H, W]\n        else:\n            mask_feat = None\n        # modify -----\n\n        cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, cx=width/2, cy=height/2, image=image, \n                              depth=None, sam_mask=sam_mask, mask_feat=mask_feat,\n                              image_path=image_path, image_name=image_name, width=width, height=height)\n        cam_infos.append(cam_info)\n    sys.stdout.write('\\n')\n    return cam_infos\n\ndef fetchPly(path):\n    plydata = PlyData.read(path)\n    vertices = plydata['vertex']\n    positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T\n    if {'red', 'green', 'blue'}.issubset(vertices.data.dtype.names):\n        colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0\n    else:\n        colors = np.random.rand(positions.shape[0], 3)\n    if {'nx', 'ny', 'nz'}.issubset(vertices.data.dtype.names):\n        normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T\n    else:\n        normals = np.random.rand(positions.shape[0], 3)\n\n    return BasicPointCloud(points=positions, colors=colors, normals=normals)\n\ndef storePly(path, xyz, rgb):\n    # Define the dtype for the structured array\n    dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),\n            ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),\n            ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]\n    \n    normals = np.zeros_like(xyz)\n\n    elements = np.empty(xyz.shape[0], dtype=dtype)\n    attributes = np.concatenate((xyz, normals, rgb), axis=1)\n    elements[:] = list(map(tuple, attributes))\n\n    # Create the PlyData object and write to file\n    vertex_element = PlyElement.describe(elements, 'vertex')\n    ply_data = PlyData([vertex_element])\n    ply_data.write(path)\n\ndef readColmapSceneInfo(path, images, eval, llffhold=8):\n    try:\n        cameras_extrinsic_file = os.path.join(path, \"sparse/0\", \"images.bin\")\n        cameras_intrinsic_file = os.path.join(path, \"sparse/0\", \"cameras.bin\")\n        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)\n        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)\n    except:\n        cameras_extrinsic_file = os.path.join(path, \"sparse/0\", \"images.txt\")\n        cameras_intrinsic_file = os.path.join(path, \"sparse/0\", \"cameras.txt\")\n        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)\n        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)\n\n    reading_dir = \"images\" if images == None else images\n    cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))\n    cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)\n\n    if eval:\n        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]\n        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]\n    else:\n        train_cam_infos = cam_infos\n        test_cam_infos = []\n\n    nerf_normalization = getNerfppNorm(train_cam_infos)\n\n    ply_path = os.path.join(path, \"sparse/0/points3D.ply\")\n    bin_path = os.path.join(path, \"sparse/0/points3D.bin\")\n    txt_path = os.path.join(path, \"sparse/0/points3D.txt\")\n    if not os.path.exists(ply_path):\n        print(\"Converting point3d.bin to .ply, will happen only the first time you open the scene.\")\n        try:\n            xyz, rgb, _ = read_points3D_binary(bin_path)\n        except:\n            xyz, rgb, _ = read_points3D_text(txt_path)\n        storePly(ply_path, xyz, rgb)\n    try:\n        pcd = fetchPly(ply_path)\n    except:\n        pcd = None\n\n    scene_info = SceneInfo(point_cloud=pcd,\n                           train_cameras=train_cam_infos,\n                           test_cameras=test_cam_infos,\n                           nerf_normalization=nerf_normalization,\n                           ply_path=ply_path)\n    return scene_info\n\ndef readCamerasFromTransforms(path, transformsfile, white_background, extension=\".png\"):\n    cam_infos = []\n\n    with open(os.path.join(path, transformsfile)) as json_file:\n        contents = json.load(json_file)\n\n        # ----- modify -----\n        if \"camera_angle_x\" not in contents.keys():\n            fovx = None\n        else:\n            fovx = contents[\"camera_angle_x\"] \n        # ----- modify -----\n\n        # modify -----\n        cx, cy = -1, -1\n        if \"cx\" in contents.keys():\n            cx = contents[\"cx\"]\n            cy = contents[\"cy\"]\n        elif \"h\" in contents.keys():\n            cx = contents[\"w\"] / 2\n            cy = contents[\"h\"] / 2\n        # modify -----\n\n        frames = contents[\"frames\"]\n        # for idx, frame in enumerate(frames):\n        for idx, frame in tqdm(enumerate(frames), total=len(frames), desc=\"load images\"):\n            cam_name = os.path.join(path, frame[\"file_path\"] + extension)\n\n            # NeRF 'transform_matrix' is a camera-to-world transform\n            c2w = np.array(frame[\"transform_matrix\"])\n            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)\n            c2w[:3, 1:3] *= -1    # TODO\n\n            # get the world-to-camera transform and set R, T\n            w2c = np.linalg.inv(c2w)\n            R = np.transpose(w2c[:3,:3])  # R is stored transposed due to 'glm' in CUDA code\n            T = w2c[:3, 3]\n\n            image_path = os.path.join(path, cam_name)\n            if not os.path.exists(image_path):\n                # modify -----\n                base, ext = os.path.splitext(image_path)\n                if ext.lower() == \".jpg\":\n                    image_path = base + \".png\"\n                elif ext.lower() == \".png\":\n                    image_path = base + \".jpg\"\n                if not os.path.exists(image_path):\n                    continue\n                # modify ----\n\n            image_name = Path(cam_name).stem\n            image = Image.open(image_path)\n\n            im_data = np.array(image.convert(\"RGBA\"))\n\n            bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])\n\n            norm_data = im_data / 255.0\n            arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])\n            image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), \"RGB\")\n\n            # NOTE: load SAM mask and CLIP feat. [OpenGaussian]\n            mask_seg_path = os.path.join(path, \"language_features/\" + frame[\"file_path\"].split('/')[-1] + \"_s.npy\")\n            mask_feat_path = os.path.join(path, \"language_features/\" + frame[\"file_path\"].split('/')[-1] + \"_f.npy\")\n            if os.path.exists(mask_seg_path):\n                sam_mask = np.load(mask_seg_path)    # [level=4, H, W]\n            else:\n                sam_mask = None\n            if os.path.exists(mask_feat_path):\n                mask_feat = np.load(mask_feat_path)  # [num_mask, dim=512]\n            else:\n                mask_feat = None\n            # modify -----\n\n            # ----- modify -----\n            if \"K\" in frame.keys():\n                cx = frame[\"K\"][0][2]\n                cy = frame[\"K\"][1][2]\n            if cx == -1:\n                cx = image.size[0] / 2\n                cy = image.size[1] / 2\n            # ----- modify -----\n\n            # ----- modify -----\n            if fovx == None:\n                if \"K\" in frame.keys():\n                    focal_length = frame[\"K\"][0][0]\n                if \"fl_x\" in contents.keys():\n                    focal_length = contents[\"fl_x\"]\n                if \"fl_x\" in frame.keys():\n                    focal_length = frame[\"fl_x\"]\n                FovY = focal2fov(focal_length, image.size[1])\n                FovX = focal2fov(focal_length, image.size[0])\n            else:\n                fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])\n                FovY = fovx \n                FovX = fovy\n            # ----- modify -----\n\n            cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, cx=cx, cy=cy, image=image, \n                            depth=None, sam_mask=sam_mask, mask_feat=mask_feat,\n                            image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))\n            \n    return cam_infos\n\ndef readNerfSyntheticInfo(path, white_background, eval, extension=\".png\"):\n    print(\"Reading Training Transforms\")\n    train_cam_infos = readCamerasFromTransforms(path, \"transforms_train.json\", white_background, extension)\n    print(\"Reading Test Transforms\")\n    if os.path.exists(os.path.join(path, \"transforms_test.json\")):\n        test_cam_infos = readCamerasFromTransforms(path, \"transforms_test.json\", white_background, extension)\n    else:\n        test_cam_infos = train_cam_infos\n    \n    if not eval:\n        train_cam_infos.extend(test_cam_infos)\n        test_cam_infos = []\n\n    nerf_normalization = getNerfppNorm(train_cam_infos)\n\n    ply_path = os.path.join(path, \"points3d.ply\")\n    if not os.path.exists(ply_path):\n        # Since this data set has no colmap data, we start with random points\n        num_pts = 100_000\n        print(f\"Generating random point cloud ({num_pts})...\")\n        \n        # We create random points inside the bounds of the synthetic Blender scenes\n        xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3\n        shs = np.random.random((num_pts, 3)) / 255.0\n        pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))\n\n        storePly(ply_path, xyz, SH2RGB(shs) * 255)\n    try:\n        pcd = fetchPly(ply_path)\n    except:\n        pcd = None\n\n    scene_info = SceneInfo(point_cloud=pcd,\n                           train_cameras=train_cam_infos,\n                           test_cameras=test_cam_infos,\n                           nerf_normalization=nerf_normalization,\n                           ply_path=ply_path)\n    return scene_info\n\nsceneLoadTypeCallbacks = {\n    \"Colmap\": readColmapSceneInfo,\n    \"Blender\" : readNerfSyntheticInfo\n}"
  },
  {
    "path": "scene/gaussian_model.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport numpy as np\nfrom utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation\nfrom torch import nn\nimport os\nfrom utils.system_utils import mkdir_p\nfrom plyfile import PlyData, PlyElement\nfrom utils.sh_utils import RGB2SH\n# from simple_knn._C import distCUDA2   # no need\nfrom scipy.spatial import KDTree        # modify\nfrom utils.graphics_utils import BasicPointCloud\nfrom utils.general_utils import strip_symmetric, build_scaling_rotation\n\ndef sigmoid(x):  \n    return 1 / (1 + np.exp(-x))  \n\ndef distCUDA2(points):\n    '''\n    https://github.com/graphdeco-inria/gaussian-splatting/issues/292\n    '''\n    points_np = points.detach().cpu().float().numpy()\n    dists, inds = KDTree(points_np).query(points_np, k=4)\n    meanDists = (dists[:, 1:] ** 2).mean(1)\n\n    return torch.tensor(meanDists, dtype=points.dtype, device=points.device)\n\nclass GaussianModel:\n\n    def setup_functions(self):\n        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):\n            L = build_scaling_rotation(scaling_modifier * scaling, rotation)\n            actual_covariance = L @ L.transpose(1, 2)\n            symm = strip_symmetric(actual_covariance)\n            return symm\n        \n        self.scaling_activation = torch.exp\n        self.scaling_inverse_activation = torch.log\n\n        self.covariance_activation = build_covariance_from_scaling_rotation\n\n        self.opacity_activation = torch.sigmoid\n        self.inverse_opacity_activation = inverse_sigmoid\n\n        self.rotation_activation = torch.nn.functional.normalize\n\n\n    def __init__(self, sh_degree : int):\n        self.active_sh_degree = 0\n        self.max_sh_degree = sh_degree  \n        self._xyz = torch.empty(0)\n        self._features_dc = torch.empty(0)\n        self._features_rest = torch.empty(0)\n        self._scaling = torch.empty(0)\n        self._rotation = torch.empty(0)\n        self._opacity = torch.empty(0)\n        self._ins_feat = torch.empty(0)     # Continuous instance features before quantization\n        self._ins_feat_q = torch.empty(0)   # Discrete instance features after quantization\n        self.iClusterSubNum = torch.empty(0)\n        self.max_radii2D = torch.empty(0)\n        self.xyz_gradient_accum = torch.empty(0)\n        self.denom = torch.empty(0)\n        self.optimizer = None\n        self.percent_dense = 0\n        self.spatial_lr_scale = 0\n        self.setup_functions()\n\n    def capture(self):\n        return (\n            self.active_sh_degree,\n            self._xyz,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self._ins_feat,     # Continuous instance features before quantization\n            self._ins_feat_q,   # Discrete instance features after quantization\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            self.optimizer.state_dict(),\n            self.spatial_lr_scale,\n        )\n    \n    def restore(self, model_args, training_args):\n        (self.active_sh_degree, \n        self._xyz, \n        self._features_dc, \n        self._features_rest,\n        self._scaling, \n        self._rotation, \n        self._opacity,\n        self._ins_feat,     # Continuous instance features before quantization\n        self._ins_feat_q,   # Discrete instance features after quantization\n        self.max_radii2D, \n        xyz_gradient_accum, \n        denom,\n        opt_dict, \n        self.spatial_lr_scale) = model_args\n        self.training_setup(training_args)\n        self.xyz_gradient_accum = xyz_gradient_accum\n        self.denom = denom\n        self.optimizer.load_state_dict(opt_dict)\n\n    @property\n    def get_scaling(self):\n        return self.scaling_activation(self._scaling)\n    \n    @property\n    def get_scaling_origin(self):\n        return self.scaling_activation(self._scaling)\n    \n    @property\n    def get_rotation(self):\n        return self.rotation_activation(self._rotation)\n    \n    @property\n    def get_rotation_matrix(self):\n        return build_rotation(self._rotation)\n    \n    @property\n    def get_eigenvector(self):\n        scales = self.get_scaling_origin\n        N = scales.shape[0]\n        idx = torch.min(scales, dim=1)[1]\n        normals = self.get_rotation_matrix[np.arange(N), :, idx]\n        normals = torch.nn.functional.normalize(normals, dim=1)\n        return normals\n    \n    @property\n    def get_xyz(self):\n        return self._xyz\n    \n    @property\n    def get_features(self):\n        features_dc = self._features_dc\n        features_rest = self._features_rest\n        return torch.cat((features_dc, features_rest), dim=1)\n    \n    @property\n    def get_opacity(self):\n        return self.opacity_activation(self._opacity)\n    \n    # NOTE: get instance feature\n    # @property\n    def get_ins_feat(self, origin=False):\n        if len(self._ins_feat_q) == 0 or origin:\n            ins_feat = self._ins_feat\n        else:\n            ins_feat = self._ins_feat_q\n        ins_feat = torch.nn.functional.normalize(ins_feat, dim=1)\n        return ins_feat\n    \n    def get_covariance(self, scaling_modifier = 1):\n        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)\n\n    def oneupSHdegree(self):\n        if self.active_sh_degree < self.max_sh_degree:\n            self.active_sh_degree += 1\n\n    def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):\n        self.spatial_lr_scale = spatial_lr_scale\n        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()\n        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())\n        features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() # [N, 3, 16]\n        features[:, :3, 0 ] = fused_color\n        features[:, 3:, 1:] = 0.0\n\n        print(\"Number of points at initialisation : \", fused_point_cloud.shape[0])\n\n        dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)\n        scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)\n        rots = torch.zeros((fused_point_cloud.shape[0], 4), device=\"cuda\")\n        rots[:, 0] = 1\n\n        opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device=\"cuda\"))\n\n        # modify -----\n        ins_feat = torch.rand((fused_point_cloud.shape[0], 6), dtype=torch.float, device=\"cuda\")\n\n        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))\n        self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))\n        self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))\n        self._scaling = nn.Parameter(scales.requires_grad_(True))\n        self._rotation = nn.Parameter(rots.requires_grad_(True))\n        self._opacity = nn.Parameter(opacities.requires_grad_(True))\n        # modify -----\n        self._ins_feat = nn.Parameter(ins_feat.requires_grad_(True))\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n    def training_setup(self, training_args):\n        self.percent_dense = training_args.percent_dense\n        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n\n        l = [\n            {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, \"name\": \"xyz\"},\n            {'params': [self._features_dc], 'lr': training_args.feature_lr, \"name\": \"f_dc\"},\n            {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, \"name\": \"f_rest\"},\n            {'params': [self._opacity], 'lr': training_args.opacity_lr, \"name\": \"opacity\"},\n            {'params': [self._scaling], 'lr': training_args.scaling_lr, \"name\": \"scaling\"},\n            {'params': [self._rotation], 'lr': training_args.rotation_lr, \"name\": \"rotation\"},\n            {'params': [self._ins_feat], 'lr': training_args.ins_feat_lr, \"name\": \"ins_feat\"}  # modify -----\n        ]\n\n        # note: Freeze the position of the initial point, do not densify. for ScanNet 3DGS pre-train stage\n        if training_args.frozen_init_pts:\n            self._xyz = self._xyz.detach()\n\n        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)\n        self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,\n                                                    lr_final=training_args.position_lr_final*self.spatial_lr_scale,\n                                                    lr_delay_mult=training_args.position_lr_delay_mult,\n                                                    max_steps=training_args.position_lr_max_steps)\n\n    def update_learning_rate(self, iteration, root_start, leaf_start):\n        ''' Learning rate scheduling per step '''\n        for param_group in self.optimizer.param_groups:\n            if param_group[\"name\"] == \"xyz\":\n                lr = self.xyz_scheduler_args(iteration)\n                param_group['lr'] = lr\n                # return lr\n            if param_group[\"name\"] == \"ins_feat\":\n                if iteration > root_start and iteration <= leaf_start:      # TODO: update lr\n                    param_group['lr'] = param_group['lr'] * 0 + 0.0001\n                else:\n                    param_group['lr'] = param_group['lr'] * 0 + 0.001\n\n    def construct_list_of_attributes(self):\n        l = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'ins_feat_r', 'ins_feat_g', 'ins_feat_b', \\\n            'ins_feat_r2', 'ins_feat_g2', 'ins_feat_b2']\n        # All channels except the 3 DC\n        for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):\n            l.append('f_dc_{}'.format(i))\n        for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):\n            l.append('f_rest_{}'.format(i))\n        l.append('opacity')\n        for i in range(self._scaling.shape[1]):\n            l.append('scale_{}'.format(i))\n        for i in range(self._rotation.shape[1]):\n            l.append('rot_{}'.format(i))\n        return l\n\n    def save_ply(self, path, save_q=[]):\n        mkdir_p(os.path.dirname(path))\n\n        xyz = self._xyz.detach().cpu().numpy()\n        normals = np.zeros_like(xyz)\n        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n        f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n        opacities = self._opacity.detach().cpu().numpy()\n        scale = self._scaling.detach().cpu().numpy()\n        rotation = self._rotation.detach().cpu().numpy()\n        if \"ins_feat\" in save_q:\n            ins_feat = self._ins_feat_q.detach().cpu().numpy()\n        else:\n            ins_feat = self._ins_feat.detach().cpu().numpy()\n\n        # NOTE: pts feat visualization\n        vis_color = (ins_feat + 1) / 2 * 255\n        r, g, b = vis_color[:, 0].reshape(-1, 1), vis_color[:, 1].reshape(-1, 1), vis_color[:, 2].reshape(-1, 1)\n\n        # todo: points not fully optimized due to sampled training images.\n        ignored_ind = sigmoid(opacities) < 0.1\n        r[ignored_ind], g[ignored_ind], b[ignored_ind] = 128, 128, 128\n\n        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]\n        dtype_full = dtype_full + [('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]  # modify\n\n        elements = np.empty(xyz.shape[0], dtype=dtype_full)\n        attributes = np.concatenate((xyz, normals, ins_feat,\\\n                                    f_dc, f_rest, opacities, scale, rotation,\\\n                                    r, g, b), axis=1)\n        elements[:] = list(map(tuple, attributes))\n        el = PlyElement.describe(elements, 'vertex')\n        PlyData([el]).write(path)\n\n    def reset_opacity(self):\n        opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))\n        optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, \"opacity\")\n        self._opacity = optimizable_tensors[\"opacity\"]\n\n    def load_ply(self, path):\n        plydata = PlyData.read(path)\n\n        xyz = np.stack((np.asarray(plydata.elements[0][\"x\"]),\n                        np.asarray(plydata.elements[0][\"y\"]),\n                        np.asarray(plydata.elements[0][\"z\"])),  axis=1)\n        ins_feat = np.stack((np.asarray(plydata.elements[0][\"ins_feat_r\"]),\n                        np.asarray(plydata.elements[0][\"ins_feat_g\"]),\n                        np.asarray(plydata.elements[0][\"ins_feat_b\"]),\n                        np.asarray(plydata.elements[0][\"ins_feat_r2\"]),\n                        np.asarray(plydata.elements[0][\"ins_feat_g2\"]),\n                        np.asarray(plydata.elements[0][\"ins_feat_b2\"])),  axis=1)\n        opacities = np.asarray(plydata.elements[0][\"opacity\"])[..., np.newaxis]\n        if not opacities.flags['C_CONTIGUOUS']:\n            opacities = np.ascontiguousarray(opacities)\n\n        features_dc = np.zeros((xyz.shape[0], 3, 1))\n        features_dc[:, 0, 0] = np.asarray(plydata.elements[0][\"f_dc_0\"])\n        features_dc[:, 1, 0] = np.asarray(plydata.elements[0][\"f_dc_1\"])\n        features_dc[:, 2, 0] = np.asarray(plydata.elements[0][\"f_dc_2\"])\n\n        extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"f_rest_\")]\n        extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))\n        assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3\n        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))\n        for idx, attr_name in enumerate(extra_f_names):\n            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])\n        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)\n        features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))\n\n        scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"scale_\")]\n        scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))\n        scales = np.zeros((xyz.shape[0], len(scale_names)))\n        for idx, attr_name in enumerate(scale_names):\n            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])\n\n        rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"rot\")]\n        rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))\n        rots = np.zeros((xyz.shape[0], len(rot_names)))\n        for idx, attr_name in enumerate(rot_names):\n            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])\n\n        self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device=\"cuda\").transpose(1, 2).contiguous().requires_grad_(True))\n        self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device=\"cuda\").transpose(1, 2).contiguous().requires_grad_(True))\n        self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._ins_feat = nn.Parameter(torch.tensor(ins_feat, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n\n        self.active_sh_degree = self.max_sh_degree\n\n    def replace_tensor_to_optimizer(self, tensor, name):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            if group[\"name\"] == name:\n                stored_state = self.optimizer.state.get(group['params'][0], None)\n                stored_state[\"exp_avg\"] = torch.zeros_like(tensor)\n                stored_state[\"exp_avg_sq\"] = torch.zeros_like(tensor)\n\n                del self.optimizer.state[group['params'][0]]\n                group[\"params\"][0] = nn.Parameter(tensor.requires_grad_(True))\n                self.optimizer.state[group['params'][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n        return optimizable_tensors\n\n    def _prune_optimizer(self, mask):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            stored_state = self.optimizer.state.get(group['params'][0], None)\n            if stored_state is not None:\n                stored_state[\"exp_avg\"] = stored_state[\"exp_avg\"][mask]\n                stored_state[\"exp_avg_sq\"] = stored_state[\"exp_avg_sq\"][mask]\n\n                del self.optimizer.state[group['params'][0]]\n                group[\"params\"][0] = nn.Parameter((group[\"params\"][0][mask].requires_grad_(True)))\n                self.optimizer.state[group['params'][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n            else:\n                group[\"params\"][0] = nn.Parameter(group[\"params\"][0][mask].requires_grad_(True))\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n        return optimizable_tensors\n\n    def prune_points(self, mask):\n        valid_points_mask = ~mask\n        optimizable_tensors = self._prune_optimizer(valid_points_mask)\n\n        self._xyz = optimizable_tensors[\"xyz\"]\n        self._features_dc = optimizable_tensors[\"f_dc\"]\n        self._features_rest = optimizable_tensors[\"f_rest\"]\n        self._opacity = optimizable_tensors[\"opacity\"]\n        self._scaling = optimizable_tensors[\"scaling\"]\n        self._rotation = optimizable_tensors[\"rotation\"]\n        self._ins_feat = optimizable_tensors[\"ins_feat\"]\n\n        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]\n\n        self.denom = self.denom[valid_points_mask]\n        self.max_radii2D = self.max_radii2D[valid_points_mask]\n\n    def cat_tensors_to_optimizer(self, tensors_dict):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            assert len(group[\"params\"]) == 1\n            extension_tensor = tensors_dict[group[\"name\"]]\n            stored_state = self.optimizer.state.get(group['params'][0], None)\n            if stored_state is not None:\n\n                stored_state[\"exp_avg\"] = torch.cat((stored_state[\"exp_avg\"], torch.zeros_like(extension_tensor)), dim=0)\n                stored_state[\"exp_avg_sq\"] = torch.cat((stored_state[\"exp_avg_sq\"], torch.zeros_like(extension_tensor)), dim=0)\n\n                del self.optimizer.state[group['params'][0]]\n                group[\"params\"][0] = nn.Parameter(torch.cat((group[\"params\"][0], extension_tensor), dim=0).requires_grad_(True))\n                self.optimizer.state[group['params'][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n            else:\n                group[\"params\"][0] = nn.Parameter(torch.cat((group[\"params\"][0], extension_tensor), dim=0).requires_grad_(True))\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n\n        return optimizable_tensors\n\n    def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, \\\n                                new_scaling, new_rotation, new_ins_feat):\n        d = {\"xyz\": new_xyz,\n        \"f_dc\": new_features_dc,\n        \"f_rest\": new_features_rest,\n        \"opacity\": new_opacities,\n        \"scaling\" : new_scaling,\n        \"rotation\" : new_rotation,\n        \"ins_feat\": new_ins_feat}\n\n        optimizable_tensors = self.cat_tensors_to_optimizer(d)\n        self._xyz = optimizable_tensors[\"xyz\"]\n        self._features_dc = optimizable_tensors[\"f_dc\"]\n        self._features_rest = optimizable_tensors[\"f_rest\"]\n        self._opacity = optimizable_tensors[\"opacity\"]\n        self._scaling = optimizable_tensors[\"scaling\"]\n        self._rotation = optimizable_tensors[\"rotation\"]\n        self._ins_feat = optimizable_tensors[\"ins_feat\"]\n\n        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):\n        n_init_points = self.get_xyz.shape[0]\n        # Extract points that satisfy the gradient condition\n        padded_grad = torch.zeros((n_init_points), device=\"cuda\")\n        padded_grad[:grads.shape[0]] = grads.squeeze()\n        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)\n        selected_pts_mask = torch.logical_and(selected_pts_mask,\n                                              torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)\n\n        stds = self.get_scaling[selected_pts_mask].repeat(N,1)\n        means =torch.zeros((stds.size(0), 3),device=\"cuda\")\n        samples = torch.normal(mean=means, std=stds)\n        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)\n        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)\n        new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))\n        new_rotation = self._rotation[selected_pts_mask].repeat(N,1)\n        new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)\n        new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)\n        new_opacity = self._opacity[selected_pts_mask].repeat(N,1)\n        new_ins_feat = self._ins_feat[selected_pts_mask].repeat(N,1)\n\n        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, \\\n            new_opacity, new_scaling, new_rotation, new_ins_feat)\n\n        prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device=\"cuda\", dtype=bool)))\n        self.prune_points(prune_filter)\n\n    def densify_and_clone(self, grads, grad_threshold, scene_extent):\n        # Extract points that satisfy the gradient condition\n        selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)\n        selected_pts_mask = torch.logical_and(selected_pts_mask,\n                                              torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)\n        \n        new_xyz = self._xyz[selected_pts_mask]\n        new_features_dc = self._features_dc[selected_pts_mask]\n        new_features_rest = self._features_rest[selected_pts_mask]\n        new_opacities = self._opacity[selected_pts_mask]\n        new_scaling = self._scaling[selected_pts_mask]\n        new_rotation = self._rotation[selected_pts_mask]\n        new_ins_feat = self._ins_feat[selected_pts_mask]\n\n        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, \\\n            new_scaling, new_rotation, new_ins_feat)\n\n    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):\n        grads = self.xyz_gradient_accum / self.denom\n        grads[grads.isnan()] = 0.0\n\n        self.densify_and_clone(grads, max_grad, extent)\n        self.densify_and_split(grads, max_grad, extent)\n\n        prune_mask = (self.get_opacity < min_opacity).squeeze()\n        if max_screen_size:\n            big_points_vs = self.max_radii2D > max_screen_size\n            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent\n            prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)\n        self.prune_points(prune_mask)\n\n        torch.cuda.empty_cache()\n\n    def add_densification_stats(self, viewspace_point_tensor, update_filter):\n        self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)\n        self.denom[update_filter] += 1"
  },
  {
    "path": "scene/kmeans_quantize.py",
    "content": "import os\nimport pdb\nfrom tqdm import tqdm\nimport time\n\nimport torch\nimport numpy as np\nfrom torch import nn\nimport torch.nn.functional as F\n\n\nclass Quantize_kMeans():\n    def __init__(self, num_clusters=64, num_leaf_clusters=10, num_iters=10, dim=9, dim_leaf=6):\n        self.num_clusters = num_clusters            # k1\n        self.leaf_num_clusters = num_leaf_clusters  # k2\n        self.num_kmeans_iters = num_iters           # iter\n        self.vec_dim = dim                          # coarse-level, dim=9(feat+xyz)\n        self.leaf_vec_dim = dim_leaf                # fine-level, dim=6(feat)\n        self.centers = torch.empty(0)               # coarse center， [k1, 9]\n        self.leaf_centers = torch.empty(0)          # fine center， [k2, 6]\n        self.iLeafSubNum = torch.empty(0)           # Number of fine clusters per coarse cluster\n        self.cls_ids = torch.empty(0)               # coarse cluster id [num_pts]\n        self.leaf_cls_ids = torch.empty(0)          # fine cluster id[num_pts]\n        \n        self.nn_index = torch.empty(0)              # [num_pts] temporary variable\n\n        # for update_centers\n        self.cluster_ids = torch.empty(0)\n        self.excl_clusters = []\n        self.excl_cluster_ids = []\n        self.cluster_len = torch.empty(0)\n        self.max_cnt = 0                  \n        self.max_cnt_th = 10000\n        self.n_excl_cls = 0       \n\n        self.pos_centers = torch.empty(0)           \n\n    def get_dist(self, x, y, mode='sq_euclidean'):\n        \"\"\"Calculate distance between all vectors in x and all vectors in y.\n\n        x: (m, dim)\n        y: (n, dim)\n        dist: (m, n)\n        \"\"\"\n        if mode == 'sq_euclidean_chunk':\n            step = 65536\n            if x.shape[0] < step:\n                step = x.shape[0]\n            dist = []\n            for i in range(np.ceil(x.shape[0] / step).astype(int)):\n                dist.append(torch.cdist(x[(i*step): (i+1)*step, :].unsqueeze(0), y.unsqueeze(0))[0])\n            dist = torch.cat(dist, 0)\n        elif mode == 'sq_euclidean':\n            dist = torch.cdist(x.unsqueeze(0).detach(), y.unsqueeze(0).detach())[0]\n        return dist\n\n    # Update centers in non-cluster assignment iters using cached nn indices.\n    def update_centers(self, feat, mode=\"root\", selected_leaf=-1):\n        if mode == \"root\":\n            centers = self.centers\n            num_clusters = self.num_clusters\n            vec_dim = self.vec_dim\n        elif mode == \"leaf\":\n            centers = self.leaf_centers\n            num_clusters = self.num_clusters * self.leaf_num_clusters + 1\n            vec_dim = self.leaf_vec_dim\n        feat = feat.detach().reshape(-1, vec_dim)  # [num_pts, dim] [766267, 9]\n        # Update all clusters except the excluded ones in a single operation\n        # Add a dummy element with zeros at the end\n        feat = torch.cat([feat, torch.zeros_like(feat[:1]).cuda()], 0)  # [num_pts+1, dim]\n        centers = torch.sum(feat[self.cluster_ids, :].reshape(\n            num_clusters, self.max_cnt, -1), dim=1)    # [num_clusters, vec_dim]\n        if len(self.excl_cluster_ids) > 0:\n            for i, cls in enumerate(self.excl_clusters):\n                # Division by num_points in cluster is done during the one-shot averaging of all\n                # clusters below. Only the extra elements in the bigger clusters are added here.\n                centers[cls] += torch.sum(feat[self.excl_cluster_ids[i], :], dim=0)\n        centers /= (self.cluster_len + 1e-6)\n\n    # Update centers during cluster assignment using mask matrix multiplication\n    # Mask is obtained from distance matrix\n    def update_centers_(self, feat, cluster_mask=None, nn_index=None, avg=False):\n        # feat = feat.detach().reshape(-1, self.vec_dim)\n        centers = (cluster_mask.T @ feat)   # [1w, num_cluster] * [1w, dim] -> [num_cluster, dim]\n        # if avg:\n        #     self.centers /= counts.unsqueeze(-1)\n        return centers\n\n    def equalize_cluster_size(self, mode=\"root\"):\n        \"\"\"Make the size of all the clusters the same by appending dummy elements.\n\n        \"\"\"\n        # Find the maximum number of elements in a cluster, make size of all clusters\n        # equal by appending dummy elements until size is equal to size of max cluster.\n        # If max is too large, exclude it and consider the next biggest. Use for loop for\n        # the excluded clusters and a single operation for the remaining ones for\n        # updating the cluster centers.\n\n        unq, n_unq = torch.unique(self.nn_index, return_counts=True)\n        # Find max cluster size and exclude clusters greater than a threshold\n        topk = 100\n        if len(n_unq) < topk:\n            topk = len(n_unq)\n        max_cnt_topk, topk_idx = torch.topk(n_unq, topk)\n        self.max_cnt = max_cnt_topk[0]\n        idx = 0\n        self.excl_clusters = []\n        self.excl_cluster_ids = []\n        while(self.max_cnt > self.max_cnt_th):\n            self.excl_clusters.append(unq[topk_idx[idx]])\n            idx += 1\n            if idx < topk:\n                self.max_cnt = max_cnt_topk[idx]\n            else:\n                break\n        self.n_excl_cls = len(self.excl_clusters)\n        self.excl_clusters = sorted(self.excl_clusters)\n        # Store the indices of elements for each cluster\n        all_ids = []\n        cls_len = []\n        if mode == \"root\":\n            num_clusters = self.num_clusters\n        elif mode == \"leaf\":\n            num_clusters = self.num_clusters * self.leaf_num_clusters + 1\n        for i in range(num_clusters):\n            cur_cluster_ids = torch.where(self.nn_index == i)[0]\n            # For excluded clusters, use only the first max_cnt elements\n            # for averaging along with other clusters. Separately average the\n            # remaining elements just for the excluded clusters.\n            cls_len.append(torch.Tensor([len(cur_cluster_ids)]))\n            if i in self.excl_clusters:\n                self.excl_cluster_ids.append(cur_cluster_ids[self.max_cnt:])\n                cur_cluster_ids = cur_cluster_ids[:self.max_cnt]\n            # Append dummy elements to have same size for all clusters\n            all_ids.append(torch.cat([cur_cluster_ids, -1 * torch.ones((self.max_cnt - len(cur_cluster_ids)),\n                                                                       dtype=torch.long).cuda()]))\n        all_ids = torch.cat(all_ids).type(torch.long)\n        cls_len = torch.cat(cls_len).type(torch.long)\n        self.cluster_ids = all_ids\n        self.cluster_len = cls_len.unsqueeze(1).cuda()\n        if mode == \"root\":\n            self.cls_ids = self.nn_index\n        elif mode == \"leaf\":\n            self.leaf_cls_ids = self.nn_index\n\n    def cluster_assign(self, feat, feat_scaled=None, mode=\"root\", selected_leaf=-1):\n\n        # quantize with kmeans\n        feat = feat.detach()    # [N, dim]\n\n        if feat_scaled is None:\n            feat_scaled = feat\n            scale = feat[0] / (feat_scaled[0] + 1e-8)\n        # init. centers and ids\n        if len(self.centers) == 0 and mode == \"root\":\n            self.centers = feat[torch.randperm(feat.shape[0])[:self.num_clusters], :]\n        if len(self.leaf_centers) == 0 and mode == \"leaf\":\n            # [num_clusters, leaf_num_clusters, dim_leaf] eg. [640, 6]\n            self.leaf_centers = feat[torch.randperm(feat.shape[0])[:self.num_clusters * self.leaf_num_clusters+1], :]\n            self.leaf_cls_ids = torch.ones(feat.shape[0]).to(torch.int64).cuda() * self.num_clusters * self.leaf_num_clusters\n\n        # start kmeans\n        chunk = True\n        # tmp centers\n        if mode == \"root\":\n            tmp_centers = torch.zeros_like(self.centers)\n            counts = torch.zeros(self.num_clusters, dtype=torch.float32).cuda() + 1e-6\n        elif mode == \"leaf\":\n            tmp_centers = torch.zeros_like(self.leaf_centers)[:self.leaf_num_clusters, :]\n            counts = torch.zeros(self.leaf_num_clusters, dtype=torch.float32).cuda() + 1e-6\n            start_id = selected_leaf * self.leaf_num_clusters\n            end_id = selected_leaf * self.leaf_num_clusters + self.iLeafSubNum[selected_leaf]\n        for iteration in range(self.num_kmeans_iters):\n            # chunk for memory issues\n            if chunk:\n                self.nn_index = None\n                i = 0\n                chunk = 10000\n                if mode == \"root\":\n                    while True:\n                        dist = self.get_dist(feat[i*chunk:(i+1)*chunk, :], self.centers)\n                        curr_nn_index = torch.argmin(dist, dim=-1)  # [1W]\n                        # Assign a single cluster when distance to multiple clusters is same\n                        dist = F.one_hot(curr_nn_index, self.num_clusters).type(torch.float32)  # [1W, 512]\n                        curr_centers = self.update_centers_(feat[i*chunk:(i+1)*chunk, :], dist, curr_nn_index, avg=False)   # [512, 45]\n                        counts += dist.detach().sum(0) + 1e-6   # [512]\n                        tmp_centers += curr_centers\n                        if self.nn_index == None:\n                            self.nn_index = curr_nn_index\n                        else:\n                            self.nn_index = torch.cat((self.nn_index, curr_nn_index), dim=0)\n                        i += 1\n                        if i*chunk > feat.shape[0]:\n                            break\n                elif mode == \"leaf\":\n                    for idx_c in range(self.num_clusters):\n                        if idx_c != selected_leaf:\n                            continue\n                        selected_pts = self.cls_ids == idx_c\n                        dist = self.get_dist(feat[selected_pts], self.leaf_centers[start_id:end_id])\n                        curr_nn_index = torch.argmin(dist, dim=-1)  # [1W]\n                        dist = F.one_hot(curr_nn_index, self.leaf_num_clusters).type(torch.float32)  # [1W, 10]\n                        curr_centers = self.update_centers_(feat[selected_pts], dist, curr_nn_index, avg=False)   # [512, 45]\n                        counts += dist.detach().sum(0) + 1e-6   # [512]\n                        tmp_centers += curr_centers\n                        self.leaf_cls_ids[selected_pts] = curr_nn_index + start_id\n            # avrage centers\n            if mode == \"root\":\n                self.centers = tmp_centers / counts.unsqueeze(-1)   \n            elif mode == \"leaf\":\n                self.leaf_centers[start_id: start_id+self.leaf_num_clusters] = tmp_centers / counts.unsqueeze(-1)   \n            # Reinitialize to 0\n            tmp_centers[tmp_centers != 0] = 0.\n            counts[counts > 0.1] = 0.\n\n        # Reassign ID according to the new centers\n        if chunk:\n            self.nn_index = None\n            i = 0\n            # chunk = 100000\n            if mode == \"root\":\n                while True:\n                    dist = self.get_dist(feat_scaled[i * chunk:(i + 1) * chunk, :], self.centers)\n                    curr_nn_index = torch.argmin(dist, dim=-1)\n                    if self.nn_index == None:\n                        self.nn_index = curr_nn_index\n                    else:\n                        self.nn_index = torch.cat((self.nn_index, curr_nn_index), dim=0)\n                    i += 1\n                    if i * chunk > feat.shape[0]:\n                        break\n            elif mode == \"leaf\":\n                for idx_c in range(self.num_clusters):\n                    if idx_c != selected_leaf:\n                        continue\n                    selected_pts = self.cls_ids == idx_c\n                    dist = self.get_dist(feat[selected_pts], self.leaf_centers[start_id:end_id])\n                    curr_nn_index = torch.argmin(dist, dim=-1)\n                    self.leaf_cls_ids[selected_pts] = curr_nn_index + start_id\n                self.nn_index = self.leaf_cls_ids\n        self.equalize_cluster_size(mode=mode)\n\n    def rescale(self, feat, scale=None):\n        \"\"\"Scale the feature to be in the range [-1, 1] by dividing by its max value.\n\n        \"\"\"\n        if scale is None:\n            return feat / (abs(feat).max(dim=0)[0] + 1e-8)\n        else:\n            return feat / (scale + 1e-8)\n\n    def forward(self, gaussian, iteration, assign=False, mode=\"root\", selected_leaf=-1, pos_weight=1.0):\n        if mode == \"root\":\n            # (1) coarse-level: feature + xyz\n            scale = pos_weight     # TODO\n            xyz_feat = gaussian._xyz.detach() * scale\n            feat = torch.cat((gaussian._ins_feat, xyz_feat), dim=1)    # [N, 9]\n        elif mode == \"leaf\":\n            # (2) fine-level: feature only\n            feat = gaussian._ins_feat\n\n        if assign:\n            self.cluster_assign(feat, mode=mode, selected_leaf=selected_leaf)   # gaussian._ins_feat\n        else:\n            self.update_centers(feat, mode=mode, selected_leaf=selected_leaf)   # gaussian._ins_feat\n\n        if mode == \"root\":\n            centers = self.centers\n            vec_dim = self.vec_dim\n        elif mode == \"leaf\":\n            centers = self.leaf_centers\n            vec_dim = self.leaf_vec_dim\n        sampled_centers = torch.gather(centers, 0, self.nn_index.unsqueeze(-1).repeat(1, vec_dim))\n        # NOTE: \"During backpropagation, the gradients of the quantized features are copied to the instance features\", mentioned in the paper.\n        gaussian._ins_feat_q = gaussian._ins_feat - gaussian._ins_feat.detach() + sampled_centers[:,:6]\n\n    def replace_with_centers(self, gaussian):\n        deg = gaussian._features_rest.shape[1]\n        sampled_centers = torch.gather(self.centers, 0, self.nn_index.unsqueeze(-1).repeat(1, self.vec_dim))\n        gaussian._features_rest = gaussian._features_rest - gaussian._features_rest.detach() + sampled_centers.reshape(-1, deg, 3)\n"
  },
  {
    "path": "scripts/compute_lerf_iou.py",
    "content": "import os\nimport numpy as np\nfrom PIL import Image\nfrom argparse import ArgumentParser\n\ndef load_image_as_binary(image_path, is_png=False, threshold=10):\n    image = Image.open(image_path)\n    if is_png:\n        image = image.convert('L')\n    image_array = np.array(image)\n    binary_image = (image_array > threshold).astype(int)\n    return binary_image\n\ndef calculate_iou(mask1, mask2):\n    intersection = np.logical_and(mask1, mask2).sum()\n    union = np.logical_or(mask1, mask2).sum()\n    if union == 0:\n        return 0\n    return intersection / union\n\ndef evalute(gt_base, pred_base, scene_name):\n    scene_gt_frames = {\n        \"waldo_kitchen\": [\"frame_00053\", \"frame_00066\", \"frame_00089\", \"frame_00140\", \"frame_00154\"],\n        \"ramen\": [\"frame_00006\", \"frame_00024\", \"frame_00060\", \"frame_00065\", \"frame_00081\", \"frame_00119\", \"frame_00128\"],\n        \"figurines\": [\"frame_00041\", \"frame_00105\", \"frame_00152\", \"frame_00195\"],\n        \"teatime\": [\"frame_00002\", \"frame_00025\", \"frame_00043\", \"frame_00107\", \"frame_00129\", \"frame_00140\"]\n    }\n    frame_names = scene_gt_frames[scene_name]\n\n    ious = []\n    for frame in frame_names:\n        print(\"frame:\", frame)\n        gt_floder = os.path.join(gt_base, frame)\n        file_names = [f for f in os.listdir(gt_floder) if f.endswith('.jpg')]\n        for file_name in file_names:\n            base_name = os.path.splitext(file_name)[0]\n            gt_obj_path = os.path.join(gt_floder, file_name)\n            pred_obj_path = os.path.join(pred_base, frame + \"_\" + base_name + '.png')\n            if not os.path.exists(pred_obj_path):\n                print(f\"Missing pred file for {file_name}, skipping...\")\n                print(f\"IoU for {file_name}: 0\")\n                ious.append(0.0)\n                continue\n            mask_gt = load_image_as_binary(gt_obj_path)\n            mask_pred = load_image_as_binary(pred_obj_path, is_png=True)\n            iou = calculate_iou(mask_gt, mask_pred)\n            ious.append(iou)\n            print(f\"IoU for {file_name} and {base_name + '.png'}: {iou:.4f}\")\n    \n    # Acc.\n    total_count = len(ious)\n    count_iou_025 = (np.array(ious) > 0.25).sum()\n    count_iou_05 = (np.array(ious) > 0.5).sum()\n\n    # mIoU\n    average_iou = np.mean(ious)\n    print(f\"Average IoU: {average_iou:.4f}\")\n    print(f\"Acc@0.25: {count_iou_025/total_count:.4f}\")\n    print(f\"Acc@0.5: {count_iou_05/total_count:.4f}\")\n\nif __name__ == \"__main__\":\n    parser = ArgumentParser(\"Compute LeRF IoU\")\n    parser.add_argument(\"--scene_name\", type=str, choices=[\"waldo_kitchen\", \"ramen\", \"figurines\", \"teatime\"],\n                        help=\"Specify the scene_name from: figurines, teatime, ramen, waldo_kitchen\")\n    args = parser.parse_args()\n    if not args.scene_name:\n        parser.error(\"The --scene_name argument is required and must be one of: waldo_kitchen, ramen, figurines, teatime\")\n\n    # TODO: change\n    path_gt = \"/gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/label/waldo_kitchen/gt\"\n    # renders_cluster_silhouette is the predicted mask\n    path_pred = \"output/xxxxxxxx-x/text2obj/ours_70000/renders_cluster_silhouette\"\n    evalute(path_gt, path_pred, args.scene_name)"
  },
  {
    "path": "scripts/eval_scannet.py",
    "content": "import os\nfrom plyfile import PlyData, PlyElement\nimport torch.nn.functional as F\nimport numpy as np\nimport torch\nimport json\n\nnyu40_dict = {\n    0: \"unlabeled\", 1: \"wall\", 2: \"floor\", 3: \"cabinet\", 4: \"bed\", 5: \"chair\",\n    6: \"sofa\", 7: \"table\", 8: \"door\", 9: \"window\", 10: \"bookshelf\",\n    11: \"picture\", 12: \"counter\", 13: \"blinds\", 14: \"desk\", 15: \"shelves\",\n    16: \"curtain\", 17: \"dresser\", 18: \"pillow\", 19: \"mirror\", 20: \"floormat\",\n    21: \"clothes\", 22: \"ceiling\", 23: \"books\", 24: \"refrigerator\", 25: \"television\",\n    26: \"paper\", 27: \"towel\", 28: \"showercurtain\", 29: \"box\", 30: \"whiteboard\",\n    31: \"person\", 32: \"nightstand\", 33: \"toilet\", 34: \"sink\", 35: \"lamp\",\n    36: \"bathtub\", 37: \"bag\", 38: \"otherstructure\", 39: \"otherfurniture\", 40: \"otherprop\"\n}\n\n# ScanNet 20 classes\nscannet19_dict = {\n    1: \"wall\", 2: \"floor\", 3: \"cabinet\", 4: \"bed\", 5: \"chair\",\n    6: \"sofa\", 7: \"table\", 8: \"door\", 9: \"window\", 10: \"bookshelf\",\n    11: \"picture\", 12: \"counter\", 14: \"desk\", 16: \"curtain\",\n    24: \"refrigerator\", 28: \"shower curtain\", 33: \"toilet\", 34: \"sink\",\n    36: \"bathtub\", # 39: \"otherfurniture\"\n}\n\nimport numpy as np  \ndef sigmoid(x):  \n    return 1 / (1 + np.exp(-x))  \n\ndef write_ply(vertex_data, output_path):\n    vertices = []\n    for vertex in vertex_data:\n        r = (vertex['ins_feat_r'] + 1)/2 * 255\n        g = (vertex['ins_feat_g'] + 1)/2 * 255\n        b = (vertex['ins_feat_b'] + 1)/2 * 255\n        new_vertex = (vertex['x'], vertex['y'], vertex['z'], r, g, b)\n        vertices.append(new_vertex)\n    \n    vertex_dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]\n    new_vertex_data = np.array(vertices, dtype=vertex_dtype)\n    \n    el = PlyElement.describe(new_vertex_data, 'vertex')\n    PlyData([el], text=True).write(output_path)\n\ndef read_labels_from_ply(file_path):\n    ply_data = PlyData.read(file_path)\n    vertex_data = ply_data['vertex'].data\n    # Extract the coordinates and labels of the points. The labels are from 1 to 40 for the NYU40 dataset, with 0 being invalid.\n    points = np.vstack([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T\n    labels = vertex_data['label']\n    return points, labels\n\ndef calculate_metrics(gt, pred, total_classes):\n    gt = gt.cpu()\n    pred = pred.cpu()\n    pred[gt == 0] = 0\n\n    ious = torch.zeros(total_classes)\n\n    intersection = torch.zeros(total_classes)\n    union = torch.zeros(total_classes)\n    correct = torch.zeros(total_classes)\n    total = torch.zeros(total_classes)\n\n    for cls in range(1, total_classes):\n        intersection[cls] = torch.sum((gt == cls) & (pred == cls)).item()\n        union[cls] = torch.sum((gt == cls) | (pred == cls)).item()\n        correct[cls] = torch.sum((gt == cls) & (pred == cls)).item()\n        total[cls] = torch.sum(gt == cls).item()\n\n    valid_union = union != 0\n    ious[valid_union] = intersection[valid_union] / union[valid_union]\n\n    # Only consider the categories that exist in the current scene\n    gt_classes = torch.unique(gt)\n    valid_gt_classes = gt_classes[gt_classes != 0]  # ignore 0\n\n    # miou\n    mean_iou = ious[valid_gt_classes].mean().item()\n\n    # acc\n    valid_mask = gt != 0\n    correct_predictions = torch.sum((gt == pred) & valid_mask).item()\n    total_valid_points = torch.sum(valid_mask).item()\n    accuracy = correct_predictions / total_valid_points if total_valid_points > 0 else float('nan')\n\n    class_accuracy = correct / total\n    # mAcc.\n    mean_class_accuracy = class_accuracy[valid_gt_classes].mean().item()\n\n    return ious, mean_iou, accuracy, mean_class_accuracy\n\nif __name__ == \"__main__\":\n    scene_list = [  'scene0000_00', 'scene0062_00', 'scene0070_00', 'scene0097_00', 'scene0140_00', \n                    'scene0200_00', 'scene0347_00', 'scene0400_00', 'scene0590_00', 'scene0645_00']\n\n    iteration = 90000\n    for scan_name in scene_list:\n        # (1) GT ply    change!\n        gt_file_path = f\"/gdata/cold1/wuyanmin/OpenGaussian/data/scannet_2d_3types/{scan_name}/{scan_name}_vh_clean_2.labels.ply\"\n        points, labels = read_labels_from_ply(gt_file_path)\n\n        # (2) note: 19 & 15 & 10 classes\n        # Given the category ID that needs to be queried (relative to the original NYU40), obtain the corresponding category name.\n        target_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36]   # 19\n        # target_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 33, 34]   # 15\n        # target_id = [1,2,4,5,6,7,8,9,10,33] # 10\n\n        target_dict = {key: nyu40_dict[key] for key in target_id}\n        target_names = list(target_dict.values())\n\n        # (3) update gt label\n        # Obtained new point cloud labels, taking 19 categories as an example, where updated_labels are labels 0, 1-19.\n        target_id_mapping = {value: index + 1 for index, value in enumerate(target_id)}\n        updated_labels = np.zeros_like(labels)\n        for original_value, new_value in target_id_mapping.items():\n            updated_labels[labels == original_value] = new_value\n        updated_gt_labels = torch.from_numpy(updated_labels.astype(np.int64)).cuda()\n        \n        # (4) load gaussian ply file\n        model_path = f\"output/{scan_name}/\"\n        ply_path = os.path.join(model_path, f\"point_cloud/iteration_{iteration}/point_cloud.ply\")\n        ply_data = PlyData.read(ply_path)\n        vertex_data = ply_data['vertex'].data\n        # NOTE Filter out points based on their opacity values.\n        ignored_pts = sigmoid(vertex_data[\"opacity\"]) < 0.1\n        updated_gt_labels[ignored_pts] = 0\n\n        # (5) load cluster language file\n        mapping_file = os.path.join(model_path, \"cluster_lang.npz\")\n        # load the saved codebook(leaf id) and instance-level language feature\n        # 'leaf_feat', 'leaf_acore', 'occu_count', 'leaf_ind'\n        saved_data = np.load(mapping_file)\n        leaf_lang_feat = torch.from_numpy(saved_data[\"leaf_feat.npy\"]).cuda()    # [num_leaf=k1*k2, 512] \n        leaf_score = torch.from_numpy(saved_data[\"leaf_score.npy\"]).cuda()       # [num_leaf=k1*k2] \n        leaf_occu_count = torch.from_numpy(saved_data[\"occu_count.npy\"]).cuda()  # [num_leaf=k1*k2] \n        leaf_ind = torch.from_numpy(saved_data[\"leaf_ind.npy\"]).cuda()           # [num_pts] \n        leaf_lang_feat[leaf_occu_count < 2] *= 0.0\n        leaf_ind = leaf_ind.clamp(max=319)  # 64*5=320\n\n        # (6) load query text feat.\n        with open('assets/text_features.json', 'r') as f:\n            data_loaded = json.load(f)\n        all_texts = list(data_loaded.keys())\n        text_features = torch.from_numpy(np.array(list(data_loaded.values()))).to(torch.float32)  # [num_text, 512]\n        \n        query_text_feats = torch.zeros(len(target_names), 512).cuda()\n        for i, text in enumerate(target_names):\n            feat = text_features[all_texts.index(text)].unsqueeze(0)\n            query_text_feats[i] = feat\n\n        # (7) Calculate the cosine similarity and return the ID of the category with the highest value.\n        query_text_feats = F.normalize(query_text_feats, dim=1, p=2)  \n        leaf_lang_feat = F.normalize(leaf_lang_feat, dim=1, p=2)  \n        cosine_similarity = torch.matmul(query_text_feats, leaf_lang_feat.transpose(0, 1))\n        # cosine_similarity = torch.mm(query_text_feats, leaf_lang_feat.t())   # [cls_num, cluster_num]\n        max_id = torch.argmax(cosine_similarity, dim=0) # [cluster_num]\n        pred_pts_cls_id = max_id[leaf_ind] + 1          # [num_pts] \n\n        ious, mean_iou, accuracy, mean_acc = calculate_metrics(updated_gt_labels, pred_pts_cls_id, total_classes=len(target_names)+1)\n        print(f\"Scene: {scan_name}, mIoU: {mean_iou:.4f}, mAcc.: {mean_acc:.4f}\") "
  },
  {
    "path": "scripts/render_by_click.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport torch.nn.functional as F\nfrom scene import Scene\nimport os\nfrom tqdm import tqdm\nfrom os import makedirs\nfrom gaussian_renderer import render\nimport torchvision\nfrom utils.general_utils import safe_state\nfrom argparse import ArgumentParser\nfrom arguments import ModelParams, PipelineParams, get_combined_args\nfrom gaussian_renderer import GaussianModel\nimport numpy as np\nfrom PIL import Image\nimport json\nfrom utils.opengs_utlis import mask_feature_mean, get_SAM_mask_and_feat, load_code_book\nimport pytorch3d.ops\n\nnp.random.seed(42)\ncolors_defined = np.random.randint(100, 256, size=(300, 3))\ncolors_defined[0] = np.array([0, 0, 0])\ncolors_defined = torch.from_numpy(colors_defined)\n\ndef get_pixel_values(image_path, position, radius=10):\n    with Image.open(image_path) as img:\n        img = img.convert('RGB')\n        width, height = img.size\n        \n        left = max(position[0] - radius, 0)\n        right = min(position[0] + radius + 1, width)\n        top = max(position[1] - radius, 0)\n        bottom = min(position[1] + radius + 1, height)\n\n        pixels = []\n        for x in range(left, right):\n            for y in range(top, bottom):\n                pixels.append(img.getpixel((x, y)))\n\n        pixels_array = np.array(pixels)\n        mean_pixel = pixels_array.mean(axis=0)\n    \n    return tuple(mean_pixel)\n\ndef compute_click_values(model_path, image_name, pix_xy, radius=5):\n    def compute_level_click_val(iter, model_path, image_name, pix_xy, radius):\n        img_path1 = f\"{model_path}/train/ours_{iter}/renders_ins_feat1/{image_name}_1.png\"      # TODO\n        img_path2 = f\"{model_path}/train/ours_{iter}/renders_ins_feat2/{image_name}_2.png\"      # TODO\n        val1 = get_pixel_values(img_path1, pix_xy, radius)\n        val2 = get_pixel_values(img_path2, pix_xy, radius)\n        click_val = (torch.tensor(list(val1) + list(val2)) / 255) * 2 - 1\n        return click_val\n    \n    level1_click_val = compute_level_click_val(50000, model_path, image_name, pix_xy, radius)   # TODO\n    level2_click_val = compute_level_click_val(70000, model_path, image_name, pix_xy, radius)   # TODO\n    \n    return level1_click_val, level2_click_val\n\ndef render_set(model_path, name, iteration, views, gaussians, pipeline, background):\n    render_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"renders\")\n    gts_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"gt\")\n\n    render_ins_feat_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"renders_ins_feat\")\n    gt_sam_mask_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"gt_sam_mask\")\n    pseudo_ins_feat_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"pseudo_ins_feat\")\n\n    makedirs(render_path, exist_ok=True)\n    makedirs(gts_path, exist_ok=True)\n    makedirs(render_ins_feat_path, exist_ok=True)\n    makedirs(gt_sam_mask_path, exist_ok=True)\n    makedirs(pseudo_ins_feat_path, exist_ok=True)\n\n    # load codebook\n    root_code_book, root_cluster_indices = load_code_book(os.path.join(model_path, \"point_cloud\", \\\n        f'iteration_{iteration}', \"root_code_book\"))\n    leaf_code_book, leaf_cluster_indices = load_code_book(os.path.join(model_path, \"point_cloud\", \\\n        f'iteration_{iteration}', \"leaf_code_book\"))\n    root_cluster_indices = torch.from_numpy(root_cluster_indices).cuda()\n    leaf_cluster_indices = torch.from_numpy(leaf_cluster_indices).cuda()\n    # counts = torch.bincount(torch.from_numpy(cluster_indices), minlength=64)\n\n    # load the saved codebook(leaf id) and instance-level language feature\n    # 'leaf_feat', 'leaf_acore', 'occu_count', 'leaf_ind'       leaf_figurines_cluster_lang\n    mapping_file = os.path.join(model_path, \"cluster_lang.npz\")\n    saved_data = np.load(mapping_file)\n    leaf_lang_feat = torch.from_numpy(saved_data[\"leaf_feat.npy\"]).cuda()    # [num_leaf=640, 512] Language feature of each instance\n    leaf_score = torch.from_numpy(saved_data[\"leaf_score.npy\"]).cuda()       # [num_leaf=640] Score of each instance\n    leaf_occu_count = torch.from_numpy(saved_data[\"occu_count.npy\"]).cuda()  # [num_leaf=640] Number of occurrences of each instance\n    leaf_ind = torch.from_numpy(saved_data[\"leaf_ind.npy\"]).cuda()           # [num_pts] Instance ID corresponding to each point\n    leaf_lang_feat[leaf_occu_count < 5] *= 0.0      # ignore\n    leaf_cluster_indices = leaf_ind\n    \n    image_name = \"frame_00002\"      # TODO\n    # # object_name = \"apple\"\n    # pix_xy = (450, 217) # bag of cookies\n    # pix_xy = (344, 350) # apple\n    # # teatime       image_name = \"frame_00002\"\n    # object_names = [\"bear nose\", \"stuffed bear\", \"sheep\", \"bag of cookies\", \\\n    #                 \"plate\", \"three cookies\", \"tea in a glass\", \"apple\", \\\n    #                 \"coffee mug\", \"coffee\", \"paper napkin\"]\n    # pix_xy_list = [ (740, 80), (800, 160), (80, 240), (450, 200),\n    #                 (468, 288), (438, 273), (309, 308), (343, 361),\n    #                 (578, 274), (571, 260), (565, 380)]\n    # figurines   image_name = \"frame_00002\"\n    # TODO\n    object_names = [\"rubber duck with buoy\", \"porcelain hand\", \"miffy\", \"toy elephant\", \"toy cat statue\", \\\n                    \"jake\", \"Play-Doh bucket\", \"rubber duck with hat\", \"rubics cube\", \"waldo\", \\\n                    \"twizzlers\", \"red toy chair\", \"green toy chair\", \"pink ice cream\", \"spatula\", \\\n                    \"pikachu\", \"green apple\", \"rabbit\", \"old camera\", \"pumpkin\", \\\n                    \"tesla door handle\"]\n    # TODO\n    pix_xy_list = [ (103, 378), (552, 390), (896, 342), (720, 257), (254, 297),\n                    (451, 197), (626, 256), (760, 166), (781, 243), (896, 136),\n                    (927, 241), (688, 148), (538, 160), (565, 238), (575, 257),\n                    (377, 156), (156, 244), (21, 237), (283, 152), (330, 200),\n                    (514, 200)]\n    # # ramen           image_name = \"frame_00002\"\n    # object_names = [\"clouth\", \"sake cup\", \"chopsticks\", \"spoon\", \"plate\", \\\n    #                 \"bowl\", \"egg\", \"nori\", \"glass of water\", \"napkin\"]\n    # pix_xy_list = [(345, 38), (276, 424), (361, 370), (419, 285), (688, 412),\n    #                (489, 119), (694, 187), (810, 154), (939, 289), (428, 462)]\n    # # waldo_kitchen     image_name = \"frame_00001\"\n    # object_names = [\"knife\", \"pour-over vessel\", \"glass pot1\", \"glass pot2\", \"toaster\", \\\n    #                 \"hot water pot\", \"metal can\", \"cabinet\", \"ottolenghi\", \"waldo\"]\n    # pix_xy_list = [(439, 76), (410, 297), (306, 127), (349, 182), (261, 256),\n    #                (201, 262), (161, 267), (80, 34), (17, 141), (76, 169)]\n\n    for o_i, object in enumerate(object_names):\n        pix_xy = pix_xy_list[o_i]\n        root_click_val, leaf_click_val = compute_click_values(model_path, image_name, pix_xy)\n    \n        # Compute the nearest clusters with respect to the two-level codebook\n        distances_root = torch.norm(root_click_val - root_code_book[\"ins_feat\"][:, :-3].cpu(), dim=1)\n        distances_leaf = torch.norm(leaf_click_val - leaf_code_book[\"ins_feat\"][:-1, :].cpu(), dim=1)\n        distances_leaf[leaf_code_book[\"ins_feat\"][:-1].sum(-1) == 0] = 999  # Assign a large value to dis for nodes that remain unassigned\n        \n        # Retrieve the candidate child nodes linked to each selected root node\n        min_index_root = torch.argmin(distances_root).item()\n        leaf_num = (leaf_code_book[\"ins_feat\"].shape[0] - 1) / root_code_book[\"ins_feat\"].shape[0]\n        start_id = int(min_index_root*leaf_num)\n        end_id = int((min_index_root + 1)*leaf_num)\n        distances_leaf_sub = distances_leaf[start_id: end_id]   # [10]\n\n        # # (1) Choose several child nodes that fulfill the requirements\n        # click_leaf_indices = torch.nonzero(distances_leaf_sub < 0.9).squeeze() + start_id\n        # if (click_leaf_indices.dim() == 0) and click_leaf_indices.numel() != 0:\n        #     click_leaf_indices = click_leaf_indices.unsqueeze(0) \n        # elif click_leaf_indices.numel() == 0:\n        #     click_leaf_indices = torch.argmin(distances_leaf_sub).unsqueeze(0)\n        # (2) identify the root-level codebook and then pick the closest leaf node inside it (preferred)\n        click_leaf_indices = torch.argmin(distances_leaf_sub).unsqueeze(0) + start_id\n        # (3) directly select the child node with the minimum distance (less precise)\n        # click_leaf_indices = torch.argmin(distances_leaf).unsqueeze(0)\n        # # (4) you can also directly specify a particular child node if needed\n        # click_leaf_indices = torch.tensor([60, 66])     # 64 picachu, 60, 66 toy elephant, 65 jake, 633 green apple, 639 duck\n        \n        # Get the mask linked to the child node\n        pre_pts_mask = (leaf_cluster_indices.unsqueeze(1) == click_leaf_indices.cuda()).any(dim=1)\n\n        # post process  modify-----\n        post_process = True\n        max_time = 5\n        if post_process and max_time > 0:\n            nearest_k_distance = pytorch3d.ops.knn_points(\n                gaussians._xyz[pre_pts_mask].unsqueeze(0),\n                gaussians._xyz[pre_pts_mask].unsqueeze(0),\n                K=int(pre_pts_mask.sum()**0.5) * 2,\n            ).dists\n            mean_nearest_k_distance, std_nearest_k_distance = nearest_k_distance.mean(), nearest_k_distance.std()\n            # print(std_nearest_k_distance, \"std_nearest_k_distance\")\n\n            # mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + std_nearest_k_distance\n            mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + 0.1 * std_nearest_k_distance\n            # mask = nearest_k_distance.mean(dim = -1) < 2 * mean_nearest_k_distance \n\n            mask = mask.squeeze()\n            if pre_pts_mask is not None:\n                pre_pts_mask[pre_pts_mask != 0] = mask\n            max_time -= 1\n\n        # out_dir = \"ca9c2998-e\"\n        # splits = [\"train\", \"train\", \"train\", \"train\", \"test\"]\n        # frame_name_list = [\"frame_00053\", \"frame_00066\", \"frame_00140\", \"frame_00154\", \"frame_00089\"]\n        # for f_i, frame_name in enumerate(frame_name_list):\n        #     base_path = f\"/mnt/disk1/codes/wuyanmin/code/OpenGaussian/output/{out_dir}/{splits[f_i]}/ours_70000/renders_cluster_silhouette\"\n        #     target_path = f\"/mnt/disk1/codes/wuyanmin/code/OpenGaussian/output/{out_dir}/{splits[f_i]}/ours_70000/result/{frame_name}\"\n        #     makedirs(target_path, exist_ok=True)\n        #     for _, text in enumerate(waldo_kitchen_texts):\n        #         pos_feat = text_features[query_texts.index(text)].unsqueeze(0)\n        #         similarity_pos = F.cosine_similarity(pos_feat, leaf_lang_feat.cpu())    # [640]\n        #         top_values, top_indices = torch.topk(similarity_pos, 10)   # [num_mask]\n        #         print(\"text: {} | cluster id: {}\".format(text, top_indices[0]))\n        #         ori_img_name = base_path + f\"/{frame_name}_cluster_{top_indices[0].item()}.png\"\n        #         new_name = target_path + f\"/{text}.png\"\n                \n        #         if not os.path.exists(ori_img_name):\n        #             top = 10\n        #             for i in range(top):\n        #                 ori_img_name = target_path + f\"/{frame_name}_cluster_{top_indices[i].item()}.png\"\n        #                 if os.path.exists(ori_img_name):\n        #                     break\n        #         if not os.path.exists(ori_img_name):\n        #             print(f\"No file found at {ori_img_name}. Operation skipped.\")\n        #             continue\n        #         import shutil\n        #         shutil.copy2(ori_img_name, new_name)\n\n        # render\n        for idx, view in enumerate(tqdm(views, desc=\"Rendering progress\")):\n            # render_pkg = render(view, gaussians, pipeline, background, iteration, rescale=False)\n            \n            # # figurines\n            # if  view.image_name not in [\"frame_00041\", \"frame_00105\", \"frame_00152\", \"frame_00195\"]:\n            #     continue\n            # # teatime\n            # if  view.image_name not in [\"frame_00002\", \"frame_00025\", \"frame_00043\", \"frame_00107\", \"frame_00129\", \"frame_00140\"]:\n            #     continue\n            # # ramen\n            # if  view.image_name not in [\"frame_00006\", \"frame_00024\", \"frame_00060\", \"frame_00065\", \"frame_00081\", \"frame_00119\", \"frame_00128\"]:\n            #     continue\n            # # waldo_kitchen\n            # if  view.image_name not in [\"frame_00053\", \"frame_00066\", \"frame_00089\", \"frame_00140\", \"frame_00154\"]:\n            #     continue\n\n            # NOTE render\n            render_pkg = render(view, gaussians, pipeline, background, iteration,\n                                rescale=False,                #)  # wherther to re-scale the gaussian scale\n                                # cluster_idx=leaf_cluster_indices,     # root id \n                                leaf_cluster_idx=leaf_cluster_indices,            # leaf id               \n                                selected_leaf_id=click_leaf_indices.cuda(),       # selected leaf id      \n                                render_feat_map=True, \n                                render_cluster=False,\n                                better_vis=True,\n                                pre_mask=pre_pts_mask,\n                                seg_rgb=True)\n            rendering = render_pkg[\"render\"]\n            rendered_cluster_imgs = render_pkg[\"leaf_clusters_imgs\"]\n            occured_leaf_id = render_pkg[\"occured_leaf_id\"]\n            rendered_leaf_cluster_silhouettes = render_pkg[\"leaf_cluster_silhouettes\"]\n\n            # save Rendered RGB\n            torchvision.utils.save_image(rendering, os.path.join(render_path, view.image_name + \".png\"))\n\n            render_cluster_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"click_cluster\")\n            render_cluster_silhouette_path = os.path.join(model_path, name, \"ours_{}\".format(iteration), \"click_cluster_mask\")\n            makedirs(render_cluster_path, exist_ok=True)\n            makedirs(render_cluster_silhouette_path, exist_ok=True)\n            for i, img in enumerate(rendered_cluster_imgs):\n                torchvision.utils.save_image(img[:3,:,:], os.path.join(render_cluster_path, \\\n                    view.image_name + f\"_{object}_cluster_{occured_leaf_id[i]}.png\"))\n                # save mask\n                cluster_silhouette = rendered_leaf_cluster_silhouettes[i] > 0.8\n                torchvision.utils.save_image(cluster_silhouette.to(torch.float32), os.path.join(render_cluster_silhouette_path, \\\n                    view.image_name + f\"_{object}_cluster_{occured_leaf_id[i]}.png\"))\n\ndef render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):\n    with torch.no_grad():\n        gaussians = GaussianModel(dataset.sh_degree)\n        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)\n\n        bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]\n        background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n\n        if not skip_train:\n             render_set(dataset.model_path, \"train\", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)\n\n        if not skip_test:\n             render_set(dataset.model_path, \"test\", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)\n\nif __name__ == \"__main__\":\n    # Set up command line argument parser\n    parser = ArgumentParser(description=\"Testing script parameters\")\n    model = ModelParams(parser, sentinel=True)\n    pipeline = PipelineParams(parser)\n    parser.add_argument(\"--iteration\", default=-1, type=int)\n    parser.add_argument(\"--skip_train\", action=\"store_true\")\n    parser.add_argument(\"--skip_test\", action=\"store_true\")\n    parser.add_argument(\"--quiet\", action=\"store_true\")\n    args = get_combined_args(parser)\n    print(\"Rendering \" + args.model_path)\n\n    # Initialize system state (RNG)\n    safe_state(args.quiet)\n\n    render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)"
  },
  {
    "path": "scripts/scannet2blender.py",
    "content": "import os\nimport json\nimport numpy as np\n\ndef load_transform_matrix(file_path):\n    \"\"\"\n    Load the transform matrix from a text file.\n    \"\"\"\n    with open(file_path, 'r') as file:\n        matrix = [list(map(float, line.strip().split())) for line in file]\n    return matrix\n\ndef process_directory(directory_path):\n    \"\"\"\n    Process each directory and create a JSON file with the transform matrices.\n    \"\"\"\n    color_dir = os.path.join(directory_path, \"color\")           # TODO\n    pose_dir = os.path.join(directory_path, \"pose\")             # TODO\n    intrinsic_dir = os.path.join(directory_path, \"intrinsic\")   # TODO\n\n    # Check if both directories exist\n    if not os.path.isdir(color_dir) or not os.path.isdir(pose_dir):\n        return\n\n    # scannet\n    transform_data = {\n            'w': 1296,\n            'h': 968,\n            'fl_x': 1170.187988,\n            'fl_y': 1170.187988,\n            'cx': 647.75,\n            'cy': 483.75,\n            # 'aabb_scale': 2,\n            'frames': [],\n        }\n    # # scannet\n    # transform_data = {\n    #         'w': 640,\n    #         'h': 512,\n    #         'fl_x': 534.56,\n    #         'fl_y': 534.80,\n    #         'cx': 314.27,\n    #         'cy': 259.96,\n    #         # 'aabb_scale': 2,\n    #         'frames': [],\n    #     }\n    # Collect all image names and sort them\n    img_names = [img_name for img_name in os.listdir(color_dir) if img_name.endswith(\".jpg\")]\n    # img_names.sort(key=lambda x: int(os.path.splitext(x)[0]))  # Sort by image number\n    img_names.sort(key=lambda x: os.path.splitext(x)[0])  # Sort by image number\n\n    # Iterate over the color images\n    for img_name in img_names:\n        if img_name.endswith(\".jpg\"):\n            # Construct the corresponding pose file path\n            pose_file = os.path.splitext(img_name)[0] + \".txt\"\n            pose_file_path = os.path.join(pose_dir, pose_file)\n\n            intrinsic_file = os.path.splitext(img_name)[0] + \".txt\"\n            intrinsic_file_path = os.path.join(intrinsic_dir, intrinsic_file)\n\n            # Check if the pose file exists\n            if os.path.isfile(pose_file_path):\n                transform_matrix = load_transform_matrix(pose_file_path)\n                \n                # note: colmap --> blender\n                transform_matrix = np.array(transform_matrix)\n                transform_matrix[:3, 1:3] *= -1     \n                transform_matrix = transform_matrix.tolist()\n\n                frame_data = {\n                    \"file_path\": os.path.join(\"color\", os.path.splitext(img_name)[0]),\n                    \"transform_matrix\": transform_matrix\n                }\n\n                if os.path.isfile(intrinsic_file_path):\n                    intrinsic_info = load_transform_matrix(intrinsic_file_path)\n                    frame_data.update({\n                        'fl_x': intrinsic_info[0][0],\n                        'fl_y': intrinsic_info[1][1],\n                        'cx':  intrinsic_info[0][2],\n                        'cy': intrinsic_info[1][2]\n                    })\n\n                transform_data[\"frames\"].append(frame_data)\n\n    return transform_data\n\n# Directory containing the scenes\nbase_directory = 'PATH_TO_YOUR_SCANNET'     # TODO\n\n# Process each scene directory and create JSON files\nfor scene_dir in os.listdir(base_directory):\n    # if scene_dir != \"scene0000_00\":\n    #     continue\n    \n    scene_path = os.path.join(base_directory, scene_dir)\n    if os.path.isdir(scene_path):\n        # Process the directory and get the transform data\n        transform_data = process_directory(scene_path)\n\n        print(scene_path)\n        \n        # Create the JSON file\n        if transform_data:\n            json_file_path = os.path.join(scene_path, \"transforms_train.json\")\n            with open(json_file_path, 'w') as json_file:\n                json.dump(transform_data, json_file, indent=4)\n"
  },
  {
    "path": "scripts/train_lerf.sh",
    "content": "#!/bin/bash\n# chmod +x scripts/train_lerf.sh\n# ./scripts/train_lerf.sh\n\n# !!! Please check the dataset path specified by -s.\n\n# Total training steps: 70k\n# 3dgs pre-train: 0~30k\n# stage1: 30~40k\n# stage2 (coarse-level): 40~50k\n# stage2 (fine-level): 50k~70k\n\n# ###############################################\n# #              (1/4) figurines\n# # Training takes approximately 70 minutes on a 24G 4090 GPU.\n# # The object selection effect is better (recommended), the point cloud visualization is poor (not recommended).\n# # k1=64, k2=10\n# # --pos_weight 0.5\n# # --save_memory: Saves memory, but will reduce training speed. If your GPU memory > 24GB, you can omit this flag\n# ###############################################\nscan=\"figurines\"\ngpu_num=3           # change\necho \"Training for ${scan} .....\"\nCUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \\\n    -s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \\\n    --iterations 70_000 \\\n    --start_ins_feat_iter 30_000 \\\n    --start_root_cb_iter 40_000 \\\n    --start_leaf_cb_iter 50_000 \\\n    --sam_level 3 \\\n    --root_node_num 64 \\\n    --leaf_node_num 10 \\\n    --pos_weight 0.5 \\\n    --save_memory \\\n    --test_iterations 30000 \\\n    --eval\n\n\n# ###############################################\n# #              (2/4) waldo_kitchen\n# # Training takes approximately 60 minutes on a 24G 4090 GPU.\n# # Good point cloud visualization result (recommended), suboptimal object selection effect.\n# # k1=64, k2=10\n# # --pos_weight 0.5\n# # No need to set save_memory, 24G is sufficient.\n# ###############################################\nscan=\"waldo_kitchen\"\ngpu_num=3           # change\necho \"Training for ${scan} .....\"\nCUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \\\n    -s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \\\n    --iterations 70_000 \\\n    --start_ins_feat_iter 30_000 \\\n    --start_root_cb_iter 40_000 \\\n    --start_leaf_cb_iter 50_000 \\\n    --sam_level 3 \\\n    --root_node_num 64 \\\n    --leaf_node_num 10 \\\n    --pos_weight 0.5 \\\n    --test_iterations 30000 \\\n    --eval\n\n\n# ###############################################\n# #              (3/4) teatime\n# # Training takes approximately 80 minutes on a 24G 4090 GPU.\n# # k1=32, k2=10\n# # --pos_weight 0.1\n# # --save_memory: Saves memory, but will reduce training speed. If your GPU memory > 24GB, you can omit this flag\n# ###############################################\nscan=\"teatime\"\ngpu_num=3       # change\necho \"Training for ${scan} .....\"\nCUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \\\n    -s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \\\n    --iterations 70_000 \\\n    --start_ins_feat_iter 30_000 \\\n    --start_root_cb_iter 40_000 \\\n    --start_leaf_cb_iter 50_000 \\\n    --sam_level 3 \\\n    --root_node_num 32 \\\n    --leaf_node_num 10 \\\n    --pos_weight 0.1 \\\n    --save_memory \\\n    --test_iterations 30000 \\\n    --eval\n\n\n# ###############################################\n# #              (4/4) ramen\n# # Training takes approximately 40 minutes on a 24G 4090 GPU.\n# # The object selection effect is the worst and unstable (not recommended).\n# # k1=64, k2=10\n# # --pos_weight 0.5\n# # --loss_weight 0.01: the weight of intra-mask smooth loss. 0.1 is used for the other scenes.\n# # No need to set save_memory, 24G is sufficient.\n# ###############################################\nscan=\"ramen\"\ngpu_num=3\necho \"Training for ${scan} .....\"\nCUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \\\n    -s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \\\n    --iterations 70_000 \\\n    --start_ins_feat_iter 30_000 \\\n    --start_root_cb_iter 40_000 \\\n    --start_leaf_cb_iter 50_000 \\\n    --sam_level 3 \\\n    --root_node_num 64 \\\n    --leaf_node_num 10 \\\n    --pos_weight 0.5 \\\n    --loss_weight 0.01 \\\n    --test_iterations 30000 \\\n    --eval"
  },
  {
    "path": "scripts/train_scannet.sh",
    "content": "#!/bin/bash\n# chmod +x scripts/train_scannet.sh\n# ./scripts/train_scannet.sh\n\n# ============== [Notice] ==============\n# 1. The 10 scene hyperparameters in the ScanNet dataset are consistent.\n# 2. Train a scene for about 20 minutes on a 24G 4090 GPU.\n# 3. Please check the dataset path specified by -s.\n\n# ============== [Hyperparameter explanation] ==============\n# Total training steps: 90k\n# 3dgs pre-train: 0~30k\n# stage1: 30~50k\n# stage2 (coarse-level): 50~70k\n# stage2 (fine-level): 70k~90k\n# k1=64, k2=5\n# frozen_init_pts: The point clouds provided by the ScanNet dataset are frozen, without using the densification scheme of 3DGS.\n# -r 2 : We use half-resolution data for training.\n\n# ============== [10 scenes] ==============\nscan_list=(\"scene0000_00\" \"scene0062_00\" \"scene0070_00\" \"scene0097_00\" \"scene0140_00\" \\\n\"scene0200_00\" \"scene0347_00\" \"scene0400_00\" \"scene0590_00\" \"scene0645_00\")\n\ngpu_num=3     # change!\nfor scan in \"${scan_list[@]}\"; do\n    echo \"Training for ${scan} .....\"\n    CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \\\n        -s /gdata/cold1/wuyanmin/OpenGaussian/data/onedrive/scannet/${scan} \\\n        -r 2 \\\n        --frozen_init_pts \\\n        --iterations 90_000 \\\n        --start_ins_feat_iter 30_000 \\\n        --start_root_cb_iter 50_000 \\\n        --start_leaf_cb_iter 70_000 \\\n        --sam_level 0 \\\n        --root_node_num 64 \\\n        --leaf_node_num 5 \\\n        --pos_weight 1.0 \\\n        --test_iterations 30000 \\\n        --eval\ndone"
  },
  {
    "path": "scripts/vis_opengs_pts_feat.py",
    "content": "import numpy as np\nfrom plyfile import PlyData\nimport open3d as o3d\n\ndef sigmoid(x):\n    \"\"\"Sigmoid function.\"\"\"\n    return 1 / (1 + np.exp(-x))\n\ndef visualize_ply(ply_path):\n    # Load the PLY file\n    ply_data = PlyData.read(ply_path)\n    vertex_data = ply_data['vertex'].data\n\n    # Extract the point cloud attributes\n    points = np.array([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T\n    colors = np.array([vertex_data['red'], vertex_data['green'], vertex_data['blue']]).T / 255.0\n    opacity = vertex_data['opacity']\n\n    # Apply the opacity filter\n    sigmoid_opacity = sigmoid(opacity)\n    filtered_indices = sigmoid_opacity >= 0.1\n    filtered_points = points[filtered_indices]\n    filtered_colors = colors[filtered_indices]\n\n    # Create an Open3D PointCloud object\n    pcd = o3d.geometry.PointCloud()\n    pcd.points = o3d.utility.Vector3dVector(filtered_points)\n    pcd.colors = o3d.utility.Vector3dVector(filtered_colors)\n\n    # Visualize the point cloud\n    o3d.visualization.draw_geometries([pcd])\n\nif __name__ == \"__main__\":\n    # Replace with the path to your PLY file\n    ply_path = \"output/xxxxxxxx-x/point_cloud/iteration_x0000/point_cloud.ply\"\n    visualize_ply(ply_path)"
  },
  {
    "path": "train.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport os\nimport torch\nimport torch.nn.functional as F\nfrom random import randint\nfrom utils.loss_utils import l1_loss, ssim, l2_loss\nfrom gaussian_renderer import render, network_gui\nimport sys\nfrom scene import Scene, GaussianModel\nfrom utils.general_utils import safe_state\nimport uuid\nfrom tqdm import tqdm\nfrom utils.image_utils import psnr\nfrom argparse import ArgumentParser, Namespace\nfrom arguments import ModelParams, PipelineParams, OptimizationParams\nfrom utils.graphics_utils import getWorld2View2, focal2fov, fov2focal\nfrom os import makedirs\nimport torchvision\nimport numpy as np\nfrom utils.sh_utils import RGB2SH\nimport math\n# import faiss\nfrom scene.kmeans_quantize import Quantize_kMeans\nfrom bitarray import bitarray\nfrom utils.system_utils import mkdir_p\nfrom utils.opengs_utlis import mask_feature_mean, pair_mask_feature_mean, \\\n    get_SAM_mask_and_feat, load_code_book, \\\n    calculate_iou, calculate_distances, calculate_pairwise_distances\n\ntry:\n    from torch.utils.tensorboard import SummaryWriter\n    TENSORBOARD_FOUND = True\nexcept ImportError:\n    TENSORBOARD_FOUND = False\n\n# Randomly initialize 300 colors for visualizing the SAM mask. [OpenGaussian]\nnp.random.seed(42)\ncolors_defined = np.random.randint(100, 256, size=(300, 3))\ncolors_defined[0] = np.array([0, 0, 0]) # Ignore the mask ID of -1 and set it to black.\ncolors_defined = torch.from_numpy(colors_defined)\n\ndef dec2binary(x, n_bits=None):\n    \"\"\"Convert decimal integer x to binary.\n\n    Code from: https://stackoverflow.com/questions/55918468/convert-integer-to-pytorch-tensor-of-binary-bits\n    \"\"\"\n    if n_bits is None:\n        n_bits = torch.ceil(torch.log2(x)).type(torch.int64)\n    mask = 2**torch.arange(n_bits-1, -1, -1).to(x.device, x.dtype)\n    return x.unsqueeze(-1).bitwise_and(mask).ne(0)\n\ndef save_kmeans(kmeans_list, quantized_params, out_dir, mode=\"root\"):\n    \"\"\"Save the codebook and indices of KMeans.\n\n    \"\"\"\n    # Convert to bitarray object to save compressed version\n    # saving as npy or pth will use 8bits per digit (or boolean) for the indices\n    # Convert to binary, concat the indices for all params and save.\n    if mode==\"root\":\n        out_dir = os.path.join(out_dir, 'root_code_book')\n    elif mode==\"leaf\":\n        out_dir = os.path.join(out_dir, 'leaf_code_book')\n    \n    mkdir_p(out_dir)\n    bitarray_all = bitarray([])\n    for kmeans in kmeans_list:\n        if mode==\"root\":\n            cls_ids = kmeans.cls_ids\n        elif mode==\"leaf\":\n            cls_ids = kmeans.leaf_cls_ids\n        n_bits = int(np.ceil(np.log2(len(cls_ids))))\n        assignments = dec2binary(cls_ids, n_bits)\n        bitarr = bitarray(list(assignments.cpu().numpy().flatten()))\n        bitarray_all.extend(bitarr)\n    with open(os.path.join(out_dir, 'kmeans_inds.bin'), 'wb') as file:  # cls_ids\n        bitarray_all.tofile(file)\n\n    # Save details needed for loading\n    args_dict = {}\n    args_dict['params'] = quantized_params\n    args_dict['n_bits'] = n_bits\n    args_dict['total_len'] = len(bitarray_all)\n    np.save(os.path.join(out_dir, 'kmeans_args.npy'), args_dict)\n    if mode==\"root\":\n        centers_dict = {param: kmeans.centers for (kmeans, param) in zip(kmeans_list, quantized_params)}\n    elif mode==\"leaf\":\n        centers_dict = {param: kmeans.leaf_centers for (kmeans, param) in zip(kmeans_list, quantized_params)}\n\n    # Save codebook\n    torch.save(centers_dict, os.path.join(out_dir, 'kmeans_centers.pth'))\n\ndef cohesion_loss(feat_map, gt_mask, feat_mean_stack):\n    \"\"\"intra-mask smoothing loss. Eq.(1) in the paper\n    Constrain the feature of each pixel within the mask to be close to the mean feature of that mask.\n    \"\"\"\n    N, H, W = gt_mask.shape\n    C = feat_map.shape[0]\n    # expand feat_map [6, H, W] to [N, 6, H, W]\n    feat_map_expanded = feat_map.unsqueeze(0).expand(N, C, H, W)\n    # expand mean feat [N, 6] to [N, 6, H, W]\n    feat_mean_stack_expanded = feat_mean_stack.unsqueeze(-1).unsqueeze(-1).expand(N, C, H, W)\n    \n    # fature distance    \n    masked_feat = feat_map_expanded * gt_mask.unsqueeze(1)           # [N, 6, H, W]\n    dist = (masked_feat - feat_mean_stack_expanded).norm(p=2, dim=1) # [N, H, W]\n    \n    # per mask feature distance (loss)\n    masked_dist = dist * gt_mask    # [N, H, W]\n    loss_per_mask = masked_dist.sum(dim=[1, 2]) / gt_mask.sum(dim=[1, 2]).clamp(min=1)\n\n    return loss_per_mask.mean()\n\ndef separation_loss(feat_mean_stack, iteration):\n    \"\"\" inter-mask contrastive loss Eq.(2) in the paper\n    Constrain the instance features within different masks to be as far apart as possible.\n    \"\"\"\n    N, _ = feat_mean_stack.shape\n\n    # expand feat_mean_stack[N, 6] to [N, N, C]\n    feat_expanded = feat_mean_stack.unsqueeze(1).expand(-1, N, -1)\n    feat_transposed = feat_mean_stack.unsqueeze(0).expand(N, -1, -1)\n    \n    # distance\n    diff_squared = (feat_expanded - feat_transposed).pow(2).sum(2)\n    \n    # Calculate the inverse of the distance to enhance discrimination\n    epsilon = 1     # 1e-6\n    inverse_distance = 1.0 / (diff_squared + epsilon)\n    # Exclude diagonal elements (distance from itself) and calculate the mean inverse distance\n    mask = torch.eye(N, device=feat_mean_stack.device).bool()\n    inverse_distance.masked_fill_(mask, 0)  \n\n    # note: weight\n    # sorted by distance\n    sorted_indices = inverse_distance.argsort().argsort()\n    loss_weight = (sorted_indices.float() / (N - 1)) * (1.0 - 0.1) + 0.1    # scale to 0.1 - 1.0, [N, N]\n    # small weight\n    if iteration > 35_000:\n        loss_weight[loss_weight < 0.9] = 0.1\n    inverse_distance *= loss_weight     # [N, N]\n\n    # final loss\n    loss = inverse_distance.sum() / (N * (N - 1))\n\n    return loss\n\ndef training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, \\\n             checkpoint, debug_from):\n    iterations = [opt.start_ins_feat_iter, opt.start_leaf_cb_iter, opt.start_root_cb_iter]\n    saving_iterations.extend(iterations)\n    checkpoint_iterations.extend(iterations)\n\n    first_iter = 0\n    tb_writer = prepare_output_and_logger(dataset)\n    gaussians = GaussianModel(dataset.sh_degree)\n    scene = Scene(dataset, gaussians)\n    gaussians.training_setup(opt)\n    if checkpoint:\n        (model_params, first_iter) = torch.load(checkpoint)\n        # NOTE: Load the original 3DGS pre-trained checkpoint and add the ins_feat attribute. [OpenGaussian]\n        if len(model_params) == 12:\n            # initialize instance color.\n            ins_feat = torch.rand((model_params[8].shape[0], opt.ins_feat_dim), dtype=torch.float, device=\"cuda\")\n            ins_feat = torch.nn.Parameter(ins_feat.requires_grad_(True))\n            to_list = list(model_params)\n            # (1) replace optimizer\n            to_list[10] = gaussians.optimizer.state_dict()\n            # (2) add ins_feat \n            to_list.insert(7, ins_feat)\n            # (3) add ins_feat_q (quantized ins_feat)\n            ins_feat_q = torch.empty(0)\n            to_list.insert(8, ins_feat_q)\n            model_params = tuple(to_list)\n        gaussians.restore(model_params, opt)\n        ins_feat_continue = gaussians._ins_feat.clone().detach()    # not used\n    else:\n        ins_feat_continue = None    # not used\n\n    # initialize the codebook\n    ins_feat_codebook = Quantize_kMeans(num_clusters=opt.root_node_num,         # k1\n                                        num_leaf_clusters=opt.leaf_node_num,    # k2\n                                        num_iters=5, \n                                        dim=9)\n    \n    # note: load the saved codebook\n    leaf_cluster_indices = None\n    if checkpoint:\n        base_dir = os.path.dirname(checkpoint)\n        load_iter = checkpoint.split('/')[-1].split('.')[0][6:]\n        root_code_book_path = os.path.join(base_dir, 'point_cloud', f\"iteration_{load_iter}\", \"root_code_book\")\n        leaf_code_book_path = os.path.join(base_dir, 'point_cloud', f\"iteration_{load_iter}\", \"leaf_code_book\")\n        if os.path.exists(os.path.join(root_code_book_path, 'kmeans_inds.bin')):\n            root_center, root_indices = load_code_book(root_code_book_path)\n            root_center_saved = root_center[\"ins_feat\"]\n            cluster_indices = torch.from_numpy(root_indices).cuda()\n            ins_feat_codebook.centers = root_center_saved\n            ins_feat_codebook.cls_ids = cluster_indices\n        else:\n            cluster_indices = None\n        if os.path.exists(os.path.join(leaf_code_book_path, 'kmeans_inds.bin')):\n            leaf_center, leaf_indices = load_code_book(leaf_code_book_path)\n            leaf_center_saved = leaf_center[\"ins_feat\"]\n            leaf_cluster_indices = torch.from_numpy(leaf_indices).cuda()\n            ins_feat_codebook.leaf_centers = leaf_center_saved\n            ins_feat_codebook.leaf_cls_ids = leaf_cluster_indices\n        else:\n            leaf_cluster_indices = None\n\n    bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]\n    background = torch.tensor(bg_color, dtype=torch.float32, device=\"cuda\")\n\n    iter_start = torch.cuda.Event(enable_timing = True)\n    iter_end = torch.cuda.Event(enable_timing = True)\n\n    viewpoint_stack = None\n    ema_loss_for_log = 0.0\n    progress_bar = tqdm(range(first_iter, opt.iterations), desc=\"Training progress\")\n    first_iter += 1\n    root_id = 0                 # for stage 2.2\n    loss = torch.tensor(0.0)\n    Ll1 = torch.tensor(0.0)\n    for iteration in range(first_iter, opt.iterations + 1):        \n        no_need_bk = False\n        \n        if network_gui.conn == None:\n            network_gui.try_connect()\n        while network_gui.conn != None:\n            try:\n                net_image_bytes = None\n                custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()\n                if custom_cam != None:\n                    net_image = render(custom_cam, gaussians, pipe, background, iteration, scaling_modifer)[\"render\"]\n                    net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())\n                network_gui.send(net_image_bytes, dataset.source_path)\n                if do_training and ((iteration < int(opt.iterations)) or not keep_alive):\n                    break\n            except Exception as e:\n                network_gui.conn = None\n\n        iter_start.record()\n\n        gaussians.update_learning_rate(iteration, opt.start_root_cb_iter, opt.start_leaf_cb_iter)\n\n        # Every 1000 its we increase the levels of SH up to a maximum degree\n        if iteration % 1000 == 0:\n            gaussians.oneupSHdegree()\n\n        # Pick a random Camera\n        if not viewpoint_stack:\n            viewpoint_stack = scene.getTrainCameras().copy()\n        viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))\n        if not viewpoint_cam.data_on_gpu:\n            viewpoint_cam.to_gpu()\n\n        cb_mode = None  # Current status: No launch codebook discretization\n        if iteration == 1:\n            print(\"[Stage 0] Start 3dgs pre-train ...\")\n            sys.stdout.flush()\n        if iteration == opt.start_ins_feat_iter + 1:\n            print(\"[Stage 1] Start continuous instance feature learning ...\")\n            sys.stdout.flush()\n        # Stage 2.1: Coarse-level codebook\n        if iteration > opt.start_root_cb_iter and iteration <= opt.start_leaf_cb_iter:\n            cb_mode = \"root\"\n            if iteration == opt.start_root_cb_iter + 1:\n                print(\"[Stage 2.1] Start coarse-level codebook discretization ...\")\n                sys.stdout.flush()\n        elif iteration > opt.start_leaf_cb_iter:\n            cb_mode = \"leaf\"\n            # Stage 2.2: Fine-level codebook\n            if iteration == opt.start_leaf_cb_iter + 1:\n                print(\"[Stage 2.2] Start fine-level codebook discretization ...\")\n                sys.stdout.flush()\n            # note Update a coarse cluster every leaf_update_fr(default 300) steps.\n            if (iteration - opt.start_leaf_cb_iter) % opt.leaf_update_fr == 0:\n                root_id += 1    # 0 ~ k1-1\n                if root_id > (opt.root_node_num-1):\n                    root_id = 0\n        \n        # ###########################################################################\n        # [Stage 2]: Two-Level Codebook for Discretization                          #\n        #   - Preprocessing: construct pseudo labels (instance features of stage 1) #\n        #     Will execute twice, before coarse-level and fine-level clustering     #\n        # ###########################################################################\n        if (cb_mode is not None and viewpoint_cam.pesudo_ins_feat is None) or \\\n           ((iteration == opt.start_root_cb_iter + 1) or (iteration == opt.start_leaf_cb_iter + 1)):\n            with torch.no_grad():\n                if cb_mode == \"leaf\" and cluster_indices is None:\n                    cluster_indices = ins_feat_codebook.cls_ids # [num_pts], Coarse-level ID of each point (0 ~ k1-1)\n                construct_pseudo_ins_feat(scene, render, (pipe, background, iteration),\n                                          cluster_indices=cluster_indices, mode=cb_mode,\n                                          root_num=opt.root_node_num, leaf_num=opt.leaf_node_num,\n                                          sam_level=opt.sam_level,\n                                          save_memory=opt.save_memory)\n                if not viewpoint_cam.data_on_gpu:\n                    viewpoint_cam.to_gpu()\n                if cb_mode == \"leaf\":\n                    # Number of leaves per root\n                    ins_feat_codebook.iLeafSubNum = gaussians.iClusterSubNum\n\n        # Render\n        if (iteration - 1) == debug_from:\n            pipe.debug = True\n\n        bg = torch.rand((3), device=\"cuda\") if opt.random_background else background\n        \n        # ####################################################\n        # [Stage 2]: Two-Level Codebook for Discretization   #\n        #   - Update codebook                                #\n        # ####################################################\n        freq_k_means = 200       # coarse-level codebook update frequency\n        if cb_mode == \"leaf\":\n            freq_k_means = 50    # todo fine-level codebook update frequency\n        if cb_mode is not None:\n            if (iteration % freq_k_means == 1) or iteration == opt.start_root_cb_iter + 1:\n                assign = True   # Reassign cluster centers\n            else:\n                assign = False  #  update cluster centers\n            ins_feat_codebook.forward(gaussians, iteration, assign=assign, \\\n                                      mode=cb_mode, selected_leaf=root_id, \\\n                                      pos_weight=opt.pos_weight)   # note: position weight\n\n        # render function\n        if iteration <= opt.start_ins_feat_iter:    # stage 0\n            render_feat=False\n            render_cluster=False\n            cluster_indices=None\n        elif iteration > opt.start_leaf_cb_iter:  # stage 2.2 (fine-level)\n            render_feat=False   \n            render_cluster=True\n        else:   # stage 1, stage 2.1(coarse-level)\n            render_feat=True\n            render_cluster=False\n            cluster_indices=None\n        # rescale\n        if iteration > opt.start_root_cb_iter:  # stage 2, rescale\n            rescale=True\n        else:\n            rescale=False\n\n        render_pkg = render(viewpoint_cam, gaussians, pipe, bg, iteration,\n                            rescale=rescale,                # wherther to re-scale the gaussian scale\n                            cluster_idx=cluster_indices,    # coarse-level cluster id\n                            leaf_cluster_idx=ins_feat_codebook.leaf_cls_ids,    # fine-level cluster id\n                            render_feat_map=render_feat, \n                            render_cluster=render_cluster,\n                            selected_root_id=root_id)       # coarse id (stage 2.2)\n        # rendered results\n        image, viewspace_point_tensor, visibility_filter, radii = \\\n            render_pkg[\"render\"], render_pkg[\"viewspace_points\"], render_pkg[\"visibility_filter\"], render_pkg[\"radii\"]\n        alpha = render_pkg[\"alpha\"]\n        rendered_silhouette = render_pkg[\"silhouette\"] if render_pkg[\"silhouette\"] is not None else alpha\n        rendered_silhouette = (rendered_silhouette > 0.7) * 1.0 # mask after re-scale\n        rendered_ins_feat = render_pkg[\"ins_feat\"]\n        rendered_cluster_imgs = render_pkg[\"cluster_imgs\"]  # [num_cl, 6, H, W]\n        rendered_leaf_cluster_imgs = render_pkg[\"leaf_clusters_imgs\"]\n        rendered_cluster_silhouettes = render_pkg[\"cluster_silhouettes\"]\n        if render_cluster:\n            if rendered_cluster_silhouettes is not None and len(rendered_cluster_silhouettes) > 0:\n                rendered_cluster_silhouettes = rendered_cluster_silhouettes > 0.7\n            else:\n                # root_id-th coarse cluster not visible in current view\n                no_need_bk = True\n\n        # gt supervision: rgb image & SAM mask\n        gt_image = viewpoint_cam.original_image.cuda()\n        if viewpoint_cam.original_sam_mask is not None:\n            gt_sam_mask = viewpoint_cam.original_sam_mask.cuda()    # [4, H, W]\n        \n        # ##################################################\n        # [Stage 0]: 0 to 3w steps, Standard 3DGS RGB loss #\n        # ##################################################\n        if iteration <= opt.start_ins_feat_iter:\n            Ll1 = l1_loss(image, gt_image)\n            loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))\n\n        # Start learning instance features after 3W steps.\n        if iteration > opt.start_ins_feat_iter:\n            # NOTE: Freeze the pre-trained Gaussian parameters and only train the instance features.\n            scene.gaussians._xyz = scene.gaussians._xyz.detach()\n            scene.gaussians._features_dc = scene.gaussians._features_dc.detach()\n            scene.gaussians._features_rest = scene.gaussians._features_rest.detach()\n            scene.gaussians._opacity = scene.gaussians._opacity.detach()\n            scene.gaussians._scaling = scene.gaussians._scaling.detach()\n            scene.gaussians._rotation = scene.gaussians._rotation.detach()\n\n            # construct boolean masks [num_mask, H, W]\n            # sam_level, leaf:3, scannet:0\n            sam_level = opt.sam_level\n            mask_id, mask_bool, invalid_pix = get_SAM_mask_and_feat(gt_sam_mask, level=sam_level, filter_th=50)\n\n            # #################################################\n            # [Stage 1]: Continuous instance feature learning #\n            #           LERF 3W-4W steps; ScanNet 3w-5w steps #\n            #           see Sec.3.1 in the paper              #\n            # #################################################\n            if cb_mode is None:\n                # (0) compute the average instance features within each mask. [num_mask, 6]\n                feat_mean_stack = mask_feature_mean(rendered_ins_feat, mask_bool, image_mask=rendered_silhouette)\n                # (1) intra-mask smoothing loss. Eq.(1) in the paper\n                loss_cohesion = cohesion_loss(rendered_ins_feat, mask_bool, feat_mean_stack)\n                # (2) inter-mask contrastive loss Eq.(2) in the paper\n                loss_separation = separation_loss(feat_mean_stack, iteration)\n                # total loss, opt.loss_weight: 0.1\n                loss = loss_separation + opt.loss_weight * loss_cohesion\n        \n        # ####################################################\n        # [Stage 2]: Two-Level Codebook for Discretization \n        #   - coarse-level(root) loss computation\n        #   - fine-level(leaf) loss computation\n        # ####################################################\n        # 2.1 coarse-level\n        if cb_mode == \"root\":   \n            # Only consider valid pixels\n            keeped_pix = viewpoint_cam.pesudo_ins_feat.sum(dim=(0)) > 0     # Invalid pixels of pseudo-labels\n            keeped_pix = keeped_pix.bool()&rendered_silhouette.bool()       # Empty regions after rescaling\n            keeped_pix = keeped_pix&(~invalid_pix.unsqueeze(0))             # Invalid area of the original mask\n            keeped_pix = rendered_silhouette.bool()\n            # loss  Eq.(4) in the paper.\n            feat_loss = l1_loss(rendered_ins_feat, viewpoint_cam.pesudo_ins_feat, keeped_pix)  \n            # feat_loss = l2_loss(rendered_ins_feat, viewpoint_cam.pesudo_ins_feat, keeped_pix)\n            loss = feat_loss\n        # 2.2 fine-level\n        if cb_mode == \"leaf\" and no_need_bk == False:   \n            total_pix = gt_image.shape[1] * gt_image.shape[2]\n            for i in range(len(rendered_cluster_imgs)):\n                cluster_pred = rendered_cluster_imgs[i]\n                cluster_silhouette = rendered_cluster_silhouettes[i]    # [H, W] bool\n                rendered_ins_feat = cluster_pred                    # \n                # cluster_mask = viewpoint_cam.cluster_masks[i]     # [H, W] bool\n                # cluster_silhouette = cluster_silhouette & cluster_mask\n                feat_loss = l2_loss(cluster_pred, viewpoint_cam.pesudo_ins_feat, cluster_silhouette)\n                if i == 0:\n                    # loss = feat_loss * (cluster_silhouette.sum() / total_pix)\n                    loss = feat_loss\n                else:\n                    # loss += (feat_loss * (cluster_silhouette.sum() / total_pix))\n                    loss += feat_loss\n\n        # mask loss. modify -----\n        if viewpoint_cam.original_mask is not None:\n            gt_mask = viewpoint_cam.original_mask.cuda()\n            mask_loss = F.mse_loss(alpha, gt_mask)\n            loss = loss + mask_loss\n        \n        if no_need_bk == False:\n            loss.backward()\n\n        iter_end.record()\n\n        # Save the intermediate training results. [OpenGaussian]\n        save_intermediate = True\n        save_fre = 1000\n        if iteration > opt.start_leaf_cb_iter:\n            save_fre = 100\n        if (iteration % save_fre == 0) and save_intermediate:\n            gts_path = os.path.join(scene.model_path, \"train_process\", \"gt\")\n            makedirs(gts_path, exist_ok=True)\n            torchvision.utils.save_image(gt_image.detach().cpu(), os.path.join(gts_path, '{0:05d}'.format(iteration) + \".png\"))\n            \n            render_path = os.path.join(scene.model_path, \"train_process\", \"renders\")\n            makedirs(render_path, exist_ok=True)\n            torchvision.utils.save_image(image.detach().cpu(), os.path.join(render_path, '{0:05d}'.format(iteration) + \".png\"))\n\n            # alpha_path = os.path.join(scene.model_path, \"train_process\", \"alpha\")\n            # makedirs(alpha_path, exist_ok=True)\n            # torchvision.utils.save_image(alpha.detach().cpu(), os.path.join(alpha_path, '{0:05d}'.format(iteration) + \".png\"))\n            \n            if iteration > opt.start_ins_feat_iter:\n                if cb_mode is None:\n                    sub_floader = \"stage1\"\n                elif cb_mode == \"root\":\n                    sub_floader = \"stage2_1\"\n                elif cb_mode == \"leaf\":\n                    sub_floader = \"stage2_2\"\n                # Visualize the SAM mask. [OpenGaussian]\n                if gt_sam_mask is not None and iteration > opt.start_ins_feat_iter:\n                    # read predefined mask color\n                    mask_color_rand = colors_defined[mask_id.detach().cpu()].type(torch.float64)\n                    mask_color_rand = mask_color_rand.permute(2, 0, 1)\n                    gt_sam_path = os.path.join(scene.model_path, \"train_process\", sub_floader, \"gt_sam_mask_\" + str(opt.sam_level))\n                    makedirs(gt_sam_path, exist_ok=True)\n                    torchvision.utils.save_image(mask_color_rand/255.0, os.path.join(gt_sam_path, '{0:05d}'.format(iteration) + \".png\"))\n                \n                # TODO \n                if viewpoint_cam.pesudo_ins_feat is not None:\n                    feat = viewpoint_cam.pesudo_ins_feat\n                    pseudo_ins_feat_path = os.path.join(scene.model_path, \"train_process\", sub_floader, \"pseudo_ins_feat\")\n                    makedirs(pseudo_ins_feat_path, exist_ok=True)\n                    torchvision.utils.save_image(feat.detach().cpu()[:3, :, :], os.path.join(pseudo_ins_feat_path, '{0:05d}'.format(iteration) + \"_1.png\"))\n                    torchvision.utils.save_image(feat.detach().cpu()[3:6, :, :], os.path.join(pseudo_ins_feat_path, '{0:05d}'.format(iteration) + \"_2.png\"))\n\n                if cb_mode is not None:\n                    # silhouette (alpha to mask) [OpenGaussian] stage 2\n                    silhouette_path = os.path.join(scene.model_path, \"train_process\", sub_floader, \"silhouette\")\n                    makedirs(silhouette_path, exist_ok=True)\n                    torchvision.utils.save_image(rendered_silhouette.detach().cpu(), os.path.join(silhouette_path, '{0:05d}'.format(iteration) + \".png\"))\n\n                # Visualize the 6-dimensional instance feature. [OpenGuassian]\n                if rendered_ins_feat is not None:\n                    # dim 0:3\n                    ins_feat_path = os.path.join(scene.model_path, \"train_process\", sub_floader, \"ins_feat\")\n                    makedirs(ins_feat_path, exist_ok=True)\n                    torchvision.utils.save_image(rendered_ins_feat.detach().cpu()[:3, :, :], os.path.join(ins_feat_path, '{0:05d}'.format(iteration) + \".png\"))\n                    # dim 3:6\n                    ins_feat_path2 = os.path.join(scene.model_path, \"train_process\", sub_floader, \"ins_feat2\")\n                    makedirs(ins_feat_path2, exist_ok=True)\n                    torchvision.utils.save_image(rendered_ins_feat.detach().cpu()[3:6, :, :], os.path.join(ins_feat_path2, '{0:05d}'.format(iteration) + \".png\"))\n\n                # # fine-level cluster\n                # if rendered_leaf_cluster_imgs is not None:\n                #     leaf_cluster_path = os.path.join(scene.model_path, \"train_process\", sub_floader, \"cluster_leaf\")\n                #     makedirs(leaf_cluster_path, exist_ok=True)\n                #     for i, leaf_img in enumerate(rendered_leaf_cluster_imgs):\n                #         torchvision.utils.save_image(leaf_img.detach().cpu()[:3, :, :], os.path.join(leaf_cluster_path, '{0:05d}'.format(iteration) + \"leaf_{}.png\".format(i)))\n\n        with torch.no_grad():\n            # Progress bar\n            ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log\n            if iteration % 10 == 0:\n                progress_bar.set_postfix({\"Loss\": f\"{ema_loss_for_log:.{7}f}\"})\n                progress_bar.update(10)\n            if iteration == opt.iterations:\n                progress_bar.close()\n\n            # Log and save .ply\n            # training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), \\\n            #     testing_iterations, opt.start_root_cb_iter, scene, render, (pipe, background, iteration))\n            if (iteration in saving_iterations):\n                print(\"\\n[ITER {}] Saving Gaussians\".format(iteration))\n                sys.stdout.flush()\n                if iteration > opt.start_root_cb_iter:\n                    # note: save codebook [OpenGaussian]\n                    out_dir = os.path.join(scene.model_path, 'point_cloud/iteration_%d' % iteration)\n                    save_kmeans([ins_feat_codebook], [\"ins_feat\"], out_dir, mode=\"root\")\n                    if cb_mode == \"leaf\":\n                        save_kmeans([ins_feat_codebook], [\"ins_feat\"], out_dir, mode=\"leaf\")\n                    scene.save(iteration, [\"ins_feat\"])\n                else:\n                    scene.save(iteration)\n\n            # Densification\n            if iteration < opt.densify_until_iter and \\\n                not opt.frozen_init_pts: # note: ScanNet dataset is not densified [OpenGaussian]\n                # Keep track of max radii in image-space for pruning\n                gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])\n                gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)\n\n                if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:\n                    size_threshold = 20 if iteration > opt.opacity_reset_interval else None\n                    gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)\n\n                if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):\n                    gaussians.reset_opacity()\n\n            # Optimizer step\n            if iteration < opt.iterations:\n                gaussians.optimizer.step()\n                gaussians.optimizer.zero_grad(set_to_none = True)\n                torch.cuda.empty_cache()\n\n            if (iteration in checkpoint_iterations):\n                print(\"\\n[ITER {}] Saving Checkpoint\".format(iteration))\n                sys.stdout.flush()\n                torch.save((gaussians.capture(), iteration), scene.model_path + \"/chkpnt\" + str(iteration) + \".pth\")\n            \n            # ###########################################################\n            # Stage 3. associate language feature (training-free stage) #\n            #   - Performed after training.                             #\n            # ###########################################################\n            if iteration == opt.iterations and iteration > opt.start_leaf_cb_iter:\n                print(\"[Stage 3] Start 2D language feature - 3D cluster association ...\")\n                sys.stdout.flush()\n                if leaf_cluster_indices is None:\n                    leaf_cluster_indices = ins_feat_codebook.leaf_cls_ids   # fine-level cluster id\n                construct_pseudo_ins_feat(scene, render, (pipe, background, first_iter),\n                                          cluster_indices=leaf_cluster_indices, mode=\"lang\",\n                                          root_num=opt.root_node_num, leaf_num=opt.leaf_node_num,\n                                          sam_level=opt.sam_level,\n                                          save_memory=opt.save_memory)\n        \n        # note: save memory (only stage 2, 3)\n        if viewpoint_cam.data_on_gpu and opt.save_memory and cb_mode is not None:\n            viewpoint_cam.to_cpu()\n\ndef prepare_output_and_logger(args):    \n    if not args.model_path:\n        if os.getenv('OAR_JOB_ID'):\n            unique_str=os.getenv('OAR_JOB_ID')\n        else:\n            unique_str = str(uuid.uuid4())\n        args.model_path = os.path.join(\"./output/\", unique_str[0:10])\n        \n    # Set up output folder\n    print(\"Output folder: {}\".format(args.model_path))\n    os.makedirs(args.model_path, exist_ok = True)\n    with open(os.path.join(args.model_path, \"cfg_args\"), 'w') as cfg_log_f:\n        cfg_log_f.write(str(Namespace(**vars(args))))\n\n    # Create Tensorboard writer\n    tb_writer = None\n    if TENSORBOARD_FOUND:\n        tb_writer = SummaryWriter(args.model_path)\n    else:\n        print(\"Tensorboard not available: not logging progress\")\n    return tb_writer\n\ndef construct_pseudo_ins_feat(scene : Scene, renderFunc, renderArgs, \n                            filter=True,            # filter pseudo features\n                            cluster_indices=None,   # coarse-level ID of each point (0 ~ k1-1)\n                            mode=\"root\",            # root, leaf, lang\n                            root_num=64, leaf_num=10,   # k1, k2\n                            sam_level=3,\n                            save_memory=False):\n    torch.cuda.empty_cache()\n    # ##############################################################################################\n    # [Stage 2.1, 2.2] Render all training views once to construct pseudo-instance feature labels. #\n    #   - view.pesudo_ins_feat  [C=6, H, W]                                                        #\n    #   - view.pesudo_mask_bool [num_mask, H, W]                                                   #\n    # ##############################################################################################\n    sorted_train_cameras = sorted(scene.getTrainCameras(), key=lambda Camera: Camera.image_name)\n    for idx, view in enumerate(tqdm(sorted_train_cameras, desc=\"construt pseudo feat\")):\n        if not view.data_on_gpu:\n            view.to_gpu()\n\n        # render\n        render_pkg = renderFunc(view, scene.gaussians, *renderArgs, rescale=False, origin_feat=True)\n        rendered_ins_feat = render_pkg[\"ins_feat\"]\n        \n        # get gt sam mask\n        mask_id, mask_bool, invalid_pix = \\\n            get_SAM_mask_and_feat(view.original_sam_mask.cuda(), level=sam_level)\n\n        # construt pseudo ins_feat, mask levle\n        pseudo_mask_ins_feat_, mask_var, pix_count = mask_feature_mean(rendered_ins_feat, mask_bool, return_var=True)   # [num_mask, 6]\n        pseudo_mask_ins_feat = torch.cat((torch.zeros((1, 6)).cuda(), pseudo_mask_ins_feat_), dim=0)# [num_mask+1, 6]\n        # Filter out masks with high variance. Potentially incorrect segmentation.\n        filter_mask = mask_var > 0.006   # True->del\n        filter_mask = torch.cat((torch.tensor([False]).cuda(), filter_mask), dim=0)  # [num_mask+1]\n        # Masks with large pixel ratio may be background points, inevitably leading to a large variance， Keep them.\n        ignored_mask_ind = torch.nonzero(pix_count > pix_count.max() * 0.8).squeeze()\n        filter_mask[ignored_mask_ind + 1] = False\n        filtered_mask_pseudo_ins_feat = pseudo_mask_ins_feat.clone()\n        filtered_mask_pseudo_ins_feat[filter_mask] *= 0\n\n        # pseudo ins_feat, image level\n        pseudo_ins_feat = pseudo_mask_ins_feat[mask_id]     # Retrieve corresponding ins_feat by mask ID\n        pseudo_ins_feat = pseudo_ins_feat.permute(2, 0, 1)  # [H, W, 6]->[6, H, W]\n\n        # filterd pseudo ins_feat, image level\n        filter_pseudo_ins_feat = filtered_mask_pseudo_ins_feat[mask_id]\n        filter_pseudo_ins_feat = filter_pseudo_ins_feat.permute(2, 0, 1)\n\n        # filtered mask [1+num_mask, H, W]\n        mask_bool_filtered = torch.cat((torch.zeros_like(mask_bool[0].unsqueeze(0)), mask_bool), dim=0)\n        mask_bool_filtered[filter_mask] *= 0\n\n        # NOTE: save the construct pesudo_ins_feat\n        # total_feat.append(pseudo_mask_ins_feat[1:,:])\n        # if view.pesudo_ins_feat is None:\n        view.pesudo_ins_feat = filter_pseudo_ins_feat if filter else pseudo_ins_feat\n        # view.pesudo_ins_feat = rendered_ins_feat\n        view.pesudo_mask_bool = mask_bool_filtered.to(torch.bool)\n\n        # Save some results for visualization.\n        pseudo_debug = True\n        if idx % 20 == 0 and pseudo_debug:\n            pseudo_ins_feat_path = os.path.join(scene.model_path, \"train_process\", \"debug_pseudo_label\", \"all_pseudo_ins_feat\")\n            filter_pseudo_ins_feat_path = os.path.join(scene.model_path, \"train_process\", \"debug_pseudo_label\", \"all_filter_pseudo_ins_feat\")\n            rendered_ins_feat_path = os.path.join(scene.model_path, \"train_process\", \"debug_pseudo_label\", \"all_render_ins_feat\")\n            sam_mask_path = os.path.join(scene.model_path, \"train_process\", \"debug_pseudo_label\", \"all_sam_mask\")\n            makedirs(pseudo_ins_feat_path, exist_ok=True)\n            makedirs(filter_pseudo_ins_feat_path, exist_ok=True)\n            makedirs(rendered_ins_feat_path, exist_ok=True)\n            makedirs(sam_mask_path, exist_ok=True)\n\n            # pseudo ins_feat\n            torchvision.utils.save_image(pseudo_ins_feat[:3,:,:], os.path.join(pseudo_ins_feat_path, '{0:05d}'.format(idx) + \"_1.png\"))\n            # torchvision.utils.save_image(pseudo_ins_feat[3:6,:,:], os.path.join(pseudo_ins_feat_path, '{0:05d}'.format(idx) + \"_2.png\"))\n            # filtered pseudo ins_feat\n            torchvision.utils.save_image(filter_pseudo_ins_feat[:3,:,:], os.path.join(filter_pseudo_ins_feat_path, '{0:05d}'.format(idx) + \"_1.png\"))\n            # torchvision.utils.save_image(filter_pseudo_ins_feat[3:6,:,:], os.path.join(filter_pseudo_ins_feat_path, '{0:05d}'.format(idx) + \"_2.png\"))\n            # rendered ins_feat\n            torchvision.utils.save_image(rendered_ins_feat[:3,:,:], os.path.join(rendered_ins_feat_path, '{0:05d}'.format(idx) + \"_1.png\"))\n            # torchvision.utils.save_image(rendered_ins_feat[3:6,:,:], os.path.join(rendered_ins_feat_path, '{0:05d}'.format(idx) + \"_2.png\"))\n            # gt SAM mask, read predefined mask color\n            mask_color_rand = colors_defined[mask_id.detach().cpu()].type(torch.float64)\n            mask_color_rand = mask_color_rand.permute(2, 0, 1)\n            torchvision.utils.save_image(mask_color_rand/255.0, os.path.join(sam_mask_path, '{0:05d}'.format(idx) + \".png\"))\n        # to cpu\n        if view.data_on_gpu and save_memory:\n            view.to_cpu()\n    \n    # ##################################################################################################\n    # Preprocessing for Stage 2.2\n    # determine how many objects are in each coarse cluster, not just setting a fixed k2 value.\n    # ##################################################################################################\n    torch.cuda.empty_cache()\n    if mode==\"leaf\":\n        iClusterSubNum = torch.ones(cluster_indices.max()+1).to(torch.int32)\n        for idx, view in enumerate(tqdm(sorted_train_cameras, desc=\"render coarse-level cluster\")):\n            if not view.data_on_gpu:\n                view.to_gpu()\n            render_pkg = renderFunc(view, scene.gaussians, *renderArgs, cluster_idx=cluster_indices, rescale=False,\\\n                                    render_feat_map=False, render_cluster=True, origin_feat=True, better_vis=True,\n                                    root_num=root_num, leaf_num=leaf_num)\n            rendered_cluster_imgs = render_pkg[\"cluster_imgs\"]  # coarse cluster feature map\n            rendered_cluster_silhouettes = render_pkg[\"cluster_silhouettes\"] # coarse cluster mask\n            cluster_occur = render_pkg[\"cluster_occur\"] # bool [k1] Whether coarse clusters visible in the current view\n\n            pser_cluster_pesudo_mask = []\n            i = -1\n            for cluster_idx in range(cluster_indices.max()+1):\n                if not cluster_occur[cluster_idx]:  # Process only coarse clusters visible in the current view\n                    continue\n\n                i += 1\n                rendered_ins_feat = rendered_cluster_imgs[i]    # cluster feat map\n                rendered_silhouette = (rendered_cluster_silhouettes[i] > 0.9).unsqueeze(0)  # cluster mask\n\n                # (1) compute the IoU of this cluster with pseudo masks.\n                ious = calculate_iou(view.pesudo_mask_bool, rendered_silhouette, base=\"former\")\n                # pseudo masks with IoU above threshold\n                inters_mask = view.pesudo_mask_bool[ious[0] > 0.2]  # [num_mask, H, W]\n                inters_mask_ = inters_mask.sum(0).to(torch.bool)   # [H, W] bool\n                # pseudo mask features, noly for visalization [6, H, W]\n                inters_pesudo_ins_feat = view.pesudo_ins_feat * inters_mask_.unsqueeze(0) \n\n                # (2) compute the distance between coarse cluster features and pseudo features\n                # mean feature of the pesudo mask, [num_mask, 6]\n                inters_mask_feat_mean = mask_feature_mean(view.pesudo_ins_feat, inters_mask) \n                # mean feature of the cluster, [num_mask, 6]\n                cluster_mask_feat_mean = mask_feature_mean(rendered_ins_feat, inters_mask, image_mask=rendered_silhouette) \n                # distance\n                l1_dis, l2_dis = calculate_distances(inters_mask_feat_mean, cluster_mask_feat_mean)   # metric=\"l1\"\n\n                # (3) filter out some pseudo masks\n                inters_mask_filter = inters_mask[(l1_dis < 0.9) & (l2_dis < 0.5)]  # l2_disk < 0.8\n                if inters_mask_filter.shape[0] > 10:    # TODO 10? --> leaf_num\n                    smallest_10 = torch.topk(l1_dis, 10, largest=False)[1]\n                    inters_mask_filter = inters_mask[smallest_10]\n                inters_mask_filter_ = inters_mask_filter.sum(0).to(torch.bool) \n                inters_pesudo_ins_feat2 = view.pesudo_ins_feat * inters_mask_filter_.unsqueeze(0) # noly for visalization\n                if inters_mask_filter_.any() == False:  # Skip if the cluster doesn’t intersect with any pseudo masks.\n                    cluster_occur[cluster_idx] = False\n                    continue\n                \n                pser_cluster_pesudo_mask.append(inters_mask_filter_)    # valid mask\n                # NOTE: (4) Determine the number of masks (i.e., objects) in each coarse cluster.\n                iClusterSubNum[cluster_idx] = max(iClusterSubNum[cluster_idx], inters_mask_filter.shape[0])\n\n                # (5) save some intermediate results for debugging\n                coarse_debug = False\n                if coarse_debug:\n                    cluster_path = os.path.join(scene.model_path, \"train_process\", \"debug_coarse_cluster\", \"cluster\")\n                    cluster_silhouette_path = os.path.join(scene.model_path, \"train_process\", \"debug_coarse_cluster\", \"cluster_silhouette\")\n                    cluster_inters_pesudo_path = os.path.join(scene.model_path, \"train_process\", \"debug_coarse_cluster\", \"cluster_inters_pesudo\")\n                    makedirs(cluster_path, exist_ok=True)\n                    makedirs(cluster_silhouette_path, exist_ok=True)\n                    makedirs(cluster_inters_pesudo_path, exist_ok=True)\n\n                    # coarse-level cluster feature map\n                    torchvision.utils.save_image(rendered_ins_feat[:3,:,:].cpu(), os.path.join(cluster_path, '{0:05d}'.format(idx) + f\"_c_{cluster_idx}\" + \"_1.png\"))\n                    # torchvision.utils.save_image(rendered_ins_feat[3:,:,:].cpu(), os.path.join(cluster_path, '{0:05d}'.format(idx) + f\"_c_{cluster_idx}\" + \"_2.png\"))\n                    torchvision.utils.save_image(rendered_silhouette.to(torch.float32).cpu(), os.path.join(cluster_silhouette_path, '{0:05d}'.format(idx) + f\"_c_{cluster_idx}\" + \"_1.png\"))\n\n                    # pseudo masks of coarse cluster (_f represents the filtered.)\n                    torchvision.utils.save_image(inters_pesudo_ins_feat[:3,:,:].cpu(), os.path.join(cluster_inters_pesudo_path, '{0:05d}'.format(idx) + f\"_c_{cluster_idx}\" + \"_1.png\"))\n                    # torchvision.utils.save_image(inters_pesudo_ins_feat[3:,:,:].cpu(), os.path.join(cluster_inters_pesudo_path, '{0:05d}'.format(idx) + f\"_c_{cluster_idx}\" + \"_2.png\"))\n                    torchvision.utils.save_image(inters_pesudo_ins_feat2[:3,:,:].cpu(), os.path.join(cluster_inters_pesudo_path, '{0:05d}'.format(idx) + f\"_c_{cluster_idx}\" + \"_1_f.png\"))\n                    # torchvision.utils.save_image(inters_pesudo_ins_feat2[3:,:,:].cpu(), os.path.join(cluster_inters_pesudo_path, '{0:05d}'.format(idx) + f\"_c_{cluster_idx}\" + \"_2_f.png\"))\n\n            if view.cluster_masks is None:\n                view.cluster_masks = pser_cluster_pesudo_mask   # pseudo masks of coarse cluster\n                view.bClusterOccur = cluster_occur              # whether visible in the current view\n\n            if view.data_on_gpu and save_memory:\n                view.to_cpu()\n\n        # update\n        scene.gaussians.iClusterSubNum = (iClusterSubNum + 1).clamp(max=leaf_num)\n        torch.cuda.empty_cache()\n    \n    # ###########################################################################\n    # [Stage 3] 2D mask(and language feat) - 3D fine level cluster association  # \n    #   - Sec. 3.3 in the paper                                                 #\n    # ###########################################################################\n    if mode == \"lang\":\n        # [leaf_num, view_num, (matched_mask_id, matched_score, b_matched)]\n        match_info = torch.zeros(root_num * leaf_num, len(sorted_train_cameras), 3).cuda()  # [k1*k2, num_imgs, 3]\n        # iterate over the coarse-level clusters\n        for root_id, _ in enumerate(tqdm(range(root_num), desc=\"mapping\")):\n            # iterate over all training views\n            for v_id, view in enumerate(sorted_train_cameras):\n                if not view.data_on_gpu:\n                    view.to_gpu()\n\n                # (0) render\n                render_pkg = renderFunc(view, scene.gaussians, *renderArgs, leaf_cluster_idx=cluster_indices, rescale=False,\\\n                                        render_feat_map=False, render_cluster=True, origin_feat=True, better_vis=False,\\\n                                        selected_root_id=root_id,\\\n                                        root_num=root_num, leaf_num=leaf_num)\n                rendered_leaf_cluster_imgs = render_pkg[\"leaf_clusters_imgs\"]   # all fine-level clusters of the root_id-th coarse-level.\n                rendered_leaf_cluster_silhouettes = render_pkg[\"leaf_cluster_silhouettes\"]\n                occured_leaf_id = render_pkg[\"occured_leaf_id\"]\n                if len(occured_leaf_id) > 0:\n                    occured_leaf_id = torch.tensor(occured_leaf_id).cuda()\n                    rendered_leaf_cluster_imgs = torch.stack(rendered_leaf_cluster_imgs, dim=0) # [N, C, H, W]\n                    rendered_leaf_cluster_silhouettes = rendered_leaf_cluster_silhouettes > 0.8 # [N, H, W]\n                else:\n                    if view.data_on_gpu and save_memory:\n                        view.to_cpu()\n                    continue    # root_id not visible in current view\n\n                # (1) iou  [num_rendered_leaf, num_mask]\n                ious = calculate_iou(view.pesudo_mask_bool, rendered_leaf_cluster_silhouettes)\n\n                # (2) feature distance\n                # cluster mean feat, [num_leaf, dim]\n                pred_mask_feat_mean = pair_mask_feature_mean(rendered_leaf_cluster_imgs, rendered_leaf_cluster_silhouettes) \n                # pesudo mean feat, [num_pesudo_mask, dim]\n                pesudo_mask_feat_mean = mask_feature_mean(view.pesudo_ins_feat, view.pesudo_mask_bool)\n                # only for visualization, [num_pesudo_mask, dim， H, W]\n                pesudo_mask_feat = view.pesudo_ins_feat * view.pesudo_mask_bool.unsqueeze(1)\n                # distance\n                l1_dis, _ = calculate_pairwise_distances(pred_mask_feat_mean, pesudo_mask_feat_mean, metric=\"l1\")   # method=\"l1\"\n\n                # (3) iou-feature distance joint score\n                scores = ious * (1-l1_dis)      # Eq.(5) in the paper\n\n                # (4) save the association result\n                max_score, max_ind = torch.max(scores, dim=-1)  # [num_leaf]\n                b_matched = max_score > 0.2     # todo\n                max_score[~b_matched] *= 0\n                max_ind[~b_matched] *= 0\n                match_info[occured_leaf_id, v_id] = torch.stack((max_ind, max_score, b_matched), dim=1)\n\n                # (5) save matching results for visualization. (only save the paired mask)\n                association_debug = True\n                if association_debug:\n                    leaf_cluster_path = os.path.join(scene.model_path, \"train_process\", \"stage3\", \"leaf_cluster\")\n                    leaf_cluster_silhouette_path = os.path.join(scene.model_path, \"train_process\", \"stage3\", \"leaf_cluster_silhouettes\")\n                    leaf_pesudo_mask_path = os.path.join(scene.model_path, \"train_process\", \"stage3\", \"leaf_pesudo_mask\")\n                    makedirs(leaf_cluster_path, exist_ok=True)\n                    makedirs(leaf_cluster_silhouette_path, exist_ok=True)\n                    makedirs(leaf_pesudo_mask_path, exist_ok=True)\n                    if b_matched.sum() > 0:\n                        for i, img in enumerate(rendered_leaf_cluster_imgs):\n                            if not b_matched[i]:\n                                continue\n                            if max_score[i] < 0.8:  # note: 0.8 is just for visualization\n                                continue\n                            torchvision.utils.save_image(img[:3,:,:], os.path.join(leaf_cluster_path, \\\n                                                            f\"r{root_id}_l{i}_v{v_id}.png\"))\n                            torchvision.utils.save_image(rendered_leaf_cluster_silhouettes[i].to(torch.float32), \\\n                                                    os.path.join(leaf_cluster_silhouette_path, f\"r{root_id}_l{i}_v{v_id}.png\"))\n                            torchvision.utils.save_image(pesudo_mask_feat[max_ind[i]][:3,:,:], os.path.join(leaf_pesudo_mask_path, \\\n                                                                f\"r{root_id}_l{i}_v{v_id}.png\"))\n                    # print(\"end one root cluster of one view\")\n                if view.data_on_gpu and save_memory:\n                    view.to_cpu()\n        # print(\"end matching\")\n        torch.cuda.empty_cache()\n\n        # count the matches of each leaf (fine-level cluster) across all viewpoints.\n        leaf_per_view_matched_mask = match_info[:, :, 0].to(torch.int64) # [k1*k2, num_cam] matched mask id\n        match_info_sum = match_info.sum(dim=1)  # [k1*k2, (matched_mask_id, matched_score, b_matched)]\n        leaf_ave_score = match_info_sum[:, 1] / (match_info_sum[:, 2]+ 1e-6)    # [k1*k2] ave score\n        leaf_occu_count = match_info_sum[:, 2]          # [k1*k2] number of matches for each leaf\n        \n        # accumulated 2D features of each leaf\n        per_leaf_feat_sum = torch.zeros(root_num * leaf_num, 512).cuda()  # [k1*k2] \n        for v_id, view in enumerate(sorted_train_cameras):\n            if not view.data_on_gpu:\n                view.to_gpu()\n            if sam_level == 0:\n                strat_id = 0\n                end_id = view.original_sam_mask[sam_level].max().to(torch.int64) + 1\n            else:\n                strat_id = view.original_sam_mask[sam_level-1].max().to(torch.int64) + 1\n                end_id = view.original_sam_mask[sam_level].max().to(torch.int64) + 1\n            curr_view_lang_feat = view.original_mask_feat[strat_id:end_id, :]   # [num_mask, 512]\n            curr_view_lang_feat = torch.cat((torch.zeros_like(curr_view_lang_feat[0]).unsqueeze(0), \\\n                curr_view_lang_feat))   # note: [num_mask+1, 512] add a feature with all 0s, i.e., the feature with id=0.\n            # current feat [k1*k2, 512]\n            single_view_leaf_feat = curr_view_lang_feat[leaf_per_view_matched_mask[:, v_id]]\n            # accumulate\n            per_leaf_feat_sum += single_view_leaf_feat\n\n            if view.data_on_gpu and save_memory:\n                view.to_cpu()\n\n        # average language features [k1*k2, 512] \n        per_leaf_feat = per_leaf_feat_sum / (leaf_occu_count + 1e-4).unsqueeze(1)\n\n        # save per_leaf_feat[k1*k2, 512], leaf_ave_score[k1*k2], leaf_occu_count[k1*k2], cluster_indices[num_pts]\n        np.savez(f'{scene.model_path}/cluster_lang.npz',leaf_feat=per_leaf_feat.cpu().numpy(), \\\n                                    leaf_score=leaf_ave_score.cpu().numpy(), \\\n                                    occu_count=leaf_occu_count.cpu().numpy(), \\\n                                    leaf_ind=cluster_indices.cpu().numpy())\n\ndef training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, \\\n    start_root_cb_iter, scene : Scene, renderFunc, renderArgs):\n    if tb_writer:\n        tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)\n        tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)\n        tb_writer.add_scalar('iter_time', elapsed, iteration)\n\n    # Report test and samples of training set\n    if iteration in testing_iterations:\n        torch.cuda.empty_cache()\n        validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()}, \n                              {'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})\n\n        for config in validation_configs:\n            if config['cameras'] and len(config['cameras']) > 0:\n                l1_test = 0.0\n                psnr_test = 0.0\n                for idx, viewpoint in enumerate(config['cameras']):\n                    image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)[\"render\"], 0.0, 1.0)\n                    gt_image = torch.clamp(viewpoint.original_image.to(\"cuda\"), 0.0, 1.0)\n                    if tb_writer and (idx < 5):\n                        tb_writer.add_images(config['name'] + \"_view_{}/render\".format(viewpoint.image_name), image[None], global_step=iteration)\n                        if iteration == testing_iterations[0]:\n                            tb_writer.add_images(config['name'] + \"_view_{}/ground_truth\".format(viewpoint.image_name), gt_image[None], global_step=iteration)\n                    l1_test += l1_loss(image, gt_image).mean().double()\n                    psnr_test += psnr(image, gt_image).mean().double()\n                psnr_test /= len(config['cameras'])\n                l1_test /= len(config['cameras'])          \n                print(\"\\n[ITER {}] Evaluating {}: L1 {} PSNR {}\".format(iteration, config['name'], l1_test, psnr_test))\n                sys.stdout.flush()\n                if tb_writer:\n                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)\n                    tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)\n\n        if tb_writer:\n            tb_writer.add_histogram(\"scene/opacity_histogram\", scene.gaussians.get_opacity, iteration)\n            tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)\n        torch.cuda.empty_cache()\n\n# initialize new gaussian parameters. modify -----\ndef initialize_new_params(new_pt_cld, mean3_sq_dist):\n    num_pts = new_pt_cld.shape[0]\n    means3D = new_pt_cld[:, :3] # [num_gaussians, 3]\n    unnorm_rots = np.tile([1, 0, 0, 0], (num_pts, 1)) # [num_gaussians, 3]\n    logit_opacities = torch.zeros((num_pts, 1), dtype=torch.float, device=\"cuda\")\n    logit_ins_feat = torch.zeros((num_pts, 3), dtype=torch.float, device=\"cuda\")\n    # color [N, 3, 16]\n    max_sh_degree = 3\n    fused_color = RGB2SH(new_pt_cld[:, 3:6])\n    features = torch.zeros((fused_color.shape[0], 3, (max_sh_degree + 1) ** 2)).float().cuda() # [N, 3, 16]\n    features[:, :3, 0 ] = fused_color\n    features[:, 3:, 1:] = 0.0\n    params = {\n        'new_xyz': means3D,\n        'new_features_dc': features[:,:,0:1].transpose(1, 2).contiguous(),\n        'new_features_rest':features[:,:,1:].transpose(1, 2).contiguous(),\n        'new_opacities': logit_opacities,\n        # 'new_scaling': torch.tile(torch.log(torch.sqrt(mean3_sq_dist))[..., None], (1, 1)),\n        'new_scaling': torch.tile(torch.log(torch.sqrt(mean3_sq_dist))[..., None], (1, 3)),\n        'new_rotation': unnorm_rots,\n        'new_ins_feat': logit_ins_feat,\n    }\n\n    for k, v in params.items():\n        # Check if value is already a torch tensor\n        if not isinstance(v, torch.Tensor):\n            params[k] = torch.nn.Parameter(torch.tensor(v).cuda().float().contiguous().requires_grad_(True))\n        else:\n            params[k] = torch.nn.Parameter(v.cuda().float().contiguous().requires_grad_(True))\n\n    return params\n# modify -----\n\nif __name__ == \"__main__\":\n    # Set up command line argument parser\n    parser = ArgumentParser(description=\"Training script parameters\")\n    lp = ModelParams(parser)\n    op = OptimizationParams(parser)\n    pp = PipelineParams(parser)\n    parser.add_argument('--ip', type=str, default=\"127.0.0.1\")\n    parser.add_argument('--port', type=int, default=6009)\n    parser.add_argument('--debug_from', type=int, default=-1)\n    parser.add_argument('--detect_anomaly', action='store_true', default=False)\n    parser.add_argument(\"--test_iterations\", nargs=\"+\", type=int, default=[30_000])\n    parser.add_argument(\"--save_iterations\", nargs=\"+\", type=int, default=[30_000])\n    parser.add_argument(\"--quiet\", action=\"store_true\")\n    parser.add_argument(\"--checkpoint_iterations\", nargs=\"+\", type=int, default=[])\n    parser.add_argument(\"--start_checkpoint\", type=str, default = None)\n    args = parser.parse_args(sys.argv[1:])\n    args.save_iterations.append(args.iterations)\n    args.checkpoint_iterations.append(args.iterations)\n    \n    print(\"Optimizing \" + args.model_path)\n\n    # Initialize system state (RNG)\n    safe_state(args.quiet)\n\n    # Start GUI server, configure and run training\n    network_gui.init(args.ip, args.port)\n    torch.autograd.set_detect_anomaly(args.detect_anomaly)\n    training(lp.extract(args), op.extract(args), pp.extract(args), \\\n             args.test_iterations, args.save_iterations, args.checkpoint_iterations, \\\n             args.start_checkpoint, args.debug_from)\n\n    # All done\n    print(\"\\nTraining complete.\")\n"
  },
  {
    "path": "utils/camera_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom scene.cameras import Camera\nimport numpy as np\nfrom utils.general_utils import PILtoTorch\nfrom utils.graphics_utils import fov2focal\nimport torch\n\nWARNED = False\n\ndef loadCam(args, id, cam_info, resolution_scale):\n    orig_w, orig_h = cam_info.image.size\n\n    if args.resolution in [1, 2, 4, 8]:\n        resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))\n    else:  # should be a type that converts to float\n        if args.resolution == -1:\n            if orig_w > 1600:\n                global WARNED\n                if not WARNED:\n                    print(\"[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\\n \"\n                        \"If this is not desired, please explicitly specify '--resolution/-r' as 1\")\n                    WARNED = True\n                global_down = orig_w / 1600\n            else:\n                global_down = 1\n        else:\n            global_down = orig_w / args.resolution\n\n        scale = float(global_down) * float(resolution_scale)\n        resolution = (int(orig_w / scale), int(orig_h / scale))\n\n    resized_image_rgb = PILtoTorch(cam_info.image, resolution)  # [C, H, W]\n    \n    # NOTE: load SAM mask. modify -----\n    if cam_info.sam_mask is not None:\n        # step = int(args.resolution/2)     \n        step = int(max(args.resolution, 1))\n        gt_sam_mask = cam_info.sam_mask[:, ::step, ::step]  # downsample for mask\n        gt_sam_mask = torch.from_numpy(gt_sam_mask)\n        # align resolution\n        if resized_image_rgb.shape[1] != gt_sam_mask.shape[1]:\n            resolution = (gt_sam_mask.shape[2], gt_sam_mask.shape[1])   # modify -----\n            resized_image_rgb = PILtoTorch(cam_info.image, resolution)  # [C, H, W]\n    else:\n        gt_sam_mask = None\n    if cam_info.mask_feat is not None:\n        mask_feat = torch.from_numpy(cam_info.mask_feat)\n    else:\n        mask_feat = None\n    # modify -----\n\n    gt_image = resized_image_rgb[:3, ...]\n    loaded_mask = None\n\n    # if resized_image_rgb.shape[1] == 4:\n    if resized_image_rgb.shape[0] == 4:\n        loaded_mask = resized_image_rgb[3:4, ...]\n\n    return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T, \n                  FoVx=cam_info.FovX, FoVy=cam_info.FovY, \n                  cx=cam_info.cx/args.resolution, cy=cam_info.cy/args.resolution,\n                  image=gt_image, depth=None, gt_alpha_mask=loaded_mask,\n                  gt_sam_mask=gt_sam_mask, gt_mask_feat=mask_feat,\n                  image_name=cam_info.image_name, uid=id, data_device=args.data_device)\n\ndef cameraList_from_camInfos(cam_infos, resolution_scale, args):\n    camera_list = []\n\n    for id, c in enumerate(cam_infos):\n        camera_list.append(loadCam(args, id, c, resolution_scale))\n\n    return camera_list\n\ndef camera_to_JSON(id, camera : Camera):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = camera.R.transpose()\n    Rt[:3, 3] = camera.T\n    Rt[3, 3] = 1.0\n\n    W2C = np.linalg.inv(Rt)\n    pos = W2C[:3, 3]\n    rot = W2C[:3, :3]\n    serializable_array_2d = [x.tolist() for x in rot]\n    camera_entry = {\n        'id' : id,\n        'img_name' : camera.image_name,\n        'width' : camera.width,\n        'height' : camera.height,\n        'position': pos.tolist(),\n        'rotation': serializable_array_2d,\n        'fy' : fov2focal(camera.FovY, camera.height),\n        'fx' : fov2focal(camera.FovX, camera.width)\n    }\n    return camera_entry\n"
  },
  {
    "path": "utils/general_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport sys\nfrom datetime import datetime\nimport numpy as np\nimport random\n\ndef inverse_sigmoid(x):\n    return torch.log(x/(1-x))\n\ndef PILtoTorch(pil_image, resolution):\n    resized_image_PIL = pil_image.resize(resolution)\n    resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0\n    if len(resized_image.shape) == 3:\n        return resized_image.permute(2, 0, 1)\n    else:\n        return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)\n\ndef get_expon_lr_func(\n    lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000\n):\n    \"\"\"\n    Copied from Plenoxels\n\n    Continuous learning rate decay function. Adapted from JaxNeRF\n    The returned rate is lr_init when step=0 and lr_final when step=max_steps, and\n    is log-linearly interpolated elsewhere (equivalent to exponential decay).\n    If lr_delay_steps>0 then the learning rate will be scaled by some smooth\n    function of lr_delay_mult, such that the initial learning rate is\n    lr_init*lr_delay_mult at the beginning of optimization but will be eased back\n    to the normal learning rate when steps>lr_delay_steps.\n    :param conf: config subtree 'lr' or similar\n    :param max_steps: int, the number of steps during optimization.\n    :return HoF which takes step as input\n    \"\"\"\n\n    def helper(step):\n        if step < 0 or (lr_init == 0.0 and lr_final == 0.0):\n            # Disable this parameter\n            return 0.0\n        if lr_delay_steps > 0:\n            # A kind of reverse cosine decay.\n            delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(\n                0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)\n            )\n        else:\n            delay_rate = 1.0\n        t = np.clip(step / max_steps, 0, 1)\n        log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)\n        return delay_rate * log_lerp\n\n    return helper\n\ndef strip_lowerdiag(L):\n    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=\"cuda\")\n\n    uncertainty[:, 0] = L[:, 0, 0]\n    uncertainty[:, 1] = L[:, 0, 1]\n    uncertainty[:, 2] = L[:, 0, 2]\n    uncertainty[:, 3] = L[:, 1, 1]\n    uncertainty[:, 4] = L[:, 1, 2]\n    uncertainty[:, 5] = L[:, 2, 2]\n    return uncertainty\n\ndef strip_symmetric(sym):\n    return strip_lowerdiag(sym)\n\ndef build_rotation(r):\n    norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])\n\n    q = r / norm[:, None]\n\n    R = torch.zeros((q.size(0), 3, 3), device='cuda')\n\n    r = q[:, 0]\n    x = q[:, 1]\n    y = q[:, 2]\n    z = q[:, 3]\n\n    R[:, 0, 0] = 1 - 2 * (y*y + z*z)\n    R[:, 0, 1] = 2 * (x*y - r*z)\n    R[:, 0, 2] = 2 * (x*z + r*y)\n    R[:, 1, 0] = 2 * (x*y + r*z)\n    R[:, 1, 1] = 1 - 2 * (x*x + z*z)\n    R[:, 1, 2] = 2 * (y*z - r*x)\n    R[:, 2, 0] = 2 * (x*z - r*y)\n    R[:, 2, 1] = 2 * (y*z + r*x)\n    R[:, 2, 2] = 1 - 2 * (x*x + y*y)\n    return R\n\ndef build_scaling_rotation(s, r):\n    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=\"cuda\")\n    R = build_rotation(r)\n\n    L[:,0,0] = s[:,0]\n    L[:,1,1] = s[:,1]\n    L[:,2,2] = s[:,2]\n\n    L = R @ L\n    return L\n\ndef safe_state(silent):\n    old_f = sys.stdout\n    class F:\n        def __init__(self, silent):\n            self.silent = silent\n\n        def write(self, x):\n            if not self.silent:\n                if x.endswith(\"\\n\"):\n                    old_f.write(x.replace(\"\\n\", \" [{}]\\n\".format(str(datetime.now().strftime(\"%d/%m %H:%M:%S\")))))\n                else:\n                    old_f.write(x)\n\n        def flush(self):\n            old_f.flush()\n\n    sys.stdout = F(silent)\n\n    random.seed(0)\n    np.random.seed(0)\n    torch.manual_seed(0)\n    torch.cuda.set_device(torch.device(\"cuda:0\"))\n"
  },
  {
    "path": "utils/graphics_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport math\nimport numpy as np\nfrom typing import NamedTuple\n\nclass BasicPointCloud(NamedTuple):\n    points : np.array\n    colors : np.array\n    normals : np.array\n\ndef geom_transform_points(points, transf_matrix):\n    P, _ = points.shape\n    ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)\n    points_hom = torch.cat([points, ones], dim=1)\n    points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))\n\n    denom = points_out[..., 3:] + 0.0000001\n    return (points_out[..., :3] / denom).squeeze(dim=0)\n\ndef getWorld2View(R, t):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = R.transpose()\n    Rt[:3, 3] = t\n    Rt[3, 3] = 1.0\n    return np.float32(Rt)\n\ndef getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):\n    Rt = np.zeros((4, 4))\n    Rt[:3, :3] = R.transpose()\n    Rt[:3, 3] = t\n    Rt[3, 3] = 1.0\n\n    C2W = np.linalg.inv(Rt)\n    cam_center = C2W[:3, 3]\n    cam_center = (cam_center + translate) * scale\n    C2W[:3, 3] = cam_center\n    Rt = np.linalg.inv(C2W)\n    return np.float32(Rt)\n\ndef getProjectionMatrix(znear, zfar, fovX, fovY):\n    tanHalfFovY = math.tan((fovY / 2))\n    tanHalfFovX = math.tan((fovX / 2))\n\n    top = tanHalfFovY * znear\n    bottom = -top\n    right = tanHalfFovX * znear\n    left = -right\n\n    P = torch.zeros(4, 4)\n\n    z_sign = 1.0\n\n    P[0, 0] = 2.0 * znear / (right - left)\n    P[1, 1] = 2.0 * znear / (top - bottom)\n    P[0, 2] = (right + left) / (right - left)\n    P[1, 2] = (top + bottom) / (top - bottom)\n    P[3, 2] = z_sign\n    P[2, 2] = z_sign * zfar / (zfar - znear)\n    P[2, 3] = -(zfar * znear) / (zfar - znear)\n    return P\n\ndef fov2focal(fov, pixels):\n    return pixels / (2 * math.tan(fov / 2))\n\ndef focal2fov(focal, pixels):\n    return 2*math.atan(pixels/(2*focal))"
  },
  {
    "path": "utils/image_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\n\ndef mse(img1, img2):\n    return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n\ndef psnr(img1, img2):\n    mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)\n    return 20 * torch.log10(1.0 / torch.sqrt(mse))\n"
  },
  {
    "path": "utils/loss_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.autograd import Variable\nfrom math import exp\n\ndef l1_loss(network_output, gt, mask=None, weight=None):    \n    if mask == None:\n        return torch.abs((network_output - gt)).mean()\n    else:\n        if weight is None:\n            weight = torch.ones_like(mask)\n        return torch.abs((network_output - gt) * mask * weight).sum() / mask.sum().clamp(min=1)\n\ndef l2_loss(network_output, gt, mask=None, weight=None):\n    if mask == None:\n        return ((network_output - gt) ** 2).mean()\n    else:\n        if weight is None:\n            weight = torch.ones_like(mask)\n        return ((network_output - gt) ** 2 * mask * weight).sum() / mask.sum().clamp(min=1)\n\ndef gaussian(window_size, sigma):\n    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])\n    return gauss / gauss.sum()\n\ndef create_window(window_size, channel):\n    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)\n    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)\n    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())\n    return window\n\ndef ssim(img1, img2, window_size=11, size_average=True):\n    channel = img1.size(-3)\n    window = create_window(window_size, channel)\n\n    if img1.is_cuda:\n        window = window.cuda(img1.get_device())\n    window = window.type_as(img1)\n\n    return _ssim(img1, img2, window, window_size, channel, size_average)\n\ndef _ssim(img1, img2, window, window_size, channel, size_average=True):\n    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)\n    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)\n\n    mu1_sq = mu1.pow(2)\n    mu2_sq = mu2.pow(2)\n    mu1_mu2 = mu1 * mu2\n\n    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq\n    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq\n    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2\n\n    C1 = 0.01 ** 2\n    C2 = 0.03 ** 2\n\n    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))\n\n    if size_average:\n        return ssim_map.mean()\n    else:\n        return ssim_map.mean(1).mean(1).mean(1)\n\n"
  },
  {
    "path": "utils/opengs_utlis.py",
    "content": "import torch\nimport numpy as np\nimport torch.nn.functional as F\nimport os\nfrom bitarray import bitarray\nfrom collections import OrderedDict\n\ndef calculate_pairwise_distances(tensor1, tensor2, metric=None):\n    \"\"\"\n    Calculate L1 (Manhattan) and L2 (Euclidean) distances between every pair of vectors\n    in two tensors of shape [m, 6] and [n, 6].\n    Args:\n        tensor1 (torch.Tensor): A tensor of shape [m, 6].\n        tensor2 (torch.Tensor): Another tensor of shape [n, 6].\n    Returns:\n        torch.Tensor: L1 distances of shape [m, n].\n        torch.Tensor: L2 distances of shape [m, n].\n    \"\"\"\n    # Reshape tensors to allow broadcasting\n    # tensor1 shape becomes [m, 1, 6] and tensor2 shape becomes [1, n, 6]\n    tensor1 = tensor1.unsqueeze(1)  # Now tensor1 is [m, 1, 6]\n    tensor2 = tensor2.unsqueeze(0)  # Now tensor2 is [1, n, 6]\n\n    # Compute the L1 distance\n    if metric == \"l1\":\n        return torch.abs(tensor1 - tensor2).sum(dim=2), None  # Result is [m, n]\n\n    # Compute the L2 distance\n    if metric == \"l2\":\n        return None, torch.sqrt((tensor1 - tensor2).pow(2).sum(dim=2))  # Result is [m, n]\n\n    l1_distances = torch.abs(tensor1 - tensor2).sum(dim=2)\n    l2_distances = torch.sqrt((tensor1 - tensor2).pow(2).sum(dim=2))\n    return l1_distances, l2_distances\n\ndef calculate_distances(tensor1, tensor2, metric=None):\n    \"\"\"\n    Calculate L1 (Manhattan) and L2 (Euclidean) distances between corresponding vectors\n    in two tensors of shape [N, dim].\n    Args:\n        tensor1 (torch.Tensor): A tensor of shape [N, dim].\n        tensor2 (torch.Tensor): Another tensor of shape [N, dim].\n    Returns:\n        torch.Tensor: L1 distances of shape [N].\n        torch.Tensor: L2 distances of shape [N].\n    \"\"\"\n    # Compute L1 distance\n    if metric == \"l1\":\n        return torch.abs(tensor1 - tensor2).sum(dim=1)\n    \n    # Compute L2 distance\n    if metric == \"l2\":\n        return torch.sqrt((tensor1 - tensor2).pow(2).sum(dim=1))\n    \n    l1_distances = torch.abs(tensor1 - tensor2).sum(dim=1)\n    l2_distances = torch.sqrt((tensor1 - tensor2).pow(2).sum(dim=1))\n\n    return l1_distances, l2_distances\n    \n\ndef bin2dec(b, bits):\n    \"\"\"Convert binary b to decimal integer.\n    Code from: https://stackoverflow.com/questions/55918468/convert-integer-to-pytorch-tensor-of-binary-bits\n    \"\"\"\n    mask = 2 ** torch.arange(bits - 1, -1, -1).to(b.device, torch.int64)\n    return torch.sum(mask * b, -1)\n\ndef load_code_book(base_path):\n    inds_file = os.path.join(base_path, 'kmeans_inds.bin')\n    codebook_file = os.path.join(base_path, 'kmeans_centers.pth')\n    args_file = os.path.join(base_path, 'kmeans_args.npy')\n    codebook = torch.load(codebook_file)    # [num_cluster, dim]\n    args_dict = np.load(args_file, allow_pickle=True).item()\n    quant_params = args_dict['params']\n    loaded_bitarray = bitarray()\n    with open(inds_file, 'rb') as file:\n        loaded_bitarray.fromfile(file)\n    # bitarray pads 0s if array is not divisible by 8. ignore extra 0s at end when loading\n    total_len = args_dict['total_len']\n    loaded_bitarray = loaded_bitarray[:total_len].tolist()\n    indices = np.reshape(loaded_bitarray, (-1, args_dict['n_bits']))\n    indices = bin2dec(torch.from_numpy(indices), args_dict['n_bits'])\n    indices = np.reshape(indices.cpu().numpy(), (len(quant_params), -1))\n    indices_dict = OrderedDict()\n    for i, key in enumerate(args_dict['params']):\n        indices_dict[key] = indices[i]\n    \n    return codebook, indices_dict['ins_feat']\n\ndef calculate_iou(masks1, masks2, base=None):\n    \"\"\"\n    Calculate the Intersection over Union (IoU) between two sets of masks.\n    Args:\n        masks1: PyTorch tensor of shape [n, H, W], torch.int32.\n        masks2: PyTorch tensor of shape [m, H, W], torch.int32.\n    Returns:\n        iou_matrix: PyTorch tensor of shape [m, n], containing IoU values.\n    \"\"\"\n    # Ensure the masks are of type torch.int32\n    if masks1.dtype != torch.bool:\n        masks1 = masks1.to(torch.bool)\n    if masks2.dtype != torch.bool:\n        masks2 = masks2.to(torch.bool)\n    \n    # Expand masks to broadcastable shapes\n    masks1_expanded = masks1.unsqueeze(0)  # [1, n, H, W]\n    masks2_expanded = masks2.unsqueeze(1)  # [m, 1, H, W]\n    \n    # Compute intersection\n    intersection = (masks1_expanded & masks2_expanded).float().sum(dim=(2, 3))  # [m, n]\n    \n    # Compute union\n    if base == \"former\":\n        union = (masks1_expanded).float().sum(dim=(2, 3)) + 1e-6  # [m, n]\n    elif base == \"later\":\n        union = (masks2_expanded).float().sum(dim=(2, 3)) + 1e-6  # [m, n]\n    else:\n        union = (masks1_expanded | masks2_expanded).float().sum(dim=(2, 3)) + 1e-6  # [m, n]\n    \n    # Compute IoU\n    iou_matrix = intersection / union\n    \n    return iou_matrix\n\ndef get_SAM_mask_and_feat(gt_sam_mask, level=3, filter_th=50, original_mask_feat=None, sample_mask=False):\n    \"\"\"\n    input: \n        gt_sam_mask[4, H, W]: mask id\n    output:\n        mask_id[H, W]: The ID of the mask each pixel belongs to (0 indicates invalid pixels)\n        mask_bool[num_mask+1, H, W]: Boolean, note that the return value excludes the 0th mask (invalid points)\n        invalid_pix[H, W]: Boolean, invalid pixels\n    \"\"\"\n    # (1) mask id: -1, 1, 2, 3,...\n    mask_id = gt_sam_mask[level].clone()\n    if level > 0:\n        # subtract the maximum mask ID of the previous level\n        mask_id = mask_id - (gt_sam_mask[level-1].max().detach().cpu()+1)\n    if mask_id.min() < 0:\n        mask_id = mask_id.clamp_min(-1)    # -1, 0~num_mask\n    mask_id += 1    # 0, 1~num_mask+1\n    invalid_pix = mask_id==0    # invalid pixels\n\n    # (2) mask id[H, W] -> one-hot/mask_bool [num_mask+1, H, W]\n    instance_num = mask_id.max()\n    one_hot = F.one_hot(mask_id.type(torch.int64), num_classes=int(instance_num.item() + 1))\n    # bool mask [num+1, H, W]\n    mask_bool = one_hot.permute(2, 0, 1)\n    \n    # # TODO modify -------- only keep the largest 50\n    # if instance_num > 50:\n    #     top50_values, _ = torch.topk(mask_bool.sum(dim=(1,2)), 50, largest=True)\n    #     filter_th = top50_values[-1].item()\n    # # modify --------\n\n    # # TODO: not used\n    # # (3) delete small mask \n    # saved_idx = mask_bool.sum(dim=(1,2)) >= filter_th  # default 50 pixels\n    # # Random sampling, not actually used\n    # if sample_mask:\n    #     prob = torch.rand(saved_idx.shape[0])\n    #     sample_ind = prob > 0.5\n    #     saved_idx = saved_idx & sample_ind.cuda()\n    # saved_idx[0] = True  # Keep the mask for invalid points, ensuring that mask_id == 0 corresponds to invalid pixels.\n    # mask_bool = mask_bool[saved_idx]    # [num_filt, H, W]\n\n    # update mask id\n    mask_id = torch.argmax(mask_bool, dim=0)  # [H, W] The ID of the pixels after filtering is 0\n    invalid_pix = mask_id==0\n\n    # TODO not used!\n    # (4) Get the language features corresponding to the masks (used for 2D-3D association in the third stage)\n    if original_mask_feat is not None:\n        mask_feat = original_mask_feat.clone()       # [num_mask, 512]\n        max_ind = int(gt_sam_mask[level].max())+1\n        min_ind = int(gt_sam_mask[level-1].max())+1 if level > 0 else 0\n        mask_feat = mask_feat[min_ind:max_ind, :]\n        # # update mask feat\n        # mask_feat = mask_feat[saved_idx[1:]]    # The 0th element of saved_idx is the mask corresponding to invalid pixels and has no features\n\n        return mask_id, mask_bool[1:, :, :], mask_feat, invalid_pix\n    return mask_id, mask_bool[1:, :, :], invalid_pix\n\ndef pair_mask_feature_mean(feat_map, masks):\n    \"\"\" mean feat of N masks\n    feat_map: [N, C, H, W]\n    masks: [N, H, W]\n    mean_values: [N, C]\n    \"\"\"\n    N, C, H, W = feat_map.shape\n\n    # [N, H, W] -> [N, C, H, W]\n    expanded_masks = masks.unsqueeze(1).expand(-1, C, -1, -1)\n    # [N, C, H, W]\n    masked_features = feat_map * expanded_masks.float()\n    # pixels\n    mask_counts = expanded_masks.sum(dim=[2, 3]) + 1e-6\n    # mean feat [N, C]\n    mean_values = masked_features.sum(dim=[2, 3]) / mask_counts\n\n    return mean_values\n\ndef process_in_chunks(masks_expanded, masked_feats, mean_per_channel, chunk_size=5):\n    result = torch.zeros_like(masked_feats)\n    for i in range(0, masks_expanded.size(0), chunk_size):\n        end_i = min(i + chunk_size, masks_expanded.size(0))\n        for j in range(0, masks_expanded.size(1), chunk_size):\n            end_j = min(j + chunk_size, masks_expanded.size(1))\n            chunk_mask = masks_expanded[i:end_i, j:end_j]\n            chunk_feats = masked_feats[i:end_i, j:end_j]\n            chunk_mean = mean_per_channel[i:end_i, j:end_j].unsqueeze(-1).unsqueeze(-1)\n\n            result[i:end_i, j:end_j] = torch.where(chunk_mask.bool(), chunk_feats - chunk_mean, torch.zeros_like(chunk_feats))\n    return result\n\ndef calculate_variance_in_chunks(masked_for_variance, mask_counts, chunk_size=5):\n    variance_per_channel = torch.zeros(masked_for_variance.size(0), masked_for_variance.size(1), device=masked_for_variance.device)\n    for i in range(0, masked_for_variance.size(0), chunk_size):\n        end_i = min(i + chunk_size, masked_for_variance.size(0))\n        for j in range(0, masked_for_variance.size(1), chunk_size):\n            end_j = min(j + chunk_size, masked_for_variance.size(1))\n            chunk_masked_for_variance = masked_for_variance[i:end_i, j:end_j]\n\n            chunk_variance = (chunk_masked_for_variance ** 2).sum(dim=[2, 3]) / mask_counts[i:end_i, j:end_j]\n            variance_per_channel[i:end_i, j:end_j] = chunk_variance\n    return variance_per_channel\n\ndef ele_multip_in_chunks(feat_expanded, masks_expanded, chunk_size=5):\n    result = torch.zeros_like(feat_expanded)\n    for i in range(0, feat_expanded.size(0), chunk_size):\n        end_i = min(i + chunk_size, feat_expanded.size(0))\n        for j in range(0, feat_expanded.size(1), chunk_size):\n            end_j = min(j + chunk_size, feat_expanded.size(1))\n            chunk_feat = feat_expanded[i:end_i, j:end_j]\n            chunk_mask = masks_expanded[i:end_i, j:end_j].float()\n\n            result[i:end_i, j:end_j] = chunk_feat * chunk_mask\n    return result\n\ndef mask_feature_mean(feat_map, gt_masks, image_mask=None, return_var=False):\n    \"\"\"Compute the average instance features within each mask.\n    feat_map: [C=6, H, W]         the instance features of the entire image\n    gt_masks: [num_mask, H, W]  num_mask boolean masks\n    \"\"\"\n    num_mask, H, W = gt_masks.shape\n\n    # expand feat and masks for batch processing\n    feat_expanded = feat_map.unsqueeze(0).expand(num_mask, *feat_map.shape)  # [num_mask, C, H, W]\n    masks_expanded = gt_masks.unsqueeze(1).expand(-1, feat_map.shape[0], -1, -1)  # [num_mask, C, H, W]\n    if image_mask is not None:  # image level mask\n        image_mask_expanded = image_mask.unsqueeze(0).expand(num_mask, feat_map.shape[0], -1, -1)\n\n    # average features within each mask\n    if image_mask is not None:\n        masked_feats = feat_expanded * masks_expanded.float() * image_mask_expanded.float()\n        mask_counts = (masks_expanded * image_mask_expanded.float()).sum(dim=(2, 3))\n    else:\n        # masked_feats = feat_expanded * masks_expanded.float()  # [num_mask, C, H, W] may cause OOM\n        masked_feats = ele_multip_in_chunks(feat_expanded, masks_expanded, chunk_size=5)   # in chuck to avoid OOM\n        mask_counts = masks_expanded.sum(dim=(2, 3))  # [num_mask, C]\n\n    # the number of pixels within each mask\n    mask_counts = mask_counts.clamp(min=1)\n\n    # the mean features of each mask\n    sum_per_channel = masked_feats.sum(dim=[2, 3])\n    mean_per_channel = sum_per_channel / mask_counts    # [num_mask, C]\n\n    if not return_var:\n        return mean_per_channel   # [num_mask, C]\n    else:\n        # calculate variance\n        # masked_for_variance = torch.where(masks_expanded.bool(), masked_feats - mean_per_channel.unsqueeze(-1).unsqueeze(-1), torch.zeros_like(masked_feats))\n        masked_for_variance = process_in_chunks(masks_expanded, masked_feats, mean_per_channel, chunk_size=5) # in chunk to avoid OOM\n\n        # variance_per_channel = (masked_for_variance ** 2).sum(dim=[2, 3]) / mask_counts    # [num_mask, 6]\n        variance_per_channel = calculate_variance_in_chunks(masked_for_variance, mask_counts, chunk_size=5)   # in chuck to avoid OOM\n\n        # mean and variance\n        mean = mean_per_channel.mean(dim=1)          # [num_mask]，not used\n        variance = variance_per_channel.mean(dim=1)  # [num_mask]\n\n        return mean_per_channel, variance, mask_counts[:, 0]   # [num_mask, C], [num_mask], [num_mask]\n\ndef linear_to_srgb(linear):\n    if isinstance(linear, torch.Tensor):\n        \"\"\"Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.\"\"\"\n        eps = torch.finfo(torch.float32).eps\n        srgb0 = 323 / 25 * linear\n        srgb1 = (211 * torch.clamp(linear, min=eps)**(5 / 12) - 11) / 200\n        return torch.where(linear <= 0.0031308, srgb0, srgb1)\n    elif isinstance(linear, np.ndarray):\n        eps = np.finfo(np.float32).eps\n        srgb0 = 323 / 25 * linear\n        srgb1 = (211 * np.maximum(eps, linear) ** (5 / 12) - 11) / 200\n        return np.where(linear <= 0.0031308, srgb0, srgb1)\n    else:\n        raise NotImplementedError\n\ndef srgb_to_linear(srgb):\n    if isinstance(srgb, torch.Tensor):\n        \"\"\"Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.\"\"\"\n        eps = torch.finfo(torch.float32).eps\n        linear0 = 25 / 323 * srgb\n        linear1 = torch.clamp(((200 * srgb + 11) / (211)), min=eps)**(12 / 5)\n        return torch.where(srgb <= 0.04045, linear0, linear1)\n    elif isinstance(srgb, np.ndarray):\n        \"\"\"Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB.\"\"\"\n        eps = np.finfo(np.float32).eps\n        linear0 = 25 / 323 * srgb\n        linear1 = np.maximum(((200 * srgb + 11) / (211)), eps)**(12 / 5)\n        return np.where(srgb <= 0.04045, linear0, linear1)\n    else:\n        raise NotImplementedError"
  },
  {
    "path": "utils/sh_utils.py",
    "content": "#  Copyright 2021 The PlenOctree Authors.\n#  Redistribution and use in source and binary forms, with or without\n#  modification, are permitted provided that the following conditions are met:\n#\n#  1. Redistributions of source code must retain the above copyright notice,\n#  this list of conditions and the following disclaimer.\n#\n#  2. Redistributions in binary form must reproduce the above copyright notice,\n#  this list of conditions and the following disclaimer in the documentation\n#  and/or other materials provided with the distribution.\n#\n#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE\n#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n#  POSSIBILITY OF SUCH DAMAGE.\n\nimport torch\n\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n    1.0925484305920792,\n    -1.0925484305920792,\n    0.31539156525252005,\n    -1.0925484305920792,\n    0.5462742152960396\n]\nC3 = [\n    -0.5900435899266435,\n    2.890611442640554,\n    -0.4570457994644658,\n    0.3731763325901154,\n    -0.4570457994644658,\n    1.445305721320277,\n    -0.5900435899266435\n]\nC4 = [\n    2.5033429417967046,\n    -1.7701307697799304,\n    0.9461746957575601,\n    -0.6690465435572892,\n    0.10578554691520431,\n    -0.6690465435572892,\n    0.47308734787878004,\n    -1.7701307697799304,\n    0.6258357354491761,\n]   \n\n\ndef eval_sh(deg, sh, dirs):\n    \"\"\"\n    Evaluate spherical harmonics at unit directions\n    using hardcoded SH polynomials.\n    Works with torch/np/jnp.\n    ... Can be 0 or more batch dimensions.\n    Args:\n        deg: int SH deg. Currently, 0-3 supported\n        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]\n        dirs: jnp.ndarray unit directions [..., 3]\n    Returns:\n        [..., C]\n    \"\"\"\n    assert deg <= 4 and deg >= 0\n    coeff = (deg + 1) ** 2\n    assert sh.shape[-1] >= coeff\n\n    result = C0 * sh[..., 0]\n    if deg > 0:\n        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]\n        result = (result -\n                C1 * y * sh[..., 1] +\n                C1 * z * sh[..., 2] -\n                C1 * x * sh[..., 3])\n\n        if deg > 1:\n            xx, yy, zz = x * x, y * y, z * z\n            xy, yz, xz = x * y, y * z, x * z\n            result = (result +\n                    C2[0] * xy * sh[..., 4] +\n                    C2[1] * yz * sh[..., 5] +\n                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +\n                    C2[3] * xz * sh[..., 7] +\n                    C2[4] * (xx - yy) * sh[..., 8])\n\n            if deg > 2:\n                result = (result +\n                C3[0] * y * (3 * xx - yy) * sh[..., 9] +\n                C3[1] * xy * z * sh[..., 10] +\n                C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +\n                C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +\n                C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +\n                C3[5] * z * (xx - yy) * sh[..., 14] +\n                C3[6] * x * (xx - 3 * yy) * sh[..., 15])\n\n                if deg > 3:\n                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +\n                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +\n                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +\n                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +\n                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +\n                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +\n                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +\n                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +\n                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])\n    return result\n\ndef RGB2SH(rgb):\n    return (rgb - 0.5) / C0\n\ndef SH2RGB(sh):\n    return sh * C0 + 0.5"
  },
  {
    "path": "utils/system_utils.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom errno import EEXIST\nfrom os import makedirs, path\nimport os\n\ndef mkdir_p(folder_path):\n    # Creates a directory. equivalent to using mkdir -p on the command line\n    try:\n        makedirs(folder_path)\n    except OSError as exc: # Python >2.5\n        if exc.errno == EEXIST and path.isdir(folder_path):\n            pass\n        else:\n            raise\n\ndef searchForMaxIteration(folder):\n    saved_iters = [int(fname.split(\"_\")[-1]) for fname in os.listdir(folder)]\n    return max(saved_iters)\n"
  }
]