Full Code of yanmin-wu/OpenGaussian for AI

main 1f99db1b7e7d cached
38 files
269.9 KB
70.9k tokens
189 symbols
1 requests
Download .txt
Showing preview only (283K chars total). Download the full file or copy to clipboard to get everything.
Repository: yanmin-wu/OpenGaussian
Branch: main
Commit: 1f99db1b7e7d
Files: 38
Total size: 269.9 KB

Directory structure:
gitextract_gh4x629c/

├── .gitignore
├── LICENSE.md
├── README.md
├── arguments/
│   └── __init__.py
├── convert.py
├── environment.yml
├── full_eval.py
├── gaussian_renderer/
│   ├── __init__.py
│   └── network_gui.py
├── lpipsPyTorch/
│   ├── __init__.py
│   └── modules/
│       ├── lpips.py
│       ├── networks.py
│       └── utils.py
├── metrics.py
├── render.py
├── render_lerf_by_text.py
├── scene/
│   ├── __init__.py
│   ├── cameras.py
│   ├── colmap_loader.py
│   ├── dataset_readers.py
│   ├── gaussian_model.py
│   └── kmeans_quantize.py
├── scripts/
│   ├── compute_lerf_iou.py
│   ├── eval_scannet.py
│   ├── render_by_click.py
│   ├── scannet2blender.py
│   ├── train_lerf.sh
│   ├── train_scannet.sh
│   └── vis_opengs_pts_feat.py
├── train.py
└── utils/
    ├── camera_utils.py
    ├── general_utils.py
    ├── graphics_utils.py
    ├── image_utils.py
    ├── loss_utils.py
    ├── opengs_utlis.py
    ├── sh_utils.py
    └── system_utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.pyc
# .vscode
.git---
output
build
diff_rasterization/diff_rast.egg-info
diff_rasterization/dist
tensorboard_3d
screenshots
*.ipynb_checkpoints
# submodules/
# assets/
*.npz
*.bundle
output*
*.log
log

================================================
FILE: LICENSE.md
================================================
Gaussian-Splatting License  
===========================  

**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.  
The *Software* is in the process of being registered with the Agence pour la Protection des  
Programmes (APP).  

The *Software* is still being developed by the *Licensor*.  

*Licensor*'s goal is to allow the research community to use, test and evaluate  
the *Software*.  

## 1.  Definitions  

*Licensee* means any person or entity that uses the *Software* and distributes  
its *Work*.  

*Licensor* means the owners of the *Software*, i.e Inria and MPII  

*Software* means the original work of authorship made available under this  
License ie gaussian-splatting.  

*Work* means the *Software* and any additions to or derivative works of the  
*Software* that are made available under this License.  


## 2.  Purpose  
This license is intended to define the rights granted to the *Licensee* by  
Licensors under the *Software*.  

## 3.  Rights granted  

For the above reasons Licensors have decided to distribute the *Software*.  
Licensors grant non-exclusive rights to use the *Software* for research purposes  
to research users (both academic and industrial), free of charge, without right  
to sublicense.. The *Software* may be used "non-commercially", i.e., for research  
and/or evaluation purposes only.  

Subject to the terms and conditions of this License, you are granted a  
non-exclusive, royalty-free, license to reproduce, prepare derivative works of,  
publicly display, publicly perform and distribute its *Work* and any resulting  
derivative works in any form.  

## 4.  Limitations  

**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do  
so under this License, (b) you include a complete copy of this License with  
your distribution, and (c) you retain without modification any copyright,  
patent, trademark, or attribution notices that are present in the *Work*.  

**4.2 Derivative Works.** You may specify that additional or different terms apply  
to the use, reproduction, and distribution of your derivative works of the *Work*  
("Your Terms") only if (a) Your Terms provide that the use limitation in  
Section 2 applies to your derivative works, and (b) you identify the specific  
derivative works that are subject to Your Terms. Notwithstanding Your Terms,  
this License (including the redistribution requirements in Section 3.1) will  
continue to apply to the *Work* itself.  

**4.3** Any other use without of prior consent of Licensors is prohibited. Research  
users explicitly acknowledge having received from Licensors all information  
allowing to appreciate the adequacy between of the *Software* and their needs and  
to undertake all necessary precautions for its execution and use.  

**4.4** The *Software* is provided both as a compiled library file and as source  
code. In case of using the *Software* for a publication or other results obtained  
through the use of the *Software*, users are strongly encouraged to cite the  
corresponding publications as explained in the documentation of the *Software*.  

## 5.  Disclaimer  

THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES  
WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY  
UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL  
CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES  
OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL  
USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR  
ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE  
AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR  
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE  
GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)  
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT  
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR  
IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.  

## 6.  Files subject to permissive licenses
The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license. 

Title: pytorch-ssim\
Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\
Copyright Evan Su, 2017\
License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT)

================================================
FILE: README.md
================================================
<div align="center">

# [NeurIPS2024🔥] OpenGaussian: Towards Point-Level 3D Gaussian-based Open Vocabulary Understanding

<!-- <a href="https://arxiv.org/abs/2406.02058"><strong>Paper</strong></a> |  -->

<h3>
  <strong>Paper(<a href="https://arxiv.org/abs/2406.02058">arXiv</a> / <a href="https://proceedings.neurips.cc/paper_files/paper/2024/hash/21f7b745f73ce0d1f9bcea7f40b1388e-Abstract-Conference.html">Conference</a>)</strong> | 
  <a href="https://3d-aigc.github.io/OpenGaussian/"><strong>Project Page</strong></a>
</h3>

<!-- [**Paper**](https://arxiv.org/abs/2406.02058) | [**Project Page**](https://3d-aigc.github.io/OpenGaussian/) -->
<!-- [![arXiv](https://img.shields.io/badge/arXiv-<Paper>-<COLOR>.svg)](https://arxiv.org/abs/2406.02058)
[![Project Page](https://img.shields.io/badge/Project_Page-<Website>-blue.svg)](https://3d-aigc.github.io/OpenGaussian/) -->

[Yanmin Wu](https://yanmin-wu.github.io/)<sup>1</sup>, [Jiarui Meng](https://scholar.google.com/citations?user=N_pRAVAAAAAJ&hl=en&oi=ao)<sup>1</sup>, [Haijie Li](https://villa.jianzhang.tech/people/haijie-li-%E6%9D%8E%E6%B5%B7%E6%9D%B0/)<sup>1</sup>, [Chenming Wu](https://chenming-wu.github.io/)<sup>2*</sup>, [Yahao Shi](https://scholar.google.com/citations?user=-VJZrUkAAAAJ&hl=en)<sup>3</sup>, [Xinhua Cheng](https://cxh0519.github.io/)<sup>1</sup>, 
[Chen Zhao](https://openreview.net/profile?id=~Chen_Zhao9)<sup>2</sup>, [Haocheng Feng](https://openreview.net/profile?id=~Haocheng_Feng1)<sup>2</sup>, [Errui Ding](https://scholar.google.com/citations?user=1wzEtxcAAAAJ&hl=zh-CN)<sup>2</sup>, [Jingdong Wang](https://jingdongwang2017.github.io/)<sup>2</sup>, [Jian Zhang](https://jianzhang.tech/)<sup>1*</sup>

<sup>1</sup> Peking University, <sup>2</sup> Baidu VIS, <sup>3</sup> Beihang University

</div>

## 0. Installation

The installation of OpenGaussian is similar to [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting).
```
git clone https://github.com/yanmin-wu/OpenGaussian.git
```
Then install the dependencies:
```shell
conda env create --file environment.yml
conda activate gaussian_splatting

# the rasterization lib comes from DreamGaussian
cd OpenGaussian/submodules
unzip ashawkey-diff-gaussian-rasterization.zip
pip install ./ashawkey-diff-gaussian-rasterization
```
+ other additional dependencies: bitarray, scipy, [pytorch3d](https://anaconda.org/pytorch3d/pytorch3d/files)
    ```shell
    pip install bitarray scipy
    
    # install a pytorch3d version compatible with your PyTorch, Python, and CUDA.
    ```
+ `simple-knn` is not required

---

## 1. ToDo list

+ [x] Point feature visualization
+ [x] Data preprocessing
+ ~~[ ] Improved SAM mask extraction (extracting only one layer)~~
+ [x] Click to Select 3D Object

---

## 2. Data preparation
The files are as follows:
```
[DATA_ROOT]
├── [1] scannet/
│   │   ├── scene0000_00/
|   |   |   |── color/
|   |   |   |── language_features/
|   |   |   |── points3d.ply
|   |   |   |── transforms_train/test.json
|   |   |   |── *_vh_clean_2.labels.ply
│   │   ├── scene0062_00/
│   │   └── ...
├── [2] lerf_ovs/
│   │   ├── figurines/ & ramen/ & teatime/ & waldo_kitchen/
|   |   |   |── images/
|   |   |   |── language_features/
|   |   |   |── sparse/
│   │   ├── label/
```
+ **[1] Prepare ScanNet Data**
    + You can directly download our pre-processed data: [**OneDrive**](https://onedrive.live.com/?authkey=%21AIgsXZy3gl%5FuKmM&id=744D3E86422BE3C9%2139813&cid=744D3E86422BE3C9) / [Baidu](https://pan.baidu.com/s/1B_tGYla5dWyJRu3jTNTMvA?pwd=u5iy). Please unzip the `color.zip` and `language_features.zip` files.
    + The ScanNet dataset requires permission for use, following the [ScanNet instructions](https://github.com/ScanNet/ScanNet) to apply for dataset permission.
    + **If you want to process more scenes from the ScanNet dataset, you can follow these steps:**
	    + First, use the official `download-scannet.py` script provided by ScanNet to download the `.sens` archive of the specified scenes;
	    + Then, refer to the [`preprocess_2d_scannet.py`](https://github.com/pengsongyou/openscene/blob/main/scripts/preprocess/preprocess_2d_scannet.py) script to extract the `color` and `pose` information;
	    + Finally, convert the data into Blender format using the [`scripts/scannet2blender.py`](https://github.com/yanmin-wu/OpenGaussian/blob/main/scripts/scannet2blender.py) script. Please check the `TODO` comments in the script to specify the paths.
+ **[2] Prepare lerf_ovs Data**
    + You can directly download our pre-processed data: [**OneDrive**](https://onedrive.live.com/?authkey=%21AIgsXZy3gl%5FuKmM&id=744D3E86422BE3C9%2139815&cid=744D3E86422BE3C9) / [Baidu](https://pan.baidu.com/s/1B_tGYla5dWyJRu3jTNTMvA?pwd=u5iy) (re-annotated by LangSplat). Please unzip the `images.zip` and `language_features.zip` files.
+ **Mask and Language Feature Extraction Details**
    + We use the tools provided by LangSplat to extract the SAM mask and CLIP features, but we only use the large-level mask.

---

## 3. Training
### 3.1 ScanNet
```shell
chmod +x scripts/train_scannet.sh
./scripts/train_scannet.sh
```
+ Please ***check*** the script for more details and ***modify*** the dataset path.
+ you will see the following processes during training:
    ```shell
    [Stage 0] Start 3dgs pre-train ... (step 0-30k)
    [Stage 1] Start continuous instance feature learning ... (step 30-50k)
    [Stage 2.1] Start coarse-level codebook discretization ... (step 50-70k)
    [Stage 2.2] Start fine-level codebook discretization ... (step 70-90k)
    [Stage 3] Start 2D language feature - 3D cluster association ... (1 min)
    ```
+ Intermediate results from different stages can be found in subfolders `***/train_process/stage*`. (The intermediate results of stage 3 are recommended to be observed in the LeRF dataset.)

### 3.2 LeRF_ovs
```shell
chmod +x scripts/train_lerf.sh
./scripts/train_lerf.sh
```
+ Please ***check*** the script for more details and ***modify*** the dataset path.
+ you will see the following processes during training:
    ```shell
    [Stage 0] Start 3dgs pre-train ... (step 0-30k)
    [Stage 1] Start continuous instance feature learning ... (step 30-40k)
    [Stage 2.1] Start coarse-level codebook discretization ... (step 40-50k)
    [Stage 2.2] Start fine-level codebook discretization ... (step 50-70k)
    [Stage 3] Start 2D language feature - 3D cluster association ... (1 min)
    ```
+ Intermediate results from different stages can be found in subfolders `***/train_process/stage*`.

### 3.3 Custom data
+ Without any special processing, videos are first captured, approximately 200 frames are sampled, and COLMAP is then used to initialize the point cloud and camera poses.

---

## 4. Render & Eval & Downstream Tasks

### 4.1 3D Instance Feature Visualization
+ Please install `open3d` first, and then execute the following command on a system with UI support:
    ```python
    python scripts/vis_opengs_pts_feat.py
    ```
    + Please specify `ply_path` in the script as the PLY file `output/xxxxxxxx-x/point_cloud/iteration_x0000/point_cloud.ply` saved at different stages.
    + During the training process, we have saved the first three dimensions of the 6D features as colors for visualization; see [here](https://github.com/yanmin-wu/OpenGaussian/blob/2845b9c744c1b06ac6930ffa2d2a6f9167f1b843/scene/gaussian_model.py#L272).

### 4.2 Render 2D Feature Map
+ The same rendering method as the 3DGS rendering colors.
    ```shell
    python render.py -m "output/xxxxxxxx-x"
    ```
    You can find the rendered feature maps in subfolders `renders_ins_feat1` and `renders_ins_feat2`.

### 4.3 ScanNet Evalution (Open-Vocabulary Point Cloud Understanding)
> Due to code optimization and the use of more suitable hyperparameters, the latest evaluation metrics may be higher than those reported in the paper. 
+ Evaluate text-guided segmentation performance on ScanNet for 19, 15, and 10 categories.
    ```shell
    # unzip the pre-extracted text features
    cd assets
    unzip text_features.zip

    # 1. please check the `gt_file_path` and `model_path` are correct
    # 2. specify `target_id` as 19, 15, or 10 categories.
    python scripts/eval_scannet.py
    ```

### 4.4 LeRF Evalution (Open-Vocabulary Object Selection in 3D Space)
+ (1) First, render text-selected 3D Gaussians into multi-view images.
    ```shell
    # unzip the pre-extracted text features
    cd assets
    unzip text_features.zip

    # 1. specify the model path using -m
    # 2. specify the scene name: figurines, teatime, ramen, waldo_kitchen
    python render_lerf_by_text.py -m "output/xxxxxxxx-x" --scene_name "figurines"
    ```
    The object selection results are saved in `output/xxxxxxxx-x/text2obj/ours_70000/renders_cluster`.

+ (2) Then, compute evaluation metrics.
    > Due to code optimization and the use of more suitable hyperparameters, the latest evaluation metrics may be higher than those reported in the paper. 
    > The metrics may be unstable due to the limited evaluation samples of LeRF.
    ```shell
    # 1. change path_gt and path_pred in the script
    # 2. specify the scene name: figurines, teatime, ramen, waldo_kitchen
    python scripts/compute_lerf_iou.py --scene_name "figurines"
    ```

### 4.5 Click to Select 3D Object

+ (1) First, you need to render the feature maps (refer to Step 4.3; in practice, only two feature maps from a single view are required).
+ (2) Then, check the [`scripts/render_by_click.py`](https://github.com/yanmin-wu/OpenGaussian/blob/main/scripts/render_by_click.py) script for `TODO` comments, including specifying the frame filename, clicked pixel coordinates, and file paths.
+ (3) Finally, run the [`scripts/render_by_click.py`](https://github.com/yanmin-wu/OpenGaussian/blob/main/scripts/render_by_click.py) script. *Note that this script has not been tested with the current version of the code and may require debugging*.

---

## 5. Acknowledgements
We are quite grateful for [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), [LangSplat](https://github.com/minghanqin/LangSplat), [CompGS](https://github.com/UCDvision/compact3d), [LEGaussians](https://github.com/buaavrcg/LEGaussians), [SAGA](https://github.com/Jumpat/SegAnyGAussians), and [SAM](https://segment-anything.com/).

---

## 6. Citation

```
@inproceedings{wu2024opengaussian,
    title={OpenGaussian: Towards Point-Level 3D Gaussian-based Open Vocabulary Understanding},
    author={Wu, Yanmin and Meng, Jiarui and Li, Haijie and Wu, Chenming and Shi, Yahao and Cheng, Xinhua and Zhao, Chen and Feng, Haocheng and Ding, Errui and Wang, Jingdong and Zhang, Jian},
    booktitle={Proceedings of the Advances in Neural Information Processing Systems (NeurIPS)},
    pages={19114--19138},
    year={2024}
}
```

---

## 7. Contact
If you have any questions about this project, please feel free to contact [Yanmin Wu](https://yanmin-wu.github.io/): wuyanminmax[AT]gmail.com


================================================
FILE: arguments/__init__.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

from argparse import ArgumentParser, Namespace
import sys
import os

class GroupParams:
    pass

class ParamGroup:
    def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
        group = parser.add_argument_group(name)
        for key, value in vars(self).items():
            shorthand = False
            if key.startswith("_"):
                shorthand = True
                key = key[1:]
            t = type(value)
            value = value if not fill_none else None 
            if shorthand:
                if t == bool:
                    group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
                else:
                    group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
            else:
                if t == bool:
                    group.add_argument("--" + key, default=value, action="store_true")
                else:
                    group.add_argument("--" + key, default=value, type=t)

    def extract(self, args):
        group = GroupParams()
        for arg in vars(args).items():
            if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
                setattr(group, arg[0], arg[1])
        return group

class ModelParams(ParamGroup): 
    def __init__(self, parser, sentinel=False):
        self.sh_degree = 3
        self._source_path = ""
        self._model_path = ""
        self._images = "images"
        self._resolution = -1
        self._white_background = False
        self.data_device = "cuda"
        self.eval = False
        super().__init__(parser, "Loading Parameters", sentinel)

    def extract(self, args):
        g = super().extract(args)
        g.source_path = os.path.abspath(g.source_path)
        return g

class PipelineParams(ParamGroup):
    def __init__(self, parser):
        self.convert_SHs_python = False
        self.compute_cov3D_python = False
        self.debug = False
        super().__init__(parser, "Pipeline Parameters")

class OptimizationParams(ParamGroup):
    def __init__(self, parser):
        self.leaf_update_fr = 300           # coarse-level codebook update frequency
        self.ins_feat_dim = 6
        self.position_lr_init = 0.00016
        self.position_lr_final = 0.0000016
        self.position_lr_delay_mult = 0.01
        self.position_lr_max_steps = 30_000
        self.feature_lr = 0.0025
        self.ins_feat_lr = 0.001
        self.opacity_lr = 0.05
        self.scaling_lr = 0.005
        self.rotation_lr = 0.001
        self.percent_dense = 0.01
        self.lambda_dssim = 0.2
        self.densification_interval = 100
        self.opacity_reset_interval = 3000
        self.densify_from_iter = 500
        self.densify_until_iter = 15_000
        self.densify_grad_threshold = 0.0002
        self.random_background = False

        parser.add_argument('--root_node_num', type=int, default=64)    # k1=64
        parser.add_argument('--leaf_node_num', type=int, default=5)     # k2=5/10

        parser.add_argument('--pos_weight', type=float, default=1.0)    # position weight for coarse codebook
        parser.add_argument('--loss_weight', type=float, default=0.1)   # loss_cohesion weight

        parser.add_argument('--iterations', type=int, default=70_000)   # default 7w, scannet 9w
        parser.add_argument('--start_ins_feat_iter', type=int, default=30_000)  # default 3w
        parser.add_argument('--start_root_cb_iter', type=int, default=40_000)   # default 4w, scannet 5w
        parser.add_argument('--start_leaf_cb_iter', type=int, default=50_000)   # default 5w, scannet 7w

        # note: Freeze the position of the initial point, do not densify. for ScanNet
        parser.add_argument('--frozen_init_pts', action='store_true', default=False)
        parser.add_argument('--sam_level', type=int, default=3)

        parser.add_argument('--save_memory', action='store_true', default=False)
        super().__init__(parser, "Optimization Parameters")
    
    def extract(self, args):
        g = super().extract(args)
        g.root_node_num = args.root_node_num
        g.leaf_node_num = args.leaf_node_num
        g.pos_weight = args.pos_weight
        g.loss_weight = args.loss_weight
        g.frozen_init_pts = args.frozen_init_pts
        g.sam_level = args.sam_level
        g.iterations = args.iterations
        g.start_ins_feat_iter = args.start_ins_feat_iter
        g.start_root_cb_iter = args.start_root_cb_iter
        g.start_leaf_cb_iter = args.start_leaf_cb_iter
        g.save_memory = args.save_memory

        return g

def get_combined_args(parser : ArgumentParser):
    cmdlne_string = sys.argv[1:]
    cfgfile_string = "Namespace()"
    args_cmdline = parser.parse_args(cmdlne_string)

    try:
        cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
        print("Looking for config file in", cfgfilepath)
        with open(cfgfilepath) as cfg_file:
            print("Config file found: {}".format(cfgfilepath))
            cfgfile_string = cfg_file.read()
    except TypeError:
        print("Config file not found at")
        pass
    args_cfgfile = eval(cfgfile_string)

    merged_dict = vars(args_cfgfile).copy()
    for k,v in vars(args_cmdline).items():
        if v != None:
            merged_dict[k] = v
    return Namespace(**merged_dict)


================================================
FILE: convert.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import logging
from argparse import ArgumentParser
import shutil

# This Python script is based on the shell converter script provided in the MipNerF 360 repository.
parser = ArgumentParser("Colmap converter")
parser.add_argument("--no_gpu", action='store_true')
parser.add_argument("--skip_matching", action='store_true')
parser.add_argument("--source_path", "-s", required=True, type=str)
parser.add_argument("--camera", default="OPENCV", type=str)
parser.add_argument("--colmap_executable", default="", type=str)
parser.add_argument("--resize", action="store_true")
parser.add_argument("--magick_executable", default="", type=str)
args = parser.parse_args()
colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
use_gpu = 1 if not args.no_gpu else 0

if not args.skip_matching:
    os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)

    ## Feature extraction
    feat_extracton_cmd = colmap_command + " feature_extractor "\
        "--database_path " + args.source_path + "/distorted/database.db \
        --image_path " + args.source_path + "/input \
        --ImageReader.single_camera 1 \
        --ImageReader.camera_model " + args.camera + " \
        --SiftExtraction.use_gpu " + str(use_gpu)
    exit_code = os.system(feat_extracton_cmd)
    if exit_code != 0:
        logging.error(f"Feature extraction failed with code {exit_code}. Exiting.")
        exit(exit_code)

    ## Feature matching
    feat_matching_cmd = colmap_command + " exhaustive_matcher \
        --database_path " + args.source_path + "/distorted/database.db \
        --SiftMatching.use_gpu " + str(use_gpu)
    exit_code = os.system(feat_matching_cmd)
    if exit_code != 0:
        logging.error(f"Feature matching failed with code {exit_code}. Exiting.")
        exit(exit_code)

    ### Bundle adjustment
    # The default Mapper tolerance is unnecessarily large,
    # decreasing it speeds up bundle adjustment steps.
    mapper_cmd = (colmap_command + " mapper \
        --database_path " + args.source_path + "/distorted/database.db \
        --image_path "  + args.source_path + "/input \
        --output_path "  + args.source_path + "/distorted/sparse \
        --Mapper.ba_global_function_tolerance=0.000001")
    exit_code = os.system(mapper_cmd)
    if exit_code != 0:
        logging.error(f"Mapper failed with code {exit_code}. Exiting.")
        exit(exit_code)

### Image undistortion
## We need to undistort our images into ideal pinhole intrinsics.
img_undist_cmd = (colmap_command + " image_undistorter \
    --image_path " + args.source_path + "/input \
    --input_path " + args.source_path + "/distorted/sparse/0 \
    --output_path " + args.source_path + "\
    --output_type COLMAP")
exit_code = os.system(img_undist_cmd)
if exit_code != 0:
    logging.error(f"Mapper failed with code {exit_code}. Exiting.")
    exit(exit_code)

files = os.listdir(args.source_path + "/sparse")
os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
# Copy each file from the source directory to the destination directory
for file in files:
    if file == '0':
        continue
    source_file = os.path.join(args.source_path, "sparse", file)
    destination_file = os.path.join(args.source_path, "sparse", "0", file)
    shutil.move(source_file, destination_file)

if(args.resize):
    print("Copying and resizing...")

    # Resize images.
    os.makedirs(args.source_path + "/images_2", exist_ok=True)
    os.makedirs(args.source_path + "/images_4", exist_ok=True)
    os.makedirs(args.source_path + "/images_8", exist_ok=True)
    # Get the list of files in the source directory
    files = os.listdir(args.source_path + "/images")
    # Copy each file from the source directory to the destination directory
    for file in files:
        source_file = os.path.join(args.source_path, "images", file)

        destination_file = os.path.join(args.source_path, "images_2", file)
        shutil.copy2(source_file, destination_file)
        exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file)
        if exit_code != 0:
            logging.error(f"50% resize failed with code {exit_code}. Exiting.")
            exit(exit_code)

        destination_file = os.path.join(args.source_path, "images_4", file)
        shutil.copy2(source_file, destination_file)
        exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file)
        if exit_code != 0:
            logging.error(f"25% resize failed with code {exit_code}. Exiting.")
            exit(exit_code)

        destination_file = os.path.join(args.source_path, "images_8", file)
        shutil.copy2(source_file, destination_file)
        exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
        if exit_code != 0:
            logging.error(f"12.5% resize failed with code {exit_code}. Exiting.")
            exit(exit_code)

print("Done.")


================================================
FILE: environment.yml
================================================
name: gaussian_splatting
channels:
  - pytorch
  - conda-forge
  - defaults
dependencies:
  - cudatoolkit=11.6
  - plyfile=0.8.1
  - python=3.7.13
  - pip=22.3.1
  - pytorch=1.12.1
  - torchaudio=0.12.1
  - torchvision=0.13.1
  - tqdm
  - pip:
    - bitarray
    - scipy
    - submodules/ashawkey-diff-gaussian-rasterization

================================================
FILE: full_eval.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
from argparse import ArgumentParser

mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
tanks_and_temples_scenes = ["truck", "train"]
deep_blending_scenes = ["drjohnson", "playroom"]

parser = ArgumentParser(description="Full evaluation script parameters")
parser.add_argument("--skip_training", action="store_true")
parser.add_argument("--skip_rendering", action="store_true")
parser.add_argument("--skip_metrics", action="store_true")
parser.add_argument("--output_path", default="./eval")
args, _ = parser.parse_known_args()

all_scenes = []
all_scenes.extend(mipnerf360_outdoor_scenes)
all_scenes.extend(mipnerf360_indoor_scenes)
all_scenes.extend(tanks_and_temples_scenes)
all_scenes.extend(deep_blending_scenes)

if not args.skip_training or not args.skip_rendering:
    parser.add_argument('--mipnerf360', "-m360", required=True, type=str)
    parser.add_argument("--tanksandtemples", "-tat", required=True, type=str)
    parser.add_argument("--deepblending", "-db", required=True, type=str)
    args = parser.parse_args()

if not args.skip_training:
    common_args = " --quiet --eval --test_iterations -1 "
    for scene in mipnerf360_outdoor_scenes:
        source = args.mipnerf360 + "/" + scene
        os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args)
    for scene in mipnerf360_indoor_scenes:
        source = args.mipnerf360 + "/" + scene
        os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args)
    for scene in tanks_and_temples_scenes:
        source = args.tanksandtemples + "/" + scene
        os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
    for scene in deep_blending_scenes:
        source = args.deepblending + "/" + scene
        os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)

if not args.skip_rendering:
    all_sources = []
    for scene in mipnerf360_outdoor_scenes:
        all_sources.append(args.mipnerf360 + "/" + scene)
    for scene in mipnerf360_indoor_scenes:
        all_sources.append(args.mipnerf360 + "/" + scene)
    for scene in tanks_and_temples_scenes:
        all_sources.append(args.tanksandtemples + "/" + scene)
    for scene in deep_blending_scenes:
        all_sources.append(args.deepblending + "/" + scene)

    common_args = " --quiet --eval --skip_train"
    for scene, source in zip(all_scenes, all_sources):
        os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
        os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)

if not args.skip_metrics:
    scenes_string = ""
    for scene in all_scenes:
        scenes_string += "\"" + args.output_path + "/" + scene + "\" "

    os.system("python metrics.py -m " + scenes_string)

================================================
FILE: gaussian_renderer/__init__.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import math
# from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from ashawkey_diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from scene.gaussian_model import GaussianModel
from utils.sh_utils import eval_sh
from utils.opengs_utlis import *
# from sklearn.neighbors import NearestNeighbors
import pytorch3d.ops

def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, iteration,
            scaling_modifier = 1.0, override_color = None, visible_mask = None, mask_num=0,
            cluster_idx=None,       # per-point cluster id (coarse-level)
            leaf_cluster_idx=None,  # per-point cluster id (fine-level)
            rescale=True,           # re-scale (for enhance ins_feat)
            origin_feat=False,      # origin ins_feat (not quantized)
            render_feat_map=True,   # render image-level feat map
            render_color=True,      # render rgb image
            render_cluster=False,   # render cluster, stage 2.2
            better_vis=False,       # filter some points
            selected_root_id=None,  # coarse-level cluster id
            selected_leaf_id=None,  # fine-level cluster id (possibly more than one)
            pre_mask=None,
            seg_rgb=False,          # render cluster rgb, not feat
            post_process=False,     # post
            root_num=64, leaf_num=10):  # k1, k2 
    """
    Render the scene. 
    
    Background tensor (bg_color) must be on GPU!
    """
 
    # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
    screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
    try:
        screenspace_points.retain_grad()
    except:
        pass

    # Set up rasterization configuration
    tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
    tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

    raster_settings = GaussianRasterizationSettings(
        image_height=int(viewpoint_camera.image_height),
        image_width=int(viewpoint_camera.image_width),
        tanfovx=tanfovx,
        tanfovy=tanfovy,
        bg=bg_color,
        scale_modifier=scaling_modifier,
        viewmatrix=viewpoint_camera.world_view_transform,
        projmatrix=viewpoint_camera.full_proj_transform,
        sh_degree=pc.active_sh_degree,
        campos=viewpoint_camera.camera_center,
        prefiltered=False,
        debug=pipe.debug
    )

    rasterizer = GaussianRasterizer(raster_settings=raster_settings)

    means3D = pc.get_xyz
    means2D = screenspace_points
    opacity = pc.get_opacity

    # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
    # scaling / rotation by the rasterizer.
    scales = None
    rotations = None
    cov3D_precomp = None
    if pipe.compute_cov3D_python:
        cov3D_precomp = pc.get_covariance(scaling_modifier)
    else:
        scales = pc.get_scaling
        rotations = pc.get_rotation

    # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
    # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
    shs = None
    colors_precomp = None
    if override_color is None:
        if pipe.convert_SHs_python:
            shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
            dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
            dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
            sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
            colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
        else:
            shs = pc.get_features
    else:
        colors_precomp = override_color

    if render_color:
        rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
            means3D = means3D,
            means2D = means2D,
            shs = shs,
            colors_precomp = colors_precomp,
            opacities = opacity,
            scales = scales,
            rotations = rotations,
            cov3D_precomp = cov3D_precomp)
    else:
        rendered_image, radii, rendered_depth, rendered_alpha = None, None, None, None

    # ################################################################
    # [Stage 1, Stage 2.1] Render image-level instance feature map   #
    #   - rendered_ins_feat: image-level feat map                    #
    # ################################################################
    # probabilistically rescale
    prob = torch.rand(1)
    rescale_factor = torch.tensor(1.0, dtype=torch.float32).cuda()
    if prob > 0.5 and rescale:
        rescale_factor = torch.rand(1).cuda()
    if render_feat_map:
        # get feature
        ins_feat = (pc.get_ins_feat(origin=origin_feat) + 1) / 2   # pseudo -> norm, else -> raw
        # first three channels
        rendered_ins_feat, _, _, _ = rasterizer(
            means3D = means3D,
            means2D = means2D,
            shs = None,
            colors_precomp = ins_feat[:, :3],   # render features as pre-computed colors
            opacities = opacity,
            scales = scales * rescale_factor,

            rotations = rotations,
            cov3D_precomp = cov3D_precomp)
        # last three channels
        if ins_feat.shape[-1] > 3:
            rendered_ins_feat2, _, _, _ = rasterizer(
                means3D = means3D,
                means2D = means2D,
                shs = None,
                colors_precomp = ins_feat[:, 3:6],  # render features as pre-computed colors
                opacities = opacity,
                scales = scales * rescale_factor,

                rotations = rotations,
                cov3D_precomp = cov3D_precomp)
            rendered_ins_feat = torch.cat((rendered_ins_feat, rendered_ins_feat2), dim=0)
        # mask
        _, _, _, silhouette = rasterizer(
            means3D = means3D,
            means2D = means2D,
            shs = shs,
            colors_precomp = colors_precomp,
            opacities = opacity,
            scales = scales * rescale_factor,
            # opacities = opacity*0+1.0,    # 
            # scales = scales*0+0.001,   # *0.1
            rotations = rotations,
            cov3D_precomp = cov3D_precomp)
    else:
        rendered_ins_feat, silhouette = None, None


    # ########################################################################
    # [Preprocessing for Stage 2.2]: render (coarse) cluster-level feat map  #
    #   - rendered_clusters: feat map of the coarse clusters                 #
    #   - rendered_cluster_silhouettes: cluster mask                         #
    # ########################################################################
    viewed_pts = radii > 0      # ignore the invisible points
    if cluster_idx is not None:
        num_cluster = cluster_idx.max() + 1
        cluster_occur = torch.zeros(num_cluster).to(torch.bool) # [num_cluster], bool
    else:
        cluster_occur = None
    if render_cluster and cluster_idx is not None and viewed_pts.sum() != 0:
        ins_feat = (pc.get_ins_feat(origin=origin_feat) + 1) / 2   # pseudo -> norm, else -> raw
        rendered_clusters = []
        rendered_cluster_silhouettes = []
        scale_filter = (scales < 0.5).all(dim=1)    #   filter
        for idx in range(num_cluster):
            if not better_vis and idx != selected_root_id:
                continue

            # ignore the invisible coarse-level cluster
            if viewpoint_camera.bClusterOccur is not None and viewpoint_camera.bClusterOccur[idx] == False:
                continue
            
            # NOTE: Render only the idx-th coarse cluster
            filter_idx = cluster_idx == idx
            
            filter_idx = filter_idx & viewed_pts
            # todo: filter
            if better_vis:
                filter_idx = filter_idx & scale_filter
                if filter_idx.sum() < 100:
                    continue
                    
            # render cluster-level feat map
            rendered_cluster, _, _, cluster_silhouette = rasterizer(
                means3D = means3D[filter_idx],
                means2D = means2D[filter_idx],
                shs = None,  # feat
                colors_precomp = ins_feat[:, :3][filter_idx],  # feat
                # shs = shs[filter_idx],  # rgb
                # colors_precomp = None,  # rgb
                opacities = opacity[filter_idx],
                scales = scales[filter_idx] * rescale_factor,
                rotations = rotations[filter_idx],
                cov3D_precomp = cov3D_precomp)
            if ins_feat.shape[-1] > 3:
                rendered_cluster2, _, _, cluster_silhouette = rasterizer(
                    means3D = means3D[filter_idx],
                    means2D = means2D[filter_idx],
                    shs = None,           # feat
                    colors_precomp = ins_feat[:, 3:][filter_idx],  # feat
                    # shs = shs[filter_idx],  # rgb
                    # colors_precomp = None,  # rgb
                    opacities = opacity[filter_idx],
                    scales = scales[filter_idx] * rescale_factor,
                    rotations = rotations[filter_idx],
                    cov3D_precomp = cov3D_precomp)
                rendered_cluster = torch.cat((rendered_cluster, rendered_cluster2), dim=0)

            # alpha --> mask
            if cluster_silhouette.max() > 0.8:
                cluster_occur[idx] = True
                rendered_clusters.append(rendered_cluster)
                rendered_cluster_silhouettes.append(cluster_silhouette)
        if len(rendered_cluster_silhouettes) != 0:
            rendered_cluster_silhouettes = torch.vstack(rendered_cluster_silhouettes)
    else:
        rendered_clusters, rendered_cluster_silhouettes = None, None


    # ###############################################################
    # [Stage 2.2 & Stage 3] render (fine) cluster-level feat map    #
    #   - rendered_leaf_clusters: feat map of the fine clusters     #
    #   - rendered_leaf_cluster_silhouettes: fine cluster mask      #
    #   - occured_leaf_id: visible fine cluster id                  #
    # ###############################################################
    if leaf_cluster_idx is not None and leaf_cluster_idx.numel() > 0:
        ins_feat = (pc.get_ins_feat(origin=origin_feat) + 1) / 2   # pseudo -> norm, else -> raw
        # todo: rescale
        scale_filter = (scales < 0.1).all(dim=1)
        # scale_filter = (scales < 0.1).all(dim=1) & (opacity > 0.1).squeeze(-1)
        re_scale_factor = torch.ones_like(opacity)  # not used

        # determine the fine cluster ID range (lerf_range) based on the coarse cluster ID (selected_leaf_id).
        # root_num = 64   # todo modify
        # leaf_num = 5    # todo modify
        rendered_leaf_clusters = []
        rendered_leaf_cluster_silhouettes = []
        occured_leaf_id = []
        if selected_leaf_id is None:
            if selected_root_id is not None:
                start_leaf = selected_root_id * leaf_num   # todo 10
                end_leaf = start_leaf + leaf_num   # todo 10
            else:
                start_leaf = 0
                end_leaf = root_num * leaf_num  # todo 64 * 10
            lerf_range = range(start_leaf, end_leaf)
        else:
            lerf_range = selected_leaf_id.tolist()
        for _, leaf_idx in enumerate(lerf_range):   # render each fine cluster
            # ignore the invisible clusters
            if viewpoint_camera.bClusterOccur is not None and viewpoint_camera.bClusterOccur[selected_root_id] == False:
                continue

            if selected_leaf_id is None:
                filter_idx = leaf_cluster_idx == leaf_idx     # Render only the idx-th fine cluster
                # filter_idx = labels != value      # remove the idx-th fine cluster
            else:
                filter_idx = (leaf_cluster_idx.unsqueeze(1) == selected_leaf_id).any(dim=1)

            # pre-mask
            if pre_mask is not None:
                filter_idx = filter_idx & pre_mask

            filter_idx = filter_idx & viewed_pts
            # filter
            if better_vis:
                filter_idx = filter_idx & scale_filter
                if filter_idx.sum() < 100:
                    continue
            
            # TODO post process (for 3D object selection)
            # pre_count = filter_idx.sum()
            max_time = 5
            if post_process and max_time > 0:
                nearest_k_distance = pytorch3d.ops.knn_points(
                    means3D[filter_idx].unsqueeze(0),
                    means3D[filter_idx].unsqueeze(0),
                    # K=int(filter_idx.sum()**0.5),
                    K=int(filter_idx.sum()**0.5),
                ).dists
                mean_nearest_k_distance, std_nearest_k_distance = nearest_k_distance.mean(), nearest_k_distance.std()
                # print(std_nearest_k_distance, "std_nearest_k_distance")

                mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + std_nearest_k_distance
                # mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + 0.1 * std_nearest_k_distance

                mask = mask.squeeze()
                if filter_idx is not None:
                    filter_idx[filter_idx != 0] = mask
                max_time -= 1
            
            if filter_idx.sum() < 10:
                continue

            # record the fine cluster id appears in the current view.
            occured_leaf_id.append(leaf_idx)

            # note: render cluster rgb or feat.
            if seg_rgb:
                _shs = shs[filter_idx]
                _colors_precomp1 = None
                _colors_precomp2 = None
            else:
                _shs = None
                _colors_precomp1 = ins_feat[:, :3][filter_idx]
                _colors_precomp2 = ins_feat[:, 3:][filter_idx]
            
            rendered_leaf_cluster, _, _, leaf_cluster_silhouette = rasterizer(
                means3D = means3D[filter_idx],
                means2D = means2D[filter_idx],
                shs = _shs,                          # rgb or feat
                colors_precomp = _colors_precomp1,   # rgb or feat
                opacities = opacity[filter_idx],
                scales = (scales * re_scale_factor)[filter_idx],
                rotations = rotations[filter_idx],
                cov3D_precomp = cov3D_precomp)
            if ins_feat.shape[-1] > 3:
                rendered_leaf_cluster2, _, _, _ = rasterizer(
                    means3D = means3D[filter_idx],
                    means2D = means2D[filter_idx],
                    shs = _shs,                          # rgb or feat
                    colors_precomp = _colors_precomp2,   # rgb or feat
                    opacities = opacity[filter_idx],
                    scales = (scales * re_scale_factor)[filter_idx],
                    rotations = rotations[filter_idx],
                    cov3D_precomp = cov3D_precomp)
                rendered_leaf_cluster = torch.cat((rendered_leaf_cluster, rendered_leaf_cluster2), dim=0)
            rendered_leaf_clusters.append(rendered_leaf_cluster)
            rendered_leaf_cluster_silhouettes.append(leaf_cluster_silhouette)
            if selected_leaf_id is not None and len(rendered_leaf_clusters) > 0:
                break
        if len(rendered_leaf_cluster_silhouettes) != 0:
            rendered_leaf_cluster_silhouettes = torch.vstack(rendered_leaf_cluster_silhouettes)
    else:
        rendered_leaf_clusters = None
        rendered_leaf_cluster_silhouettes =  None
        occured_leaf_id = None

    # Those Gaussians that were frustum culled or had a radius of 0 were not visible.
    # They will be excluded from value updates used in the splitting criteria.
    return {"render": rendered_image,
            "alpha": rendered_alpha,
            "depth": rendered_depth,    # not used
            "silhouette": silhouette,
            "ins_feat": rendered_ins_feat,          # image-level feat map
            "cluster_imgs": rendered_clusters,      # coarse cluster feat map/image
            "cluster_silhouettes": rendered_cluster_silhouettes,    # coarse cluster mask
            "leaf_clusters_imgs": rendered_leaf_clusters,           # fine cluster feat map/image
            "leaf_cluster_silhouettes": rendered_leaf_cluster_silhouettes,  # fine cluster mask
            "occured_leaf_id": occured_leaf_id,     # fine cluster
            "cluster_occur": cluster_occur,         # coarse cluster
            "viewspace_points": screenspace_points,
            "visibility_filter" : radii > 0,
            "radii": radii}

================================================
FILE: gaussian_renderer/network_gui.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import traceback
import socket
import json
from scene.cameras import MiniCam

host = "127.0.0.1"
port = 6009

conn = None
addr = None

listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

def init(wish_host, wish_port):
    global host, port, listener
    host = wish_host
    port = wish_port
    listener.bind((host, port))
    listener.listen()
    listener.settimeout(0)

def try_connect():
    global conn, addr, listener
    try:
        conn, addr = listener.accept()
        print(f"\nConnected by {addr}")
        conn.settimeout(None)
    except Exception as inst:
        pass
            
def read():
    global conn
    messageLength = conn.recv(4)
    messageLength = int.from_bytes(messageLength, 'little')
    message = conn.recv(messageLength)
    return json.loads(message.decode("utf-8"))

def send(message_bytes, verify):
    global conn
    if message_bytes != None:
        conn.sendall(message_bytes)
    conn.sendall(len(verify).to_bytes(4, 'little'))
    conn.sendall(bytes(verify, 'ascii'))

def receive():
    message = read()

    width = message["resolution_x"]
    height = message["resolution_y"]

    if width != 0 and height != 0:
        try:
            do_training = bool(message["train"])
            fovy = message["fov_y"]
            fovx = message["fov_x"]
            znear = message["z_near"]
            zfar = message["z_far"]
            do_shs_python = bool(message["shs_python"])
            do_rot_scale_python = bool(message["rot_scale_python"])
            keep_alive = bool(message["keep_alive"])
            scaling_modifier = message["scaling_modifier"]
            world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
            world_view_transform[:,1] = -world_view_transform[:,1]
            world_view_transform[:,2] = -world_view_transform[:,2]
            full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
            full_proj_transform[:,1] = -full_proj_transform[:,1]
            custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
        except Exception as e:
            print("")
            traceback.print_exc()
            raise e
        return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
    else:
        return None, None, None, None, None, None

================================================
FILE: lpipsPyTorch/__init__.py
================================================
import torch

from .modules.lpips import LPIPS


def lpips(x: torch.Tensor,
          y: torch.Tensor,
          net_type: str = 'alex',
          version: str = '0.1'):
    r"""Function that measures
    Learned Perceptual Image Patch Similarity (LPIPS).

    Arguments:
        x, y (torch.Tensor): the input tensors to compare.
        net_type (str): the network type to compare the features: 
                        'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
        version (str): the version of LPIPS. Default: 0.1.
    """
    device = x.device
    criterion = LPIPS(net_type, version).to(device)
    return criterion(x, y)


================================================
FILE: lpipsPyTorch/modules/lpips.py
================================================
import torch
import torch.nn as nn

from .networks import get_network, LinLayers
from .utils import get_state_dict


class LPIPS(nn.Module):
    r"""Creates a criterion that measures
    Learned Perceptual Image Patch Similarity (LPIPS).

    Arguments:
        net_type (str): the network type to compare the features: 
                        'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
        version (str): the version of LPIPS. Default: 0.1.
    """
    def __init__(self, net_type: str = 'alex', version: str = '0.1'):

        assert version in ['0.1'], 'v0.1 is only supported now'

        super(LPIPS, self).__init__()

        # pretrained network
        self.net = get_network(net_type)

        # linear layers
        self.lin = LinLayers(self.net.n_channels_list)
        self.lin.load_state_dict(get_state_dict(net_type, version))

    def forward(self, x: torch.Tensor, y: torch.Tensor):
        feat_x, feat_y = self.net(x), self.net(y)

        diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
        res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]

        return torch.sum(torch.cat(res, 0), 0, True)


================================================
FILE: lpipsPyTorch/modules/networks.py
================================================
from typing import Sequence

from itertools import chain

import torch
import torch.nn as nn
from torchvision import models

from .utils import normalize_activation


def get_network(net_type: str):
    if net_type == 'alex':
        return AlexNet()
    elif net_type == 'squeeze':
        return SqueezeNet()
    elif net_type == 'vgg':
        return VGG16()
    else:
        raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')


class LinLayers(nn.ModuleList):
    def __init__(self, n_channels_list: Sequence[int]):
        super(LinLayers, self).__init__([
            nn.Sequential(
                nn.Identity(),
                nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
            ) for nc in n_channels_list
        ])

        for param in self.parameters():
            param.requires_grad = False


class BaseNet(nn.Module):
    def __init__(self):
        super(BaseNet, self).__init__()

        # register buffer
        self.register_buffer(
            'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer(
            'std', torch.Tensor([.458, .448, .450])[None, :, None, None])

    def set_requires_grad(self, state: bool):
        for param in chain(self.parameters(), self.buffers()):
            param.requires_grad = state

    def z_score(self, x: torch.Tensor):
        return (x - self.mean) / self.std

    def forward(self, x: torch.Tensor):
        x = self.z_score(x)

        output = []
        for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
            x = layer(x)
            if i in self.target_layers:
                output.append(normalize_activation(x))
            if len(output) == len(self.target_layers):
                break
        return output


class SqueezeNet(BaseNet):
    def __init__(self):
        super(SqueezeNet, self).__init__()

        self.layers = models.squeezenet1_1(True).features
        self.target_layers = [2, 5, 8, 10, 11, 12, 13]
        self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]

        self.set_requires_grad(False)


class AlexNet(BaseNet):
    def __init__(self):
        super(AlexNet, self).__init__()

        self.layers = models.alexnet(True).features
        self.target_layers = [2, 5, 8, 10, 12]
        self.n_channels_list = [64, 192, 384, 256, 256]

        self.set_requires_grad(False)


class VGG16(BaseNet):
    def __init__(self):
        super(VGG16, self).__init__()

        self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
        self.target_layers = [4, 9, 16, 23, 30]
        self.n_channels_list = [64, 128, 256, 512, 512]

        self.set_requires_grad(False)


================================================
FILE: lpipsPyTorch/modules/utils.py
================================================
from collections import OrderedDict

import torch


def normalize_activation(x, eps=1e-10):
    norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
    return x / (norm_factor + eps)


def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
    # build url
    url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
        + f'master/lpips/weights/v{version}/{net_type}.pth'

    # download
    old_state_dict = torch.hub.load_state_dict_from_url(
        url, progress=True,
        map_location=None if torch.cuda.is_available() else torch.device('cpu')
    )

    # rename keys
    new_state_dict = OrderedDict()
    for key, val in old_state_dict.items():
        new_key = key
        new_key = new_key.replace('lin', '')
        new_key = new_key.replace('model.', '')
        new_state_dict[new_key] = val

    return new_state_dict


================================================
FILE: metrics.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms.functional as tf
from utils.loss_utils import ssim
from lpipsPyTorch import lpips
import json
from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser

def readImages(renders_dir, gt_dir):
    renders = []
    gts = []
    image_names = []
    for fname in os.listdir(renders_dir):
        render = Image.open(renders_dir / fname)
        gt = Image.open(gt_dir / fname)
        renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
        gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
        image_names.append(fname)
    return renders, gts, image_names

def evaluate(model_paths):

    full_dict = {}
    per_view_dict = {}
    full_dict_polytopeonly = {}
    per_view_dict_polytopeonly = {}
    print("")

    for scene_dir in model_paths:
        try:
            print("Scene:", scene_dir)
            full_dict[scene_dir] = {}
            per_view_dict[scene_dir] = {}
            full_dict_polytopeonly[scene_dir] = {}
            per_view_dict_polytopeonly[scene_dir] = {}

            test_dir = Path(scene_dir) / "test"

            for method in os.listdir(test_dir):
                print("Method:", method)

                full_dict[scene_dir][method] = {}
                per_view_dict[scene_dir][method] = {}
                full_dict_polytopeonly[scene_dir][method] = {}
                per_view_dict_polytopeonly[scene_dir][method] = {}

                method_dir = test_dir / method
                gt_dir = method_dir/ "gt"
                renders_dir = method_dir / "renders"
                renders, gts, image_names = readImages(renders_dir, gt_dir)

                ssims = []
                psnrs = []
                lpipss = []

                for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
                    ssims.append(ssim(renders[idx], gts[idx]))
                    psnrs.append(psnr(renders[idx], gts[idx]))
                    lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))

                print("  SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
                print("  PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
                print("  LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
                print("")

                full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
                                                        "PSNR": torch.tensor(psnrs).mean().item(),
                                                        "LPIPS": torch.tensor(lpipss).mean().item()})
                per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
                                                            "PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
                                                            "LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})

            with open(scene_dir + "/results.json", 'w') as fp:
                json.dump(full_dict[scene_dir], fp, indent=True)
            with open(scene_dir + "/per_view.json", 'w') as fp:
                json.dump(per_view_dict[scene_dir], fp, indent=True)
        except:
            print("Unable to compute metrics for model", scene_dir)

if __name__ == "__main__":
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)

    # Set up command line argument parser
    parser = ArgumentParser(description="Training script parameters")
    parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
    args = parser.parse_args()
    evaluate(args.model_paths)


================================================
FILE: render.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import torch.nn.functional as F
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import numpy as np
from utils.opengs_utlis import get_SAM_mask_and_feat, load_code_book

# Randomly initialize 300 colors for visualizing the SAM mask. [OpenGaussian]
np.random.seed(42)
colors_defined = np.random.randint(100, 256, size=(300, 3))
colors_defined[0] = np.array([0, 0, 0]) # Ignore the mask ID of -1 and set it to black.
colors_defined = torch.from_numpy(colors_defined)

def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
    render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
    gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")

    render_ins_feat_path1 = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_ins_feat1")
    render_ins_feat_path2 = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_ins_feat2")
    gt_sam_mask_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt_sam_mask")

    makedirs(render_path, exist_ok=True)
    makedirs(gts_path, exist_ok=True)
    makedirs(render_ins_feat_path1, exist_ok=True)
    makedirs(render_ins_feat_path2, exist_ok=True)
    makedirs(gt_sam_mask_path, exist_ok=True)

    # load codebook
    root_code_book_path = os.path.join(model_path, "point_cloud", f'iteration_{iteration}', "root_code_book")
    leaf_code_book_path = os.path.join(model_path, "point_cloud", f'iteration_{iteration}', "leaf_code_book")
    if os.path.exists(os.path.join(root_code_book_path, 'kmeans_inds.bin')):
        root_code_book, root_cluster_indices = load_code_book(root_code_book_path)
        root_cluster_indices = torch.from_numpy(root_cluster_indices).cuda()
    if os.path.exists(os.path.join(leaf_code_book_path, 'kmeans_inds.bin')):
        leaf_code_book, leaf_cluster_indices = load_code_book(leaf_code_book_path)
        leaf_cluster_indices = torch.from_numpy(leaf_cluster_indices).cuda()
    else:
        leaf_cluster_indices = None

    # render
    for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
        render_pkg = render(view, gaussians, pipeline, background, iteration, rescale=False)

        # RGB
        rendering = render_pkg["render"]
        gt = view.original_image[0:3, :, :]

        # ins_feat
        rendered_ins_feat = render_pkg["ins_feat"]
        gt_sam_mask = view.original_sam_mask.cuda()    # [4, H, W]

        # Rendered RGB
        torchvision.utils.save_image(rendering, os.path.join(render_path, view.image_name + ".png"))
        # GT RGB
        torchvision.utils.save_image(gt, os.path.join(gts_path, view.image_name + ".png"))

        # ins_feat
        torchvision.utils.save_image(rendered_ins_feat[:3,:,:], os.path.join(render_ins_feat_path1, view.image_name + "_1.png"))
        torchvision.utils.save_image(rendered_ins_feat[3:6,:,:], os.path.join(render_ins_feat_path2, view.image_name + "_2.png"))

        # NOTE get SAM id, mask bool, mask_feat, invalid pix
        mask_id, _, _, _ = \
            get_SAM_mask_and_feat(gt_sam_mask, level=0, original_mask_feat=view.original_mask_feat)
        # mask visualization
        mask_color_rand = colors_defined[mask_id.detach().cpu().type(torch.int64)].type(torch.float64)
        mask_color_rand = mask_color_rand.permute(2, 0, 1)
        torchvision.utils.save_image(mask_color_rand/255.0, os.path.join(gt_sam_mask_path, view.image_name + ".png"))

def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
    with torch.no_grad():
        gaussians = GaussianModel(dataset.sh_degree)
        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

        bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        if not skip_train:
             render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)

        if not skip_test:
             render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)

if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    args = get_combined_args(parser)
    print("Rendering " + args.model_path)

    # Initialize system state (RNG)
    safe_state(args.quiet)

    render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)

================================================
FILE: render_lerf_by_text.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import torch.nn.functional as F
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import numpy as np
import json
from utils.opengs_utlis import mask_feature_mean, get_SAM_mask_and_feat, load_code_book

np.random.seed(42)
colors_defined = np.random.randint(100, 256, size=(300, 3))
colors_defined[0] = np.array([0, 0, 0]) # Ignore the mask ID of -1 and set it to black.
colors_defined = torch.from_numpy(colors_defined)

def render_set(model_path, name, iteration, views, gaussians, pipeline, background, scene_name):
    render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
    gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")

    render_ins_feat_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_ins_feat")
    gt_sam_mask_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt_sam_mask")

    makedirs(render_path, exist_ok=True)
    makedirs(gts_path, exist_ok=True)
    makedirs(render_ins_feat_path, exist_ok=True)
    makedirs(gt_sam_mask_path, exist_ok=True)

    # load codebook
    root_code_book, root_cluster_indices = load_code_book(os.path.join(model_path, "point_cloud", \
        f'iteration_{iteration}', "root_code_book"))
    leaf_code_book, leaf_cluster_indices = load_code_book(os.path.join(model_path, "point_cloud", \
        f'iteration_{iteration}', "leaf_code_book"))
    root_cluster_indices = torch.from_numpy(root_cluster_indices).cuda()
    leaf_cluster_indices = torch.from_numpy(leaf_cluster_indices).cuda()
    # counts = torch.bincount(torch.from_numpy(cluster_indices), minlength=64)

    # load the saved codebook(leaf id) and instance-level language feature
    # 'leaf_feat', 'leaf_acore', 'occu_count', 'leaf_ind'
    mapping_file = os.path.join(model_path, "cluster_lang.npz")
    saved_data = np.load(mapping_file)
    leaf_lang_feat = torch.from_numpy(saved_data["leaf_feat.npy"]).cuda()    # [num_leaf=k1*k2, 512] cluster lang feat
    leaf_score = torch.from_numpy(saved_data["leaf_score.npy"]).cuda()       # [num_leaf=k1*k2] cluster score
    leaf_occu_count = torch.from_numpy(saved_data["occu_count.npy"]).cuda()  # [num_leaf=k1*k2] 
    leaf_ind = torch.from_numpy(saved_data["leaf_ind.npy"]).cuda()           # [num_pts] fine id
    leaf_lang_feat[leaf_occu_count < 5] *= 0.0      # Filter out clusters that occur too infrequently.
    leaf_cluster_indices = leaf_ind
    
    root_num = root_cluster_indices.max() + 1
    leaf_num = leaf_lang_feat.shape[0] / root_num

    # text feature
    with open('assets/text_features.json', 'r') as f:
        data_loaded = json.load(f)
    all_texts = list(data_loaded.keys())
    text_features = torch.from_numpy(np.array(list(data_loaded.values()))).to(torch.float32)  # [num_text, 512]

    scene_texts = {
        "waldo_kitchen": ['Stainless steel pots', 'dark cup', 'refrigerator', 'frog cup', 'pot', 'spatula', 'plate', \
                'spoon', 'toaster', 'ottolenghi', 'plastic ladle', 'sink', 'ketchup', 'cabinet', 'red cup', \
                'pour-over vessel', 'knife', 'yellow desk'],
        "ramen": ['nori', 'sake cup', 'kamaboko', 'corn', 'spoon', 'egg', 'onion segments', 'plate', \
                'napkin', 'bowl', 'glass of water', 'hand', 'chopsticks', 'wavy noodles'],
        "figurines": ['jake', 'pirate hat', 'pikachu', 'rubber duck with hat', 'porcelain hand', \
                    'red apple', 'tesla door handle', 'waldo', 'bag', 'toy cat statue', 'miffy', \
                    'green apple', 'pumpkin', 'rubics cube', 'old camera', 'rubber duck with buoy', \
                    'red toy chair', 'pink ice cream', 'spatula', 'green toy chair', 'toy elephant'],
        "teatime": ['sheep', 'yellow pouf', 'stuffed bear', 'coffee mug', 'tea in a glass', 'apple', 
                'coffee', 'hooves', 'bear nose', 'dall-e brand', 'plate', 'paper napkin', 'three cookies', \
                'bag of cookies']
    }
    # note: query text
    target_text = scene_texts[scene_name]

    query_text_feats = torch.zeros(len(target_text), 512).cuda()
    for i, text in enumerate(target_text):
        feat = text_features[all_texts.index(text)].unsqueeze(0)
        query_text_feats[i] = feat

    for t_i, text_feat in enumerate(query_text_feats):
        # if target_text[t_i] != "old camera":
        #     continue

        print(f"rendering the {t_i+1}-th query of {len(target_text)} texts: {target_text[t_i]}")
        # compute cosine similarity
        text_feat = F.normalize(text_feat.unsqueeze(0), dim=1, p=2)  
        leaf_lang_feat = F.normalize(leaf_lang_feat, dim=1, p=2)  
        cosine_similarity = torch.matmul(text_feat, leaf_lang_feat.transpose(0, 1))
        max_id = torch.argmax(cosine_similarity, dim=-1) # [cluster_num]
        text_leaf_indices = max_id

        top_values, top_indices = torch.topk(cosine_similarity, 10)
        for candidate_id in top_indices[0][1:]:
            if candidate_id - max_id < leaf_num:  # TODO !!!!!!
                max_feat = leaf_code_book['ins_feat'][max_id]
                candi_feat = leaf_code_book['ins_feat'][candidate_id]
                distances = torch.norm(max_feat - candi_feat, dim=1)
                if distances < 0.9:
                    text_leaf_indices = torch.cat([text_leaf_indices, candidate_id.unsqueeze(0)])

        # render
        for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
            # note: evaluation frame
            scene_gt_frames = {
                "waldo_kitchen": ["frame_00053", "frame_00066", "frame_00089", "frame_00140", "frame_00154"],
                "ramen": ["frame_00006", "frame_00024", "frame_00060", "frame_00065", "frame_00081", "frame_00119", "frame_00128"],
                "figurines": ["frame_00041", "frame_00105", "frame_00152", "frame_00195"],
                "teatime": ["frame_00002", "frame_00025", "frame_00043", "frame_00107", "frame_00129", "frame_00140"]
            }
            candidate_frames = scene_gt_frames[scene_name]
            
            if  view.image_name not in candidate_frames:
                continue

            render_pkg = render(view, gaussians, pipeline, background, iteration, rescale=False)
            # RGB
            rendering = render_pkg["render"]
            gt = view.original_image[0:3, :, :]

            # ins_feat
            rendered_ins_feat = render_pkg["ins_feat"]
            gt_sam_mask = view.original_sam_mask.cuda()    # [4, H, W]

            # RGB
            torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
            torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))

            # ins_feat
            torchvision.utils.save_image(rendered_ins_feat[:3,:,:], os.path.join(render_ins_feat_path, '{0:05d}'.format(idx) + "_1.png"))
            torchvision.utils.save_image(rendered_ins_feat[3:6,:,:], os.path.join(render_ins_feat_path, '{0:05d}'.format(idx) + "_2.png"))

            # NOTE get SAM id, mask bool, mask_feat, invalid pix
            mask_id, mask_bool, mask_feat, invalid_pix = \
                get_SAM_mask_and_feat(gt_sam_mask, level=3, original_mask_feat=view.original_mask_feat)
            
            # sam mask
            mask_color_rand = colors_defined[mask_id.detach().cpu().type(torch.int64)].type(torch.float64)
            mask_color_rand = mask_color_rand.permute(2, 0, 1)
            torchvision.utils.save_image(mask_color_rand/255.0, os.path.join(gt_sam_mask_path, '{0:05d}'.format(idx) + ".png"))
            
            # render target object
            render_pkg = render(view, gaussians, pipeline, background, iteration,
                                rescale=False,                #)  # wherther to re-scale the gaussian scale
                                # cluster_idx=leaf_cluster_indices,     # root id
                                leaf_cluster_idx=leaf_cluster_indices,  # leaf id
                                selected_leaf_id=text_leaf_indices.cuda(),  # text query 所选择的 leaf id
                                render_feat_map=False, 
                                render_cluster=False,
                                better_vis=True,
                                seg_rgb=True,
                                post_process=True,
                                root_num=root_num, leaf_num=leaf_num)
            rendered_cluster_imgs = render_pkg["leaf_clusters_imgs"]
            occured_leaf_id = render_pkg["occured_leaf_id"]
            rendered_leaf_cluster_silhouettes = render_pkg["leaf_cluster_silhouettes"]

            render_cluster_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_cluster")
            render_cluster_silhouette_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_cluster_silhouette")
            makedirs(render_cluster_path, exist_ok=True)
            makedirs(render_cluster_silhouette_path, exist_ok=True)
            for i, img in enumerate(rendered_cluster_imgs):
                # save object RGB
                torchvision.utils.save_image(img[:3,:,:], os.path.join(render_cluster_path, \
                    view.image_name + f"_{target_text[t_i]}.png"))
                # save object mask
                cluster_silhouette = rendered_leaf_cluster_silhouettes[i] > 0.7
                torchvision.utils.save_image(cluster_silhouette.to(torch.float32), os.path.join(render_cluster_silhouette_path, \
                    view.image_name + f"_{target_text[t_i]}.png"))
        
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool,
                scene_name: str):
    with torch.no_grad():
        gaussians = GaussianModel(dataset.sh_degree)
        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

        # bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
        bg_color = [1,1,1]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        if not skip_train:
             render_set(dataset.model_path, "text2obj", scene.loaded_iter, scene.getTrainCameras(), 
                        gaussians, pipeline, background, scene_name)
        if not skip_test:
             render_set(dataset.model_path, "text2obj", scene.loaded_iter, scene.getTestCameras(), 
                        gaussians, pipeline, background, scene_name)

if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    parser.add_argument("--scene_name", type=str, choices=["waldo_kitchen", "ramen", "figurines", "teatime"],
                        help="Specify the scene_name from: figurines, teatime, ramen, waldo_kitchen")
    args = get_combined_args(parser)
    print("Rendering " + args.model_path)

    if not args.scene_name:
        parser.error("The --scene_name argument is required and must be one of: waldo_kitchen, ramen, figurines, teatime")

    # Initialize system state (RNG)
    safe_state(args.quiet)

    render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.scene_name)

================================================
FILE: scene/__init__.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import random
import json
from utils.system_utils import searchForMaxIteration
from scene.dataset_readers import sceneLoadTypeCallbacks
from scene.gaussian_model import GaussianModel
from arguments import ModelParams
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON

class Scene:

    gaussians : GaussianModel

    def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
        """b
        :param path: Path to colmap scene main folder.
        """
        self.model_path = args.model_path
        self.loaded_iter = None
        self.gaussians = gaussians

        if load_iteration:
            if load_iteration == -1:
                self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
            else:
                self.loaded_iter = load_iteration
            print("Loading trained model at iteration {}".format(self.loaded_iter))

        self.train_cameras = {}
        self.test_cameras = {}

        if os.path.exists(os.path.join(args.source_path, "sparse")):
            scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
        elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
            print("Found transforms_train.json file, assuming Blender data set!")
            scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
        else:
            assert False, "Could not recognize scene type!"

        if not self.loaded_iter:
            with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
                dest_file.write(src_file.read())
            json_cams = []
            camlist = []
            if scene_info.test_cameras:
                camlist.extend(scene_info.test_cameras)
            if scene_info.train_cameras:
                camlist.extend(scene_info.train_cameras)
            for id, cam in enumerate(camlist):
                json_cams.append(camera_to_JSON(id, cam))
            with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
                json.dump(json_cams, file)

        if shuffle:
            random.shuffle(scene_info.train_cameras)  # Multi-res consistent random shuffling
            random.shuffle(scene_info.test_cameras)  # Multi-res consistent random shuffling

        self.cameras_extent = scene_info.nerf_normalization["radius"]

        for resolution_scale in resolution_scales:
            print("Resolution: ", resolution_scale)
            print("Loading Training Cameras")
            self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
            print("Loading Test Cameras")
            self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)

        if self.loaded_iter:
            self.gaussians.load_ply(os.path.join(self.model_path,
                                                           "point_cloud",
                                                           "iteration_" + str(self.loaded_iter),
                                                           "point_cloud.ply"))
        else:
            self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)

    def save(self, iteration, save_q=[]):
        point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
        self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"), save_q)

    def getTrainCameras(self, scale=1.0):
        return self.train_cameras[scale]

    def getTestCameras(self, scale=1.0):
        return self.test_cameras[scale]

================================================
FILE: scene/cameras.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
from torch import nn
import numpy as np
from utils.graphics_utils import getWorld2View2, getProjectionMatrix

class Camera(nn.Module):
    def __init__(self, colmap_id, R, T, FoVx, FoVy, cx, cy, image, depth, gt_alpha_mask,
                 gt_sam_mask, gt_mask_feat,
                 image_name, uid,
                 trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
                 ):
        super(Camera, self).__init__()

        self.uid = uid
        self.colmap_id = colmap_id
        self.R = R
        self.T = T
        self.FoVx = FoVx
        self.FoVy = FoVy
        # modify -----
        self.cx = cx
        self.cy = cy
        # modify -----
        self.image_name = image_name

        try:
            self.data_device = torch.device(data_device)
        except Exception as e:
            print(e)
            print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
            self.data_device = torch.device("cuda")

        self.data_on_gpu = True     # note
        self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
        # modify -----
        self.original_mask = gt_alpha_mask.to(self.data_device) if gt_alpha_mask is not None else None
        
        # modify -----
        self.original_sam_mask = gt_sam_mask.to(self.data_device) if gt_sam_mask is not None else None
        self.original_mask_feat = gt_mask_feat.to(self.data_device) if gt_mask_feat is not None else None
        self.pesudo_ins_feat = None
        self.pesudo_mask_bool = None
        self.cluster_masks = None
        self.bClusterOccur = None

        self.image_width = self.original_image.shape[2]
        self.image_height = self.original_image.shape[1]

        if gt_alpha_mask is not None:
            self.original_image *= gt_alpha_mask.to(self.data_device)
        else:
            self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)

        self.zfar = 100.0
        self.znear = 0.01

        self.trans = trans
        self.scale = scale

        self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
        self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
        self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
        self.camera_center = self.world_view_transform.inverse()[3, :3]
    
    # modify -----
    def to_gpu(self):
        for attr_name in dir(self):
            attr = getattr(self, attr_name)
            if isinstance(attr, torch.Tensor) and not attr.is_cuda:
                setattr(self, attr_name, attr.to('cuda'))
        self.data_on_gpu = True

    # modify -----
    def to_cpu(self):
        for attr_name in dir(self):
            attr = getattr(self, attr_name)
            if isinstance(attr, torch.Tensor) and attr.is_cuda:
                setattr(self, attr_name, attr.to('cpu'))
        self.data_on_gpu = False

class MiniCam:
    def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
        self.image_width = width
        self.image_height = height    
        self.FoVy = fovy
        self.FoVx = fovx
        self.znear = znear
        self.zfar = zfar
        self.world_view_transform = world_view_transform
        self.full_proj_transform = full_proj_transform
        view_inv = torch.inverse(self.world_view_transform)
        self.camera_center = view_inv[3][:3]



================================================
FILE: scene/colmap_loader.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import numpy as np
import collections
import struct

CameraModel = collections.namedtuple(
    "CameraModel", ["model_id", "model_name", "num_params"])
Camera = collections.namedtuple(
    "Camera", ["id", "model", "width", "height", "params"])
BaseImage = collections.namedtuple(
    "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
Point3D = collections.namedtuple(
    "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
CAMERA_MODELS = {
    CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
    CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
    CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
    CameraModel(model_id=3, model_name="RADIAL", num_params=5),
    CameraModel(model_id=4, model_name="OPENCV", num_params=8),
    CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
    CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
    CameraModel(model_id=7, model_name="FOV", num_params=5),
    CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
    CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
    CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
}
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
                         for camera_model in CAMERA_MODELS])
CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
                           for camera_model in CAMERA_MODELS])


def qvec2rotmat(qvec):
    return np.array([
        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])

def rotmat2qvec(R):
    Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
    K = np.array([
        [Rxx - Ryy - Rzz, 0, 0, 0],
        [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
        [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
        [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
    eigvals, eigvecs = np.linalg.eigh(K)
    qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
    if qvec[0] < 0:
        qvec *= -1
    return qvec

class Image(BaseImage):
    def qvec2rotmat(self):
        return qvec2rotmat(self.qvec)

def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
    """Read and unpack the next bytes from a binary file.
    :param fid:
    :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
    :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
    :param endian_character: Any of {@, =, <, >, !}
    :return: Tuple of read and unpacked values.
    """
    data = fid.read(num_bytes)
    return struct.unpack(endian_character + format_char_sequence, data)

def read_points3D_text(path):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::ReadPoints3DText(const std::string& path)
        void Reconstruction::WritePoints3DText(const std::string& path)
    """
    xyzs = None
    rgbs = None
    errors = None
    num_points = 0
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                num_points += 1


    xyzs = np.empty((num_points, 3))
    rgbs = np.empty((num_points, 3))
    errors = np.empty((num_points, 1))
    count = 0
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                elems = line.split()
                xyz = np.array(tuple(map(float, elems[1:4])))
                rgb = np.array(tuple(map(int, elems[4:7])))
                error = np.array(float(elems[7]))
                xyzs[count] = xyz
                rgbs[count] = rgb
                errors[count] = error
                count += 1

    return xyzs, rgbs, errors

def read_points3D_binary(path_to_model_file):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::ReadPoints3DBinary(const std::string& path)
        void Reconstruction::WritePoints3DBinary(const std::string& path)
    """


    with open(path_to_model_file, "rb") as fid:
        num_points = read_next_bytes(fid, 8, "Q")[0]

        xyzs = np.empty((num_points, 3))
        rgbs = np.empty((num_points, 3))
        errors = np.empty((num_points, 1))

        for p_id in range(num_points):
            binary_point_line_properties = read_next_bytes(
                fid, num_bytes=43, format_char_sequence="QdddBBBd")
            xyz = np.array(binary_point_line_properties[1:4])
            rgb = np.array(binary_point_line_properties[4:7])
            error = np.array(binary_point_line_properties[7])
            track_length = read_next_bytes(
                fid, num_bytes=8, format_char_sequence="Q")[0]
            track_elems = read_next_bytes(
                fid, num_bytes=8*track_length,
                format_char_sequence="ii"*track_length)
            xyzs[p_id] = xyz
            rgbs[p_id] = rgb
            errors[p_id] = error
    return xyzs, rgbs, errors

def read_intrinsics_text(path):
    """
    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
    """
    cameras = {}
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                elems = line.split()
                camera_id = int(elems[0])
                model = elems[1]
                assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
                width = int(elems[2])
                height = int(elems[3])
                params = np.array(tuple(map(float, elems[4:])))
                cameras[camera_id] = Camera(id=camera_id, model=model,
                                            width=width, height=height,
                                            params=params)
    return cameras

def read_extrinsics_binary(path_to_model_file):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::ReadImagesBinary(const std::string& path)
        void Reconstruction::WriteImagesBinary(const std::string& path)
    """
    images = {}
    with open(path_to_model_file, "rb") as fid:
        num_reg_images = read_next_bytes(fid, 8, "Q")[0]
        for _ in range(num_reg_images):
            binary_image_properties = read_next_bytes(
                fid, num_bytes=64, format_char_sequence="idddddddi")
            image_id = binary_image_properties[0]
            qvec = np.array(binary_image_properties[1:5])
            tvec = np.array(binary_image_properties[5:8])
            camera_id = binary_image_properties[8]
            image_name = ""
            current_char = read_next_bytes(fid, 1, "c")[0]
            while current_char != b"\x00":   # look for the ASCII 0 entry
                image_name += current_char.decode("utf-8")
                current_char = read_next_bytes(fid, 1, "c")[0]
            num_points2D = read_next_bytes(fid, num_bytes=8,
                                           format_char_sequence="Q")[0]
            x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
                                       format_char_sequence="ddq"*num_points2D)
            xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
                                   tuple(map(float, x_y_id_s[1::3]))])
            point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
            images[image_id] = Image(
                id=image_id, qvec=qvec, tvec=tvec,
                camera_id=camera_id, name=image_name,
                xys=xys, point3D_ids=point3D_ids)
    return images


def read_intrinsics_binary(path_to_model_file):
    """
    see: src/base/reconstruction.cc
        void Reconstruction::WriteCamerasBinary(const std::string& path)
        void Reconstruction::ReadCamerasBinary(const std::string& path)
    """
    cameras = {}
    with open(path_to_model_file, "rb") as fid:
        num_cameras = read_next_bytes(fid, 8, "Q")[0]
        for _ in range(num_cameras):
            camera_properties = read_next_bytes(
                fid, num_bytes=24, format_char_sequence="iiQQ")
            camera_id = camera_properties[0]
            model_id = camera_properties[1]
            model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
            width = camera_properties[2]
            height = camera_properties[3]
            num_params = CAMERA_MODEL_IDS[model_id].num_params
            params = read_next_bytes(fid, num_bytes=8*num_params,
                                     format_char_sequence="d"*num_params)
            cameras[camera_id] = Camera(id=camera_id,
                                        model=model_name,
                                        width=width,
                                        height=height,
                                        params=np.array(params))
        assert len(cameras) == num_cameras
    return cameras


def read_extrinsics_text(path):
    """
    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
    """
    images = {}
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                elems = line.split()
                image_id = int(elems[0])
                qvec = np.array(tuple(map(float, elems[1:5])))
                tvec = np.array(tuple(map(float, elems[5:8])))
                camera_id = int(elems[8])
                image_name = elems[9]
                elems = fid.readline().split()
                xys = np.column_stack([tuple(map(float, elems[0::3])),
                                       tuple(map(float, elems[1::3]))])
                point3D_ids = np.array(tuple(map(int, elems[2::3])))
                images[image_id] = Image(
                    id=image_id, qvec=qvec, tvec=tvec,
                    camera_id=camera_id, name=image_name,
                    xys=xys, point3D_ids=point3D_ids)
    return images


def read_colmap_bin_array(path):
    """
    Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py

    :param path: path to the colmap binary file.
    :return: nd array with the floating point values in the value
    """
    with open(path, "rb") as fid:
        width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
                                                usecols=(0, 1, 2), dtype=int)
        fid.seek(0)
        num_delimiter = 0
        byte = fid.read(1)
        while True:
            if byte == b"&":
                num_delimiter += 1
                if num_delimiter >= 3:
                    break
            byte = fid.read(1)
        array = np.fromfile(fid, np.float32)
    array = array.reshape((width, height, channels), order="F")
    return np.transpose(array, (1, 0, 2)).squeeze()


================================================
FILE: scene/dataset_readers.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import sys
from PIL import Image
from typing import NamedTuple
from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
    read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
import numpy as np
import json
import random
from tqdm import tqdm
from pathlib import Path
from plyfile import PlyData, PlyElement
from utils.sh_utils import SH2RGB
from scene.gaussian_model import BasicPointCloud

class CameraInfo(NamedTuple):
    uid: int
    R: np.array
    T: np.array
    FovY: np.array
    FovX: np.array
    cx: np.array
    cy: np.array
    image: np.array
    depth: np.array     # not used
    sam_mask: np.array  # modify -----
    mask_feat: np.array # modify -----
    image_path: str
    image_name: str
    width: int
    height: int

class SceneInfo(NamedTuple):
    point_cloud: BasicPointCloud
    train_cameras: list
    test_cameras: list
    nerf_normalization: dict
    ply_path: str

def getNerfppNorm(cam_info):
    def get_center_and_diag(cam_centers):
        cam_centers = np.hstack(cam_centers)
        avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
        center = avg_cam_center
        dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
        diagonal = np.max(dist)
        return center.flatten(), diagonal

    cam_centers = []

    for cam in cam_info:
        W2C = getWorld2View2(cam.R, cam.T)
        C2W = np.linalg.inv(W2C)
        cam_centers.append(C2W[:3, 3:4])

    center, diagonal = get_center_and_diag(cam_centers)
    radius = diagonal * 1.1

    translate = -center

    return {"translate": translate, "radius": radius}

def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
    cam_infos = []

    for idx, key in enumerate(cam_extrinsics):
        sys.stdout.write('\r')
        # the exact output you're looking for:
        sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
        sys.stdout.flush()

        extr = cam_extrinsics[key]
        intr = cam_intrinsics[extr.camera_id]
        height = intr.height
        width = intr.width

        uid = intr.id
        R = np.transpose(qvec2rotmat(extr.qvec))
        T = np.array(extr.tvec)

        if intr.model=="SIMPLE_PINHOLE":
            focal_length_x = intr.params[0]
            FovY = focal2fov(focal_length_x, height)
            FovX = focal2fov(focal_length_x, width)
        elif intr.model=="PINHOLE":
            focal_length_x = intr.params[0]
            focal_length_y = intr.params[1]
            FovY = focal2fov(focal_length_y, height)
            FovX = focal2fov(focal_length_x, width)
        else:
            assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"

        image_path = os.path.join(images_folder, os.path.basename(extr.name))
        if not os.path.exists(image_path):
            # modify -----
            base, ext = os.path.splitext(image_path)
            if ext.lower() == ".jpg":
                image_path = base + ".png"
            elif ext.lower() == ".png":
                image_path = base + ".jpg"
            if not os.path.exists(image_path):
                continue
            # modify ----

        image_name = os.path.basename(image_path).split(".")[0]
        image = Image.open(image_path)

        # NOTE: load SAM mask and CLIP feat. [OpenGaussian]
        mask_seg_path = os.path.join(images_folder[:-6], "language_features/" + extr.name.split('/')[-1][:-4] + "_s.npy")
        mask_feat_path = os.path.join(images_folder[:-6], "language_features/" + extr.name.split('/')[-1][:-4] + "_f.npy")
        if os.path.exists(mask_seg_path):
            sam_mask = np.load(mask_seg_path)    # [level=4, H, W]
        else:
            sam_mask = None
        if mask_feat_path is not None and os.path.exists(mask_feat_path):
            mask_feat = np.load(mask_feat_path)    # [level=4, H, W]
        else:
            mask_feat = None
        # modify -----

        cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, cx=width/2, cy=height/2, image=image, 
                              depth=None, sam_mask=sam_mask, mask_feat=mask_feat,
                              image_path=image_path, image_name=image_name, width=width, height=height)
        cam_infos.append(cam_info)
    sys.stdout.write('\n')
    return cam_infos

def fetchPly(path):
    plydata = PlyData.read(path)
    vertices = plydata['vertex']
    positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
    if {'red', 'green', 'blue'}.issubset(vertices.data.dtype.names):
        colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
    else:
        colors = np.random.rand(positions.shape[0], 3)
    if {'nx', 'ny', 'nz'}.issubset(vertices.data.dtype.names):
        normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
    else:
        normals = np.random.rand(positions.shape[0], 3)

    return BasicPointCloud(points=positions, colors=colors, normals=normals)

def storePly(path, xyz, rgb):
    # Define the dtype for the structured array
    dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
            ('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
            ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
    
    normals = np.zeros_like(xyz)

    elements = np.empty(xyz.shape[0], dtype=dtype)
    attributes = np.concatenate((xyz, normals, rgb), axis=1)
    elements[:] = list(map(tuple, attributes))

    # Create the PlyData object and write to file
    vertex_element = PlyElement.describe(elements, 'vertex')
    ply_data = PlyData([vertex_element])
    ply_data.write(path)

def readColmapSceneInfo(path, images, eval, llffhold=8):
    try:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
        cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
    except:
        cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
        cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
        cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
        cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)

    reading_dir = "images" if images == None else images
    cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
    cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)

    if eval:
        train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
        test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
    else:
        train_cam_infos = cam_infos
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(path, "sparse/0/points3D.ply")
    bin_path = os.path.join(path, "sparse/0/points3D.bin")
    txt_path = os.path.join(path, "sparse/0/points3D.txt")
    if not os.path.exists(ply_path):
        print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
        try:
            xyz, rgb, _ = read_points3D_binary(bin_path)
        except:
            xyz, rgb, _ = read_points3D_text(txt_path)
        storePly(ply_path, xyz, rgb)
    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(point_cloud=pcd,
                           train_cameras=train_cam_infos,
                           test_cameras=test_cam_infos,
                           nerf_normalization=nerf_normalization,
                           ply_path=ply_path)
    return scene_info

def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
    cam_infos = []

    with open(os.path.join(path, transformsfile)) as json_file:
        contents = json.load(json_file)

        # ----- modify -----
        if "camera_angle_x" not in contents.keys():
            fovx = None
        else:
            fovx = contents["camera_angle_x"] 
        # ----- modify -----

        # modify -----
        cx, cy = -1, -1
        if "cx" in contents.keys():
            cx = contents["cx"]
            cy = contents["cy"]
        elif "h" in contents.keys():
            cx = contents["w"] / 2
            cy = contents["h"] / 2
        # modify -----

        frames = contents["frames"]
        # for idx, frame in enumerate(frames):
        for idx, frame in tqdm(enumerate(frames), total=len(frames), desc="load images"):
            cam_name = os.path.join(path, frame["file_path"] + extension)

            # NeRF 'transform_matrix' is a camera-to-world transform
            c2w = np.array(frame["transform_matrix"])
            # change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
            c2w[:3, 1:3] *= -1    # TODO

            # get the world-to-camera transform and set R, T
            w2c = np.linalg.inv(c2w)
            R = np.transpose(w2c[:3,:3])  # R is stored transposed due to 'glm' in CUDA code
            T = w2c[:3, 3]

            image_path = os.path.join(path, cam_name)
            if not os.path.exists(image_path):
                # modify -----
                base, ext = os.path.splitext(image_path)
                if ext.lower() == ".jpg":
                    image_path = base + ".png"
                elif ext.lower() == ".png":
                    image_path = base + ".jpg"
                if not os.path.exists(image_path):
                    continue
                # modify ----

            image_name = Path(cam_name).stem
            image = Image.open(image_path)

            im_data = np.array(image.convert("RGBA"))

            bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])

            norm_data = im_data / 255.0
            arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
            image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")

            # NOTE: load SAM mask and CLIP feat. [OpenGaussian]
            mask_seg_path = os.path.join(path, "language_features/" + frame["file_path"].split('/')[-1] + "_s.npy")
            mask_feat_path = os.path.join(path, "language_features/" + frame["file_path"].split('/')[-1] + "_f.npy")
            if os.path.exists(mask_seg_path):
                sam_mask = np.load(mask_seg_path)    # [level=4, H, W]
            else:
                sam_mask = None
            if os.path.exists(mask_feat_path):
                mask_feat = np.load(mask_feat_path)  # [num_mask, dim=512]
            else:
                mask_feat = None
            # modify -----

            # ----- modify -----
            if "K" in frame.keys():
                cx = frame["K"][0][2]
                cy = frame["K"][1][2]
            if cx == -1:
                cx = image.size[0] / 2
                cy = image.size[1] / 2
            # ----- modify -----

            # ----- modify -----
            if fovx == None:
                if "K" in frame.keys():
                    focal_length = frame["K"][0][0]
                if "fl_x" in contents.keys():
                    focal_length = contents["fl_x"]
                if "fl_x" in frame.keys():
                    focal_length = frame["fl_x"]
                FovY = focal2fov(focal_length, image.size[1])
                FovX = focal2fov(focal_length, image.size[0])
            else:
                fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
                FovY = fovx 
                FovX = fovy
            # ----- modify -----

            cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, cx=cx, cy=cy, image=image, 
                            depth=None, sam_mask=sam_mask, mask_feat=mask_feat,
                            image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
            
    return cam_infos

def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
    print("Reading Training Transforms")
    train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
    print("Reading Test Transforms")
    if os.path.exists(os.path.join(path, "transforms_test.json")):
        test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
    else:
        test_cam_infos = train_cam_infos
    
    if not eval:
        train_cam_infos.extend(test_cam_infos)
        test_cam_infos = []

    nerf_normalization = getNerfppNorm(train_cam_infos)

    ply_path = os.path.join(path, "points3d.ply")
    if not os.path.exists(ply_path):
        # Since this data set has no colmap data, we start with random points
        num_pts = 100_000
        print(f"Generating random point cloud ({num_pts})...")
        
        # We create random points inside the bounds of the synthetic Blender scenes
        xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
        shs = np.random.random((num_pts, 3)) / 255.0
        pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))

        storePly(ply_path, xyz, SH2RGB(shs) * 255)
    try:
        pcd = fetchPly(ply_path)
    except:
        pcd = None

    scene_info = SceneInfo(point_cloud=pcd,
                           train_cameras=train_cam_infos,
                           test_cameras=test_cam_infos,
                           nerf_normalization=nerf_normalization,
                           ply_path=ply_path)
    return scene_info

sceneLoadTypeCallbacks = {
    "Colmap": readColmapSceneInfo,
    "Blender" : readNerfSyntheticInfo
}

================================================
FILE: scene/gaussian_model.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import numpy as np
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
from torch import nn
import os
from utils.system_utils import mkdir_p
from plyfile import PlyData, PlyElement
from utils.sh_utils import RGB2SH
# from simple_knn._C import distCUDA2   # no need
from scipy.spatial import KDTree        # modify
from utils.graphics_utils import BasicPointCloud
from utils.general_utils import strip_symmetric, build_scaling_rotation

def sigmoid(x):  
    return 1 / (1 + np.exp(-x))  

def distCUDA2(points):
    '''
    https://github.com/graphdeco-inria/gaussian-splatting/issues/292
    '''
    points_np = points.detach().cpu().float().numpy()
    dists, inds = KDTree(points_np).query(points_np, k=4)
    meanDists = (dists[:, 1:] ** 2).mean(1)

    return torch.tensor(meanDists, dtype=points.dtype, device=points.device)

class GaussianModel:

    def setup_functions(self):
        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
            L = build_scaling_rotation(scaling_modifier * scaling, rotation)
            actual_covariance = L @ L.transpose(1, 2)
            symm = strip_symmetric(actual_covariance)
            return symm
        
        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log

        self.covariance_activation = build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize


    def __init__(self, sh_degree : int):
        self.active_sh_degree = 0
        self.max_sh_degree = sh_degree  
        self._xyz = torch.empty(0)
        self._features_dc = torch.empty(0)
        self._features_rest = torch.empty(0)
        self._scaling = torch.empty(0)
        self._rotation = torch.empty(0)
        self._opacity = torch.empty(0)
        self._ins_feat = torch.empty(0)     # Continuous instance features before quantization
        self._ins_feat_q = torch.empty(0)   # Discrete instance features after quantization
        self.iClusterSubNum = torch.empty(0)
        self.max_radii2D = torch.empty(0)
        self.xyz_gradient_accum = torch.empty(0)
        self.denom = torch.empty(0)
        self.optimizer = None
        self.percent_dense = 0
        self.spatial_lr_scale = 0
        self.setup_functions()

    def capture(self):
        return (
            self.active_sh_degree,
            self._xyz,
            self._features_dc,
            self._features_rest,
            self._scaling,
            self._rotation,
            self._opacity,
            self._ins_feat,     # Continuous instance features before quantization
            self._ins_feat_q,   # Discrete instance features after quantization
            self.max_radii2D,
            self.xyz_gradient_accum,
            self.denom,
            self.optimizer.state_dict(),
            self.spatial_lr_scale,
        )
    
    def restore(self, model_args, training_args):
        (self.active_sh_degree, 
        self._xyz, 
        self._features_dc, 
        self._features_rest,
        self._scaling, 
        self._rotation, 
        self._opacity,
        self._ins_feat,     # Continuous instance features before quantization
        self._ins_feat_q,   # Discrete instance features after quantization
        self.max_radii2D, 
        xyz_gradient_accum, 
        denom,
        opt_dict, 
        self.spatial_lr_scale) = model_args
        self.training_setup(training_args)
        self.xyz_gradient_accum = xyz_gradient_accum
        self.denom = denom
        self.optimizer.load_state_dict(opt_dict)

    @property
    def get_scaling(self):
        return self.scaling_activation(self._scaling)
    
    @property
    def get_scaling_origin(self):
        return self.scaling_activation(self._scaling)
    
    @property
    def get_rotation(self):
        return self.rotation_activation(self._rotation)
    
    @property
    def get_rotation_matrix(self):
        return build_rotation(self._rotation)
    
    @property
    def get_eigenvector(self):
        scales = self.get_scaling_origin
        N = scales.shape[0]
        idx = torch.min(scales, dim=1)[1]
        normals = self.get_rotation_matrix[np.arange(N), :, idx]
        normals = torch.nn.functional.normalize(normals, dim=1)
        return normals
    
    @property
    def get_xyz(self):
        return self._xyz
    
    @property
    def get_features(self):
        features_dc = self._features_dc
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)
    
    @property
    def get_opacity(self):
        return self.opacity_activation(self._opacity)
    
    # NOTE: get instance feature
    # @property
    def get_ins_feat(self, origin=False):
        if len(self._ins_feat_q) == 0 or origin:
            ins_feat = self._ins_feat
        else:
            ins_feat = self._ins_feat_q
        ins_feat = torch.nn.functional.normalize(ins_feat, dim=1)
        return ins_feat
    
    def get_covariance(self, scaling_modifier = 1):
        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)

    def oneupSHdegree(self):
        if self.active_sh_degree < self.max_sh_degree:
            self.active_sh_degree += 1

    def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
        self.spatial_lr_scale = spatial_lr_scale
        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
        features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() # [N, 3, 16]
        features[:, :3, 0 ] = fused_color
        features[:, 3:, 1:] = 0.0

        print("Number of points at initialisation : ", fused_point_cloud.shape[0])

        dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
        scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
        rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
        rots[:, 0] = 1

        opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))

        # modify -----
        ins_feat = torch.rand((fused_point_cloud.shape[0], 6), dtype=torch.float, device="cuda")

        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
        self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
        self._scaling = nn.Parameter(scales.requires_grad_(True))
        self._rotation = nn.Parameter(rots.requires_grad_(True))
        self._opacity = nn.Parameter(opacities.requires_grad_(True))
        # modify -----
        self._ins_feat = nn.Parameter(ins_feat.requires_grad_(True))
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")

    def training_setup(self, training_args):
        self.percent_dense = training_args.percent_dense
        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")

        l = [
            {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
            {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
            {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
            {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
            {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
            {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
            {'params': [self._ins_feat], 'lr': training_args.ins_feat_lr, "name": "ins_feat"}  # modify -----
        ]

        # note: Freeze the position of the initial point, do not densify. for ScanNet 3DGS pre-train stage
        if training_args.frozen_init_pts:
            self._xyz = self._xyz.detach()

        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
        self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
                                                    lr_final=training_args.position_lr_final*self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.position_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)

    def update_learning_rate(self, iteration, root_start, leaf_start):
        ''' Learning rate scheduling per step '''
        for param_group in self.optimizer.param_groups:
            if param_group["name"] == "xyz":
                lr = self.xyz_scheduler_args(iteration)
                param_group['lr'] = lr
                # return lr
            if param_group["name"] == "ins_feat":
                if iteration > root_start and iteration <= leaf_start:      # TODO: update lr
                    param_group['lr'] = param_group['lr'] * 0 + 0.0001
                else:
                    param_group['lr'] = param_group['lr'] * 0 + 0.001

    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'ins_feat_r', 'ins_feat_g', 'ins_feat_b', \
            'ins_feat_r2', 'ins_feat_g2', 'ins_feat_b2']
        # All channels except the 3 DC
        for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
            l.append('f_dc_{}'.format(i))
        for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
            l.append('f_rest_{}'.format(i))
        l.append('opacity')
        for i in range(self._scaling.shape[1]):
            l.append('scale_{}'.format(i))
        for i in range(self._rotation.shape[1]):
            l.append('rot_{}'.format(i))
        return l

    def save_ply(self, path, save_q=[]):
        mkdir_p(os.path.dirname(path))

        xyz = self._xyz.detach().cpu().numpy()
        normals = np.zeros_like(xyz)
        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        opacities = self._opacity.detach().cpu().numpy()
        scale = self._scaling.detach().cpu().numpy()
        rotation = self._rotation.detach().cpu().numpy()
        if "ins_feat" in save_q:
            ins_feat = self._ins_feat_q.detach().cpu().numpy()
        else:
            ins_feat = self._ins_feat.detach().cpu().numpy()

        # NOTE: pts feat visualization
        vis_color = (ins_feat + 1) / 2 * 255
        r, g, b = vis_color[:, 0].reshape(-1, 1), vis_color[:, 1].reshape(-1, 1), vis_color[:, 2].reshape(-1, 1)

        # todo: points not fully optimized due to sampled training images.
        ignored_ind = sigmoid(opacities) < 0.1
        r[ignored_ind], g[ignored_ind], b[ignored_ind] = 128, 128, 128

        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
        dtype_full = dtype_full + [('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]  # modify

        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        attributes = np.concatenate((xyz, normals, ins_feat,\
                                    f_dc, f_rest, opacities, scale, rotation,\
                                    r, g, b), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)

    def reset_opacity(self):
        opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
        optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
        self._opacity = optimizable_tensors["opacity"]

    def load_ply(self, path):
        plydata = PlyData.read(path)

        xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
                        np.asarray(plydata.elements[0]["y"]),
                        np.asarray(plydata.elements[0]["z"])),  axis=1)
        ins_feat = np.stack((np.asarray(plydata.elements[0]["ins_feat_r"]),
                        np.asarray(plydata.elements[0]["ins_feat_g"]),
                        np.asarray(plydata.elements[0]["ins_feat_b"]),
                        np.asarray(plydata.elements[0]["ins_feat_r2"]),
                        np.asarray(plydata.elements[0]["ins_feat_g2"]),
                        np.asarray(plydata.elements[0]["ins_feat_b2"])),  axis=1)
        opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
        if not opacities.flags['C_CONTIGUOUS']:
            opacities = np.ascontiguousarray(opacities)

        features_dc = np.zeros((xyz.shape[0], 3, 1))
        features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
        features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
        features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])

        extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
        extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
        assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
        for idx, attr_name in enumerate(extra_f_names):
            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
        features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))

        scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
        scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
        scales = np.zeros((xyz.shape[0], len(scale_names)))
        for idx, attr_name in enumerate(scale_names):
            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])

        rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
        rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
        rots = np.zeros((xyz.shape[0], len(rot_names)))
        for idx, attr_name in enumerate(rot_names):
            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])

        self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
        self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
        self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
        self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
        self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
        self._ins_feat = nn.Parameter(torch.tensor(ins_feat, dtype=torch.float, device="cuda").requires_grad_(True))

        self.active_sh_degree = self.max_sh_degree

    def replace_tensor_to_optimizer(self, tensor, name):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if group["name"] == name:
                stored_state = self.optimizer.state.get(group['params'][0], None)
                stored_state["exp_avg"] = torch.zeros_like(tensor)
                stored_state["exp_avg_sq"] = torch.zeros_like(tensor)

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def _prune_optimizer(self, mask):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            stored_state = self.optimizer.state.get(group['params'][0], None)
            if stored_state is not None:
                stored_state["exp_avg"] = stored_state["exp_avg"][mask]
                stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def prune_points(self, mask):
        valid_points_mask = ~mask
        optimizable_tensors = self._prune_optimizer(valid_points_mask)

        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]
        self._ins_feat = optimizable_tensors["ins_feat"]

        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]

        self.denom = self.denom[valid_points_mask]
        self.max_radii2D = self.max_radii2D[valid_points_mask]

    def cat_tensors_to_optimizer(self, tensors_dict):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            assert len(group["params"]) == 1
            extension_tensor = tensors_dict[group["name"]]
            stored_state = self.optimizer.state.get(group['params'][0], None)
            if stored_state is not None:

                stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
                stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
                optimizable_tensors[group["name"]] = group["params"][0]

        return optimizable_tensors

    def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, \
                                new_scaling, new_rotation, new_ins_feat):
        d = {"xyz": new_xyz,
        "f_dc": new_features_dc,
        "f_rest": new_features_rest,
        "opacity": new_opacities,
        "scaling" : new_scaling,
        "rotation" : new_rotation,
        "ins_feat": new_ins_feat}

        optimizable_tensors = self.cat_tensors_to_optimizer(d)
        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]
        self._ins_feat = optimizable_tensors["ins_feat"]

        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")

    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
        n_init_points = self.get_xyz.shape[0]
        # Extract points that satisfy the gradient condition
        padded_grad = torch.zeros((n_init_points), device="cuda")
        padded_grad[:grads.shape[0]] = grads.squeeze()
        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
        selected_pts_mask = torch.logical_and(selected_pts_mask,
                                              torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)

        stds = self.get_scaling[selected_pts_mask].repeat(N,1)
        means =torch.zeros((stds.size(0), 3),device="cuda")
        samples = torch.normal(mean=means, std=stds)
        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
        new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
        new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
        new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
        new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
        new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
        new_ins_feat = self._ins_feat[selected_pts_mask].repeat(N,1)

        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, \
            new_opacity, new_scaling, new_rotation, new_ins_feat)

        prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
        self.prune_points(prune_filter)

    def densify_and_clone(self, grads, grad_threshold, scene_extent):
        # Extract points that satisfy the gradient condition
        selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
        selected_pts_mask = torch.logical_and(selected_pts_mask,
                                              torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
        
        new_xyz = self._xyz[selected_pts_mask]
        new_features_dc = self._features_dc[selected_pts_mask]
        new_features_rest = self._features_rest[selected_pts_mask]
        new_opacities = self._opacity[selected_pts_mask]
        new_scaling = self._scaling[selected_pts_mask]
        new_rotation = self._rotation[selected_pts_mask]
        new_ins_feat = self._ins_feat[selected_pts_mask]

        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, \
            new_scaling, new_rotation, new_ins_feat)

    def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
        grads = self.xyz_gradient_accum / self.denom
        grads[grads.isnan()] = 0.0

        self.densify_and_clone(grads, max_grad, extent)
        self.densify_and_split(grads, max_grad, extent)

        prune_mask = (self.get_opacity < min_opacity).squeeze()
        if max_screen_size:
            big_points_vs = self.max_radii2D > max_screen_size
            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
            prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
        self.prune_points(prune_mask)

        torch.cuda.empty_cache()

    def add_densification_stats(self, viewspace_point_tensor, update_filter):
        self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
        self.denom[update_filter] += 1

================================================
FILE: scene/kmeans_quantize.py
================================================
import os
import pdb
from tqdm import tqdm
import time

import torch
import numpy as np
from torch import nn
import torch.nn.functional as F


class Quantize_kMeans():
    def __init__(self, num_clusters=64, num_leaf_clusters=10, num_iters=10, dim=9, dim_leaf=6):
        self.num_clusters = num_clusters            # k1
        self.leaf_num_clusters = num_leaf_clusters  # k2
        self.num_kmeans_iters = num_iters           # iter
        self.vec_dim = dim                          # coarse-level, dim=9(feat+xyz)
        self.leaf_vec_dim = dim_leaf                # fine-level, dim=6(feat)
        self.centers = torch.empty(0)               # coarse center, [k1, 9]
        self.leaf_centers = torch.empty(0)          # fine center, [k2, 6]
        self.iLeafSubNum = torch.empty(0)           # Number of fine clusters per coarse cluster
        self.cls_ids = torch.empty(0)               # coarse cluster id [num_pts]
        self.leaf_cls_ids = torch.empty(0)          # fine cluster id[num_pts]
        
        self.nn_index = torch.empty(0)              # [num_pts] temporary variable

        # for update_centers
        self.cluster_ids = torch.empty(0)
        self.excl_clusters = []
        self.excl_cluster_ids = []
        self.cluster_len = torch.empty(0)
        self.max_cnt = 0                  
        self.max_cnt_th = 10000
        self.n_excl_cls = 0       

        self.pos_centers = torch.empty(0)           

    def get_dist(self, x, y, mode='sq_euclidean'):
        """Calculate distance between all vectors in x and all vectors in y.

        x: (m, dim)
        y: (n, dim)
        dist: (m, n)
        """
        if mode == 'sq_euclidean_chunk':
            step = 65536
            if x.shape[0] < step:
                step = x.shape[0]
            dist = []
            for i in range(np.ceil(x.shape[0] / step).astype(int)):
                dist.append(torch.cdist(x[(i*step): (i+1)*step, :].unsqueeze(0), y.unsqueeze(0))[0])
            dist = torch.cat(dist, 0)
        elif mode == 'sq_euclidean':
            dist = torch.cdist(x.unsqueeze(0).detach(), y.unsqueeze(0).detach())[0]
        return dist

    # Update centers in non-cluster assignment iters using cached nn indices.
    def update_centers(self, feat, mode="root", selected_leaf=-1):
        if mode == "root":
            centers = self.centers
            num_clusters = self.num_clusters
            vec_dim = self.vec_dim
        elif mode == "leaf":
            centers = self.leaf_centers
            num_clusters = self.num_clusters * self.leaf_num_clusters + 1
            vec_dim = self.leaf_vec_dim
        feat = feat.detach().reshape(-1, vec_dim)  # [num_pts, dim] [766267, 9]
        # Update all clusters except the excluded ones in a single operation
        # Add a dummy element with zeros at the end
        feat = torch.cat([feat, torch.zeros_like(feat[:1]).cuda()], 0)  # [num_pts+1, dim]
        centers = torch.sum(feat[self.cluster_ids, :].reshape(
            num_clusters, self.max_cnt, -1), dim=1)    # [num_clusters, vec_dim]
        if len(self.excl_cluster_ids) > 0:
            for i, cls in enumerate(self.excl_clusters):
                # Division by num_points in cluster is done during the one-shot averaging of all
                # clusters below. Only the extra elements in the bigger clusters are added here.
                centers[cls] += torch.sum(feat[self.excl_cluster_ids[i], :], dim=0)
        centers /= (self.cluster_len + 1e-6)

    # Update centers during cluster assignment using mask matrix multiplication
    # Mask is obtained from distance matrix
    def update_centers_(self, feat, cluster_mask=None, nn_index=None, avg=False):
        # feat = feat.detach().reshape(-1, self.vec_dim)
        centers = (cluster_mask.T @ feat)   # [1w, num_cluster] * [1w, dim] -> [num_cluster, dim]
        # if avg:
        #     self.centers /= counts.unsqueeze(-1)
        return centers

    def equalize_cluster_size(self, mode="root"):
        """Make the size of all the clusters the same by appending dummy elements.

        """
        # Find the maximum number of elements in a cluster, make size of all clusters
        # equal by appending dummy elements until size is equal to size of max cluster.
        # If max is too large, exclude it and consider the next biggest. Use for loop for
        # the excluded clusters and a single operation for the remaining ones for
        # updating the cluster centers.

        unq, n_unq = torch.unique(self.nn_index, return_counts=True)
        # Find max cluster size and exclude clusters greater than a threshold
        topk = 100
        if len(n_unq) < topk:
            topk = len(n_unq)
        max_cnt_topk, topk_idx = torch.topk(n_unq, topk)
        self.max_cnt = max_cnt_topk[0]
        idx = 0
        self.excl_clusters = []
        self.excl_cluster_ids = []
        while(self.max_cnt > self.max_cnt_th):
            self.excl_clusters.append(unq[topk_idx[idx]])
            idx += 1
            if idx < topk:
                self.max_cnt = max_cnt_topk[idx]
            else:
                break
        self.n_excl_cls = len(self.excl_clusters)
        self.excl_clusters = sorted(self.excl_clusters)
        # Store the indices of elements for each cluster
        all_ids = []
        cls_len = []
        if mode == "root":
            num_clusters = self.num_clusters
        elif mode == "leaf":
            num_clusters = self.num_clusters * self.leaf_num_clusters + 1
        for i in range(num_clusters):
            cur_cluster_ids = torch.where(self.nn_index == i)[0]
            # For excluded clusters, use only the first max_cnt elements
            # for averaging along with other clusters. Separately average the
            # remaining elements just for the excluded clusters.
            cls_len.append(torch.Tensor([len(cur_cluster_ids)]))
            if i in self.excl_clusters:
                self.excl_cluster_ids.append(cur_cluster_ids[self.max_cnt:])
                cur_cluster_ids = cur_cluster_ids[:self.max_cnt]
            # Append dummy elements to have same size for all clusters
            all_ids.append(torch.cat([cur_cluster_ids, -1 * torch.ones((self.max_cnt - len(cur_cluster_ids)),
                                                                       dtype=torch.long).cuda()]))
        all_ids = torch.cat(all_ids).type(torch.long)
        cls_len = torch.cat(cls_len).type(torch.long)
        self.cluster_ids = all_ids
        self.cluster_len = cls_len.unsqueeze(1).cuda()
        if mode == "root":
            self.cls_ids = self.nn_index
        elif mode == "leaf":
            self.leaf_cls_ids = self.nn_index

    def cluster_assign(self, feat, feat_scaled=None, mode="root", selected_leaf=-1):

        # quantize with kmeans
        feat = feat.detach()    # [N, dim]

        if feat_scaled is None:
            feat_scaled = feat
            scale = feat[0] / (feat_scaled[0] + 1e-8)
        # init. centers and ids
        if len(self.centers) == 0 and mode == "root":
            self.centers = feat[torch.randperm(feat.shape[0])[:self.num_clusters], :]
        if len(self.leaf_centers) == 0 and mode == "leaf":
            # [num_clusters, leaf_num_clusters, dim_leaf] eg. [640, 6]
            self.leaf_centers = feat[torch.randperm(feat.shape[0])[:self.num_clusters * self.leaf_num_clusters+1], :]
            self.leaf_cls_ids = torch.ones(feat.shape[0]).to(torch.int64).cuda() * self.num_clusters * self.leaf_num_clusters

        # start kmeans
        chunk = True
        # tmp centers
        if mode == "root":
            tmp_centers = torch.zeros_like(self.centers)
            counts = torch.zeros(self.num_clusters, dtype=torch.float32).cuda() + 1e-6
        elif mode == "leaf":
            tmp_centers = torch.zeros_like(self.leaf_centers)[:self.leaf_num_clusters, :]
            counts = torch.zeros(self.leaf_num_clusters, dtype=torch.float32).cuda() + 1e-6
            start_id = selected_leaf * self.leaf_num_clusters
            end_id = selected_leaf * self.leaf_num_clusters + self.iLeafSubNum[selected_leaf]
        for iteration in range(self.num_kmeans_iters):
            # chunk for memory issues
            if chunk:
                self.nn_index = None
                i = 0
                chunk = 10000
                if mode == "root":
                    while True:
                        dist = self.get_dist(feat[i*chunk:(i+1)*chunk, :], self.centers)
                        curr_nn_index = torch.argmin(dist, dim=-1)  # [1W]
                        # Assign a single cluster when distance to multiple clusters is same
                        dist = F.one_hot(curr_nn_index, self.num_clusters).type(torch.float32)  # [1W, 512]
                        curr_centers = self.update_centers_(feat[i*chunk:(i+1)*chunk, :], dist, curr_nn_index, avg=False)   # [512, 45]
                        counts += dist.detach().sum(0) + 1e-6   # [512]
                        tmp_centers += curr_centers
                        if self.nn_index == None:
                            self.nn_index = curr_nn_index
                        else:
                            self.nn_index = torch.cat((self.nn_index, curr_nn_index), dim=0)
                        i += 1
                        if i*chunk > feat.shape[0]:
                            break
                elif mode == "leaf":
                    for idx_c in range(self.num_clusters):
                        if idx_c != selected_leaf:
                            continue
                        selected_pts = self.cls_ids == idx_c
                        dist = self.get_dist(feat[selected_pts], self.leaf_centers[start_id:end_id])
                        curr_nn_index = torch.argmin(dist, dim=-1)  # [1W]
                        dist = F.one_hot(curr_nn_index, self.leaf_num_clusters).type(torch.float32)  # [1W, 10]
                        curr_centers = self.update_centers_(feat[selected_pts], dist, curr_nn_index, avg=False)   # [512, 45]
                        counts += dist.detach().sum(0) + 1e-6   # [512]
                        tmp_centers += curr_centers
                        self.leaf_cls_ids[selected_pts] = curr_nn_index + start_id
            # avrage centers
            if mode == "root":
                self.centers = tmp_centers / counts.unsqueeze(-1)   
            elif mode == "leaf":
                self.leaf_centers[start_id: start_id+self.leaf_num_clusters] = tmp_centers / counts.unsqueeze(-1)   
            # Reinitialize to 0
            tmp_centers[tmp_centers != 0] = 0.
            counts[counts > 0.1] = 0.

        # Reassign ID according to the new centers
        if chunk:
            self.nn_index = None
            i = 0
            # chunk = 100000
            if mode == "root":
                while True:
                    dist = self.get_dist(feat_scaled[i * chunk:(i + 1) * chunk, :], self.centers)
                    curr_nn_index = torch.argmin(dist, dim=-1)
                    if self.nn_index == None:
                        self.nn_index = curr_nn_index
                    else:
                        self.nn_index = torch.cat((self.nn_index, curr_nn_index), dim=0)
                    i += 1
                    if i * chunk > feat.shape[0]:
                        break
            elif mode == "leaf":
                for idx_c in range(self.num_clusters):
                    if idx_c != selected_leaf:
                        continue
                    selected_pts = self.cls_ids == idx_c
                    dist = self.get_dist(feat[selected_pts], self.leaf_centers[start_id:end_id])
                    curr_nn_index = torch.argmin(dist, dim=-1)
                    self.leaf_cls_ids[selected_pts] = curr_nn_index + start_id
                self.nn_index = self.leaf_cls_ids
        self.equalize_cluster_size(mode=mode)

    def rescale(self, feat, scale=None):
        """Scale the feature to be in the range [-1, 1] by dividing by its max value.

        """
        if scale is None:
            return feat / (abs(feat).max(dim=0)[0] + 1e-8)
        else:
            return feat / (scale + 1e-8)

    def forward(self, gaussian, iteration, assign=False, mode="root", selected_leaf=-1, pos_weight=1.0):
        if mode == "root":
            # (1) coarse-level: feature + xyz
            scale = pos_weight     # TODO
            xyz_feat = gaussian._xyz.detach() * scale
            feat = torch.cat((gaussian._ins_feat, xyz_feat), dim=1)    # [N, 9]
        elif mode == "leaf":
            # (2) fine-level: feature only
            feat = gaussian._ins_feat

        if assign:
            self.cluster_assign(feat, mode=mode, selected_leaf=selected_leaf)   # gaussian._ins_feat
        else:
            self.update_centers(feat, mode=mode, selected_leaf=selected_leaf)   # gaussian._ins_feat

        if mode == "root":
            centers = self.centers
            vec_dim = self.vec_dim
        elif mode == "leaf":
            centers = self.leaf_centers
            vec_dim = self.leaf_vec_dim
        sampled_centers = torch.gather(centers, 0, self.nn_index.unsqueeze(-1).repeat(1, vec_dim))
        # NOTE: "During backpropagation, the gradients of the quantized features are copied to the instance features", mentioned in the paper.
        gaussian._ins_feat_q = gaussian._ins_feat - gaussian._ins_feat.detach() + sampled_centers[:,:6]

    def replace_with_centers(self, gaussian):
        deg = gaussian._features_rest.shape[1]
        sampled_centers = torch.gather(self.centers, 0, self.nn_index.unsqueeze(-1).repeat(1, self.vec_dim))
        gaussian._features_rest = gaussian._features_rest - gaussian._features_rest.detach() + sampled_centers.reshape(-1, deg, 3)


================================================
FILE: scripts/compute_lerf_iou.py
================================================
import os
import numpy as np
from PIL import Image
from argparse import ArgumentParser

def load_image_as_binary(image_path, is_png=False, threshold=10):
    image = Image.open(image_path)
    if is_png:
        image = image.convert('L')
    image_array = np.array(image)
    binary_image = (image_array > threshold).astype(int)
    return binary_image

def calculate_iou(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    if union == 0:
        return 0
    return intersection / union

def evalute(gt_base, pred_base, scene_name):
    scene_gt_frames = {
        "waldo_kitchen": ["frame_00053", "frame_00066", "frame_00089", "frame_00140", "frame_00154"],
        "ramen": ["frame_00006", "frame_00024", "frame_00060", "frame_00065", "frame_00081", "frame_00119", "frame_00128"],
        "figurines": ["frame_00041", "frame_00105", "frame_00152", "frame_00195"],
        "teatime": ["frame_00002", "frame_00025", "frame_00043", "frame_00107", "frame_00129", "frame_00140"]
    }
    frame_names = scene_gt_frames[scene_name]

    ious = []
    for frame in frame_names:
        print("frame:", frame)
        gt_floder = os.path.join(gt_base, frame)
        file_names = [f for f in os.listdir(gt_floder) if f.endswith('.jpg')]
        for file_name in file_names:
            base_name = os.path.splitext(file_name)[0]
            gt_obj_path = os.path.join(gt_floder, file_name)
            pred_obj_path = os.path.join(pred_base, frame + "_" + base_name + '.png')
            if not os.path.exists(pred_obj_path):
                print(f"Missing pred file for {file_name}, skipping...")
                print(f"IoU for {file_name}: 0")
                ious.append(0.0)
                continue
            mask_gt = load_image_as_binary(gt_obj_path)
            mask_pred = load_image_as_binary(pred_obj_path, is_png=True)
            iou = calculate_iou(mask_gt, mask_pred)
            ious.append(iou)
            print(f"IoU for {file_name} and {base_name + '.png'}: {iou:.4f}")
    
    # Acc.
    total_count = len(ious)
    count_iou_025 = (np.array(ious) > 0.25).sum()
    count_iou_05 = (np.array(ious) > 0.5).sum()

    # mIoU
    average_iou = np.mean(ious)
    print(f"Average IoU: {average_iou:.4f}")
    print(f"Acc@0.25: {count_iou_025/total_count:.4f}")
    print(f"Acc@0.5: {count_iou_05/total_count:.4f}")

if __name__ == "__main__":
    parser = ArgumentParser("Compute LeRF IoU")
    parser.add_argument("--scene_name", type=str, choices=["waldo_kitchen", "ramen", "figurines", "teatime"],
                        help="Specify the scene_name from: figurines, teatime, ramen, waldo_kitchen")
    args = parser.parse_args()
    if not args.scene_name:
        parser.error("The --scene_name argument is required and must be one of: waldo_kitchen, ramen, figurines, teatime")

    # TODO: change
    path_gt = "/gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/label/waldo_kitchen/gt"
    # renders_cluster_silhouette is the predicted mask
    path_pred = "output/xxxxxxxx-x/text2obj/ours_70000/renders_cluster_silhouette"
    evalute(path_gt, path_pred, args.scene_name)

================================================
FILE: scripts/eval_scannet.py
================================================
import os
from plyfile import PlyData, PlyElement
import torch.nn.functional as F
import numpy as np
import torch
import json

nyu40_dict = {
    0: "unlabeled", 1: "wall", 2: "floor", 3: "cabinet", 4: "bed", 5: "chair",
    6: "sofa", 7: "table", 8: "door", 9: "window", 10: "bookshelf",
    11: "picture", 12: "counter", 13: "blinds", 14: "desk", 15: "shelves",
    16: "curtain", 17: "dresser", 18: "pillow", 19: "mirror", 20: "floormat",
    21: "clothes", 22: "ceiling", 23: "books", 24: "refrigerator", 25: "television",
    26: "paper", 27: "towel", 28: "showercurtain", 29: "box", 30: "whiteboard",
    31: "person", 32: "nightstand", 33: "toilet", 34: "sink", 35: "lamp",
    36: "bathtub", 37: "bag", 38: "otherstructure", 39: "otherfurniture", 40: "otherprop"
}

# ScanNet 20 classes
scannet19_dict = {
    1: "wall", 2: "floor", 3: "cabinet", 4: "bed", 5: "chair",
    6: "sofa", 7: "table", 8: "door", 9: "window", 10: "bookshelf",
    11: "picture", 12: "counter", 14: "desk", 16: "curtain",
    24: "refrigerator", 28: "shower curtain", 33: "toilet", 34: "sink",
    36: "bathtub", # 39: "otherfurniture"
}

import numpy as np  
def sigmoid(x):  
    return 1 / (1 + np.exp(-x))  

def write_ply(vertex_data, output_path):
    vertices = []
    for vertex in vertex_data:
        r = (vertex['ins_feat_r'] + 1)/2 * 255
        g = (vertex['ins_feat_g'] + 1)/2 * 255
        b = (vertex['ins_feat_b'] + 1)/2 * 255
        new_vertex = (vertex['x'], vertex['y'], vertex['z'], r, g, b)
        vertices.append(new_vertex)
    
    vertex_dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
    new_vertex_data = np.array(vertices, dtype=vertex_dtype)
    
    el = PlyElement.describe(new_vertex_data, 'vertex')
    PlyData([el], text=True).write(output_path)

def read_labels_from_ply(file_path):
    ply_data = PlyData.read(file_path)
    vertex_data = ply_data['vertex'].data
    # Extract the coordinates and labels of the points. The labels are from 1 to 40 for the NYU40 dataset, with 0 being invalid.
    points = np.vstack([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
    labels = vertex_data['label']
    return points, labels

def calculate_metrics(gt, pred, total_classes):
    gt = gt.cpu()
    pred = pred.cpu()
    pred[gt == 0] = 0

    ious = torch.zeros(total_classes)

    intersection = torch.zeros(total_classes)
    union = torch.zeros(total_classes)
    correct = torch.zeros(total_classes)
    total = torch.zeros(total_classes)

    for cls in range(1, total_classes):
        intersection[cls] = torch.sum((gt == cls) & (pred == cls)).item()
        union[cls] = torch.sum((gt == cls) | (pred == cls)).item()
        correct[cls] = torch.sum((gt == cls) & (pred == cls)).item()
        total[cls] = torch.sum(gt == cls).item()

    valid_union = union != 0
    ious[valid_union] = intersection[valid_union] / union[valid_union]

    # Only consider the categories that exist in the current scene
    gt_classes = torch.unique(gt)
    valid_gt_classes = gt_classes[gt_classes != 0]  # ignore 0

    # miou
    mean_iou = ious[valid_gt_classes].mean().item()

    # acc
    valid_mask = gt != 0
    correct_predictions = torch.sum((gt == pred) & valid_mask).item()
    total_valid_points = torch.sum(valid_mask).item()
    accuracy = correct_predictions / total_valid_points if total_valid_points > 0 else float('nan')

    class_accuracy = correct / total
    # mAcc.
    mean_class_accuracy = class_accuracy[valid_gt_classes].mean().item()

    return ious, mean_iou, accuracy, mean_class_accuracy

if __name__ == "__main__":
    scene_list = [  'scene0000_00', 'scene0062_00', 'scene0070_00', 'scene0097_00', 'scene0140_00', 
                    'scene0200_00', 'scene0347_00', 'scene0400_00', 'scene0590_00', 'scene0645_00']

    iteration = 90000
    for scan_name in scene_list:
        # (1) GT ply    change!
        gt_file_path = f"/gdata/cold1/wuyanmin/OpenGaussian/data/scannet_2d_3types/{scan_name}/{scan_name}_vh_clean_2.labels.ply"
        points, labels = read_labels_from_ply(gt_file_path)

        # (2) note: 19 & 15 & 10 classes
        # Given the category ID that needs to be queried (relative to the original NYU40), obtain the corresponding category name.
        target_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36]   # 19
        # target_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 33, 34]   # 15
        # target_id = [1,2,4,5,6,7,8,9,10,33] # 10

        target_dict = {key: nyu40_dict[key] for key in target_id}
        target_names = list(target_dict.values())

        # (3) update gt label
        # Obtained new point cloud labels, taking 19 categories as an example, where updated_labels are labels 0, 1-19.
        target_id_mapping = {value: index + 1 for index, value in enumerate(target_id)}
        updated_labels = np.zeros_like(labels)
        for original_value, new_value in target_id_mapping.items():
            updated_labels[labels == original_value] = new_value
        updated_gt_labels = torch.from_numpy(updated_labels.astype(np.int64)).cuda()
        
        # (4) load gaussian ply file
        model_path = f"output/{scan_name}/"
        ply_path = os.path.join(model_path, f"point_cloud/iteration_{iteration}/point_cloud.ply")
        ply_data = PlyData.read(ply_path)
        vertex_data = ply_data['vertex'].data
        # NOTE Filter out points based on their opacity values.
        ignored_pts = sigmoid(vertex_data["opacity"]) < 0.1
        updated_gt_labels[ignored_pts] = 0

        # (5) load cluster language file
        mapping_file = os.path.join(model_path, "cluster_lang.npz")
        # load the saved codebook(leaf id) and instance-level language feature
        # 'leaf_feat', 'leaf_acore', 'occu_count', 'leaf_ind'
        saved_data = np.load(mapping_file)
        leaf_lang_feat = torch.from_numpy(saved_data["leaf_feat.npy"]).cuda()    # [num_leaf=k1*k2, 512] 
        leaf_score = torch.from_numpy(saved_data["leaf_score.npy"]).cuda()       # [num_leaf=k1*k2] 
        leaf_occu_count = torch.from_numpy(saved_data["occu_count.npy"]).cuda()  # [num_leaf=k1*k2] 
        leaf_ind = torch.from_numpy(saved_data["leaf_ind.npy"]).cuda()           # [num_pts] 
        leaf_lang_feat[leaf_occu_count < 2] *= 0.0
        leaf_ind = leaf_ind.clamp(max=319)  # 64*5=320

        # (6) load query text feat.
        with open('assets/text_features.json', 'r') as f:
            data_loaded = json.load(f)
        all_texts = list(data_loaded.keys())
        text_features = torch.from_numpy(np.array(list(data_loaded.values()))).to(torch.float32)  # [num_text, 512]
        
        query_text_feats = torch.zeros(len(target_names), 512).cuda()
        for i, text in enumerate(target_names):
            feat = text_features[all_texts.index(text)].unsqueeze(0)
            query_text_feats[i] = feat

        # (7) Calculate the cosine similarity and return the ID of the category with the highest value.
        query_text_feats = F.normalize(query_text_feats, dim=1, p=2)  
        leaf_lang_feat = F.normalize(leaf_lang_feat, dim=1, p=2)  
        cosine_similarity = torch.matmul(query_text_feats, leaf_lang_feat.transpose(0, 1))
        # cosine_similarity = torch.mm(query_text_feats, leaf_lang_feat.t())   # [cls_num, cluster_num]
        max_id = torch.argmax(cosine_similarity, dim=0) # [cluster_num]
        pred_pts_cls_id = max_id[leaf_ind] + 1          # [num_pts] 

        ious, mean_iou, accuracy, mean_acc = calculate_metrics(updated_gt_labels, pred_pts_cls_id, total_classes=len(target_names)+1)
        print(f"Scene: {scan_name}, mIoU: {mean_iou:.4f}, mAcc.: {mean_acc:.4f}") 

================================================
FILE: scripts/render_by_click.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import torch
import torch.nn.functional as F
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import numpy as np
from PIL import Image
import json
from utils.opengs_utlis import mask_feature_mean, get_SAM_mask_and_feat, load_code_book
import pytorch3d.ops

np.random.seed(42)
colors_defined = np.random.randint(100, 256, size=(300, 3))
colors_defined[0] = np.array([0, 0, 0])
colors_defined = torch.from_numpy(colors_defined)

def get_pixel_values(image_path, position, radius=10):
    with Image.open(image_path) as img:
        img = img.convert('RGB')
        width, height = img.size
        
        left = max(position[0] - radius, 0)
        right = min(position[0] + radius + 1, width)
        top = max(position[1] - radius, 0)
        bottom = min(position[1] + radius + 1, height)

        pixels = []
        for x in range(left, right):
            for y in range(top, bottom):
                pixels.append(img.getpixel((x, y)))

        pixels_array = np.array(pixels)
        mean_pixel = pixels_array.mean(axis=0)
    
    return tuple(mean_pixel)

def compute_click_values(model_path, image_name, pix_xy, radius=5):
    def compute_level_click_val(iter, model_path, image_name, pix_xy, radius):
        img_path1 = f"{model_path}/train/ours_{iter}/renders_ins_feat1/{image_name}_1.png"      # TODO
        img_path2 = f"{model_path}/train/ours_{iter}/renders_ins_feat2/{image_name}_2.png"      # TODO
        val1 = get_pixel_values(img_path1, pix_xy, radius)
        val2 = get_pixel_values(img_path2, pix_xy, radius)
        click_val = (torch.tensor(list(val1) + list(val2)) / 255) * 2 - 1
        return click_val
    
    level1_click_val = compute_level_click_val(50000, model_path, image_name, pix_xy, radius)   # TODO
    level2_click_val = compute_level_click_val(70000, model_path, image_name, pix_xy, radius)   # TODO
    
    return level1_click_val, level2_click_val

def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
    render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
    gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")

    render_ins_feat_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_ins_feat")
    gt_sam_mask_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt_sam_mask")
    pseudo_ins_feat_path = os.path.join(model_path, name, "ours_{}".format(iteration), "pseudo_ins_feat")

    makedirs(render_path, exist_ok=True)
    makedirs(gts_path, exist_ok=True)
    makedirs(render_ins_feat_path, exist_ok=True)
    makedirs(gt_sam_mask_path, exist_ok=True)
    makedirs(pseudo_ins_feat_path, exist_ok=True)

    # load codebook
    root_code_book, root_cluster_indices = load_code_book(os.path.join(model_path, "point_cloud", \
        f'iteration_{iteration}', "root_code_book"))
    leaf_code_book, leaf_cluster_indices = load_code_book(os.path.join(model_path, "point_cloud", \
        f'iteration_{iteration}', "leaf_code_book"))
    root_cluster_indices = torch.from_numpy(root_cluster_indices).cuda()
    leaf_cluster_indices = torch.from_numpy(leaf_cluster_indices).cuda()
    # counts = torch.bincount(torch.from_numpy(cluster_indices), minlength=64)

    # load the saved codebook(leaf id) and instance-level language feature
    # 'leaf_feat', 'leaf_acore', 'occu_count', 'leaf_ind'       leaf_figurines_cluster_lang
    mapping_file = os.path.join(model_path, "cluster_lang.npz")
    saved_data = np.load(mapping_file)
    leaf_lang_feat = torch.from_numpy(saved_data["leaf_feat.npy"]).cuda()    # [num_leaf=640, 512] Language feature of each instance
    leaf_score = torch.from_numpy(saved_data["leaf_score.npy"]).cuda()       # [num_leaf=640] Score of each instance
    leaf_occu_count = torch.from_numpy(saved_data["occu_count.npy"]).cuda()  # [num_leaf=640] Number of occurrences of each instance
    leaf_ind = torch.from_numpy(saved_data["leaf_ind.npy"]).cuda()           # [num_pts] Instance ID corresponding to each point
    leaf_lang_feat[leaf_occu_count < 5] *= 0.0      # ignore
    leaf_cluster_indices = leaf_ind
    
    image_name = "frame_00002"      # TODO
    # # object_name = "apple"
    # pix_xy = (450, 217) # bag of cookies
    # pix_xy = (344, 350) # apple
    # # teatime       image_name = "frame_00002"
    # object_names = ["bear nose", "stuffed bear", "sheep", "bag of cookies", \
    #                 "plate", "three cookies", "tea in a glass", "apple", \
    #                 "coffee mug", "coffee", "paper napkin"]
    # pix_xy_list = [ (740, 80), (800, 160), (80, 240), (450, 200),
    #                 (468, 288), (438, 273), (309, 308), (343, 361),
    #                 (578, 274), (571, 260), (565, 380)]
    # figurines   image_name = "frame_00002"
    # TODO
    object_names = ["rubber duck with buoy", "porcelain hand", "miffy", "toy elephant", "toy cat statue", \
                    "jake", "Play-Doh bucket", "rubber duck with hat", "rubics cube", "waldo", \
                    "twizzlers", "red toy chair", "green toy chair", "pink ice cream", "spatula", \
                    "pikachu", "green apple", "rabbit", "old camera", "pumpkin", \
                    "tesla door handle"]
    # TODO
    pix_xy_list = [ (103, 378), (552, 390), (896, 342), (720, 257), (254, 297),
                    (451, 197), (626, 256), (760, 166), (781, 243), (896, 136),
                    (927, 241), (688, 148), (538, 160), (565, 238), (575, 257),
                    (377, 156), (156, 244), (21, 237), (283, 152), (330, 200),
                    (514, 200)]
    # # ramen           image_name = "frame_00002"
    # object_names = ["clouth", "sake cup", "chopsticks", "spoon", "plate", \
    #                 "bowl", "egg", "nori", "glass of water", "napkin"]
    # pix_xy_list = [(345, 38), (276, 424), (361, 370), (419, 285), (688, 412),
    #                (489, 119), (694, 187), (810, 154), (939, 289), (428, 462)]
    # # waldo_kitchen     image_name = "frame_00001"
    # object_names = ["knife", "pour-over vessel", "glass pot1", "glass pot2", "toaster", \
    #                 "hot water pot", "metal can", "cabinet", "ottolenghi", "waldo"]
    # pix_xy_list = [(439, 76), (410, 297), (306, 127), (349, 182), (261, 256),
    #                (201, 262), (161, 267), (80, 34), (17, 141), (76, 169)]

    for o_i, object in enumerate(object_names):
        pix_xy = pix_xy_list[o_i]
        root_click_val, leaf_click_val = compute_click_values(model_path, image_name, pix_xy)
    
        # Compute the nearest clusters with respect to the two-level codebook
        distances_root = torch.norm(root_click_val - root_code_book["ins_feat"][:, :-3].cpu(), dim=1)
        distances_leaf = torch.norm(leaf_click_val - leaf_code_book["ins_feat"][:-1, :].cpu(), dim=1)
        distances_leaf[leaf_code_book["ins_feat"][:-1].sum(-1) == 0] = 999  # Assign a large value to dis for nodes that remain unassigned
        
        # Retrieve the candidate child nodes linked to each selected root node
        min_index_root = torch.argmin(distances_root).item()
        leaf_num = (leaf_code_book["ins_feat"].shape[0] - 1) / root_code_book["ins_feat"].shape[0]
        start_id = int(min_index_root*leaf_num)
        end_id = int((min_index_root + 1)*leaf_num)
        distances_leaf_sub = distances_leaf[start_id: end_id]   # [10]

        # # (1) Choose several child nodes that fulfill the requirements
        # click_leaf_indices = torch.nonzero(distances_leaf_sub < 0.9).squeeze() + start_id
        # if (click_leaf_indices.dim() == 0) and click_leaf_indices.numel() != 0:
        #     click_leaf_indices = click_leaf_indices.unsqueeze(0) 
        # elif click_leaf_indices.numel() == 0:
        #     click_leaf_indices = torch.argmin(distances_leaf_sub).unsqueeze(0)
        # (2) identify the root-level codebook and then pick the closest leaf node inside it (preferred)
        click_leaf_indices = torch.argmin(distances_leaf_sub).unsqueeze(0) + start_id
        # (3) directly select the child node with the minimum distance (less precise)
        # click_leaf_indices = torch.argmin(distances_leaf).unsqueeze(0)
        # # (4) you can also directly specify a particular child node if needed
        # click_leaf_indices = torch.tensor([60, 66])     # 64 picachu, 60, 66 toy elephant, 65 jake, 633 green apple, 639 duck
        
        # Get the mask linked to the child node
        pre_pts_mask = (leaf_cluster_indices.unsqueeze(1) == click_leaf_indices.cuda()).any(dim=1)

        # post process  modify-----
        post_process = True
        max_time = 5
        if post_process and max_time > 0:
            nearest_k_distance = pytorch3d.ops.knn_points(
                gaussians._xyz[pre_pts_mask].unsqueeze(0),
                gaussians._xyz[pre_pts_mask].unsqueeze(0),
                K=int(pre_pts_mask.sum()**0.5) * 2,
            ).dists
            mean_nearest_k_distance, std_nearest_k_distance = nearest_k_distance.mean(), nearest_k_distance.std()
            # print(std_nearest_k_distance, "std_nearest_k_distance")

            # mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + std_nearest_k_distance
            mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + 0.1 * std_nearest_k_distance
            # mask = nearest_k_distance.mean(dim = -1) < 2 * mean_nearest_k_distance 

            mask = mask.squeeze()
            if pre_pts_mask is not None:
                pre_pts_mask[pre_pts_mask != 0] = mask
            max_time -= 1

        # out_dir = "ca9c2998-e"
        # splits = ["train", "train", "train", "train", "test"]
        # frame_name_list = ["frame_00053", "frame_00066", "frame_00140", "frame_00154", "frame_00089"]
        # for f_i, frame_name in enumerate(frame_name_list):
        #     base_path = f"/mnt/disk1/codes/wuyanmin/code/OpenGaussian/output/{out_dir}/{splits[f_i]}/ours_70000/renders_cluster_silhouette"
        #     target_path = f"/mnt/disk1/codes/wuyanmin/code/OpenGaussian/output/{out_dir}/{splits[f_i]}/ours_70000/result/{frame_name}"
        #     makedirs(target_path, exist_ok=True)
        #     for _, text in enumerate(waldo_kitchen_texts):
        #         pos_feat = text_features[query_texts.index(text)].unsqueeze(0)
        #         similarity_pos = F.cosine_similarity(pos_feat, leaf_lang_feat.cpu())    # [640]
        #         top_values, top_indices = torch.topk(similarity_pos, 10)   # [num_mask]
        #         print("text: {} | cluster id: {}".format(text, top_indices[0]))
        #         ori_img_name = base_path + f"/{frame_name}_cluster_{top_indices[0].item()}.png"
        #         new_name = target_path + f"/{text}.png"
                
        #         if not os.path.exists(ori_img_name):
        #             top = 10
        #             for i in range(top):
        #                 ori_img_name = target_path + f"/{frame_name}_cluster_{top_indices[i].item()}.png"
        #                 if os.path.exists(ori_img_name):
        #                     break
        #         if not os.path.exists(ori_img_name):
        #             print(f"No file found at {ori_img_name}. Operation skipped.")
        #             continue
        #         import shutil
        #         shutil.copy2(ori_img_name, new_name)

        # render
        for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
            # render_pkg = render(view, gaussians, pipeline, background, iteration, rescale=False)
            
            # # figurines
            # if  view.image_name not in ["frame_00041", "frame_00105", "frame_00152", "frame_00195"]:
            #     continue
            # # teatime
            # if  view.image_name not in ["frame_00002", "frame_00025", "frame_00043", "frame_00107", "frame_00129", "frame_00140"]:
            #     continue
            # # ramen
            # if  view.image_name not in ["frame_00006", "frame_00024", "frame_00060", "frame_00065", "frame_00081", "frame_00119", "frame_00128"]:
            #     continue
            # # waldo_kitchen
            # if  view.image_name not in ["frame_00053", "frame_00066", "frame_00089", "frame_00140", "frame_00154"]:
            #     continue

            # NOTE render
            render_pkg = render(view, gaussians, pipeline, background, iteration,
                                rescale=False,                #)  # wherther to re-scale the gaussian scale
                                # cluster_idx=leaf_cluster_indices,     # root id 
                                leaf_cluster_idx=leaf_cluster_indices,            # leaf id               
                                selected_leaf_id=click_leaf_indices.cuda(),       # selected leaf id      
                                render_feat_map=True, 
                                render_cluster=False,
                                better_vis=True,
                                pre_mask=pre_pts_mask,
                                seg_rgb=True)
            rendering = render_pkg["render"]
            rendered_cluster_imgs = render_pkg["leaf_clusters_imgs"]
            occured_leaf_id = render_pkg["occured_leaf_id"]
            rendered_leaf_cluster_silhouettes = render_pkg["leaf_cluster_silhouettes"]

            # save Rendered RGB
            torchvision.utils.save_image(rendering, os.path.join(render_path, view.image_name + ".png"))

            render_cluster_path = os.path.join(model_path, name, "ours_{}".format(iteration), "click_cluster")
            render_cluster_silhouette_path = os.path.join(model_path, name, "ours_{}".format(iteration), "click_cluster_mask")
            makedirs(render_cluster_path, exist_ok=True)
            makedirs(render_cluster_silhouette_path, exist_ok=True)
            for i, img in enumerate(rendered_cluster_imgs):
                torchvision.utils.save_image(img[:3,:,:], os.path.join(render_cluster_path, \
                    view.image_name + f"_{object}_cluster_{occured_leaf_id[i]}.png"))
                # save mask
                cluster_silhouette = rendered_leaf_cluster_silhouettes[i] > 0.8
                torchvision.utils.save_image(cluster_silhouette.to(torch.float32), os.path.join(render_cluster_silhouette_path, \
                    view.image_name + f"_{object}_cluster_{occured_leaf_id[i]}.png"))

def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
    with torch.no_grad():
        gaussians = GaussianModel(dataset.sh_degree)
        scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)

        bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
        background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")

        if not skip_train:
             render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)

        if not skip_test:
             render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)

if __name__ == "__main__":
    # Set up command line argument parser
    parser = ArgumentParser(description="Testing script parameters")
    model = ModelParams(parser, sentinel=True)
    pipeline = PipelineParams(parser)
    parser.add_argument("--iteration", default=-1, type=int)
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_test", action="store_true")
    parser.add_argument("--quiet", action="store_true")
    args = get_combined_args(parser)
    print("Rendering " + args.model_path)

    # Initialize system state (RNG)
    safe_state(args.quiet)

    render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)

================================================
FILE: scripts/scannet2blender.py
================================================
import os
import json
import numpy as np

def load_transform_matrix(file_path):
    """
    Load the transform matrix from a text file.
    """
    with open(file_path, 'r') as file:
        matrix = [list(map(float, line.strip().split())) for line in file]
    return matrix

def process_directory(directory_path):
    """
    Process each directory and create a JSON file with the transform matrices.
    """
    color_dir = os.path.join(directory_path, "color")           # TODO
    pose_dir = os.path.join(directory_path, "pose")             # TODO
    intrinsic_dir = os.path.join(directory_path, "intrinsic")   # TODO

    # Check if both directories exist
    if not os.path.isdir(color_dir) or not os.path.isdir(pose_dir):
        return

    # scannet
    transform_data = {
            'w': 1296,
            'h': 968,
            'fl_x': 1170.187988,
            'fl_y': 1170.187988,
            'cx': 647.75,
            'cy': 483.75,
            # 'aabb_scale': 2,
            'frames': [],
        }
    # # scannet
    # transform_data = {
    #         'w': 640,
    #         'h': 512,
    #         'fl_x': 534.56,
    #         'fl_y': 534.80,
    #         'cx': 314.27,
    #         'cy': 259.96,
    #         # 'aabb_scale': 2,
    #         'frames': [],
    #     }
    # Collect all image names and sort them
    img_names = [img_name for img_name in os.listdir(color_dir) if img_name.endswith(".jpg")]
    # img_names.sort(key=lambda x: int(os.path.splitext(x)[0]))  # Sort by image number
    img_names.sort(key=lambda x: os.path.splitext(x)[0])  # Sort by image number

    # Iterate over the color images
    for img_name in img_names:
        if img_name.endswith(".jpg"):
            # Construct the corresponding pose file path
            pose_file = os.path.splitext(img_name)[0] + ".txt"
            pose_file_path = os.path.join(pose_dir, pose_file)

            intrinsic_file = os.path.splitext(img_name)[0] + ".txt"
            intrinsic_file_path = os.path.join(intrinsic_dir, intrinsic_file)

            # Check if the pose file exists
            if os.path.isfile(pose_file_path):
                transform_matrix = load_transform_matrix(pose_file_path)
                
                # note: colmap --> blender
                transform_matrix = np.array(transform_matrix)
                transform_matrix[:3, 1:3] *= -1     
                transform_matrix = transform_matrix.tolist()

                frame_data = {
                    "file_path": os.path.join("color", os.path.splitext(img_name)[0]),
                    "transform_matrix": transform_matrix
                }

                if os.path.isfile(intrinsic_file_path):
                    intrinsic_info = load_transform_matrix(intrinsic_file_path)
                    frame_data.update({
                        'fl_x': intrinsic_info[0][0],
                        'fl_y': intrinsic_info[1][1],
                        'cx':  intrinsic_info[0][2],
                        'cy': intrinsic_info[1][2]
                    })

                transform_data["frames"].append(frame_data)

    return transform_data

# Directory containing the scenes
base_directory = 'PATH_TO_YOUR_SCANNET'     # TODO

# Process each scene directory and create JSON files
for scene_dir in os.listdir(base_directory):
    # if scene_dir != "scene0000_00":
    #     continue
    
    scene_path = os.path.join(base_directory, scene_dir)
    if os.path.isdir(scene_path):
        # Process the directory and get the transform data
        transform_data = process_directory(scene_path)

        print(scene_path)
        
        # Create the JSON file
        if transform_data:
            json_file_path = os.path.join(scene_path, "transforms_train.json")
            with open(json_file_path, 'w') as json_file:
                json.dump(transform_data, json_file, indent=4)


================================================
FILE: scripts/train_lerf.sh
================================================
#!/bin/bash
# chmod +x scripts/train_lerf.sh
# ./scripts/train_lerf.sh

# !!! Please check the dataset path specified by -s.

# Total training steps: 70k
# 3dgs pre-train: 0~30k
# stage1: 30~40k
# stage2 (coarse-level): 40~50k
# stage2 (fine-level): 50k~70k

# ###############################################
# #              (1/4) figurines
# # Training takes approximately 70 minutes on a 24G 4090 GPU.
# # The object selection effect is better (recommended), the point cloud visualization is poor (not recommended).
# # k1=64, k2=10
# # --pos_weight 0.5
# # --save_memory: Saves memory, but will reduce training speed. If your GPU memory > 24GB, you can omit this flag
# ###############################################
scan="figurines"
gpu_num=3           # change
echo "Training for ${scan} ....."
CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \
    -s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \
    --iterations 70_000 \
    --start_ins_feat_iter 30_000 \
    --start_root_cb_iter 40_000 \
    --start_leaf_cb_iter 50_000 \
    --sam_level 3 \
    --root_node_num 64 \
    --leaf_node_num 10 \
    --pos_weight 0.5 \
    --save_memory \
    --test_iterations 30000 \
    --eval


# ###############################################
# #              (2/4) waldo_kitchen
# # Training takes approximately 60 minutes on a 24G 4090 GPU.
# # Good point cloud visualization result (recommended), suboptimal object selection effect.
# # k1=64, k2=10
# # --pos_weight 0.5
# # No need to set save_memory, 24G is sufficient.
# ###############################################
scan="waldo_kitchen"
gpu_num=3           # change
echo "Training for ${scan} ....."
CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \
    -s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \
    --iterations 70_000 \
    --start_ins_feat_iter 30_000 \
    --start_root_cb_iter 40_000 \
    --start_leaf_cb_iter 50_000 \
    --sam_level 3 \
    --root_node_num 64 \
    --leaf_node_num 10 \
    --pos_weight 0.5 \
    --test_iterations 30000 \
    --eval


# ###############################################
# #              (3/4) teatime
# # Training takes approximately 80 minutes on a 24G 4090 GPU.
# # k1=32, k2=10
# # --pos_weight 0.1
# # --save_memory: Saves memory, but will reduce training speed. If your GPU memory > 24GB, you can omit this flag
# ###############################################
scan="teatime"
gpu_num=3       # change
echo "Training for ${scan} ....."
CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \
    -s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \
    --iterations 70_000 \
    --start_ins_feat_iter 30_000 \
    --start_root_cb_iter 40_000 \
    --start_leaf_cb_iter 50_000 \
    --sam_level 3 \
    --root_node_num 32 \
    --leaf_node_num 10 \
    --pos_weight 0.1 \
    --save_memory \
    --test_iterations 30000 \
    --eval


# ###############################################
# #              (4/4) ramen
# # Training takes approximately 40 minutes on a 24G 4090 GPU.
# # The object selection effect is the worst and unstable (not recommended).
# # k1=64, k2=10
# # --pos_weight 0.5
# # --loss_weight 0.01: the weight of intra-mask smooth loss. 0.1 is used for the other scenes.
# # No need to set save_memory, 24G is sufficient.
# ###############################################
scan="ramen"
gpu_num=3
echo "Training for ${scan} ....."
CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \
    -s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \
    --iterations 70_000 \
    --start_ins_feat_iter 30_000 \
    --start_root_cb_iter 40_000 \
    --start_leaf_cb_iter 50_000 \
    --sam_level 3 \
    --root_node_num 64 \
    --leaf_node_num 10 \
    --pos_weight 0.5 \
    --loss_weight 0.01 \
    --test_iterations 30000 \
    --eval

================================================
FILE: scripts/train_scannet.sh
================================================
#!/bin/bash
# chmod +x scripts/train_scannet.sh
# ./scripts/train_scannet.sh

# ============== [Notice] ==============
# 1. The 10 scene hyperparameters in the ScanNet dataset are consistent.
# 2. Train a scene for about 20 minutes on a 24G 4090 GPU.
# 3. Please check the dataset path specified by -s.

# ============== [Hyperparameter explanation] ==============
# Total training steps: 90k
# 3dgs pre-train: 0~30k
# stage1: 30~50k
# stage2 (coarse-level): 50~70k
# stage2 (fine-level): 70k~90k
# k1=64, k2=5
# frozen_init_pts: The point clouds provided by the ScanNet dataset are frozen, without using the densification scheme of 3DGS.
# -r 2 : We use half-resolution data for training.

# ============== [10 scenes] ==============
scan_list=("scene0000_00" "scene0062_00" "scene0070_00" "scene0097_00" "scene0140_00" \
"scene0200_00" "scene0347_00" "scene0400_00" "scene0590_00" "scene0645_00")

gpu_num=3     # change!
for scan in "${scan_list[@]}"; do
    echo "Training for ${scan} ....."
    CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \
        -s /gdata/cold1/wuyanmin/OpenGaussian/data/onedrive/scannet/${scan} \
        -r 2 \
        --frozen_init_pts \
        --iterations 90_000 \
        --start_ins_feat_iter 30_000 \
        --start_root_cb_iter 50_000 \
        --start_leaf_cb_iter 70_000 \
        --sam_level 0 \
        --root_node_num 64 \
        --leaf_node_num 5 \
        --pos_weight 1.0 \
        --test_iterations 30000 \
        --eval
done

================================================
FILE: scripts/vis_opengs_pts_feat.py
================================================
import numpy as np
from plyfile import PlyData
import open3d as o3d

def sigmoid(x):
    """Sigmoid function."""
    return 1 / (1 + np.exp(-x))

def visualize_ply(ply_path):
    # Load the PLY file
    ply_data = PlyData.read(ply_path)
    vertex_data = ply_data['vertex'].data

    # Extract the point cloud attributes
    points = np.array([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
    colors = np.array([vertex_data['red'], vertex_data['green'], vertex_data['blue']]).T / 255.0
    opacity = vertex_data['opacity']

    # Apply the opacity filter
    sigmoid_opacity = sigmoid(opacity)
    filtered_indices = sigmoid_opacity >= 0.1
    filtered_points = points[filtered_indices]
    filtered_colors = colors[filtered_indices]

    # Create an Open3D PointCloud object
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(filtered_points)
    pcd.colors = o3d.utility.Vector3dVector(filtered_colors)

    # Visualize the point cloud
    o3d.visualization.draw_geometries([pcd])

if __name__ == "__main__":
    # Replace with the path to your PLY file
    ply_path = "output/xxxxxxxx-x/point_cloud/iteration_x0000/point_cloud.ply"
    visualize_ply(ply_path)

================================================
FILE: train.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

import os
import torch
import torch.nn.functional as F
from random import randint
from utils.loss_utils import l1_loss, ssim, l2_loss
from gaussian_renderer import render, network_gui
import sys
from scene import Scene, GaussianModel
from utils.general_utils import safe_state
import uuid
from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
from os import makedirs
import torchvision
import numpy as np
from utils.sh_utils import RGB2SH
import math
# import faiss
from scene.kmeans_quantize import Quantize_kMeans
from bitarray import bitarray
from utils.system_utils import mkdir_p
from utils.opengs_utlis import mask_feature_mean, pair_mask_feature_mean, \
    get_SAM_mask_and_feat, load_code_book, \
    calculate_iou, calculate_distances, calculate_pairwise_distances

try:
    from torch.utils.tensorboard import SummaryWriter
    TENSORBOARD_FOUND = True
except ImportError:
    TENSORBOARD_FOUND = False

# Randomly initialize 300 colors for visualizing the SAM mask. [OpenGaussian]
np.random.seed(42)
colors_defined = np.random.randint(100, 256, size=(300, 3))
colors_defined[0] = np.array([0, 0, 0]) # Ignore the mask ID of -1 and set it to black.
colors_defined = torch.from_numpy(colors_defined)

def dec2binary(x, n_bits=None):
    """Convert decimal integer x to binary.

    Code from: https://stackoverflow.com/questions/55918468/convert-integer-to-pytorch-tensor-of-binary-bits
    """
    if n_bits is None:
        n_bits = torch.ceil(torch.log2(x)).type(torch.int64)
    mask = 2**torch.arange(n_bits-1, -1, -1).to(x.device, x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0)

def save_kmeans(kmeans_list, quantized_params, out_dir, mode="root"):
    """Save the codebook and indices of KMeans.

    """
    # Convert to bitarray object to save compressed version
    # saving as npy or pth will use 8bits per digit (or boolean) for the indices
    # Convert to binary, concat the indices for all params and save.
    if mode=="root":
        out_dir = os.path.join(out_dir, 'root_code_book')
    elif mode=="leaf":
        out_dir = os.path.join(out_dir, 'leaf_code_book')
    
    mkdir_p(out_dir)
    bitarray_all = bitarray([])
    for kmeans in kmeans_list:
        if mode=="root":
            cls_ids = kmeans.cls_ids
        elif mode=="leaf":
            cls_ids = kmeans.leaf_cls_ids
        n_bits = int(np.ceil(np.log2(len(cls_ids))))
        assignments = dec2binary(cls_ids, n_bits)
        bitarr = bitarray(list(assignments.cpu().numpy().flatten()))
        bitarray_all.extend(bitarr)
    with open(os.path.join(out_dir, 'kmeans_inds.bin'), 'wb') as file:  # cls_ids
        bitarray_all.tofile(file)

    # Save details needed for loading
    args_dict = {}
    args_dict['params'] = quantized_params
    args_dict['n_bits'] = n_bits
    args_dict['total_len'] = len(bitarray_all)
    np.save(os.path.join(out_dir, 'kmeans_args.npy'), args_dict)
    if mode=="root":
        centers_dict = {param: kmeans.centers for (kmeans, param) in zip(kmeans_list, quantized_params)}
    elif mode=="leaf":
        centers_dict = {param: kmeans.leaf_centers for (kmeans, param) in zip(kmeans_list, quantized_params)}

    # Save codebook
    torch.save(centers_dict, os.path.join(out_dir, 'kmeans_centers.pth'))

def cohesion_loss(feat_map, gt_mask, feat_mean_stack):
    """intra-mask smoothing loss. Eq.(1) in the paper
    Constrain the feature of each pixel within the mask to be close to the mean feature of that mask.
    """
    N, H, W = gt_mask.shape
    C = feat_map.shape[0]
    # expand feat_map [6, H, W] to [N, 6, H, W]
    feat_map_expanded = feat_map.unsqueeze(0).expand(N, C, H, W)
    # expand mean feat [N, 6] to [N, 6, H, W]
    feat_mean_stack_expanded = feat_mean_stack.unsqueeze(-1).unsqueeze(-1).expand(N, C, H, W)
    
    # fature distance    
    masked_feat = feat_map_expanded * gt_mask.unsqueeze(1)           # [N, 6, H, W]
    dist = (masked_feat - feat_mean_stack_expanded).norm(p=2, dim=1) # [N, H, W]
    
    # per mask feature distance (loss)
    masked_dist = dist * gt_mask    # [N, H, W]
    loss_per_mask = masked_dist.sum(dim=[1, 2]) / gt_mask.sum(dim=[1, 2]).clamp(min=1)

    return loss_per_mask.mean()

def separation_loss(feat_mean_stack, iteration):
    """ inter-mask contrastive loss Eq.(2) in the paper
    Constrain the instance features within different masks to be as far apart as possible.
    """
    N, _ = feat_mean_stack.shape

    # expand feat_mean_stack[N, 6] to [N, N, C]
    feat_expanded = feat_mean_stack.unsqueeze(1).expand(-1, N, -1)
    feat_transposed = feat_mean_stack.unsqueeze(0).expand(N, -1, -1)
    
    # distance
    diff_squared = (feat_expanded - feat_transposed).pow(2).sum(2)
    
    # Calculate the inverse of the distance to enhance discrimination
    epsilon = 1     # 1e-6
    inverse_distance = 1.0 / (diff_squared + epsilon)
    # Exclude diagonal elements (distance from itself) and calculate the mean inverse distance
    mask = torch.eye(N, device=feat_mean_stack.device).bool()
    inverse_distance.masked_fill_(mask, 0)  

    # note: weight
    # sorted by distance
    sorted_indices = inverse_distance.argsort().argsort()
    loss_weight = (sorted_indices.float() / (N - 1)) * (1.0 - 0.1) + 0.1    # scale to 0.1 - 1.0, [N, N]
    # small weight
    if iteration > 35_000:
        loss_weight[loss_weight < 0.9] = 0.1
    inverse_distance *= loss_weight     # [N, N]

    # final loss
    loss = inverse_distance.sum() / (N * (N - 1))

    return loss

def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, \
             checkpoint, debug_from):
    iterations = [opt.start_ins_feat_iter, opt.start_leaf_cb_iter, opt.start_root_cb_iter]
    saving_iterations.extend(iterations)
    checkpoint_iterations.extend(iterations)

    first_iter = 0
    tb_writer = prepare_output_and_logger(dataset)
    gaussians = GaussianModel(dataset.sh_degree)
    scene = Scene(dataset, gaussians)
    gaussians.training_setup(opt)
    if checkpoint:
        (model_params, first_iter) = torch.load(checkpoint)
        # NOTE: Load the original 3DGS pre-trained checkpoint and add the ins_feat attribute. [OpenGaussian]
        if len(model_params) == 12:
            # initialize instance color.
            ins_feat = torch.rand((model_params[8].shape[0], opt.ins_feat_dim), dtype=torch.float, device="cuda")
            ins_feat = torch.nn.Parameter(ins_feat.requires_grad_(True))
            to_list = list(model_params)
            # (1) replace optimizer
            to_list[10] = gaussians.optimizer.state_dict()
            # (2) add ins_feat 
            to_list.insert(7, ins_feat)
            # (3) add ins_feat_q (quantized ins_feat)
            ins_feat_q = torch.empty(0)
            to_list.insert(8, ins_feat_q)
            model_params = tuple(to_list)
        gaussians.restore(model_params, opt)
        ins_feat_continue = gaussians._ins_feat.clone().detach()    # not used
    else:
        ins_feat_continue = None    # not used

    # initialize the codebook
    ins_feat_codebook = Quantize_kMeans(num_clusters=opt.root_node_num,         # k1
                                        num_leaf_clusters=opt.leaf_node_num,    # k2
                                        num_iters=5, 
                                        dim=9)
    
    # note: load the saved codebook
    leaf_cluster_indices = None
    if checkpoint:
        base_dir = os.path.dirname(checkpoint)
        load_iter = checkpoint.split('/')[-1].split('.')[0][6:]
        root_code_book_path = os.path.join(base_dir, 'point_cloud', f"iteration_{load_iter}", "root_code_book")
        leaf_code_book_path = os.path.join(base_dir, 'point_cloud', f"iteration_{load_iter}", "leaf_code_book")
        if os.path.exists(os.path.join(root_code_book_path, 'kmeans_inds.bin')):
            root_center, root_indices = load_code_book(root_code_book_path)
            root_center_saved = root_center["ins_feat"]
            cluster_indices = torch.from_numpy(root_indices).cuda()
            ins_feat_codebook.centers = root_center_saved
            ins_feat_codebook.cls_ids = cluster_indices
        else:
            cluster_indices = None
        if os.path.exists(os.path.join(leaf_code_book_path, 'kmeans_inds.bin')):
            leaf_center, leaf_indices = load_code_book(leaf_code_book_path)
            leaf_center_saved = leaf_center["ins_feat"]
            leaf_cluster_indices = torch.from_numpy(leaf_indices).cuda()
            ins_feat_codebook.leaf_centers = leaf_center_saved
            ins_feat_codebook.leaf_cl
Download .txt
gitextract_gh4x629c/

├── .gitignore
├── LICENSE.md
├── README.md
├── arguments/
│   └── __init__.py
├── convert.py
├── environment.yml
├── full_eval.py
├── gaussian_renderer/
│   ├── __init__.py
│   └── network_gui.py
├── lpipsPyTorch/
│   ├── __init__.py
│   └── modules/
│       ├── lpips.py
│       ├── networks.py
│       └── utils.py
├── metrics.py
├── render.py
├── render_lerf_by_text.py
├── scene/
│   ├── __init__.py
│   ├── cameras.py
│   ├── colmap_loader.py
│   ├── dataset_readers.py
│   ├── gaussian_model.py
│   └── kmeans_quantize.py
├── scripts/
│   ├── compute_lerf_iou.py
│   ├── eval_scannet.py
│   ├── render_by_click.py
│   ├── scannet2blender.py
│   ├── train_lerf.sh
│   ├── train_scannet.sh
│   └── vis_opengs_pts_feat.py
├── train.py
└── utils/
    ├── camera_utils.py
    ├── general_utils.py
    ├── graphics_utils.py
    ├── image_utils.py
    ├── loss_utils.py
    ├── opengs_utlis.py
    ├── sh_utils.py
    └── system_utils.py
Download .txt
SYMBOL INDEX (189 symbols across 30 files)

FILE: arguments/__init__.py
  class GroupParams (line 16) | class GroupParams:
  class ParamGroup (line 19) | class ParamGroup:
    method __init__ (line 20) | def __init__(self, parser: ArgumentParser, name : str, fill_none = Fal...
    method extract (line 40) | def extract(self, args):
  class ModelParams (line 47) | class ModelParams(ParamGroup):
    method __init__ (line 48) | def __init__(self, parser, sentinel=False):
    method extract (line 59) | def extract(self, args):
  class PipelineParams (line 64) | class PipelineParams(ParamGroup):
    method __init__ (line 65) | def __init__(self, parser):
  class OptimizationParams (line 71) | class OptimizationParams(ParamGroup):
    method __init__ (line 72) | def __init__(self, parser):
    method extract (line 111) | def extract(self, args):
  function get_combined_args (line 127) | def get_combined_args(parser : ArgumentParser):

FILE: gaussian_renderer/__init__.py
  function render (line 22) | def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch....

FILE: gaussian_renderer/network_gui.py
  function init (line 26) | def init(wish_host, wish_port):
  function try_connect (line 34) | def try_connect():
  function read (line 43) | def read():
  function send (line 50) | def send(message_bytes, verify):
  function receive (line 57) | def receive():

FILE: lpipsPyTorch/__init__.py
  function lpips (line 6) | def lpips(x: torch.Tensor,

FILE: lpipsPyTorch/modules/lpips.py
  class LPIPS (line 8) | class LPIPS(nn.Module):
    method __init__ (line 17) | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
    method forward (line 30) | def forward(self, x: torch.Tensor, y: torch.Tensor):

FILE: lpipsPyTorch/modules/networks.py
  function get_network (line 12) | def get_network(net_type: str):
  class LinLayers (line 23) | class LinLayers(nn.ModuleList):
    method __init__ (line 24) | def __init__(self, n_channels_list: Sequence[int]):
  class BaseNet (line 36) | class BaseNet(nn.Module):
    method __init__ (line 37) | def __init__(self):
    method set_requires_grad (line 46) | def set_requires_grad(self, state: bool):
    method z_score (line 50) | def z_score(self, x: torch.Tensor):
    method forward (line 53) | def forward(self, x: torch.Tensor):
  class SqueezeNet (line 66) | class SqueezeNet(BaseNet):
    method __init__ (line 67) | def __init__(self):
  class AlexNet (line 77) | class AlexNet(BaseNet):
    method __init__ (line 78) | def __init__(self):
  class VGG16 (line 88) | class VGG16(BaseNet):
    method __init__ (line 89) | def __init__(self):

FILE: lpipsPyTorch/modules/utils.py
  function normalize_activation (line 6) | def normalize_activation(x, eps=1e-10):
  function get_state_dict (line 11) | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):

FILE: metrics.py
  function readImages (line 24) | def readImages(renders_dir, gt_dir):
  function evaluate (line 36) | def evaluate(model_paths):

FILE: render.py
  function render_set (line 33) | def render_set(model_path, name, iteration, views, gaussians, pipeline, ...
  function render_sets (line 88) | def render_sets(dataset : ModelParams, iteration : int, pipeline : Pipel...

FILE: render_lerf_by_text.py
  function render_set (line 33) | def render_set(model_path, name, iteration, views, gaussians, pipeline, ...
  function render_sets (line 186) | def render_sets(dataset : ModelParams, iteration : int, pipeline : Pipel...

FILE: scene/__init__.py
  class Scene (line 21) | class Scene:
    method __init__ (line 25) | def __init__(self, args : ModelParams, gaussians : GaussianModel, load...
    method save (line 86) | def save(self, iteration, save_q=[]):
    method getTrainCameras (line 90) | def getTrainCameras(self, scale=1.0):
    method getTestCameras (line 93) | def getTestCameras(self, scale=1.0):

FILE: scene/cameras.py
  class Camera (line 17) | class Camera(nn.Module):
    method __init__ (line 18) | def __init__(self, colmap_id, R, T, FoVx, FoVy, cx, cy, image, depth, ...
    method to_gpu (line 77) | def to_gpu(self):
    method to_cpu (line 85) | def to_cpu(self):
  class MiniCam (line 92) | class MiniCam:
    method __init__ (line 93) | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_...

FILE: scene/colmap_loader.py
  function qvec2rotmat (line 43) | def qvec2rotmat(qvec):
  function rotmat2qvec (line 55) | def rotmat2qvec(R):
  class Image (line 68) | class Image(BaseImage):
    method qvec2rotmat (line 69) | def qvec2rotmat(self):
  function read_next_bytes (line 72) | def read_next_bytes(fid, num_bytes, format_char_sequence, endian_charact...
  function read_points3D_text (line 83) | def read_points3D_text(path):
  function read_points3D_binary (line 125) | def read_points3D_binary(path_to_model_file):
  function read_intrinsics_text (line 156) | def read_intrinsics_text(path):
  function read_extrinsics_binary (line 180) | def read_extrinsics_binary(path_to_model_file):
  function read_intrinsics_binary (line 215) | def read_intrinsics_binary(path_to_model_file):
  function read_extrinsics_text (line 244) | def read_extrinsics_text(path):
  function read_colmap_bin_array (line 273) | def read_colmap_bin_array(path):

FILE: scene/dataset_readers.py
  class CameraInfo (line 28) | class CameraInfo(NamedTuple):
  class SceneInfo (line 45) | class SceneInfo(NamedTuple):
  function getNerfppNorm (line 52) | def getNerfppNorm(cam_info):
  function readColmapCameras (line 75) | def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
  function fetchPly (line 140) | def fetchPly(path):
  function storePly (line 155) | def storePly(path, xyz, rgb):
  function readColmapSceneInfo (line 172) | def readColmapSceneInfo(path, images, eval, llffhold=8):
  function readCamerasFromTransforms (line 219) | def readCamerasFromTransforms(path, transformsfile, white_background, ex...
  function readNerfSyntheticInfo (line 324) | def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):

FILE: scene/gaussian_model.py
  function sigmoid (line 25) | def sigmoid(x):
  function distCUDA2 (line 28) | def distCUDA2(points):
  class GaussianModel (line 38) | class GaussianModel:
    method setup_functions (line 40) | def setup_functions(self):
    method __init__ (line 58) | def __init__(self, sh_degree : int):
    method capture (line 78) | def capture(self):
    method restore (line 96) | def restore(self, model_args, training_args):
    method get_scaling (line 117) | def get_scaling(self):
    method get_scaling_origin (line 121) | def get_scaling_origin(self):
    method get_rotation (line 125) | def get_rotation(self):
    method get_rotation_matrix (line 129) | def get_rotation_matrix(self):
    method get_eigenvector (line 133) | def get_eigenvector(self):
    method get_xyz (line 142) | def get_xyz(self):
    method get_features (line 146) | def get_features(self):
    method get_opacity (line 152) | def get_opacity(self):
    method get_ins_feat (line 157) | def get_ins_feat(self, origin=False):
    method get_covariance (line 165) | def get_covariance(self, scaling_modifier = 1):
    method oneupSHdegree (line 168) | def oneupSHdegree(self):
    method create_from_pcd (line 172) | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : fl...
    method training_setup (line 202) | def training_setup(self, training_args):
    method update_learning_rate (line 227) | def update_learning_rate(self, iteration, root_start, leaf_start):
    method construct_list_of_attributes (line 240) | def construct_list_of_attributes(self):
    method save_ply (line 255) | def save_ply(self, path, save_q=[]):
    method reset_opacity (line 289) | def reset_opacity(self):
    method load_ply (line 294) | def load_ply(self, path):
    method replace_tensor_to_optimizer (line 346) | def replace_tensor_to_optimizer(self, tensor, name):
    method _prune_optimizer (line 361) | def _prune_optimizer(self, mask):
    method prune_points (line 379) | def prune_points(self, mask):
    method cat_tensors_to_optimizer (line 396) | def cat_tensors_to_optimizer(self, tensors_dict):
    method densification_postfix (line 418) | def densification_postfix(self, new_xyz, new_features_dc, new_features...
    method densify_and_split (line 441) | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
    method densify_and_clone (line 468) | def densify_and_clone(self, grads, grad_threshold, scene_extent):
    method densify_and_prune (line 485) | def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_...
    method add_densification_stats (line 501) | def add_densification_stats(self, viewspace_point_tensor, update_filter):

FILE: scene/kmeans_quantize.py
  class Quantize_kMeans (line 12) | class Quantize_kMeans():
    method __init__ (line 13) | def __init__(self, num_clusters=64, num_leaf_clusters=10, num_iters=10...
    method get_dist (line 38) | def get_dist(self, x, y, mode='sq_euclidean'):
    method update_centers (line 58) | def update_centers(self, feat, mode="root", selected_leaf=-1):
    method update_centers_ (line 82) | def update_centers_(self, feat, cluster_mask=None, nn_index=None, avg=...
    method equalize_cluster_size (line 89) | def equalize_cluster_size(self, mode="root"):
    method cluster_assign (line 146) | def cluster_assign(self, feat, feat_scaled=None, mode="root", selected...
    method rescale (line 243) | def rescale(self, feat, scale=None):
    method forward (line 252) | def forward(self, gaussian, iteration, assign=False, mode="root", sele...
    method replace_with_centers (line 277) | def replace_with_centers(self, gaussian):

FILE: scripts/compute_lerf_iou.py
  function load_image_as_binary (line 6) | def load_image_as_binary(image_path, is_png=False, threshold=10):
  function calculate_iou (line 14) | def calculate_iou(mask1, mask2):
  function evalute (line 21) | def evalute(gt_base, pred_base, scene_name):

FILE: scripts/eval_scannet.py
  function sigmoid (line 29) | def sigmoid(x):
  function write_ply (line 32) | def write_ply(vertex_data, output_path):
  function read_labels_from_ply (line 47) | def read_labels_from_ply(file_path):
  function calculate_metrics (line 55) | def calculate_metrics(gt, pred, total_classes):

FILE: scripts/render_by_click.py
  function get_pixel_values (line 35) | def get_pixel_values(image_path, position, radius=10):
  function compute_click_values (line 55) | def compute_click_values(model_path, image_name, pix_xy, radius=5):
  function render_set (line 69) | def render_set(model_path, name, iteration, views, gaussians, pipeline, ...
  function render_sets (line 266) | def render_sets(dataset : ModelParams, iteration : int, pipeline : Pipel...

FILE: scripts/scannet2blender.py
  function load_transform_matrix (line 5) | def load_transform_matrix(file_path):
  function process_directory (line 13) | def process_directory(directory_path):

FILE: scripts/vis_opengs_pts_feat.py
  function sigmoid (line 5) | def sigmoid(x):
  function visualize_ply (line 9) | def visualize_ply(ply_path):

FILE: train.py
  function dec2binary (line 52) | def dec2binary(x, n_bits=None):
  function save_kmeans (line 62) | def save_kmeans(kmeans_list, quantized_params, out_dir, mode="root"):
  function cohesion_loss (line 102) | def cohesion_loss(feat_map, gt_mask, feat_mean_stack):
  function separation_loss (line 123) | def separation_loss(feat_mean_stack, iteration):
  function training (line 157) | def training(dataset, opt, pipe, testing_iterations, saving_iterations, ...
  function prepare_output_and_logger (line 596) | def prepare_output_and_logger(args):
  function construct_pseudo_ins_feat (line 618) | def construct_pseudo_ins_feat(scene : Scene, renderFunc, renderArgs,
  function training_report (line 912) | def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, t...
  function initialize_new_params (line 952) | def initialize_new_params(new_pt_cld, mean3_sq_dist):

FILE: utils/camera_utils.py
  function loadCam (line 20) | def loadCam(args, id, cam_info, resolution_scale):
  function cameraList_from_camInfos (line 76) | def cameraList_from_camInfos(cam_infos, resolution_scale, args):
  function camera_to_JSON (line 84) | def camera_to_JSON(id, camera : Camera):

FILE: utils/general_utils.py
  function inverse_sigmoid (line 18) | def inverse_sigmoid(x):
  function PILtoTorch (line 21) | def PILtoTorch(pil_image, resolution):
  function get_expon_lr_func (line 29) | def get_expon_lr_func(
  function strip_lowerdiag (line 64) | def strip_lowerdiag(L):
  function strip_symmetric (line 75) | def strip_symmetric(sym):
  function build_rotation (line 78) | def build_rotation(r):
  function build_scaling_rotation (line 101) | def build_scaling_rotation(s, r):
  function safe_state (line 112) | def safe_state(silent):

FILE: utils/graphics_utils.py
  class BasicPointCloud (line 17) | class BasicPointCloud(NamedTuple):
  function geom_transform_points (line 22) | def geom_transform_points(points, transf_matrix):
  function getWorld2View (line 31) | def getWorld2View(R, t):
  function getWorld2View2 (line 38) | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
  function getProjectionMatrix (line 51) | def getProjectionMatrix(znear, zfar, fovX, fovY):
  function fov2focal (line 73) | def fov2focal(fov, pixels):
  function focal2fov (line 76) | def focal2fov(focal, pixels):

FILE: utils/image_utils.py
  function mse (line 14) | def mse(img1, img2):
  function psnr (line 17) | def psnr(img1, img2):

FILE: utils/loss_utils.py
  function l1_loss (line 17) | def l1_loss(network_output, gt, mask=None, weight=None):
  function l2_loss (line 25) | def l2_loss(network_output, gt, mask=None, weight=None):
  function gaussian (line 33) | def gaussian(window_size, sigma):
  function create_window (line 37) | def create_window(window_size, channel):
  function ssim (line 43) | def ssim(img1, img2, window_size=11, size_average=True):
  function _ssim (line 53) | def _ssim(img1, img2, window, window_size, channel, size_average=True):

FILE: utils/opengs_utlis.py
  function calculate_pairwise_distances (line 8) | def calculate_pairwise_distances(tensor1, tensor2, metric=None):
  function calculate_distances (line 36) | def calculate_distances(tensor1, tensor2, metric=None):
  function bin2dec (line 61) | def bin2dec(b, bits):
  function load_code_book (line 68) | def load_code_book(base_path):
  function calculate_iou (line 90) | def calculate_iou(masks1, masks2, base=None):
  function get_SAM_mask_and_feat (line 125) | def get_SAM_mask_and_feat(gt_sam_mask, level=3, filter_th=50, original_m...
  function pair_mask_feature_mean (line 184) | def pair_mask_feature_mean(feat_map, masks):
  function process_in_chunks (line 203) | def process_in_chunks(masks_expanded, masked_feats, mean_per_channel, ch...
  function calculate_variance_in_chunks (line 216) | def calculate_variance_in_chunks(masked_for_variance, mask_counts, chunk...
  function ele_multip_in_chunks (line 228) | def ele_multip_in_chunks(feat_expanded, masks_expanded, chunk_size=5):
  function mask_feature_mean (line 240) | def mask_feature_mean(feat_map, gt_masks, image_mask=None, return_var=Fa...
  function linear_to_srgb (line 285) | def linear_to_srgb(linear):
  function srgb_to_linear (line 300) | def srgb_to_linear(srgb):

FILE: utils/sh_utils.py
  function eval_sh (line 57) | def eval_sh(deg, sh, dirs):
  function RGB2SH (line 114) | def RGB2SH(rgb):
  function SH2RGB (line 117) | def SH2RGB(sh):

FILE: utils/system_utils.py
  function mkdir_p (line 16) | def mkdir_p(folder_path):
  function searchForMaxIteration (line 26) | def searchForMaxIteration(folder):
Condensed preview — 38 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (287K chars).
[
  {
    "path": ".gitignore",
    "chars": 202,
    "preview": "*.pyc\n# .vscode\n.git---\noutput\nbuild\ndiff_rasterization/diff_rast.egg-info\ndiff_rasterization/dist\ntensorboard_3d\nscreen"
  },
  {
    "path": "LICENSE.md",
    "chars": 4662,
    "preview": "Gaussian-Splatting License  \n===========================  \n\n**Inria** and **the Max Planck Institut for Informatik (MPII"
  },
  {
    "path": "README.md",
    "chars": 10954,
    "preview": "<div align=\"center\">\n\n# [NeurIPS2024🔥] OpenGaussian: Towards Point-Level 3D Gaussian-based Open Vocabulary Understanding"
  },
  {
    "path": "arguments/__init__.py",
    "chars": 5613,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "convert.py",
    "chars": 5349,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "environment.yml",
    "chars": 324,
    "preview": "name: gaussian_splatting\nchannels:\n  - pytorch\n  - conda-forge\n  - defaults\ndependencies:\n  - cudatoolkit=11.6\n  - plyfi"
  },
  {
    "path": "full_eval.py",
    "chars": 3340,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "gaussian_renderer/__init__.py",
    "chars": 17126,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "gaussian_renderer/network_gui.py",
    "chars": 2716,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "lpipsPyTorch/__init__.py",
    "chars": 635,
    "preview": "import torch\n\nfrom .modules.lpips import LPIPS\n\n\ndef lpips(x: torch.Tensor,\n          y: torch.Tensor,\n          net_typ"
  },
  {
    "path": "lpipsPyTorch/modules/lpips.py",
    "chars": 1151,
    "preview": "import torch\nimport torch.nn as nn\n\nfrom .networks import get_network, LinLayers\nfrom .utils import get_state_dict\n\n\ncla"
  },
  {
    "path": "lpipsPyTorch/modules/networks.py",
    "chars": 2692,
    "preview": "from typing import Sequence\n\nfrom itertools import chain\n\nimport torch\nimport torch.nn as nn\nfrom torchvision import mod"
  },
  {
    "path": "lpipsPyTorch/modules/utils.py",
    "chars": 885,
    "preview": "from collections import OrderedDict\n\nimport torch\n\n\ndef normalize_activation(x, eps=1e-10):\n    norm_factor = torch.sqrt"
  },
  {
    "path": "metrics.py",
    "chars": 4143,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "render.py",
    "chars": 5497,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "render_lerf_by_text.py",
    "chars": 12114,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "scene/__init__.py",
    "chars": 4121,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "scene/cameras.py",
    "chars": 3898,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "scene/colmap_loader.py",
    "chars": 11859,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "scene/dataset_readers.py",
    "chars": 14228,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "scene/gaussian_model.py",
    "chars": 24240,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "scene/kmeans_quantize.py",
    "chars": 13750,
    "preview": "import os\nimport pdb\nfrom tqdm import tqdm\nimport time\n\nimport torch\nimport numpy as np\nfrom torch import nn\nimport torc"
  },
  {
    "path": "scripts/compute_lerf_iou.py",
    "chars": 3162,
    "preview": "import os\nimport numpy as np\nfrom PIL import Image\nfrom argparse import ArgumentParser\n\ndef load_image_as_binary(image_p"
  },
  {
    "path": "scripts/eval_scannet.py",
    "chars": 7691,
    "preview": "import os\nfrom plyfile import PlyData, PlyElement\nimport torch.nn.functional as F\nimport numpy as np\nimport torch\nimport"
  },
  {
    "path": "scripts/render_by_click.py",
    "chars": 16356,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "scripts/scannet2blender.py",
    "chars": 3875,
    "preview": "import os\nimport json\nimport numpy as np\n\ndef load_transform_matrix(file_path):\n    \"\"\"\n    Load the transform matrix fr"
  },
  {
    "path": "scripts/train_lerf.sh",
    "chars": 3854,
    "preview": "#!/bin/bash\n# chmod +x scripts/train_lerf.sh\n# ./scripts/train_lerf.sh\n\n# !!! Please check the dataset path specified by"
  },
  {
    "path": "scripts/train_scannet.sh",
    "chars": 1494,
    "preview": "#!/bin/bash\n# chmod +x scripts/train_scannet.sh\n# ./scripts/train_scannet.sh\n\n# ============== [Notice] ==============\n#"
  },
  {
    "path": "scripts/vis_opengs_pts_feat.py",
    "chars": 1203,
    "preview": "import numpy as np\nfrom plyfile import PlyData\nimport open3d as o3d\n\ndef sigmoid(x):\n    \"\"\"Sigmoid function.\"\"\"\n    ret"
  },
  {
    "path": "train.py",
    "chars": 57226,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "utils/camera_utils.py",
    "chars": 3750,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "utils/general_utils.py",
    "chars": 3971,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "utils/graphics_utils.py",
    "chars": 2052,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "utils/image_utils.py",
    "chars": 554,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "utils/loss_utils.py",
    "chars": 2641,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "utils/opengs_utlis.py",
    "chars": 13911,
    "preview": "import torch\nimport numpy as np\nimport torch.nn.functional as F\nimport os\nfrom bitarray import bitarray\nfrom collections"
  },
  {
    "path": "utils/sh_utils.py",
    "chars": 4371,
    "preview": "#  Copyright 2021 The PlenOctree Authors.\n#  Redistribution and use in source and binary forms, with or without\n#  modif"
  },
  {
    "path": "utils/system_utils.py",
    "chars": 785,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  }
]

About this extraction

This page contains the full source code of the yanmin-wu/OpenGaussian GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 38 files (269.9 KB), approximately 70.9k tokens, and a symbol index with 189 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!