Repository: yanmin-wu/OpenGaussian
Branch: main
Commit: 1f99db1b7e7d
Files: 38
Total size: 269.9 KB
Directory structure:
gitextract_gh4x629c/
├── .gitignore
├── LICENSE.md
├── README.md
├── arguments/
│ └── __init__.py
├── convert.py
├── environment.yml
├── full_eval.py
├── gaussian_renderer/
│ ├── __init__.py
│ └── network_gui.py
├── lpipsPyTorch/
│ ├── __init__.py
│ └── modules/
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── metrics.py
├── render.py
├── render_lerf_by_text.py
├── scene/
│ ├── __init__.py
│ ├── cameras.py
│ ├── colmap_loader.py
│ ├── dataset_readers.py
│ ├── gaussian_model.py
│ └── kmeans_quantize.py
├── scripts/
│ ├── compute_lerf_iou.py
│ ├── eval_scannet.py
│ ├── render_by_click.py
│ ├── scannet2blender.py
│ ├── train_lerf.sh
│ ├── train_scannet.sh
│ └── vis_opengs_pts_feat.py
├── train.py
└── utils/
├── camera_utils.py
├── general_utils.py
├── graphics_utils.py
├── image_utils.py
├── loss_utils.py
├── opengs_utlis.py
├── sh_utils.py
└── system_utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pyc
# .vscode
.git---
output
build
diff_rasterization/diff_rast.egg-info
diff_rasterization/dist
tensorboard_3d
screenshots
*.ipynb_checkpoints
# submodules/
# assets/
*.npz
*.bundle
output*
*.log
log
================================================
FILE: LICENSE.md
================================================
Gaussian-Splatting License
===========================
**Inria** and **the Max Planck Institut for Informatik (MPII)** hold all the ownership rights on the *Software* named **gaussian-splatting**.
The *Software* is in the process of being registered with the Agence pour la Protection des
Programmes (APP).
The *Software* is still being developed by the *Licensor*.
*Licensor*'s goal is to allow the research community to use, test and evaluate
the *Software*.
## 1. Definitions
*Licensee* means any person or entity that uses the *Software* and distributes
its *Work*.
*Licensor* means the owners of the *Software*, i.e Inria and MPII
*Software* means the original work of authorship made available under this
License ie gaussian-splatting.
*Work* means the *Software* and any additions to or derivative works of the
*Software* that are made available under this License.
## 2. Purpose
This license is intended to define the rights granted to the *Licensee* by
Licensors under the *Software*.
## 3. Rights granted
For the above reasons Licensors have decided to distribute the *Software*.
Licensors grant non-exclusive rights to use the *Software* for research purposes
to research users (both academic and industrial), free of charge, without right
to sublicense.. The *Software* may be used "non-commercially", i.e., for research
and/or evaluation purposes only.
Subject to the terms and conditions of this License, you are granted a
non-exclusive, royalty-free, license to reproduce, prepare derivative works of,
publicly display, publicly perform and distribute its *Work* and any resulting
derivative works in any form.
## 4. Limitations
**4.1 Redistribution.** You may reproduce or distribute the *Work* only if (a) you do
so under this License, (b) you include a complete copy of this License with
your distribution, and (c) you retain without modification any copyright,
patent, trademark, or attribution notices that are present in the *Work*.
**4.2 Derivative Works.** You may specify that additional or different terms apply
to the use, reproduction, and distribution of your derivative works of the *Work*
("Your Terms") only if (a) Your Terms provide that the use limitation in
Section 2 applies to your derivative works, and (b) you identify the specific
derivative works that are subject to Your Terms. Notwithstanding Your Terms,
this License (including the redistribution requirements in Section 3.1) will
continue to apply to the *Work* itself.
**4.3** Any other use without of prior consent of Licensors is prohibited. Research
users explicitly acknowledge having received from Licensors all information
allowing to appreciate the adequacy between of the *Software* and their needs and
to undertake all necessary precautions for its execution and use.
**4.4** The *Software* is provided both as a compiled library file and as source
code. In case of using the *Software* for a publication or other results obtained
through the use of the *Software*, users are strongly encouraged to cite the
corresponding publications as explained in the documentation of the *Software*.
## 5. Disclaimer
THE USER CANNOT USE, EXPLOIT OR DISTRIBUTE THE *SOFTWARE* FOR COMMERCIAL PURPOSES
WITHOUT PRIOR AND EXPLICIT CONSENT OF LICENSORS. YOU MUST CONTACT INRIA FOR ANY
UNAUTHORIZED USE: stip-sophia.transfert@inria.fr . ANY SUCH ACTION WILL
CONSTITUTE A FORGERY. THIS *SOFTWARE* IS PROVIDED "AS IS" WITHOUT ANY WARRANTIES
OF ANY NATURE AND ANY EXPRESS OR IMPLIED WARRANTIES, WITH REGARDS TO COMMERCIAL
USE, PROFESSIONNAL USE, LEGAL OR NOT, OR OTHER, OR COMMERCIALISATION OR
ADAPTATION. UNLESS EXPLICITLY PROVIDED BY LAW, IN NO EVENT, SHALL INRIA OR THE
AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE
GOODS OR SERVICES, LOSS OF USE, DATA, OR PROFITS OR BUSINESS INTERRUPTION)
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING FROM, OUT OF OR
IN CONNECTION WITH THE *SOFTWARE* OR THE USE OR OTHER DEALINGS IN THE *SOFTWARE*.
## 6. Files subject to permissive licenses
The contents of the file ```utils/loss_utils.py``` are based on publicly available code authored by Evan Su, which falls under the permissive MIT license.
Title: pytorch-ssim\
Project code: https://github.com/Po-Hsun-Su/pytorch-ssim\
Copyright Evan Su, 2017\
License: https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/LICENSE.txt (MIT)
================================================
FILE: README.md
================================================
# [NeurIPS2024🔥] OpenGaussian: Towards Point-Level 3D Gaussian-based Open Vocabulary Understanding
[Yanmin Wu](https://yanmin-wu.github.io/)
1, [Jiarui Meng](https://scholar.google.com/citations?user=N_pRAVAAAAAJ&hl=en&oi=ao)
1, [Haijie Li](https://villa.jianzhang.tech/people/haijie-li-%E6%9D%8E%E6%B5%B7%E6%9D%B0/)
1, [Chenming Wu](https://chenming-wu.github.io/)
2*, [Yahao Shi](https://scholar.google.com/citations?user=-VJZrUkAAAAJ&hl=en)
3, [Xinhua Cheng](https://cxh0519.github.io/)
1,
[Chen Zhao](https://openreview.net/profile?id=~Chen_Zhao9)
2, [Haocheng Feng](https://openreview.net/profile?id=~Haocheng_Feng1)
2, [Errui Ding](https://scholar.google.com/citations?user=1wzEtxcAAAAJ&hl=zh-CN)
2, [Jingdong Wang](https://jingdongwang2017.github.io/)
2, [Jian Zhang](https://jianzhang.tech/)
1*
1 Peking University,
2 Baidu VIS,
3 Beihang University
## 0. Installation
The installation of OpenGaussian is similar to [3D Gaussian Splatting](https://github.com/graphdeco-inria/gaussian-splatting).
```
git clone https://github.com/yanmin-wu/OpenGaussian.git
```
Then install the dependencies:
```shell
conda env create --file environment.yml
conda activate gaussian_splatting
# the rasterization lib comes from DreamGaussian
cd OpenGaussian/submodules
unzip ashawkey-diff-gaussian-rasterization.zip
pip install ./ashawkey-diff-gaussian-rasterization
```
+ other additional dependencies: bitarray, scipy, [pytorch3d](https://anaconda.org/pytorch3d/pytorch3d/files)
```shell
pip install bitarray scipy
# install a pytorch3d version compatible with your PyTorch, Python, and CUDA.
```
+ `simple-knn` is not required
---
## 1. ToDo list
+ [x] Point feature visualization
+ [x] Data preprocessing
+ ~~[ ] Improved SAM mask extraction (extracting only one layer)~~
+ [x] Click to Select 3D Object
---
## 2. Data preparation
The files are as follows:
```
[DATA_ROOT]
├── [1] scannet/
│ │ ├── scene0000_00/
| | | |── color/
| | | |── language_features/
| | | |── points3d.ply
| | | |── transforms_train/test.json
| | | |── *_vh_clean_2.labels.ply
│ │ ├── scene0062_00/
│ │ └── ...
├── [2] lerf_ovs/
│ │ ├── figurines/ & ramen/ & teatime/ & waldo_kitchen/
| | | |── images/
| | | |── language_features/
| | | |── sparse/
│ │ ├── label/
```
+ **[1] Prepare ScanNet Data**
+ You can directly download our pre-processed data: [**OneDrive**](https://onedrive.live.com/?authkey=%21AIgsXZy3gl%5FuKmM&id=744D3E86422BE3C9%2139813&cid=744D3E86422BE3C9) / [Baidu](https://pan.baidu.com/s/1B_tGYla5dWyJRu3jTNTMvA?pwd=u5iy). Please unzip the `color.zip` and `language_features.zip` files.
+ The ScanNet dataset requires permission for use, following the [ScanNet instructions](https://github.com/ScanNet/ScanNet) to apply for dataset permission.
+ **If you want to process more scenes from the ScanNet dataset, you can follow these steps:**
+ First, use the official `download-scannet.py` script provided by ScanNet to download the `.sens` archive of the specified scenes;
+ Then, refer to the [`preprocess_2d_scannet.py`](https://github.com/pengsongyou/openscene/blob/main/scripts/preprocess/preprocess_2d_scannet.py) script to extract the `color` and `pose` information;
+ Finally, convert the data into Blender format using the [`scripts/scannet2blender.py`](https://github.com/yanmin-wu/OpenGaussian/blob/main/scripts/scannet2blender.py) script. Please check the `TODO` comments in the script to specify the paths.
+ **[2] Prepare lerf_ovs Data**
+ You can directly download our pre-processed data: [**OneDrive**](https://onedrive.live.com/?authkey=%21AIgsXZy3gl%5FuKmM&id=744D3E86422BE3C9%2139815&cid=744D3E86422BE3C9) / [Baidu](https://pan.baidu.com/s/1B_tGYla5dWyJRu3jTNTMvA?pwd=u5iy) (re-annotated by LangSplat). Please unzip the `images.zip` and `language_features.zip` files.
+ **Mask and Language Feature Extraction Details**
+ We use the tools provided by LangSplat to extract the SAM mask and CLIP features, but we only use the large-level mask.
---
## 3. Training
### 3.1 ScanNet
```shell
chmod +x scripts/train_scannet.sh
./scripts/train_scannet.sh
```
+ Please ***check*** the script for more details and ***modify*** the dataset path.
+ you will see the following processes during training:
```shell
[Stage 0] Start 3dgs pre-train ... (step 0-30k)
[Stage 1] Start continuous instance feature learning ... (step 30-50k)
[Stage 2.1] Start coarse-level codebook discretization ... (step 50-70k)
[Stage 2.2] Start fine-level codebook discretization ... (step 70-90k)
[Stage 3] Start 2D language feature - 3D cluster association ... (1 min)
```
+ Intermediate results from different stages can be found in subfolders `***/train_process/stage*`. (The intermediate results of stage 3 are recommended to be observed in the LeRF dataset.)
### 3.2 LeRF_ovs
```shell
chmod +x scripts/train_lerf.sh
./scripts/train_lerf.sh
```
+ Please ***check*** the script for more details and ***modify*** the dataset path.
+ you will see the following processes during training:
```shell
[Stage 0] Start 3dgs pre-train ... (step 0-30k)
[Stage 1] Start continuous instance feature learning ... (step 30-40k)
[Stage 2.1] Start coarse-level codebook discretization ... (step 40-50k)
[Stage 2.2] Start fine-level codebook discretization ... (step 50-70k)
[Stage 3] Start 2D language feature - 3D cluster association ... (1 min)
```
+ Intermediate results from different stages can be found in subfolders `***/train_process/stage*`.
### 3.3 Custom data
+ Without any special processing, videos are first captured, approximately 200 frames are sampled, and COLMAP is then used to initialize the point cloud and camera poses.
---
## 4. Render & Eval & Downstream Tasks
### 4.1 3D Instance Feature Visualization
+ Please install `open3d` first, and then execute the following command on a system with UI support:
```python
python scripts/vis_opengs_pts_feat.py
```
+ Please specify `ply_path` in the script as the PLY file `output/xxxxxxxx-x/point_cloud/iteration_x0000/point_cloud.ply` saved at different stages.
+ During the training process, we have saved the first three dimensions of the 6D features as colors for visualization; see [here](https://github.com/yanmin-wu/OpenGaussian/blob/2845b9c744c1b06ac6930ffa2d2a6f9167f1b843/scene/gaussian_model.py#L272).
### 4.2 Render 2D Feature Map
+ The same rendering method as the 3DGS rendering colors.
```shell
python render.py -m "output/xxxxxxxx-x"
```
You can find the rendered feature maps in subfolders `renders_ins_feat1` and `renders_ins_feat2`.
### 4.3 ScanNet Evalution (Open-Vocabulary Point Cloud Understanding)
> Due to code optimization and the use of more suitable hyperparameters, the latest evaluation metrics may be higher than those reported in the paper.
+ Evaluate text-guided segmentation performance on ScanNet for 19, 15, and 10 categories.
```shell
# unzip the pre-extracted text features
cd assets
unzip text_features.zip
# 1. please check the `gt_file_path` and `model_path` are correct
# 2. specify `target_id` as 19, 15, or 10 categories.
python scripts/eval_scannet.py
```
### 4.4 LeRF Evalution (Open-Vocabulary Object Selection in 3D Space)
+ (1) First, render text-selected 3D Gaussians into multi-view images.
```shell
# unzip the pre-extracted text features
cd assets
unzip text_features.zip
# 1. specify the model path using -m
# 2. specify the scene name: figurines, teatime, ramen, waldo_kitchen
python render_lerf_by_text.py -m "output/xxxxxxxx-x" --scene_name "figurines"
```
The object selection results are saved in `output/xxxxxxxx-x/text2obj/ours_70000/renders_cluster`.
+ (2) Then, compute evaluation metrics.
> Due to code optimization and the use of more suitable hyperparameters, the latest evaluation metrics may be higher than those reported in the paper.
> The metrics may be unstable due to the limited evaluation samples of LeRF.
```shell
# 1. change path_gt and path_pred in the script
# 2. specify the scene name: figurines, teatime, ramen, waldo_kitchen
python scripts/compute_lerf_iou.py --scene_name "figurines"
```
### 4.5 Click to Select 3D Object
+ (1) First, you need to render the feature maps (refer to Step 4.3; in practice, only two feature maps from a single view are required).
+ (2) Then, check the [`scripts/render_by_click.py`](https://github.com/yanmin-wu/OpenGaussian/blob/main/scripts/render_by_click.py) script for `TODO` comments, including specifying the frame filename, clicked pixel coordinates, and file paths.
+ (3) Finally, run the [`scripts/render_by_click.py`](https://github.com/yanmin-wu/OpenGaussian/blob/main/scripts/render_by_click.py) script. *Note that this script has not been tested with the current version of the code and may require debugging*.
---
## 5. Acknowledgements
We are quite grateful for [3DGS](https://github.com/graphdeco-inria/gaussian-splatting), [LangSplat](https://github.com/minghanqin/LangSplat), [CompGS](https://github.com/UCDvision/compact3d), [LEGaussians](https://github.com/buaavrcg/LEGaussians), [SAGA](https://github.com/Jumpat/SegAnyGAussians), and [SAM](https://segment-anything.com/).
---
## 6. Citation
```
@inproceedings{wu2024opengaussian,
title={OpenGaussian: Towards Point-Level 3D Gaussian-based Open Vocabulary Understanding},
author={Wu, Yanmin and Meng, Jiarui and Li, Haijie and Wu, Chenming and Shi, Yahao and Cheng, Xinhua and Zhao, Chen and Feng, Haocheng and Ding, Errui and Wang, Jingdong and Zhang, Jian},
booktitle={Proceedings of the Advances in Neural Information Processing Systems (NeurIPS)},
pages={19114--19138},
year={2024}
}
```
---
## 7. Contact
If you have any questions about this project, please feel free to contact [Yanmin Wu](https://yanmin-wu.github.io/): wuyanminmax[AT]gmail.com
================================================
FILE: arguments/__init__.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
from argparse import ArgumentParser, Namespace
import sys
import os
class GroupParams:
pass
class ParamGroup:
def __init__(self, parser: ArgumentParser, name : str, fill_none = False):
group = parser.add_argument_group(name)
for key, value in vars(self).items():
shorthand = False
if key.startswith("_"):
shorthand = True
key = key[1:]
t = type(value)
value = value if not fill_none else None
if shorthand:
if t == bool:
group.add_argument("--" + key, ("-" + key[0:1]), default=value, action="store_true")
else:
group.add_argument("--" + key, ("-" + key[0:1]), default=value, type=t)
else:
if t == bool:
group.add_argument("--" + key, default=value, action="store_true")
else:
group.add_argument("--" + key, default=value, type=t)
def extract(self, args):
group = GroupParams()
for arg in vars(args).items():
if arg[0] in vars(self) or ("_" + arg[0]) in vars(self):
setattr(group, arg[0], arg[1])
return group
class ModelParams(ParamGroup):
def __init__(self, parser, sentinel=False):
self.sh_degree = 3
self._source_path = ""
self._model_path = ""
self._images = "images"
self._resolution = -1
self._white_background = False
self.data_device = "cuda"
self.eval = False
super().__init__(parser, "Loading Parameters", sentinel)
def extract(self, args):
g = super().extract(args)
g.source_path = os.path.abspath(g.source_path)
return g
class PipelineParams(ParamGroup):
def __init__(self, parser):
self.convert_SHs_python = False
self.compute_cov3D_python = False
self.debug = False
super().__init__(parser, "Pipeline Parameters")
class OptimizationParams(ParamGroup):
def __init__(self, parser):
self.leaf_update_fr = 300 # coarse-level codebook update frequency
self.ins_feat_dim = 6
self.position_lr_init = 0.00016
self.position_lr_final = 0.0000016
self.position_lr_delay_mult = 0.01
self.position_lr_max_steps = 30_000
self.feature_lr = 0.0025
self.ins_feat_lr = 0.001
self.opacity_lr = 0.05
self.scaling_lr = 0.005
self.rotation_lr = 0.001
self.percent_dense = 0.01
self.lambda_dssim = 0.2
self.densification_interval = 100
self.opacity_reset_interval = 3000
self.densify_from_iter = 500
self.densify_until_iter = 15_000
self.densify_grad_threshold = 0.0002
self.random_background = False
parser.add_argument('--root_node_num', type=int, default=64) # k1=64
parser.add_argument('--leaf_node_num', type=int, default=5) # k2=5/10
parser.add_argument('--pos_weight', type=float, default=1.0) # position weight for coarse codebook
parser.add_argument('--loss_weight', type=float, default=0.1) # loss_cohesion weight
parser.add_argument('--iterations', type=int, default=70_000) # default 7w, scannet 9w
parser.add_argument('--start_ins_feat_iter', type=int, default=30_000) # default 3w
parser.add_argument('--start_root_cb_iter', type=int, default=40_000) # default 4w, scannet 5w
parser.add_argument('--start_leaf_cb_iter', type=int, default=50_000) # default 5w, scannet 7w
# note: Freeze the position of the initial point, do not densify. for ScanNet
parser.add_argument('--frozen_init_pts', action='store_true', default=False)
parser.add_argument('--sam_level', type=int, default=3)
parser.add_argument('--save_memory', action='store_true', default=False)
super().__init__(parser, "Optimization Parameters")
def extract(self, args):
g = super().extract(args)
g.root_node_num = args.root_node_num
g.leaf_node_num = args.leaf_node_num
g.pos_weight = args.pos_weight
g.loss_weight = args.loss_weight
g.frozen_init_pts = args.frozen_init_pts
g.sam_level = args.sam_level
g.iterations = args.iterations
g.start_ins_feat_iter = args.start_ins_feat_iter
g.start_root_cb_iter = args.start_root_cb_iter
g.start_leaf_cb_iter = args.start_leaf_cb_iter
g.save_memory = args.save_memory
return g
def get_combined_args(parser : ArgumentParser):
cmdlne_string = sys.argv[1:]
cfgfile_string = "Namespace()"
args_cmdline = parser.parse_args(cmdlne_string)
try:
cfgfilepath = os.path.join(args_cmdline.model_path, "cfg_args")
print("Looking for config file in", cfgfilepath)
with open(cfgfilepath) as cfg_file:
print("Config file found: {}".format(cfgfilepath))
cfgfile_string = cfg_file.read()
except TypeError:
print("Config file not found at")
pass
args_cfgfile = eval(cfgfile_string)
merged_dict = vars(args_cfgfile).copy()
for k,v in vars(args_cmdline).items():
if v != None:
merged_dict[k] = v
return Namespace(**merged_dict)
================================================
FILE: convert.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import os
import logging
from argparse import ArgumentParser
import shutil
# This Python script is based on the shell converter script provided in the MipNerF 360 repository.
parser = ArgumentParser("Colmap converter")
parser.add_argument("--no_gpu", action='store_true')
parser.add_argument("--skip_matching", action='store_true')
parser.add_argument("--source_path", "-s", required=True, type=str)
parser.add_argument("--camera", default="OPENCV", type=str)
parser.add_argument("--colmap_executable", default="", type=str)
parser.add_argument("--resize", action="store_true")
parser.add_argument("--magick_executable", default="", type=str)
args = parser.parse_args()
colmap_command = '"{}"'.format(args.colmap_executable) if len(args.colmap_executable) > 0 else "colmap"
magick_command = '"{}"'.format(args.magick_executable) if len(args.magick_executable) > 0 else "magick"
use_gpu = 1 if not args.no_gpu else 0
if not args.skip_matching:
os.makedirs(args.source_path + "/distorted/sparse", exist_ok=True)
## Feature extraction
feat_extracton_cmd = colmap_command + " feature_extractor "\
"--database_path " + args.source_path + "/distorted/database.db \
--image_path " + args.source_path + "/input \
--ImageReader.single_camera 1 \
--ImageReader.camera_model " + args.camera + " \
--SiftExtraction.use_gpu " + str(use_gpu)
exit_code = os.system(feat_extracton_cmd)
if exit_code != 0:
logging.error(f"Feature extraction failed with code {exit_code}. Exiting.")
exit(exit_code)
## Feature matching
feat_matching_cmd = colmap_command + " exhaustive_matcher \
--database_path " + args.source_path + "/distorted/database.db \
--SiftMatching.use_gpu " + str(use_gpu)
exit_code = os.system(feat_matching_cmd)
if exit_code != 0:
logging.error(f"Feature matching failed with code {exit_code}. Exiting.")
exit(exit_code)
### Bundle adjustment
# The default Mapper tolerance is unnecessarily large,
# decreasing it speeds up bundle adjustment steps.
mapper_cmd = (colmap_command + " mapper \
--database_path " + args.source_path + "/distorted/database.db \
--image_path " + args.source_path + "/input \
--output_path " + args.source_path + "/distorted/sparse \
--Mapper.ba_global_function_tolerance=0.000001")
exit_code = os.system(mapper_cmd)
if exit_code != 0:
logging.error(f"Mapper failed with code {exit_code}. Exiting.")
exit(exit_code)
### Image undistortion
## We need to undistort our images into ideal pinhole intrinsics.
img_undist_cmd = (colmap_command + " image_undistorter \
--image_path " + args.source_path + "/input \
--input_path " + args.source_path + "/distorted/sparse/0 \
--output_path " + args.source_path + "\
--output_type COLMAP")
exit_code = os.system(img_undist_cmd)
if exit_code != 0:
logging.error(f"Mapper failed with code {exit_code}. Exiting.")
exit(exit_code)
files = os.listdir(args.source_path + "/sparse")
os.makedirs(args.source_path + "/sparse/0", exist_ok=True)
# Copy each file from the source directory to the destination directory
for file in files:
if file == '0':
continue
source_file = os.path.join(args.source_path, "sparse", file)
destination_file = os.path.join(args.source_path, "sparse", "0", file)
shutil.move(source_file, destination_file)
if(args.resize):
print("Copying and resizing...")
# Resize images.
os.makedirs(args.source_path + "/images_2", exist_ok=True)
os.makedirs(args.source_path + "/images_4", exist_ok=True)
os.makedirs(args.source_path + "/images_8", exist_ok=True)
# Get the list of files in the source directory
files = os.listdir(args.source_path + "/images")
# Copy each file from the source directory to the destination directory
for file in files:
source_file = os.path.join(args.source_path, "images", file)
destination_file = os.path.join(args.source_path, "images_2", file)
shutil.copy2(source_file, destination_file)
exit_code = os.system(magick_command + " mogrify -resize 50% " + destination_file)
if exit_code != 0:
logging.error(f"50% resize failed with code {exit_code}. Exiting.")
exit(exit_code)
destination_file = os.path.join(args.source_path, "images_4", file)
shutil.copy2(source_file, destination_file)
exit_code = os.system(magick_command + " mogrify -resize 25% " + destination_file)
if exit_code != 0:
logging.error(f"25% resize failed with code {exit_code}. Exiting.")
exit(exit_code)
destination_file = os.path.join(args.source_path, "images_8", file)
shutil.copy2(source_file, destination_file)
exit_code = os.system(magick_command + " mogrify -resize 12.5% " + destination_file)
if exit_code != 0:
logging.error(f"12.5% resize failed with code {exit_code}. Exiting.")
exit(exit_code)
print("Done.")
================================================
FILE: environment.yml
================================================
name: gaussian_splatting
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- cudatoolkit=11.6
- plyfile=0.8.1
- python=3.7.13
- pip=22.3.1
- pytorch=1.12.1
- torchaudio=0.12.1
- torchvision=0.13.1
- tqdm
- pip:
- bitarray
- scipy
- submodules/ashawkey-diff-gaussian-rasterization
================================================
FILE: full_eval.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import os
from argparse import ArgumentParser
mipnerf360_outdoor_scenes = ["bicycle", "flowers", "garden", "stump", "treehill"]
mipnerf360_indoor_scenes = ["room", "counter", "kitchen", "bonsai"]
tanks_and_temples_scenes = ["truck", "train"]
deep_blending_scenes = ["drjohnson", "playroom"]
parser = ArgumentParser(description="Full evaluation script parameters")
parser.add_argument("--skip_training", action="store_true")
parser.add_argument("--skip_rendering", action="store_true")
parser.add_argument("--skip_metrics", action="store_true")
parser.add_argument("--output_path", default="./eval")
args, _ = parser.parse_known_args()
all_scenes = []
all_scenes.extend(mipnerf360_outdoor_scenes)
all_scenes.extend(mipnerf360_indoor_scenes)
all_scenes.extend(tanks_and_temples_scenes)
all_scenes.extend(deep_blending_scenes)
if not args.skip_training or not args.skip_rendering:
parser.add_argument('--mipnerf360', "-m360", required=True, type=str)
parser.add_argument("--tanksandtemples", "-tat", required=True, type=str)
parser.add_argument("--deepblending", "-db", required=True, type=str)
args = parser.parse_args()
if not args.skip_training:
common_args = " --quiet --eval --test_iterations -1 "
for scene in mipnerf360_outdoor_scenes:
source = args.mipnerf360 + "/" + scene
os.system("python train.py -s " + source + " -i images_4 -m " + args.output_path + "/" + scene + common_args)
for scene in mipnerf360_indoor_scenes:
source = args.mipnerf360 + "/" + scene
os.system("python train.py -s " + source + " -i images_2 -m " + args.output_path + "/" + scene + common_args)
for scene in tanks_and_temples_scenes:
source = args.tanksandtemples + "/" + scene
os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
for scene in deep_blending_scenes:
source = args.deepblending + "/" + scene
os.system("python train.py -s " + source + " -m " + args.output_path + "/" + scene + common_args)
if not args.skip_rendering:
all_sources = []
for scene in mipnerf360_outdoor_scenes:
all_sources.append(args.mipnerf360 + "/" + scene)
for scene in mipnerf360_indoor_scenes:
all_sources.append(args.mipnerf360 + "/" + scene)
for scene in tanks_and_temples_scenes:
all_sources.append(args.tanksandtemples + "/" + scene)
for scene in deep_blending_scenes:
all_sources.append(args.deepblending + "/" + scene)
common_args = " --quiet --eval --skip_train"
for scene, source in zip(all_scenes, all_sources):
os.system("python render.py --iteration 7000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
os.system("python render.py --iteration 30000 -s " + source + " -m " + args.output_path + "/" + scene + common_args)
if not args.skip_metrics:
scenes_string = ""
for scene in all_scenes:
scenes_string += "\"" + args.output_path + "/" + scene + "\" "
os.system("python metrics.py -m " + scenes_string)
================================================
FILE: gaussian_renderer/__init__.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import math
# from diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from ashawkey_diff_gaussian_rasterization import GaussianRasterizationSettings, GaussianRasterizer
from scene.gaussian_model import GaussianModel
from utils.sh_utils import eval_sh
from utils.opengs_utlis import *
# from sklearn.neighbors import NearestNeighbors
import pytorch3d.ops
def render(viewpoint_camera, pc : GaussianModel, pipe, bg_color : torch.Tensor, iteration,
scaling_modifier = 1.0, override_color = None, visible_mask = None, mask_num=0,
cluster_idx=None, # per-point cluster id (coarse-level)
leaf_cluster_idx=None, # per-point cluster id (fine-level)
rescale=True, # re-scale (for enhance ins_feat)
origin_feat=False, # origin ins_feat (not quantized)
render_feat_map=True, # render image-level feat map
render_color=True, # render rgb image
render_cluster=False, # render cluster, stage 2.2
better_vis=False, # filter some points
selected_root_id=None, # coarse-level cluster id
selected_leaf_id=None, # fine-level cluster id (possibly more than one)
pre_mask=None,
seg_rgb=False, # render cluster rgb, not feat
post_process=False, # post
root_num=64, leaf_num=10): # k1, k2
"""
Render the scene.
Background tensor (bg_color) must be on GPU!
"""
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
pass
# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=pc.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=pipe.debug
)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = pc.get_xyz
means2D = screenspace_points
opacity = pc.get_opacity
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
if pipe.compute_cov3D_python:
cov3D_precomp = pc.get_covariance(scaling_modifier)
else:
scales = pc.get_scaling
rotations = pc.get_rotation
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
shs = None
colors_precomp = None
if override_color is None:
if pipe.convert_SHs_python:
shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2)
dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1))
dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True)
sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized)
colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0)
else:
shs = pc.get_features
else:
colors_precomp = override_color
if render_color:
rendered_image, radii, rendered_depth, rendered_alpha = rasterizer(
means3D = means3D,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
opacities = opacity,
scales = scales,
rotations = rotations,
cov3D_precomp = cov3D_precomp)
else:
rendered_image, radii, rendered_depth, rendered_alpha = None, None, None, None
# ################################################################
# [Stage 1, Stage 2.1] Render image-level instance feature map #
# - rendered_ins_feat: image-level feat map #
# ################################################################
# probabilistically rescale
prob = torch.rand(1)
rescale_factor = torch.tensor(1.0, dtype=torch.float32).cuda()
if prob > 0.5 and rescale:
rescale_factor = torch.rand(1).cuda()
if render_feat_map:
# get feature
ins_feat = (pc.get_ins_feat(origin=origin_feat) + 1) / 2 # pseudo -> norm, else -> raw
# first three channels
rendered_ins_feat, _, _, _ = rasterizer(
means3D = means3D,
means2D = means2D,
shs = None,
colors_precomp = ins_feat[:, :3], # render features as pre-computed colors
opacities = opacity,
scales = scales * rescale_factor,
rotations = rotations,
cov3D_precomp = cov3D_precomp)
# last three channels
if ins_feat.shape[-1] > 3:
rendered_ins_feat2, _, _, _ = rasterizer(
means3D = means3D,
means2D = means2D,
shs = None,
colors_precomp = ins_feat[:, 3:6], # render features as pre-computed colors
opacities = opacity,
scales = scales * rescale_factor,
rotations = rotations,
cov3D_precomp = cov3D_precomp)
rendered_ins_feat = torch.cat((rendered_ins_feat, rendered_ins_feat2), dim=0)
# mask
_, _, _, silhouette = rasterizer(
means3D = means3D,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
opacities = opacity,
scales = scales * rescale_factor,
# opacities = opacity*0+1.0, #
# scales = scales*0+0.001, # *0.1
rotations = rotations,
cov3D_precomp = cov3D_precomp)
else:
rendered_ins_feat, silhouette = None, None
# ########################################################################
# [Preprocessing for Stage 2.2]: render (coarse) cluster-level feat map #
# - rendered_clusters: feat map of the coarse clusters #
# - rendered_cluster_silhouettes: cluster mask #
# ########################################################################
viewed_pts = radii > 0 # ignore the invisible points
if cluster_idx is not None:
num_cluster = cluster_idx.max() + 1
cluster_occur = torch.zeros(num_cluster).to(torch.bool) # [num_cluster], bool
else:
cluster_occur = None
if render_cluster and cluster_idx is not None and viewed_pts.sum() != 0:
ins_feat = (pc.get_ins_feat(origin=origin_feat) + 1) / 2 # pseudo -> norm, else -> raw
rendered_clusters = []
rendered_cluster_silhouettes = []
scale_filter = (scales < 0.5).all(dim=1) # filter
for idx in range(num_cluster):
if not better_vis and idx != selected_root_id:
continue
# ignore the invisible coarse-level cluster
if viewpoint_camera.bClusterOccur is not None and viewpoint_camera.bClusterOccur[idx] == False:
continue
# NOTE: Render only the idx-th coarse cluster
filter_idx = cluster_idx == idx
filter_idx = filter_idx & viewed_pts
# todo: filter
if better_vis:
filter_idx = filter_idx & scale_filter
if filter_idx.sum() < 100:
continue
# render cluster-level feat map
rendered_cluster, _, _, cluster_silhouette = rasterizer(
means3D = means3D[filter_idx],
means2D = means2D[filter_idx],
shs = None, # feat
colors_precomp = ins_feat[:, :3][filter_idx], # feat
# shs = shs[filter_idx], # rgb
# colors_precomp = None, # rgb
opacities = opacity[filter_idx],
scales = scales[filter_idx] * rescale_factor,
rotations = rotations[filter_idx],
cov3D_precomp = cov3D_precomp)
if ins_feat.shape[-1] > 3:
rendered_cluster2, _, _, cluster_silhouette = rasterizer(
means3D = means3D[filter_idx],
means2D = means2D[filter_idx],
shs = None, # feat
colors_precomp = ins_feat[:, 3:][filter_idx], # feat
# shs = shs[filter_idx], # rgb
# colors_precomp = None, # rgb
opacities = opacity[filter_idx],
scales = scales[filter_idx] * rescale_factor,
rotations = rotations[filter_idx],
cov3D_precomp = cov3D_precomp)
rendered_cluster = torch.cat((rendered_cluster, rendered_cluster2), dim=0)
# alpha --> mask
if cluster_silhouette.max() > 0.8:
cluster_occur[idx] = True
rendered_clusters.append(rendered_cluster)
rendered_cluster_silhouettes.append(cluster_silhouette)
if len(rendered_cluster_silhouettes) != 0:
rendered_cluster_silhouettes = torch.vstack(rendered_cluster_silhouettes)
else:
rendered_clusters, rendered_cluster_silhouettes = None, None
# ###############################################################
# [Stage 2.2 & Stage 3] render (fine) cluster-level feat map #
# - rendered_leaf_clusters: feat map of the fine clusters #
# - rendered_leaf_cluster_silhouettes: fine cluster mask #
# - occured_leaf_id: visible fine cluster id #
# ###############################################################
if leaf_cluster_idx is not None and leaf_cluster_idx.numel() > 0:
ins_feat = (pc.get_ins_feat(origin=origin_feat) + 1) / 2 # pseudo -> norm, else -> raw
# todo: rescale
scale_filter = (scales < 0.1).all(dim=1)
# scale_filter = (scales < 0.1).all(dim=1) & (opacity > 0.1).squeeze(-1)
re_scale_factor = torch.ones_like(opacity) # not used
# determine the fine cluster ID range (lerf_range) based on the coarse cluster ID (selected_leaf_id).
# root_num = 64 # todo modify
# leaf_num = 5 # todo modify
rendered_leaf_clusters = []
rendered_leaf_cluster_silhouettes = []
occured_leaf_id = []
if selected_leaf_id is None:
if selected_root_id is not None:
start_leaf = selected_root_id * leaf_num # todo 10
end_leaf = start_leaf + leaf_num # todo 10
else:
start_leaf = 0
end_leaf = root_num * leaf_num # todo 64 * 10
lerf_range = range(start_leaf, end_leaf)
else:
lerf_range = selected_leaf_id.tolist()
for _, leaf_idx in enumerate(lerf_range): # render each fine cluster
# ignore the invisible clusters
if viewpoint_camera.bClusterOccur is not None and viewpoint_camera.bClusterOccur[selected_root_id] == False:
continue
if selected_leaf_id is None:
filter_idx = leaf_cluster_idx == leaf_idx # Render only the idx-th fine cluster
# filter_idx = labels != value # remove the idx-th fine cluster
else:
filter_idx = (leaf_cluster_idx.unsqueeze(1) == selected_leaf_id).any(dim=1)
# pre-mask
if pre_mask is not None:
filter_idx = filter_idx & pre_mask
filter_idx = filter_idx & viewed_pts
# filter
if better_vis:
filter_idx = filter_idx & scale_filter
if filter_idx.sum() < 100:
continue
# TODO post process (for 3D object selection)
# pre_count = filter_idx.sum()
max_time = 5
if post_process and max_time > 0:
nearest_k_distance = pytorch3d.ops.knn_points(
means3D[filter_idx].unsqueeze(0),
means3D[filter_idx].unsqueeze(0),
# K=int(filter_idx.sum()**0.5),
K=int(filter_idx.sum()**0.5),
).dists
mean_nearest_k_distance, std_nearest_k_distance = nearest_k_distance.mean(), nearest_k_distance.std()
# print(std_nearest_k_distance, "std_nearest_k_distance")
mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + std_nearest_k_distance
# mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + 0.1 * std_nearest_k_distance
mask = mask.squeeze()
if filter_idx is not None:
filter_idx[filter_idx != 0] = mask
max_time -= 1
if filter_idx.sum() < 10:
continue
# record the fine cluster id appears in the current view.
occured_leaf_id.append(leaf_idx)
# note: render cluster rgb or feat.
if seg_rgb:
_shs = shs[filter_idx]
_colors_precomp1 = None
_colors_precomp2 = None
else:
_shs = None
_colors_precomp1 = ins_feat[:, :3][filter_idx]
_colors_precomp2 = ins_feat[:, 3:][filter_idx]
rendered_leaf_cluster, _, _, leaf_cluster_silhouette = rasterizer(
means3D = means3D[filter_idx],
means2D = means2D[filter_idx],
shs = _shs, # rgb or feat
colors_precomp = _colors_precomp1, # rgb or feat
opacities = opacity[filter_idx],
scales = (scales * re_scale_factor)[filter_idx],
rotations = rotations[filter_idx],
cov3D_precomp = cov3D_precomp)
if ins_feat.shape[-1] > 3:
rendered_leaf_cluster2, _, _, _ = rasterizer(
means3D = means3D[filter_idx],
means2D = means2D[filter_idx],
shs = _shs, # rgb or feat
colors_precomp = _colors_precomp2, # rgb or feat
opacities = opacity[filter_idx],
scales = (scales * re_scale_factor)[filter_idx],
rotations = rotations[filter_idx],
cov3D_precomp = cov3D_precomp)
rendered_leaf_cluster = torch.cat((rendered_leaf_cluster, rendered_leaf_cluster2), dim=0)
rendered_leaf_clusters.append(rendered_leaf_cluster)
rendered_leaf_cluster_silhouettes.append(leaf_cluster_silhouette)
if selected_leaf_id is not None and len(rendered_leaf_clusters) > 0:
break
if len(rendered_leaf_cluster_silhouettes) != 0:
rendered_leaf_cluster_silhouettes = torch.vstack(rendered_leaf_cluster_silhouettes)
else:
rendered_leaf_clusters = None
rendered_leaf_cluster_silhouettes = None
occured_leaf_id = None
# Those Gaussians that were frustum culled or had a radius of 0 were not visible.
# They will be excluded from value updates used in the splitting criteria.
return {"render": rendered_image,
"alpha": rendered_alpha,
"depth": rendered_depth, # not used
"silhouette": silhouette,
"ins_feat": rendered_ins_feat, # image-level feat map
"cluster_imgs": rendered_clusters, # coarse cluster feat map/image
"cluster_silhouettes": rendered_cluster_silhouettes, # coarse cluster mask
"leaf_clusters_imgs": rendered_leaf_clusters, # fine cluster feat map/image
"leaf_cluster_silhouettes": rendered_leaf_cluster_silhouettes, # fine cluster mask
"occured_leaf_id": occured_leaf_id, # fine cluster
"cluster_occur": cluster_occur, # coarse cluster
"viewspace_points": screenspace_points,
"visibility_filter" : radii > 0,
"radii": radii}
================================================
FILE: gaussian_renderer/network_gui.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import traceback
import socket
import json
from scene.cameras import MiniCam
host = "127.0.0.1"
port = 6009
conn = None
addr = None
listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
def init(wish_host, wish_port):
global host, port, listener
host = wish_host
port = wish_port
listener.bind((host, port))
listener.listen()
listener.settimeout(0)
def try_connect():
global conn, addr, listener
try:
conn, addr = listener.accept()
print(f"\nConnected by {addr}")
conn.settimeout(None)
except Exception as inst:
pass
def read():
global conn
messageLength = conn.recv(4)
messageLength = int.from_bytes(messageLength, 'little')
message = conn.recv(messageLength)
return json.loads(message.decode("utf-8"))
def send(message_bytes, verify):
global conn
if message_bytes != None:
conn.sendall(message_bytes)
conn.sendall(len(verify).to_bytes(4, 'little'))
conn.sendall(bytes(verify, 'ascii'))
def receive():
message = read()
width = message["resolution_x"]
height = message["resolution_y"]
if width != 0 and height != 0:
try:
do_training = bool(message["train"])
fovy = message["fov_y"]
fovx = message["fov_x"]
znear = message["z_near"]
zfar = message["z_far"]
do_shs_python = bool(message["shs_python"])
do_rot_scale_python = bool(message["rot_scale_python"])
keep_alive = bool(message["keep_alive"])
scaling_modifier = message["scaling_modifier"]
world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda()
world_view_transform[:,1] = -world_view_transform[:,1]
world_view_transform[:,2] = -world_view_transform[:,2]
full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda()
full_proj_transform[:,1] = -full_proj_transform[:,1]
custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform)
except Exception as e:
print("")
traceback.print_exc()
raise e
return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier
else:
return None, None, None, None, None, None
================================================
FILE: lpipsPyTorch/__init__.py
================================================
import torch
from .modules.lpips import LPIPS
def lpips(x: torch.Tensor,
y: torch.Tensor,
net_type: str = 'alex',
version: str = '0.1'):
r"""Function that measures
Learned Perceptual Image Patch Similarity (LPIPS).
Arguments:
x, y (torch.Tensor): the input tensors to compare.
net_type (str): the network type to compare the features:
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
version (str): the version of LPIPS. Default: 0.1.
"""
device = x.device
criterion = LPIPS(net_type, version).to(device)
return criterion(x, y)
================================================
FILE: lpipsPyTorch/modules/lpips.py
================================================
import torch
import torch.nn as nn
from .networks import get_network, LinLayers
from .utils import get_state_dict
class LPIPS(nn.Module):
r"""Creates a criterion that measures
Learned Perceptual Image Patch Similarity (LPIPS).
Arguments:
net_type (str): the network type to compare the features:
'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
version (str): the version of LPIPS. Default: 0.1.
"""
def __init__(self, net_type: str = 'alex', version: str = '0.1'):
assert version in ['0.1'], 'v0.1 is only supported now'
super(LPIPS, self).__init__()
# pretrained network
self.net = get_network(net_type)
# linear layers
self.lin = LinLayers(self.net.n_channels_list)
self.lin.load_state_dict(get_state_dict(net_type, version))
def forward(self, x: torch.Tensor, y: torch.Tensor):
feat_x, feat_y = self.net(x), self.net(y)
diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
return torch.sum(torch.cat(res, 0), 0, True)
================================================
FILE: lpipsPyTorch/modules/networks.py
================================================
from typing import Sequence
from itertools import chain
import torch
import torch.nn as nn
from torchvision import models
from .utils import normalize_activation
def get_network(net_type: str):
if net_type == 'alex':
return AlexNet()
elif net_type == 'squeeze':
return SqueezeNet()
elif net_type == 'vgg':
return VGG16()
else:
raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
class LinLayers(nn.ModuleList):
def __init__(self, n_channels_list: Sequence[int]):
super(LinLayers, self).__init__([
nn.Sequential(
nn.Identity(),
nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
) for nc in n_channels_list
])
for param in self.parameters():
param.requires_grad = False
class BaseNet(nn.Module):
def __init__(self):
super(BaseNet, self).__init__()
# register buffer
self.register_buffer(
'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
self.register_buffer(
'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
def set_requires_grad(self, state: bool):
for param in chain(self.parameters(), self.buffers()):
param.requires_grad = state
def z_score(self, x: torch.Tensor):
return (x - self.mean) / self.std
def forward(self, x: torch.Tensor):
x = self.z_score(x)
output = []
for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
x = layer(x)
if i in self.target_layers:
output.append(normalize_activation(x))
if len(output) == len(self.target_layers):
break
return output
class SqueezeNet(BaseNet):
def __init__(self):
super(SqueezeNet, self).__init__()
self.layers = models.squeezenet1_1(True).features
self.target_layers = [2, 5, 8, 10, 11, 12, 13]
self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
self.set_requires_grad(False)
class AlexNet(BaseNet):
def __init__(self):
super(AlexNet, self).__init__()
self.layers = models.alexnet(True).features
self.target_layers = [2, 5, 8, 10, 12]
self.n_channels_list = [64, 192, 384, 256, 256]
self.set_requires_grad(False)
class VGG16(BaseNet):
def __init__(self):
super(VGG16, self).__init__()
self.layers = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
self.target_layers = [4, 9, 16, 23, 30]
self.n_channels_list = [64, 128, 256, 512, 512]
self.set_requires_grad(False)
================================================
FILE: lpipsPyTorch/modules/utils.py
================================================
from collections import OrderedDict
import torch
def normalize_activation(x, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
return x / (norm_factor + eps)
def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
# build url
url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
+ f'master/lpips/weights/v{version}/{net_type}.pth'
# download
old_state_dict = torch.hub.load_state_dict_from_url(
url, progress=True,
map_location=None if torch.cuda.is_available() else torch.device('cpu')
)
# rename keys
new_state_dict = OrderedDict()
for key, val in old_state_dict.items():
new_key = key
new_key = new_key.replace('lin', '')
new_key = new_key.replace('model.', '')
new_state_dict[new_key] = val
return new_state_dict
================================================
FILE: metrics.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
from pathlib import Path
import os
from PIL import Image
import torch
import torchvision.transforms.functional as tf
from utils.loss_utils import ssim
from lpipsPyTorch import lpips
import json
from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser
def readImages(renders_dir, gt_dir):
renders = []
gts = []
image_names = []
for fname in os.listdir(renders_dir):
render = Image.open(renders_dir / fname)
gt = Image.open(gt_dir / fname)
renders.append(tf.to_tensor(render).unsqueeze(0)[:, :3, :, :].cuda())
gts.append(tf.to_tensor(gt).unsqueeze(0)[:, :3, :, :].cuda())
image_names.append(fname)
return renders, gts, image_names
def evaluate(model_paths):
full_dict = {}
per_view_dict = {}
full_dict_polytopeonly = {}
per_view_dict_polytopeonly = {}
print("")
for scene_dir in model_paths:
try:
print("Scene:", scene_dir)
full_dict[scene_dir] = {}
per_view_dict[scene_dir] = {}
full_dict_polytopeonly[scene_dir] = {}
per_view_dict_polytopeonly[scene_dir] = {}
test_dir = Path(scene_dir) / "test"
for method in os.listdir(test_dir):
print("Method:", method)
full_dict[scene_dir][method] = {}
per_view_dict[scene_dir][method] = {}
full_dict_polytopeonly[scene_dir][method] = {}
per_view_dict_polytopeonly[scene_dir][method] = {}
method_dir = test_dir / method
gt_dir = method_dir/ "gt"
renders_dir = method_dir / "renders"
renders, gts, image_names = readImages(renders_dir, gt_dir)
ssims = []
psnrs = []
lpipss = []
for idx in tqdm(range(len(renders)), desc="Metric evaluation progress"):
ssims.append(ssim(renders[idx], gts[idx]))
psnrs.append(psnr(renders[idx], gts[idx]))
lpipss.append(lpips(renders[idx], gts[idx], net_type='vgg'))
print(" SSIM : {:>12.7f}".format(torch.tensor(ssims).mean(), ".5"))
print(" PSNR : {:>12.7f}".format(torch.tensor(psnrs).mean(), ".5"))
print(" LPIPS: {:>12.7f}".format(torch.tensor(lpipss).mean(), ".5"))
print("")
full_dict[scene_dir][method].update({"SSIM": torch.tensor(ssims).mean().item(),
"PSNR": torch.tensor(psnrs).mean().item(),
"LPIPS": torch.tensor(lpipss).mean().item()})
per_view_dict[scene_dir][method].update({"SSIM": {name: ssim for ssim, name in zip(torch.tensor(ssims).tolist(), image_names)},
"PSNR": {name: psnr for psnr, name in zip(torch.tensor(psnrs).tolist(), image_names)},
"LPIPS": {name: lp for lp, name in zip(torch.tensor(lpipss).tolist(), image_names)}})
with open(scene_dir + "/results.json", 'w') as fp:
json.dump(full_dict[scene_dir], fp, indent=True)
with open(scene_dir + "/per_view.json", 'w') as fp:
json.dump(per_view_dict[scene_dir], fp, indent=True)
except:
print("Unable to compute metrics for model", scene_dir)
if __name__ == "__main__":
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Set up command line argument parser
parser = ArgumentParser(description="Training script parameters")
parser.add_argument('--model_paths', '-m', required=True, nargs="+", type=str, default=[])
args = parser.parse_args()
evaluate(args.model_paths)
================================================
FILE: render.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import torch.nn.functional as F
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import numpy as np
from utils.opengs_utlis import get_SAM_mask_and_feat, load_code_book
# Randomly initialize 300 colors for visualizing the SAM mask. [OpenGaussian]
np.random.seed(42)
colors_defined = np.random.randint(100, 256, size=(300, 3))
colors_defined[0] = np.array([0, 0, 0]) # Ignore the mask ID of -1 and set it to black.
colors_defined = torch.from_numpy(colors_defined)
def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
render_ins_feat_path1 = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_ins_feat1")
render_ins_feat_path2 = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_ins_feat2")
gt_sam_mask_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt_sam_mask")
makedirs(render_path, exist_ok=True)
makedirs(gts_path, exist_ok=True)
makedirs(render_ins_feat_path1, exist_ok=True)
makedirs(render_ins_feat_path2, exist_ok=True)
makedirs(gt_sam_mask_path, exist_ok=True)
# load codebook
root_code_book_path = os.path.join(model_path, "point_cloud", f'iteration_{iteration}', "root_code_book")
leaf_code_book_path = os.path.join(model_path, "point_cloud", f'iteration_{iteration}', "leaf_code_book")
if os.path.exists(os.path.join(root_code_book_path, 'kmeans_inds.bin')):
root_code_book, root_cluster_indices = load_code_book(root_code_book_path)
root_cluster_indices = torch.from_numpy(root_cluster_indices).cuda()
if os.path.exists(os.path.join(leaf_code_book_path, 'kmeans_inds.bin')):
leaf_code_book, leaf_cluster_indices = load_code_book(leaf_code_book_path)
leaf_cluster_indices = torch.from_numpy(leaf_cluster_indices).cuda()
else:
leaf_cluster_indices = None
# render
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
render_pkg = render(view, gaussians, pipeline, background, iteration, rescale=False)
# RGB
rendering = render_pkg["render"]
gt = view.original_image[0:3, :, :]
# ins_feat
rendered_ins_feat = render_pkg["ins_feat"]
gt_sam_mask = view.original_sam_mask.cuda() # [4, H, W]
# Rendered RGB
torchvision.utils.save_image(rendering, os.path.join(render_path, view.image_name + ".png"))
# GT RGB
torchvision.utils.save_image(gt, os.path.join(gts_path, view.image_name + ".png"))
# ins_feat
torchvision.utils.save_image(rendered_ins_feat[:3,:,:], os.path.join(render_ins_feat_path1, view.image_name + "_1.png"))
torchvision.utils.save_image(rendered_ins_feat[3:6,:,:], os.path.join(render_ins_feat_path2, view.image_name + "_2.png"))
# NOTE get SAM id, mask bool, mask_feat, invalid pix
mask_id, _, _, _ = \
get_SAM_mask_and_feat(gt_sam_mask, level=0, original_mask_feat=view.original_mask_feat)
# mask visualization
mask_color_rand = colors_defined[mask_id.detach().cpu().type(torch.int64)].type(torch.float64)
mask_color_rand = mask_color_rand.permute(2, 0, 1)
torchvision.utils.save_image(mask_color_rand/255.0, os.path.join(gt_sam_mask_path, view.image_name + ".png"))
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
if not skip_train:
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
if not skip_test:
render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=-1, type=int)
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--skip_test", action="store_true")
parser.add_argument("--quiet", action="store_true")
args = get_combined_args(parser)
print("Rendering " + args.model_path)
# Initialize system state (RNG)
safe_state(args.quiet)
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
================================================
FILE: render_lerf_by_text.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import torch.nn.functional as F
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import numpy as np
import json
from utils.opengs_utlis import mask_feature_mean, get_SAM_mask_and_feat, load_code_book
np.random.seed(42)
colors_defined = np.random.randint(100, 256, size=(300, 3))
colors_defined[0] = np.array([0, 0, 0]) # Ignore the mask ID of -1 and set it to black.
colors_defined = torch.from_numpy(colors_defined)
def render_set(model_path, name, iteration, views, gaussians, pipeline, background, scene_name):
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
render_ins_feat_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_ins_feat")
gt_sam_mask_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt_sam_mask")
makedirs(render_path, exist_ok=True)
makedirs(gts_path, exist_ok=True)
makedirs(render_ins_feat_path, exist_ok=True)
makedirs(gt_sam_mask_path, exist_ok=True)
# load codebook
root_code_book, root_cluster_indices = load_code_book(os.path.join(model_path, "point_cloud", \
f'iteration_{iteration}', "root_code_book"))
leaf_code_book, leaf_cluster_indices = load_code_book(os.path.join(model_path, "point_cloud", \
f'iteration_{iteration}', "leaf_code_book"))
root_cluster_indices = torch.from_numpy(root_cluster_indices).cuda()
leaf_cluster_indices = torch.from_numpy(leaf_cluster_indices).cuda()
# counts = torch.bincount(torch.from_numpy(cluster_indices), minlength=64)
# load the saved codebook(leaf id) and instance-level language feature
# 'leaf_feat', 'leaf_acore', 'occu_count', 'leaf_ind'
mapping_file = os.path.join(model_path, "cluster_lang.npz")
saved_data = np.load(mapping_file)
leaf_lang_feat = torch.from_numpy(saved_data["leaf_feat.npy"]).cuda() # [num_leaf=k1*k2, 512] cluster lang feat
leaf_score = torch.from_numpy(saved_data["leaf_score.npy"]).cuda() # [num_leaf=k1*k2] cluster score
leaf_occu_count = torch.from_numpy(saved_data["occu_count.npy"]).cuda() # [num_leaf=k1*k2]
leaf_ind = torch.from_numpy(saved_data["leaf_ind.npy"]).cuda() # [num_pts] fine id
leaf_lang_feat[leaf_occu_count < 5] *= 0.0 # Filter out clusters that occur too infrequently.
leaf_cluster_indices = leaf_ind
root_num = root_cluster_indices.max() + 1
leaf_num = leaf_lang_feat.shape[0] / root_num
# text feature
with open('assets/text_features.json', 'r') as f:
data_loaded = json.load(f)
all_texts = list(data_loaded.keys())
text_features = torch.from_numpy(np.array(list(data_loaded.values()))).to(torch.float32) # [num_text, 512]
scene_texts = {
"waldo_kitchen": ['Stainless steel pots', 'dark cup', 'refrigerator', 'frog cup', 'pot', 'spatula', 'plate', \
'spoon', 'toaster', 'ottolenghi', 'plastic ladle', 'sink', 'ketchup', 'cabinet', 'red cup', \
'pour-over vessel', 'knife', 'yellow desk'],
"ramen": ['nori', 'sake cup', 'kamaboko', 'corn', 'spoon', 'egg', 'onion segments', 'plate', \
'napkin', 'bowl', 'glass of water', 'hand', 'chopsticks', 'wavy noodles'],
"figurines": ['jake', 'pirate hat', 'pikachu', 'rubber duck with hat', 'porcelain hand', \
'red apple', 'tesla door handle', 'waldo', 'bag', 'toy cat statue', 'miffy', \
'green apple', 'pumpkin', 'rubics cube', 'old camera', 'rubber duck with buoy', \
'red toy chair', 'pink ice cream', 'spatula', 'green toy chair', 'toy elephant'],
"teatime": ['sheep', 'yellow pouf', 'stuffed bear', 'coffee mug', 'tea in a glass', 'apple',
'coffee', 'hooves', 'bear nose', 'dall-e brand', 'plate', 'paper napkin', 'three cookies', \
'bag of cookies']
}
# note: query text
target_text = scene_texts[scene_name]
query_text_feats = torch.zeros(len(target_text), 512).cuda()
for i, text in enumerate(target_text):
feat = text_features[all_texts.index(text)].unsqueeze(0)
query_text_feats[i] = feat
for t_i, text_feat in enumerate(query_text_feats):
# if target_text[t_i] != "old camera":
# continue
print(f"rendering the {t_i+1}-th query of {len(target_text)} texts: {target_text[t_i]}")
# compute cosine similarity
text_feat = F.normalize(text_feat.unsqueeze(0), dim=1, p=2)
leaf_lang_feat = F.normalize(leaf_lang_feat, dim=1, p=2)
cosine_similarity = torch.matmul(text_feat, leaf_lang_feat.transpose(0, 1))
max_id = torch.argmax(cosine_similarity, dim=-1) # [cluster_num]
text_leaf_indices = max_id
top_values, top_indices = torch.topk(cosine_similarity, 10)
for candidate_id in top_indices[0][1:]:
if candidate_id - max_id < leaf_num: # TODO !!!!!!
max_feat = leaf_code_book['ins_feat'][max_id]
candi_feat = leaf_code_book['ins_feat'][candidate_id]
distances = torch.norm(max_feat - candi_feat, dim=1)
if distances < 0.9:
text_leaf_indices = torch.cat([text_leaf_indices, candidate_id.unsqueeze(0)])
# render
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
# note: evaluation frame
scene_gt_frames = {
"waldo_kitchen": ["frame_00053", "frame_00066", "frame_00089", "frame_00140", "frame_00154"],
"ramen": ["frame_00006", "frame_00024", "frame_00060", "frame_00065", "frame_00081", "frame_00119", "frame_00128"],
"figurines": ["frame_00041", "frame_00105", "frame_00152", "frame_00195"],
"teatime": ["frame_00002", "frame_00025", "frame_00043", "frame_00107", "frame_00129", "frame_00140"]
}
candidate_frames = scene_gt_frames[scene_name]
if view.image_name not in candidate_frames:
continue
render_pkg = render(view, gaussians, pipeline, background, iteration, rescale=False)
# RGB
rendering = render_pkg["render"]
gt = view.original_image[0:3, :, :]
# ins_feat
rendered_ins_feat = render_pkg["ins_feat"]
gt_sam_mask = view.original_sam_mask.cuda() # [4, H, W]
# RGB
torchvision.utils.save_image(rendering, os.path.join(render_path, '{0:05d}'.format(idx) + ".png"))
torchvision.utils.save_image(gt, os.path.join(gts_path, '{0:05d}'.format(idx) + ".png"))
# ins_feat
torchvision.utils.save_image(rendered_ins_feat[:3,:,:], os.path.join(render_ins_feat_path, '{0:05d}'.format(idx) + "_1.png"))
torchvision.utils.save_image(rendered_ins_feat[3:6,:,:], os.path.join(render_ins_feat_path, '{0:05d}'.format(idx) + "_2.png"))
# NOTE get SAM id, mask bool, mask_feat, invalid pix
mask_id, mask_bool, mask_feat, invalid_pix = \
get_SAM_mask_and_feat(gt_sam_mask, level=3, original_mask_feat=view.original_mask_feat)
# sam mask
mask_color_rand = colors_defined[mask_id.detach().cpu().type(torch.int64)].type(torch.float64)
mask_color_rand = mask_color_rand.permute(2, 0, 1)
torchvision.utils.save_image(mask_color_rand/255.0, os.path.join(gt_sam_mask_path, '{0:05d}'.format(idx) + ".png"))
# render target object
render_pkg = render(view, gaussians, pipeline, background, iteration,
rescale=False, #) # wherther to re-scale the gaussian scale
# cluster_idx=leaf_cluster_indices, # root id
leaf_cluster_idx=leaf_cluster_indices, # leaf id
selected_leaf_id=text_leaf_indices.cuda(), # text query 所选择的 leaf id
render_feat_map=False,
render_cluster=False,
better_vis=True,
seg_rgb=True,
post_process=True,
root_num=root_num, leaf_num=leaf_num)
rendered_cluster_imgs = render_pkg["leaf_clusters_imgs"]
occured_leaf_id = render_pkg["occured_leaf_id"]
rendered_leaf_cluster_silhouettes = render_pkg["leaf_cluster_silhouettes"]
render_cluster_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_cluster")
render_cluster_silhouette_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_cluster_silhouette")
makedirs(render_cluster_path, exist_ok=True)
makedirs(render_cluster_silhouette_path, exist_ok=True)
for i, img in enumerate(rendered_cluster_imgs):
# save object RGB
torchvision.utils.save_image(img[:3,:,:], os.path.join(render_cluster_path, \
view.image_name + f"_{target_text[t_i]}.png"))
# save object mask
cluster_silhouette = rendered_leaf_cluster_silhouettes[i] > 0.7
torchvision.utils.save_image(cluster_silhouette.to(torch.float32), os.path.join(render_cluster_silhouette_path, \
view.image_name + f"_{target_text[t_i]}.png"))
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool,
scene_name: str):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
# bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
bg_color = [1,1,1]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
if not skip_train:
render_set(dataset.model_path, "text2obj", scene.loaded_iter, scene.getTrainCameras(),
gaussians, pipeline, background, scene_name)
if not skip_test:
render_set(dataset.model_path, "text2obj", scene.loaded_iter, scene.getTestCameras(),
gaussians, pipeline, background, scene_name)
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=-1, type=int)
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--skip_test", action="store_true")
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--scene_name", type=str, choices=["waldo_kitchen", "ramen", "figurines", "teatime"],
help="Specify the scene_name from: figurines, teatime, ramen, waldo_kitchen")
args = get_combined_args(parser)
print("Rendering " + args.model_path)
if not args.scene_name:
parser.error("The --scene_name argument is required and must be one of: waldo_kitchen, ramen, figurines, teatime")
# Initialize system state (RNG)
safe_state(args.quiet)
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test, args.scene_name)
================================================
FILE: scene/__init__.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import os
import random
import json
from utils.system_utils import searchForMaxIteration
from scene.dataset_readers import sceneLoadTypeCallbacks
from scene.gaussian_model import GaussianModel
from arguments import ModelParams
from utils.camera_utils import cameraList_from_camInfos, camera_to_JSON
class Scene:
gaussians : GaussianModel
def __init__(self, args : ModelParams, gaussians : GaussianModel, load_iteration=None, shuffle=True, resolution_scales=[1.0]):
"""b
:param path: Path to colmap scene main folder.
"""
self.model_path = args.model_path
self.loaded_iter = None
self.gaussians = gaussians
if load_iteration:
if load_iteration == -1:
self.loaded_iter = searchForMaxIteration(os.path.join(self.model_path, "point_cloud"))
else:
self.loaded_iter = load_iteration
print("Loading trained model at iteration {}".format(self.loaded_iter))
self.train_cameras = {}
self.test_cameras = {}
if os.path.exists(os.path.join(args.source_path, "sparse")):
scene_info = sceneLoadTypeCallbacks["Colmap"](args.source_path, args.images, args.eval)
elif os.path.exists(os.path.join(args.source_path, "transforms_train.json")):
print("Found transforms_train.json file, assuming Blender data set!")
scene_info = sceneLoadTypeCallbacks["Blender"](args.source_path, args.white_background, args.eval)
else:
assert False, "Could not recognize scene type!"
if not self.loaded_iter:
with open(scene_info.ply_path, 'rb') as src_file, open(os.path.join(self.model_path, "input.ply") , 'wb') as dest_file:
dest_file.write(src_file.read())
json_cams = []
camlist = []
if scene_info.test_cameras:
camlist.extend(scene_info.test_cameras)
if scene_info.train_cameras:
camlist.extend(scene_info.train_cameras)
for id, cam in enumerate(camlist):
json_cams.append(camera_to_JSON(id, cam))
with open(os.path.join(self.model_path, "cameras.json"), 'w') as file:
json.dump(json_cams, file)
if shuffle:
random.shuffle(scene_info.train_cameras) # Multi-res consistent random shuffling
random.shuffle(scene_info.test_cameras) # Multi-res consistent random shuffling
self.cameras_extent = scene_info.nerf_normalization["radius"]
for resolution_scale in resolution_scales:
print("Resolution: ", resolution_scale)
print("Loading Training Cameras")
self.train_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.train_cameras, resolution_scale, args)
print("Loading Test Cameras")
self.test_cameras[resolution_scale] = cameraList_from_camInfos(scene_info.test_cameras, resolution_scale, args)
if self.loaded_iter:
self.gaussians.load_ply(os.path.join(self.model_path,
"point_cloud",
"iteration_" + str(self.loaded_iter),
"point_cloud.ply"))
else:
self.gaussians.create_from_pcd(scene_info.point_cloud, self.cameras_extent)
def save(self, iteration, save_q=[]):
point_cloud_path = os.path.join(self.model_path, "point_cloud/iteration_{}".format(iteration))
self.gaussians.save_ply(os.path.join(point_cloud_path, "point_cloud.ply"), save_q)
def getTrainCameras(self, scale=1.0):
return self.train_cameras[scale]
def getTestCameras(self, scale=1.0):
return self.test_cameras[scale]
================================================
FILE: scene/cameras.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
from torch import nn
import numpy as np
from utils.graphics_utils import getWorld2View2, getProjectionMatrix
class Camera(nn.Module):
def __init__(self, colmap_id, R, T, FoVx, FoVy, cx, cy, image, depth, gt_alpha_mask,
gt_sam_mask, gt_mask_feat,
image_name, uid,
trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda"
):
super(Camera, self).__init__()
self.uid = uid
self.colmap_id = colmap_id
self.R = R
self.T = T
self.FoVx = FoVx
self.FoVy = FoVy
# modify -----
self.cx = cx
self.cy = cy
# modify -----
self.image_name = image_name
try:
self.data_device = torch.device(data_device)
except Exception as e:
print(e)
print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" )
self.data_device = torch.device("cuda")
self.data_on_gpu = True # note
self.original_image = image.clamp(0.0, 1.0).to(self.data_device)
# modify -----
self.original_mask = gt_alpha_mask.to(self.data_device) if gt_alpha_mask is not None else None
# modify -----
self.original_sam_mask = gt_sam_mask.to(self.data_device) if gt_sam_mask is not None else None
self.original_mask_feat = gt_mask_feat.to(self.data_device) if gt_mask_feat is not None else None
self.pesudo_ins_feat = None
self.pesudo_mask_bool = None
self.cluster_masks = None
self.bClusterOccur = None
self.image_width = self.original_image.shape[2]
self.image_height = self.original_image.shape[1]
if gt_alpha_mask is not None:
self.original_image *= gt_alpha_mask.to(self.data_device)
else:
self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device)
self.zfar = 100.0
self.znear = 0.01
self.trans = trans
self.scale = scale
self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda()
self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda()
self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0)
self.camera_center = self.world_view_transform.inverse()[3, :3]
# modify -----
def to_gpu(self):
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, torch.Tensor) and not attr.is_cuda:
setattr(self, attr_name, attr.to('cuda'))
self.data_on_gpu = True
# modify -----
def to_cpu(self):
for attr_name in dir(self):
attr = getattr(self, attr_name)
if isinstance(attr, torch.Tensor) and attr.is_cuda:
setattr(self, attr_name, attr.to('cpu'))
self.data_on_gpu = False
class MiniCam:
def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform):
self.image_width = width
self.image_height = height
self.FoVy = fovy
self.FoVx = fovx
self.znear = znear
self.zfar = zfar
self.world_view_transform = world_view_transform
self.full_proj_transform = full_proj_transform
view_inv = torch.inverse(self.world_view_transform)
self.camera_center = view_inv[3][:3]
================================================
FILE: scene/colmap_loader.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import numpy as np
import collections
import struct
CameraModel = collections.namedtuple(
"CameraModel", ["model_id", "model_name", "num_params"])
Camera = collections.namedtuple(
"Camera", ["id", "model", "width", "height", "params"])
BaseImage = collections.namedtuple(
"Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
Point3D = collections.namedtuple(
"Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])
CAMERA_MODELS = {
CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
CameraModel(model_id=3, model_name="RADIAL", num_params=5),
CameraModel(model_id=4, model_name="OPENCV", num_params=8),
CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
CameraModel(model_id=7, model_name="FOV", num_params=5),
CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12)
}
CAMERA_MODEL_IDS = dict([(camera_model.model_id, camera_model)
for camera_model in CAMERA_MODELS])
CAMERA_MODEL_NAMES = dict([(camera_model.model_name, camera_model)
for camera_model in CAMERA_MODELS])
def qvec2rotmat(qvec):
return np.array([
[1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
[2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
[2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])
def rotmat2qvec(R):
Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
K = np.array([
[Rxx - Ryy - Rzz, 0, 0, 0],
[Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
[Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
[Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz]]) / 3.0
eigvals, eigvecs = np.linalg.eigh(K)
qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
if qvec[0] < 0:
qvec *= -1
return qvec
class Image(BaseImage):
def qvec2rotmat(self):
return qvec2rotmat(self.qvec)
def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
"""Read and unpack the next bytes from a binary file.
:param fid:
:param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
:param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
:param endian_character: Any of {@, =, <, >, !}
:return: Tuple of read and unpacked values.
"""
data = fid.read(num_bytes)
return struct.unpack(endian_character + format_char_sequence, data)
def read_points3D_text(path):
"""
see: src/base/reconstruction.cc
void Reconstruction::ReadPoints3DText(const std::string& path)
void Reconstruction::WritePoints3DText(const std::string& path)
"""
xyzs = None
rgbs = None
errors = None
num_points = 0
with open(path, "r") as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != "#":
num_points += 1
xyzs = np.empty((num_points, 3))
rgbs = np.empty((num_points, 3))
errors = np.empty((num_points, 1))
count = 0
with open(path, "r") as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != "#":
elems = line.split()
xyz = np.array(tuple(map(float, elems[1:4])))
rgb = np.array(tuple(map(int, elems[4:7])))
error = np.array(float(elems[7]))
xyzs[count] = xyz
rgbs[count] = rgb
errors[count] = error
count += 1
return xyzs, rgbs, errors
def read_points3D_binary(path_to_model_file):
"""
see: src/base/reconstruction.cc
void Reconstruction::ReadPoints3DBinary(const std::string& path)
void Reconstruction::WritePoints3DBinary(const std::string& path)
"""
with open(path_to_model_file, "rb") as fid:
num_points = read_next_bytes(fid, 8, "Q")[0]
xyzs = np.empty((num_points, 3))
rgbs = np.empty((num_points, 3))
errors = np.empty((num_points, 1))
for p_id in range(num_points):
binary_point_line_properties = read_next_bytes(
fid, num_bytes=43, format_char_sequence="QdddBBBd")
xyz = np.array(binary_point_line_properties[1:4])
rgb = np.array(binary_point_line_properties[4:7])
error = np.array(binary_point_line_properties[7])
track_length = read_next_bytes(
fid, num_bytes=8, format_char_sequence="Q")[0]
track_elems = read_next_bytes(
fid, num_bytes=8*track_length,
format_char_sequence="ii"*track_length)
xyzs[p_id] = xyz
rgbs[p_id] = rgb
errors[p_id] = error
return xyzs, rgbs, errors
def read_intrinsics_text(path):
"""
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
"""
cameras = {}
with open(path, "r") as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != "#":
elems = line.split()
camera_id = int(elems[0])
model = elems[1]
assert model == "PINHOLE", "While the loader support other types, the rest of the code assumes PINHOLE"
width = int(elems[2])
height = int(elems[3])
params = np.array(tuple(map(float, elems[4:])))
cameras[camera_id] = Camera(id=camera_id, model=model,
width=width, height=height,
params=params)
return cameras
def read_extrinsics_binary(path_to_model_file):
"""
see: src/base/reconstruction.cc
void Reconstruction::ReadImagesBinary(const std::string& path)
void Reconstruction::WriteImagesBinary(const std::string& path)
"""
images = {}
with open(path_to_model_file, "rb") as fid:
num_reg_images = read_next_bytes(fid, 8, "Q")[0]
for _ in range(num_reg_images):
binary_image_properties = read_next_bytes(
fid, num_bytes=64, format_char_sequence="idddddddi")
image_id = binary_image_properties[0]
qvec = np.array(binary_image_properties[1:5])
tvec = np.array(binary_image_properties[5:8])
camera_id = binary_image_properties[8]
image_name = ""
current_char = read_next_bytes(fid, 1, "c")[0]
while current_char != b"\x00": # look for the ASCII 0 entry
image_name += current_char.decode("utf-8")
current_char = read_next_bytes(fid, 1, "c")[0]
num_points2D = read_next_bytes(fid, num_bytes=8,
format_char_sequence="Q")[0]
x_y_id_s = read_next_bytes(fid, num_bytes=24*num_points2D,
format_char_sequence="ddq"*num_points2D)
xys = np.column_stack([tuple(map(float, x_y_id_s[0::3])),
tuple(map(float, x_y_id_s[1::3]))])
point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
images[image_id] = Image(
id=image_id, qvec=qvec, tvec=tvec,
camera_id=camera_id, name=image_name,
xys=xys, point3D_ids=point3D_ids)
return images
def read_intrinsics_binary(path_to_model_file):
"""
see: src/base/reconstruction.cc
void Reconstruction::WriteCamerasBinary(const std::string& path)
void Reconstruction::ReadCamerasBinary(const std::string& path)
"""
cameras = {}
with open(path_to_model_file, "rb") as fid:
num_cameras = read_next_bytes(fid, 8, "Q")[0]
for _ in range(num_cameras):
camera_properties = read_next_bytes(
fid, num_bytes=24, format_char_sequence="iiQQ")
camera_id = camera_properties[0]
model_id = camera_properties[1]
model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
width = camera_properties[2]
height = camera_properties[3]
num_params = CAMERA_MODEL_IDS[model_id].num_params
params = read_next_bytes(fid, num_bytes=8*num_params,
format_char_sequence="d"*num_params)
cameras[camera_id] = Camera(id=camera_id,
model=model_name,
width=width,
height=height,
params=np.array(params))
assert len(cameras) == num_cameras
return cameras
def read_extrinsics_text(path):
"""
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py
"""
images = {}
with open(path, "r") as fid:
while True:
line = fid.readline()
if not line:
break
line = line.strip()
if len(line) > 0 and line[0] != "#":
elems = line.split()
image_id = int(elems[0])
qvec = np.array(tuple(map(float, elems[1:5])))
tvec = np.array(tuple(map(float, elems[5:8])))
camera_id = int(elems[8])
image_name = elems[9]
elems = fid.readline().split()
xys = np.column_stack([tuple(map(float, elems[0::3])),
tuple(map(float, elems[1::3]))])
point3D_ids = np.array(tuple(map(int, elems[2::3])))
images[image_id] = Image(
id=image_id, qvec=qvec, tvec=tvec,
camera_id=camera_id, name=image_name,
xys=xys, point3D_ids=point3D_ids)
return images
def read_colmap_bin_array(path):
"""
Taken from https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
:param path: path to the colmap binary file.
:return: nd array with the floating point values in the value
"""
with open(path, "rb") as fid:
width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
usecols=(0, 1, 2), dtype=int)
fid.seek(0)
num_delimiter = 0
byte = fid.read(1)
while True:
if byte == b"&":
num_delimiter += 1
if num_delimiter >= 3:
break
byte = fid.read(1)
array = np.fromfile(fid, np.float32)
array = array.reshape((width, height, channels), order="F")
return np.transpose(array, (1, 0, 2)).squeeze()
================================================
FILE: scene/dataset_readers.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import os
import sys
from PIL import Image
from typing import NamedTuple
from scene.colmap_loader import read_extrinsics_text, read_intrinsics_text, qvec2rotmat, \
read_extrinsics_binary, read_intrinsics_binary, read_points3D_binary, read_points3D_text
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
import numpy as np
import json
import random
from tqdm import tqdm
from pathlib import Path
from plyfile import PlyData, PlyElement
from utils.sh_utils import SH2RGB
from scene.gaussian_model import BasicPointCloud
class CameraInfo(NamedTuple):
uid: int
R: np.array
T: np.array
FovY: np.array
FovX: np.array
cx: np.array
cy: np.array
image: np.array
depth: np.array # not used
sam_mask: np.array # modify -----
mask_feat: np.array # modify -----
image_path: str
image_name: str
width: int
height: int
class SceneInfo(NamedTuple):
point_cloud: BasicPointCloud
train_cameras: list
test_cameras: list
nerf_normalization: dict
ply_path: str
def getNerfppNorm(cam_info):
def get_center_and_diag(cam_centers):
cam_centers = np.hstack(cam_centers)
avg_cam_center = np.mean(cam_centers, axis=1, keepdims=True)
center = avg_cam_center
dist = np.linalg.norm(cam_centers - center, axis=0, keepdims=True)
diagonal = np.max(dist)
return center.flatten(), diagonal
cam_centers = []
for cam in cam_info:
W2C = getWorld2View2(cam.R, cam.T)
C2W = np.linalg.inv(W2C)
cam_centers.append(C2W[:3, 3:4])
center, diagonal = get_center_and_diag(cam_centers)
radius = diagonal * 1.1
translate = -center
return {"translate": translate, "radius": radius}
def readColmapCameras(cam_extrinsics, cam_intrinsics, images_folder):
cam_infos = []
for idx, key in enumerate(cam_extrinsics):
sys.stdout.write('\r')
# the exact output you're looking for:
sys.stdout.write("Reading camera {}/{}".format(idx+1, len(cam_extrinsics)))
sys.stdout.flush()
extr = cam_extrinsics[key]
intr = cam_intrinsics[extr.camera_id]
height = intr.height
width = intr.width
uid = intr.id
R = np.transpose(qvec2rotmat(extr.qvec))
T = np.array(extr.tvec)
if intr.model=="SIMPLE_PINHOLE":
focal_length_x = intr.params[0]
FovY = focal2fov(focal_length_x, height)
FovX = focal2fov(focal_length_x, width)
elif intr.model=="PINHOLE":
focal_length_x = intr.params[0]
focal_length_y = intr.params[1]
FovY = focal2fov(focal_length_y, height)
FovX = focal2fov(focal_length_x, width)
else:
assert False, "Colmap camera model not handled: only undistorted datasets (PINHOLE or SIMPLE_PINHOLE cameras) supported!"
image_path = os.path.join(images_folder, os.path.basename(extr.name))
if not os.path.exists(image_path):
# modify -----
base, ext = os.path.splitext(image_path)
if ext.lower() == ".jpg":
image_path = base + ".png"
elif ext.lower() == ".png":
image_path = base + ".jpg"
if not os.path.exists(image_path):
continue
# modify ----
image_name = os.path.basename(image_path).split(".")[0]
image = Image.open(image_path)
# NOTE: load SAM mask and CLIP feat. [OpenGaussian]
mask_seg_path = os.path.join(images_folder[:-6], "language_features/" + extr.name.split('/')[-1][:-4] + "_s.npy")
mask_feat_path = os.path.join(images_folder[:-6], "language_features/" + extr.name.split('/')[-1][:-4] + "_f.npy")
if os.path.exists(mask_seg_path):
sam_mask = np.load(mask_seg_path) # [level=4, H, W]
else:
sam_mask = None
if mask_feat_path is not None and os.path.exists(mask_feat_path):
mask_feat = np.load(mask_feat_path) # [level=4, H, W]
else:
mask_feat = None
# modify -----
cam_info = CameraInfo(uid=uid, R=R, T=T, FovY=FovY, FovX=FovX, cx=width/2, cy=height/2, image=image,
depth=None, sam_mask=sam_mask, mask_feat=mask_feat,
image_path=image_path, image_name=image_name, width=width, height=height)
cam_infos.append(cam_info)
sys.stdout.write('\n')
return cam_infos
def fetchPly(path):
plydata = PlyData.read(path)
vertices = plydata['vertex']
positions = np.vstack([vertices['x'], vertices['y'], vertices['z']]).T
if {'red', 'green', 'blue'}.issubset(vertices.data.dtype.names):
colors = np.vstack([vertices['red'], vertices['green'], vertices['blue']]).T / 255.0
else:
colors = np.random.rand(positions.shape[0], 3)
if {'nx', 'ny', 'nz'}.issubset(vertices.data.dtype.names):
normals = np.vstack([vertices['nx'], vertices['ny'], vertices['nz']]).T
else:
normals = np.random.rand(positions.shape[0], 3)
return BasicPointCloud(points=positions, colors=colors, normals=normals)
def storePly(path, xyz, rgb):
# Define the dtype for the structured array
dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
('nx', 'f4'), ('ny', 'f4'), ('nz', 'f4'),
('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
normals = np.zeros_like(xyz)
elements = np.empty(xyz.shape[0], dtype=dtype)
attributes = np.concatenate((xyz, normals, rgb), axis=1)
elements[:] = list(map(tuple, attributes))
# Create the PlyData object and write to file
vertex_element = PlyElement.describe(elements, 'vertex')
ply_data = PlyData([vertex_element])
ply_data.write(path)
def readColmapSceneInfo(path, images, eval, llffhold=8):
try:
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.bin")
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.bin")
cam_extrinsics = read_extrinsics_binary(cameras_extrinsic_file)
cam_intrinsics = read_intrinsics_binary(cameras_intrinsic_file)
except:
cameras_extrinsic_file = os.path.join(path, "sparse/0", "images.txt")
cameras_intrinsic_file = os.path.join(path, "sparse/0", "cameras.txt")
cam_extrinsics = read_extrinsics_text(cameras_extrinsic_file)
cam_intrinsics = read_intrinsics_text(cameras_intrinsic_file)
reading_dir = "images" if images == None else images
cam_infos_unsorted = readColmapCameras(cam_extrinsics=cam_extrinsics, cam_intrinsics=cam_intrinsics, images_folder=os.path.join(path, reading_dir))
cam_infos = sorted(cam_infos_unsorted.copy(), key = lambda x : x.image_name)
if eval:
train_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold != 0]
test_cam_infos = [c for idx, c in enumerate(cam_infos) if idx % llffhold == 0]
else:
train_cam_infos = cam_infos
test_cam_infos = []
nerf_normalization = getNerfppNorm(train_cam_infos)
ply_path = os.path.join(path, "sparse/0/points3D.ply")
bin_path = os.path.join(path, "sparse/0/points3D.bin")
txt_path = os.path.join(path, "sparse/0/points3D.txt")
if not os.path.exists(ply_path):
print("Converting point3d.bin to .ply, will happen only the first time you open the scene.")
try:
xyz, rgb, _ = read_points3D_binary(bin_path)
except:
xyz, rgb, _ = read_points3D_text(txt_path)
storePly(ply_path, xyz, rgb)
try:
pcd = fetchPly(ply_path)
except:
pcd = None
scene_info = SceneInfo(point_cloud=pcd,
train_cameras=train_cam_infos,
test_cameras=test_cam_infos,
nerf_normalization=nerf_normalization,
ply_path=ply_path)
return scene_info
def readCamerasFromTransforms(path, transformsfile, white_background, extension=".png"):
cam_infos = []
with open(os.path.join(path, transformsfile)) as json_file:
contents = json.load(json_file)
# ----- modify -----
if "camera_angle_x" not in contents.keys():
fovx = None
else:
fovx = contents["camera_angle_x"]
# ----- modify -----
# modify -----
cx, cy = -1, -1
if "cx" in contents.keys():
cx = contents["cx"]
cy = contents["cy"]
elif "h" in contents.keys():
cx = contents["w"] / 2
cy = contents["h"] / 2
# modify -----
frames = contents["frames"]
# for idx, frame in enumerate(frames):
for idx, frame in tqdm(enumerate(frames), total=len(frames), desc="load images"):
cam_name = os.path.join(path, frame["file_path"] + extension)
# NeRF 'transform_matrix' is a camera-to-world transform
c2w = np.array(frame["transform_matrix"])
# change from OpenGL/Blender camera axes (Y up, Z back) to COLMAP (Y down, Z forward)
c2w[:3, 1:3] *= -1 # TODO
# get the world-to-camera transform and set R, T
w2c = np.linalg.inv(c2w)
R = np.transpose(w2c[:3,:3]) # R is stored transposed due to 'glm' in CUDA code
T = w2c[:3, 3]
image_path = os.path.join(path, cam_name)
if not os.path.exists(image_path):
# modify -----
base, ext = os.path.splitext(image_path)
if ext.lower() == ".jpg":
image_path = base + ".png"
elif ext.lower() == ".png":
image_path = base + ".jpg"
if not os.path.exists(image_path):
continue
# modify ----
image_name = Path(cam_name).stem
image = Image.open(image_path)
im_data = np.array(image.convert("RGBA"))
bg = np.array([1,1,1]) if white_background else np.array([0, 0, 0])
norm_data = im_data / 255.0
arr = norm_data[:,:,:3] * norm_data[:, :, 3:4] + bg * (1 - norm_data[:, :, 3:4])
image = Image.fromarray(np.array(arr*255.0, dtype=np.byte), "RGB")
# NOTE: load SAM mask and CLIP feat. [OpenGaussian]
mask_seg_path = os.path.join(path, "language_features/" + frame["file_path"].split('/')[-1] + "_s.npy")
mask_feat_path = os.path.join(path, "language_features/" + frame["file_path"].split('/')[-1] + "_f.npy")
if os.path.exists(mask_seg_path):
sam_mask = np.load(mask_seg_path) # [level=4, H, W]
else:
sam_mask = None
if os.path.exists(mask_feat_path):
mask_feat = np.load(mask_feat_path) # [num_mask, dim=512]
else:
mask_feat = None
# modify -----
# ----- modify -----
if "K" in frame.keys():
cx = frame["K"][0][2]
cy = frame["K"][1][2]
if cx == -1:
cx = image.size[0] / 2
cy = image.size[1] / 2
# ----- modify -----
# ----- modify -----
if fovx == None:
if "K" in frame.keys():
focal_length = frame["K"][0][0]
if "fl_x" in contents.keys():
focal_length = contents["fl_x"]
if "fl_x" in frame.keys():
focal_length = frame["fl_x"]
FovY = focal2fov(focal_length, image.size[1])
FovX = focal2fov(focal_length, image.size[0])
else:
fovy = focal2fov(fov2focal(fovx, image.size[0]), image.size[1])
FovY = fovx
FovX = fovy
# ----- modify -----
cam_infos.append(CameraInfo(uid=idx, R=R, T=T, FovY=FovY, FovX=FovX, cx=cx, cy=cy, image=image,
depth=None, sam_mask=sam_mask, mask_feat=mask_feat,
image_path=image_path, image_name=image_name, width=image.size[0], height=image.size[1]))
return cam_infos
def readNerfSyntheticInfo(path, white_background, eval, extension=".png"):
print("Reading Training Transforms")
train_cam_infos = readCamerasFromTransforms(path, "transforms_train.json", white_background, extension)
print("Reading Test Transforms")
if os.path.exists(os.path.join(path, "transforms_test.json")):
test_cam_infos = readCamerasFromTransforms(path, "transforms_test.json", white_background, extension)
else:
test_cam_infos = train_cam_infos
if not eval:
train_cam_infos.extend(test_cam_infos)
test_cam_infos = []
nerf_normalization = getNerfppNorm(train_cam_infos)
ply_path = os.path.join(path, "points3d.ply")
if not os.path.exists(ply_path):
# Since this data set has no colmap data, we start with random points
num_pts = 100_000
print(f"Generating random point cloud ({num_pts})...")
# We create random points inside the bounds of the synthetic Blender scenes
xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
shs = np.random.random((num_pts, 3)) / 255.0
pcd = BasicPointCloud(points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)))
storePly(ply_path, xyz, SH2RGB(shs) * 255)
try:
pcd = fetchPly(ply_path)
except:
pcd = None
scene_info = SceneInfo(point_cloud=pcd,
train_cameras=train_cam_infos,
test_cameras=test_cam_infos,
nerf_normalization=nerf_normalization,
ply_path=ply_path)
return scene_info
sceneLoadTypeCallbacks = {
"Colmap": readColmapSceneInfo,
"Blender" : readNerfSyntheticInfo
}
================================================
FILE: scene/gaussian_model.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import numpy as np
from utils.general_utils import inverse_sigmoid, get_expon_lr_func, build_rotation
from torch import nn
import os
from utils.system_utils import mkdir_p
from plyfile import PlyData, PlyElement
from utils.sh_utils import RGB2SH
# from simple_knn._C import distCUDA2 # no need
from scipy.spatial import KDTree # modify
from utils.graphics_utils import BasicPointCloud
from utils.general_utils import strip_symmetric, build_scaling_rotation
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def distCUDA2(points):
'''
https://github.com/graphdeco-inria/gaussian-splatting/issues/292
'''
points_np = points.detach().cpu().float().numpy()
dists, inds = KDTree(points_np).query(points_np, k=4)
meanDists = (dists[:, 1:] ** 2).mean(1)
return torch.tensor(meanDists, dtype=points.dtype, device=points.device)
class GaussianModel:
def setup_functions(self):
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
actual_covariance = L @ L.transpose(1, 2)
symm = strip_symmetric(actual_covariance)
return symm
self.scaling_activation = torch.exp
self.scaling_inverse_activation = torch.log
self.covariance_activation = build_covariance_from_scaling_rotation
self.opacity_activation = torch.sigmoid
self.inverse_opacity_activation = inverse_sigmoid
self.rotation_activation = torch.nn.functional.normalize
def __init__(self, sh_degree : int):
self.active_sh_degree = 0
self.max_sh_degree = sh_degree
self._xyz = torch.empty(0)
self._features_dc = torch.empty(0)
self._features_rest = torch.empty(0)
self._scaling = torch.empty(0)
self._rotation = torch.empty(0)
self._opacity = torch.empty(0)
self._ins_feat = torch.empty(0) # Continuous instance features before quantization
self._ins_feat_q = torch.empty(0) # Discrete instance features after quantization
self.iClusterSubNum = torch.empty(0)
self.max_radii2D = torch.empty(0)
self.xyz_gradient_accum = torch.empty(0)
self.denom = torch.empty(0)
self.optimizer = None
self.percent_dense = 0
self.spatial_lr_scale = 0
self.setup_functions()
def capture(self):
return (
self.active_sh_degree,
self._xyz,
self._features_dc,
self._features_rest,
self._scaling,
self._rotation,
self._opacity,
self._ins_feat, # Continuous instance features before quantization
self._ins_feat_q, # Discrete instance features after quantization
self.max_radii2D,
self.xyz_gradient_accum,
self.denom,
self.optimizer.state_dict(),
self.spatial_lr_scale,
)
def restore(self, model_args, training_args):
(self.active_sh_degree,
self._xyz,
self._features_dc,
self._features_rest,
self._scaling,
self._rotation,
self._opacity,
self._ins_feat, # Continuous instance features before quantization
self._ins_feat_q, # Discrete instance features after quantization
self.max_radii2D,
xyz_gradient_accum,
denom,
opt_dict,
self.spatial_lr_scale) = model_args
self.training_setup(training_args)
self.xyz_gradient_accum = xyz_gradient_accum
self.denom = denom
self.optimizer.load_state_dict(opt_dict)
@property
def get_scaling(self):
return self.scaling_activation(self._scaling)
@property
def get_scaling_origin(self):
return self.scaling_activation(self._scaling)
@property
def get_rotation(self):
return self.rotation_activation(self._rotation)
@property
def get_rotation_matrix(self):
return build_rotation(self._rotation)
@property
def get_eigenvector(self):
scales = self.get_scaling_origin
N = scales.shape[0]
idx = torch.min(scales, dim=1)[1]
normals = self.get_rotation_matrix[np.arange(N), :, idx]
normals = torch.nn.functional.normalize(normals, dim=1)
return normals
@property
def get_xyz(self):
return self._xyz
@property
def get_features(self):
features_dc = self._features_dc
features_rest = self._features_rest
return torch.cat((features_dc, features_rest), dim=1)
@property
def get_opacity(self):
return self.opacity_activation(self._opacity)
# NOTE: get instance feature
# @property
def get_ins_feat(self, origin=False):
if len(self._ins_feat_q) == 0 or origin:
ins_feat = self._ins_feat
else:
ins_feat = self._ins_feat_q
ins_feat = torch.nn.functional.normalize(ins_feat, dim=1)
return ins_feat
def get_covariance(self, scaling_modifier = 1):
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
def oneupSHdegree(self):
if self.active_sh_degree < self.max_sh_degree:
self.active_sh_degree += 1
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float):
self.spatial_lr_scale = spatial_lr_scale
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() # [N, 3, 16]
features[:, :3, 0 ] = fused_color
features[:, 3:, 1:] = 0.0
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
rots[:, 0] = 1
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
# modify -----
ins_feat = torch.rand((fused_point_cloud.shape[0], 6), dtype=torch.float, device="cuda")
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
self._scaling = nn.Parameter(scales.requires_grad_(True))
self._rotation = nn.Parameter(rots.requires_grad_(True))
self._opacity = nn.Parameter(opacities.requires_grad_(True))
# modify -----
self._ins_feat = nn.Parameter(ins_feat.requires_grad_(True))
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
def training_setup(self, training_args):
self.percent_dense = training_args.percent_dense
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
l = [
{'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
{'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
{'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
{'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
{'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
{'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"},
{'params': [self._ins_feat], 'lr': training_args.ins_feat_lr, "name": "ins_feat"} # modify -----
]
# note: Freeze the position of the initial point, do not densify. for ScanNet 3DGS pre-train stage
if training_args.frozen_init_pts:
self._xyz = self._xyz.detach()
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
lr_delay_mult=training_args.position_lr_delay_mult,
max_steps=training_args.position_lr_max_steps)
def update_learning_rate(self, iteration, root_start, leaf_start):
''' Learning rate scheduling per step '''
for param_group in self.optimizer.param_groups:
if param_group["name"] == "xyz":
lr = self.xyz_scheduler_args(iteration)
param_group['lr'] = lr
# return lr
if param_group["name"] == "ins_feat":
if iteration > root_start and iteration <= leaf_start: # TODO: update lr
param_group['lr'] = param_group['lr'] * 0 + 0.0001
else:
param_group['lr'] = param_group['lr'] * 0 + 0.001
def construct_list_of_attributes(self):
l = ['x', 'y', 'z', 'nx', 'ny', 'nz', 'ins_feat_r', 'ins_feat_g', 'ins_feat_b', \
'ins_feat_r2', 'ins_feat_g2', 'ins_feat_b2']
# All channels except the 3 DC
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
l.append('f_dc_{}'.format(i))
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
l.append('f_rest_{}'.format(i))
l.append('opacity')
for i in range(self._scaling.shape[1]):
l.append('scale_{}'.format(i))
for i in range(self._rotation.shape[1]):
l.append('rot_{}'.format(i))
return l
def save_ply(self, path, save_q=[]):
mkdir_p(os.path.dirname(path))
xyz = self._xyz.detach().cpu().numpy()
normals = np.zeros_like(xyz)
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
opacities = self._opacity.detach().cpu().numpy()
scale = self._scaling.detach().cpu().numpy()
rotation = self._rotation.detach().cpu().numpy()
if "ins_feat" in save_q:
ins_feat = self._ins_feat_q.detach().cpu().numpy()
else:
ins_feat = self._ins_feat.detach().cpu().numpy()
# NOTE: pts feat visualization
vis_color = (ins_feat + 1) / 2 * 255
r, g, b = vis_color[:, 0].reshape(-1, 1), vis_color[:, 1].reshape(-1, 1), vis_color[:, 2].reshape(-1, 1)
# todo: points not fully optimized due to sampled training images.
ignored_ind = sigmoid(opacities) < 0.1
r[ignored_ind], g[ignored_ind], b[ignored_ind] = 128, 128, 128
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
dtype_full = dtype_full + [('red', 'u1'), ('green', 'u1'), ('blue', 'u1')] # modify
elements = np.empty(xyz.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyz, normals, ins_feat,\
f_dc, f_rest, opacities, scale, rotation,\
r, g, b), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path)
def reset_opacity(self):
opacities_new = inverse_sigmoid(torch.min(self.get_opacity, torch.ones_like(self.get_opacity)*0.01))
optimizable_tensors = self.replace_tensor_to_optimizer(opacities_new, "opacity")
self._opacity = optimizable_tensors["opacity"]
def load_ply(self, path):
plydata = PlyData.read(path)
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"])), axis=1)
ins_feat = np.stack((np.asarray(plydata.elements[0]["ins_feat_r"]),
np.asarray(plydata.elements[0]["ins_feat_g"]),
np.asarray(plydata.elements[0]["ins_feat_b"]),
np.asarray(plydata.elements[0]["ins_feat_r2"]),
np.asarray(plydata.elements[0]["ins_feat_g2"]),
np.asarray(plydata.elements[0]["ins_feat_b2"])), axis=1)
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
if not opacities.flags['C_CONTIGUOUS']:
opacities = np.ascontiguousarray(opacities)
features_dc = np.zeros((xyz.shape[0], 3, 1))
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
extra_f_names = sorted(extra_f_names, key = lambda x: int(x.split('_')[-1]))
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
scale_names = sorted(scale_names, key = lambda x: int(x.split('_')[-1]))
scales = np.zeros((xyz.shape[0], len(scale_names)))
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
rot_names = sorted(rot_names, key = lambda x: int(x.split('_')[-1]))
rots = np.zeros((xyz.shape[0], len(rot_names)))
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
self._ins_feat = nn.Parameter(torch.tensor(ins_feat, dtype=torch.float, device="cuda").requires_grad_(True))
self.active_sh_degree = self.max_sh_degree
def replace_tensor_to_optimizer(self, tensor, name):
optimizable_tensors = {}
for group in self.optimizer.param_groups:
if group["name"] == name:
stored_state = self.optimizer.state.get(group['params'][0], None)
stored_state["exp_avg"] = torch.zeros_like(tensor)
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def _prune_optimizer(self, mask):
optimizable_tensors = {}
for group in self.optimizer.param_groups:
stored_state = self.optimizer.state.get(group['params'][0], None)
if stored_state is not None:
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
else:
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def prune_points(self, mask):
valid_points_mask = ~mask
optimizable_tensors = self._prune_optimizer(valid_points_mask)
self._xyz = optimizable_tensors["xyz"]
self._features_dc = optimizable_tensors["f_dc"]
self._features_rest = optimizable_tensors["f_rest"]
self._opacity = optimizable_tensors["opacity"]
self._scaling = optimizable_tensors["scaling"]
self._rotation = optimizable_tensors["rotation"]
self._ins_feat = optimizable_tensors["ins_feat"]
self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
self.denom = self.denom[valid_points_mask]
self.max_radii2D = self.max_radii2D[valid_points_mask]
def cat_tensors_to_optimizer(self, tensors_dict):
optimizable_tensors = {}
for group in self.optimizer.param_groups:
assert len(group["params"]) == 1
extension_tensor = tensors_dict[group["name"]]
stored_state = self.optimizer.state.get(group['params'][0], None)
if stored_state is not None:
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
else:
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, \
new_scaling, new_rotation, new_ins_feat):
d = {"xyz": new_xyz,
"f_dc": new_features_dc,
"f_rest": new_features_rest,
"opacity": new_opacities,
"scaling" : new_scaling,
"rotation" : new_rotation,
"ins_feat": new_ins_feat}
optimizable_tensors = self.cat_tensors_to_optimizer(d)
self._xyz = optimizable_tensors["xyz"]
self._features_dc = optimizable_tensors["f_dc"]
self._features_rest = optimizable_tensors["f_rest"]
self._opacity = optimizable_tensors["opacity"]
self._scaling = optimizable_tensors["scaling"]
self._rotation = optimizable_tensors["rotation"]
self._ins_feat = optimizable_tensors["ins_feat"]
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
n_init_points = self.get_xyz.shape[0]
# Extract points that satisfy the gradient condition
padded_grad = torch.zeros((n_init_points), device="cuda")
padded_grad[:grads.shape[0]] = grads.squeeze()
selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
selected_pts_mask = torch.logical_and(selected_pts_mask,
torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
stds = self.get_scaling[selected_pts_mask].repeat(N,1)
means =torch.zeros((stds.size(0), 3),device="cuda")
samples = torch.normal(mean=means, std=stds)
rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
new_ins_feat = self._ins_feat[selected_pts_mask].repeat(N,1)
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, \
new_opacity, new_scaling, new_rotation, new_ins_feat)
prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
self.prune_points(prune_filter)
def densify_and_clone(self, grads, grad_threshold, scene_extent):
# Extract points that satisfy the gradient condition
selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
selected_pts_mask = torch.logical_and(selected_pts_mask,
torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
new_xyz = self._xyz[selected_pts_mask]
new_features_dc = self._features_dc[selected_pts_mask]
new_features_rest = self._features_rest[selected_pts_mask]
new_opacities = self._opacity[selected_pts_mask]
new_scaling = self._scaling[selected_pts_mask]
new_rotation = self._rotation[selected_pts_mask]
new_ins_feat = self._ins_feat[selected_pts_mask]
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, \
new_scaling, new_rotation, new_ins_feat)
def densify_and_prune(self, max_grad, min_opacity, extent, max_screen_size):
grads = self.xyz_gradient_accum / self.denom
grads[grads.isnan()] = 0.0
self.densify_and_clone(grads, max_grad, extent)
self.densify_and_split(grads, max_grad, extent)
prune_mask = (self.get_opacity < min_opacity).squeeze()
if max_screen_size:
big_points_vs = self.max_radii2D > max_screen_size
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
prune_mask = torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws)
self.prune_points(prune_mask)
torch.cuda.empty_cache()
def add_densification_stats(self, viewspace_point_tensor, update_filter):
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor.grad[update_filter,:2], dim=-1, keepdim=True)
self.denom[update_filter] += 1
================================================
FILE: scene/kmeans_quantize.py
================================================
import os
import pdb
from tqdm import tqdm
import time
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
class Quantize_kMeans():
def __init__(self, num_clusters=64, num_leaf_clusters=10, num_iters=10, dim=9, dim_leaf=6):
self.num_clusters = num_clusters # k1
self.leaf_num_clusters = num_leaf_clusters # k2
self.num_kmeans_iters = num_iters # iter
self.vec_dim = dim # coarse-level, dim=9(feat+xyz)
self.leaf_vec_dim = dim_leaf # fine-level, dim=6(feat)
self.centers = torch.empty(0) # coarse center, [k1, 9]
self.leaf_centers = torch.empty(0) # fine center, [k2, 6]
self.iLeafSubNum = torch.empty(0) # Number of fine clusters per coarse cluster
self.cls_ids = torch.empty(0) # coarse cluster id [num_pts]
self.leaf_cls_ids = torch.empty(0) # fine cluster id[num_pts]
self.nn_index = torch.empty(0) # [num_pts] temporary variable
# for update_centers
self.cluster_ids = torch.empty(0)
self.excl_clusters = []
self.excl_cluster_ids = []
self.cluster_len = torch.empty(0)
self.max_cnt = 0
self.max_cnt_th = 10000
self.n_excl_cls = 0
self.pos_centers = torch.empty(0)
def get_dist(self, x, y, mode='sq_euclidean'):
"""Calculate distance between all vectors in x and all vectors in y.
x: (m, dim)
y: (n, dim)
dist: (m, n)
"""
if mode == 'sq_euclidean_chunk':
step = 65536
if x.shape[0] < step:
step = x.shape[0]
dist = []
for i in range(np.ceil(x.shape[0] / step).astype(int)):
dist.append(torch.cdist(x[(i*step): (i+1)*step, :].unsqueeze(0), y.unsqueeze(0))[0])
dist = torch.cat(dist, 0)
elif mode == 'sq_euclidean':
dist = torch.cdist(x.unsqueeze(0).detach(), y.unsqueeze(0).detach())[0]
return dist
# Update centers in non-cluster assignment iters using cached nn indices.
def update_centers(self, feat, mode="root", selected_leaf=-1):
if mode == "root":
centers = self.centers
num_clusters = self.num_clusters
vec_dim = self.vec_dim
elif mode == "leaf":
centers = self.leaf_centers
num_clusters = self.num_clusters * self.leaf_num_clusters + 1
vec_dim = self.leaf_vec_dim
feat = feat.detach().reshape(-1, vec_dim) # [num_pts, dim] [766267, 9]
# Update all clusters except the excluded ones in a single operation
# Add a dummy element with zeros at the end
feat = torch.cat([feat, torch.zeros_like(feat[:1]).cuda()], 0) # [num_pts+1, dim]
centers = torch.sum(feat[self.cluster_ids, :].reshape(
num_clusters, self.max_cnt, -1), dim=1) # [num_clusters, vec_dim]
if len(self.excl_cluster_ids) > 0:
for i, cls in enumerate(self.excl_clusters):
# Division by num_points in cluster is done during the one-shot averaging of all
# clusters below. Only the extra elements in the bigger clusters are added here.
centers[cls] += torch.sum(feat[self.excl_cluster_ids[i], :], dim=0)
centers /= (self.cluster_len + 1e-6)
# Update centers during cluster assignment using mask matrix multiplication
# Mask is obtained from distance matrix
def update_centers_(self, feat, cluster_mask=None, nn_index=None, avg=False):
# feat = feat.detach().reshape(-1, self.vec_dim)
centers = (cluster_mask.T @ feat) # [1w, num_cluster] * [1w, dim] -> [num_cluster, dim]
# if avg:
# self.centers /= counts.unsqueeze(-1)
return centers
def equalize_cluster_size(self, mode="root"):
"""Make the size of all the clusters the same by appending dummy elements.
"""
# Find the maximum number of elements in a cluster, make size of all clusters
# equal by appending dummy elements until size is equal to size of max cluster.
# If max is too large, exclude it and consider the next biggest. Use for loop for
# the excluded clusters and a single operation for the remaining ones for
# updating the cluster centers.
unq, n_unq = torch.unique(self.nn_index, return_counts=True)
# Find max cluster size and exclude clusters greater than a threshold
topk = 100
if len(n_unq) < topk:
topk = len(n_unq)
max_cnt_topk, topk_idx = torch.topk(n_unq, topk)
self.max_cnt = max_cnt_topk[0]
idx = 0
self.excl_clusters = []
self.excl_cluster_ids = []
while(self.max_cnt > self.max_cnt_th):
self.excl_clusters.append(unq[topk_idx[idx]])
idx += 1
if idx < topk:
self.max_cnt = max_cnt_topk[idx]
else:
break
self.n_excl_cls = len(self.excl_clusters)
self.excl_clusters = sorted(self.excl_clusters)
# Store the indices of elements for each cluster
all_ids = []
cls_len = []
if mode == "root":
num_clusters = self.num_clusters
elif mode == "leaf":
num_clusters = self.num_clusters * self.leaf_num_clusters + 1
for i in range(num_clusters):
cur_cluster_ids = torch.where(self.nn_index == i)[0]
# For excluded clusters, use only the first max_cnt elements
# for averaging along with other clusters. Separately average the
# remaining elements just for the excluded clusters.
cls_len.append(torch.Tensor([len(cur_cluster_ids)]))
if i in self.excl_clusters:
self.excl_cluster_ids.append(cur_cluster_ids[self.max_cnt:])
cur_cluster_ids = cur_cluster_ids[:self.max_cnt]
# Append dummy elements to have same size for all clusters
all_ids.append(torch.cat([cur_cluster_ids, -1 * torch.ones((self.max_cnt - len(cur_cluster_ids)),
dtype=torch.long).cuda()]))
all_ids = torch.cat(all_ids).type(torch.long)
cls_len = torch.cat(cls_len).type(torch.long)
self.cluster_ids = all_ids
self.cluster_len = cls_len.unsqueeze(1).cuda()
if mode == "root":
self.cls_ids = self.nn_index
elif mode == "leaf":
self.leaf_cls_ids = self.nn_index
def cluster_assign(self, feat, feat_scaled=None, mode="root", selected_leaf=-1):
# quantize with kmeans
feat = feat.detach() # [N, dim]
if feat_scaled is None:
feat_scaled = feat
scale = feat[0] / (feat_scaled[0] + 1e-8)
# init. centers and ids
if len(self.centers) == 0 and mode == "root":
self.centers = feat[torch.randperm(feat.shape[0])[:self.num_clusters], :]
if len(self.leaf_centers) == 0 and mode == "leaf":
# [num_clusters, leaf_num_clusters, dim_leaf] eg. [640, 6]
self.leaf_centers = feat[torch.randperm(feat.shape[0])[:self.num_clusters * self.leaf_num_clusters+1], :]
self.leaf_cls_ids = torch.ones(feat.shape[0]).to(torch.int64).cuda() * self.num_clusters * self.leaf_num_clusters
# start kmeans
chunk = True
# tmp centers
if mode == "root":
tmp_centers = torch.zeros_like(self.centers)
counts = torch.zeros(self.num_clusters, dtype=torch.float32).cuda() + 1e-6
elif mode == "leaf":
tmp_centers = torch.zeros_like(self.leaf_centers)[:self.leaf_num_clusters, :]
counts = torch.zeros(self.leaf_num_clusters, dtype=torch.float32).cuda() + 1e-6
start_id = selected_leaf * self.leaf_num_clusters
end_id = selected_leaf * self.leaf_num_clusters + self.iLeafSubNum[selected_leaf]
for iteration in range(self.num_kmeans_iters):
# chunk for memory issues
if chunk:
self.nn_index = None
i = 0
chunk = 10000
if mode == "root":
while True:
dist = self.get_dist(feat[i*chunk:(i+1)*chunk, :], self.centers)
curr_nn_index = torch.argmin(dist, dim=-1) # [1W]
# Assign a single cluster when distance to multiple clusters is same
dist = F.one_hot(curr_nn_index, self.num_clusters).type(torch.float32) # [1W, 512]
curr_centers = self.update_centers_(feat[i*chunk:(i+1)*chunk, :], dist, curr_nn_index, avg=False) # [512, 45]
counts += dist.detach().sum(0) + 1e-6 # [512]
tmp_centers += curr_centers
if self.nn_index == None:
self.nn_index = curr_nn_index
else:
self.nn_index = torch.cat((self.nn_index, curr_nn_index), dim=0)
i += 1
if i*chunk > feat.shape[0]:
break
elif mode == "leaf":
for idx_c in range(self.num_clusters):
if idx_c != selected_leaf:
continue
selected_pts = self.cls_ids == idx_c
dist = self.get_dist(feat[selected_pts], self.leaf_centers[start_id:end_id])
curr_nn_index = torch.argmin(dist, dim=-1) # [1W]
dist = F.one_hot(curr_nn_index, self.leaf_num_clusters).type(torch.float32) # [1W, 10]
curr_centers = self.update_centers_(feat[selected_pts], dist, curr_nn_index, avg=False) # [512, 45]
counts += dist.detach().sum(0) + 1e-6 # [512]
tmp_centers += curr_centers
self.leaf_cls_ids[selected_pts] = curr_nn_index + start_id
# avrage centers
if mode == "root":
self.centers = tmp_centers / counts.unsqueeze(-1)
elif mode == "leaf":
self.leaf_centers[start_id: start_id+self.leaf_num_clusters] = tmp_centers / counts.unsqueeze(-1)
# Reinitialize to 0
tmp_centers[tmp_centers != 0] = 0.
counts[counts > 0.1] = 0.
# Reassign ID according to the new centers
if chunk:
self.nn_index = None
i = 0
# chunk = 100000
if mode == "root":
while True:
dist = self.get_dist(feat_scaled[i * chunk:(i + 1) * chunk, :], self.centers)
curr_nn_index = torch.argmin(dist, dim=-1)
if self.nn_index == None:
self.nn_index = curr_nn_index
else:
self.nn_index = torch.cat((self.nn_index, curr_nn_index), dim=0)
i += 1
if i * chunk > feat.shape[0]:
break
elif mode == "leaf":
for idx_c in range(self.num_clusters):
if idx_c != selected_leaf:
continue
selected_pts = self.cls_ids == idx_c
dist = self.get_dist(feat[selected_pts], self.leaf_centers[start_id:end_id])
curr_nn_index = torch.argmin(dist, dim=-1)
self.leaf_cls_ids[selected_pts] = curr_nn_index + start_id
self.nn_index = self.leaf_cls_ids
self.equalize_cluster_size(mode=mode)
def rescale(self, feat, scale=None):
"""Scale the feature to be in the range [-1, 1] by dividing by its max value.
"""
if scale is None:
return feat / (abs(feat).max(dim=0)[0] + 1e-8)
else:
return feat / (scale + 1e-8)
def forward(self, gaussian, iteration, assign=False, mode="root", selected_leaf=-1, pos_weight=1.0):
if mode == "root":
# (1) coarse-level: feature + xyz
scale = pos_weight # TODO
xyz_feat = gaussian._xyz.detach() * scale
feat = torch.cat((gaussian._ins_feat, xyz_feat), dim=1) # [N, 9]
elif mode == "leaf":
# (2) fine-level: feature only
feat = gaussian._ins_feat
if assign:
self.cluster_assign(feat, mode=mode, selected_leaf=selected_leaf) # gaussian._ins_feat
else:
self.update_centers(feat, mode=mode, selected_leaf=selected_leaf) # gaussian._ins_feat
if mode == "root":
centers = self.centers
vec_dim = self.vec_dim
elif mode == "leaf":
centers = self.leaf_centers
vec_dim = self.leaf_vec_dim
sampled_centers = torch.gather(centers, 0, self.nn_index.unsqueeze(-1).repeat(1, vec_dim))
# NOTE: "During backpropagation, the gradients of the quantized features are copied to the instance features", mentioned in the paper.
gaussian._ins_feat_q = gaussian._ins_feat - gaussian._ins_feat.detach() + sampled_centers[:,:6]
def replace_with_centers(self, gaussian):
deg = gaussian._features_rest.shape[1]
sampled_centers = torch.gather(self.centers, 0, self.nn_index.unsqueeze(-1).repeat(1, self.vec_dim))
gaussian._features_rest = gaussian._features_rest - gaussian._features_rest.detach() + sampled_centers.reshape(-1, deg, 3)
================================================
FILE: scripts/compute_lerf_iou.py
================================================
import os
import numpy as np
from PIL import Image
from argparse import ArgumentParser
def load_image_as_binary(image_path, is_png=False, threshold=10):
image = Image.open(image_path)
if is_png:
image = image.convert('L')
image_array = np.array(image)
binary_image = (image_array > threshold).astype(int)
return binary_image
def calculate_iou(mask1, mask2):
intersection = np.logical_and(mask1, mask2).sum()
union = np.logical_or(mask1, mask2).sum()
if union == 0:
return 0
return intersection / union
def evalute(gt_base, pred_base, scene_name):
scene_gt_frames = {
"waldo_kitchen": ["frame_00053", "frame_00066", "frame_00089", "frame_00140", "frame_00154"],
"ramen": ["frame_00006", "frame_00024", "frame_00060", "frame_00065", "frame_00081", "frame_00119", "frame_00128"],
"figurines": ["frame_00041", "frame_00105", "frame_00152", "frame_00195"],
"teatime": ["frame_00002", "frame_00025", "frame_00043", "frame_00107", "frame_00129", "frame_00140"]
}
frame_names = scene_gt_frames[scene_name]
ious = []
for frame in frame_names:
print("frame:", frame)
gt_floder = os.path.join(gt_base, frame)
file_names = [f for f in os.listdir(gt_floder) if f.endswith('.jpg')]
for file_name in file_names:
base_name = os.path.splitext(file_name)[0]
gt_obj_path = os.path.join(gt_floder, file_name)
pred_obj_path = os.path.join(pred_base, frame + "_" + base_name + '.png')
if not os.path.exists(pred_obj_path):
print(f"Missing pred file for {file_name}, skipping...")
print(f"IoU for {file_name}: 0")
ious.append(0.0)
continue
mask_gt = load_image_as_binary(gt_obj_path)
mask_pred = load_image_as_binary(pred_obj_path, is_png=True)
iou = calculate_iou(mask_gt, mask_pred)
ious.append(iou)
print(f"IoU for {file_name} and {base_name + '.png'}: {iou:.4f}")
# Acc.
total_count = len(ious)
count_iou_025 = (np.array(ious) > 0.25).sum()
count_iou_05 = (np.array(ious) > 0.5).sum()
# mIoU
average_iou = np.mean(ious)
print(f"Average IoU: {average_iou:.4f}")
print(f"Acc@0.25: {count_iou_025/total_count:.4f}")
print(f"Acc@0.5: {count_iou_05/total_count:.4f}")
if __name__ == "__main__":
parser = ArgumentParser("Compute LeRF IoU")
parser.add_argument("--scene_name", type=str, choices=["waldo_kitchen", "ramen", "figurines", "teatime"],
help="Specify the scene_name from: figurines, teatime, ramen, waldo_kitchen")
args = parser.parse_args()
if not args.scene_name:
parser.error("The --scene_name argument is required and must be one of: waldo_kitchen, ramen, figurines, teatime")
# TODO: change
path_gt = "/gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/label/waldo_kitchen/gt"
# renders_cluster_silhouette is the predicted mask
path_pred = "output/xxxxxxxx-x/text2obj/ours_70000/renders_cluster_silhouette"
evalute(path_gt, path_pred, args.scene_name)
================================================
FILE: scripts/eval_scannet.py
================================================
import os
from plyfile import PlyData, PlyElement
import torch.nn.functional as F
import numpy as np
import torch
import json
nyu40_dict = {
0: "unlabeled", 1: "wall", 2: "floor", 3: "cabinet", 4: "bed", 5: "chair",
6: "sofa", 7: "table", 8: "door", 9: "window", 10: "bookshelf",
11: "picture", 12: "counter", 13: "blinds", 14: "desk", 15: "shelves",
16: "curtain", 17: "dresser", 18: "pillow", 19: "mirror", 20: "floormat",
21: "clothes", 22: "ceiling", 23: "books", 24: "refrigerator", 25: "television",
26: "paper", 27: "towel", 28: "showercurtain", 29: "box", 30: "whiteboard",
31: "person", 32: "nightstand", 33: "toilet", 34: "sink", 35: "lamp",
36: "bathtub", 37: "bag", 38: "otherstructure", 39: "otherfurniture", 40: "otherprop"
}
# ScanNet 20 classes
scannet19_dict = {
1: "wall", 2: "floor", 3: "cabinet", 4: "bed", 5: "chair",
6: "sofa", 7: "table", 8: "door", 9: "window", 10: "bookshelf",
11: "picture", 12: "counter", 14: "desk", 16: "curtain",
24: "refrigerator", 28: "shower curtain", 33: "toilet", 34: "sink",
36: "bathtub", # 39: "otherfurniture"
}
import numpy as np
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def write_ply(vertex_data, output_path):
vertices = []
for vertex in vertex_data:
r = (vertex['ins_feat_r'] + 1)/2 * 255
g = (vertex['ins_feat_g'] + 1)/2 * 255
b = (vertex['ins_feat_b'] + 1)/2 * 255
new_vertex = (vertex['x'], vertex['y'], vertex['z'], r, g, b)
vertices.append(new_vertex)
vertex_dtype = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
new_vertex_data = np.array(vertices, dtype=vertex_dtype)
el = PlyElement.describe(new_vertex_data, 'vertex')
PlyData([el], text=True).write(output_path)
def read_labels_from_ply(file_path):
ply_data = PlyData.read(file_path)
vertex_data = ply_data['vertex'].data
# Extract the coordinates and labels of the points. The labels are from 1 to 40 for the NYU40 dataset, with 0 being invalid.
points = np.vstack([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
labels = vertex_data['label']
return points, labels
def calculate_metrics(gt, pred, total_classes):
gt = gt.cpu()
pred = pred.cpu()
pred[gt == 0] = 0
ious = torch.zeros(total_classes)
intersection = torch.zeros(total_classes)
union = torch.zeros(total_classes)
correct = torch.zeros(total_classes)
total = torch.zeros(total_classes)
for cls in range(1, total_classes):
intersection[cls] = torch.sum((gt == cls) & (pred == cls)).item()
union[cls] = torch.sum((gt == cls) | (pred == cls)).item()
correct[cls] = torch.sum((gt == cls) & (pred == cls)).item()
total[cls] = torch.sum(gt == cls).item()
valid_union = union != 0
ious[valid_union] = intersection[valid_union] / union[valid_union]
# Only consider the categories that exist in the current scene
gt_classes = torch.unique(gt)
valid_gt_classes = gt_classes[gt_classes != 0] # ignore 0
# miou
mean_iou = ious[valid_gt_classes].mean().item()
# acc
valid_mask = gt != 0
correct_predictions = torch.sum((gt == pred) & valid_mask).item()
total_valid_points = torch.sum(valid_mask).item()
accuracy = correct_predictions / total_valid_points if total_valid_points > 0 else float('nan')
class_accuracy = correct / total
# mAcc.
mean_class_accuracy = class_accuracy[valid_gt_classes].mean().item()
return ious, mean_iou, accuracy, mean_class_accuracy
if __name__ == "__main__":
scene_list = [ 'scene0000_00', 'scene0062_00', 'scene0070_00', 'scene0097_00', 'scene0140_00',
'scene0200_00', 'scene0347_00', 'scene0400_00', 'scene0590_00', 'scene0645_00']
iteration = 90000
for scan_name in scene_list:
# (1) GT ply change!
gt_file_path = f"/gdata/cold1/wuyanmin/OpenGaussian/data/scannet_2d_3types/{scan_name}/{scan_name}_vh_clean_2.labels.ply"
points, labels = read_labels_from_ply(gt_file_path)
# (2) note: 19 & 15 & 10 classes
# Given the category ID that needs to be queried (relative to the original NYU40), obtain the corresponding category name.
target_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36] # 19
# target_id = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 33, 34] # 15
# target_id = [1,2,4,5,6,7,8,9,10,33] # 10
target_dict = {key: nyu40_dict[key] for key in target_id}
target_names = list(target_dict.values())
# (3) update gt label
# Obtained new point cloud labels, taking 19 categories as an example, where updated_labels are labels 0, 1-19.
target_id_mapping = {value: index + 1 for index, value in enumerate(target_id)}
updated_labels = np.zeros_like(labels)
for original_value, new_value in target_id_mapping.items():
updated_labels[labels == original_value] = new_value
updated_gt_labels = torch.from_numpy(updated_labels.astype(np.int64)).cuda()
# (4) load gaussian ply file
model_path = f"output/{scan_name}/"
ply_path = os.path.join(model_path, f"point_cloud/iteration_{iteration}/point_cloud.ply")
ply_data = PlyData.read(ply_path)
vertex_data = ply_data['vertex'].data
# NOTE Filter out points based on their opacity values.
ignored_pts = sigmoid(vertex_data["opacity"]) < 0.1
updated_gt_labels[ignored_pts] = 0
# (5) load cluster language file
mapping_file = os.path.join(model_path, "cluster_lang.npz")
# load the saved codebook(leaf id) and instance-level language feature
# 'leaf_feat', 'leaf_acore', 'occu_count', 'leaf_ind'
saved_data = np.load(mapping_file)
leaf_lang_feat = torch.from_numpy(saved_data["leaf_feat.npy"]).cuda() # [num_leaf=k1*k2, 512]
leaf_score = torch.from_numpy(saved_data["leaf_score.npy"]).cuda() # [num_leaf=k1*k2]
leaf_occu_count = torch.from_numpy(saved_data["occu_count.npy"]).cuda() # [num_leaf=k1*k2]
leaf_ind = torch.from_numpy(saved_data["leaf_ind.npy"]).cuda() # [num_pts]
leaf_lang_feat[leaf_occu_count < 2] *= 0.0
leaf_ind = leaf_ind.clamp(max=319) # 64*5=320
# (6) load query text feat.
with open('assets/text_features.json', 'r') as f:
data_loaded = json.load(f)
all_texts = list(data_loaded.keys())
text_features = torch.from_numpy(np.array(list(data_loaded.values()))).to(torch.float32) # [num_text, 512]
query_text_feats = torch.zeros(len(target_names), 512).cuda()
for i, text in enumerate(target_names):
feat = text_features[all_texts.index(text)].unsqueeze(0)
query_text_feats[i] = feat
# (7) Calculate the cosine similarity and return the ID of the category with the highest value.
query_text_feats = F.normalize(query_text_feats, dim=1, p=2)
leaf_lang_feat = F.normalize(leaf_lang_feat, dim=1, p=2)
cosine_similarity = torch.matmul(query_text_feats, leaf_lang_feat.transpose(0, 1))
# cosine_similarity = torch.mm(query_text_feats, leaf_lang_feat.t()) # [cls_num, cluster_num]
max_id = torch.argmax(cosine_similarity, dim=0) # [cluster_num]
pred_pts_cls_id = max_id[leaf_ind] + 1 # [num_pts]
ious, mean_iou, accuracy, mean_acc = calculate_metrics(updated_gt_labels, pred_pts_cls_id, total_classes=len(target_names)+1)
print(f"Scene: {scan_name}, mIoU: {mean_iou:.4f}, mAcc.: {mean_acc:.4f}")
================================================
FILE: scripts/render_by_click.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import torch.nn.functional as F
from scene import Scene
import os
from tqdm import tqdm
from os import makedirs
from gaussian_renderer import render
import torchvision
from utils.general_utils import safe_state
from argparse import ArgumentParser
from arguments import ModelParams, PipelineParams, get_combined_args
from gaussian_renderer import GaussianModel
import numpy as np
from PIL import Image
import json
from utils.opengs_utlis import mask_feature_mean, get_SAM_mask_and_feat, load_code_book
import pytorch3d.ops
np.random.seed(42)
colors_defined = np.random.randint(100, 256, size=(300, 3))
colors_defined[0] = np.array([0, 0, 0])
colors_defined = torch.from_numpy(colors_defined)
def get_pixel_values(image_path, position, radius=10):
with Image.open(image_path) as img:
img = img.convert('RGB')
width, height = img.size
left = max(position[0] - radius, 0)
right = min(position[0] + radius + 1, width)
top = max(position[1] - radius, 0)
bottom = min(position[1] + radius + 1, height)
pixels = []
for x in range(left, right):
for y in range(top, bottom):
pixels.append(img.getpixel((x, y)))
pixels_array = np.array(pixels)
mean_pixel = pixels_array.mean(axis=0)
return tuple(mean_pixel)
def compute_click_values(model_path, image_name, pix_xy, radius=5):
def compute_level_click_val(iter, model_path, image_name, pix_xy, radius):
img_path1 = f"{model_path}/train/ours_{iter}/renders_ins_feat1/{image_name}_1.png" # TODO
img_path2 = f"{model_path}/train/ours_{iter}/renders_ins_feat2/{image_name}_2.png" # TODO
val1 = get_pixel_values(img_path1, pix_xy, radius)
val2 = get_pixel_values(img_path2, pix_xy, radius)
click_val = (torch.tensor(list(val1) + list(val2)) / 255) * 2 - 1
return click_val
level1_click_val = compute_level_click_val(50000, model_path, image_name, pix_xy, radius) # TODO
level2_click_val = compute_level_click_val(70000, model_path, image_name, pix_xy, radius) # TODO
return level1_click_val, level2_click_val
def render_set(model_path, name, iteration, views, gaussians, pipeline, background):
render_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders")
gts_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt")
render_ins_feat_path = os.path.join(model_path, name, "ours_{}".format(iteration), "renders_ins_feat")
gt_sam_mask_path = os.path.join(model_path, name, "ours_{}".format(iteration), "gt_sam_mask")
pseudo_ins_feat_path = os.path.join(model_path, name, "ours_{}".format(iteration), "pseudo_ins_feat")
makedirs(render_path, exist_ok=True)
makedirs(gts_path, exist_ok=True)
makedirs(render_ins_feat_path, exist_ok=True)
makedirs(gt_sam_mask_path, exist_ok=True)
makedirs(pseudo_ins_feat_path, exist_ok=True)
# load codebook
root_code_book, root_cluster_indices = load_code_book(os.path.join(model_path, "point_cloud", \
f'iteration_{iteration}', "root_code_book"))
leaf_code_book, leaf_cluster_indices = load_code_book(os.path.join(model_path, "point_cloud", \
f'iteration_{iteration}', "leaf_code_book"))
root_cluster_indices = torch.from_numpy(root_cluster_indices).cuda()
leaf_cluster_indices = torch.from_numpy(leaf_cluster_indices).cuda()
# counts = torch.bincount(torch.from_numpy(cluster_indices), minlength=64)
# load the saved codebook(leaf id) and instance-level language feature
# 'leaf_feat', 'leaf_acore', 'occu_count', 'leaf_ind' leaf_figurines_cluster_lang
mapping_file = os.path.join(model_path, "cluster_lang.npz")
saved_data = np.load(mapping_file)
leaf_lang_feat = torch.from_numpy(saved_data["leaf_feat.npy"]).cuda() # [num_leaf=640, 512] Language feature of each instance
leaf_score = torch.from_numpy(saved_data["leaf_score.npy"]).cuda() # [num_leaf=640] Score of each instance
leaf_occu_count = torch.from_numpy(saved_data["occu_count.npy"]).cuda() # [num_leaf=640] Number of occurrences of each instance
leaf_ind = torch.from_numpy(saved_data["leaf_ind.npy"]).cuda() # [num_pts] Instance ID corresponding to each point
leaf_lang_feat[leaf_occu_count < 5] *= 0.0 # ignore
leaf_cluster_indices = leaf_ind
image_name = "frame_00002" # TODO
# # object_name = "apple"
# pix_xy = (450, 217) # bag of cookies
# pix_xy = (344, 350) # apple
# # teatime image_name = "frame_00002"
# object_names = ["bear nose", "stuffed bear", "sheep", "bag of cookies", \
# "plate", "three cookies", "tea in a glass", "apple", \
# "coffee mug", "coffee", "paper napkin"]
# pix_xy_list = [ (740, 80), (800, 160), (80, 240), (450, 200),
# (468, 288), (438, 273), (309, 308), (343, 361),
# (578, 274), (571, 260), (565, 380)]
# figurines image_name = "frame_00002"
# TODO
object_names = ["rubber duck with buoy", "porcelain hand", "miffy", "toy elephant", "toy cat statue", \
"jake", "Play-Doh bucket", "rubber duck with hat", "rubics cube", "waldo", \
"twizzlers", "red toy chair", "green toy chair", "pink ice cream", "spatula", \
"pikachu", "green apple", "rabbit", "old camera", "pumpkin", \
"tesla door handle"]
# TODO
pix_xy_list = [ (103, 378), (552, 390), (896, 342), (720, 257), (254, 297),
(451, 197), (626, 256), (760, 166), (781, 243), (896, 136),
(927, 241), (688, 148), (538, 160), (565, 238), (575, 257),
(377, 156), (156, 244), (21, 237), (283, 152), (330, 200),
(514, 200)]
# # ramen image_name = "frame_00002"
# object_names = ["clouth", "sake cup", "chopsticks", "spoon", "plate", \
# "bowl", "egg", "nori", "glass of water", "napkin"]
# pix_xy_list = [(345, 38), (276, 424), (361, 370), (419, 285), (688, 412),
# (489, 119), (694, 187), (810, 154), (939, 289), (428, 462)]
# # waldo_kitchen image_name = "frame_00001"
# object_names = ["knife", "pour-over vessel", "glass pot1", "glass pot2", "toaster", \
# "hot water pot", "metal can", "cabinet", "ottolenghi", "waldo"]
# pix_xy_list = [(439, 76), (410, 297), (306, 127), (349, 182), (261, 256),
# (201, 262), (161, 267), (80, 34), (17, 141), (76, 169)]
for o_i, object in enumerate(object_names):
pix_xy = pix_xy_list[o_i]
root_click_val, leaf_click_val = compute_click_values(model_path, image_name, pix_xy)
# Compute the nearest clusters with respect to the two-level codebook
distances_root = torch.norm(root_click_val - root_code_book["ins_feat"][:, :-3].cpu(), dim=1)
distances_leaf = torch.norm(leaf_click_val - leaf_code_book["ins_feat"][:-1, :].cpu(), dim=1)
distances_leaf[leaf_code_book["ins_feat"][:-1].sum(-1) == 0] = 999 # Assign a large value to dis for nodes that remain unassigned
# Retrieve the candidate child nodes linked to each selected root node
min_index_root = torch.argmin(distances_root).item()
leaf_num = (leaf_code_book["ins_feat"].shape[0] - 1) / root_code_book["ins_feat"].shape[0]
start_id = int(min_index_root*leaf_num)
end_id = int((min_index_root + 1)*leaf_num)
distances_leaf_sub = distances_leaf[start_id: end_id] # [10]
# # (1) Choose several child nodes that fulfill the requirements
# click_leaf_indices = torch.nonzero(distances_leaf_sub < 0.9).squeeze() + start_id
# if (click_leaf_indices.dim() == 0) and click_leaf_indices.numel() != 0:
# click_leaf_indices = click_leaf_indices.unsqueeze(0)
# elif click_leaf_indices.numel() == 0:
# click_leaf_indices = torch.argmin(distances_leaf_sub).unsqueeze(0)
# (2) identify the root-level codebook and then pick the closest leaf node inside it (preferred)
click_leaf_indices = torch.argmin(distances_leaf_sub).unsqueeze(0) + start_id
# (3) directly select the child node with the minimum distance (less precise)
# click_leaf_indices = torch.argmin(distances_leaf).unsqueeze(0)
# # (4) you can also directly specify a particular child node if needed
# click_leaf_indices = torch.tensor([60, 66]) # 64 picachu, 60, 66 toy elephant, 65 jake, 633 green apple, 639 duck
# Get the mask linked to the child node
pre_pts_mask = (leaf_cluster_indices.unsqueeze(1) == click_leaf_indices.cuda()).any(dim=1)
# post process modify-----
post_process = True
max_time = 5
if post_process and max_time > 0:
nearest_k_distance = pytorch3d.ops.knn_points(
gaussians._xyz[pre_pts_mask].unsqueeze(0),
gaussians._xyz[pre_pts_mask].unsqueeze(0),
K=int(pre_pts_mask.sum()**0.5) * 2,
).dists
mean_nearest_k_distance, std_nearest_k_distance = nearest_k_distance.mean(), nearest_k_distance.std()
# print(std_nearest_k_distance, "std_nearest_k_distance")
# mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + std_nearest_k_distance
mask = nearest_k_distance.mean(dim = -1) < mean_nearest_k_distance + 0.1 * std_nearest_k_distance
# mask = nearest_k_distance.mean(dim = -1) < 2 * mean_nearest_k_distance
mask = mask.squeeze()
if pre_pts_mask is not None:
pre_pts_mask[pre_pts_mask != 0] = mask
max_time -= 1
# out_dir = "ca9c2998-e"
# splits = ["train", "train", "train", "train", "test"]
# frame_name_list = ["frame_00053", "frame_00066", "frame_00140", "frame_00154", "frame_00089"]
# for f_i, frame_name in enumerate(frame_name_list):
# base_path = f"/mnt/disk1/codes/wuyanmin/code/OpenGaussian/output/{out_dir}/{splits[f_i]}/ours_70000/renders_cluster_silhouette"
# target_path = f"/mnt/disk1/codes/wuyanmin/code/OpenGaussian/output/{out_dir}/{splits[f_i]}/ours_70000/result/{frame_name}"
# makedirs(target_path, exist_ok=True)
# for _, text in enumerate(waldo_kitchen_texts):
# pos_feat = text_features[query_texts.index(text)].unsqueeze(0)
# similarity_pos = F.cosine_similarity(pos_feat, leaf_lang_feat.cpu()) # [640]
# top_values, top_indices = torch.topk(similarity_pos, 10) # [num_mask]
# print("text: {} | cluster id: {}".format(text, top_indices[0]))
# ori_img_name = base_path + f"/{frame_name}_cluster_{top_indices[0].item()}.png"
# new_name = target_path + f"/{text}.png"
# if not os.path.exists(ori_img_name):
# top = 10
# for i in range(top):
# ori_img_name = target_path + f"/{frame_name}_cluster_{top_indices[i].item()}.png"
# if os.path.exists(ori_img_name):
# break
# if not os.path.exists(ori_img_name):
# print(f"No file found at {ori_img_name}. Operation skipped.")
# continue
# import shutil
# shutil.copy2(ori_img_name, new_name)
# render
for idx, view in enumerate(tqdm(views, desc="Rendering progress")):
# render_pkg = render(view, gaussians, pipeline, background, iteration, rescale=False)
# # figurines
# if view.image_name not in ["frame_00041", "frame_00105", "frame_00152", "frame_00195"]:
# continue
# # teatime
# if view.image_name not in ["frame_00002", "frame_00025", "frame_00043", "frame_00107", "frame_00129", "frame_00140"]:
# continue
# # ramen
# if view.image_name not in ["frame_00006", "frame_00024", "frame_00060", "frame_00065", "frame_00081", "frame_00119", "frame_00128"]:
# continue
# # waldo_kitchen
# if view.image_name not in ["frame_00053", "frame_00066", "frame_00089", "frame_00140", "frame_00154"]:
# continue
# NOTE render
render_pkg = render(view, gaussians, pipeline, background, iteration,
rescale=False, #) # wherther to re-scale the gaussian scale
# cluster_idx=leaf_cluster_indices, # root id
leaf_cluster_idx=leaf_cluster_indices, # leaf id
selected_leaf_id=click_leaf_indices.cuda(), # selected leaf id
render_feat_map=True,
render_cluster=False,
better_vis=True,
pre_mask=pre_pts_mask,
seg_rgb=True)
rendering = render_pkg["render"]
rendered_cluster_imgs = render_pkg["leaf_clusters_imgs"]
occured_leaf_id = render_pkg["occured_leaf_id"]
rendered_leaf_cluster_silhouettes = render_pkg["leaf_cluster_silhouettes"]
# save Rendered RGB
torchvision.utils.save_image(rendering, os.path.join(render_path, view.image_name + ".png"))
render_cluster_path = os.path.join(model_path, name, "ours_{}".format(iteration), "click_cluster")
render_cluster_silhouette_path = os.path.join(model_path, name, "ours_{}".format(iteration), "click_cluster_mask")
makedirs(render_cluster_path, exist_ok=True)
makedirs(render_cluster_silhouette_path, exist_ok=True)
for i, img in enumerate(rendered_cluster_imgs):
torchvision.utils.save_image(img[:3,:,:], os.path.join(render_cluster_path, \
view.image_name + f"_{object}_cluster_{occured_leaf_id[i]}.png"))
# save mask
cluster_silhouette = rendered_leaf_cluster_silhouettes[i] > 0.8
torchvision.utils.save_image(cluster_silhouette.to(torch.float32), os.path.join(render_cluster_silhouette_path, \
view.image_name + f"_{object}_cluster_{occured_leaf_id[i]}.png"))
def render_sets(dataset : ModelParams, iteration : int, pipeline : PipelineParams, skip_train : bool, skip_test : bool):
with torch.no_grad():
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians, load_iteration=iteration, shuffle=False)
bg_color = [1,1,1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
if not skip_train:
render_set(dataset.model_path, "train", scene.loaded_iter, scene.getTrainCameras(), gaussians, pipeline, background)
if not skip_test:
render_set(dataset.model_path, "test", scene.loaded_iter, scene.getTestCameras(), gaussians, pipeline, background)
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
parser.add_argument("--iteration", default=-1, type=int)
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--skip_test", action="store_true")
parser.add_argument("--quiet", action="store_true")
args = get_combined_args(parser)
print("Rendering " + args.model_path)
# Initialize system state (RNG)
safe_state(args.quiet)
render_sets(model.extract(args), args.iteration, pipeline.extract(args), args.skip_train, args.skip_test)
================================================
FILE: scripts/scannet2blender.py
================================================
import os
import json
import numpy as np
def load_transform_matrix(file_path):
"""
Load the transform matrix from a text file.
"""
with open(file_path, 'r') as file:
matrix = [list(map(float, line.strip().split())) for line in file]
return matrix
def process_directory(directory_path):
"""
Process each directory and create a JSON file with the transform matrices.
"""
color_dir = os.path.join(directory_path, "color") # TODO
pose_dir = os.path.join(directory_path, "pose") # TODO
intrinsic_dir = os.path.join(directory_path, "intrinsic") # TODO
# Check if both directories exist
if not os.path.isdir(color_dir) or not os.path.isdir(pose_dir):
return
# scannet
transform_data = {
'w': 1296,
'h': 968,
'fl_x': 1170.187988,
'fl_y': 1170.187988,
'cx': 647.75,
'cy': 483.75,
# 'aabb_scale': 2,
'frames': [],
}
# # scannet
# transform_data = {
# 'w': 640,
# 'h': 512,
# 'fl_x': 534.56,
# 'fl_y': 534.80,
# 'cx': 314.27,
# 'cy': 259.96,
# # 'aabb_scale': 2,
# 'frames': [],
# }
# Collect all image names and sort them
img_names = [img_name for img_name in os.listdir(color_dir) if img_name.endswith(".jpg")]
# img_names.sort(key=lambda x: int(os.path.splitext(x)[0])) # Sort by image number
img_names.sort(key=lambda x: os.path.splitext(x)[0]) # Sort by image number
# Iterate over the color images
for img_name in img_names:
if img_name.endswith(".jpg"):
# Construct the corresponding pose file path
pose_file = os.path.splitext(img_name)[0] + ".txt"
pose_file_path = os.path.join(pose_dir, pose_file)
intrinsic_file = os.path.splitext(img_name)[0] + ".txt"
intrinsic_file_path = os.path.join(intrinsic_dir, intrinsic_file)
# Check if the pose file exists
if os.path.isfile(pose_file_path):
transform_matrix = load_transform_matrix(pose_file_path)
# note: colmap --> blender
transform_matrix = np.array(transform_matrix)
transform_matrix[:3, 1:3] *= -1
transform_matrix = transform_matrix.tolist()
frame_data = {
"file_path": os.path.join("color", os.path.splitext(img_name)[0]),
"transform_matrix": transform_matrix
}
if os.path.isfile(intrinsic_file_path):
intrinsic_info = load_transform_matrix(intrinsic_file_path)
frame_data.update({
'fl_x': intrinsic_info[0][0],
'fl_y': intrinsic_info[1][1],
'cx': intrinsic_info[0][2],
'cy': intrinsic_info[1][2]
})
transform_data["frames"].append(frame_data)
return transform_data
# Directory containing the scenes
base_directory = 'PATH_TO_YOUR_SCANNET' # TODO
# Process each scene directory and create JSON files
for scene_dir in os.listdir(base_directory):
# if scene_dir != "scene0000_00":
# continue
scene_path = os.path.join(base_directory, scene_dir)
if os.path.isdir(scene_path):
# Process the directory and get the transform data
transform_data = process_directory(scene_path)
print(scene_path)
# Create the JSON file
if transform_data:
json_file_path = os.path.join(scene_path, "transforms_train.json")
with open(json_file_path, 'w') as json_file:
json.dump(transform_data, json_file, indent=4)
================================================
FILE: scripts/train_lerf.sh
================================================
#!/bin/bash
# chmod +x scripts/train_lerf.sh
# ./scripts/train_lerf.sh
# !!! Please check the dataset path specified by -s.
# Total training steps: 70k
# 3dgs pre-train: 0~30k
# stage1: 30~40k
# stage2 (coarse-level): 40~50k
# stage2 (fine-level): 50k~70k
# ###############################################
# # (1/4) figurines
# # Training takes approximately 70 minutes on a 24G 4090 GPU.
# # The object selection effect is better (recommended), the point cloud visualization is poor (not recommended).
# # k1=64, k2=10
# # --pos_weight 0.5
# # --save_memory: Saves memory, but will reduce training speed. If your GPU memory > 24GB, you can omit this flag
# ###############################################
scan="figurines"
gpu_num=3 # change
echo "Training for ${scan} ....."
CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \
-s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \
--iterations 70_000 \
--start_ins_feat_iter 30_000 \
--start_root_cb_iter 40_000 \
--start_leaf_cb_iter 50_000 \
--sam_level 3 \
--root_node_num 64 \
--leaf_node_num 10 \
--pos_weight 0.5 \
--save_memory \
--test_iterations 30000 \
--eval
# ###############################################
# # (2/4) waldo_kitchen
# # Training takes approximately 60 minutes on a 24G 4090 GPU.
# # Good point cloud visualization result (recommended), suboptimal object selection effect.
# # k1=64, k2=10
# # --pos_weight 0.5
# # No need to set save_memory, 24G is sufficient.
# ###############################################
scan="waldo_kitchen"
gpu_num=3 # change
echo "Training for ${scan} ....."
CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \
-s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \
--iterations 70_000 \
--start_ins_feat_iter 30_000 \
--start_root_cb_iter 40_000 \
--start_leaf_cb_iter 50_000 \
--sam_level 3 \
--root_node_num 64 \
--leaf_node_num 10 \
--pos_weight 0.5 \
--test_iterations 30000 \
--eval
# ###############################################
# # (3/4) teatime
# # Training takes approximately 80 minutes on a 24G 4090 GPU.
# # k1=32, k2=10
# # --pos_weight 0.1
# # --save_memory: Saves memory, but will reduce training speed. If your GPU memory > 24GB, you can omit this flag
# ###############################################
scan="teatime"
gpu_num=3 # change
echo "Training for ${scan} ....."
CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \
-s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \
--iterations 70_000 \
--start_ins_feat_iter 30_000 \
--start_root_cb_iter 40_000 \
--start_leaf_cb_iter 50_000 \
--sam_level 3 \
--root_node_num 32 \
--leaf_node_num 10 \
--pos_weight 0.1 \
--save_memory \
--test_iterations 30000 \
--eval
# ###############################################
# # (4/4) ramen
# # Training takes approximately 40 minutes on a 24G 4090 GPU.
# # The object selection effect is the worst and unstable (not recommended).
# # k1=64, k2=10
# # --pos_weight 0.5
# # --loss_weight 0.01: the weight of intra-mask smooth loss. 0.1 is used for the other scenes.
# # No need to set save_memory, 24G is sufficient.
# ###############################################
scan="ramen"
gpu_num=3
echo "Training for ${scan} ....."
CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \
-s /gdata/cold1/wuyanmin/OpenGaussian/data/lerf_ovs/${scan} \
--iterations 70_000 \
--start_ins_feat_iter 30_000 \
--start_root_cb_iter 40_000 \
--start_leaf_cb_iter 50_000 \
--sam_level 3 \
--root_node_num 64 \
--leaf_node_num 10 \
--pos_weight 0.5 \
--loss_weight 0.01 \
--test_iterations 30000 \
--eval
================================================
FILE: scripts/train_scannet.sh
================================================
#!/bin/bash
# chmod +x scripts/train_scannet.sh
# ./scripts/train_scannet.sh
# ============== [Notice] ==============
# 1. The 10 scene hyperparameters in the ScanNet dataset are consistent.
# 2. Train a scene for about 20 minutes on a 24G 4090 GPU.
# 3. Please check the dataset path specified by -s.
# ============== [Hyperparameter explanation] ==============
# Total training steps: 90k
# 3dgs pre-train: 0~30k
# stage1: 30~50k
# stage2 (coarse-level): 50~70k
# stage2 (fine-level): 70k~90k
# k1=64, k2=5
# frozen_init_pts: The point clouds provided by the ScanNet dataset are frozen, without using the densification scheme of 3DGS.
# -r 2 : We use half-resolution data for training.
# ============== [10 scenes] ==============
scan_list=("scene0000_00" "scene0062_00" "scene0070_00" "scene0097_00" "scene0140_00" \
"scene0200_00" "scene0347_00" "scene0400_00" "scene0590_00" "scene0645_00")
gpu_num=3 # change!
for scan in "${scan_list[@]}"; do
echo "Training for ${scan} ....."
CUDA_VISIBLE_DEVICES=$gpu_num python train.py --port 601$gpu_num \
-s /gdata/cold1/wuyanmin/OpenGaussian/data/onedrive/scannet/${scan} \
-r 2 \
--frozen_init_pts \
--iterations 90_000 \
--start_ins_feat_iter 30_000 \
--start_root_cb_iter 50_000 \
--start_leaf_cb_iter 70_000 \
--sam_level 0 \
--root_node_num 64 \
--leaf_node_num 5 \
--pos_weight 1.0 \
--test_iterations 30000 \
--eval
done
================================================
FILE: scripts/vis_opengs_pts_feat.py
================================================
import numpy as np
from plyfile import PlyData
import open3d as o3d
def sigmoid(x):
"""Sigmoid function."""
return 1 / (1 + np.exp(-x))
def visualize_ply(ply_path):
# Load the PLY file
ply_data = PlyData.read(ply_path)
vertex_data = ply_data['vertex'].data
# Extract the point cloud attributes
points = np.array([vertex_data['x'], vertex_data['y'], vertex_data['z']]).T
colors = np.array([vertex_data['red'], vertex_data['green'], vertex_data['blue']]).T / 255.0
opacity = vertex_data['opacity']
# Apply the opacity filter
sigmoid_opacity = sigmoid(opacity)
filtered_indices = sigmoid_opacity >= 0.1
filtered_points = points[filtered_indices]
filtered_colors = colors[filtered_indices]
# Create an Open3D PointCloud object
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(filtered_points)
pcd.colors = o3d.utility.Vector3dVector(filtered_colors)
# Visualize the point cloud
o3d.visualization.draw_geometries([pcd])
if __name__ == "__main__":
# Replace with the path to your PLY file
ply_path = "output/xxxxxxxx-x/point_cloud/iteration_x0000/point_cloud.ply"
visualize_ply(ply_path)
================================================
FILE: train.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import os
import torch
import torch.nn.functional as F
from random import randint
from utils.loss_utils import l1_loss, ssim, l2_loss
from gaussian_renderer import render, network_gui
import sys
from scene import Scene, GaussianModel
from utils.general_utils import safe_state
import uuid
from tqdm import tqdm
from utils.image_utils import psnr
from argparse import ArgumentParser, Namespace
from arguments import ModelParams, PipelineParams, OptimizationParams
from utils.graphics_utils import getWorld2View2, focal2fov, fov2focal
from os import makedirs
import torchvision
import numpy as np
from utils.sh_utils import RGB2SH
import math
# import faiss
from scene.kmeans_quantize import Quantize_kMeans
from bitarray import bitarray
from utils.system_utils import mkdir_p
from utils.opengs_utlis import mask_feature_mean, pair_mask_feature_mean, \
get_SAM_mask_and_feat, load_code_book, \
calculate_iou, calculate_distances, calculate_pairwise_distances
try:
from torch.utils.tensorboard import SummaryWriter
TENSORBOARD_FOUND = True
except ImportError:
TENSORBOARD_FOUND = False
# Randomly initialize 300 colors for visualizing the SAM mask. [OpenGaussian]
np.random.seed(42)
colors_defined = np.random.randint(100, 256, size=(300, 3))
colors_defined[0] = np.array([0, 0, 0]) # Ignore the mask ID of -1 and set it to black.
colors_defined = torch.from_numpy(colors_defined)
def dec2binary(x, n_bits=None):
"""Convert decimal integer x to binary.
Code from: https://stackoverflow.com/questions/55918468/convert-integer-to-pytorch-tensor-of-binary-bits
"""
if n_bits is None:
n_bits = torch.ceil(torch.log2(x)).type(torch.int64)
mask = 2**torch.arange(n_bits-1, -1, -1).to(x.device, x.dtype)
return x.unsqueeze(-1).bitwise_and(mask).ne(0)
def save_kmeans(kmeans_list, quantized_params, out_dir, mode="root"):
"""Save the codebook and indices of KMeans.
"""
# Convert to bitarray object to save compressed version
# saving as npy or pth will use 8bits per digit (or boolean) for the indices
# Convert to binary, concat the indices for all params and save.
if mode=="root":
out_dir = os.path.join(out_dir, 'root_code_book')
elif mode=="leaf":
out_dir = os.path.join(out_dir, 'leaf_code_book')
mkdir_p(out_dir)
bitarray_all = bitarray([])
for kmeans in kmeans_list:
if mode=="root":
cls_ids = kmeans.cls_ids
elif mode=="leaf":
cls_ids = kmeans.leaf_cls_ids
n_bits = int(np.ceil(np.log2(len(cls_ids))))
assignments = dec2binary(cls_ids, n_bits)
bitarr = bitarray(list(assignments.cpu().numpy().flatten()))
bitarray_all.extend(bitarr)
with open(os.path.join(out_dir, 'kmeans_inds.bin'), 'wb') as file: # cls_ids
bitarray_all.tofile(file)
# Save details needed for loading
args_dict = {}
args_dict['params'] = quantized_params
args_dict['n_bits'] = n_bits
args_dict['total_len'] = len(bitarray_all)
np.save(os.path.join(out_dir, 'kmeans_args.npy'), args_dict)
if mode=="root":
centers_dict = {param: kmeans.centers for (kmeans, param) in zip(kmeans_list, quantized_params)}
elif mode=="leaf":
centers_dict = {param: kmeans.leaf_centers for (kmeans, param) in zip(kmeans_list, quantized_params)}
# Save codebook
torch.save(centers_dict, os.path.join(out_dir, 'kmeans_centers.pth'))
def cohesion_loss(feat_map, gt_mask, feat_mean_stack):
"""intra-mask smoothing loss. Eq.(1) in the paper
Constrain the feature of each pixel within the mask to be close to the mean feature of that mask.
"""
N, H, W = gt_mask.shape
C = feat_map.shape[0]
# expand feat_map [6, H, W] to [N, 6, H, W]
feat_map_expanded = feat_map.unsqueeze(0).expand(N, C, H, W)
# expand mean feat [N, 6] to [N, 6, H, W]
feat_mean_stack_expanded = feat_mean_stack.unsqueeze(-1).unsqueeze(-1).expand(N, C, H, W)
# fature distance
masked_feat = feat_map_expanded * gt_mask.unsqueeze(1) # [N, 6, H, W]
dist = (masked_feat - feat_mean_stack_expanded).norm(p=2, dim=1) # [N, H, W]
# per mask feature distance (loss)
masked_dist = dist * gt_mask # [N, H, W]
loss_per_mask = masked_dist.sum(dim=[1, 2]) / gt_mask.sum(dim=[1, 2]).clamp(min=1)
return loss_per_mask.mean()
def separation_loss(feat_mean_stack, iteration):
""" inter-mask contrastive loss Eq.(2) in the paper
Constrain the instance features within different masks to be as far apart as possible.
"""
N, _ = feat_mean_stack.shape
# expand feat_mean_stack[N, 6] to [N, N, C]
feat_expanded = feat_mean_stack.unsqueeze(1).expand(-1, N, -1)
feat_transposed = feat_mean_stack.unsqueeze(0).expand(N, -1, -1)
# distance
diff_squared = (feat_expanded - feat_transposed).pow(2).sum(2)
# Calculate the inverse of the distance to enhance discrimination
epsilon = 1 # 1e-6
inverse_distance = 1.0 / (diff_squared + epsilon)
# Exclude diagonal elements (distance from itself) and calculate the mean inverse distance
mask = torch.eye(N, device=feat_mean_stack.device).bool()
inverse_distance.masked_fill_(mask, 0)
# note: weight
# sorted by distance
sorted_indices = inverse_distance.argsort().argsort()
loss_weight = (sorted_indices.float() / (N - 1)) * (1.0 - 0.1) + 0.1 # scale to 0.1 - 1.0, [N, N]
# small weight
if iteration > 35_000:
loss_weight[loss_weight < 0.9] = 0.1
inverse_distance *= loss_weight # [N, N]
# final loss
loss = inverse_distance.sum() / (N * (N - 1))
return loss
def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations, \
checkpoint, debug_from):
iterations = [opt.start_ins_feat_iter, opt.start_leaf_cb_iter, opt.start_root_cb_iter]
saving_iterations.extend(iterations)
checkpoint_iterations.extend(iterations)
first_iter = 0
tb_writer = prepare_output_and_logger(dataset)
gaussians = GaussianModel(dataset.sh_degree)
scene = Scene(dataset, gaussians)
gaussians.training_setup(opt)
if checkpoint:
(model_params, first_iter) = torch.load(checkpoint)
# NOTE: Load the original 3DGS pre-trained checkpoint and add the ins_feat attribute. [OpenGaussian]
if len(model_params) == 12:
# initialize instance color.
ins_feat = torch.rand((model_params[8].shape[0], opt.ins_feat_dim), dtype=torch.float, device="cuda")
ins_feat = torch.nn.Parameter(ins_feat.requires_grad_(True))
to_list = list(model_params)
# (1) replace optimizer
to_list[10] = gaussians.optimizer.state_dict()
# (2) add ins_feat
to_list.insert(7, ins_feat)
# (3) add ins_feat_q (quantized ins_feat)
ins_feat_q = torch.empty(0)
to_list.insert(8, ins_feat_q)
model_params = tuple(to_list)
gaussians.restore(model_params, opt)
ins_feat_continue = gaussians._ins_feat.clone().detach() # not used
else:
ins_feat_continue = None # not used
# initialize the codebook
ins_feat_codebook = Quantize_kMeans(num_clusters=opt.root_node_num, # k1
num_leaf_clusters=opt.leaf_node_num, # k2
num_iters=5,
dim=9)
# note: load the saved codebook
leaf_cluster_indices = None
if checkpoint:
base_dir = os.path.dirname(checkpoint)
load_iter = checkpoint.split('/')[-1].split('.')[0][6:]
root_code_book_path = os.path.join(base_dir, 'point_cloud', f"iteration_{load_iter}", "root_code_book")
leaf_code_book_path = os.path.join(base_dir, 'point_cloud', f"iteration_{load_iter}", "leaf_code_book")
if os.path.exists(os.path.join(root_code_book_path, 'kmeans_inds.bin')):
root_center, root_indices = load_code_book(root_code_book_path)
root_center_saved = root_center["ins_feat"]
cluster_indices = torch.from_numpy(root_indices).cuda()
ins_feat_codebook.centers = root_center_saved
ins_feat_codebook.cls_ids = cluster_indices
else:
cluster_indices = None
if os.path.exists(os.path.join(leaf_code_book_path, 'kmeans_inds.bin')):
leaf_center, leaf_indices = load_code_book(leaf_code_book_path)
leaf_center_saved = leaf_center["ins_feat"]
leaf_cluster_indices = torch.from_numpy(leaf_indices).cuda()
ins_feat_codebook.leaf_centers = leaf_center_saved
ins_feat_codebook.leaf_cls_ids = leaf_cluster_indices
else:
leaf_cluster_indices = None
bg_color = [1, 1, 1] if dataset.white_background else [0, 0, 0]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
iter_start = torch.cuda.Event(enable_timing = True)
iter_end = torch.cuda.Event(enable_timing = True)
viewpoint_stack = None
ema_loss_for_log = 0.0
progress_bar = tqdm(range(first_iter, opt.iterations), desc="Training progress")
first_iter += 1
root_id = 0 # for stage 2.2
loss = torch.tensor(0.0)
Ll1 = torch.tensor(0.0)
for iteration in range(first_iter, opt.iterations + 1):
no_need_bk = False
if network_gui.conn == None:
network_gui.try_connect()
while network_gui.conn != None:
try:
net_image_bytes = None
custom_cam, do_training, pipe.convert_SHs_python, pipe.compute_cov3D_python, keep_alive, scaling_modifer = network_gui.receive()
if custom_cam != None:
net_image = render(custom_cam, gaussians, pipe, background, iteration, scaling_modifer)["render"]
net_image_bytes = memoryview((torch.clamp(net_image, min=0, max=1.0) * 255).byte().permute(1, 2, 0).contiguous().cpu().numpy())
network_gui.send(net_image_bytes, dataset.source_path)
if do_training and ((iteration < int(opt.iterations)) or not keep_alive):
break
except Exception as e:
network_gui.conn = None
iter_start.record()
gaussians.update_learning_rate(iteration, opt.start_root_cb_iter, opt.start_leaf_cb_iter)
# Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % 1000 == 0:
gaussians.oneupSHdegree()
# Pick a random Camera
if not viewpoint_stack:
viewpoint_stack = scene.getTrainCameras().copy()
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
if not viewpoint_cam.data_on_gpu:
viewpoint_cam.to_gpu()
cb_mode = None # Current status: No launch codebook discretization
if iteration == 1:
print("[Stage 0] Start 3dgs pre-train ...")
sys.stdout.flush()
if iteration == opt.start_ins_feat_iter + 1:
print("[Stage 1] Start continuous instance feature learning ...")
sys.stdout.flush()
# Stage 2.1: Coarse-level codebook
if iteration > opt.start_root_cb_iter and iteration <= opt.start_leaf_cb_iter:
cb_mode = "root"
if iteration == opt.start_root_cb_iter + 1:
print("[Stage 2.1] Start coarse-level codebook discretization ...")
sys.stdout.flush()
elif iteration > opt.start_leaf_cb_iter:
cb_mode = "leaf"
# Stage 2.2: Fine-level codebook
if iteration == opt.start_leaf_cb_iter + 1:
print("[Stage 2.2] Start fine-level codebook discretization ...")
sys.stdout.flush()
# note Update a coarse cluster every leaf_update_fr(default 300) steps.
if (iteration - opt.start_leaf_cb_iter) % opt.leaf_update_fr == 0:
root_id += 1 # 0 ~ k1-1
if root_id > (opt.root_node_num-1):
root_id = 0
# ###########################################################################
# [Stage 2]: Two-Level Codebook for Discretization #
# - Preprocessing: construct pseudo labels (instance features of stage 1) #
# Will execute twice, before coarse-level and fine-level clustering #
# ###########################################################################
if (cb_mode is not None and viewpoint_cam.pesudo_ins_feat is None) or \
((iteration == opt.start_root_cb_iter + 1) or (iteration == opt.start_leaf_cb_iter + 1)):
with torch.no_grad():
if cb_mode == "leaf" and cluster_indices is None:
cluster_indices = ins_feat_codebook.cls_ids # [num_pts], Coarse-level ID of each point (0 ~ k1-1)
construct_pseudo_ins_feat(scene, render, (pipe, background, iteration),
cluster_indices=cluster_indices, mode=cb_mode,
root_num=opt.root_node_num, leaf_num=opt.leaf_node_num,
sam_level=opt.sam_level,
save_memory=opt.save_memory)
if not viewpoint_cam.data_on_gpu:
viewpoint_cam.to_gpu()
if cb_mode == "leaf":
# Number of leaves per root
ins_feat_codebook.iLeafSubNum = gaussians.iClusterSubNum
# Render
if (iteration - 1) == debug_from:
pipe.debug = True
bg = torch.rand((3), device="cuda") if opt.random_background else background
# ####################################################
# [Stage 2]: Two-Level Codebook for Discretization #
# - Update codebook #
# ####################################################
freq_k_means = 200 # coarse-level codebook update frequency
if cb_mode == "leaf":
freq_k_means = 50 # todo fine-level codebook update frequency
if cb_mode is not None:
if (iteration % freq_k_means == 1) or iteration == opt.start_root_cb_iter + 1:
assign = True # Reassign cluster centers
else:
assign = False # update cluster centers
ins_feat_codebook.forward(gaussians, iteration, assign=assign, \
mode=cb_mode, selected_leaf=root_id, \
pos_weight=opt.pos_weight) # note: position weight
# render function
if iteration <= opt.start_ins_feat_iter: # stage 0
render_feat=False
render_cluster=False
cluster_indices=None
elif iteration > opt.start_leaf_cb_iter: # stage 2.2 (fine-level)
render_feat=False
render_cluster=True
else: # stage 1, stage 2.1(coarse-level)
render_feat=True
render_cluster=False
cluster_indices=None
# rescale
if iteration > opt.start_root_cb_iter: # stage 2, rescale
rescale=True
else:
rescale=False
render_pkg = render(viewpoint_cam, gaussians, pipe, bg, iteration,
rescale=rescale, # wherther to re-scale the gaussian scale
cluster_idx=cluster_indices, # coarse-level cluster id
leaf_cluster_idx=ins_feat_codebook.leaf_cls_ids, # fine-level cluster id
render_feat_map=render_feat,
render_cluster=render_cluster,
selected_root_id=root_id) # coarse id (stage 2.2)
# rendered results
image, viewspace_point_tensor, visibility_filter, radii = \
render_pkg["render"], render_pkg["viewspace_points"], render_pkg["visibility_filter"], render_pkg["radii"]
alpha = render_pkg["alpha"]
rendered_silhouette = render_pkg["silhouette"] if render_pkg["silhouette"] is not None else alpha
rendered_silhouette = (rendered_silhouette > 0.7) * 1.0 # mask after re-scale
rendered_ins_feat = render_pkg["ins_feat"]
rendered_cluster_imgs = render_pkg["cluster_imgs"] # [num_cl, 6, H, W]
rendered_leaf_cluster_imgs = render_pkg["leaf_clusters_imgs"]
rendered_cluster_silhouettes = render_pkg["cluster_silhouettes"]
if render_cluster:
if rendered_cluster_silhouettes is not None and len(rendered_cluster_silhouettes) > 0:
rendered_cluster_silhouettes = rendered_cluster_silhouettes > 0.7
else:
# root_id-th coarse cluster not visible in current view
no_need_bk = True
# gt supervision: rgb image & SAM mask
gt_image = viewpoint_cam.original_image.cuda()
if viewpoint_cam.original_sam_mask is not None:
gt_sam_mask = viewpoint_cam.original_sam_mask.cuda() # [4, H, W]
# ##################################################
# [Stage 0]: 0 to 3w steps, Standard 3DGS RGB loss #
# ##################################################
if iteration <= opt.start_ins_feat_iter:
Ll1 = l1_loss(image, gt_image)
loss = (1.0 - opt.lambda_dssim) * Ll1 + opt.lambda_dssim * (1.0 - ssim(image, gt_image))
# Start learning instance features after 3W steps.
if iteration > opt.start_ins_feat_iter:
# NOTE: Freeze the pre-trained Gaussian parameters and only train the instance features.
scene.gaussians._xyz = scene.gaussians._xyz.detach()
scene.gaussians._features_dc = scene.gaussians._features_dc.detach()
scene.gaussians._features_rest = scene.gaussians._features_rest.detach()
scene.gaussians._opacity = scene.gaussians._opacity.detach()
scene.gaussians._scaling = scene.gaussians._scaling.detach()
scene.gaussians._rotation = scene.gaussians._rotation.detach()
# construct boolean masks [num_mask, H, W]
# sam_level, leaf:3, scannet:0
sam_level = opt.sam_level
mask_id, mask_bool, invalid_pix = get_SAM_mask_and_feat(gt_sam_mask, level=sam_level, filter_th=50)
# #################################################
# [Stage 1]: Continuous instance feature learning #
# LERF 3W-4W steps; ScanNet 3w-5w steps #
# see Sec.3.1 in the paper #
# #################################################
if cb_mode is None:
# (0) compute the average instance features within each mask. [num_mask, 6]
feat_mean_stack = mask_feature_mean(rendered_ins_feat, mask_bool, image_mask=rendered_silhouette)
# (1) intra-mask smoothing loss. Eq.(1) in the paper
loss_cohesion = cohesion_loss(rendered_ins_feat, mask_bool, feat_mean_stack)
# (2) inter-mask contrastive loss Eq.(2) in the paper
loss_separation = separation_loss(feat_mean_stack, iteration)
# total loss, opt.loss_weight: 0.1
loss = loss_separation + opt.loss_weight * loss_cohesion
# ####################################################
# [Stage 2]: Two-Level Codebook for Discretization
# - coarse-level(root) loss computation
# - fine-level(leaf) loss computation
# ####################################################
# 2.1 coarse-level
if cb_mode == "root":
# Only consider valid pixels
keeped_pix = viewpoint_cam.pesudo_ins_feat.sum(dim=(0)) > 0 # Invalid pixels of pseudo-labels
keeped_pix = keeped_pix.bool()&rendered_silhouette.bool() # Empty regions after rescaling
keeped_pix = keeped_pix&(~invalid_pix.unsqueeze(0)) # Invalid area of the original mask
keeped_pix = rendered_silhouette.bool()
# loss Eq.(4) in the paper.
feat_loss = l1_loss(rendered_ins_feat, viewpoint_cam.pesudo_ins_feat, keeped_pix)
# feat_loss = l2_loss(rendered_ins_feat, viewpoint_cam.pesudo_ins_feat, keeped_pix)
loss = feat_loss
# 2.2 fine-level
if cb_mode == "leaf" and no_need_bk == False:
total_pix = gt_image.shape[1] * gt_image.shape[2]
for i in range(len(rendered_cluster_imgs)):
cluster_pred = rendered_cluster_imgs[i]
cluster_silhouette = rendered_cluster_silhouettes[i] # [H, W] bool
rendered_ins_feat = cluster_pred #
# cluster_mask = viewpoint_cam.cluster_masks[i] # [H, W] bool
# cluster_silhouette = cluster_silhouette & cluster_mask
feat_loss = l2_loss(cluster_pred, viewpoint_cam.pesudo_ins_feat, cluster_silhouette)
if i == 0:
# loss = feat_loss * (cluster_silhouette.sum() / total_pix)
loss = feat_loss
else:
# loss += (feat_loss * (cluster_silhouette.sum() / total_pix))
loss += feat_loss
# mask loss. modify -----
if viewpoint_cam.original_mask is not None:
gt_mask = viewpoint_cam.original_mask.cuda()
mask_loss = F.mse_loss(alpha, gt_mask)
loss = loss + mask_loss
if no_need_bk == False:
loss.backward()
iter_end.record()
# Save the intermediate training results. [OpenGaussian]
save_intermediate = True
save_fre = 1000
if iteration > opt.start_leaf_cb_iter:
save_fre = 100
if (iteration % save_fre == 0) and save_intermediate:
gts_path = os.path.join(scene.model_path, "train_process", "gt")
makedirs(gts_path, exist_ok=True)
torchvision.utils.save_image(gt_image.detach().cpu(), os.path.join(gts_path, '{0:05d}'.format(iteration) + ".png"))
render_path = os.path.join(scene.model_path, "train_process", "renders")
makedirs(render_path, exist_ok=True)
torchvision.utils.save_image(image.detach().cpu(), os.path.join(render_path, '{0:05d}'.format(iteration) + ".png"))
# alpha_path = os.path.join(scene.model_path, "train_process", "alpha")
# makedirs(alpha_path, exist_ok=True)
# torchvision.utils.save_image(alpha.detach().cpu(), os.path.join(alpha_path, '{0:05d}'.format(iteration) + ".png"))
if iteration > opt.start_ins_feat_iter:
if cb_mode is None:
sub_floader = "stage1"
elif cb_mode == "root":
sub_floader = "stage2_1"
elif cb_mode == "leaf":
sub_floader = "stage2_2"
# Visualize the SAM mask. [OpenGaussian]
if gt_sam_mask is not None and iteration > opt.start_ins_feat_iter:
# read predefined mask color
mask_color_rand = colors_defined[mask_id.detach().cpu()].type(torch.float64)
mask_color_rand = mask_color_rand.permute(2, 0, 1)
gt_sam_path = os.path.join(scene.model_path, "train_process", sub_floader, "gt_sam_mask_" + str(opt.sam_level))
makedirs(gt_sam_path, exist_ok=True)
torchvision.utils.save_image(mask_color_rand/255.0, os.path.join(gt_sam_path, '{0:05d}'.format(iteration) + ".png"))
# TODO
if viewpoint_cam.pesudo_ins_feat is not None:
feat = viewpoint_cam.pesudo_ins_feat
pseudo_ins_feat_path = os.path.join(scene.model_path, "train_process", sub_floader, "pseudo_ins_feat")
makedirs(pseudo_ins_feat_path, exist_ok=True)
torchvision.utils.save_image(feat.detach().cpu()[:3, :, :], os.path.join(pseudo_ins_feat_path, '{0:05d}'.format(iteration) + "_1.png"))
torchvision.utils.save_image(feat.detach().cpu()[3:6, :, :], os.path.join(pseudo_ins_feat_path, '{0:05d}'.format(iteration) + "_2.png"))
if cb_mode is not None:
# silhouette (alpha to mask) [OpenGaussian] stage 2
silhouette_path = os.path.join(scene.model_path, "train_process", sub_floader, "silhouette")
makedirs(silhouette_path, exist_ok=True)
torchvision.utils.save_image(rendered_silhouette.detach().cpu(), os.path.join(silhouette_path, '{0:05d}'.format(iteration) + ".png"))
# Visualize the 6-dimensional instance feature. [OpenGuassian]
if rendered_ins_feat is not None:
# dim 0:3
ins_feat_path = os.path.join(scene.model_path, "train_process", sub_floader, "ins_feat")
makedirs(ins_feat_path, exist_ok=True)
torchvision.utils.save_image(rendered_ins_feat.detach().cpu()[:3, :, :], os.path.join(ins_feat_path, '{0:05d}'.format(iteration) + ".png"))
# dim 3:6
ins_feat_path2 = os.path.join(scene.model_path, "train_process", sub_floader, "ins_feat2")
makedirs(ins_feat_path2, exist_ok=True)
torchvision.utils.save_image(rendered_ins_feat.detach().cpu()[3:6, :, :], os.path.join(ins_feat_path2, '{0:05d}'.format(iteration) + ".png"))
# # fine-level cluster
# if rendered_leaf_cluster_imgs is not None:
# leaf_cluster_path = os.path.join(scene.model_path, "train_process", sub_floader, "cluster_leaf")
# makedirs(leaf_cluster_path, exist_ok=True)
# for i, leaf_img in enumerate(rendered_leaf_cluster_imgs):
# torchvision.utils.save_image(leaf_img.detach().cpu()[:3, :, :], os.path.join(leaf_cluster_path, '{0:05d}'.format(iteration) + "leaf_{}.png".format(i)))
with torch.no_grad():
# Progress bar
ema_loss_for_log = 0.4 * loss.item() + 0.6 * ema_loss_for_log
if iteration % 10 == 0:
progress_bar.set_postfix({"Loss": f"{ema_loss_for_log:.{7}f}"})
progress_bar.update(10)
if iteration == opt.iterations:
progress_bar.close()
# Log and save .ply
# training_report(tb_writer, iteration, Ll1, loss, l1_loss, iter_start.elapsed_time(iter_end), \
# testing_iterations, opt.start_root_cb_iter, scene, render, (pipe, background, iteration))
if (iteration in saving_iterations):
print("\n[ITER {}] Saving Gaussians".format(iteration))
sys.stdout.flush()
if iteration > opt.start_root_cb_iter:
# note: save codebook [OpenGaussian]
out_dir = os.path.join(scene.model_path, 'point_cloud/iteration_%d' % iteration)
save_kmeans([ins_feat_codebook], ["ins_feat"], out_dir, mode="root")
if cb_mode == "leaf":
save_kmeans([ins_feat_codebook], ["ins_feat"], out_dir, mode="leaf")
scene.save(iteration, ["ins_feat"])
else:
scene.save(iteration)
# Densification
if iteration < opt.densify_until_iter and \
not opt.frozen_init_pts: # note: ScanNet dataset is not densified [OpenGaussian]
# Keep track of max radii in image-space for pruning
gaussians.max_radii2D[visibility_filter] = torch.max(gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if iteration > opt.densify_from_iter and iteration % opt.densification_interval == 0:
size_threshold = 20 if iteration > opt.opacity_reset_interval else None
gaussians.densify_and_prune(opt.densify_grad_threshold, 0.005, scene.cameras_extent, size_threshold)
if iteration % opt.opacity_reset_interval == 0 or (dataset.white_background and iteration == opt.densify_from_iter):
gaussians.reset_opacity()
# Optimizer step
if iteration < opt.iterations:
gaussians.optimizer.step()
gaussians.optimizer.zero_grad(set_to_none = True)
torch.cuda.empty_cache()
if (iteration in checkpoint_iterations):
print("\n[ITER {}] Saving Checkpoint".format(iteration))
sys.stdout.flush()
torch.save((gaussians.capture(), iteration), scene.model_path + "/chkpnt" + str(iteration) + ".pth")
# ###########################################################
# Stage 3. associate language feature (training-free stage) #
# - Performed after training. #
# ###########################################################
if iteration == opt.iterations and iteration > opt.start_leaf_cb_iter:
print("[Stage 3] Start 2D language feature - 3D cluster association ...")
sys.stdout.flush()
if leaf_cluster_indices is None:
leaf_cluster_indices = ins_feat_codebook.leaf_cls_ids # fine-level cluster id
construct_pseudo_ins_feat(scene, render, (pipe, background, first_iter),
cluster_indices=leaf_cluster_indices, mode="lang",
root_num=opt.root_node_num, leaf_num=opt.leaf_node_num,
sam_level=opt.sam_level,
save_memory=opt.save_memory)
# note: save memory (only stage 2, 3)
if viewpoint_cam.data_on_gpu and opt.save_memory and cb_mode is not None:
viewpoint_cam.to_cpu()
def prepare_output_and_logger(args):
if not args.model_path:
if os.getenv('OAR_JOB_ID'):
unique_str=os.getenv('OAR_JOB_ID')
else:
unique_str = str(uuid.uuid4())
args.model_path = os.path.join("./output/", unique_str[0:10])
# Set up output folder
print("Output folder: {}".format(args.model_path))
os.makedirs(args.model_path, exist_ok = True)
with open(os.path.join(args.model_path, "cfg_args"), 'w') as cfg_log_f:
cfg_log_f.write(str(Namespace(**vars(args))))
# Create Tensorboard writer
tb_writer = None
if TENSORBOARD_FOUND:
tb_writer = SummaryWriter(args.model_path)
else:
print("Tensorboard not available: not logging progress")
return tb_writer
def construct_pseudo_ins_feat(scene : Scene, renderFunc, renderArgs,
filter=True, # filter pseudo features
cluster_indices=None, # coarse-level ID of each point (0 ~ k1-1)
mode="root", # root, leaf, lang
root_num=64, leaf_num=10, # k1, k2
sam_level=3,
save_memory=False):
torch.cuda.empty_cache()
# ##############################################################################################
# [Stage 2.1, 2.2] Render all training views once to construct pseudo-instance feature labels. #
# - view.pesudo_ins_feat [C=6, H, W] #
# - view.pesudo_mask_bool [num_mask, H, W] #
# ##############################################################################################
sorted_train_cameras = sorted(scene.getTrainCameras(), key=lambda Camera: Camera.image_name)
for idx, view in enumerate(tqdm(sorted_train_cameras, desc="construt pseudo feat")):
if not view.data_on_gpu:
view.to_gpu()
# render
render_pkg = renderFunc(view, scene.gaussians, *renderArgs, rescale=False, origin_feat=True)
rendered_ins_feat = render_pkg["ins_feat"]
# get gt sam mask
mask_id, mask_bool, invalid_pix = \
get_SAM_mask_and_feat(view.original_sam_mask.cuda(), level=sam_level)
# construt pseudo ins_feat, mask levle
pseudo_mask_ins_feat_, mask_var, pix_count = mask_feature_mean(rendered_ins_feat, mask_bool, return_var=True) # [num_mask, 6]
pseudo_mask_ins_feat = torch.cat((torch.zeros((1, 6)).cuda(), pseudo_mask_ins_feat_), dim=0)# [num_mask+1, 6]
# Filter out masks with high variance. Potentially incorrect segmentation.
filter_mask = mask_var > 0.006 # True->del
filter_mask = torch.cat((torch.tensor([False]).cuda(), filter_mask), dim=0) # [num_mask+1]
# Masks with large pixel ratio may be background points, inevitably leading to a large variance, Keep them.
ignored_mask_ind = torch.nonzero(pix_count > pix_count.max() * 0.8).squeeze()
filter_mask[ignored_mask_ind + 1] = False
filtered_mask_pseudo_ins_feat = pseudo_mask_ins_feat.clone()
filtered_mask_pseudo_ins_feat[filter_mask] *= 0
# pseudo ins_feat, image level
pseudo_ins_feat = pseudo_mask_ins_feat[mask_id] # Retrieve corresponding ins_feat by mask ID
pseudo_ins_feat = pseudo_ins_feat.permute(2, 0, 1) # [H, W, 6]->[6, H, W]
# filterd pseudo ins_feat, image level
filter_pseudo_ins_feat = filtered_mask_pseudo_ins_feat[mask_id]
filter_pseudo_ins_feat = filter_pseudo_ins_feat.permute(2, 0, 1)
# filtered mask [1+num_mask, H, W]
mask_bool_filtered = torch.cat((torch.zeros_like(mask_bool[0].unsqueeze(0)), mask_bool), dim=0)
mask_bool_filtered[filter_mask] *= 0
# NOTE: save the construct pesudo_ins_feat
# total_feat.append(pseudo_mask_ins_feat[1:,:])
# if view.pesudo_ins_feat is None:
view.pesudo_ins_feat = filter_pseudo_ins_feat if filter else pseudo_ins_feat
# view.pesudo_ins_feat = rendered_ins_feat
view.pesudo_mask_bool = mask_bool_filtered.to(torch.bool)
# Save some results for visualization.
pseudo_debug = True
if idx % 20 == 0 and pseudo_debug:
pseudo_ins_feat_path = os.path.join(scene.model_path, "train_process", "debug_pseudo_label", "all_pseudo_ins_feat")
filter_pseudo_ins_feat_path = os.path.join(scene.model_path, "train_process", "debug_pseudo_label", "all_filter_pseudo_ins_feat")
rendered_ins_feat_path = os.path.join(scene.model_path, "train_process", "debug_pseudo_label", "all_render_ins_feat")
sam_mask_path = os.path.join(scene.model_path, "train_process", "debug_pseudo_label", "all_sam_mask")
makedirs(pseudo_ins_feat_path, exist_ok=True)
makedirs(filter_pseudo_ins_feat_path, exist_ok=True)
makedirs(rendered_ins_feat_path, exist_ok=True)
makedirs(sam_mask_path, exist_ok=True)
# pseudo ins_feat
torchvision.utils.save_image(pseudo_ins_feat[:3,:,:], os.path.join(pseudo_ins_feat_path, '{0:05d}'.format(idx) + "_1.png"))
# torchvision.utils.save_image(pseudo_ins_feat[3:6,:,:], os.path.join(pseudo_ins_feat_path, '{0:05d}'.format(idx) + "_2.png"))
# filtered pseudo ins_feat
torchvision.utils.save_image(filter_pseudo_ins_feat[:3,:,:], os.path.join(filter_pseudo_ins_feat_path, '{0:05d}'.format(idx) + "_1.png"))
# torchvision.utils.save_image(filter_pseudo_ins_feat[3:6,:,:], os.path.join(filter_pseudo_ins_feat_path, '{0:05d}'.format(idx) + "_2.png"))
# rendered ins_feat
torchvision.utils.save_image(rendered_ins_feat[:3,:,:], os.path.join(rendered_ins_feat_path, '{0:05d}'.format(idx) + "_1.png"))
# torchvision.utils.save_image(rendered_ins_feat[3:6,:,:], os.path.join(rendered_ins_feat_path, '{0:05d}'.format(idx) + "_2.png"))
# gt SAM mask, read predefined mask color
mask_color_rand = colors_defined[mask_id.detach().cpu()].type(torch.float64)
mask_color_rand = mask_color_rand.permute(2, 0, 1)
torchvision.utils.save_image(mask_color_rand/255.0, os.path.join(sam_mask_path, '{0:05d}'.format(idx) + ".png"))
# to cpu
if view.data_on_gpu and save_memory:
view.to_cpu()
# ##################################################################################################
# Preprocessing for Stage 2.2
# determine how many objects are in each coarse cluster, not just setting a fixed k2 value.
# ##################################################################################################
torch.cuda.empty_cache()
if mode=="leaf":
iClusterSubNum = torch.ones(cluster_indices.max()+1).to(torch.int32)
for idx, view in enumerate(tqdm(sorted_train_cameras, desc="render coarse-level cluster")):
if not view.data_on_gpu:
view.to_gpu()
render_pkg = renderFunc(view, scene.gaussians, *renderArgs, cluster_idx=cluster_indices, rescale=False,\
render_feat_map=False, render_cluster=True, origin_feat=True, better_vis=True,
root_num=root_num, leaf_num=leaf_num)
rendered_cluster_imgs = render_pkg["cluster_imgs"] # coarse cluster feature map
rendered_cluster_silhouettes = render_pkg["cluster_silhouettes"] # coarse cluster mask
cluster_occur = render_pkg["cluster_occur"] # bool [k1] Whether coarse clusters visible in the current view
pser_cluster_pesudo_mask = []
i = -1
for cluster_idx in range(cluster_indices.max()+1):
if not cluster_occur[cluster_idx]: # Process only coarse clusters visible in the current view
continue
i += 1
rendered_ins_feat = rendered_cluster_imgs[i] # cluster feat map
rendered_silhouette = (rendered_cluster_silhouettes[i] > 0.9).unsqueeze(0) # cluster mask
# (1) compute the IoU of this cluster with pseudo masks.
ious = calculate_iou(view.pesudo_mask_bool, rendered_silhouette, base="former")
# pseudo masks with IoU above threshold
inters_mask = view.pesudo_mask_bool[ious[0] > 0.2] # [num_mask, H, W]
inters_mask_ = inters_mask.sum(0).to(torch.bool) # [H, W] bool
# pseudo mask features, noly for visalization [6, H, W]
inters_pesudo_ins_feat = view.pesudo_ins_feat * inters_mask_.unsqueeze(0)
# (2) compute the distance between coarse cluster features and pseudo features
# mean feature of the pesudo mask, [num_mask, 6]
inters_mask_feat_mean = mask_feature_mean(view.pesudo_ins_feat, inters_mask)
# mean feature of the cluster, [num_mask, 6]
cluster_mask_feat_mean = mask_feature_mean(rendered_ins_feat, inters_mask, image_mask=rendered_silhouette)
# distance
l1_dis, l2_dis = calculate_distances(inters_mask_feat_mean, cluster_mask_feat_mean) # metric="l1"
# (3) filter out some pseudo masks
inters_mask_filter = inters_mask[(l1_dis < 0.9) & (l2_dis < 0.5)] # l2_disk < 0.8
if inters_mask_filter.shape[0] > 10: # TODO 10? --> leaf_num
smallest_10 = torch.topk(l1_dis, 10, largest=False)[1]
inters_mask_filter = inters_mask[smallest_10]
inters_mask_filter_ = inters_mask_filter.sum(0).to(torch.bool)
inters_pesudo_ins_feat2 = view.pesudo_ins_feat * inters_mask_filter_.unsqueeze(0) # noly for visalization
if inters_mask_filter_.any() == False: # Skip if the cluster doesn’t intersect with any pseudo masks.
cluster_occur[cluster_idx] = False
continue
pser_cluster_pesudo_mask.append(inters_mask_filter_) # valid mask
# NOTE: (4) Determine the number of masks (i.e., objects) in each coarse cluster.
iClusterSubNum[cluster_idx] = max(iClusterSubNum[cluster_idx], inters_mask_filter.shape[0])
# (5) save some intermediate results for debugging
coarse_debug = False
if coarse_debug:
cluster_path = os.path.join(scene.model_path, "train_process", "debug_coarse_cluster", "cluster")
cluster_silhouette_path = os.path.join(scene.model_path, "train_process", "debug_coarse_cluster", "cluster_silhouette")
cluster_inters_pesudo_path = os.path.join(scene.model_path, "train_process", "debug_coarse_cluster", "cluster_inters_pesudo")
makedirs(cluster_path, exist_ok=True)
makedirs(cluster_silhouette_path, exist_ok=True)
makedirs(cluster_inters_pesudo_path, exist_ok=True)
# coarse-level cluster feature map
torchvision.utils.save_image(rendered_ins_feat[:3,:,:].cpu(), os.path.join(cluster_path, '{0:05d}'.format(idx) + f"_c_{cluster_idx}" + "_1.png"))
# torchvision.utils.save_image(rendered_ins_feat[3:,:,:].cpu(), os.path.join(cluster_path, '{0:05d}'.format(idx) + f"_c_{cluster_idx}" + "_2.png"))
torchvision.utils.save_image(rendered_silhouette.to(torch.float32).cpu(), os.path.join(cluster_silhouette_path, '{0:05d}'.format(idx) + f"_c_{cluster_idx}" + "_1.png"))
# pseudo masks of coarse cluster (_f represents the filtered.)
torchvision.utils.save_image(inters_pesudo_ins_feat[:3,:,:].cpu(), os.path.join(cluster_inters_pesudo_path, '{0:05d}'.format(idx) + f"_c_{cluster_idx}" + "_1.png"))
# torchvision.utils.save_image(inters_pesudo_ins_feat[3:,:,:].cpu(), os.path.join(cluster_inters_pesudo_path, '{0:05d}'.format(idx) + f"_c_{cluster_idx}" + "_2.png"))
torchvision.utils.save_image(inters_pesudo_ins_feat2[:3,:,:].cpu(), os.path.join(cluster_inters_pesudo_path, '{0:05d}'.format(idx) + f"_c_{cluster_idx}" + "_1_f.png"))
# torchvision.utils.save_image(inters_pesudo_ins_feat2[3:,:,:].cpu(), os.path.join(cluster_inters_pesudo_path, '{0:05d}'.format(idx) + f"_c_{cluster_idx}" + "_2_f.png"))
if view.cluster_masks is None:
view.cluster_masks = pser_cluster_pesudo_mask # pseudo masks of coarse cluster
view.bClusterOccur = cluster_occur # whether visible in the current view
if view.data_on_gpu and save_memory:
view.to_cpu()
# update
scene.gaussians.iClusterSubNum = (iClusterSubNum + 1).clamp(max=leaf_num)
torch.cuda.empty_cache()
# ###########################################################################
# [Stage 3] 2D mask(and language feat) - 3D fine level cluster association #
# - Sec. 3.3 in the paper #
# ###########################################################################
if mode == "lang":
# [leaf_num, view_num, (matched_mask_id, matched_score, b_matched)]
match_info = torch.zeros(root_num * leaf_num, len(sorted_train_cameras), 3).cuda() # [k1*k2, num_imgs, 3]
# iterate over the coarse-level clusters
for root_id, _ in enumerate(tqdm(range(root_num), desc="mapping")):
# iterate over all training views
for v_id, view in enumerate(sorted_train_cameras):
if not view.data_on_gpu:
view.to_gpu()
# (0) render
render_pkg = renderFunc(view, scene.gaussians, *renderArgs, leaf_cluster_idx=cluster_indices, rescale=False,\
render_feat_map=False, render_cluster=True, origin_feat=True, better_vis=False,\
selected_root_id=root_id,\
root_num=root_num, leaf_num=leaf_num)
rendered_leaf_cluster_imgs = render_pkg["leaf_clusters_imgs"] # all fine-level clusters of the root_id-th coarse-level.
rendered_leaf_cluster_silhouettes = render_pkg["leaf_cluster_silhouettes"]
occured_leaf_id = render_pkg["occured_leaf_id"]
if len(occured_leaf_id) > 0:
occured_leaf_id = torch.tensor(occured_leaf_id).cuda()
rendered_leaf_cluster_imgs = torch.stack(rendered_leaf_cluster_imgs, dim=0) # [N, C, H, W]
rendered_leaf_cluster_silhouettes = rendered_leaf_cluster_silhouettes > 0.8 # [N, H, W]
else:
if view.data_on_gpu and save_memory:
view.to_cpu()
continue # root_id not visible in current view
# (1) iou [num_rendered_leaf, num_mask]
ious = calculate_iou(view.pesudo_mask_bool, rendered_leaf_cluster_silhouettes)
# (2) feature distance
# cluster mean feat, [num_leaf, dim]
pred_mask_feat_mean = pair_mask_feature_mean(rendered_leaf_cluster_imgs, rendered_leaf_cluster_silhouettes)
# pesudo mean feat, [num_pesudo_mask, dim]
pesudo_mask_feat_mean = mask_feature_mean(view.pesudo_ins_feat, view.pesudo_mask_bool)
# only for visualization, [num_pesudo_mask, dim, H, W]
pesudo_mask_feat = view.pesudo_ins_feat * view.pesudo_mask_bool.unsqueeze(1)
# distance
l1_dis, _ = calculate_pairwise_distances(pred_mask_feat_mean, pesudo_mask_feat_mean, metric="l1") # method="l1"
# (3) iou-feature distance joint score
scores = ious * (1-l1_dis) # Eq.(5) in the paper
# (4) save the association result
max_score, max_ind = torch.max(scores, dim=-1) # [num_leaf]
b_matched = max_score > 0.2 # todo
max_score[~b_matched] *= 0
max_ind[~b_matched] *= 0
match_info[occured_leaf_id, v_id] = torch.stack((max_ind, max_score, b_matched), dim=1)
# (5) save matching results for visualization. (only save the paired mask)
association_debug = True
if association_debug:
leaf_cluster_path = os.path.join(scene.model_path, "train_process", "stage3", "leaf_cluster")
leaf_cluster_silhouette_path = os.path.join(scene.model_path, "train_process", "stage3", "leaf_cluster_silhouettes")
leaf_pesudo_mask_path = os.path.join(scene.model_path, "train_process", "stage3", "leaf_pesudo_mask")
makedirs(leaf_cluster_path, exist_ok=True)
makedirs(leaf_cluster_silhouette_path, exist_ok=True)
makedirs(leaf_pesudo_mask_path, exist_ok=True)
if b_matched.sum() > 0:
for i, img in enumerate(rendered_leaf_cluster_imgs):
if not b_matched[i]:
continue
if max_score[i] < 0.8: # note: 0.8 is just for visualization
continue
torchvision.utils.save_image(img[:3,:,:], os.path.join(leaf_cluster_path, \
f"r{root_id}_l{i}_v{v_id}.png"))
torchvision.utils.save_image(rendered_leaf_cluster_silhouettes[i].to(torch.float32), \
os.path.join(leaf_cluster_silhouette_path, f"r{root_id}_l{i}_v{v_id}.png"))
torchvision.utils.save_image(pesudo_mask_feat[max_ind[i]][:3,:,:], os.path.join(leaf_pesudo_mask_path, \
f"r{root_id}_l{i}_v{v_id}.png"))
# print("end one root cluster of one view")
if view.data_on_gpu and save_memory:
view.to_cpu()
# print("end matching")
torch.cuda.empty_cache()
# count the matches of each leaf (fine-level cluster) across all viewpoints.
leaf_per_view_matched_mask = match_info[:, :, 0].to(torch.int64) # [k1*k2, num_cam] matched mask id
match_info_sum = match_info.sum(dim=1) # [k1*k2, (matched_mask_id, matched_score, b_matched)]
leaf_ave_score = match_info_sum[:, 1] / (match_info_sum[:, 2]+ 1e-6) # [k1*k2] ave score
leaf_occu_count = match_info_sum[:, 2] # [k1*k2] number of matches for each leaf
# accumulated 2D features of each leaf
per_leaf_feat_sum = torch.zeros(root_num * leaf_num, 512).cuda() # [k1*k2]
for v_id, view in enumerate(sorted_train_cameras):
if not view.data_on_gpu:
view.to_gpu()
if sam_level == 0:
strat_id = 0
end_id = view.original_sam_mask[sam_level].max().to(torch.int64) + 1
else:
strat_id = view.original_sam_mask[sam_level-1].max().to(torch.int64) + 1
end_id = view.original_sam_mask[sam_level].max().to(torch.int64) + 1
curr_view_lang_feat = view.original_mask_feat[strat_id:end_id, :] # [num_mask, 512]
curr_view_lang_feat = torch.cat((torch.zeros_like(curr_view_lang_feat[0]).unsqueeze(0), \
curr_view_lang_feat)) # note: [num_mask+1, 512] add a feature with all 0s, i.e., the feature with id=0.
# current feat [k1*k2, 512]
single_view_leaf_feat = curr_view_lang_feat[leaf_per_view_matched_mask[:, v_id]]
# accumulate
per_leaf_feat_sum += single_view_leaf_feat
if view.data_on_gpu and save_memory:
view.to_cpu()
# average language features [k1*k2, 512]
per_leaf_feat = per_leaf_feat_sum / (leaf_occu_count + 1e-4).unsqueeze(1)
# save per_leaf_feat[k1*k2, 512], leaf_ave_score[k1*k2], leaf_occu_count[k1*k2], cluster_indices[num_pts]
np.savez(f'{scene.model_path}/cluster_lang.npz',leaf_feat=per_leaf_feat.cpu().numpy(), \
leaf_score=leaf_ave_score.cpu().numpy(), \
occu_count=leaf_occu_count.cpu().numpy(), \
leaf_ind=cluster_indices.cpu().numpy())
def training_report(tb_writer, iteration, Ll1, loss, l1_loss, elapsed, testing_iterations, \
start_root_cb_iter, scene : Scene, renderFunc, renderArgs):
if tb_writer:
tb_writer.add_scalar('train_loss_patches/l1_loss', Ll1.item(), iteration)
tb_writer.add_scalar('train_loss_patches/total_loss', loss.item(), iteration)
tb_writer.add_scalar('iter_time', elapsed, iteration)
# Report test and samples of training set
if iteration in testing_iterations:
torch.cuda.empty_cache()
validation_configs = ({'name': 'test', 'cameras' : scene.getTestCameras()},
{'name': 'train', 'cameras' : [scene.getTrainCameras()[idx % len(scene.getTrainCameras())] for idx in range(5, 30, 5)]})
for config in validation_configs:
if config['cameras'] and len(config['cameras']) > 0:
l1_test = 0.0
psnr_test = 0.0
for idx, viewpoint in enumerate(config['cameras']):
image = torch.clamp(renderFunc(viewpoint, scene.gaussians, *renderArgs)["render"], 0.0, 1.0)
gt_image = torch.clamp(viewpoint.original_image.to("cuda"), 0.0, 1.0)
if tb_writer and (idx < 5):
tb_writer.add_images(config['name'] + "_view_{}/render".format(viewpoint.image_name), image[None], global_step=iteration)
if iteration == testing_iterations[0]:
tb_writer.add_images(config['name'] + "_view_{}/ground_truth".format(viewpoint.image_name), gt_image[None], global_step=iteration)
l1_test += l1_loss(image, gt_image).mean().double()
psnr_test += psnr(image, gt_image).mean().double()
psnr_test /= len(config['cameras'])
l1_test /= len(config['cameras'])
print("\n[ITER {}] Evaluating {}: L1 {} PSNR {}".format(iteration, config['name'], l1_test, psnr_test))
sys.stdout.flush()
if tb_writer:
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - l1_loss', l1_test, iteration)
tb_writer.add_scalar(config['name'] + '/loss_viewpoint - psnr', psnr_test, iteration)
if tb_writer:
tb_writer.add_histogram("scene/opacity_histogram", scene.gaussians.get_opacity, iteration)
tb_writer.add_scalar('total_points', scene.gaussians.get_xyz.shape[0], iteration)
torch.cuda.empty_cache()
# initialize new gaussian parameters. modify -----
def initialize_new_params(new_pt_cld, mean3_sq_dist):
num_pts = new_pt_cld.shape[0]
means3D = new_pt_cld[:, :3] # [num_gaussians, 3]
unnorm_rots = np.tile([1, 0, 0, 0], (num_pts, 1)) # [num_gaussians, 3]
logit_opacities = torch.zeros((num_pts, 1), dtype=torch.float, device="cuda")
logit_ins_feat = torch.zeros((num_pts, 3), dtype=torch.float, device="cuda")
# color [N, 3, 16]
max_sh_degree = 3
fused_color = RGB2SH(new_pt_cld[:, 3:6])
features = torch.zeros((fused_color.shape[0], 3, (max_sh_degree + 1) ** 2)).float().cuda() # [N, 3, 16]
features[:, :3, 0 ] = fused_color
features[:, 3:, 1:] = 0.0
params = {
'new_xyz': means3D,
'new_features_dc': features[:,:,0:1].transpose(1, 2).contiguous(),
'new_features_rest':features[:,:,1:].transpose(1, 2).contiguous(),
'new_opacities': logit_opacities,
# 'new_scaling': torch.tile(torch.log(torch.sqrt(mean3_sq_dist))[..., None], (1, 1)),
'new_scaling': torch.tile(torch.log(torch.sqrt(mean3_sq_dist))[..., None], (1, 3)),
'new_rotation': unnorm_rots,
'new_ins_feat': logit_ins_feat,
}
for k, v in params.items():
# Check if value is already a torch tensor
if not isinstance(v, torch.Tensor):
params[k] = torch.nn.Parameter(torch.tensor(v).cuda().float().contiguous().requires_grad_(True))
else:
params[k] = torch.nn.Parameter(v.cuda().float().contiguous().requires_grad_(True))
return params
# modify -----
if __name__ == "__main__":
# Set up command line argument parser
parser = ArgumentParser(description="Training script parameters")
lp = ModelParams(parser)
op = OptimizationParams(parser)
pp = PipelineParams(parser)
parser.add_argument('--ip', type=str, default="127.0.0.1")
parser.add_argument('--port', type=int, default=6009)
parser.add_argument('--debug_from', type=int, default=-1)
parser.add_argument('--detect_anomaly', action='store_true', default=False)
parser.add_argument("--test_iterations", nargs="+", type=int, default=[30_000])
parser.add_argument("--save_iterations", nargs="+", type=int, default=[30_000])
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--checkpoint_iterations", nargs="+", type=int, default=[])
parser.add_argument("--start_checkpoint", type=str, default = None)
args = parser.parse_args(sys.argv[1:])
args.save_iterations.append(args.iterations)
args.checkpoint_iterations.append(args.iterations)
print("Optimizing " + args.model_path)
# Initialize system state (RNG)
safe_state(args.quiet)
# Start GUI server, configure and run training
network_gui.init(args.ip, args.port)
torch.autograd.set_detect_anomaly(args.detect_anomaly)
training(lp.extract(args), op.extract(args), pp.extract(args), \
args.test_iterations, args.save_iterations, args.checkpoint_iterations, \
args.start_checkpoint, args.debug_from)
# All done
print("\nTraining complete.")
================================================
FILE: utils/camera_utils.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
from scene.cameras import Camera
import numpy as np
from utils.general_utils import PILtoTorch
from utils.graphics_utils import fov2focal
import torch
WARNED = False
def loadCam(args, id, cam_info, resolution_scale):
orig_w, orig_h = cam_info.image.size
if args.resolution in [1, 2, 4, 8]:
resolution = round(orig_w/(resolution_scale * args.resolution)), round(orig_h/(resolution_scale * args.resolution))
else: # should be a type that converts to float
if args.resolution == -1:
if orig_w > 1600:
global WARNED
if not WARNED:
print("[ INFO ] Encountered quite large input images (>1.6K pixels width), rescaling to 1.6K.\n "
"If this is not desired, please explicitly specify '--resolution/-r' as 1")
WARNED = True
global_down = orig_w / 1600
else:
global_down = 1
else:
global_down = orig_w / args.resolution
scale = float(global_down) * float(resolution_scale)
resolution = (int(orig_w / scale), int(orig_h / scale))
resized_image_rgb = PILtoTorch(cam_info.image, resolution) # [C, H, W]
# NOTE: load SAM mask. modify -----
if cam_info.sam_mask is not None:
# step = int(args.resolution/2)
step = int(max(args.resolution, 1))
gt_sam_mask = cam_info.sam_mask[:, ::step, ::step] # downsample for mask
gt_sam_mask = torch.from_numpy(gt_sam_mask)
# align resolution
if resized_image_rgb.shape[1] != gt_sam_mask.shape[1]:
resolution = (gt_sam_mask.shape[2], gt_sam_mask.shape[1]) # modify -----
resized_image_rgb = PILtoTorch(cam_info.image, resolution) # [C, H, W]
else:
gt_sam_mask = None
if cam_info.mask_feat is not None:
mask_feat = torch.from_numpy(cam_info.mask_feat)
else:
mask_feat = None
# modify -----
gt_image = resized_image_rgb[:3, ...]
loaded_mask = None
# if resized_image_rgb.shape[1] == 4:
if resized_image_rgb.shape[0] == 4:
loaded_mask = resized_image_rgb[3:4, ...]
return Camera(colmap_id=cam_info.uid, R=cam_info.R, T=cam_info.T,
FoVx=cam_info.FovX, FoVy=cam_info.FovY,
cx=cam_info.cx/args.resolution, cy=cam_info.cy/args.resolution,
image=gt_image, depth=None, gt_alpha_mask=loaded_mask,
gt_sam_mask=gt_sam_mask, gt_mask_feat=mask_feat,
image_name=cam_info.image_name, uid=id, data_device=args.data_device)
def cameraList_from_camInfos(cam_infos, resolution_scale, args):
camera_list = []
for id, c in enumerate(cam_infos):
camera_list.append(loadCam(args, id, c, resolution_scale))
return camera_list
def camera_to_JSON(id, camera : Camera):
Rt = np.zeros((4, 4))
Rt[:3, :3] = camera.R.transpose()
Rt[:3, 3] = camera.T
Rt[3, 3] = 1.0
W2C = np.linalg.inv(Rt)
pos = W2C[:3, 3]
rot = W2C[:3, :3]
serializable_array_2d = [x.tolist() for x in rot]
camera_entry = {
'id' : id,
'img_name' : camera.image_name,
'width' : camera.width,
'height' : camera.height,
'position': pos.tolist(),
'rotation': serializable_array_2d,
'fy' : fov2focal(camera.FovY, camera.height),
'fx' : fov2focal(camera.FovX, camera.width)
}
return camera_entry
================================================
FILE: utils/general_utils.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import sys
from datetime import datetime
import numpy as np
import random
def inverse_sigmoid(x):
return torch.log(x/(1-x))
def PILtoTorch(pil_image, resolution):
resized_image_PIL = pil_image.resize(resolution)
resized_image = torch.from_numpy(np.array(resized_image_PIL)) / 255.0
if len(resized_image.shape) == 3:
return resized_image.permute(2, 0, 1)
else:
return resized_image.unsqueeze(dim=-1).permute(2, 0, 1)
def get_expon_lr_func(
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
"""
Copied from Plenoxels
Continuous learning rate decay function. Adapted from JaxNeRF
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
is log-linearly interpolated elsewhere (equivalent to exponential decay).
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
function of lr_delay_mult, such that the initial learning rate is
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
to the normal learning rate when steps>lr_delay_steps.
:param conf: config subtree 'lr' or similar
:param max_steps: int, the number of steps during optimization.
:return HoF which takes step as input
"""
def helper(step):
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
# Disable this parameter
return 0.0
if lr_delay_steps > 0:
# A kind of reverse cosine decay.
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
)
else:
delay_rate = 1.0
t = np.clip(step / max_steps, 0, 1)
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
return delay_rate * log_lerp
return helper
def strip_lowerdiag(L):
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
uncertainty[:, 0] = L[:, 0, 0]
uncertainty[:, 1] = L[:, 0, 1]
uncertainty[:, 2] = L[:, 0, 2]
uncertainty[:, 3] = L[:, 1, 1]
uncertainty[:, 4] = L[:, 1, 2]
uncertainty[:, 5] = L[:, 2, 2]
return uncertainty
def strip_symmetric(sym):
return strip_lowerdiag(sym)
def build_rotation(r):
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
q = r / norm[:, None]
R = torch.zeros((q.size(0), 3, 3), device='cuda')
r = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
R[:, 0, 1] = 2 * (x*y - r*z)
R[:, 0, 2] = 2 * (x*z + r*y)
R[:, 1, 0] = 2 * (x*y + r*z)
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
R[:, 1, 2] = 2 * (y*z - r*x)
R[:, 2, 0] = 2 * (x*z - r*y)
R[:, 2, 1] = 2 * (y*z + r*x)
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
return R
def build_scaling_rotation(s, r):
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
R = build_rotation(r)
L[:,0,0] = s[:,0]
L[:,1,1] = s[:,1]
L[:,2,2] = s[:,2]
L = R @ L
return L
def safe_state(silent):
old_f = sys.stdout
class F:
def __init__(self, silent):
self.silent = silent
def write(self, x):
if not self.silent:
if x.endswith("\n"):
old_f.write(x.replace("\n", " [{}]\n".format(str(datetime.now().strftime("%d/%m %H:%M:%S")))))
else:
old_f.write(x)
def flush(self):
old_f.flush()
sys.stdout = F(silent)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.set_device(torch.device("cuda:0"))
================================================
FILE: utils/graphics_utils.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import math
import numpy as np
from typing import NamedTuple
class BasicPointCloud(NamedTuple):
points : np.array
colors : np.array
normals : np.array
def geom_transform_points(points, transf_matrix):
P, _ = points.shape
ones = torch.ones(P, 1, dtype=points.dtype, device=points.device)
points_hom = torch.cat([points, ones], dim=1)
points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0))
denom = points_out[..., 3:] + 0.0000001
return (points_out[..., :3] / denom).squeeze(dim=0)
def getWorld2View(R, t):
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
return np.float32(Rt)
def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0):
Rt = np.zeros((4, 4))
Rt[:3, :3] = R.transpose()
Rt[:3, 3] = t
Rt[3, 3] = 1.0
C2W = np.linalg.inv(Rt)
cam_center = C2W[:3, 3]
cam_center = (cam_center + translate) * scale
C2W[:3, 3] = cam_center
Rt = np.linalg.inv(C2W)
return np.float32(Rt)
def getProjectionMatrix(znear, zfar, fovX, fovY):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))
top = tanHalfFovY * znear
bottom = -top
right = tanHalfFovX * znear
left = -right
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 2.0 * znear / (right - left)
P[1, 1] = 2.0 * znear / (top - bottom)
P[0, 2] = (right + left) / (right - left)
P[1, 2] = (top + bottom) / (top - bottom)
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
def fov2focal(fov, pixels):
return pixels / (2 * math.tan(fov / 2))
def focal2fov(focal, pixels):
return 2*math.atan(pixels/(2*focal))
================================================
FILE: utils/image_utils.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
def mse(img1, img2):
return (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
def psnr(img1, img2):
mse = (((img1 - img2)) ** 2).view(img1.shape[0], -1).mean(1, keepdim=True)
return 20 * torch.log10(1.0 / torch.sqrt(mse))
================================================
FILE: utils/loss_utils.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from math import exp
def l1_loss(network_output, gt, mask=None, weight=None):
if mask == None:
return torch.abs((network_output - gt)).mean()
else:
if weight is None:
weight = torch.ones_like(mask)
return torch.abs((network_output - gt) * mask * weight).sum() / mask.sum().clamp(min=1)
def l2_loss(network_output, gt, mask=None, weight=None):
if mask == None:
return ((network_output - gt) ** 2).mean()
else:
if weight is None:
weight = torch.ones_like(mask)
return ((network_output - gt) ** 2 * mask * weight).sum() / mask.sum().clamp(min=1)
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
return window
def ssim(img1, img2, window_size=11, size_average=True):
channel = img1.size(-3)
window = create_window(window_size, channel)
if img1.is_cuda:
window = window.cuda(img1.get_device())
window = window.type_as(img1)
return _ssim(img1, img2, window, window_size, channel, size_average)
def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
else:
return ssim_map.mean(1).mean(1).mean(1)
================================================
FILE: utils/opengs_utlis.py
================================================
import torch
import numpy as np
import torch.nn.functional as F
import os
from bitarray import bitarray
from collections import OrderedDict
def calculate_pairwise_distances(tensor1, tensor2, metric=None):
"""
Calculate L1 (Manhattan) and L2 (Euclidean) distances between every pair of vectors
in two tensors of shape [m, 6] and [n, 6].
Args:
tensor1 (torch.Tensor): A tensor of shape [m, 6].
tensor2 (torch.Tensor): Another tensor of shape [n, 6].
Returns:
torch.Tensor: L1 distances of shape [m, n].
torch.Tensor: L2 distances of shape [m, n].
"""
# Reshape tensors to allow broadcasting
# tensor1 shape becomes [m, 1, 6] and tensor2 shape becomes [1, n, 6]
tensor1 = tensor1.unsqueeze(1) # Now tensor1 is [m, 1, 6]
tensor2 = tensor2.unsqueeze(0) # Now tensor2 is [1, n, 6]
# Compute the L1 distance
if metric == "l1":
return torch.abs(tensor1 - tensor2).sum(dim=2), None # Result is [m, n]
# Compute the L2 distance
if metric == "l2":
return None, torch.sqrt((tensor1 - tensor2).pow(2).sum(dim=2)) # Result is [m, n]
l1_distances = torch.abs(tensor1 - tensor2).sum(dim=2)
l2_distances = torch.sqrt((tensor1 - tensor2).pow(2).sum(dim=2))
return l1_distances, l2_distances
def calculate_distances(tensor1, tensor2, metric=None):
"""
Calculate L1 (Manhattan) and L2 (Euclidean) distances between corresponding vectors
in two tensors of shape [N, dim].
Args:
tensor1 (torch.Tensor): A tensor of shape [N, dim].
tensor2 (torch.Tensor): Another tensor of shape [N, dim].
Returns:
torch.Tensor: L1 distances of shape [N].
torch.Tensor: L2 distances of shape [N].
"""
# Compute L1 distance
if metric == "l1":
return torch.abs(tensor1 - tensor2).sum(dim=1)
# Compute L2 distance
if metric == "l2":
return torch.sqrt((tensor1 - tensor2).pow(2).sum(dim=1))
l1_distances = torch.abs(tensor1 - tensor2).sum(dim=1)
l2_distances = torch.sqrt((tensor1 - tensor2).pow(2).sum(dim=1))
return l1_distances, l2_distances
def bin2dec(b, bits):
"""Convert binary b to decimal integer.
Code from: https://stackoverflow.com/questions/55918468/convert-integer-to-pytorch-tensor-of-binary-bits
"""
mask = 2 ** torch.arange(bits - 1, -1, -1).to(b.device, torch.int64)
return torch.sum(mask * b, -1)
def load_code_book(base_path):
inds_file = os.path.join(base_path, 'kmeans_inds.bin')
codebook_file = os.path.join(base_path, 'kmeans_centers.pth')
args_file = os.path.join(base_path, 'kmeans_args.npy')
codebook = torch.load(codebook_file) # [num_cluster, dim]
args_dict = np.load(args_file, allow_pickle=True).item()
quant_params = args_dict['params']
loaded_bitarray = bitarray()
with open(inds_file, 'rb') as file:
loaded_bitarray.fromfile(file)
# bitarray pads 0s if array is not divisible by 8. ignore extra 0s at end when loading
total_len = args_dict['total_len']
loaded_bitarray = loaded_bitarray[:total_len].tolist()
indices = np.reshape(loaded_bitarray, (-1, args_dict['n_bits']))
indices = bin2dec(torch.from_numpy(indices), args_dict['n_bits'])
indices = np.reshape(indices.cpu().numpy(), (len(quant_params), -1))
indices_dict = OrderedDict()
for i, key in enumerate(args_dict['params']):
indices_dict[key] = indices[i]
return codebook, indices_dict['ins_feat']
def calculate_iou(masks1, masks2, base=None):
"""
Calculate the Intersection over Union (IoU) between two sets of masks.
Args:
masks1: PyTorch tensor of shape [n, H, W], torch.int32.
masks2: PyTorch tensor of shape [m, H, W], torch.int32.
Returns:
iou_matrix: PyTorch tensor of shape [m, n], containing IoU values.
"""
# Ensure the masks are of type torch.int32
if masks1.dtype != torch.bool:
masks1 = masks1.to(torch.bool)
if masks2.dtype != torch.bool:
masks2 = masks2.to(torch.bool)
# Expand masks to broadcastable shapes
masks1_expanded = masks1.unsqueeze(0) # [1, n, H, W]
masks2_expanded = masks2.unsqueeze(1) # [m, 1, H, W]
# Compute intersection
intersection = (masks1_expanded & masks2_expanded).float().sum(dim=(2, 3)) # [m, n]
# Compute union
if base == "former":
union = (masks1_expanded).float().sum(dim=(2, 3)) + 1e-6 # [m, n]
elif base == "later":
union = (masks2_expanded).float().sum(dim=(2, 3)) + 1e-6 # [m, n]
else:
union = (masks1_expanded | masks2_expanded).float().sum(dim=(2, 3)) + 1e-6 # [m, n]
# Compute IoU
iou_matrix = intersection / union
return iou_matrix
def get_SAM_mask_and_feat(gt_sam_mask, level=3, filter_th=50, original_mask_feat=None, sample_mask=False):
"""
input:
gt_sam_mask[4, H, W]: mask id
output:
mask_id[H, W]: The ID of the mask each pixel belongs to (0 indicates invalid pixels)
mask_bool[num_mask+1, H, W]: Boolean, note that the return value excludes the 0th mask (invalid points)
invalid_pix[H, W]: Boolean, invalid pixels
"""
# (1) mask id: -1, 1, 2, 3,...
mask_id = gt_sam_mask[level].clone()
if level > 0:
# subtract the maximum mask ID of the previous level
mask_id = mask_id - (gt_sam_mask[level-1].max().detach().cpu()+1)
if mask_id.min() < 0:
mask_id = mask_id.clamp_min(-1) # -1, 0~num_mask
mask_id += 1 # 0, 1~num_mask+1
invalid_pix = mask_id==0 # invalid pixels
# (2) mask id[H, W] -> one-hot/mask_bool [num_mask+1, H, W]
instance_num = mask_id.max()
one_hot = F.one_hot(mask_id.type(torch.int64), num_classes=int(instance_num.item() + 1))
# bool mask [num+1, H, W]
mask_bool = one_hot.permute(2, 0, 1)
# # TODO modify -------- only keep the largest 50
# if instance_num > 50:
# top50_values, _ = torch.topk(mask_bool.sum(dim=(1,2)), 50, largest=True)
# filter_th = top50_values[-1].item()
# # modify --------
# # TODO: not used
# # (3) delete small mask
# saved_idx = mask_bool.sum(dim=(1,2)) >= filter_th # default 50 pixels
# # Random sampling, not actually used
# if sample_mask:
# prob = torch.rand(saved_idx.shape[0])
# sample_ind = prob > 0.5
# saved_idx = saved_idx & sample_ind.cuda()
# saved_idx[0] = True # Keep the mask for invalid points, ensuring that mask_id == 0 corresponds to invalid pixels.
# mask_bool = mask_bool[saved_idx] # [num_filt, H, W]
# update mask id
mask_id = torch.argmax(mask_bool, dim=0) # [H, W] The ID of the pixels after filtering is 0
invalid_pix = mask_id==0
# TODO not used!
# (4) Get the language features corresponding to the masks (used for 2D-3D association in the third stage)
if original_mask_feat is not None:
mask_feat = original_mask_feat.clone() # [num_mask, 512]
max_ind = int(gt_sam_mask[level].max())+1
min_ind = int(gt_sam_mask[level-1].max())+1 if level > 0 else 0
mask_feat = mask_feat[min_ind:max_ind, :]
# # update mask feat
# mask_feat = mask_feat[saved_idx[1:]] # The 0th element of saved_idx is the mask corresponding to invalid pixels and has no features
return mask_id, mask_bool[1:, :, :], mask_feat, invalid_pix
return mask_id, mask_bool[1:, :, :], invalid_pix
def pair_mask_feature_mean(feat_map, masks):
""" mean feat of N masks
feat_map: [N, C, H, W]
masks: [N, H, W]
mean_values: [N, C]
"""
N, C, H, W = feat_map.shape
# [N, H, W] -> [N, C, H, W]
expanded_masks = masks.unsqueeze(1).expand(-1, C, -1, -1)
# [N, C, H, W]
masked_features = feat_map * expanded_masks.float()
# pixels
mask_counts = expanded_masks.sum(dim=[2, 3]) + 1e-6
# mean feat [N, C]
mean_values = masked_features.sum(dim=[2, 3]) / mask_counts
return mean_values
def process_in_chunks(masks_expanded, masked_feats, mean_per_channel, chunk_size=5):
result = torch.zeros_like(masked_feats)
for i in range(0, masks_expanded.size(0), chunk_size):
end_i = min(i + chunk_size, masks_expanded.size(0))
for j in range(0, masks_expanded.size(1), chunk_size):
end_j = min(j + chunk_size, masks_expanded.size(1))
chunk_mask = masks_expanded[i:end_i, j:end_j]
chunk_feats = masked_feats[i:end_i, j:end_j]
chunk_mean = mean_per_channel[i:end_i, j:end_j].unsqueeze(-1).unsqueeze(-1)
result[i:end_i, j:end_j] = torch.where(chunk_mask.bool(), chunk_feats - chunk_mean, torch.zeros_like(chunk_feats))
return result
def calculate_variance_in_chunks(masked_for_variance, mask_counts, chunk_size=5):
variance_per_channel = torch.zeros(masked_for_variance.size(0), masked_for_variance.size(1), device=masked_for_variance.device)
for i in range(0, masked_for_variance.size(0), chunk_size):
end_i = min(i + chunk_size, masked_for_variance.size(0))
for j in range(0, masked_for_variance.size(1), chunk_size):
end_j = min(j + chunk_size, masked_for_variance.size(1))
chunk_masked_for_variance = masked_for_variance[i:end_i, j:end_j]
chunk_variance = (chunk_masked_for_variance ** 2).sum(dim=[2, 3]) / mask_counts[i:end_i, j:end_j]
variance_per_channel[i:end_i, j:end_j] = chunk_variance
return variance_per_channel
def ele_multip_in_chunks(feat_expanded, masks_expanded, chunk_size=5):
result = torch.zeros_like(feat_expanded)
for i in range(0, feat_expanded.size(0), chunk_size):
end_i = min(i + chunk_size, feat_expanded.size(0))
for j in range(0, feat_expanded.size(1), chunk_size):
end_j = min(j + chunk_size, feat_expanded.size(1))
chunk_feat = feat_expanded[i:end_i, j:end_j]
chunk_mask = masks_expanded[i:end_i, j:end_j].float()
result[i:end_i, j:end_j] = chunk_feat * chunk_mask
return result
def mask_feature_mean(feat_map, gt_masks, image_mask=None, return_var=False):
"""Compute the average instance features within each mask.
feat_map: [C=6, H, W] the instance features of the entire image
gt_masks: [num_mask, H, W] num_mask boolean masks
"""
num_mask, H, W = gt_masks.shape
# expand feat and masks for batch processing
feat_expanded = feat_map.unsqueeze(0).expand(num_mask, *feat_map.shape) # [num_mask, C, H, W]
masks_expanded = gt_masks.unsqueeze(1).expand(-1, feat_map.shape[0], -1, -1) # [num_mask, C, H, W]
if image_mask is not None: # image level mask
image_mask_expanded = image_mask.unsqueeze(0).expand(num_mask, feat_map.shape[0], -1, -1)
# average features within each mask
if image_mask is not None:
masked_feats = feat_expanded * masks_expanded.float() * image_mask_expanded.float()
mask_counts = (masks_expanded * image_mask_expanded.float()).sum(dim=(2, 3))
else:
# masked_feats = feat_expanded * masks_expanded.float() # [num_mask, C, H, W] may cause OOM
masked_feats = ele_multip_in_chunks(feat_expanded, masks_expanded, chunk_size=5) # in chuck to avoid OOM
mask_counts = masks_expanded.sum(dim=(2, 3)) # [num_mask, C]
# the number of pixels within each mask
mask_counts = mask_counts.clamp(min=1)
# the mean features of each mask
sum_per_channel = masked_feats.sum(dim=[2, 3])
mean_per_channel = sum_per_channel / mask_counts # [num_mask, C]
if not return_var:
return mean_per_channel # [num_mask, C]
else:
# calculate variance
# masked_for_variance = torch.where(masks_expanded.bool(), masked_feats - mean_per_channel.unsqueeze(-1).unsqueeze(-1), torch.zeros_like(masked_feats))
masked_for_variance = process_in_chunks(masks_expanded, masked_feats, mean_per_channel, chunk_size=5) # in chunk to avoid OOM
# variance_per_channel = (masked_for_variance ** 2).sum(dim=[2, 3]) / mask_counts # [num_mask, 6]
variance_per_channel = calculate_variance_in_chunks(masked_for_variance, mask_counts, chunk_size=5) # in chuck to avoid OOM
# mean and variance
mean = mean_per_channel.mean(dim=1) # [num_mask],not used
variance = variance_per_channel.mean(dim=1) # [num_mask]
return mean_per_channel, variance, mask_counts[:, 0] # [num_mask, C], [num_mask], [num_mask]
def linear_to_srgb(linear):
if isinstance(linear, torch.Tensor):
"""Assumes `linear` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
eps = torch.finfo(torch.float32).eps
srgb0 = 323 / 25 * linear
srgb1 = (211 * torch.clamp(linear, min=eps)**(5 / 12) - 11) / 200
return torch.where(linear <= 0.0031308, srgb0, srgb1)
elif isinstance(linear, np.ndarray):
eps = np.finfo(np.float32).eps
srgb0 = 323 / 25 * linear
srgb1 = (211 * np.maximum(eps, linear) ** (5 / 12) - 11) / 200
return np.where(linear <= 0.0031308, srgb0, srgb1)
else:
raise NotImplementedError
def srgb_to_linear(srgb):
if isinstance(srgb, torch.Tensor):
"""Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
eps = torch.finfo(torch.float32).eps
linear0 = 25 / 323 * srgb
linear1 = torch.clamp(((200 * srgb + 11) / (211)), min=eps)**(12 / 5)
return torch.where(srgb <= 0.04045, linear0, linear1)
elif isinstance(srgb, np.ndarray):
"""Assumes `srgb` is in [0, 1], see https://en.wikipedia.org/wiki/SRGB."""
eps = np.finfo(np.float32).eps
linear0 = 25 / 323 * srgb
linear1 = np.maximum(((200 * srgb + 11) / (211)), eps)**(12 / 5)
return np.where(srgb <= 0.04045, linear0, linear1)
else:
raise NotImplementedError
================================================
FILE: utils/sh_utils.py
================================================
# Copyright 2021 The PlenOctree Authors.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import torch
C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [
1.0925484305920792,
-1.0925484305920792,
0.31539156525252005,
-1.0925484305920792,
0.5462742152960396
]
C3 = [
-0.5900435899266435,
2.890611442640554,
-0.4570457994644658,
0.3731763325901154,
-0.4570457994644658,
1.445305721320277,
-0.5900435899266435
]
C4 = [
2.5033429417967046,
-1.7701307697799304,
0.9461746957575601,
-0.6690465435572892,
0.10578554691520431,
-0.6690465435572892,
0.47308734787878004,
-1.7701307697799304,
0.6258357354491761,
]
def eval_sh(deg, sh, dirs):
"""
Evaluate spherical harmonics at unit directions
using hardcoded SH polynomials.
Works with torch/np/jnp.
... Can be 0 or more batch dimensions.
Args:
deg: int SH deg. Currently, 0-3 supported
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
dirs: jnp.ndarray unit directions [..., 3]
Returns:
[..., C]
"""
assert deg <= 4 and deg >= 0
coeff = (deg + 1) ** 2
assert sh.shape[-1] >= coeff
result = C0 * sh[..., 0]
if deg > 0:
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
result = (result -
C1 * y * sh[..., 1] +
C1 * z * sh[..., 2] -
C1 * x * sh[..., 3])
if deg > 1:
xx, yy, zz = x * x, y * y, z * z
xy, yz, xz = x * y, y * z, x * z
result = (result +
C2[0] * xy * sh[..., 4] +
C2[1] * yz * sh[..., 5] +
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
C2[3] * xz * sh[..., 7] +
C2[4] * (xx - yy) * sh[..., 8])
if deg > 2:
result = (result +
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
C3[1] * xy * z * sh[..., 10] +
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
C3[5] * z * (xx - yy) * sh[..., 14] +
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
if deg > 3:
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
return result
def RGB2SH(rgb):
return (rgb - 0.5) / C0
def SH2RGB(sh):
return sh * C0 + 0.5
================================================
FILE: utils/system_utils.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
from errno import EEXIST
from os import makedirs, path
import os
def mkdir_p(folder_path):
# Creates a directory. equivalent to using mkdir -p on the command line
try:
makedirs(folder_path)
except OSError as exc: # Python >2.5
if exc.errno == EEXIST and path.isdir(folder_path):
pass
else:
raise
def searchForMaxIteration(folder):
saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)]
return max(saved_iters)