Showing preview only (1,762K chars total). Download the full file or copy to clipboard to get everything.
Repository: autonomousvision/factor-fields
Branch: main
Commit: 2a43edea0d17
Files: 48
Total size: 1.7 MB
Directory structure:
gitextract_iica752a/
├── .gitignore
├── 2D_regression.py
├── LICENSE
├── README.md
├── README_FactorField.md
├── configs/
│ ├── 360_v2.yaml
│ ├── defaults.yaml
│ ├── image.yaml
│ ├── image_intro.yaml
│ ├── image_set.yaml
│ ├── nerf.yaml
│ ├── nerf_ft.yaml
│ ├── nerf_set.yaml
│ ├── sdf.yaml
│ └── tnt.yaml
├── dataLoader/
│ ├── __init__.py
│ ├── blender.py
│ ├── blender_set.py
│ ├── colmap.py
│ ├── colmap2nerf.py
│ ├── dtu_objs.py
│ ├── dtu_objs2.py
│ ├── google_objs.py
│ ├── image.py
│ ├── image_set.py
│ ├── llff.py
│ ├── nsvf.py
│ ├── ray_utils.py
│ ├── sdf.py
│ ├── tankstemple.py
│ └── your_own_data.py
├── models/
│ ├── FactorFields.py
│ ├── __init__.py
│ └── sh.py
├── renderer.py
├── requirements.txt
├── run_batch.py
├── scripts/
│ ├── 2D_regression.ipynb
│ ├── 2D_set_regression.ipynb
│ ├── 2D_set_regression.py
│ ├── __init__.py
│ ├── formula_demostration.ipynb
│ ├── mesh2SDF_data_process.ipynb
│ └── sdf_regression.ipynb
├── train_across_scene.py
├── train_across_scene_ft.py
├── train_per_scene.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
data/
slurm/
logs/
================================================
FILE: 2D_regression.py
================================================
import torch,imageio,sys,time,os,cmapy,scipy
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import torch.nn.functional as F
device = 'cuda'
sys.path.append('..')
from models.sparseCoding import sparseCoding
from dataLoader import dataset_dict
from torch.utils.data import DataLoader
def PSNR(a, b):
if type(a).__module__ == np.__name__:
mse = np.mean((a - b) ** 2)
else:
mse = torch.mean((a - b) ** 2).item()
psnr = -10.0 * np.log(mse) / np.log(10.0)
return psnr
def rgb_ssim(img0, img1, max_val,
filter_size=11,
filter_sigma=1.5,
k1=0.01,
k2=0.03,
return_map=False):
# Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
assert len(img0.shape) == 3
assert img0.shape[-1] == 3
assert img0.shape == img1.shape
# Construct a 1D Gaussian blur filter.
hw = filter_size // 2
shift = (2 * hw - filter_size + 1) / 2
f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma) ** 2
filt = np.exp(-0.5 * f_i)
filt /= np.sum(filt)
# Blur in x and y (faster than the 2D convolution).
def convolve2d(z, f):
return scipy.signal.convolve2d(z, f, mode='valid')
filt_fn = lambda z: np.stack([
convolve2d(convolve2d(z[..., i], filt[:, None]), filt[None, :])
for i in range(z.shape[-1])], -1)
mu0 = filt_fn(img0)
mu1 = filt_fn(img1)
mu00 = mu0 * mu0
mu11 = mu1 * mu1
mu01 = mu0 * mu1
sigma00 = filt_fn(img0 ** 2) - mu00
sigma11 = filt_fn(img1 ** 2) - mu11
sigma01 = filt_fn(img0 * img1) - mu01
# Clip the variances and covariances to valid values.
# Variance must be non-negative:
sigma00 = np.maximum(0., sigma00)
sigma11 = np.maximum(0., sigma11)
sigma01 = np.sign(sigma01) * np.minimum(
np.sqrt(sigma00 * sigma11), np.abs(sigma01))
c1 = (k1 * max_val) ** 2
c2 = (k2 * max_val) ** 2
numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
ssim_map = numer / denom
ssim = np.mean(ssim_map)
return ssim_map if return_map else ssim
@torch.no_grad()
def eval_img(aabb, reso, shiftment=[0.5, 0.5], chunk=10240):
y = torch.linspace(0, aabb[0] - 1, reso[0])
x = torch.linspace(0, aabb[1] - 1, reso[1])
yy, xx = torch.meshgrid((y, x), indexing='ij')
idx = 0
res = torch.empty(reso[0] * reso[1], train_dataset.img.shape[-1])
coordiantes = torch.stack((xx, yy), dim=-1).reshape(-1, 2) + torch.tensor(
shiftment) # /(torch.FloatTensor(reso[::-1])-1)*2-1
for coordiante in tqdm(torch.split(coordiantes, chunk, dim=0)):
feats, _ = model.get_coding(coordiante.to(model.device))
y_recon = model.linear_mat(feats, is_train=False)
# y_recon = torch.sum(feats,dim=-1,keepdim=True)
res[idx:idx + y_recon.shape[0]] = y_recon.cpu()
idx += y_recon.shape[0]
return res.view(reso[0], reso[1], -1), coordiantes
def linear_to_srgb(img):
limit = 0.0031308
return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)
def write_image_imageio(img_file, img, colormap=None, quality=100):
if colormap == 'turbo':
shape = img.shape
img = interpolate(turbo_colormap_data, img.reshape(-1)).reshape(*shape, -1)
elif colormap is not None:
img = cmapy.colorize((img * 255).astype('uint8'), colormap)
if img.dtype != 'uint8':
img = (img - np.min(img)) / (np.max(img) - np.min(img))
img = (img * 255.0).astype(np.uint8)
kwargs = {}
if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]:
if img.ndim >= 3 and img.shape[2] > 3:
img = img[:, :, :3]
kwargs["quality"] = quality
kwargs["subsampling"] = 0
imageio.imwrite(img_file, img, **kwargs)
if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(20211202)
np.random.seed(20211202)
base_conf = OmegaConf.load('configs/defaults.yaml')
cli_conf = OmegaConf.from_cli()
second_conf = OmegaConf.load('configs/image.yaml')
cfg = OmegaConf.merge(base_conf, second_conf, cli_conf)
print(cfg)
folder = cfg.defaults.expname
save_root = f'/vlg-nfs/anpei/project/NeuBasis/ours/images/'
dataset = dataset_dict[cfg.dataset.dataset_name]
delete_region = [[290,350,48,48],[300,380,48,48],[180, 407, 48, 48], [223, 263, 48, 48], [233, 150, 48, 48], [374, 119, 48, 48], [4, 199, 48, 48], [180, 234, 48, 48], [173, 39, 48, 48], [408, 308, 48, 48], [227, 177, 48, 48], [46, 330, 48, 48], [213, 26, 48, 48], [90, 44, 48, 48], [295, 61, 48, 48]]
continue_sampling = False
psnrs,ssims = [],[]
for i in range(1,257):
cfg.dataset.datadir = f'/vlg-nfs/anpei/dataset/Images/crop//{i:04d}.png'
name = os.path.basename(cfg.dataset.datadir).split('.')[0]
if os.path.exists(f'{save_root}/{folder}/{int(name):04d}.png'):
continue
train_dataset = dataset(cfg.dataset, cfg.training.batch_size, split='train',tolinear=True, perscent=1.0,HW=1024)#, continue_sampling=continue_sampling,delete_region=delete_region
train_loader = DataLoader(train_dataset,
num_workers=2,
persistent_workers=True,
batch_size=None,
pin_memory=False)
# train_dataset.img = train_dataset.img.to(device)
cfg.model.out_dim = train_dataset.img.shape[-1]
batch_size = cfg.training.batch_size
n_iter = cfg.training.n_iters
H,W = train_dataset.HW
train_dataset.scene_bbox = [[0., 0.], [W, H]]
cfg.dataset.aabb = train_dataset.scene_bbox
model = sparseCoding(cfg, device)
if 1==i:
print(model)
print('total parameters: ',model.n_parameters())
# tvreg = TVLoss()
# trainingSampler = SimpleSampler(len(train_dataset), cfg.training.batch_size)
grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)
optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#
loss_scale = 1.0
lr_factor = 0.1 ** (1 / n_iter)
# pbar = tqdm(range(10000))
start = time.time()
# for iteration in pbar:
for (iteration, sample) in zip(range(10000),train_loader):
loss_scale *= lr_factor
# if iteration==5000:
# model.coeffs = torch.nn.Parameter(F.interpolate(model.coeffs.data, size=None, scale_factor=2.0, align_corners=True,mode='bilinear'))
# grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)
# optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#
# model.set_optimizable(['mlp','basis'], False)
coordiantes, pixel_rgb = sample['xy'], sample['rgb']
feats,coeff = model.get_coding(coordiantes.to(device))
# tv_loss = model.TV_loss(tvreg)
y_recon = model.linear_mat(feats,is_train=True)
# y_recon = torch.sum(feats,dim=-1,keepdim=True)
loss = torch.mean((y_recon.squeeze()-pixel_rgb.squeeze().to(device))**2)
psnr = -10.0 * np.log(loss.item()) / np.log(10.0)
# if iteration%100==0:
# pbar.set_description(
# f'Iteration {iteration:05d}:'
# + f' loss_dist = {loss.item():.8f}'
# # + f' tv_loss = {tv_loss.item():.6f}'
# + f' psnr = {psnr:.3f}'
# )
# loss = loss + tv_loss
# loss = loss + torch.mean(coeff.abs())*1e-2
loss = loss * loss_scale
optimizer.zero_grad()
loss.backward()
optimizer.step()
# if iteration%100==0:
# model.normalize_basis()
iteration_time = time.time()-start
H,W = train_dataset.HW
img,coordinate = eval_img(train_dataset.HW,[1024,1024])
if continue_sampling:
import torch.nn.functional as F
coordinate_tmp = (coordinate.view(1,1,-1,2))/torch.tensor([W,H])*2-1.0
img_gt = F.grid_sample(train_dataset.img.view(1,H,W,-1).permute(0,3,1,2),coordinate_tmp, mode='bilinear',
align_corners=False, padding_mode='border').reshape(-1,H,W).permute(1,2,0)
else:
img_gt = train_dataset.img.view(H,W,-1)
psnrs.append(PSNR(img.clamp(0,1.),img_gt))
ssims.append(rgb_ssim(img.clamp(0,1.),img_gt,1.0))
# print(PSNR(img.clamp(0,1.),img_gt),iteration_time)
# plt.figure(figsize=(10, 10))
# plt.imshow(linear_to_srgb(img.clamp(0,1.)))
print(i, psnrs[-1], ssims[-1])
os.makedirs(f'{save_root}/{folder}',exist_ok=True)
write_image_imageio(f'{save_root}/{folder}/{int(name):04d}.png',linear_to_srgb(img.clamp(0,1.)))
np.savetxt(f'{save_root}/{folder}/{int(name):04d}.txt',[psnrs[-1],ssims[-1],iteration_time,model.n_parameters()])
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2023 autonomousvision
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Factor Fields
## [Project page](https://apchenstu.github.io/FactorFields/) | [Paper](https://arxiv.org/abs/2302.01226)
This repository contains a pytorch implementation for the paper: [Factor Fields: A Unified Framework for Neural Fields and Beyond](https://arxiv.org/abs/2302.01226) and [Dictionary Fields: Learning a Neural Basis Decomposition](https://arxiv.org/abs/2302.01226). Our work present a novel framework for modeling and representing signals,
we have also observed that Dictionary Fields offer benefits such as improved **approximation quality**, **compactness**, **faster training speed**, and the ability to **generalize** to unseen images and 3D scenes.<br><br>
## Installation
#### Tested on Ubuntu 20.04 + Pytorch 1.13.0
Install environment:
```sh
conda create -n FactorFields python=3.9
conda activate FactorFields
conda install -c "nvidia/label/cuda-11.7.1" cuda-toolkit
conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install -r requirements.txt
```
Optionally install [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), only needed if you want to run hash grid based representations.
```sh
conda install -c "nvidia/label/cuda-11.7.1" cuda-toolkit
pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
```
# Quick Start
Please ensure that you download the corresponding dataset and extract its contents into the `data` folder.
## Image
* [Data - Image Set](https://huggingface.co/apchen/Factor_Fields/blob/main/images.zip)
The training script can be found at `scripts/2D_regression.ipynb`, and the configuration file is located at `configs/image.yaml`.
<p align="left">
<img src="media/Girl_with_a_Pearl_Earring.jpg" alt="Girl with a Pearl Earring" width="320">
</p>
## SDF
* [Data - Mesh set](https://huggingface.co/apchen/Factor_Fields/blob/main/SDFs.zip)
The training script can be found at `scripts/sdf_regression.ipynb`, and the configuration file is located at `configs/sdf.yaml`.
<img src="https://github.com/apchenstu/GIFs/blob/main/FactorField-statuette.gif" alt="GIF" width="500px">
## NeRF
* [Data - Synthetic-NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
* [Data-Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip)
The training script can be found at `train_per_scene.py`:
```python
python train_per_scene.py configs/nerf.yaml defaults.expname=lego dataset.datadir=./data/nerf_synthetic/lego
```
<img src="https://github.com/apchenstu/GIFs/blob/main/FactorField-mic.gif" alt="GIF" width="500px"
## Generalization Image
* [Data - FFHQ](https://github.com/NVlabs/ffhq-dataset)
The training script can be found at `2D_set_regression.ipynb`
<p align="left">
<img src="media/inpainting.png" alt="Inpainting" width="640">
</p>
## Generalization NeRF
* [Data - Google Scanned Objects](https://drive.google.com/file/d/1w1Cs0yztH6kE3JIz7mdggvPGCwIKkVi2/view)
```python
python train_across_scene.py configs/nerf_set.yaml
```
<img src="https://github.com/apchenstu/GIFs/blob/main/FactorField-few-shot.gif" alt="GIF" width="500px">
## More examples
Command explanation with a nerf example:
* `model.basis_dims=[4, 4, 4, 2, 2, 2]` adjusts the number of levels and channels at each level, with a total of 6 levels and 18 channels.
* `model.basis_resos=[32, 51, 70, 89, 108, 128]` represents the resolution of the feature embeddings.
* `model.freq_bands=[2.0, 3.2, 4.4, 5.6, 6.8, 8.0]` indicates the frequency parameters applied at each level of the coordinate transformation function.
* `model.coeff_type` represents the coefficient field representations and can be one of the following: [none, x, grid, mlp, vec, cp, vm].
* `model.basis_type` represents the basis field representation and can be one of the following: [none, x, grid, mlp, vec, cp, vm, hash].
* `model.basis_mapping` represents the coordinate transformation and can be one of the following: [x, triangle, sawtooth, trigonometric]. Please note that if you want to use orthogonal projection, choose the cp or vm basis type, as they automatically utilize the orthogonal projection functions.
* `model.total_params` controls the total model size. It is important to note that the model's size capability is determined by model.basis_resos and model.basis_dims. The total_params parameter mainly affects the capability of the coefficients.
* `exportation.render_only` you can rendering item after training by setting this label to 1. Please also specify the `defaults.ckpt` label.
* `exportation....` you can specify whether to render the items of `[render_test, render_train, render_path, export_mesh]` after training by enable the corressponding label to 1.
Some pre-defined configurations (such as occNet, DVGO, nerf, iNGP, EG3D) can be found in `README_FactorField.py`.
## COPY RIGHT
* [Summer Day](https://www.rijksmuseum.nl/en/collection/SK-A-3005) - Credit goes to Johan Hendrik Weissenbruch and rijksmuseum.
* [Mars](https://solarsystem.nasa.gov/resources/933/true-colors-of-pluto/) - Credit goes to NASA.
* [Albert](https://cdn.loc.gov/service/pnp/cph/3b40000/3b46000/3b46000/3b46036v.jpg) - Credit goes to Orren Jack Turner.
* [Girl With a Pearl Earring](http://profoundism.com/free_licenses.html) - Renovation copyright Koorosh Orooj (CC BY-SA 4.0).
## Citation
If you find our code or paper helpful, please consider citing both of these papers:
```
@article{Chen2023factor,
title={Factor Fields: A Unified Framework for Neural Fields and Beyond},
author={Chen, Anpei and Xu, Zexiang and Wei, Xinyue and Tang, Siyu and Su, Hao and Geiger, Andreas},
journal={arXiv preprint arXiv:2302.01226},
year={2023}
}
@article{Chen2023SIGGRAPH,
title={{Dictionary Fields: Learning a Neural Basis Decomposition}},
author={Anpei, Chen and Zexiang, Xu and Xinyue, Wei and Siyu, Tang and Hao, Su and Andreas, Geiger},
booktitle={International Conference on Computer Graphics and Interactive Techniques (SIGGRAPH)},
year={2023}}
```
================================================
FILE: README_FactorField.md
================================================
## nerf reconstruction with Dictionary field
```python
for scene in ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']:
cmd = f'python train_basis.py configs/nerf.yaml defaults.expname={scene} ' \
f'dataset.datadir=./data/nerf_synthetic/{scene} '
```
## different model design choices
```python
choice_dict = {
'-grid': '', \
'-DVGO-like': 'model.basis_type=none model.coeff_reso=80', \
'-noC': 'model.coeff_type=none', \
'-SL':'model.basis_dims=[18] model.basis_resos=[70] model.freq_bands=[8.]', \
'-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \
'-iNGP-like': 'model.basis_type=hash model.coeff_type=none', \
'-hash': f'model.basis_type=hash model.coef_init=1.0 ', \
'-sinc': f'model.basis_mapping=sinc', \
'-tria': f'model.basis_mapping=triangle', \
'-vm': f'model.coeff_type=vm model.basis_type=vm', \
'-mlpB': 'model.basis_type=mlp', \
'-mlpC': 'model.coeff_type=mlp', \
'-occNet': f'model.basis_type=x model.coeff_type=none model.basis_mapping=x model.num_layers=8 model.hidden_dim=256 ', \
'-nerf': f'model.basis_type=x model.coeff_type=none model.basis_mapping=trigonometric ' \
f'model.num_layers=8 model.hidden_dim=256 ' \
f'model.freq_bands=[1.,2.,4.,8.,16.,32.,64,128,256.,512.] model.basis_dims=[1,1,1,1,1,1,1,1,1,1] model.basis_resos=[1024,512,256,128,64,32,16,8,4,2]', \
'-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \
'-vm-sl': f'model.coeff_type=vm model.basis_type=vm model.coef_init=1.0 model.basis_dims=[18] model.freq_bands=[1.] model.basis_resos=[64] model.total_params=1308416 ', \
'-DCT':'model.basis_type=fix-grid', \
}
for name in choice_dict.keys():
for scene in [ 'ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']:
cmd = f"python train_per_scene.py configs/nerf.yaml defaults.expname={scene}{name} dataset.datadir=./data/nerf_synthetic/{scene} {config}"
```
## generalized nerf
Your can choice of the the following design choice for testing.
```python
choice_dict = {
'-grid': '', \
'-DVGO-like': 'model.basis_type=none model.coeff_reso=48',
'-SL':'model.basis_dims=[72] model.basis_resos=[48] model.freq_bands=[6.]', \
'-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \
'-hash': f'model.basis_type=hash model.coef_init=1.0 ', \
'-sinc': f'model.basis_mapping=sinc', \
'-tria': f'model.basis_mapping=triangle', \
'-vm': f'model.coeff_type=vm model.basis_type=vm', \
'-mlpB': 'model.basis_type=mlp', \
'-mlpC': 'model.coeff_type=mlp', \
'-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \
'-vm-sl': f'model.coeff_type=vm model.basis_type=vm model.coef_init=1.0 model.basis_dims=[18] model.freq_bands=[1.] model.basis_resos=[64] model.total_params=1308416 ', \
'-DCT':'model.basis_type=fix-grid', \
}
for name in choice_dict.keys(): #
cmd = f'python train_across_scene.py configs/nerf_set.yaml defaults.expname=google-obj{name} {config} ' \
f'training.volume_resoFinal=128 dataset.datadir=./data/google_scanned_objects/'
```
You can also fine tune of the trained model for a new scene:
```python
for views in [5]:#3,
for name in choice_dict.keys(): #
for scene in [183]:#183,199,298,467,957,244,963,527,
cmd = f'python train_across_scene_ft.py configs/nerf_ft.yaml defaults.expname=google_objs_{name}_{scene}_{views}_views ' \
f'{config} training.n_iters=10000 ' \
f'dataset.train_views={views} ' \
f'dataset.train_scene_list=[{scene}] ' \
f'dataset.test_scene_list=[{scene}] ' \
f'dataset.datadir=./data/google_scanned_objects/ ' \
f'defaults.ckpt=./logs/google-obj{name}//google-obj{name}.th'
```
# render path after optimization
```python
for views in [5]:
for name in choice_dict.keys(): #
config = commands[name].replace(",", "','")
for scene in [183]:#183,199,298,467,957,244,963,527,681,948
cmd = f'python train_across_scene.py configs/nerf_ft.yaml defaults.expname=google_objs_{name}_{scene}_{views}_views ' \
f'{config} training.n_iters=10000 ' \
f'dataset.train_views={views} exporation.render_only=True exporation.render_path=True exporation.render_test=False ' \
f'dataset.train_scene_list=[{scene}] ' \
f'dataset.test_scene_list=[{scene}] ' \
f'dataset.datadir=./data/google_scanned_objects/ ' \
f'defaults.ckpt=./logs/google_objs_{name}_{scene}_{views}_views//google_objs_{name}_{scene}_{views}_views.th'
```
================================================
FILE: configs/360_v2.yaml
================================================
defaults:
expname: basis_room_real_mask
logdir: ./logs
ckpt: null # help='specific weights npy file to reload for coarse network'
model:
basis_dims: [5,5,5,2,2,2]
basis_resos: [ 64, 83, 102, 121, 140, 160]
coeff_reso: 16
coef_init: 0.01
phases: [0.0]
coef_mode: bilinear
basis_mode: bilinear
freq_bands: [ 1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.]
kernel_mapping_type: 'sawtooth'
in_dim: 3
out_dim: 32
num_layers: 2
hidden_dim: 128
dataset:
# loader options
dataset_name: llff # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
datadir: /home/anpei/code/NeuBasis/data/360_v2/room/
ndc_ray: 0
is_unbound: True
with_depth: 0
downsample_train: 4.0
downsample_test: 4.0
N_vis: 5
vis_every: 5000
training:
n_iters: 30000
batch_size: 4096
volume_resoInit: 128 # 128**3:
volume_resoFinal: 320 # 300**3
upsamp_list: [2000,3000,4000,5500]
update_AlphaMask_list: [2500]
shrinking_list: [-1]
L1_weight_inital: 0.0
L1_weight_rest: 0.0
TV_weight_density: 0.0
TV_weight_app: 0.00
exportation:
render_only: 0
render_test: 1
render_train: 0
render_path: 0
export_mesh: 0
export_mesh_only: 0
renderer:
shadingMode: MLP_Fea
num_layers: 3
hidden_dim: 128
fea2denseAct: 'relu'
density_shift: -10
distance_scale: 25.0
view_pe: 6
fea_pe: 2
lindisp: 0
perturb: 1 # help='set to 0. for no jitter, 1. for jitter'
step_ratio: 0.5
max_samples: 1600
alphaMask_thres: 0.04
rayMarch_weight_thres: 1e-3
================================================
FILE: configs/defaults.yaml
================================================
defaults:
expname: basis_lego
logedir: ./logs
mode: 'reconstruction'
progress_refresh_rate: 10
add_timestamp: 0
model:
basis_dims: [4,4,4,2,2,2]
basis_resos: [32,51,70,89,108,128]
coeff_reso: 32
total_params: 10744166
T_basis: 0
T_coeff: 0
coef_init: 1.0
coef_mode: bilinear
basis_mode: bilinear
freq_bands: [1.0000, 1.3300, 1.7689, 2.3526, 3.1290, 4.1616]
basis_mapping: 'sawtooth'
with_dropout: False
in_dim: 3
out_dim: 32
num_layers: 2
hidden_dim: 128
dataset:
# loader options
dataset_name: blender # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
datadir: ./data/nerf_synthetic/lego
with_depth: 0
downsample_train: 1.0
downsample_test: 1.0
is_unbound: False
training:
# training options
batch_size: 4096
n_iters: 30000
# learning rate
lr_small: 0.001
lr_large: 0.02
lr_decay_iters: -1
lr_decay_target_ratio: 0.1 # help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters'
lr_upsample_reset: 1 # help='reset lr to inital after upsampling'
# loss
L1_weight_inital: 0.0 # help='loss weight'
L1_weight_rest: 0
Ortho_weight: 0.0
TV_weight_density: 0.0
TV_weight_app: 0.0
# optimiziable
coeff: True
basis: True
linear_mat: True
renderModule: True
================================================
FILE: configs/image.yaml
================================================
defaults:
expname: basis_image
logdir: ./logs
mode: 'image'
ckpt: null # help='specific weights npy file to reload for coarse network'
model:
basis_dims: [32,32,32,16,16,16]
basis_resos: [32,51,70,89,108,128]
freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
total_params: 1426063 # albert
# total_params: 61445328 # pluto
# total_params: 71848800 #Girl_with_a_Pearl_Earring
# total_params: 37138096 # Weissenbruch_Jan_Hendrik_The_Shipping_Canal_at_Rijswijk.jpeg_base
coeff_type: 'grid'
basis_type: 'grid'
coef_init: 0.001
coef_mode: nearest
basis_mode: nearest
basis_mapping: 'sawtooth'
in_dim: 2
out_dim: 3
num_layers: 2
hidden_dim: 64
with_dropout: False
dataset:
# loader options
dataset_name: image
datadir: "../data/image/albert.exr"
# datadir: "../data/image//pluto.jpeg"
# datadir: "../data/image//Girl_with_a_Pearl_Earring.jpeg"
# datadir: "../data/image//Weissenbruch_Jan_Hendrik_The_Shipping_Canal_at_Rijswijk.jpeg"
training:
n_iters: 10000
batch_size: 102400
# learning rate
lr_small: 0.002
lr_large: 0.002
================================================
FILE: configs/image_intro.yaml
================================================
defaults:
expname: basis_image
logdir: ./logs
mode: 'demo'
ckpt: null # help='specific weights npy file to reload for coarse network'
model:
in_dim: 2
out_dim: 1
basis_dims: [32,32,32,16,16,16]
basis_resos: [32,51,70,89,108,128]
freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
# occNet
coeff_type: 'none'
basis_type: 'x'
basis_mapping: 'x'
num_layers: 8
hidden_dim: 256
# coef_init: 0.001
# coef_mode: nearest
# basis_mode: nearest
# basis_mapping: 'sawtooth'
with_dropout: False
dataset:
# loader options
dataset_name: image
datadir: ../data/image/cat_occupancy.png
training:
n_iters: 10000
batch_size: 102400
# learning rate
lr_small: 0.0002
lr_large: 0.0002
================================================
FILE: configs/image_set.yaml
================================================
defaults:
expname: basis_image
logdir: ./logs
mode: 'images'
ckpt: null # help='specific weights npy file to reload for coarse network'
model:
basis_dims: [32,32,32,16,16,16]
basis_resos: [32,51,70,89,108,128]
freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
coeff_reso: 32
total_params: 1024000
coef_init: 0.001
coef_mode: bilinear
basis_mode: bilinear
coeff_type: 'grid'
basis_type: 'grid'
in_dim: 3
out_dim: 3
num_layers: 2
hidden_dim: 64
with_dropout: True
dataset:
# loader options
dataset_name: images
datadir: data/ffhq/ffhq_512.npy
training:
n_iters: 300000
batch_size: 40960
# learning rate
lr_small: 0.002
lr_large: 0.002
================================================
FILE: configs/nerf.yaml
================================================
defaults:
expname: basis_lego
logdir: ./logs
mode: 'reconstruction'
ckpt: null # help='specific weights npy file to reload for coarse network'
model:
coeff_reso: 32
basis_dims: [4,4,4,2,2,2]
basis_resos: [32,51,70,89,108,128]
freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
coef_init: 1.0
phases: [0.0]
total_params: 5308416
coef_mode: bilinear
basis_mode: bilinear
coeff_type: 'grid'
basis_type: 'grid'
basis_mapping: 'sawtooth'
in_dim: 3
out_dim: 32
num_layers: 2
hidden_dim: 64
dataset:
# loader options
dataset_name: blender # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
datadir: ./data/nerf_synthetic/lego
ndc_ray: 0
with_depth: 0
downsample_train: 1.0
downsample_test: 1.0
N_vis: 5
vis_every: 100000
scene_reso: 768
training:
n_iters: 30000
batch_size: 4096
volume_resoInit: 128 # 128**3:
volume_resoFinal: 300 # 300**3
upsamp_list: [2000,3000,4000,5500,7000]
update_AlphaMask_list: [2500,4000]
shrinking_list: [500]
L1_weight_inital: 0.0
L1_weight_rest: 0.0
TV_weight_density: 0.000
TV_weight_app: 0.00
exportation:
render_only: 0
render_test: 1
render_train: 0
render_path: 0
export_mesh: 0
export_mesh_only: 0
renderer:
shadingMode: MLP_Fea
num_layers: 3
hidden_dim: 128
fea2denseAct: 'softplus'
density_shift: -10
distance_scale: 25.0
view_pe: 6
fea_pe: 2
lindisp: 0
perturb: 1 # help='set to 0. for no jitter, 1. for jitter'
step_ratio: 0.5
max_samples: 1200
alphaMask_thres: 0.02
rayMarch_weight_thres: 1e-3
================================================
FILE: configs/nerf_ft.yaml
================================================
defaults:
expname: basis
logdir: ./logs
mode: 'reconstructions'
ckpt: null # help='specific weights npy file to reload for coarse network'
model:
coeff_reso: 16
basis_dims: [16,16,16,8,8,8]
basis_resos: [32,51,70,89,108,128]
freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
with_dropout: True
coef_init: 1.0
phases: [0.0]
total_params: 5308416
coef_mode: bilinear
basis_mode: bilinear
coeff_type: 'grid'
basis_type: 'grid'
basis_mapping: 'sawtooth'
in_dim: 3
out_dim: 32
num_layers: 2
hidden_dim: 64
dataset:
# loader options
dataset_name: google_objs # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
datadir: /vlg-nfs/anpei/dataset/google_scanned_objects
ndc_ray: 0
train_scene_list: [100]
test_scene_list: [100]
train_views: 5
with_depth: 0
downsample_train: 1.0
downsample_test: 1.0
N_vis: 5
vis_every: 100000
scene_reso: 768
training:
n_iters: 5000
batch_size: 4096
volume_resoInit: 128 # 128**3:
volume_resoFinal: 300 # 300**3
upsamp_list: [2000,3000,4000]
update_AlphaMask_list: [1500]
shrinking_list: [-1]
L1_weight_inital: 0.0
L1_weight_rest: 0.0
TV_weight_density: 0.000
TV_weight_app: 0.00
# optimiziable
coeff: True
basis: False
linear_mat: False
renderModule: False
exportation:
render_only: 0
render_test: 1
render_train: 0
render_path: 0
export_mesh: 0
export_mesh_only: 0
renderer:
shadingMode: MLP_Fea
num_layers: 3
hidden_dim: 128
fea2denseAct: 'softplus'
density_shift: -10
distance_scale: 25.0
view_pe: 6
fea_pe: 2
lindisp: 0
perturb: 1 # help='set to 0. for no jitter, 1. for jitter'
step_ratio: 0.5
max_samples: 1200
alphaMask_thres: 0.02
rayMarch_weight_thres: 1e-3
================================================
FILE: configs/nerf_set.yaml
================================================
defaults:
expname: basis_no_relu_lego
logdir: ./logs
mode: 'reconstructions'
ckpt: null # help='specific weights npy file to reload for coarse network'
model:
coeff_reso: 16
basis_dims: [16,16,16,8,8,8]
# basis_resos: [32,51,70,89,108,128]
basis_resos: [32,51,70,89,108,128]
# freq_bands: [1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.]
freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
with_dropout: True
coef_init: 1.0
phases: [0.0]
total_params: 5308416
coef_mode: bilinear
basis_mode: bilinear
coeff_type: 'grid'
basis_type: 'grid'
basis_mapping: 'sawtooth'
in_dim: 3
out_dim: 32
num_layers: 2
hidden_dim: 64
dataset:
# loader options
dataset_name: google_objs # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
datadir: /vlg-nfs/anpei/dataset/google_scanned_objects
ndc_ray: 0
train_scene_list: [0,100]
test_scene_list: [0]
train_views: 100
with_depth: 0
downsample_train: 1.0
downsample_test: 1.0
N_vis: 5
vis_every: 100000
scene_reso: 768
training:
n_iters: 50000
batch_size: 4096
volume_resoInit: 128 # 128**3:
volume_resoFinal: 256 # 300**3
upsamp_list: [2000,3000,4000,5500,7000]
update_AlphaMask_list: [-1]
shrinking_list: [-1]
L1_weight_inital: 0.0
L1_weight_rest: 0.0
TV_weight_density: 0.000
TV_weight_app: 0.00
exportation:
render_only: 0
render_test: 0
render_train: 0
render_path: 0
export_mesh: 0
export_mesh_only: 0
renderer:
shadingMode: MLP_Fea
num_layers: 3
hidden_dim: 128
fea2denseAct: 'softplus'
density_shift: -10
distance_scale: 25.0
view_pe: 6
fea_pe: 2
lindisp: 0
perturb: 1 # help='set to 0. for no jitter, 1. for jitter'
step_ratio: 0.5
max_samples: 1200
alphaMask_thres: 0.005
rayMarch_weight_thres: 1e-3
================================================
FILE: configs/sdf.yaml
================================================
defaults:
expname: basis_sdf
logdir: ./logs
mode: 'sdf'
ckpt: null # help='specific weights npy file to reload for coarse network'
model:
basis_dims: [4,4,4,2,2,2]
basis_resos: [32,51,70,89,108,128]
freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
total_params: 5313942
coeff_reso: 32
coef_init: 0.05
coef_mode: bilinear
basis_mode: bilinear
coeff_type: 'grid'
basis_type: 'grid'
kernel_mapping_type: 'sawtooth'
in_dim: 3
out_dim: 1
num_layers: 1
hidden_dim: 64
dataset:
# loader options
dataset_name: sdf
datadir: "../data/mesh/statuette_close.npy"
scene_reso: 384
training:
n_iters: 10000
batch_size: 40960
# learning rate
lr_small: 0.002
lr_large: 0.02
================================================
FILE: configs/tnt.yaml
================================================
defaults:
expname: basis_truck
logdir: ./logs
ckpt: null # help='specific weights npy file to reload for coarse network'
model:
coeff_reso: 32
# basis_dims: [8,4,2]
# basis_resos: [32,64,128]
# freq_bands: [3.0,4.7,6.8]
## 32.88
# basis_dims: [3, 3, 3, 3, 3, 3, 3]
# basis_resos: [64, 64, 64, 64, 64, 64, 64]
## freq_bands: [1.52727273, 2.58181818, 3.63636364, 4.69090909, 5.21818182, 6.27272727, 6.8]
# freq_bands: [2., 3., 4.,5.,6.,7.,8.]
# basis_dims: [5,5,5,2,2,2]
# basis_resos: [32,51,70,89,108,128]
# coeff_reso: 26
# freq_bands: [1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.]
basis_dims: [4,4,4,2,2,2]
basis_resos: [32,51,70,89,108,128]
freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
# freq_bands: [2. , 2.8, 3.6, 4.4, 5.2, 6.]
coef_init: 1.0
phases: [0.0]
total_params: 5744166
coef_mode: bilinear
basis_mode: bilinear
kernel_mapping_type: 'sawtooth'
in_dim: 3
out_dim: 32
num_layers: 2
hidden_dim: 64
dataset:
# loader options
dataset_name: tankstemple # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
datadir: ./data/TanksAndTemple/Truck
ndc_ray: 0
with_depth: 0
downsample_train: 1.0
downsample_test: 1.0
N_vis: 5
vis_every: 100000
scene_reso: 768
training:
n_iters: 30000
batch_size: 4096
volume_resoInit: 128 # 128**3:
volume_resoFinal: 320 # 300**3
# upsamp_list: [2000,4000,7000]
# update_AlphaMask_list: [2000,3000]
upsamp_list: [2000,3000,4000,5500,7000]
update_AlphaMask_list: [2500,4000]
shrinking_list: [500]
L1_weight_inital: 0.0
L1_weight_rest: 0.0
TV_weight_density: 0.0
TV_weight_app: 0.00
exportation:
render_only: 0
render_test: 1
render_train: 0
render_path: 0
export_mesh: 0
export_mesh_only: 0
renderer:
shadingMode: MLP_Fea
num_layers: 3
hidden_dim: 128
fea2denseAct: 'softplus'
density_shift: -10
distance_scale: 25.0
view_pe: 2
fea_pe: 2
lindisp: 0
perturb: 1 # help='set to 0. for no jitter, 1. for jitter'
step_ratio: 0.5
max_samples: 1200
alphaMask_thres: 0.005
rayMarch_weight_thres: 1e-3
================================================
FILE: dataLoader/__init__.py
================================================
from .llff import LLFFDataset
from .blender import BlenderDataset
from .nsvf import NSVF
from .tankstemple import TanksTempleDataset
from .your_own_data import YourOwnDataset
from .image import ImageDataset
from .image_set import ImageSetDataset
from .colmap import ColmapDataset
from .sdf import SDFDataset
from .blender_set import BlenderDatasetSet
from .google_objs import GoogleObjsDataset
from .dtu_objs import DTUDataset
dataset_dict = {'blender': BlenderDataset,
'blender_set': BlenderDatasetSet,
'llff':LLFFDataset,
'tankstemple':TanksTempleDataset,
'nsvf':NSVF,
'own_data':YourOwnDataset,
'image':ImageDataset,
'images':ImageSetDataset,
'sdf':SDFDataset,
'colmap':ColmapDataset,
'google_objs':GoogleObjsDataset,
'dtu':DTUDataset}
================================================
FILE: dataLoader/blender.py
================================================
import torch, cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
class BlenderDataset(Dataset):
def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
# self.N_vis = N_vis
self.split = split
self.batch_size = batch_size
self.root_dir = cfg.datadir
self.is_stack = is_stack if is_stack is not None else 'train'!=split
self.downsample = cfg.get(f'downsample_{self.split}')
self.img_wh = (int(800 / self.downsample), int(800 / self.downsample))
self.define_transforms()
self.rot = torch.tensor([[0.65561799, -0.65561799, 0.37460659],
[0.73729737, 0.44876192, -0.50498052],
[0.16296514, 0.60727077, 0.77760181]])
self.scene_bbox = (np.array([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]])).tolist()
# self.scene_bbox = [[-0.8,-0.8,-0.22],[0.8,0.8,0.2]]
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.define_proj_mat()
self.white_bg = True
self.near_far = [2.0, 6.0]
# self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
# self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
def read_depth(self, filename):
depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
return depth
def read_meta(self):
with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
self.meta = json.load(f)
w, h = self.img_wh
self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(h, w, [self.focal, self.focal]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
self.intrinsics = torch.tensor([[self.focal, 0, w / 2], [0, self.focal, h / 2], [0, 0, 1]]).float()
self.image_paths = []
self.poses = []
self.all_rays = []
self.all_rgbs = []
self.all_masks = []
self.all_depth = []
self.downsample = 1.0
img_eval_interval = 1 #if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
# idxs = idxs[:10] if self.split=='train' else idxs
for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:#
frame = self.meta['frames'][i]
pose = np.array(frame['transform_matrix']) @ self.blender2opencv
c2w = torch.FloatTensor(pose)
c2w[:3,-1] /= 1.5
self.poses += [c2w]
image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
self.image_paths += [image_path]
img = Image.open(image_path)
if self.downsample != 1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (4, h, w)
img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.all_rgbs += [img]
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
# rays_o, rays_d = rays_o@self.rot, rays_d@self.rot
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
self.poses = torch.stack(self.poses)
if not self.is_stack:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
# self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3)
# self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]),self.batch_size)
def define_transforms(self):
self.transform = T.ToTensor()
def define_proj_mat(self):
self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3]
# def world2ndc(self,points,lindisp=None):
# device = points.device
# return (points - self.center.to(device)) / self.radius.to(device)
def __len__(self):
return len(self.all_rgbs) if self.split=='test' else 300000
def __getitem__(self, idx):
idx_rand = self.sampler.nextids() #torch.randint(0,len(self.all_rays),(self.batch_size,))
sample = {'rays': self.all_rays[idx_rand], 'rgbs': self.all_rgbs[idx_rand]}
return sample
================================================
FILE: dataLoader/blender_set.py
================================================
import torch, cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
class BlenderDatasetSet(Dataset):
def __init__(self, cfg, split='train'):
# self.N_vis = N_vis
self.root_dir = cfg.datadir
self.split = split
self.is_stack = False if 'train'==split else True
self.downsample = cfg.get(f'downsample_{self.split}')
self.img_wh = (int(800 / self.downsample), int(800 / self.downsample))
self.define_transforms()
self.rot = torch.tensor([[0.65561799, -0.65561799, 0.37460659],
[0.73729737, 0.44876192, -0.50498052],
[0.16296514, 0.60727077, 0.77760181]])
self.scene_bbox = (np.array([[-1.0, -1.0, -1.0, 0], [1.0, 1.0, 1.0, 2]])).tolist()
# self.scene_bbox = [[-0.8,-0.8,-0.22],[0.8,0.8,0.2]]
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.define_proj_mat()
self.white_bg = True
self.near_far = [2.0, 6.0]
# self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
# self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
def read_depth(self, filename):
depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
return depth
def read_meta(self):
with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
self.meta = json.load(f)
w, h = self.img_wh
self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(h, w, [self.focal, self.focal]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
self.intrinsics = torch.tensor([[self.focal, 0, w / 2], [0, self.focal, h / 2], [0, 0, 1]]).float()
self.image_paths = []
self.poses = []
self.all_rays = []
self.all_rgbs = []
self.all_masks = []
self.all_depth = []
self.downsample = 1.0
img_eval_interval = 1 #if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:#
frame = self.meta['frames'][i]
pose = np.array(frame['transform_matrix']) @ self.blender2opencv
c2w = torch.FloatTensor(pose)
c2w[:3,-1] /= 1.5
self.poses += [c2w]
image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
self.image_paths += [image_path]
img = Image.open(image_path)
if self.downsample != 1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (4, h, w)
img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.all_rgbs += [img]
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
# rays_o, rays_d = rays_o@self.rot, rays_d@self.rot
scene_id = torch.ones_like(rays_o[...,:1])*0
self.all_rays += [torch.cat([rays_o, rays_d, scene_id], 1)] # (h*w, 6)
self.poses = torch.stack(self.poses)
if not self.is_stack:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
# self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3)
# self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
def define_transforms(self):
self.transform = T.ToTensor()
def define_proj_mat(self):
self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3]
# def world2ndc(self,points,lindisp=None):
# device = points.device
# return (points - self.center.to(device)) / self.radius.to(device)
def __len__(self):
return len(self.all_rgbs)
def __getitem__(self, idx):
rays = torch.cat((self.all_rays[idx],torch.tensor([0+0.5])),dim=-1)
sample = {'rays': rays, 'rgbs': self.all_rgbs[idx]}
return sample
================================================
FILE: dataLoader/colmap.py
================================================
import torch, cv2
from torch.utils.data import Dataset
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
class ColmapDataset(Dataset):
def __init__(self, cfg, split='train'):
self.cfg = cfg
self.root_dir = cfg.datadir
self.split = split
self.is_stack = False if 'train'==split else True
self.downsample = cfg.get(f'downsample_{self.split}')
self.define_transforms()
self.img_eval_interval = 8
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])#np.eye(4)#
self.read_meta()
self.white_bg = cfg.get('white_bg')
# self.near_far = [0.1,2.0]
def read_meta(self):
if os.path.exists(f'{self.root_dir}/transforms.json'):
self.meta = load_json(f'{self.root_dir}/transforms.json')
i_test = np.arange(0, len(self.meta['frames']), self.img_eval_interval) # [np.argmin(dists)]
idxs = i_test if self.split != 'train' else list(set(np.arange(len(self.meta['frames']))) - set(i_test))
else:
self.meta = load_json(f'{self.root_dir}/transforms_{self.split}.json')
idxs = np.arange(0, len(self.meta['frames'])) # [np.argmin(dists)]
inv_split = 'train' if self.split!='train' else 'test'
self.meta['frames'] += load_json(f'{self.root_dir}/transforms_{inv_split}.json')['frames']
print(len(self.meta['frames']),len(idxs))
self.scale = self.meta.get('scale', 0.5)
self.offset = torch.FloatTensor(self.meta.get('offset', [0.0,0.0,0.0]))
# self.scene_bbox = (torch.tensor([[-6.,-7.,-10.0],[6.,7.,10.]])/5).tolist()
# self.scene_bbox = [[-1., -1., -1.0], [1., 1., 1.]]
# center, radius = torch.tensor([-0.082157, 2.415426,-3.703080]), torch.tensor([7.36916, 11.34958, 20.1616])/2
# self.scene_bbox = torch.stack([center-radius, center+radius]).tolist()
h, w = int(self.meta.get('h')), int(self.meta.get('w'))
cx, cy = self.meta.get('cx'), self.meta.get('cy')
self.focal = [self.meta.get('fl_x'), self.meta.get('fl_y')]
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(h, w, self.focal, center=[cx, cy]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
# self.intrinsics = torch.FloatTensor([[self.focal[0], 0, cx], [0, self.focal[1], cy], [0, 0, 1]])
# self.intrinsics[:2] /= self.downsample
poses = pose_from_json(self.meta, self.blender2opencv)
poses, self.scene_bbox = orientation(poses, f'{self.root_dir}/colmap_text/points3D.txt')
self.image_paths = []
self.poses = []
self.all_rays = []
self.all_rgbs = []
self.all_masks = []
self.all_depth = []
self.img_wh = [w,h]
for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'): # img_list:#
frame = self.meta['frames'][i]
c2w = torch.FloatTensor(poses[i])
# c2w[:3,3] = (c2w[:3,3]*self.scale + self.offset)*2-1
self.poses += [c2w]
image_path = os.path.join(self.root_dir, frame['file_path'])
self.image_paths += [image_path]
img = Image.open(image_path)
if self.downsample != 1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img)
if img.shape[0]==4:
img = img[:3] * img[-1:] + (1 - img[-1:]) # blend A to RGB
img = img.view(3, -1).permute(1, 0)
self.all_rgbs += [img]
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
self.poses = torch.stack(self.poses)
if not self.is_stack:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3)
def define_transforms(self):
self.transform = T.ToTensor()
def __len__(self):
return len(self.all_rgbs)
def __getitem__(self, idx):
if self.split == 'train': # use data in the buffers
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}
else: # create data for each image separately
img = self.all_rgbs[idx]
rays = self.all_rays[idx]
sample = {'rays': rays,
'rgbs': img}
return sample
================================================
FILE: dataLoader/colmap2nerf.py
================================================
#!/usr/bin/env python3
# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import argparse
import os
from pathlib import Path, PurePosixPath
import numpy as np
import json
import sys
import math
import cv2
import os
import shutil
def parse_args():
parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place")
parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also")
parser.add_argument("--video_fps", default=2)
parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video")
parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder")
parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images")
parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename")
parser.add_argument("--colmap_camera_model", default="OPENCV", choices=["SIMPLE_PINHOLE", "PINHOLE", "SIMPLE_RADIAL", "RADIAL","OPENCV"], help="camera model")
parser.add_argument("--colmap_camera_params", default="", help="intrinsic parameters, depending on the chosen model. Format: fx,fy,cx,cy,dist")
parser.add_argument("--images", default="images", help="input path to the images")
parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)")
parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16")
parser.add_argument("--skip_early", default=0, help="skip this many images from the start")
parser.add_argument("--keep_colmap_coords", action="store_true", help="keep transforms.json in COLMAP's original frame of reference (this will avoid reorienting and repositioning the scene for preview and rendering)")
parser.add_argument("--out", default="transforms.json", help="output path")
parser.add_argument("--vocab_path", default="", help="vocabulary tree path")
args = parser.parse_args()
return args
def do_system(arg):
print(f"==== running: {arg}")
err = os.system(arg)
if err:
print("FATAL: command failed")
sys.exit(err)
def run_ffmpeg(args):
if not os.path.isabs(args.images):
args.images = os.path.join(os.path.dirname(args.video_in), args.images)
images = args.images
video = args.video_in
fps = float(args.video_fps) or 1.0
print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.")
if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
sys.exit(1)
try:
shutil.rmtree(images)
except:
pass
do_system(f"mkdir {images}")
time_slice_value = ""
time_slice = args.time_slice
if time_slice:
start, end = time_slice.split(",")
time_slice_value = f",select='between(t\,{start}\,{end})'"
do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg")
def run_colmap(args):
db = args.colmap_db
images = "\"" + args.images + "\""
db_noext=str(Path(db).with_suffix(""))
if args.text=="text":
args.text=db_noext+"_text"
text=args.text
sparse=db_noext+"_sparse"
print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}")
if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
sys.exit(1)
if os.path.exists(db):
os.remove(db)
do_system(f"colmap feature_extractor --ImageReader.camera_model {args.colmap_camera_model} --ImageReader.camera_params \"{args.colmap_camera_params}\" --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}")
match_cmd = f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}"
if args.vocab_path:
match_cmd += f" --VocabTreeMatching.vocab_tree_path {args.vocab_path}"
do_system(match_cmd)
try:
shutil.rmtree(sparse)
except:
pass
do_system(f"mkdir {sparse}")
do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}")
do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1")
try:
shutil.rmtree(text)
except:
pass
do_system(f"mkdir {text}")
do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT")
def variance_of_laplacian(image):
return cv2.Laplacian(image, cv2.CV_64F).var()
def sharpness(imagePath):
image = cv2.imread(imagePath)
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
fm = variance_of_laplacian(gray)
return fm
def qvec2rotmat(qvec):
return np.array([
[
1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
], [
2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
], [
2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
]
])
def rotmat(a, b):
a, b = a / np.linalg.norm(a), b / np.linalg.norm(b)
v = np.cross(a, b)
c = np.dot(a, b)
s = np.linalg.norm(v)
kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10))
def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel
da = da / np.linalg.norm(da)
db = db / np.linalg.norm(db)
c = np.cross(da, db)
denom = np.linalg.norm(c)**2
t = ob - oa
ta = np.linalg.det([t, db, c]) / (denom + 1e-10)
tb = np.linalg.det([t, da, c]) / (denom + 1e-10)
if ta > 0:
ta = 0
if tb > 0:
tb = 0
return (oa+ta*da+ob+tb*db) * 0.5, denom
############ orientation ##############
def normalize(x):
return x / np.linalg.norm(x)
def rotation_matrix_from_vectors(vec1, vec2):
""" Find the rotation matrix that aligns vec1 to vec2
:param vec1: A 3d "source" vector
:param vec2: A 3d "destination" vector
:return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
"""
a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
v = np.cross(a, b)
if any(v): # if not all zeros then
c = np.dot(a, b)
s = np.linalg.norm(v)
kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2))
else:
return np.eye(3) # cross of all zeros only occurs on identical directions
def rotation_up(poses):
up = normalize(np.linalg.lstsq(poses[:, :3, 1],np.ones((poses.shape[0],1)),rcond=None)[0])
rot = rotation_matrix_from_vectors(up,np.array([0.,1.,0.]))
return rot
def search_orientation(points):
from scipy.spatial.transform import Rotation as R
bbox_sizes,rot_mats,bboxs = [],[],[]
for y_angle in np.linspace(-45,45,15):
rotvec = np.array([0,y_angle,0])/180*np.pi
rot = R.from_rotvec(rotvec).as_matrix()
point_orientation = rot@points
bbox = np.max(point_orientation,axis=1) - np.min(point_orientation,axis=1)
bbox_sizes.append(np.prod(bbox))
rot_mats.append(rot)
bboxs.append(bbox)
rot = rot_mats[np.argmin(bbox_sizes)]
bbox = bboxs[np.argmin(bbox_sizes)]
return rot,bbox
def load_point_txt(path):
points = []
with open(path, "r") as f:
for line in f:
if line[0] == "#":
continue
els = line.split(" ")
points.append([float(els[1]),float(els[2]),float(els[3])])
return np.stack(points)
# def load_c2ws(frames):
# c2ws =
# for f in frames:
# f["transform_matrix"] -= totp
#
# def oritation(transform_matrix, point_txt):
if __name__ == "__main__":
args = parse_args()
if args.video_in != "":
run_ffmpeg(args)
if args.run_colmap:
run_colmap(args)
AABB_SCALE = int(args.aabb_scale)
SKIP_EARLY = int(args.skip_early)
IMAGE_FOLDER = args.images
TEXT_FOLDER = args.text
OUT_PATH = args.out
print(f"outputting to {OUT_PATH}...")
with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f:
angle_x = math.pi / 2
for line in f:
# 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691
# 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224
# 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443
if line[0] == "#":
continue
els = line.split(" ")
w = float(els[2])
h = float(els[3])
fl_x = float(els[4])
fl_y = float(els[4])
k1 = 0
k2 = 0
p1 = 0
p2 = 0
cx = w / 2
cy = h / 2
if els[1] == "SIMPLE_PINHOLE":
cx = float(els[5])
cy = float(els[6])
elif els[1] == "PINHOLE":
fl_y = float(els[5])
cx = float(els[6])
cy = float(els[7])
elif els[1] == "SIMPLE_RADIAL":
cx = float(els[5])
cy = float(els[6])
k1 = float(els[7])
elif els[1] == "RADIAL":
cx = float(els[5])
cy = float(els[6])
k1 = float(els[7])
k2 = float(els[8])
elif els[1] == "OPENCV":
fl_y = float(els[5])
cx = float(els[6])
cy = float(els[7])
k1 = float(els[8])
k2 = float(els[9])
p1 = float(els[10])
p2 = float(els[11])
else:
print("unknown camera model ", els[1])
# fl = 0.5 * w / tan(0.5 * angle_x);
angle_x = math.atan(w / (fl_x * 2)) * 2
angle_y = math.atan(h / (fl_y * 2)) * 2
fovx = angle_x * 180 / math.pi
fovy = angle_y * 180 / math.pi
print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ")
with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f:
i = 0
bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4])
out = {
"camera_angle_x": angle_x,
"camera_angle_y": angle_y,
"fl_x": fl_x,
"fl_y": fl_y,
"k1": k1,
"k2": k2,
"p1": p1,
"p2": p2,
"cx": cx,
"cy": cy,
"w": w,
"h": h,
"aabb_scale": AABB_SCALE,
"frames": [],
}
up = np.zeros(3)
for line in f:
line = line.strip()
if line[0] == "#":
continue
i = i + 1
if i < SKIP_EARLY*2:
continue
if i % 2 == 1:
elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces)
#name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9])))
# why is this requireing a relitive path while using ^
image_rel = os.path.relpath(IMAGE_FOLDER)
name = str(f"./{image_rel}/{'_'.join(elems[9:])}")
b=sharpness(name)
print(name, "sharpness=",b)
image_id = int(elems[0])
qvec = np.array(tuple(map(float, elems[1:5])))
tvec = np.array(tuple(map(float, elems[5:8])))
R = qvec2rotmat(-qvec)
t = tvec.reshape([3,1])
m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
c2w = np.linalg.inv(m)
# c2w[0:3,2] *= -1 # flip the y and z axis
# c2w[0:3,1] *= -1
# c2w = c2w[[1,0,2,3],:] # swap y and z
# c2w[2,:] *= -1 # flip whole world upside down
# up += c2w[0:3,1]
frame={"file_path":name,"sharpness":b,"transform_matrix": c2w}
out["frames"].append(frame)
# up = up / np.linalg.norm(up)
# print("up vector was", up)
# R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]
# R = np.pad(R,[0,1])
# R[-1, -1] = 1
# for f in out["frames"]:
# f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis
nframes = len(out["frames"])
# find a central point they are all looking at
print("computing center of attention...")
totw = 0.0
totp = np.array([0.0, 0.0, 0.0])
for f in out["frames"]:
mf = f["transform_matrix"][0:3,:]
for g in out["frames"]:
mg = g["transform_matrix"][0:3,:]
p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])
if w > 0.01:
totp += p*w
totw += w
totp /= totw
print(totp) # the cameras are looking at totp
for f in out["frames"]:
f["transform_matrix"][0:3,3] -= totp
avglen = 0.
for f in out["frames"]:
avglen += np.linalg.norm(f["transform_matrix"][0:3,3])
avglen /= nframes
print("avg camera distance from origin", avglen)
for f in out["frames"]:
f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized"
for f in out["frames"]:
f["transform_matrix"] = f["transform_matrix"].tolist()
print(nframes,"frames")
print(f"writing {OUT_PATH}")
with open(OUT_PATH, "w") as outfile:
json.dump(out, outfile, indent=2)
================================================
FILE: dataLoader/dtu_objs.py
================================================
import torch
import cv2 as cv
import numpy as np
import os
from glob import glob
from .ray_utils import *
from torch.utils.data import Dataset
# This function is borrowed from IDR: https://github.com/lioryariv/idr
def load_K_Rt_from_P(filename, P=None):
if P is None:
lines = open(filename).read().splitlines()
if len(lines) == 4:
lines = lines[1:]
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
P = np.asarray(lines).astype(np.float32).squeeze()
out = cv.decomposeProjectionMatrix(P)
K = out[0]
R = out[1]
t = out[2]
K = K / K[2, 2]
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose()
pose[:3, 3] = (t[:3] / t[3])[:, 0]
return intrinsics, pose
def fps_downsample(points, n_points_to_sample):
selected_points = np.zeros((n_points_to_sample, 3))
selected_idxs = []
dist = np.ones(points.shape[0]) * 100
for i in range(n_points_to_sample):
idx = np.argmax(dist)
selected_points[i] = points[idx]
selected_idxs.append(idx)
dist_ = ((points - selected_points[i]) ** 2).sum(-1)
dist = np.minimum(dist, dist_)
return selected_idxs
class DTUDataset(Dataset):
def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
"""
img_wh should be set to a tuple ex: (1152, 864) to enable test mode!
"""
# self.N_vis = N_vis
self.split = split
self.batch_size = batch_size
self.root_dir = cfg.datadir
self.is_stack = is_stack if is_stack is not None else 'train'!=split
self.downsample = cfg.get(f'downsample_{self.split}')
self.img_wh = (int(400 / self.downsample), int(300 / self.downsample))
train_scene_idxs = sorted(cfg.train_scene_list)
test_scene_idxs = cfg.test_scene_list
if len(train_scene_idxs)==2:
train_scene_idxs = list(range(train_scene_idxs[0],train_scene_idxs[1]))
self.scene_idxs = train_scene_idxs if self.split=='train' else test_scene_idxs
print(self.scene_idxs)
self.train_views = cfg.train_views
self.scene_num = len(self.scene_idxs)
self.test_index = test_scene_idxs
# if 'test' == self.split:
# self.test_index = train_scene_idxs.index(test_scene_idxs[0])
self.scene_path_list = [os.path.join(self.root_dir, f"scan{i}") for i in self.scene_idxs]
# self.scene_path_list = sorted(glob(os.path.join(self.root_dir, "scan*")))
self.read_meta()
self.white_bg = False
def read_meta(self):
self.aabbs = []
self.all_rgb_files,self.all_pose_files,self.all_intrinsics_files = {},{},{}
for i, scene_idx in enumerate(self.scene_idxs):
scene_path = self.scene_path_list[i]
camera_dict = np.load(os.path.join(scene_path, 'cameras.npz'))
self.all_rgb_files[scene_idx] = [
os.path.join(scene_path, "image", f)
for f in sorted(os.listdir(os.path.join(scene_path, "image")))
]
# world_mat is a projection matrix from world to image
n_images = len(self.all_rgb_files[scene_idx])
world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
object_scale_mat = camera_dict['scale_mat_0']
self.aabbs.append(self.get_bbox(scale_mats_np,object_scale_mat))
# W,H = self.img_wh
intrinsics_scene, poses_scene = [], []
for img_idx, (scale_mat, world_mat) in enumerate(zip(scale_mats_np, world_mats_np)):
P = world_mat @ scale_mat
P = P[:3, :4]
intrinsic, c2w = load_K_Rt_from_P(None, P)
c2w = torch.from_numpy(c2w).float()
intrinsic = torch.from_numpy(intrinsic).float()
intrinsic[:2] /= self.downsample
poses_scene.append(c2w)
intrinsics_scene.append(intrinsic)
self.all_pose_files[scene_idx] = np.stack(poses_scene)
self.all_intrinsics_files[scene_idx] = np.stack(intrinsics_scene)
self.aabbs[0][0].append(0)
self.aabbs[0][1].append(self.scene_num)
self.scene_bbox = self.aabbs[0]
print(self.scene_bbox)
if self.split=='test' or self.scene_num==1:
self.load_data(self.scene_idxs[0],range(49))
def load_data(self, scene_idx, img_idx=None):
self.all_rays = []
n_views = len(self.all_pose_files[scene_idx])
cam_xyzs = self.all_pose_files[scene_idx][:,:3, -1]
idxs = fps_downsample(cam_xyzs, min(self.train_views, n_views)) if img_idx is None else img_idx
# if "test" == self.split:
# idxs = [item for item in list(range(n_views)) if item not in idxs]
# if len(idxs)==0:
# idxs = list(range(n_views))
images_np = np.stack([cv.resize(cv.imread(self.all_rgb_files[scene_idx][idx]), self.img_wh) for idx in idxs]) / 255.0
self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[..., [2, 1, 0]]) # [n_images, H, W, 3]
for c2w,intrinsic in zip(self.all_pose_files[scene_idx][idxs],self.all_intrinsics_files[scene_idx][idxs]):
rays_o, rays_d = self.gen_rays_at(torch.from_numpy(intrinsic).float(), torch.from_numpy(c2w).float())
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
if not self.is_stack:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = self.all_rgbs.reshape(-1, 3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
# self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size)
# def read_meta(self):
#
# images_lis = sorted(glob(os.path.join(self.root_dir, 'image/*.png')))
# images_np = np.stack([cv.resize(cv.imread(im_name), self.img_wh) for im_name in images_lis]) / 255.0
# # masks_lis = sorted(glob(os.path.join(self.root_dir, 'mask/*.png')))
# # masks_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in masks_lis])>128
#
# self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[..., [2, 1, 0]]) # [n_images, H, W, 3]
# # self.all_masks = torch.from_numpy(masks_np>0) # [n_images, H, W, 3]
# self.img_wh = [self.all_rgbs.shape[2], self.all_rgbs.shape[1]]
#
# # world_mat is a projection matrix from world to image
# n_images = len(images_lis)
# world_mats_np = [self.camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
# self.scale_mats_np = [self.camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
#
# # W,H = self.img_wh
# self.all_rays = []
# self.intrinsics, self.poses = [], []
# for img_idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, world_mats_np)):
# P = world_mat @ scale_mat
# P = P[:3, :4]
# intrinsic, c2w = load_K_Rt_from_P(None, P)
#
# c2w = torch.from_numpy(c2w).float()
# intrinsic = torch.from_numpy(intrinsic).float()
# intrinsic[:2] /= self.downsample
#
# self.poses.append(c2w)
# self.intrinsics.append(intrinsic)
#
# rays_o, rays_d = self.gen_rays_at(intrinsic, c2w)
# self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
#
# self.intrinsics, self.poses = torch.stack(self.intrinsics), torch.stack(self.poses)
#
# # self.all_rgbs[~self.all_masks] = 1.0
# if not self.is_stack:
# self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
# self.all_rgbs = self.all_rgbs.reshape(-1, 3)
# else:
# self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
# self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
#
# self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size)
def get_bbox(self, scale_mats_np, object_scale_mat):
object_bbox_min = np.array([-1.0, -1.0, -1.0, 1.0])
object_bbox_max = np.array([ 1.0, 1.0, 1.0, 1.0])
# Object scale mat: region of interest to **extract mesh**
# object_scale_mat = np.load(os.path.join(scene_path, 'cameras.npz'))
object_bbox_min = np.linalg.inv(scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
object_bbox_max = np.linalg.inv(scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
return [object_bbox_min[:3, 0].tolist(),object_bbox_max[:3, 0].tolist()]
# self.near_far = [2.125, 4.525]
def gen_rays_at(self, intrinsic, c2w, resolution_level=1):
"""
Generate rays at world space from one camera.
"""
l = resolution_level
W,H = self.img_wh
tx = torch.linspace(0, W - 1, W // l)+0.5
ty = torch.linspace(0, H - 1, H // l)+0.5
pixels_x, pixels_y = torch.meshgrid(tx, ty)
p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
intrinsic_inv = torch.inverse(intrinsic)
p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
rays_v = torch.matmul(c2w[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
rays_o = c2w[None, None, :3, 3].expand(rays_v.shape) # W, H, 3
return rays_o.transpose(0, 1).reshape(-1,3), rays_v.transpose(0, 1).reshape(-1,3)
def __len__(self):
return 1000000 #len(self.all_rays)
def __getitem__(self, idx):
idx = torch.randint(self.scene_num,(1,)).item()
if self.scene_num >= 1:
scene_name = self.scene_idxs[idx]
img_idx = np.random.choice(len(self.all_rgb_files[scene_name]), size=6)
self.load_data(scene_name, img_idx)
idxs = np.random.choice(self.all_rays.shape[0], size=self.batch_size)
return {'rays': self.all_rays[idxs], 'rgbs': self.all_rgbs[idxs], 'idx': idx}
================================================
FILE: dataLoader/dtu_objs2.py
================================================
import torch
import cv2 as cv
import numpy as np
import os
from glob import glob
from .ray_utils import *
from torch.utils.data import Dataset
# This function is borrowed from IDR: https://github.com/lioryariv/idr
def load_K_Rt_from_P(filename, P=None):
if P is None:
lines = open(filename).read().splitlines()
if len(lines) == 4:
lines = lines[1:]
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
P = np.asarray(lines).astype(np.float32).squeeze()
out = cv.decomposeProjectionMatrix(P)
K = out[0]
R = out[1]
t = out[2]
K = K / K[2, 2]
intrinsics = np.eye(4)
intrinsics[:3, :3] = K
pose = np.eye(4, dtype=np.float32)
pose[:3, :3] = R.transpose()
pose[:3, 3] = (t[:3] / t[3])[:, 0]
return intrinsics, pose
class DTUDataset(Dataset):
def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
"""
img_wh should be set to a tuple ex: (1152, 864) to enable test mode!
"""
# self.N_vis = N_vis
self.split = split
self.batch_size = batch_size
self.root_dir = cfg.datadir
self.is_stack = is_stack if is_stack is not None else 'train'!=split
self.downsample = cfg.get(f'downsample_{self.split}')
self.img_wh = (int(400 / self.downsample), int(300 / self.downsample))
self.white_bg = False
self.camera_dict = np.load(os.path.join(self.root_dir, 'cameras.npz'))
self.read_meta()
self.get_bbox()
# def define_transforms(self):
# self.transform = T.ToTensor()
def get_bbox(self):
object_bbox_min = np.array([-1.0, -1.0, -1.0, 1.0])
object_bbox_max = np.array([ 1.0, 1.0, 1.0, 1.0])
# Object scale mat: region of interest to **extract mesh**
object_scale_mat = np.load(os.path.join(self.root_dir, 'cameras.npz'))['scale_mat_0']
object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
self.scene_bbox = [object_bbox_min[:3, 0].tolist(),object_bbox_max[:3, 0].tolist()]
self.scene_bbox[0].append(0)
self.scene_bbox[1].append(1)
def gen_rays_at(self, intrinsic, c2w, resolution_level=1):
"""
Generate rays at world space from one camera.
"""
l = resolution_level
W,H = self.img_wh
tx = torch.linspace(0, W - 1, W // l)+0.5
ty = torch.linspace(0, H - 1, H // l)+0.5
pixels_x, pixels_y = torch.meshgrid(tx, ty)
p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
intrinsic_inv = torch.inverse(intrinsic)
p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze() # W, H, 3
rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
rays_v = torch.matmul(c2w[None, None, :3, :3], rays_v[:, :, :, None]).squeeze() # W, H, 3
rays_o = c2w[None, None, :3, 3].expand(rays_v.shape) # W, H, 3
return rays_o.transpose(0, 1).reshape(-1,3), rays_v.transpose(0, 1).reshape(-1,3)
def read_meta(self):
images_lis = sorted(glob(os.path.join(self.root_dir, 'image/*.png')))
images_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in images_lis]) / 255.0
# masks_lis = sorted(glob(os.path.join(self.root_dir, 'mask/*.png')))
# masks_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in masks_lis])>128
self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[...,[2,1,0]]) # [n_images, H, W, 3]
# self.all_masks = torch.from_numpy(masks_np>0) # [n_images, H, W, 3]
self.img_wh = [self.all_rgbs.shape[2],self.all_rgbs.shape[1]]
# world_mat is a projection matrix from world to image
n_images = len(images_lis)
world_mats_np = [self.camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
self.scale_mats_np = [self.camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
# W,H = self.img_wh
self.all_rays = []
self.intrinsics, self.poses = [],[]
for img_idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, world_mats_np)):
P = world_mat @ scale_mat
P = P[:3, :4]
intrinsic, c2w = load_K_Rt_from_P(None, P)
c2w = torch.from_numpy(c2w).float()
intrinsic = torch.from_numpy(intrinsic).float()
intrinsic[:2] /= self.downsample
self.poses.append(c2w)
self.intrinsics.append(intrinsic)
rays_o, rays_d = self.gen_rays_at(intrinsic,c2w)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
self.intrinsics, self.poses = torch.stack(self.intrinsics), torch.stack(self.poses)
# self.all_rgbs[~self.all_masks] = 1.0
if not self.is_stack:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = self.all_rgbs.reshape(-1,3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1],3) # (len(self.meta['frames]),h,w,3)
self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size)
def __len__(self):
return len(self.all_rays)
def __getitem__(self, idx):
idx_rand = self.sampler.nextids() #torch.randint(0,len(self.all_rays),(self.batch_size,))
sample = {'rays': self.all_rays[idx_rand], 'rgbs': self.all_rgbs[idx_rand]}
return sample
================================================
FILE: dataLoader/google_objs.py
================================================
import torch, cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
import glob
from scipy.spatial.transform import Rotation as R
from .ray_utils import *
def fps_downsample(points, n_points_to_sample):
selected_points = np.zeros((n_points_to_sample, 3))
selected_idxs = []
dist = np.ones(points.shape[0]) * 100
for i in range(n_points_to_sample):
idx = np.argmax(dist).tolist()
selected_points[i] = points[idx]
print(idx)
selected_idxs.extend([idx])
dist_ = ((points - selected_points[i]) ** 2).sum(-1)
dist = np.minimum(dist, dist_).tolist()
return selected_idxs
############################# Get Spherical Path #############################
def pose_spherical_nerf(euler, radius=1.8, ep=1):
c2ws_render = np.eye(4)
c2ws_render[:3,:3] = R.from_euler('xyz', euler, degrees=True).as_matrix()
# 保留旋转矩阵的最后一列再乘个系数就能当作位置?
c2ws_render[:3,3] = c2ws_render[:3,:3] @ np.array([0.0,0.0,ep*radius])
return c2ws_render
def nerf_video_path(c2ws, theta_range=10,phi_range=20,N_views=120,radius=1.3,ep=-1):
c2ws = torch.tensor(c2ws)
mean_position = torch.mean(c2ws[:,:3, 3],dim=0).reshape(1,3).cpu().numpy()
rotvec = []
for i in range(c2ws.shape[0]):
r = R.from_matrix(c2ws[i, :3, :3])
euler_ange = r.as_euler('xyz', degrees=True).reshape(1, 3)
if i:
mask = np.abs(euler_ange - rotvec[0])>180
euler_ange[mask] += 360.0
rotvec.append(euler_ange)
# 采用欧拉角做平均的方法求旋转矩阵的平均
rotvec = np.mean(np.stack(rotvec), axis=0)
render_poses = [pose_spherical_nerf(rotvec+np.array([angle,0.0,-phi_range]), radius=radius, ep=ep) for angle in np.linspace(-theta_range,theta_range,N_views//4, endpoint=False)]
render_poses += [pose_spherical_nerf(rotvec+np.array([theta_range,0.0,angle]), radius=radius, ep=ep) for angle in np.linspace(-phi_range,phi_range,N_views//4, endpoint=False)]
render_poses += [pose_spherical_nerf(rotvec+np.array([angle,0.0,phi_range]), radius=radius, ep=ep) for angle in np.linspace(theta_range,-theta_range,N_views//4, endpoint=False)]
render_poses += [pose_spherical_nerf(rotvec+np.array([-theta_range,0.0,angle]), radius=radius, ep=ep) for angle in np.linspace(phi_range,-phi_range,N_views//4, endpoint=False)]
return render_poses
def _interpolate_trajectory(c2ws, num_views: int = 300):
"""calculate interpolate path"""
from scipy.interpolate import interp1d
from scipy.spatial.transform import Rotation, Slerp
key_rots = Rotation.from_matrix(c2ws[:, :3, :3])
key_times = list(range(len(c2ws)))
slerp = Slerp(key_times, key_rots)
interp = interp1d(key_times, c2ws[:, :3, 3], axis=0)
render_c2ws = []
for i in range(num_views):
time = float(i) / num_views * (len(c2ws) - 1)
cam_location = interp(time)
cam_rot = slerp(time).as_matrix()
c2w = np.eye(4)
c2w[:3, :3] = cam_rot
c2w[:3, 3] = cam_location
render_c2ws.append(c2w)
return np.stack(render_c2ws, axis=0)
def google_objs_path(c2ws, N_views=150):
positions = c2ws[:, :3, 3]
selected_idxs = fps_downsample(positions, 3)
selected_idxs.append(selected_idxs[0])
return _interpolate_trajectory(c2ws[selected_idxs].numpy(), N_views)
class GoogleObjsDataset(Dataset):
def __init__(self, cfg, split="train", batch_size=4096):
# self.N_vis = N_vis
self.cfg = cfg
self.root_dir = cfg.datadir
self.split = split
self.batch_size = batch_size
self.is_stack = False if "train" == split else True
self.downsample = cfg.get(f"downsample_{self.split}")
self.img_wh = (int(512 / self.downsample), int(512 / self.downsample))
self.define_transforms()
train_scene_idxs = sorted(cfg.train_scene_list)
test_scene_idxs = cfg.test_scene_list
if len(train_scene_idxs)==2:
train_scene_idxs = list(range(train_scene_idxs[0],train_scene_idxs[1]))
self.scene_idxs = train_scene_idxs if self.split=='train' else test_scene_idxs
self.train_views = cfg.train_views
self.scene_num = len(self.scene_idxs)
if 'test' == self.split:
self.test_index = train_scene_idxs.index(test_scene_idxs[0])
# self.rot = torch.tensor([[0.65561799, -0.65561799, 0.37460659],
# [0.73729737, 0.44876192, -0.50498052],
# [0.16296514, 0.60727077, 0.77760181]])
self.scene_bbox = [[-1.0, -1.0, -1.0, 0.0], [1.0, 1.0, 1.0, self.scene_num]]
# self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.white_bg = True
self.near_far = [0.4, 1.6]
#################################
# self.folder_path = datadir
# self.num_source_views = args.num_source_views
# self.rectify_inplane_rotation = args.rectify_inplane_rotation
self.scene_path_list = sorted(glob.glob(os.path.join(self.root_dir, "*/")))
all_rgb_files = []
# all_depth_files = []
all_pose_files = []
all_intrinsics_files = []
num_files = 250
for i, scene_idx in enumerate(self.scene_idxs):
scene_path = self.scene_path_list[scene_idx]
# print(scene_idx,scene_path)
rgb_files = [
os.path.join(scene_path, "rgb", f)
for f in sorted(os.listdir(os.path.join(scene_path, "rgb")))
]
# depth_files = [os.path.join(scene_path, 'depth', f)
# for f in sorted(os.listdir(os.path.join(scene_path, 'depth')))]
pose_files = [
f.replace("rgb", "pose").replace("png", "txt") for f in rgb_files
]
intrinsics_files = [
f.replace("rgb", "intrinsics").replace("png", "txt") for f in rgb_files
]
if (
np.min([len(rgb_files), len(pose_files), len(intrinsics_files)])
< num_files
):
print(scene_path)
continue
all_rgb_files.append(rgb_files)
# all_depth_files.append(depth_files)
all_pose_files.append(pose_files)
all_intrinsics_files.append(intrinsics_files)
index = np.arange(len(all_rgb_files))
self.all_rgb_files = np.array(all_rgb_files)[index]
# self.all_depth_files = np.array(all_depth_files)[index]
self.all_pose_files = np.array(all_pose_files)[index]
self.all_intrinsics_files = np.array(all_intrinsics_files)[index]
if self.split=='test' or self.scene_num==1:
self.read_meta()
else:
self.train_idxs = range(self.train_views)
# self.define_proj_mat()
def read_depth(self, filename):
depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
return depth
def read_meta(self):
assert (
len(self.all_rgb_files)
== len(self.all_pose_files)
== len(self.all_intrinsics_files)
)
self.all_image_paths = []
self.all_poses = []
self.all_rays = []
self.all_rgbs = []
for idx in range(len(self.all_rgb_files)):
rgb_files = self.all_rgb_files[idx]
pose_files = self.all_pose_files[idx]
intrinsics_files = self.all_intrinsics_files[idx]
intrinsics = np.loadtxt(intrinsics_files[0])
index_3x3 = np.array([0, 1, 2, 4, 5, 6, 8, 9, 10])
self.intrinsics = intrinsics[index_3x3]
w, h = self.img_wh
self.focal = self.intrinsics[0]
self.focal *= (
self.img_wh[0] / 512
) # modify focal length to match size self.img_wh
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(
h, w, [self.focal, self.focal]
) # (h, w, 3)
self.directions = self.directions / torch.norm(
self.directions, dim=-1, keepdim=True
)
self.intrinsics = torch.tensor(
[[self.focal, 0, w / 2], [0, self.focal, h / 2], [0, 0, 1]]
).float()
self.scene_image_paths = []
self.scene_poses = []
self.scene_rays = []
self.scene_rgbs = []
# self.downsample = 1.0
img_eval_interval = (
1 # if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
)
if "train" == self.split:
cam_xyzs = []
# for i in range(len(pose_files)):
for i in range(100):
pose = np.loadtxt(pose_files[i])
cam_xyzs.append([pose[3], pose[7], pose[11]])
cam_xyzs = np.array(cam_xyzs)
idxs = fps_downsample(cam_xyzs, min(self.train_views, len(rgb_files)))
self.train_idxs = idxs
print("train idxs:", idxs)
else:
idxs = list(range(100, 200))
for i in tqdm(
idxs, desc=f"Loading data {self.split} ({len(idxs)})"
): # img_list:#
pose = np.loadtxt(pose_files[i])
pose = torch.FloatTensor(pose).view(4, 4)
self.scene_poses += [pose]
image_path = rgb_files[i]
self.scene_image_paths += [image_path]
img = Image.open(image_path)
if self.downsample != 1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (3, h, w)
img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGBA
# img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.scene_rgbs += [img]
rays_o, rays_d = get_rays(self.directions, pose) # both (h*w, 3)
# rays_o, rays_d = rays_o@self.rot, rays_d@self.rot
self.scene_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
self.scene_poses = torch.stack(self.scene_poses)
views = 180
radius = {183: 1.3, 199: 1.7, 298: 1.5, 467: 1.1, 957: 1.9, 244: 1.2, 963: 1.2, 527: 1.2, 681:1.9,948:1.2}
self.render_path = google_objs_path(self.scene_poses, N_views=views)
name = self.scene_path_list[self.scene_idxs[0]].split('/')[-2]
np.save(f'{self.root_dir}/{name}_render_path.npy', self.render_path)
# self.render_path = nerf_video_path(self.scene_poses, N_views=views, theta_range=45,phi_range=90, radius=radius[self.scene_idxs[0]],ep=-1)
if not self.is_stack:
self.scene_rays = torch.cat(
self.scene_rays, 0
) # (len(self.meta['frames])*h*w, 3)
self.scene_rgbs = torch.cat(
self.scene_rgbs, 0
) # (len(self.meta['frames])*h*w, 3)
# self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
else:
self.scene_rays = torch.stack(
self.scene_rays, 0
) # (len(self.meta['frames]),h*w, 3)
self.scene_rgbs = torch.stack(self.scene_rgbs, 0).reshape(
-1, *self.img_wh[::-1], 3
) # (len(self.meta['frames]),h,w,3)
# self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
######################## pre-generate and save in mem ########################
self.all_image_paths.append(self.scene_image_paths)
self.all_poses.append(self.scene_poses)
self.all_rays.append(self.scene_rays)
self.all_rgbs.append(self.scene_rgbs)
self.all_rays = torch.cat(self.all_rays)
self.all_rgbs = torch.cat(self.all_rgbs)
def get_rays(self, idx):
######################## pre-generate and save in mem ########################
return self.all_rays[idx]
def get_rgbs(self, idx):
######################## pre-generate and save in mem ########################
return self.all_rgbs[idx]
def define_transforms(self):
self.transform = T.ToTensor()
def define_proj_mat(self):
self.proj_mat = (
self.intrinsics.unsqueeze(0) @ torch.inverse(self.scene_poses)[:, :3]
)
def world2ndc(self, points, lindisp=None):
device = points.device
return (points - self.center.to(device)) / self.radius.to(device)
def update_index(self):
self.scene_idx = torch.randint(0, len(self.all_rgb_files), (1,)).item()
def __len__(self):
return 10000000 #len(self.all_rgb_files)
def __getitem__(self, idx):
#
# self.update_index()
# idx = self.scene_idx #
idx = idx % len(self.all_rgb_files)
# idx = torch.randint(len(self.all_rgb_files), (1,)).item()
######################## generate rays on the fly ########################
rgb_files = self.all_rgb_files[idx]
pose_files = self.all_pose_files[idx]
intrinsics_files = self.all_intrinsics_files[idx]
intrinsics = np.loadtxt(intrinsics_files[0])
index_3x3 = np.array([0, 1, 2, 4, 5, 6, 8, 9, 10])
intrinsics = intrinsics[index_3x3]
w, h = self.img_wh
focal = intrinsics[0]
focal *= self.img_wh[0] / 512 # modify focal length to match size self.img_wh
# ray directions for all pixels, same for all images (same H, W, focal)
directions = get_ray_directions(h, w, [focal, focal]) # (h, w, 3)
directions = directions / torch.norm(directions, dim=-1, keepdim=True)
intrinsics = torch.tensor(
[[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]]
).float()
scene_poses = []
scene_rays = []
scene_image_paths = []
scene_rgbs = []
# downsample = 1.0
sample_views = 5
if self.scene_num>1:
ids = np.random.choice(self.train_idxs, size=sample_views)
for i in ids:
image_path = rgb_files[i]
scene_image_paths += [image_path]
img = Image.open(image_path)
idxs = torch.randint(0, w*h, (self.batch_size // sample_views,))
if self.downsample != 1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (3, h, w)
img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGBA
# img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
scene_rgbs += [img[idxs]]
pose = np.loadtxt(pose_files[i])
pose = torch.FloatTensor(pose).view(4, 4)
scene_poses += [pose]
rays_o, rays_d = get_rays(directions, pose) # both (h*w, 3)
scene_rays += [torch.cat([rays_o[idxs], rays_d[idxs]], 1)] # (h*w, 6)
scene_poses = torch.stack(scene_poses)
if not self.is_stack:
scene_rays = torch.cat(scene_rays, 0) # (len(self.meta['frames])*h*w, 3)
scene_rgbs = torch.cat(scene_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
else:
scene_rays = torch.stack(scene_rays, 0) # (len(self.meta['frames]),h*w, 3)
scene_rgbs = torch.stack(scene_rgbs, 0) # (len(self.meta['frames]),h*w, 3)
return {'rays': scene_rays, 'rgbs': scene_rgbs, 'idx': idx}
else:
idx_rand = torch.randint(0, len(self.all_rays), (self.batch_size,))
return {'rays': self.all_rays[idx_rand], 'rgbs': self.all_rgbs[idx_rand]}
================================================
FILE: dataLoader/image.py
================================================
import torch,imageio,cv2
from PIL import Image
Image.MAX_IMAGE_PIXELS = 1000000000
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
_img_suffix = ['png','jpg','jpeg','bmp','tif']
def load(path):
suffix = path.split('.')[-1]
if suffix in _img_suffix:
img = np.array(Image.open(path))#.convert('L')
scale = 256.**(1+np.log2(np.max(img))//8)-1
return img/scale
elif 'exr' == suffix:
return imageio.imread(path)
elif 'npy' == suffix:
return np.load(path)
def srgb_to_linear(img):
limit = 0.04045
return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92)
class ImageDataset(Dataset):
def __init__(self, cfg, batchsize, split='train', continue_sampling=False, tolinear=False, HW=-1, perscent=1.0, delete_region=None,mask=None):
datadir = cfg.datadir
self.batchsize = batchsize
self.continue_sampling = continue_sampling
img = load(datadir).astype(np.float32)
if HW>0:
img = cv2.resize(img,[HW,HW])
if tolinear:
img = srgb_to_linear(img)
self.img = torch.from_numpy(img)
H,W = self.img.shape[:2]
y, x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W), indexing='ij')
self.coordiante = torch.stack((x,y),-1).float()+0.5
n_channel = self.img.shape[-1]
self.image = self.img
self.img, self.coordiante = self.img.reshape(H*W,-1), self.coordiante.reshape(H*W,2)
# if continue_sampling:
# coordiante_tmp = self.coordiante.view(1,1,-1,2)/torch.tensor([W,H])*2-1.0
# self.img = F.grid_sample(self.img.view(1,H,W,-1).permute(0,3,1,2),coordiante_tmp, mode='bilinear', align_corners=True).reshape(self.img.shape[-1],-1).t()
if 'train'==split:
self.mask = torch.ones_like(y)>0
if mask is not None:
self.mask = mask>0
print(torch.sum(mask)/1.0/HW/HW)
elif delete_region is not None:
if isinstance(delete_region[0], list):
for item in delete_region:
t_l_x,t_l_y,width,height = item
self.mask[t_l_y:t_l_y+height,t_l_x:t_l_x+width] = False
else:
t_l_x,t_l_y,width,height = delete_region
self.mask[t_l_y:t_l_y+height,t_l_x:t_l_x+width] = False
else:
index = torch.randperm(len(self.img))[:int(len(self.img)*perscent)]
self.mask[:] = False
self.mask.view(-1)[index] = True
self.mask = self.mask.view(-1)
self.image, self.coordiante = self.img[self.mask], self.coordiante[self.mask]
else:
self.image = self.img
self.HW = [H,W]
self.scene_bbox = [[0., 0.], [W, H]]
cfg.aabb = self.scene_bbox
#
def __len__(self):
return 10000
def __getitem__(self, idx):
H,W = self.HW
device = self.image.device
idx = torch.randint(0,len(self.image),(self.batchsize,), device=device)
if self.continue_sampling:
coordinate = self.coordiante[idx] + torch.rand((self.batchsize,2))-0.5
coordinate_tmp = (coordinate.view(1,1,self.batchsize,2))/torch.tensor([W,H],device=device)*2-1.0
rgb = F.grid_sample(self.img.view(1,H,W,-1).permute(0,3,1,2),coordinate_tmp, mode='bilinear',
align_corners=False, padding_mode='border').reshape(self.img.shape[-1],-1).t()
sample = {'rgb': rgb,
'xy': coordinate}
else:
sample = {'rgb': self.image[idx],
'xy': self.coordiante[idx]}
return sample
================================================
FILE: dataLoader/image_set.py
================================================
import torch,cv2
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
def srgb_to_linear(img):
limit = 0.04045
return torch.where(img > limit, torch.pow((img + 0.055) / 1.055, 2.4), img / 12.92)
def load(path, HW=512):
suffix = path.split('.')[-1]
if 'npy' == suffix:
img = np.load(path)
# img = 0.3*img[...,:1] + 0.59*img[...,1:2] + 0.11*img[...,2:]
if img.shape[-2]!=HW:
for i in range(img.shape[0]):
img[i] = cv2.resize(img[i],[HW,HW])
return img
class ImageSetDataset(Dataset):
def __init__(self, cfg, batchsize, split='train', continue_sampling=False, HW=512, N=10, tolinear=True):
datadir = cfg.datadir
self.batchsize = batchsize
self.continue_sampling = continue_sampling
imgs = load(datadir,HW=HW)[:N]
self.imgs = torch.from_numpy(imgs).float()/255
if tolinear:
self.imgs = srgb_to_linear(self.imgs)
D,H,W = self.imgs.shape[:3]
y, x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W), indexing='ij')
self.coordinate = torch.stack((x,y),-1).float()+0.5
self.imgs, self.coordinate = self.imgs.reshape(D,H*W,-1), self.coordinate.reshape(H*W,2)
self.DHW = [D,H,W]
self.scene_bbox = [[0., 0., 0.], [W, H, D]]
cfg.aabb = self.scene_bbox
# self.down_scale = 512.0/H
def __len__(self):
return 1000000
def __getitem__(self, idx):
D,H,W = self.DHW
pix_idx = torch.randint(0,H*W,(self.batchsize,))
img_idx = torch.randint(0,D,(self.batchsize,))
if self.continue_sampling:
coordinate = self.coordinate[pix_idx] + torch.rand((self.batchsize,2)) - 0.5
coordinate = torch.cat((coordinate,img_idx.unsqueeze(-1)+0.5),dim=-1)
coordinate_tmp = (coordinate.view(1,1,1,self.batchsize,3))/torch.tensor([W,H,D])*2-1.0
rgb = F.grid_sample(self.imgs.view(1,D,H,W,-1).permute(0,4,1,2,3),coordinate_tmp, mode='bilinear',
align_corners=False, padding_mode='border').reshape(self.imgs.shape[-1],-1).t()
# coordinate[:,:2] *= self.down_scale
sample = {'rgb': rgb,
'xy': coordinate}
else:
sample = {'rgb': self.imgs[img_idx,pix_idx],
'xy': torch.cat((self.coordinate[pix_idx],img_idx.unsqueeze(-1)+0.5),dim=-1)}
# 'xy': torch.cat((self.coordiante[pix_idx],img_idx.expand_as(pix_idx).unsqueeze(-1)),dim=-1)}
return sample
================================================
FILE: dataLoader/llff.py
================================================
import torch
from torch.utils.data import Dataset
import glob
import numpy as np
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
def normalize(v):
"""Normalize a vector."""
return v / np.linalg.norm(v)
def average_poses(poses):
"""
Calculate the average pose, which is then used to center all poses
using @center_poses. Its computation is as follows:
1. Compute the center: the average of pose centers.
2. Compute the z axis: the normalized average z axis.
3. Compute axis y': the average y axis.
4. Compute x' = y' cross product z, then normalize it as the x axis.
5. Compute the y axis: z cross product x.
Note that at step 3, we cannot directly use y' as y axis since it's
not necessarily orthogonal to z axis. We need to pass from x to y.
Inputs:
poses: (N_images, 3, 4)
Outputs:
pose_avg: (3, 4) the average pose
"""
# 1. Compute the center
center = poses[..., 3].mean(0) # (3)
# 2. Compute the z axis
z = normalize(poses[..., 2].mean(0)) # (3)
# 3. Compute axis y' (no need to normalize as it's not the final output)
y_ = poses[..., 1].mean(0) # (3)
# 4. Compute the x axis
x = normalize(np.cross(z, y_)) # (3)
# 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
y = np.cross(x, z) # (3)
pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
return pose_avg
def center_poses(poses, blender2opencv):
"""
Center the poses so that we can use NDC.
See https://github.com/bmild/nerf/issues/34
Inputs:
poses: (N_images, 3, 4)
Outputs:
poses_centered: (N_images, 3, 4) the centered poses
pose_avg: (3, 4) the average pose
"""
poses = poses @ blender2opencv
pose_avg = average_poses(poses) # (3, 4)
pose_avg_homo = np.eye(4)
pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
pose_avg_homo = pose_avg_homo
# by simply adding 0, 0, 0, 1 as the last row
last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
poses_homo = \
np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
# poses_centered = poses_centered @ blender2opencv
poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
return poses_centered, pose_avg_homo
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.eye(4)
m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
return m
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
render_poses = []
rads = np.array(list(rads) + [1.])
for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:
c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
render_poses.append(viewmatrix(z, up, c))
return render_poses
def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
# center pose
c2w = average_poses(c2ws_all)
# Get average pose
up = normalize(c2ws_all[:, :3, 1].sum(0))
# Find a reasonable "focus depth" for this dataset
dt = 0.75
close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))
# Get radii for spiral path
zdelta = near_fars.min() * .2
tt = c2ws_all[:, :3, 3]
rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
return np.stack(render_poses)
class LLFFDataset(Dataset):
def __init__(self, cfg , split='train', hold_every=8):
"""
spheric_poses: whether the images are taken in a spheric inward-facing manner
default: False (forward-facing)
val_num: number of val images (used for multigpu training, validate same image for all gpus)
"""
self.root_dir = cfg.datadir
self.split = split
self.hold_every = hold_every
self.is_stack = False if 'train' == split else True
self.downsample = cfg.get(f'downsample_{self.split}')
self.define_transforms()
self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.white_bg = False
# self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]
self.near_far = [0.0, 1.0]
self.scene_bbox = [[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]]
# self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
# self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)
# self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
def read_meta(self):
print(self.root_dir)
poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17)
self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*')))
# load full resolution image then resize
if self.split in ['train', 'test']:
assert len(poses_bounds) == len(self.image_paths), \
'Mismatch between number of images and number of poses! Please rerun COLMAP!'
poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
self.near_fars = poses_bounds[:, -2:] # (N_images, 2)
hwf = poses[:, :, -1]
# Step 1: rescale focal length according to training resolution
H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images
self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]
# Step 2: correct poses
# Original poses has rotation in form "down right back", change to "right up back"
# See https://github.com/bmild/nerf/issues/34
poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
# (N_images, 3, 4) exclude H, W, focal
self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)
# Step 3: correct scale so that the nearest depth is at a little more than 1.0
# See https://github.com/bmild/nerf/issues/34
near_original = self.near_fars.min()
scale_factor = near_original * 0.75 # 0.75 is the default parameter
# the nearest depth is at 1/0.75=1.33
self.near_fars /= scale_factor
self.poses[..., 3] /= scale_factor
# build rendering path
N_views, N_rots = 120, 2
tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
up = normalize(self.poses[:, :3, 1].sum(0))
rads = np.percentile(np.abs(tt), 90, 0)
self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
# distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
# val_idx = np.argmin(distances_from_center) # choose val image as the closest to
# center image
# ray directions for all pixels, same for all images (same H, W, focal)
W, H = self.img_wh
self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3)
average_pose = average_poses(self.poses)
dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)
i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)]
img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))
# use first N_images-1 to train, the LAST is val
self.all_rays = []
self.all_rgbs = []
for i in img_list:
image_path = self.image_paths[i]
c2w = torch.FloatTensor(self.poses[i])
img = Image.open(image_path).convert('RGB')
if self.downsample != 1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (3, h, w)
img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB
self.all_rgbs += [img]
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
# viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
if not self.is_stack:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h,w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
def define_transforms(self):
self.transform = T.ToTensor()
def __len__(self):
return len(self.all_rgbs)
def __getitem__(self, idx):
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}
return sample
================================================
FILE: dataLoader/nsvf.py
================================================
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
trans_t = lambda t : torch.Tensor([
[1,0,0,0],
[0,1,0,0],
[0,0,1,t],
[0,0,0,1]]).float()
rot_phi = lambda phi : torch.Tensor([
[1,0,0,0],
[0,np.cos(phi),-np.sin(phi),0],
[0,np.sin(phi), np.cos(phi),0],
[0,0,0,1]]).float()
rot_theta = lambda th : torch.Tensor([
[np.cos(th),0,-np.sin(th),0],
[0,1,0,0],
[np.sin(th),0, np.cos(th),0],
[0,0,0,1]]).float()
def pose_spherical(theta, phi, radius):
c2w = trans_t(radius)
c2w = rot_phi(phi/180.*np.pi) @ c2w
c2w = rot_theta(theta/180.*np.pi) @ c2w
c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
return c2w
class NSVF(Dataset):
"""NSVF Generic Dataset."""
def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False):
self.root_dir = datadir
self.split = split
self.is_stack = is_stack
self.downsample = downsample
self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
self.define_transforms()
self.white_bg = True
self.near_far = [0.5,6.0]
self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.define_proj_mat()
self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
def bbox2corners(self):
corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
for i in range(3):
corners[i,[0,1],i] = corners[i,[1,0],i]
return corners.view(-1,3)
def read_meta(self):
with open(os.path.join(self.root_dir, "intrinsics.txt")) as f:
focal = float(f.readline().split()[0])
self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]])
self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1)
pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
if self.split == 'train':
pose_files = [x for x in pose_files if x.startswith('0_')]
img_files = [x for x in img_files if x.startswith('0_')]
elif self.split == 'val':
pose_files = [x for x in pose_files if x.startswith('1_')]
img_files = [x for x in img_files if x.startswith('1_')]
elif self.split == 'test':
test_pose_files = [x for x in pose_files if x.startswith('2_')]
test_img_files = [x for x in img_files if x.startswith('2_')]
if len(test_pose_files) == 0:
test_pose_files = [x for x in pose_files if x.startswith('1_')]
test_img_files = [x for x in img_files if x.startswith('1_')]
pose_files = test_pose_files
img_files = test_img_files
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
frames = 200
self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,frames+1)[:-1]], 0)
self.poses = []
self.all_rays = []
self.all_rgbs = []
assert len(img_files) == len(pose_files)
for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
image_path = os.path.join(self.root_dir, 'rgb', img_fname)
img = Image.open(image_path)
if self.downsample!=1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (4, h, w)
img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
if img.shape[-1]==4:
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.all_rgbs += [img]
c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv
c2w = torch.FloatTensor(c2w)
self.poses.append(c2w) # C2W
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
# w2c = torch.inverse(c2w)
#
self.poses = torch.stack(self.poses)
if 'train' == self.split:
if self.is_stack:
self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
else:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
def define_transforms(self):
self.transform = T.ToTensor()
def define_proj_mat(self):
self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
def world2ndc(self, points):
device = points.device
return (points - self.center.to(device)) / self.radius.to(device)
def __len__(self):
if self.split == 'train':
return len(self.all_rays)
return len(self.all_rgbs)
def __getitem__(self, idx):
if self.split == 'train': # use data in the buffers
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}
else: # create data for each image separately
img = self.all_rgbs[idx]
rays = self.all_rays[idx]
sample = {'rays': rays,
'rgbs': img}
return sample
================================================
FILE: dataLoader/ray_utils.py
================================================
import torch, re, json
import numpy as np
from torch import searchsorted
from kornia import create_meshgrid
def load_json(path):
with open(path, 'r') as f:
return json.load(f)
# from utils import index_point_feature
def depth2dist(z_vals, cos_angle):
# z_vals: [N_ray N_sample]
device = z_vals.device
dists = z_vals[..., 1:] - z_vals[..., :-1]
dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples]
dists = dists * cos_angle.unsqueeze(-1)
return dists
def ndc2dist(ndc_pts, cos_angle):
dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1)
dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples]
return dists
def get_ray_directions(H, W, focal, center=None):
"""
Get ray directions for all pixels in camera coordinate.
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
ray-tracing-generating-camera-rays/standard-coordinate-systems
Inputs:
H, W, focal: image height, width and focal length
Outputs:
directions: (H, W, 3), the direction of the rays in camera coordinate
"""
grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5
i, j = grid.unbind(-1)
# the direction here is without +0.5 pixel centering as calibration is not so accurate
# see https://github.com/bmild/nerf/issues/24
cent = center if center is not None else [W / 2, H / 2]
directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
return directions
def get_ray_directions_blender(H, W, focal, center=None):
"""
Get ray directions for all pixels in camera coordinate.
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
ray-tracing-generating-camera-rays/standard-coordinate-systems
Inputs:
H, W, focal: image height, width and focal length
Outputs:
directions: (H, W, 3), the direction of the rays in camera coordinate
"""
grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5
i, j = grid.unbind(-1)
# the direction here is without +0.5 pixel centering as calibration is not so accurate
# see https://github.com/bmild/nerf/issues/24
cent = center if center is not None else [W / 2, H / 2]
directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)],
-1) # (H, W, 3)
return directions
def get_rays(directions, c2w):
"""
Get ray origin and normalized directions in world coordinate for all pixels in one image.
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
ray-tracing-generating-camera-rays/standard-coordinate-systems
Inputs:
directions: (H, W, 3) precomputed ray directions in camera coordinate
c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
Outputs:
rays_o: (H*W, 3), the origin of the rays in world coordinate
rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
"""
# Rotate ray directions from camera coordinate to the world coordinate
rays_d = directions @ c2w[:3, :3].T # (H, W, 3)
# rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
# The origin of all rays is the camera origin in world coordinate
rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3)
rays_d = rays_d.view(-1, 3)
rays_o = rays_o.view(-1, 3)
return rays_o, rays_d
def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
# Shift ray origins to near plane
t = -(near + rays_o[..., 2]) / rays_d[..., 2]
rays_o = rays_o + t[..., None] * rays_d
# Projection
o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
o2 = 1. + 2. * near / rays_o[..., 2]
d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
d2 = -2. * near / rays_o[..., 2]
rays_o = torch.stack([o0, o1, o2], -1)
rays_d = torch.stack([d0, d1, d2], -1)
return rays_o, rays_d
def ndc_rays(H, W, focal, near, rays_o, rays_d):
# Shift ray origins to near plane
t = (near - rays_o[..., 2]) / rays_d[..., 2]
rays_o = rays_o + t[..., None] * rays_d
# Projection
o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
o2 = 1. - 2. * near / rays_o[..., 2]
d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
d2 = 2. * near / rays_o[..., 2]
rays_o = torch.stack([o0, o1, o2], -1)
rays_d = torch.stack([d0, d1, d2], -1)
return rays_o, rays_d
# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
device = weights.device
# Get pdf
weights = weights + 1e-5 # prevent nans
pdf = weights / torch.sum(weights, -1, keepdim=True)
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
# Take uniform samples
if det:
u = torch.linspace(0., 1., steps=N_samples, device=device)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:
u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device)
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)
# Invert CDF
u = u.contiguous()
inds = searchsorted(cdf.detach(), u, right=True)
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[..., 1] - cdf_g[..., 0])
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
t = (u - cdf_g[..., 0]) / denom
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
return samples
def dda(rays_o, rays_d, bbox_3D):
inv_ray_d = 1.0 / (rays_d + 1e-6)
t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3
t_max = (bbox_3D[1:] - rays_o) * inv_ray_d
t = torch.stack((t_min, t_max)) # 2 N_rays 3
t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0]
t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0]
return t_min, t_max
def ray_marcher(rays,
N_samples=64,
lindisp=False,
perturb=0,
bbox_3D=None):
"""
sample points along the rays
Inputs:
rays: ()
Returns:
"""
# Decompose the inputs
N_rays = rays.shape[0]
rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)
if bbox_3D is not None:
# cal aabb boundles
near, far = dda(rays_o, rays_d, bbox_3D)
# Sample depth points
z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples)
if not lindisp: # use linear sampling in depth space
z_vals = near * (1 - z_steps) + far * z_steps
else: # use linear sampling in disparity space
z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)
z_vals = z_vals.expand(N_rays, N_samples)
if perturb > 0: # perturb sampling depths (z_vals)
z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points
# get intervals between samples
upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1)
lower = torch.cat([z_vals[:, :1], z_vals_mid], -1)
perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)
z_vals = lower + (upper - lower) * perturb_rand
xyz_coarse_sampled = rays_o.unsqueeze(1) + \
rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3)
return xyz_coarse_sampled, rays_o, rays_d, z_vals
def read_pfm(filename):
file = open(filename, 'rb')
color = None
width = None
height = None
scale = None
endian = None
header = file.readline().decode('utf-8').rstrip()
if header == 'PF':
color = True
elif header == 'Pf':
color = False
else:
raise Exception('Not a PFM file.')
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception('Malformed PFM header.')
scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = '<'
scale = -scale
else:
endian = '>' # big-endian
data = np.fromfile(file, endian + 'f')
shape = (height, width, 3) if color else (height, width)
data = np.reshape(data, shape)
data = np.flipud(data)
file.close()
return data, scale
class SimpleSampler:
def __init__(self, total, batch):
self.total = total
self.batch = batch
self.curr = total
self.ids = None
def nextids(self):
self.curr += self.batch
if self.curr + self.batch > self.total:
self.ids = torch.LongTensor(np.random.permutation(self.total))
self.curr = 0
return self.ids[self.curr:self.curr + self.batch]
def ndc_bbox(all_rays):
near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0]
near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0]
far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0]
far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0]
print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}')
return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max)))
def pose_from_json(meta, transpose):
c2ws = []
for frame in meta['frames']:
c2ws.append(np.array(frame['transform_matrix'])@transpose)
return np.stack(c2ws)
def rotation_matrix_from_vectors(vec1, vec2):
""" Find the rotation matrix that aligns vec1 to vec2
:param vec1: A 3d "source" vector
:param vec2: A 3d "destination" vector
:return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
"""
a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
v = np.cross(a, b)
if any(v): # if not all zeros then
c = np.dot(a, b)
s = np.linalg.norm(v)
kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2))
else:
return np.eye(3) # cross of all zeros only occurs on identical directions
def normalize(x):
return x / np.linalg.norm(x)
def rotation_up(poses):
up = normalize(np.linalg.lstsq(poses[:, :3, 1],np.ones((poses.shape[0],1)))[0])
rot = rotation_matrix_from_vectors(up,np.array([0.,1.,0.]))
return rot
def search_orientation(points):
from scipy.spatial.transform import Rotation as R
bbox_sizes,rot_mats,bboxs = [],[],[]
for y_angle in np.linspace(-45,45,15):
rotvec = np.array([0,y_angle,0])/180*np.pi
rot = R.from_rotvec(rotvec).as_matrix()
point_orientation = rot@points
bbox = np.max(point_orientation,axis=1) - np.min(point_orientation,axis=1)
bbox_sizes.append(np.prod(bbox))
rot_mats.append(rot)
bboxs.append(bbox)
rot = rot_mats[np.argmin(bbox_sizes)]
bbox = bboxs[np.argmin(bbox_sizes)]
return rot,bbox
def load_point_txt(path):
points = []
with open(path, "r") as f:
for line in f:
if line[0] == "#":
continue
els = line.split(" ")
err = float(els[7])
if err > 0.25:
continue
points.append([float(els[1]), float(els[2]), float(els[3])])
points = np.stack(points)
# mask out outerliner with a sphere cover 95 percet poitns
center = np.mean(points, axis=0)
dist = np.sqrt(np.sum((points - center) ** 2, axis=1))
radius_idx = np.argsort(dist)[:int(points.shape[0] * 0.95)]
points = points[radius_idx]
return points
p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1)
def orientation(poses, point_path=None):
if point_path is not None:
points = load_point_txt(point_path)
else:
points = poses[:,:3,-1]
# np.savetxt('points_1.txt', points)
# np.savetxt('poses_center_1.txt', poses[:, :3, -1])
rot_up = rotation_up(poses)
poses_new = rot_up[None] @ poses[:, :3]
points = rot_up[:3, :3] @ points.T
center = np.mean(poses_new[:,:3,-1], axis=0, keepdims=True)
poses_new[:, :3, -1] -= center
points -= center.T
# scale = 1.5
rot, bbox = search_orientation(poses_new[:, :3, -1].T)
poses_new[:, :3, :4] = rot[None] @ poses_new[:, :3, :4]
points = rot @ points
aabb = np.stack((np.min(points,axis=1),np.max(points,axis=1)))
poses_new[:, :3, -1] /= np.min(aabb[1]-aabb[0])
points /= np.min(aabb[1] - aabb[0])
aabb = (aabb/np.min(aabb[1]-aabb[0])*1.5).tolist()
np.savetxt('points.txt', points.T)
np.savetxt('poses_center.txt', poses_new[:, :3, -1])
return p34_to_44(poses_new), aabb
def spherify_poses(poses, radus=1):
p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1)
rays_d = poses[:, :3, 2:3]
rays_o = poses[:, :3, 3:4]
def min_line_dist(rays_o, rays_d):
A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
b_i = -A_i @ rays_o
pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0))
return pt_mindist
pt_mindist = min_line_dist(rays_o, rays_d)
center = pt_mindist
up = (poses[:, :3, 3] - center).mean(0)
vec0 = normalize(up)
vec1 = normalize(np.cross([.1, .2, .3], vec0))
vec2 = normalize(np.cross(vec0, vec1))
pos = center
c2w = np.stack([vec1, vec2, vec0, pos], 1)
poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))*radus
# print(poses_reset,center,rad)
scale = 1. / rad
poses_reset[:, :3, 3] *= scale
# bds *= sc
rad *= scale
# print ('========>',poses_reset)
centroid = np.mean(poses_reset[:, :3, 3], 0)
zh = centroid[2]
radcircle = np.sqrt(rad ** 2 - zh ** 2)
render_poses = []
for th in np.linspace(0., 2. * np.pi, 120):
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
up = np.array([0, 0, -1.])
vec2 = normalize(camorigin)
vec0 = normalize(np.cross(vec2, up))
vec1 = normalize(np.cross(vec2, vec0))
pos = camorigin
p = np.stack([vec0, vec1, vec2, pos], 1)
render_poses.append(p)
render_poses = np.stack(render_poses, 0)
# render_poses = np.concatenate([render_poses, np.broadcast_to(poses[0, :3, -1:], render_poses[:, :3, -1:].shape)], -1)
# poses_reset = np.concatenate(
# [poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)], -1)
return poses_reset, render_poses, scale
================================================
FILE: dataLoader/sdf.py
================================================
import torch
import numpy as np
from torch.utils.data import Dataset
def N_to_reso(avg_reso, bbox):
xyz_min, xyz_max = bbox
dim = len(xyz_min)
n_voxels = avg_reso**dim
voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)
return torch.ceil((xyz_max - xyz_min) / voxel_size).long().tolist()
def load(path, split, dtype='points'):
if 'grid' == dtype:
sdf = torch.from_numpy(np.load(path).astype(np.float32))
D, H, W = sdf.shape
z, y, x = torch.meshgrid(torch.arange(0, D), torch.arange(0, H), torch.arange(0, W), indexing='ij')
coordiante = torch.stack((x,y,z),-1).reshape(D*H*W,3)#*2-1 # normalize to [-1,1]
sdf = sdf.reshape(D*H*W,-1)
DHW = [D,H,W]
elif 'points' == dtype:
DHW = [640] * 3
sdf_dict = np.load(path, allow_pickle=True).item()
sdf = torch.from_numpy(sdf_dict[f'sdfs_{split}'].astype(np.float32)).reshape(-1,1)
coordiante = torch.from_numpy(sdf_dict[f'points_{split}'].astype(np.float32))
aabb = [[-1,-1,-1],[1,1,1]]
coordiante = ((coordiante + 1) / 2 * (torch.tensor(DHW[::-1]))).reshape(-1,3)
DHW = DHW[::-1]
return coordiante, sdf, DHW
class SDFDataset(Dataset):
def __init__(self, cfg, split='train'):
datadir = cfg.datadir
self.coordiante, self.sdf, self.DHW = load(datadir, split)
[D, H, W] = self.DHW
self.scene_bbox = [[0., 0., 0.], [W, H, D]]
cfg.aabb = self.scene_bbox
def __len__(self):
return len(self.sdf)
def __getitem__(self, idx):
sample = {'rgb': self.sdf[idx],
'xy': self.coordiante[idx]}
return sample
================================================
FILE: dataLoader/tankstemple.py
================================================
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
if axis == 'z':
return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h]
elif axis == 'y':
return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)]
else:
return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)]
def cross(x, y, axis=0):
T = torch if isinstance(x, torch.Tensor) else np
return T.cross(x, y, axis)
def normalize(x, axis=-1, order=2):
if isinstance(x, torch.Tensor):
l2 = x.norm(p=order, dim=axis, keepdim=True)
return x / (l2 + 1e-8), l2
else:
l2 = np.linalg.norm(x, order, axis)
l2 = np.expand_dims(l2, axis)
l2[l2 == 0] = 1
return x / l2,
def cat(x, axis=1):
if isinstance(x[0], torch.Tensor):
return torch.cat(x, dim=axis)
return np.concatenate(x, axis=axis)
def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):
"""
This function takes a vector 'camera_position' which specifies the location
of the camera in world coordinates and two vectors `at` and `up` which
indicate the position of the object and the up directions of the world
coordinate system respectively. The object is assumed to be centered at
the origin.
The output is a rotation matrix representing the transformation
from world coordinates -> view coordinates.
Input:
camera_position: 3
at: 1 x 3 or N x 3 (0, 0, 0) in default
up: 1 x 3 or N x 3 (0, 1, 0) in default
"""
if at is None:
at = torch.zeros_like(camera_position)
else:
at = torch.tensor(at).type_as(camera_position)
if up is None:
up = torch.zeros_like(camera_position)
up[2] = -1
else:
up = torch.tensor(up).type_as(camera_position)
z_axis = normalize(at - camera_position)[0]
x_axis = normalize(cross(up, z_axis))[0]
y_axis = normalize(cross(z_axis, x_axis))[0]
R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)
return R
def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
c2ws = []
for t in range(frames):
c2w = torch.eye(4)
cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi))
cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True)
c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot
c2ws.append(c2w)
return torch.stack(c2ws)
class TanksTempleDataset(Dataset):
"""NSVF Generic Dataset."""
def __init__(self, cfg, split='train'):
self.root_dir = cfg.datadir
self.split = split
self.is_stack = False if 'train'==split else True
self.downsample = cfg.get(f'downsample_{self.split}')
self.img_wh = (int(1920 / self.downsample), int(1080 / self.downsample))
self.define_transforms()
self.white_bg = True
self.near_far = [0.01, 6.0]
self.scene_bbox = (torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2, 3) * 1.2).tolist()
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.define_proj_mat()
# self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
# self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
def bbox2corners(self):
corners = self.scene_bbox.unsqueeze(0).repeat(4, 1, 1)
for i in range(3):
corners[i, [0, 1], i] = corners[i, [1, 0], i]
return corners.view(-1, 3)
def read_meta(self):
self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt"))
self.intrinsics[:2] *= (np.array(self.img_wh) / np.array([1920, 1080])).reshape(2, 1)
pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
if self.split == 'train':
pose_files = [x for x in pose_files if x.startswith('0_')]
img_files = [x for x in img_files if x.startswith('0_')]
elif self.split == 'val':
pose_files = [x for x in pose_files if x.startswith('1_')]
img_files = [x for x in img_files if x.startswith('1_')]
elif self.split == 'test':
test_pose_files = [x for x in pose_files if x.startswith('2_')]
test_img_files = [x for x in img_files if x.startswith('2_')]
if len(test_pose_files) == 0:
test_pose_files = [x for x in pose_files if x.startswith('1_')]
test_img_files = [x for x in img_files if x.startswith('1_')]
pose_files = test_pose_files
img_files = test_img_files
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0],
[self.intrinsics[0, 0], self.intrinsics[1, 1]],
center=self.intrinsics[:2, 2]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
self.poses = []
ray_per_frame = self.img_wh[0]*self.img_wh[1]
self.all_rays = torch.empty(len(pose_files),ray_per_frame,6)
self.all_rgbs = torch.empty(len(pose_files),ray_per_frame,3)
assert len(img_files) == len(pose_files)
for i in tqdm(range(len(pose_files)),desc=f'Loading data {self.split} ({len(img_files)})'):
img_fname, pose_fname = img_files[i], pose_files[i]
image_path = os.path.join(self.root_dir, 'rgb', img_fname)
img = Image.open(image_path)
if self.downsample != 1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (4, h, w)
img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
if img.shape[-1] == 4:
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
# self.all_rgbs.append(img)
c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) # @ cam_trans
c2w = torch.FloatTensor(c2w)
self.poses.append(c2w) # C2W
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
self.all_rays[i] = torch.cat([rays_o, rays_d], 1) # (h*w, 6)
self.all_rgbs[i] = img
self.poses = torch.stack(self.poses)
frames = 200
scene_bbox = torch.tensor(self.scene_bbox).float()
center = torch.mean(scene_bbox, dim=0)
radius = torch.norm(scene_bbox[1] - center) * 1.2
up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
pos_gen = circle(radius=radius, h=-0.2 * up[1], axis='y')
self.render_path = gen_path(pos_gen, up=up, frames=frames)
self.render_path[:, :3, 3] += center
if 'train' == self.split:
if not self.is_stack:
# self.all_rays = torch.stack(self.all_rays, 0).reshape(-1, *self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
# self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
# else:
self.all_rays = self.all_rays.reshape(-1,6) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = self.all_rgbs.reshape(-1,3) # (len(self.meta['frames])*h*w, 3)
else:
# self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
def define_transforms(self):
self.transform = T.ToTensor()
def define_proj_mat(self):
self.proj_mat = torch.from_numpy(self.intrinsics[:3, :3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:, :3]
def world2ndc(self, points):
device = points.device
return (points - self.center.to(device)) / self.radius.to(device)
def __len__(self):
if self.split == 'train':
return len(self.all_rays)
return len(self.all_rgbs)
def __getitem__(self, idx):
if self.split == 'train': # use data in the buffers
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}
else: # create data for each image separately
img = self.all_rgbs[idx]
rays = self.all_rays[idx]
sample = {'rays': rays,
'rgbs': img}
return sample
================================================
FILE: dataLoader/your_own_data.py
================================================
import torch,cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
from .ray_utils import *
class YourOwnDataset(Dataset):
def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):
self.N_vis = N_vis
self.root_dir = datadir
self.split = split
self.is_stack = is_stack
self.downsample = downsample
self.define_transforms()
self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
self.read_meta()
self.define_proj_mat()
self.white_bg = True
self.near_far = [0.1,100.0]
self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
self.downsample=downsample
def read_depth(self, filename):
depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
return depth
def read_meta(self):
with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
self.meta = json.load(f)
w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample)
self.img_wh = [w,h]
self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length
self.cx, self.cy = self.meta['cx'],self.meta['cy']
# ray directions for all pixels, same for all images (same H, W, focal)
self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3)
self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float()
self.image_paths = []
self.poses = []
self.all_rays = []
self.all_rgbs = []
self.all_masks = []
self.all_depth = []
img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
frame = self.meta['frames'][i]
pose = np.array(frame['transform_matrix']) @ self.blender2opencv
c2w = torch.FloatTensor(pose)
self.poses += [c2w]
image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
self.image_paths += [image_path]
img = Image.open(image_path)
if self.downsample!=1.0:
img = img.resize(self.img_wh, Image.LANCZOS)
img = self.transform(img) # (4, h, w)
img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA
if img.shape[-1]==4:
img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
self.all_rgbs += [img]
rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
self.poses = torch.stack(self.poses)
if not self.is_stack:
self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
# self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
else:
self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
# self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
def define_transforms(self):
self.transform = T.ToTensor()
def define_proj_mat(self):
self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
def world2ndc(self,points,lindisp=None):
device = points.device
return (points - self.center.to(device)) / self.radius.to(device)
def __len__(self):
return len(self.all_rgbs)
def __getitem__(self, idx):
if self.split == 'train': # use data in the buffers
sample = {'rays': self.all_rays[idx],
'rgbs': self.all_rgbs[idx]}
else: # create data for each image separately
img = self.all_rgbs[idx]
rays = self.all_rays[idx]
mask = self.all_masks[idx] # for quantity evaluation
sample = {'rays': rays,
'rgbs': img}
return sample
================================================
FILE: models/FactorFields.py
================================================
import torch, math
import torch.nn
import torch.nn.functional as F
import numpy as np
import time, skimage
from utils import N_to_reso, N_to_vm_reso
# import BasisCoding
def grid_mapping(positions, freq_bands, aabb, basis_mapping='sawtooth'):
aabbSize = max(aabb[1] - aabb[0])
scale = aabbSize[..., None] / freq_bands
if basis_mapping == 'triangle':
pts_local = (positions - aabb[0]).unsqueeze(-1) % scale
pts_local_int = ((positions - aabb[0]).unsqueeze(-1) // scale) % 2
pts_local = pts_local / (scale / 2) - 1
pts_local = torch.where(pts_local_int == 1, -pts_local, pts_local)
elif basis_mapping == 'sawtooth':
pts_local = (positions - aabb[0])[..., None] % scale
pts_local = pts_local / (scale / 2) - 1
pts_local = pts_local.clamp(-1., 1.)
elif basis_mapping == 'sinc':
pts_local = torch.sin((positions - aabb[0])[..., None] / (scale / np.pi) - np.pi / 2)
elif basis_mapping == 'trigonometric':
pts_local = (positions - aabb[0])[..., None] / scale * 2 * np.pi
pts_local = torch.cat((torch.sin(pts_local), torch.cos(pts_local)), dim=-1)
elif basis_mapping == 'x':
pts_local = (positions - aabb[0]).unsqueeze(-1) / scale
# elif basis_mapping=='hash':
# pts_local = (positions - aabb[0])/max(aabbSize)
return pts_local
def dct_dict(n_atoms_fre, size, n_selete, dim=2):
"""
Create a dictionary using the Discrete Cosine Transform (DCT) basis. If n_atoms is
not a perfect square, the returned dictionary will have ceil(sqrt(n_atoms))**2 atoms
:param n_atoms:
Number of atoms in dict
:param size:
Size of first patch dim
:return:
DCT dictionary, shape (size*size, ceil(sqrt(n_atoms))**2)
"""
# todo flip arguments to match random_dictionary
p = n_atoms_fre # int(math.ceil(math.sqrt(n_atoms)))
dct = np.zeros((p, size))
for k in range(p):
basis = np.cos(np.arange(size) * k * math.pi / p)
if k > 0:
basis = basis - np.mean(basis)
dct[k] = basis
kron = np.kron(dct, dct)
if 3 == dim:
kron = np.kron(kron, dct)
if n_selete < kron.shape[0]:
idx = [x[0] for x in np.array_split(np.arange(kron.shape[0]), n_selete)]
kron = kron[idx]
for col in range(kron.shape[0]):
norm = np.linalg.norm(kron[col]) or 1
kron[col] /= norm
kron = torch.FloatTensor(kron)
return kron
def positional_encoding(positions, freqs):
freq_bands = (2 ** torch.arange(freqs).float()).to(positions.device) # (F,)
pts = (positions[..., None] * freq_bands).reshape(
positions.shape[:-1] + (freqs * positions.shape[-1],)) # (..., DF)
pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
return pts
def raw2alpha(sigma, dist):
# sigma, dist [N_rays, N_samples]
alpha = 1. - torch.exp(-sigma * dist)
T = torch.cumprod(torch.cat([torch.ones_like(alpha[..., :1]), 1. - alpha + 1e-10], -1), -1)
weights = alpha * T[..., :-1] # [N_rays, N_samples]
return alpha, weights, T[..., -1:]
class AlphaGridMask(torch.nn.Module):
def __init__(self, device, aabb, alpha_volume):
super(AlphaGridMask, self).__init__()
self.device = device
self.aabb = aabb.to(self.device)
self.aabbSize = self.aabb[1] - self.aabb[0]
self.invgridSize = 1.0 / self.aabbSize * 2
self.alpha_volume = alpha_volume.view(1, 1, *alpha_volume.shape[-3:])
self.gridSize = torch.LongTensor([alpha_volume.shape[-1], alpha_volume.shape[-2], alpha_volume.shape[-3]]).to(
self.device)
def sample_alpha(self, xyz_sampled):
xyz_sampled = self.normalize_coord(xyz_sampled)
alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1, -1, 1, 1, 3), align_corners=True).view(-1)
return alpha_vals
def normalize_coord(self, xyz_sampled):
return (xyz_sampled - self.aabb[0]) * self.invgridSize - 1
class MLPMixer(torch.nn.Module):
def __init__(self,
in_dim,
out_dim=16,
num_layers=2,
hidden_dim=64, pe=0, with_dropout=False):
super().__init__()
self.with_dropout = with_dropout
self.in_dim = in_dim + 2 * in_dim * pe
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.pe = pe
backbone = []
for l in range(num_layers):
if l == 0:
layer_in_dim = self.in_dim
else:
layer_in_dim = self.hidden_dim
if l == num_layers - 1:
layer_out_dim, bias = out_dim, False
else:
layer_out_dim, bias = self.hidden_dim, True
backbone.append(torch.nn.Linear(layer_in_dim, layer_out_dim, bias=bias))
self.backbone = torch.nn.ModuleList(backbone)
# torch.nn.init.constant_(backbone[0].weight.data, 1.0/self.in_dim)
def forward(self, x, is_train=False):
# x: [B, 3]
h = x
if self.pe > 0:
h = torch.cat([h, positional_encoding(h, self.pe)], dim=-1)
if self.with_dropout and is_train:
h = F.dropout(h, p=0.1)
for l in range(self.num_layers):
h = self.backbone[l](h)
if l != self.num_layers - 1: # l!=0 and
h = F.relu(h, inplace=True)
# h = torch.sin(h)
# sigma, feat = h[...,0], h[...,1:]
return h
class MLPRender_Fea(torch.nn.Module):
def __init__(self, inChanel, num_layers=3, hidden_dim=64, viewpe=6, feape=2):
super(MLPRender_Fea, self).__init__()
self.in_mlpC = 3 + inChanel + 2 * viewpe * 3 + 2 * feape * inChanel
self.num_layers = num_layers
self.viewpe = viewpe
self.feape = feape
mlp = []
for l in range(num_layers):
if l == 0:
in_dim = self.in_mlpC
else:
in_dim = hidden_dim
if l == num_layers - 1:
out_dim, bias = 3, False # 3 rgb
else:
out_dim, bias = hidden_dim, True
mlp.append(torch.nn.Linear(in_dim, out_dim, bias=bias))
self.mlp = torch.nn.ModuleList(mlp)
# torch.nn.init.constant_(self.mlp[-1].bias, 0)
def forward(self, viewdirs, features):
indata = [features, viewdirs]
if self.feape > 0:
indata += [positional_encoding(features, self.feape)]
if self.viewpe > 0:
indata += [positional_encoding(viewdirs, self.viewpe)]
h = torch.cat(indata, dim=-1)
for l in range(self.num_layers):
h = self.mlp[l](h)
if l != self.num_layers - 1:
h = F.relu(h, inplace=True)
rgb = torch.sigmoid(h)
return rgb
class FactorFields(torch.nn.Module):
def __init__(self, cfg, device):
super(FactorFields, self).__init__()
self.cfg = cfg
self.device = device
self.matMode = [[0, 1], [0, 2], [1, 2]]
self.vecMode = [2, 1, 0]
self.n_scene, self.scene_idx = 1, 0
self.alphaMask = None
self.coeff_type, self.basis_type = cfg.model.coeff_type, cfg.model.basis_type
self.setup_params(self.cfg.dataset.aabb)
if self.cfg.model.coeff_type != 'none':
self.coeffs = self.init_coef()
if self.cfg.model.basis_type != 'none':
self.basises = self.init_basis()
out_dim = cfg.model.out_dim
if 'vm' in self.coeff_type:
in_dim = sum(cfg.model.basis_dims) * 3
elif 'x' in self.cfg.model.basis_type:
in_dim = len(
cfg.model.basis_dims) * 2 * self.in_dim if self.cfg.model.basis_mapping == 'trigonometric' else len(
cfg.model.basis_dims) * self.in_dim
else:
in_dim = sum(cfg.model.basis_dims)
self.linear_mat = MLPMixer(in_dim, out_dim, num_layers=cfg.model.num_layers, hidden_dim=cfg.model.hidden_dim,
with_dropout=cfg.model.with_dropout).to(device)
if 'reconstruction' in cfg.defaults.mode:
# self.cur_volumeSize = N_to_reso(cfg.training.volume_resoInit, self.aabb)
# self.update_renderParams(self.cur_volumeSize)
view_pe, fea_pe = cfg.renderer.view_pe, cfg.renderer.fea_pe
num_layers, hidden_dim = cfg.renderer.num_layers, cfg.renderer.hidden_dim
self.renderModule = MLPRender_Fea(inChanel=out_dim - 1, num_layers=num_layers, hidden_dim=hidden_dim,
viewpe=view_pe, feape=fea_pe).to(device)
self.is_unbound = self.cfg.dataset.is_unbound
if self.is_unbound:
self.bg_len = 0.2
self.inward_aabb = torch.tensor([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]]).to(device)
self.aabb = self.inward_aabb * (1 + self.bg_len)
else:
self.inward_aabb = self.aabb
# self.freq_bands = torch.FloatTensor(cfg.model.freq_bands).to(device)
self.cur_volumeSize = N_to_reso(cfg.training.volume_resoInit ** self.in_dim, self.aabb)
self.update_renderParams(self.cur_volumeSize)
print('=====> total parameters: ', self.n_parameters())
def setup_params(self, aabb):
self.in_dim = len(aabb[0]) - 1 if (
'images' == self.cfg.defaults.mode or 'reconstructions' == self.cfg.defaults.mode) else len(aabb[0])
self.aabb = torch.FloatTensor(aabb)[:, :self.in_dim].to(self.device)
self.basis_dims = self.cfg.model.basis_dims
if 'reconstruction' not in self.cfg.defaults.mode:
self.basis_reso = self.cfg.model.basis_resos if 'image' in self.cfg.defaults.mode else np.round(
np.array(self.cfg.model.basis_resos) * (min(aabb[1][:self.in_dim]) + 1) / 1024.0).astype('int').tolist()
self.T_basis = self.cfg.model.T_basis if self.cfg.model.T_basis>0 else sum(np.power(np.array(self.basis_reso), self.in_dim) * np.array(self.cfg.model.basis_dims))
self.T_coeff = self.cfg.model.T_coeff if self.cfg.model.T_coeff>0 else self.cfg.model.total_params - self.T_basis
self.T_coeff = self.T_coeff if self.T_coeff > 0 else 8 ** self.in_dim * sum(self.basis_dims)
if 'image' == self.cfg.defaults.mode:
self.freq_bands = max(aabb[1][:self.in_dim]) / torch.FloatTensor(self.basis_reso).to(self.device)
else:
self.freq_bands = torch.FloatTensor(self.cfg.model.freq_bands).to(self.device)
self.coeff_reso = N_to_reso(self.T_coeff // sum(self.basis_dims), self.aabb[:, :self.in_dim])[::-1] # DHW
self.n_scene = 1 # int(aabb[1][-1]) if 'images' == self.cfg.defaults.mode else 1
if 'sdf' == self.cfg.defaults.mode:
self.freq_bands *= 0.5
elif 'images' == self.cfg.defaults.mode:
self.coeff_reso = [aabb[1][-1]] + self.coeff_reso
self.aabb = torch.FloatTensor(aabb).to(self.device)
if 'vec' in self.coeff_type or 'cp' in self.coeff_type or 'vm' in self.coeff_type:
self.coeff_reso = aabb[1]
else:
self.coeff_reso = N_to_reso(self.cfg.model.coeff_reso ** self.in_dim, self.aabb[:, :self.in_dim])[::-1]
self.T_coeff = sum(self.cfg.model.basis_dims) * np.prod(self.coeff_reso)
self.T_basis = self.cfg.model.total_params - self.T_coeff
scale = self.T_basis / sum(
np.power(np.array(self.cfg.model.basis_resos), self.in_dim) * np.array(self.cfg.model.basis_dims))
scale = np.power(scale, 1.0 / self.in_dim)
self.basis_reso = self.cfg.model.basis_resos if (
'vec' in self.basis_type or 'cp' in self.basis_type) else np.round(
np.array(self.cfg.model.basis_resos) * scale).astype('int').tolist()
# self.freq_bands = self.cfg.dataset.scene_reso / torch.FloatTensor(self.basis_resos).to(self.device)
self.freq_bands = torch.FloatTensor(self.cfg.model.freq_bands).to(self.device) \
if (
'reconstructions' == self.cfg.defaults.mode or 'x' in self.basis_type or 'vec' in self.basis_type or 'cp' in self.basis_type) else torch.FloatTensor(
self.cfg.model.freq_bands).to(self.device) * (self.cfg.dataset.scene_reso / float(
max(self.basis_reso)) / max(self.cfg.model.freq_bands))
self.n_scene = int(aabb[1][-1]) if 'reconstructions' == self.cfg.defaults.mode else 1
# print(self.coeff_reso,self.basis_reso,self.freq_bands)
def init_coef(self):
n_scene = self.n_scene if 'reconstructions' == self.cfg.defaults.mode or 'images' == self.cfg.defaults.mode else 1
if 'hash' in self.coeff_type or 'grid' in self.coeff_type:
coeffs = [
self.cfg.model.coef_init * torch.ones((1, sum(self.basis_dims), *self.coeff_reso), device=self.device)
for _ in range(n_scene)]
coeffs = torch.nn.ParameterList(coeffs)
elif 'cp' in self.coeff_type or 'vm' in self.coeff_type:
coeffs = []
for i in range(len(self.coeff_reso)):
coeffs.append(self.cfg.model.coef_init * torch.ones(
(1, sum(self.basis_dims), max(256, self.coeff_reso[i]), n_scene), device=self.device))
coeffs = torch.nn.ParameterList(coeffs)
elif 'vec' in self.coeff_type:
coeffs = self.cfg.model.coef_init * torch.ones(
(1, sum(self.basis_dims), max(256, max(self.coeff_reso)), n_scene), device=self.device)
coeffs = torch.nn.ParameterList([coeffs])
elif 'mlp' in self.coeff_type:
coeffs = torch.nn.ParameterList(
[MLPMixer(self.in_dim, sum(self.basis_dims), num_layers=2, hidden_dim=64, pe=4).to(self.device) for _ in
range(n_scene)])
return coeffs
def init_basis(self):
if 'hash' in self.basis_type:
import tinycudann as tcnn
n_levels = len(self.basis_reso)
if 'reconstruction' not in self.cfg.defaults.mode:
base_resolution_low, base_resolution_high = 32, int(max(self.aabb[1]).item()) // 2
per_level_scale = torch.pow(self.freq_bands[-1] / self.freq_bands[0], 1.0 / n_levels).item()
else:
base_resolution_low = torch.round(self.basis_reso[0] * self.freq_bands[0]).long().item()
per_level_scale = torch.pow(self.freq_bands[0] / self.freq_bands[-1], 1.0 / n_levels).item()
base_resolution_high = torch.round(
self.basis_reso[n_levels // 2] * self.freq_bands[n_levels // 2]).long().item()
log2_hashmap_size = np.round(np.log2(self.T_basis / 3 / np.mean(self.basis_dims)))
# per_level_scale = torch.pow((self.basis_reso[-1] * self.freq_bands[-1])/(self.basis_reso[0] * self.freq_bands[0]),1.0/n_levels).item()
# per_level_scale = torch.pow(self.freq_bands[0] / self.freq_bands[-1], 1.0 / n_levels).item()
basises = []
if len(self.basis_dims) == 1 or sum(self.basis_dims) > 32:
encoding_config_low = {
"otype": "HashGrid",
"n_levels": n_levels,
"n_features_per_level": sum(self.basis_dims) // n_levels,
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_resolution_low,
"per_level_scale": per_level_scale # 1.25992 #1.38191 #
}
basises.append(tcnn.Encoding(
n_input_dims=self.in_dim,
encoding_config=encoding_config_low))
else:
encoding_config_low = {
"otype": "HashGrid",
"n_levels": n_levels // 2,
"n_features_per_level": min(16, self.basis_dims[0]),
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_resolution_low,
"per_level_scale": per_level_scale # 1.25992 #1.38191 #
}
encoding_config_high = {
"otype": "HashGrid",
"n_levels": n_levels // 2,
"n_features_per_level": self.basis_dims[-1],
"log2_hashmap_size": log2_hashmap_size,
"base_resolution": base_resolution_high,
"per_level_scale": per_level_scale # 1.25992 #1.38191 #
}
basises = []
basises.append(tcnn.Encoding(
n_input_dims=self.in_dim,
encoding_config=encoding_config_low))
if self.basis_dims[0] == 32:
basises.append(tcnn.Encoding(
n_input_dims=self.in_dim,
encoding_config=encoding_config_low))
basises.append(tcnn.Encoding(
n_input_dims=self.in_dim,
encoding_config=encoding_config_high))
return torch.nn.ParameterList(basises)
else:
basises, coeffs, n_params_basis = [], [], 0
# in_dim = self.in_dim if 'images' != self.cfg.defaults.mode else self.in_dim - 1
for i, (basis_dim, reso) in enumerate(zip(self.basis_dims, self.basis_reso)):
# reso_cur = N_to_reso(reso, aabb)[::-1]
if 'mlp' in self.basis_type:
basises.append(MLPMixer(self.in_dim, basis_dim, num_layers=2, \
hidden_dim=64, pe=4).to(self.device))
elif 'grid' in self.basis_type:
basises.append(torch.nn.Parameter(dct_dict(int(np.power(basis_dim, 1. / self.in_dim) + 1), reso,
n_selete=basis_dim, dim=self.in_dim).reshape(
[1, basis_dim] + [reso] * self.in_dim).to(self.device)))
# basises.append(torch.nn.Parameter(torch.ones([1, basis_dim] + [reso] * self.in_dim).to(self.device)))
elif 'vm' in self.basis_type:
reso_level = N_to_vm_reso(reso ** self.in_dim, self.aabb[:, :self.in_dim])
for i in range(len(self.matMode)):
mat_id_0, mat_id_1 = self.matMode[i]
basises.append(torch.nn.Parameter(
0.1 * torch.randn((1, basis_dim, reso_level[mat_id_1], reso_level[mat_id_0]),
device=self.device)))
elif 'cp' in self.basis_type:
for _ in range(self.in_dim - 1):
basises.append(torch.nn.Parameter(
0.1 * torch.randn((1, basis_dim, max(reso, 128), 1), device=self.device)))
elif 'x' in self.basis_type:
continue
return torch.nn.ParameterList(basises)
def get_coeff(self, xyz_sampled):
N_points, dim = xyz_sampled.shape
in_dim = self.in_dim
pts = self.normalize_coord(xyz_sampled).view([1, -1] + [1] * (dim - 1) + [dim])
if self.coeff_type in 'hash':
coeffs = self.coeffs(pts * 0.5 + 0.5).float()
elif 'grid' in self.coeff_type:
coeffs = F.grid_sample(self.coeffs[self.scene_idx], pts, mode=self.cfg.model.coef_mode, align_corners=False,
padding_mode='border').view(-1, N_points).t()
elif 'vec' in self.coeff_type:
pts = pts.view(1, -1, 1, in_dim)
idx = (self.scene_idx + 0.5) / self.n_scene * 2 - 1
pts = torch.stack((torch.ones_like(pts[..., 0]) * idx, pts[..., 0]), dim=-1)
coeffs = F.grid_sample(self.coeffs[0], pts, mode=self.cfg.model.coef_mode,
align_corners=False, padding_mode='border').view(-1, N_points).t()
elif 'cp' in self.coeff_type:
pts = pts.squeeze(2)
idx = (self.scene_idx + 0.5) / self.n_scene * 2 - 1
pts = torch.stack((torch.ones_like(pts) * idx, pts), dim=-2)
coeffs = F.grid_sample(self.coeffs[0], pts[..., 0], mode=self.cfg.model.coef_mode,
align_corners=False, padding_mode='border').view(-1, N_points).t()
for i in range(1, in_dim):
coeffs = coeffs * F.grid_sample(self.coeffs[i], pts[..., i], mode=self.cfg.model.coef_mode,
align_corners=False, padding_mode='border').view(-1, N_points).t()
elif 'vm' in self.coeff_type:
pts = pts.squeeze(2)
idx = (self.scene_idx + 0.5) / self.n_scene * 2 - 1
pts = torch.stack((torch.ones_like(pts) * idx, pts), dim=-2)
coeffs = []
for i in range(in_dim):
coeffs.append(F.grid_sample(self.coeffs[i], pts[..., self.vecMode[i]], mode=self.cfg.model.coef_mode,
align_corners=False, padding_mode='border').view(-1, N_points).t())
coeffs = torch.cat(coeffs, dim=-1)
elif 'mlp' in self.coeff_type:
coeffs = self.coeffs[self.scene_idx](pts.view(N_points, in_dim))
elif 'hash' in self.coeff_type:
coeffs = self.coeffs[self.scene_idx]((pts.view(N_points, in_dim) + 1) / 2)
return coeffs
def get_basis(self, x):
N_points = x.shape[0]
if 'images' == self.cfg.defaults.mode:
x = x[..., :-1]
if 'hash' in self.basis_type:
x = (x - self.aabb[0]) / torch.max(self.aabb[1] - self.aabb[0])
if len(self.basises) == 1:
basises = self.basises[0](x).float()
if len(self.basises) == 2:
basises = torch.cat((self.basises[0](x), self.basises[1](x)), dim=-1).float()
elif len(self.basises) == 3:
basises = torch.cat((self.basises[0](x), self.basises[1](x), self.basises[2](x)), dim=-1).float()
else:
freq_len = len(self.freq_bands)
xyz = grid_mapping(x, self.freq_bands, self.aabb[:, :self.in_dim], self.cfg.model.basis_mapping).view(1, *(
[1] * (self.in_dim - 1)), -1, self.in_dim, freq_len)
basises = []
for i in range(freq_len):
if 'mlp' in self.basis_type:
basises.append(self.basises[i](xyz[..., i].view(-1, self.in_dim)))
elif 'grid' in self.basis_type:
basises.append(
F.grid_sample(self.basises[i], xyz[..., i], mode=self.cfg.model.basis_mode,
align_corners=True).view(-1, N_points).T)
elif 'vm' in self.basis_type:
coordinate_mat = torch.stack((xyz[..., self.matMode[0], i], xyz[..., self.matMode[1], i],
xyz[..., self.matMode[2], i])).view(3, -1, 1, 2)
for idx_mat in range(self.in_dim):
basises.append(F.grid_sample(self.basises[i * self.in_dim + idx_mat], coordinate_mat[[idx_mat]],
align_corners=True).view(-1, x.shape[0]).t())
elif 'cp' in self.basis_type:
for idx_axis in range(self.in_dim - 1):
coordinate_vec = torch.stack(
(torch.zeros_like(xyz[..., idx_axis + 1, i]), xyz[..., idx_axis + 1, i]), dim=-1).squeeze(2)
if 0 == idx_axis:
basises_level = F.grid_sample(self.basises[i * (self.in_dim - 1) + idx_axis],
coordinate_vec,
align_corners=True).view(-1, x.shape[0]).t()
else:
basises_level = basises_level * F.grid_sample(
self.basises[i * (self.in_dim - 1) + idx_axis], coordinate_vec,
align_corners=True).view(-1, x.shape[0]).t()
basises.append(basises_level)
elif 'x' in self.basis_type:
basises.append(xyz[..., i].view(x.shape[0], -1))
if isinstance(basises, list):
basises = torch.cat(basises, dim=-1)
if 'vm' in self.basis_type: # switch order
basises = basises.view(x.shape[0], freq_len, -1).permute(0, 2, 1).reshape(x.shape[0], -1)
return basises
@torch.no_grad()
def normalize_basis(self):
for basis in self.basises:
basis.data = basis.data / torch.norm(basis.data, dim=(2, 3), keepdim=True)
def get_coding(self, x):
if self.cfg.model.coeff_type != 'none' and self.cfg.model.basis_type != 'none':
coeff = self.get_coeff(x)
basises = self.get_basis(x)
return basises * coeff, coeff
elif self.cfg.model.coeff_type != 'none':
coeff = self.get_coeff(x)
return coeff, coeff
elif self.cfg.model.basis_type != 'none':
basises = self.get_basis(x)
return basises, basises
def n_parameters(self):
total = sum(p.numel() for p in self.parameters())
if 'fix' in self.cfg.model.basis_type:
total -= self.T_basis
return total
def get_optparam_groups(self, lr_small=0.001, lr_large=0.02):
grad_vars = []
if self.cfg.training.linear_mat:
grad_vars += [{'params': self.linear_mat.parameters(), 'lr': lr_small}]
if 'none' != self.coeff_type and self.cfg.training.coeff:
grad_vars += [{'params': self.coeffs.parameters(), 'lr': lr_large}]
if 'fix' not in self.cfg.model.basis_type and 'none' != self.cfg.model.basis_type and self.cfg.training.basis:
grad_vars += [{'params': self.basises.parameters(), 'lr': lr_large}]
if 'reconstruction' in self.cfg.defaults.mode and self.cfg.training.renderModule:
grad_vars += [{'params': self.renderModule.parameters(), 'lr': lr_small}]
return grad_vars
def set_optimizable(self, items, statue):
for item in items:
if item == 'basis' and self.cfg.model.basis_type != 'none':
for item in self.basises:
item.requires_grad = statue
elif item == 'coeff' and self.cfg.model.coeff_type != 'none':
for item in self.basises:
item.requires_grad = statue
elif item == 'proj':
self.linear_mat.requires_grad = statue
elif item == 'renderer':
self.renderModule.requires_grad = statue
def TV_loss(self, reg):
total = 0
for idx in range(len(self.basises)):
total = total + reg(self.basises[idx]) * 1e-2
return total
def sample_point_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
N_samples = N_samples if N_samples > 0 else self.nSamples
near, far = self.cfg.dataset.near_far
interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o)
if is_train:
interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples)
rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
mask_outbbox = ((self.aabb[0, :self.in_dim] > rays_pts) | (rays_pts > self.aabb[1, :self.in_dim])).any(dim=-1)
return rays_pts, interpx, ~mask_outbbox
def sample_point(self, rays_o, rays_d, is_train=True, N_samples=-1):
N_samples = N_samples if N_samples > 0 else self.nSamples
vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
rate_a = (self.aabb[1, :self.in_dim] - rays_o) / vec
rate_b = (self.aabb[0, :self.in_dim] - rays_o) / vec
t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=0.05, max=1e3)
rng = torch.arange(N_samples)[None].float()
if is_train:
rng = rng.repeat(rays_d.shape[-2], 1)
rng += torch.rand_like(rng[:, [0]])
step = self.stepSize * rng.to(rays_o.device)
interpx = (t_min[..., None] + step)
rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
mask_outbbox = ((self.aabb[0, :self.in_dim] > rays_pts) | (rays_pts > self.aabb[1, :self.in_dim])).any(dim=-1)
return rays_pts, interpx, ~mask_outbbox
def sample_point_unbound(self, rays_o, rays_d, is_train=True, N_samples=-1):
N_samples = N_samples if N_samples > 0 else self.nSamples
N_inner, N_outer = 3 * N_samples // 4, N_samples // 4
b_inner = torch.linspace(0, 2, N_inner + 1).to(self.device)
b_outer = 2 / torch.linspace(1, 1 / 16, N_outer + 1).to(self.device)
if is_train:
rng = torch.rand((N_inner + N_outer), device=self.device)
interpx = torch.cat([
b_inner[1:] * rng[:N_inner] + b_inner[:-1] * (1 - rng[:N_inner]),
b_outer[1:] * rng[N_inner:] + b_outer[:-1] * (1 - rng[N_inner:]),
])[None]
else:
interpx = torch.cat([
(b_inner[1:] + b_inner[:-1]) * 0.5,
(b_outer[1:] + b_outer[:-1]) * 0.5,
])[None]
rays_pts = rays_o[:, None, :] + rays_d[:, None, :] * interpx[..., None]
norm = rays_pts.abs().amax(dim=-1, keepdim=True)
inner_mask = (norm <= 1)
rays_pts = torch.where(
inner_mask,
rays_pts,
rays_pts / norm * ((1 + self.bg_len) - self.bg_len / norm)
)
return rays_pts, interpx, inner_mask.squeeze(-1)
def normalize_coord(self, xyz_sampled):
invaabbSize = 2.0 / (self.aabb[1] - self.aabb[0])
return (xyz_sampled - self.aabb[0]) * invaabbSize - 1
def basis2density(self, density_features):
if self.cfg.renderer.fea2denseAct == "softplus":
return F.softplus(density_features + self.cfg.renderer.density_shift)
elif self.cfg.renderer.fea2denseAct == "relu":
return F.relu(density_features + self.cfg.renderer.density_shift)
@torch.no_grad()
def cal_mean_coef(self, state_dict):
if 'grid' in self.coeff_type or 'mlp' in self.coeff_type:
key_list = []
for item in state_dict.keys():
if 'coeffs.0' in item:
key_list.append(item)
for key in key_list:
average = torch.zeros_like(state_dict[key])
for i in range(self.n_scene):
item = key.replace('0', f'{i}', 1)
average += state_dict[item]
state_dict.pop(item, None)
average /= self.n_scene
state_dict[key] = average
elif 'vec' in self.coeff_type:
state_dict['coeffs.0'] = torch.mean(state_dict['coeffs.0'], dim=-1, keepdim=True)
elif 'cp' in self.coeff_type or 'vm' in self.coeff_type:
for i in range(3):
state_dict[f'coeffs.{i}'] = torch.mean(state_dict[f'coeffs.{i}'], dim=-1, keepdim=True)
return state_dict
def save(self, path):
ckpt = {'state_dict': self.state_dict(), 'cfg': self.cfg}
if self.alphaMask is not None:
alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy()
ckpt.update({'alphaMask.shape': alpha_volume.shape})
ckpt.update({'alphaMask.mask': np.packbits(alpha_volume.reshape(-1))})
ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()})
# average the coeff for saving if batch training
if 'reconstruction' in self.cfg.defaults.mode:
ckpt['state_dict'] = self.cal_mean_coef(ckpt['state_dict'])
torch.save(ckpt, path)
def load(self, ckpt):
if 'alphaMask.aabb' in ckpt.keys():
length = np.prod(ckpt['alphaMask.shape'])
alpha_volume = torch.from_numpy(
np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape']))
self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device),
alpha_volume.float().to(self.device))
self.load_state_dict(ckpt['state_dict'])
volumeSize = N_to_reso(self.cfg.training.volume_resoFinal ** self.in_dim, self.aabb)
self.update_renderParams(volumeSize)
def update_renderParams(self, gridSize):
self.aabbSize = self.aabb[1] - self.aabb[0]
self.gridSize = torch.LongTensor(gridSize).to(self.device)
units = self.aabbSize / (self.gridSize - 1)
self.stepSize = torch.mean(units) * self.cfg.renderer.step_ratio
aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize)))
self.nSamples = int((aabbDiag / self.stepSize).item()) + 1
@torch.no_grad()
def upsample_volume_grid(self, res_target):
self.update_renderParams(res_target)
if self.cfg.dataset.dataset_name == 'google_objs' and self.n_scene == 1 and self.cfg.model.coeff_type == 'grid':
coeffs = [
F.interpolate(self.coeffs[0].data, size=None, scale_factor=1.3, align_corners=True, mode='trilinear')]
self.coeffs = torch.nn.ParameterList(coeffs)
def compute_alpha(self, xyz_locs, length=1):
if self.alphaMask is not None:
alphas = self.alphaMask.sample_alpha(xyz_locs)
alpha_mask = alphas > 0
else:
alpha_mask = torch.ones_like(xyz_locs[:, 0], dtype=bool)
sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)
if alpha_mask.any():
feats, _ = self.get_coding(xyz_locs[alpha_mask])
validsigma = self.linear_mat(feats, is_train=False)[..., 0]
sigma[alpha_mask] = self.basis2density(validsigma)
alpha = 1 - torch.exp(-sigma * length).view(xyz_locs.shape[:-1])
return alpha
@torch.no_grad()
def getDenseAlpha(self, gridSize=None, times=16):
gridSize = self.gridSize.tolist() if gridSize is None else gridSize
aabbSize = self.inward_aabb[1] - self.inward_aabb[0]
units = aabbSize / (torch.LongTensor(gridSize).to(self.device) - 1)
units_half = 1.0 / (torch.LongTensor(gridSize) - 1) * 0.5
stepSize = torch.mean(units)
samples = torch.stack(torch.meshgrid(
[torch.linspace(units_half[0], 1 - units_half[0], gridSize[0]),
torch.linspace(units_half[1], 1 - units_half[1], gridSize[1]),
torch.linspace(units_half[2], 1 - units_half[2], gridSize[2])], indexing='ij'
), -1).to(self.device)
dense_xyz = self.inward_aabb[0] * (1 - samples) + self.inward_aabb[1] * samples
dense_xyz = dense_xyz.transpose(0, 2).contiguous()
alpha = torch.zeros_like(dense_xyz[..., 0])
for _ in range(times):
for i in range(gridSize[2]):
shiftment = (torch.rand(dense_xyz[i].shape) * 2 - 1).to(self.device) * (
units / 2 * 1.2) if times > 1 else 0.0
alpha[i] += self.compute_alpha((dense_xyz[i] + shiftment).view(-1, 3),
stepSize * self.cfg.renderer.distance_scale).view(
(gridSize[1], gridSize[0]))
return alpha / times, dense_xyz
@torch.no_grad()
def updateAlphaMask(self, gridSize=(200, 200, 200), is_update_alphaMask=False):
alpha, dense_xyz = self.getDenseAlpha(gridSize)
total_voxels = gridSize[0] * gridSize[1] * gridSize[2]
ks = 3
alpha = alpha.clamp(0, 1)[None, None]
alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1])
# filter floaters
min_size = np.mean(alpha.shape[-3:]).item()
alphaMask_thres = self.cfg.renderer.alphaMask_thres if is_update_alphaMask else 0.08
if self.is_unbound:
alphaMask_thres = 0.04
alpha = (alpha >= alphaMask_thres).float()
else:
alpha = skimage.morphology.remove_small_objects(alpha.cpu().numpy() >= alphaMask_thres, min_size=min_size,
connectivity=1)
alpha = torch.FloatTensor(alpha).to(self.device)
if is_update_alphaMask:
self.alphaMask = AlphaGridMask(self.device, self.inward_aabb, alpha)
valid_xyz = dense_xyz[alpha > 0.5]
xyz_min = valid_xyz.amin(0)
xyz_max = valid_xyz.amax(0)
if not self.is_unbound:
pad = (xyz_max - xyz_min) / 20
xyz_min -= pad
xyz_max += pad
new_aabb = torch.stack((xyz_min, xyz_max))
total = torch.sum(alpha)
return new_aabb
@torch.no_grad()
def shrink(self, new_aabb):
self.setup_params(new_aabb.tolist())
if self.cfg.model.coeff_type != 'none':
del self.coeffs
self.coeffs = self.init_coef()
if self.cfg.model.basis_type != 'none':
del self.basises
self.basises = self.init_basis()
self.aabb = self.inward_aabb = new_aabb
self.cfg.dataset.aabb = self.aabb.tolist()
self.update_renderParams(self.gridSize.tolist())
@torch.no_grad()
def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240 * 5, bbox_only=False):
tt = time.time()
N = torch.tensor(all_rays.shape[:-1]).prod()
mask_filtered = []
length_current = 0
idx_chunks = torch.split(torch.arange(N), chunk)
for idx_chunk in idx_chunks:
rays_chunk = all_rays[idx_chunk].to(self.device)
rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6]
if bbox_only:
vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
rate_a = (self.aabb[1] - rays_o) / vec
rate_b = (self.aabb[0] - rays_o) / vec
t_min = torch.minimum(rate_a, rate_b).amax(-1) # .clamp(min=near, max=far)
t_max = torch.maximum(rate_a, rate_b).amin(-1) # .clamp(min=near, max=far)
mask_inbbox = t_max > t_min
else:
xyz_sampled, _, _ = self.sample_point(rays_o, rays_d, N_samples=N_samples, is_train=False)
mask_inbbox = (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1)
# mask_filtered.append(mask_inbbox.cpu())
length = torch.sum(mask_inbbox)
all_rays[length_current:length_current + length], all_rgbs[length_current:length_current + length] = \
rays_chunk[mask_inbbox].cpu(), all_rgbs[idx_chunk][mask_inbbox.cpu()]
length_current += length
return all_rays[:length_current], all_rgbs[:length_current]
def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):
# sample points
viewdirs = rays_chunk[:, 3:6]
if self.is_unbound:
xyz_sampled, z_vals, inner_mask = self.sample_point_unbound(rays_chunk[:, :3], viewdirs, is_train=is_train,
N_samples=N_samples)
dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], z_vals[:, -1:] - z_vals[:, -2:-1]), dim=-1)
elif ndc_ray:
xyz_sampled, z_vals, inner_mask = self.sample_point_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,
N_samples=N_samples)
dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
dists = dists * rays_norm
viewdirs = viewdirs / rays_norm
else:
xyz_sampled, z_vals, inner_mask = self.sample_point(rays_chunk[:, :3], viewdirs, is_train=is_train,
N_samples=N_samples)
dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)
ray_valid = torch.ones_like(xyz_sampled[..., 0]).bool() if self.is_unbound else inner_mask
if self.alphaMask is not None:
alpha_inner_valid = self.alphaMask.sample_alpha(xyz_sampled[inner_mask]) > 0.5
ray_valid[inner_mask.clone()] = alpha_inner_valid
sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)
coeffs = torch.zeros((1, sum(self.cfg.model.basis_dims)), device=xyz_sampled.device)
if ray_valid.any():
feats, coeffs = self.get_coding(xyz_sampled[ray_valid])
feat = self.linear_mat(feats, is_train=is_train)
sigma[ray_valid] = self.basis2density(feat[..., 0])
alpha, weight, bg_weight = raw2alpha(sigma, dists * self.cfg.renderer.di
gitextract_iica752a/ ├── .gitignore ├── 2D_regression.py ├── LICENSE ├── README.md ├── README_FactorField.md ├── configs/ │ ├── 360_v2.yaml │ ├── defaults.yaml │ ├── image.yaml │ ├── image_intro.yaml │ ├── image_set.yaml │ ├── nerf.yaml │ ├── nerf_ft.yaml │ ├── nerf_set.yaml │ ├── sdf.yaml │ └── tnt.yaml ├── dataLoader/ │ ├── __init__.py │ ├── blender.py │ ├── blender_set.py │ ├── colmap.py │ ├── colmap2nerf.py │ ├── dtu_objs.py │ ├── dtu_objs2.py │ ├── google_objs.py │ ├── image.py │ ├── image_set.py │ ├── llff.py │ ├── nsvf.py │ ├── ray_utils.py │ ├── sdf.py │ ├── tankstemple.py │ └── your_own_data.py ├── models/ │ ├── FactorFields.py │ ├── __init__.py │ └── sh.py ├── renderer.py ├── requirements.txt ├── run_batch.py ├── scripts/ │ ├── 2D_regression.ipynb │ ├── 2D_set_regression.ipynb │ ├── 2D_set_regression.py │ ├── __init__.py │ ├── formula_demostration.ipynb │ ├── mesh2SDF_data_process.ipynb │ └── sdf_regression.ipynb ├── train_across_scene.py ├── train_across_scene_ft.py ├── train_per_scene.py └── utils.py
SYMBOL INDEX (257 symbols across 25 files)
FILE: 2D_regression.py
function PSNR (line 17) | def PSNR(a, b):
function rgb_ssim (line 26) | def rgb_ssim(img0, img1, max_val,
function eval_img (line 76) | def eval_img(aabb, reso, shiftment=[0.5, 0.5], chunk=10240):
function linear_to_srgb (line 95) | def linear_to_srgb(img):
function write_image_imageio (line 100) | def write_image_imageio(img_file, img, colormap=None, quality=100):
FILE: dataLoader/blender.py
class BlenderDataset (line 12) | class BlenderDataset(Dataset):
method __init__ (line 13) | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
method read_depth (line 40) | def read_depth(self, filename):
method read_meta (line 44) | def read_meta(self):
method define_transforms (line 105) | def define_transforms(self):
method define_proj_mat (line 108) | def define_proj_mat(self):
method __len__ (line 115) | def __len__(self):
method __getitem__ (line 118) | def __getitem__(self, idx):
FILE: dataLoader/blender_set.py
class BlenderDatasetSet (line 12) | class BlenderDatasetSet(Dataset):
method __init__ (line 13) | def __init__(self, cfg, split='train'):
method read_depth (line 39) | def read_depth(self, filename):
method read_meta (line 43) | def read_meta(self):
method define_transforms (line 103) | def define_transforms(self):
method define_proj_mat (line 106) | def define_proj_mat(self):
method __len__ (line 113) | def __len__(self):
method __getitem__ (line 116) | def __getitem__(self, idx):
FILE: dataLoader/colmap.py
class ColmapDataset (line 12) | class ColmapDataset(Dataset):
method __init__ (line 13) | def __init__(self, cfg, split='train'):
method read_meta (line 31) | def read_meta(self):
method define_transforms (line 107) | def define_transforms(self):
method __len__ (line 111) | def __len__(self):
method __getitem__ (line 114) | def __getitem__(self, idx):
FILE: dataLoader/colmap2nerf.py
function parse_args (line 23) | def parse_args():
function do_system (line 44) | def do_system(arg):
function run_ffmpeg (line 51) | def run_ffmpeg(args):
function run_colmap (line 73) | def run_colmap(args):
function variance_of_laplacian (line 107) | def variance_of_laplacian(image):
function sharpness (line 110) | def sharpness(imagePath):
function qvec2rotmat (line 116) | def qvec2rotmat(qvec):
function rotmat (line 133) | def rotmat(a, b):
function closest_point_2_lines (line 141) | def closest_point_2_lines(oa, da, ob, db): # returns point closest to bo...
function normalize (line 156) | def normalize(x):
function rotation_matrix_from_vectors (line 159) | def rotation_matrix_from_vectors(vec1, vec2):
function rotation_up (line 176) | def rotation_up(poses):
function search_orientation (line 181) | def search_orientation(points):
function load_point_txt (line 196) | def load_point_txt(path):
FILE: dataLoader/dtu_objs.py
function load_K_Rt_from_P (line 11) | def load_K_Rt_from_P(filename, P=None):
function fps_downsample (line 34) | def fps_downsample(points, n_points_to_sample):
class DTUDataset (line 47) | class DTUDataset(Dataset):
method __init__ (line 48) | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
method read_meta (line 78) | def read_meta(self):
method load_data (line 122) | def load_data(self, scene_idx, img_idx=None):
method get_bbox (line 196) | def get_bbox(self, scale_mats_np, object_scale_mat):
method gen_rays_at (line 206) | def gen_rays_at(self, intrinsic, c2w, resolution_level=1):
method __len__ (line 224) | def __len__(self):
method __getitem__ (line 227) | def __getitem__(self, idx):
FILE: dataLoader/dtu_objs2.py
function load_K_Rt_from_P (line 12) | def load_K_Rt_from_P(filename, P=None):
class DTUDataset (line 35) | class DTUDataset(Dataset):
method __init__ (line 36) | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
method get_bbox (line 57) | def get_bbox(self):
method gen_rays_at (line 68) | def gen_rays_at(self, intrinsic, c2w, resolution_level=1):
method read_meta (line 85) | def read_meta(self):
method __len__ (line 131) | def __len__(self):
method __getitem__ (line 134) | def __getitem__(self, idx):
FILE: dataLoader/google_objs.py
function fps_downsample (line 13) | def fps_downsample(points, n_points_to_sample):
function pose_spherical_nerf (line 29) | def pose_spherical_nerf(euler, radius=1.8, ep=1):
function nerf_video_path (line 36) | def nerf_video_path(c2ws, theta_range=10,phi_range=20,N_views=120,radius...
function _interpolate_trajectory (line 56) | def _interpolate_trajectory(c2ws, num_views: int = 300):
function google_objs_path (line 77) | def google_objs_path(c2ws, N_views=150):
class GoogleObjsDataset (line 85) | class GoogleObjsDataset(Dataset):
method __init__ (line 86) | def __init__(self, cfg, split="train", batch_size=4096):
method read_depth (line 174) | def read_depth(self, filename):
method read_meta (line 178) | def read_meta(self):
method get_rays (line 301) | def get_rays(self, idx):
method get_rgbs (line 306) | def get_rgbs(self, idx):
method define_transforms (line 312) | def define_transforms(self):
method define_proj_mat (line 315) | def define_proj_mat(self):
method world2ndc (line 320) | def world2ndc(self, points, lindisp=None):
method update_index (line 324) | def update_index(self):
method __len__ (line 327) | def __len__(self):
method __getitem__ (line 330) | def __getitem__(self, idx):
FILE: dataLoader/image.py
function load (line 10) | def load(path):
function srgb_to_linear (line 22) | def srgb_to_linear(img):
class ImageDataset (line 26) | class ImageDataset(Dataset):
method __init__ (line 27) | def __init__(self, cfg, batchsize, split='train', continue_sampling=Fa...
method __len__ (line 84) | def __len__(self):
method __getitem__ (line 87) | def __getitem__(self, idx):
FILE: dataLoader/image_set.py
function srgb_to_linear (line 6) | def srgb_to_linear(img):
function load (line 10) | def load(path, HW=512):
class ImageSetDataset (line 24) | class ImageSetDataset(Dataset):
method __init__ (line 25) | def __init__(self, cfg, batchsize, split='train', continue_sampling=Fa...
method __len__ (line 50) | def __len__(self):
method __getitem__ (line 53) | def __getitem__(self, idx):
FILE: dataLoader/llff.py
function normalize (line 12) | def normalize(v):
function average_poses (line 17) | def average_poses(poses):
function center_poses (line 54) | def center_poses(poses, blender2opencv):
function viewmatrix (line 81) | def viewmatrix(z, up, pos):
function render_path_spiral (line 91) | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=...
function get_spiral (line 102) | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
class LLFFDataset (line 122) | class LLFFDataset(Dataset):
method __init__ (line 123) | def __init__(self, cfg , split='train', hold_every=8):
method read_meta (line 148) | def read_meta(self):
method define_transforms (line 231) | def define_transforms(self):
method __len__ (line 234) | def __len__(self):
method __getitem__ (line 237) | def __getitem__(self, idx):
FILE: dataLoader/nsvf.py
function pose_spherical (line 29) | def pose_spherical(theta, phi, radius):
class NSVF (line 36) | class NSVF(Dataset):
method __init__ (line 38) | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800...
method bbox2corners (line 56) | def bbox2corners(self):
method read_meta (line 63) | def read_meta(self):
method define_transforms (line 132) | def define_transforms(self):
method define_proj_mat (line 135) | def define_proj_mat(self):
method world2ndc (line 138) | def world2ndc(self, points):
method __len__ (line 142) | def __len__(self):
method __getitem__ (line 147) | def __getitem__(self, idx):
FILE: dataLoader/ray_utils.py
function load_json (line 7) | def load_json(path):
function depth2dist (line 12) | def depth2dist(z_vals, cos_angle):
function ndc2dist (line 21) | def ndc2dist(ndc_pts, cos_angle):
function get_ray_directions (line 27) | def get_ray_directions(H, W, focal, center=None):
function get_ray_directions_blender (line 48) | def get_ray_directions_blender(H, W, focal, center=None):
function get_rays (line 69) | def get_rays(directions, c2w):
function ndc_rays_blender (line 93) | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
function ndc_rays (line 112) | def ndc_rays(H, W, focal, near, rays_o, rays_d):
function sample_pdf (line 132) | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
function dda (line 177) | def dda(rays_o, rays_d, bbox_3D):
function ray_marcher (line 187) | def ray_marcher(rays,
function read_pfm (line 234) | def read_pfm(filename):
class SimpleSampler (line 271) | class SimpleSampler:
method __init__ (line 272) | def __init__(self, total, batch):
method nextids (line 278) | def nextids(self):
function ndc_bbox (line 285) | def ndc_bbox(all_rays):
function pose_from_json (line 293) | def pose_from_json(meta, transpose):
function rotation_matrix_from_vectors (line 300) | def rotation_matrix_from_vectors(vec1, vec2):
function normalize (line 317) | def normalize(x):
function rotation_up (line 320) | def rotation_up(poses):
function search_orientation (line 325) | def search_orientation(points):
function load_point_txt (line 341) | def load_point_txt(path):
function orientation (line 365) | def orientation(poses, point_path=None):
function spherify_poses (line 397) | def spherify_poses(poses, radus=1):
FILE: dataLoader/sdf.py
function N_to_reso (line 5) | def N_to_reso(avg_reso, bbox):
function load (line 12) | def load(path, split, dtype='points'):
class SDFDataset (line 31) | class SDFDataset(Dataset):
method __init__ (line 32) | def __init__(self, cfg, split='train'):
method __len__ (line 42) | def __len__(self):
method __getitem__ (line 45) | def __getitem__(self, idx):
FILE: dataLoader/tankstemple.py
function circle (line 11) | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
function cross (line 20) | def cross(x, y, axis=0):
function normalize (line 25) | def normalize(x, axis=-1, order=2):
function cat (line 37) | def cat(x, axis=1):
function look_at_rotation (line 43) | def look_at_rotation(camera_position, at=None, up=None, inverse=False, c...
function gen_path (line 76) | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
class TanksTempleDataset (line 87) | class TanksTempleDataset(Dataset):
method __init__ (line 90) | def __init__(self, cfg, split='train'):
method bbox2corners (line 109) | def bbox2corners(self):
method read_meta (line 115) | def read_meta(self):
method define_transforms (line 192) | def define_transforms(self):
method define_proj_mat (line 195) | def define_proj_mat(self):
method world2ndc (line 198) | def world2ndc(self, points):
method __len__ (line 202) | def __len__(self):
method __getitem__ (line 207) | def __getitem__(self, idx):
FILE: dataLoader/your_own_data.py
class YourOwnDataset (line 13) | class YourOwnDataset(Dataset):
method __init__ (line 14) | def __init__(self, datadir, split='train', downsample=1.0, is_stack=Fa...
method read_depth (line 35) | def read_depth(self, filename):
method read_meta (line 39) | def read_meta(self):
method define_transforms (line 102) | def define_transforms(self):
method define_proj_mat (line 105) | def define_proj_mat(self):
method world2ndc (line 108) | def world2ndc(self,points,lindisp=None):
method __len__ (line 112) | def __len__(self):
method __getitem__ (line 115) | def __getitem__(self, idx):
FILE: models/FactorFields.py
function grid_mapping (line 11) | def grid_mapping(positions, freq_bands, aabb, basis_mapping='sawtooth'):
function dct_dict (line 36) | def dct_dict(n_atoms_fre, size, n_selete, dim=2):
function positional_encoding (line 74) | def positional_encoding(positions, freqs):
function raw2alpha (line 82) | def raw2alpha(sigma, dist):
class AlphaGridMask (line 91) | class AlphaGridMask(torch.nn.Module):
method __init__ (line 92) | def __init__(self, device, aabb, alpha_volume):
method sample_alpha (line 103) | def sample_alpha(self, xyz_sampled):
method normalize_coord (line 109) | def normalize_coord(self, xyz_sampled):
class MLPMixer (line 113) | class MLPMixer(torch.nn.Module):
method __init__ (line 114) | def __init__(self,
method forward (line 144) | def forward(self, x, is_train=False):
class MLPRender_Fea (line 162) | class MLPRender_Fea(torch.nn.Module):
method __init__ (line 163) | def __init__(self, inChanel, num_layers=3, hidden_dim=64, viewpe=6, fe...
method forward (line 188) | def forward(self, viewdirs, features):
class FactorFields (line 206) | class FactorFields(torch.nn.Module):
method __init__ (line 207) | def __init__(self, cfg, device):
method setup_params (line 262) | def setup_params(self, aabb):
method init_coef (line 311) | def init_coef(self):
method init_basis (line 334) | def init_basis(self):
method get_coeff (line 425) | def get_coeff(self, xyz_sampled):
method get_basis (line 467) | def get_basis(self, x):
method normalize_basis (line 519) | def normalize_basis(self):
method get_coding (line 523) | def get_coding(self, x):
method n_parameters (line 535) | def n_parameters(self):
method get_optparam_groups (line 541) | def get_optparam_groups(self, lr_small=0.001, lr_large=0.02):
method set_optimizable (line 556) | def set_optimizable(self, items, statue):
method TV_loss (line 569) | def TV_loss(self, reg):
method sample_point_ndc (line 575) | def sample_point_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
method sample_point (line 586) | def sample_point(self, rays_o, rays_d, is_train=True, N_samples=-1):
method sample_point_unbound (line 604) | def sample_point_unbound(self, rays_o, rays_d, is_train=True, N_sample...
method normalize_coord (line 635) | def normalize_coord(self, xyz_sampled):
method basis2density (line 639) | def basis2density(self, density_features):
method cal_mean_coef (line 646) | def cal_mean_coef(self, state_dict):
method save (line 669) | def save(self, path):
method load (line 682) | def load(self, ckpt):
method update_renderParams (line 693) | def update_renderParams(self, gridSize):
method upsample_volume_grid (line 702) | def upsample_volume_grid(self, res_target):
method compute_alpha (line 710) | def compute_alpha(self, xyz_locs, length=1):
method getDenseAlpha (line 730) | def getDenseAlpha(self, gridSize=None, times=16):
method updateAlphaMask (line 758) | def updateAlphaMask(self, gridSize=(200, 200, 200), is_update_alphaMas...
method shrink (line 796) | def shrink(self, new_aabb):
method filtering_rays (line 812) | def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=1024...
method forward (line 843) | def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=F...
FILE: models/sh.py
function eval_sh (line 34) | def eval_sh(deg, sh, dirs):
function eval_sh_bases (line 87) | def eval_sh_bases(deg, dirs):
FILE: renderer.py
function render_ray (line 8) | def render_ray(rays, factor_fields, chunk=4096, N_samples=-1, ndc_ray=Fa...
function evaluation (line 30) | def evaluation(test_dataset,factor_fields, renderer, savePath=None, N_vi...
function evaluation_path (line 101) | def evaluation_path(test_dataset,factor_fields, c2ws, renderer, savePath...
FILE: run_batch.py
function run_program (line 123) | def run_program(gpu, cmd):
FILE: scripts/2D_set_regression.py
function PSNR (line 19) | def PSNR(a, b):
function eval_img (line 29) | def eval_img(aabb, reso, idx, shiftment=[0.5, 0.5, 0.5], chunk=10240):
function eval_img_single (line 49) | def eval_img_single(aabb, reso, chunk=10240):
function linear_to_srgb (line 67) | def linear_to_srgb(img):
function srgb_to_linear (line 72) | def srgb_to_linear(img):
function write_image_imageio (line 77) | def write_image_imageio(img_file, img, colormap=None, quality=100):
function interpolate (line 97) | def interpolate(colormap, x):
FILE: train_across_scene.py
class SimpleSampler (line 17) | class SimpleSampler:
method __init__ (line 18) | def __init__(self, total, batch):
method nextids (line 24) | def nextids(self):
function export_mesh (line 33) | def export_mesh(cfg):
function render_test (line 55) | def render_test(cfg):
function reconstruction (line 99) | def reconstruction(cfg):
FILE: train_across_scene_ft.py
class SimpleSampler (line 17) | class SimpleSampler:
method __init__ (line 18) | def __init__(self, total, batch):
method nextids (line 24) | def nextids(self):
function export_mesh (line 33) | def export_mesh(cfg):
function render_test (line 42) | def render_test(cfg):
function reconstruction (line 88) | def reconstruction(cfg):
FILE: train_per_scene.py
class SimpleSampler (line 16) | class SimpleSampler:
method __init__ (line 17) | def __init__(self, total, batch):
method nextids (line 23) | def nextids(self):
function export_mesh (line 32) | def export_mesh(ckpt_path):
function render_test (line 44) | def render_test(cfg):
function reconstruction (line 89) | def reconstruction(cfg):
FILE: utils.py
function visualize_depth_numpy (line 11) | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
function init_log (line 28) | def init_log(log, keys):
function visualize_depth (line 33) | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
function N_to_reso (line 53) | def N_to_reso(n_voxels, bbox):
function N_to_vm_reso (line 59) | def N_to_vm_reso(n_voxels, bbox):
function cal_n_samples (line 69) | def cal_n_samples(reso, step_ratio=0.5):
class SimpleSampler (line 72) | class SimpleSampler:
method __init__ (line 73) | def __init__(self, total, batch):
method nextids (line 79) | def nextids(self):
function init_lpips (line 87) | def init_lpips(net_name, device):
function rgb_lpips (line 93) | def rgb_lpips(np_gt, np_im, net_name, device):
function findItem (line 101) | def findItem(items, target):
function rgb_ssim (line 110) | def rgb_ssim(img0, img1, max_val,
class TVLoss (line 160) | class TVLoss(nn.Module):
method __init__ (line 161) | def __init__(self,TVLoss_weight=1):
method forward (line 165) | def forward(self,x):
method _tensor_size (line 175) | def _tensor_size(self,t):
function marchcude_to_world (line 178) | def marchcude_to_world(vertices, reso_WHD):
function convert_sdf_samples_to_ply (line 183) | def convert_sdf_samples_to_ply(
Condensed preview — 48 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,773K chars).
[
{
"path": ".gitignore",
"chars": 3099,
"preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
},
{
"path": "2D_regression.py",
"chars": 9265,
"preview": "import torch,imageio,sys,time,os,cmapy,scipy\nimport numpy as np\nfrom tqdm import tqdm\nimport matplotlib.pyplot as plt\nfr"
},
{
"path": "LICENSE",
"chars": 1073,
"preview": "MIT License\n\nCopyright (c) 2023 autonomousvision\n\nPermission is hereby granted, free of charge, to any person obtaining "
},
{
"path": "README.md",
"chars": 6031,
"preview": "# Factor Fields\n## [Project page](https://apchenstu.github.io/FactorFields/) | [Paper](https://arxiv.org/abs/2302.01226"
},
{
"path": "README_FactorField.md",
"chars": 4832,
"preview": "## nerf reconstruction with Dictionary field\n\n```python\n\tfor scene in ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus',"
},
{
"path": "configs/360_v2.yaml",
"chars": 1568,
"preview": "\ndefaults:\n expname: basis_room_real_mask\n logdir: ./logs\n\n ckpt: null # help='specific weights npy "
},
{
"path": "configs/defaults.yaml",
"chars": 1342,
"preview": "\ndefaults:\n expname: basis_lego\n logedir: ./logs\n\n mode: 'reconstruction'\n\n progress_refresh_rate: 10\n\n add_timest"
},
{
"path": "configs/image.yaml",
"chars": 1129,
"preview": "\ndefaults:\n expname: basis_image\n logdir: ./logs\n\n mode: 'image'\n\n ckpt: null # help='specific weig"
},
{
"path": "configs/image_intro.yaml",
"chars": 767,
"preview": "\ndefaults:\n expname: basis_image\n logdir: ./logs\n\n mode: 'demo'\n\n ckpt: null # help='specific weigh"
},
{
"path": "configs/image_set.yaml",
"chars": 733,
"preview": "\ndefaults:\n expname: basis_image\n logdir: ./logs\n\n mode: 'images'\n\n ckpt: null # help='specific wei"
},
{
"path": "configs/nerf.yaml",
"chars": 1613,
"preview": "defaults:\n expname: basis_lego\n logdir: ./logs\n\n mode: 'reconstruction'\n\n ckpt: null # help='specif"
},
{
"path": "configs/nerf_ft.yaml",
"chars": 1795,
"preview": "defaults:\n expname: basis\n logdir: ./logs\n\n mode: 'reconstructions'\n\n ckpt: null # help='specific w"
},
{
"path": "configs/nerf_set.yaml",
"chars": 1831,
"preview": "defaults:\n expname: basis_no_relu_lego\n logdir: ./logs\n\n mode: 'reconstructions'\n\n ckpt: null # hel"
},
{
"path": "configs/sdf.yaml",
"chars": 745,
"preview": "defaults:\n expname: basis_sdf\n logdir: ./logs\n\n mode: 'sdf'\n\n ckpt: null # help='specific weights n"
},
{
"path": "configs/tnt.yaml",
"chars": 2138,
"preview": "\ndefaults:\n expname: basis_truck\n logdir: ./logs\n\n ckpt: null # help='specific weights npy file to r"
},
{
"path": "dataLoader/__init__.py",
"chars": 913,
"preview": "from .llff import LLFFDataset\nfrom .blender import BlenderDataset\nfrom .nsvf import NSVF\nfrom .tankstemple import TanksT"
},
{
"path": "dataLoader/blender.py",
"chars": 5265,
"preview": "import torch, cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image"
},
{
"path": "dataLoader/blender_set.py",
"chars": 5025,
"preview": "import torch, cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image"
},
{
"path": "dataLoader/colmap.py",
"chars": 4958,
"preview": "import torch, cv2\nfrom torch.utils.data import Dataset\n\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torch"
},
{
"path": "dataLoader/colmap2nerf.py",
"chars": 13415,
"preview": "#!/usr/bin/env python3\n\n# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.\n#\n# NVIDIA CORPORATION and "
},
{
"path": "dataLoader/dtu_objs.py",
"chars": 10686,
"preview": "\nimport torch\nimport cv2 as cv\nimport numpy as np\nimport os\nfrom glob import glob\nfrom .ray_utils import *\nfrom torch.ut"
},
{
"path": "dataLoader/dtu_objs2.py",
"chars": 5829,
"preview": "\nimport torch\nimport cv2 as cv\nimport numpy as np\nimport os\nfrom glob import glob\nfrom .ray_utils import *\nfrom torch.ut"
},
{
"path": "dataLoader/google_objs.py",
"chars": 16048,
"preview": "import torch, cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image"
},
{
"path": "dataLoader/image.py",
"chars": 3881,
"preview": "import torch,imageio,cv2\nfrom PIL import Image \nImage.MAX_IMAGE_PIXELS = 1000000000 \nimport numpy as np\nimport torch.nn."
},
{
"path": "dataLoader/image_set.py",
"chars": 2628,
"preview": "import torch,cv2\nimport numpy as np\nimport torch.nn.functional as F\nfrom torch.utils.data import Dataset\n\ndef srgb_to_li"
},
{
"path": "dataLoader/llff.py",
"chars": 9522,
"preview": "import torch\nfrom torch.utils.data import Dataset\nimport glob\nimport numpy as np\nimport os\nfrom PIL import Image\nfrom to"
},
{
"path": "dataLoader/nsvf.py",
"chars": 6584,
"preview": "import torch\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision"
},
{
"path": "dataLoader/ray_utils.py",
"chars": 16205,
"preview": "import torch, re, json\nimport numpy as np\nfrom torch import searchsorted\nfrom kornia import create_meshgrid\n\n\ndef load_j"
},
{
"path": "dataLoader/sdf.py",
"chars": 1684,
"preview": "import torch\nimport numpy as np\nfrom torch.utils.data import Dataset\n\ndef N_to_reso(avg_reso, bbox):\n xyz_min, xyz_ma"
},
{
"path": "dataLoader/tankstemple.py",
"chars": 8866,
"preview": "import torch\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision"
},
{
"path": "dataLoader/your_own_data.py",
"chars": 5026,
"preview": "import torch,cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\n"
},
{
"path": "models/FactorFields.py",
"chars": 42185,
"preview": "import torch, math\nimport torch.nn\nimport torch.nn.functional as F\nimport numpy as np\nimport time, skimage\nfrom utils im"
},
{
"path": "models/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "models/sh.py",
"chars": 5231,
"preview": "import torch\n\n################## sh function ##################\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n"
},
{
"path": "renderer.py",
"chars": 6556,
"preview": "import torch,os,imageio,sys\nfrom tqdm.auto import tqdm\nfrom dataLoader.ray_utils import get_rays\nfrom utils import *\nfro"
},
{
"path": "requirements.txt",
"chars": 170,
"preview": "imageio==2.26.0\nkornia==0.6.10\nlpips==0.1.4\nmatplotlib==3.7.1\nomegaconf==2.3.0\nopencv_python==4.7.0.72\nplyfile==0.7.4\nsc"
},
{
"path": "run_batch.py",
"chars": 8462,
"preview": "\nimport os\nimport threading, queue\nimport numpy as np\nimport time\n\n\nif __name__ == '__main__':\n\n\t################ per s"
},
{
"path": "scripts/2D_regression.ipynb",
"chars": 690058,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"id\": \"f969c229-5a8a-44b6-91a3-bba55968b202\",\n \""
},
{
"path": "scripts/2D_set_regression.ipynb",
"chars": 744891,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"id\": \"f969c229-5a8a-44b6-91a3-bba55968b202\",\n \""
},
{
"path": "scripts/2D_set_regression.py",
"chars": 5635,
"preview": "import torch,imageio,sys,cmapy,time,os\nimport numpy as np\nfrom tqdm import tqdm\n# from .autonotebook import tqdm as tqdm"
},
{
"path": "scripts/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "scripts/formula_demostration.ipynb",
"chars": 33703,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"id\": \"f969c229-5a8a-44b6-91a3-bba55968b202\",\n \""
},
{
"path": "scripts/mesh2SDF_data_process.ipynb",
"chars": 3598,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"id\": \"47120857-b2aa-4a36-8733-e04776ca7a80\",\n \""
},
{
"path": "scripts/sdf_regression.ipynb",
"chars": 15828,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"id\": \"b4ee17ad-e38b-4d49-a8d6-0ec910d9e528\",\n \""
},
{
"path": "train_across_scene.py",
"chars": 10562,
"preview": "from tqdm.auto import tqdm\nfrom omegaconf import OmegaConf\nfrom models.FactorFields import FactorFields\n\nimport json, ra"
},
{
"path": "train_across_scene_ft.py",
"chars": 11993,
"preview": "from tqdm.auto import tqdm\nfrom omegaconf import OmegaConf\nfrom models.FactorFields import FactorFields\n\nimport json, ra"
},
{
"path": "train_per_scene.py",
"chars": 10729,
"preview": "from tqdm.auto import tqdm\nfrom omegaconf import OmegaConf\nfrom models.FactorFields import FactorFields\n\nimport json, ra"
},
{
"path": "utils.py",
"chars": 7990,
"preview": "import cv2,torch,math\nimport numpy as np\nfrom PIL import Image\nimport torchvision.transforms as T\nimport torch.nn.functi"
}
]
About this extraction
This page contains the full source code of the autonomousvision/factor-fields GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 48 files (1.7 MB), approximately 1.1M tokens, and a symbol index with 257 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.