main 2a43edea0d17 cached
48 files
1.7 MB
1.1M tokens
257 symbols
1 requests
Download .txt
Showing preview only (1,762K chars total). Download the full file or copy to clipboard to get everything.
Repository: autonomousvision/factor-fields
Branch: main
Commit: 2a43edea0d17
Files: 48
Total size: 1.7 MB

Directory structure:
gitextract_iica752a/

├── .gitignore
├── 2D_regression.py
├── LICENSE
├── README.md
├── README_FactorField.md
├── configs/
│   ├── 360_v2.yaml
│   ├── defaults.yaml
│   ├── image.yaml
│   ├── image_intro.yaml
│   ├── image_set.yaml
│   ├── nerf.yaml
│   ├── nerf_ft.yaml
│   ├── nerf_set.yaml
│   ├── sdf.yaml
│   └── tnt.yaml
├── dataLoader/
│   ├── __init__.py
│   ├── blender.py
│   ├── blender_set.py
│   ├── colmap.py
│   ├── colmap2nerf.py
│   ├── dtu_objs.py
│   ├── dtu_objs2.py
│   ├── google_objs.py
│   ├── image.py
│   ├── image_set.py
│   ├── llff.py
│   ├── nsvf.py
│   ├── ray_utils.py
│   ├── sdf.py
│   ├── tankstemple.py
│   └── your_own_data.py
├── models/
│   ├── FactorFields.py
│   ├── __init__.py
│   └── sh.py
├── renderer.py
├── requirements.txt
├── run_batch.py
├── scripts/
│   ├── 2D_regression.ipynb
│   ├── 2D_set_regression.ipynb
│   ├── 2D_set_regression.py
│   ├── __init__.py
│   ├── formula_demostration.ipynb
│   ├── mesh2SDF_data_process.ipynb
│   └── sdf_regression.ipynb
├── train_across_scene.py
├── train_across_scene_ft.py
├── train_per_scene.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# poetry
#   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
#   This is especially recommended for binary packages to ensure reproducibility, and is more
#   commonly ignored for libraries.
#   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
#   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
#   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
#   in version control.
#   https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
#  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
#  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
#  and can be added to the global gitignore or merged into this file.  For a more nuclear
#  option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

data/
slurm/
logs/



================================================
FILE: 2D_regression.py
================================================
import torch,imageio,sys,time,os,cmapy,scipy
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
import torch.nn.functional as F

device = 'cuda'

sys.path.append('..')
from models.sparseCoding import sparseCoding

from dataLoader import dataset_dict
from torch.utils.data import DataLoader


def PSNR(a, b):
    if type(a).__module__ == np.__name__:
        mse = np.mean((a - b) ** 2)
    else:
        mse = torch.mean((a - b) ** 2).item()
    psnr = -10.0 * np.log(mse) / np.log(10.0)
    return psnr


def rgb_ssim(img0, img1, max_val,
             filter_size=11,
             filter_sigma=1.5,
             k1=0.01,
             k2=0.03,
             return_map=False):
    # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
    assert len(img0.shape) == 3
    assert img0.shape[-1] == 3
    assert img0.shape == img1.shape

    # Construct a 1D Gaussian blur filter.
    hw = filter_size // 2
    shift = (2 * hw - filter_size + 1) / 2
    f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma) ** 2
    filt = np.exp(-0.5 * f_i)
    filt /= np.sum(filt)

    # Blur in x and y (faster than the 2D convolution).
    def convolve2d(z, f):
        return scipy.signal.convolve2d(z, f, mode='valid')

    filt_fn = lambda z: np.stack([
        convolve2d(convolve2d(z[..., i], filt[:, None]), filt[None, :])
        for i in range(z.shape[-1])], -1)
    mu0 = filt_fn(img0)
    mu1 = filt_fn(img1)
    mu00 = mu0 * mu0
    mu11 = mu1 * mu1
    mu01 = mu0 * mu1
    sigma00 = filt_fn(img0 ** 2) - mu00
    sigma11 = filt_fn(img1 ** 2) - mu11
    sigma01 = filt_fn(img0 * img1) - mu01

    # Clip the variances and covariances to valid values.
    # Variance must be non-negative:
    sigma00 = np.maximum(0., sigma00)
    sigma11 = np.maximum(0., sigma11)
    sigma01 = np.sign(sigma01) * np.minimum(
        np.sqrt(sigma00 * sigma11), np.abs(sigma01))
    c1 = (k1 * max_val) ** 2
    c2 = (k2 * max_val) ** 2
    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
    ssim_map = numer / denom
    ssim = np.mean(ssim_map)
    return ssim_map if return_map else ssim


@torch.no_grad()
def eval_img(aabb, reso, shiftment=[0.5, 0.5], chunk=10240):
    y = torch.linspace(0, aabb[0] - 1, reso[0])
    x = torch.linspace(0, aabb[1] - 1, reso[1])
    yy, xx = torch.meshgrid((y, x), indexing='ij')

    idx = 0
    res = torch.empty(reso[0] * reso[1], train_dataset.img.shape[-1])
    coordiantes = torch.stack((xx, yy), dim=-1).reshape(-1, 2) + torch.tensor(
        shiftment)  # /(torch.FloatTensor(reso[::-1])-1)*2-1
    for coordiante in tqdm(torch.split(coordiantes, chunk, dim=0)):
        feats, _ = model.get_coding(coordiante.to(model.device))
        y_recon = model.linear_mat(feats, is_train=False)
        # y_recon = torch.sum(feats,dim=-1,keepdim=True)

        res[idx:idx + y_recon.shape[0]] = y_recon.cpu()
        idx += y_recon.shape[0]
    return res.view(reso[0], reso[1], -1), coordiantes


def linear_to_srgb(img):
    limit = 0.0031308
    return np.where(img > limit, 1.055 * (img ** (1.0 / 2.4)) - 0.055, 12.92 * img)


def write_image_imageio(img_file, img, colormap=None, quality=100):
    if colormap == 'turbo':
        shape = img.shape
        img = interpolate(turbo_colormap_data, img.reshape(-1)).reshape(*shape, -1)
    elif colormap is not None:
        img = cmapy.colorize((img * 255).astype('uint8'), colormap)

    if img.dtype != 'uint8':
        img = (img - np.min(img)) / (np.max(img) - np.min(img))
        img = (img * 255.0).astype(np.uint8)

    kwargs = {}
    if os.path.splitext(img_file)[1].lower() in [".jpg", ".jpeg"]:
        if img.ndim >= 3 and img.shape[2] > 3:
            img = img[:, :, :3]
        kwargs["quality"] = quality
        kwargs["subsampling"] = 0
    imageio.imwrite(img_file, img, **kwargs)

if __name__ == '__main__':

    torch.set_default_dtype(torch.float32)
    torch.manual_seed(20211202)
    np.random.seed(20211202)

    base_conf = OmegaConf.load('configs/defaults.yaml')
    cli_conf = OmegaConf.from_cli()
    second_conf = OmegaConf.load('configs/image.yaml')
    cfg = OmegaConf.merge(base_conf, second_conf, cli_conf)
    print(cfg)


    folder = cfg.defaults.expname
    save_root = f'/vlg-nfs/anpei/project/NeuBasis/ours/images/'

    dataset = dataset_dict[cfg.dataset.dataset_name]

    delete_region = [[290,350,48,48],[300,380,48,48],[180, 407, 48, 48], [223, 263, 48, 48], [233, 150, 48, 48], [374, 119, 48, 48], [4, 199, 48, 48], [180, 234, 48, 48], [173, 39, 48, 48], [408, 308, 48, 48], [227, 177, 48, 48], [46, 330, 48, 48], [213, 26, 48, 48], [90, 44, 48, 48], [295, 61, 48, 48]]
    continue_sampling = False

    psnrs,ssims = [],[]
    for i in  range(1,257):
        cfg.dataset.datadir = f'/vlg-nfs/anpei/dataset/Images/crop//{i:04d}.png'
        name = os.path.basename(cfg.dataset.datadir).split('.')[0]
        if os.path.exists(f'{save_root}/{folder}/{int(name):04d}.png'):
            continue


        train_dataset = dataset(cfg.dataset, cfg.training.batch_size, split='train',tolinear=True, perscent=1.0,HW=1024)#, continue_sampling=continue_sampling,delete_region=delete_region
        train_loader = DataLoader(train_dataset,
                      num_workers=2,
                      persistent_workers=True,
                      batch_size=None,
                      pin_memory=False)
        # train_dataset.img = train_dataset.img.to(device)

        cfg.model.out_dim = train_dataset.img.shape[-1]
        batch_size = cfg.training.batch_size
        n_iter = cfg.training.n_iters

        H,W = train_dataset.HW
        train_dataset.scene_bbox = [[0., 0.], [W, H]]
        cfg.dataset.aabb = train_dataset.scene_bbox

        model = sparseCoding(cfg, device)
        if 1==i:
            print(model)
            print('total parameters: ',model.n_parameters())

        # tvreg = TVLoss()
        # trainingSampler = SimpleSampler(len(train_dataset), cfg.training.batch_size)

        grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)
        optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#


        loss_scale = 1.0
        lr_factor = 0.1 ** (1 / n_iter)
        # pbar = tqdm(range(10000))
        start = time.time()
        # for iteration in pbar:
        for (iteration, sample) in zip(range(10000),train_loader):
            loss_scale *= lr_factor

            # if iteration==5000:
            #     model.coeffs = torch.nn.Parameter(F.interpolate(model.coeffs.data, size=None, scale_factor=2.0, align_corners=True,mode='bilinear'))
            #     grad_vars = model.get_optparam_groups(lr_small=cfg.training.lr_small,lr_large=cfg.training.lr_large)
            #     optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))#
            #     model.set_optimizable(['mlp','basis'], False)

            coordiantes, pixel_rgb = sample['xy'], sample['rgb']
            feats,coeff = model.get_coding(coordiantes.to(device))
            # tv_loss = model.TV_loss(tvreg)

            y_recon = model.linear_mat(feats,is_train=True)
            # y_recon = torch.sum(feats,dim=-1,keepdim=True)
            loss = torch.mean((y_recon.squeeze()-pixel_rgb.squeeze().to(device))**2)


            psnr = -10.0 * np.log(loss.item()) / np.log(10.0)
            # if iteration%100==0:
            #     pbar.set_description(
            #                 f'Iteration {iteration:05d}:'
            #                 + f' loss_dist = {loss.item():.8f}'
            #                 # + f' tv_loss = {tv_loss.item():.6f}'
            #                 + f' psnr = {psnr:.3f}'
            #             )

            # loss = loss + tv_loss
            # loss = loss + torch.mean(coeff.abs())*1e-2
            loss = loss * loss_scale
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # if iteration%100==0:
            #     model.normalize_basis()
        iteration_time = time.time()-start

        H,W = train_dataset.HW
        img,coordinate = eval_img(train_dataset.HW,[1024,1024])
        if continue_sampling:
            import torch.nn.functional as F
            coordinate_tmp = (coordinate.view(1,1,-1,2))/torch.tensor([W,H])*2-1.0
            img_gt = F.grid_sample(train_dataset.img.view(1,H,W,-1).permute(0,3,1,2),coordinate_tmp, mode='bilinear',
                                   align_corners=False, padding_mode='border').reshape(-1,H,W).permute(1,2,0)
        else:
            img_gt = train_dataset.img.view(H,W,-1)
        psnrs.append(PSNR(img.clamp(0,1.),img_gt))
        ssims.append(rgb_ssim(img.clamp(0,1.),img_gt,1.0))
        # print(PSNR(img.clamp(0,1.),img_gt),iteration_time)
        # plt.figure(figsize=(10, 10))
        # plt.imshow(linear_to_srgb(img.clamp(0,1.)))

        print(i, psnrs[-1], ssims[-1])


        os.makedirs(f'{save_root}/{folder}',exist_ok=True)
        write_image_imageio(f'{save_root}/{folder}/{int(name):04d}.png',linear_to_srgb(img.clamp(0,1.)))
        np.savetxt(f'{save_root}/{folder}/{int(name):04d}.txt',[psnrs[-1],ssims[-1],iteration_time,model.n_parameters()])




================================================
FILE: LICENSE
================================================
MIT License

Copyright (c) 2023 autonomousvision

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# Factor Fields
## [Project page](https://apchenstu.github.io/FactorFields/) |  [Paper](https://arxiv.org/abs/2302.01226)
This repository contains a pytorch implementation for the paper: [Factor Fields: A Unified Framework for Neural Fields and Beyond](https://arxiv.org/abs/2302.01226) and [Dictionary Fields: Learning a Neural Basis Decomposition](https://arxiv.org/abs/2302.01226). Our work present a novel framework for modeling and representing signals, 
we have also observed that Dictionary Fields offer benefits such as improved **approximation quality**, **compactness**, **faster training speed**, and the ability to **generalize** to unseen images and 3D scenes.<br><br>


## Installation

#### Tested on Ubuntu 20.04 + Pytorch 1.13.0 

Install environment:
```sh
conda create -n FactorFields python=3.9
conda activate FactorFields
conda install -c "nvidia/label/cuda-11.7.1" cuda-toolkit
conda install pytorch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install -r requirements.txt 
```

Optionally install [tiny-cuda-nn](https://github.com/NVlabs/tiny-cuda-nn), only needed if you want to run hash grid based representations.
```sh
conda install -c "nvidia/label/cuda-11.7.1" cuda-toolkit
pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch
```


# Quick Start
Please ensure that you download the corresponding dataset and extract its contents into the `data` folder.

## Image
* [Data - Image Set](https://huggingface.co/apchen/Factor_Fields/blob/main/images.zip)

The training script can be found at `scripts/2D_regression.ipynb`, and the configuration file is located at `configs/image.yaml`.

<p align="left">
  <img src="media/Girl_with_a_Pearl_Earring.jpg" alt="Girl with a Pearl Earring" width="320">
</p>

## SDF
* [Data - Mesh set](https://huggingface.co/apchen/Factor_Fields/blob/main/SDFs.zip)

The training script can be found at `scripts/sdf_regression.ipynb`, and the configuration file is located at `configs/sdf.yaml`.

<img src="https://github.com/apchenstu/GIFs/blob/main/FactorField-statuette.gif" alt="GIF" width="500px">



## NeRF
* [Data - Synthetic-NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 
* [Data-Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip)

The training script can be found at `train_per_scene.py`:

```python
python train_per_scene.py configs/nerf.yaml defaults.expname=lego dataset.datadir=./data/nerf_synthetic/lego
```

<img src="https://github.com/apchenstu/GIFs/blob/main/FactorField-mic.gif" alt="GIF" width="500px"


## Generalization Image
* [Data - FFHQ](https://github.com/NVlabs/ffhq-dataset)

The training script can be found at `2D_set_regression.ipynb`

<p align="left">
  <img src="media/inpainting.png" alt="Inpainting" width="640">
</p>



## Generalization NeRF
* [Data - Google Scanned Objects](https://drive.google.com/file/d/1w1Cs0yztH6kE3JIz7mdggvPGCwIKkVi2/view)

```python
python train_across_scene.py configs/nerf_set.yaml
```

<img src="https://github.com/apchenstu/GIFs/blob/main/FactorField-few-shot.gif" alt="GIF" width="500px">


## More examples

Command explanation with a nerf example:
* `model.basis_dims=[4, 4, 4, 2, 2, 2]` adjusts the number of levels and channels at each level, with a total of 6 levels and 18 channels.
* `model.basis_resos=[32, 51, 70, 89, 108, 128]` represents the resolution of the feature embeddings.
* `model.freq_bands=[2.0, 3.2, 4.4, 5.6, 6.8, 8.0]` indicates the frequency parameters applied at each level of the coordinate transformation function.
* `model.coeff_type` represents the coefficient field representations and can be one of the following: [none, x, grid, mlp, vec, cp, vm].
* `model.basis_type` represents the basis field representation and can be one of the following: [none, x, grid, mlp, vec, cp, vm, hash].
* `model.basis_mapping` represents the coordinate transformation and can be one of the following: [x, triangle, sawtooth, trigonometric]. Please note that if you want to use orthogonal projection, choose the cp or vm basis type, as they automatically utilize the orthogonal projection functions.
* `model.total_params` controls the total model size. It is important to note that the model's size capability is determined by model.basis_resos and model.basis_dims. The total_params parameter mainly affects the capability of the coefficients.
* `exportation.render_only` you can rendering item after training by setting this label to 1. Please also specify the `defaults.ckpt` label.
* `exportation....` you can specify whether to render the items of `[render_test, render_train, render_path, export_mesh]` after training by enable the corressponding label to 1.

Some pre-defined configurations (such as occNet, DVGO, nerf, iNGP, EG3D) can be found in `README_FactorField.py`.


## COPY RIGHT
* [Summer Day](https://www.rijksmuseum.nl/en/collection/SK-A-3005) - Credit goes to Johan Hendrik Weissenbruch and rijksmuseum.
* [Mars](https://solarsystem.nasa.gov/resources/933/true-colors-of-pluto/) - Credit goes to NASA.
* [Albert](https://cdn.loc.gov/service/pnp/cph/3b40000/3b46000/3b46000/3b46036v.jpg) - Credit goes to Orren Jack Turner.
* [Girl With a Pearl Earring](http://profoundism.com/free_licenses.html) - Renovation copyright Koorosh Orooj (CC BY-SA 4.0).


## Citation
If you find our code or paper helpful, please consider citing both of these papers:
```
@article{Chen2023factor,
  title={Factor Fields: A Unified Framework for Neural Fields and Beyond},
  author={Chen, Anpei and Xu, Zexiang and Wei, Xinyue and Tang, Siyu and Su, Hao and Geiger, Andreas},
  journal={arXiv preprint arXiv:2302.01226},
  year={2023}
}

@article{Chen2023SIGGRAPH, 
 title={{Dictionary Fields: Learning a Neural Basis Decomposition}}, 
 author={Anpei, Chen and Zexiang, Xu and Xinyue, Wei and Siyu, Tang and Hao, Su and Andreas, Geiger}, 
 booktitle={International Conference on Computer Graphics and Interactive Techniques (SIGGRAPH)}, 
 year={2023}}
```


================================================
FILE: README_FactorField.md
================================================
## nerf reconstruction with Dictionary field

```python
	for scene in ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']:
		cmd = f'python train_basis.py configs/nerf.yaml defaults.expname={scene} ' \
			    f'dataset.datadir=./data/nerf_synthetic/{scene} ' 
```

## different model design choices

```python
	choice_dict = {
			'-grid': '', \
			'-DVGO-like': 'model.basis_type=none model.coeff_reso=80', \
			'-noC': 'model.coeff_type=none', \
			'-SL':'model.basis_dims=[18]  model.basis_resos=[70] model.freq_bands=[8.]', \
			'-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \
			'-iNGP-like': 'model.basis_type=hash  model.coeff_type=none', \
			'-hash': f'model.basis_type=hash model.coef_init=1.0 ', \
			'-sinc': f'model.basis_mapping=sinc', \
			'-tria': f'model.basis_mapping=triangle', \
			'-vm': f'model.coeff_type=vm model.basis_type=vm', \
			'-mlpB': 'model.basis_type=mlp', \
			'-mlpC': 'model.coeff_type=mlp', \
			'-occNet': f'model.basis_type=x model.coeff_type=none model.basis_mapping=x model.num_layers=8 model.hidden_dim=256 ', \
			'-nerf': f'model.basis_type=x model.coeff_type=none model.basis_mapping=trigonometric ' \
					f'model.num_layers=8 model.hidden_dim=256 ' \
					f'model.freq_bands=[1.,2.,4.,8.,16.,32.,64,128,256.,512.] model.basis_dims=[1,1,1,1,1,1,1,1,1,1] model.basis_resos=[1024,512,256,128,64,32,16,8,4,2]', \
			'-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \
			'-vm-sl': f'model.coeff_type=vm model.basis_type=vm model.coef_init=1.0 model.basis_dims=[18] model.freq_bands=[1.] model.basis_resos=[64] model.total_params=1308416 ', \
		'-DCT':'model.basis_type=fix-grid', \
		}

	for name in choice_dict.keys(): 
		for scene in [ 'ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials']:

			cmd = f"python train_per_scene.py configs/nerf.yaml defaults.expname={scene}{name} dataset.datadir=./data/nerf_synthetic/{scene} {config}"
```

## generalized nerf
Your can choice of the the following design choice for testing.
```python
	choice_dict = {
			'-grid': '', \
			'-DVGO-like': 'model.basis_type=none model.coeff_reso=48',
			'-SL':'model.basis_dims=[72]  model.basis_resos=[48] model.freq_bands=[6.]', \
			'-CP': f'model.coeff_type=vec model.basis_type=cp model.freq_bands=[1.,1.,1.,1.,1.,1.] model.basis_resos=[512,512,512,512,512,512] model.basis_dims=[32,32,32,32,32,32]', \
			'-hash': f'model.basis_type=hash model.coef_init=1.0 ', \
			'-sinc': f'model.basis_mapping=sinc', \
			'-tria': f'model.basis_mapping=triangle', \
			'-vm': f'model.coeff_type=vm model.basis_type=vm', \
			'-mlpB': 'model.basis_type=mlp', \
			'-mlpC': 'model.coeff_type=mlp', \
			'-hash-sl': f'model.basis_type=hash model.coef_init=1.0 model.basis_dims=[16] model.freq_bands=[8.] model.basis_resos=[64] ', \
			'-vm-sl': f'model.coeff_type=vm model.basis_type=vm model.coef_init=1.0 model.basis_dims=[18] model.freq_bands=[1.] model.basis_resos=[64] model.total_params=1308416 ', \
		  '-DCT':'model.basis_type=fix-grid', \
		}
    
	for name in choice_dict.keys(): #
		cmd = f'python train_across_scene.py configs/nerf_set.yaml defaults.expname=google-obj{name} {config} ' \
				f'training.volume_resoFinal=128 dataset.datadir=./data/google_scanned_objects/'
```

You can also fine tune of the trained model for a new scene:

```python
	for views in  [5]:#3,
		for name in choice_dict.keys():  #
			for scene in [183]:#183,199,298,467,957,244,963,527,

				cmd = f'python train_across_scene_ft.py configs/nerf_ft.yaml defaults.expname=google_objs_{name}_{scene}_{views}_views ' \
					f'{config} training.n_iters=10000 ' \
					f'dataset.train_views={views} ' \
					f'dataset.train_scene_list=[{scene}] ' \
					f'dataset.test_scene_list=[{scene}] ' \
					f'dataset.datadir=./data/google_scanned_objects/ ' \
					f'defaults.ckpt=./logs/google-obj{name}//google-obj{name}.th'
```

# render path after optimization
```python
	for views in  [5]:
		for name in choice_dict.keys():  #
			config = commands[name].replace(",", "','")
			for scene in [183]:#183,199,298,467,957,244,963,527,681,948

				cmd = f'python train_across_scene.py configs/nerf_ft.yaml defaults.expname=google_objs_{name}_{scene}_{views}_views ' \
					f'{config} training.n_iters=10000 ' \
					f'dataset.train_views={views} exporation.render_only=True exporation.render_path=True exporation.render_test=False ' \
					f'dataset.train_scene_list=[{scene}] ' \
					f'dataset.test_scene_list=[{scene}] ' \
					f'dataset.datadir=./data/google_scanned_objects/ ' \
					f'defaults.ckpt=./logs/google_objs_{name}_{scene}_{views}_views//google_objs_{name}_{scene}_{views}_views.th'
```

================================================
FILE: configs/360_v2.yaml
================================================

defaults:
  expname: basis_room_real_mask
  logdir: ./logs

  ckpt: null                  # help='specific weights npy file to reload for coarse network'

model:
  basis_dims: [5,5,5,2,2,2]
  basis_resos: [ 64,  83, 102, 121, 140, 160]
  coeff_reso: 16
  coef_init: 0.01
  phases: [0.0]

  coef_mode: bilinear
  basis_mode: bilinear

  freq_bands: [ 1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.]

  kernel_mapping_type: 'sawtooth'

  in_dim: 3
  out_dim: 32
  num_layers: 2
  hidden_dim: 128

dataset:
  # loader options
  dataset_name: llff # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
  datadir: /home/anpei/code/NeuBasis/data/360_v2/room/
  ndc_ray: 0
  is_unbound: True

  with_depth: 0
  downsample_train: 4.0
  downsample_test: 4.0

  N_vis: 5
  vis_every: 5000

training:

  n_iters: 30000
  batch_size: 4096

  volume_resoInit: 128 # 128**3:
  volume_resoFinal: 320 # 300**3

  upsamp_list: [2000,3000,4000,5500]
  update_AlphaMask_list: [2500]
  shrinking_list: [-1]

  L1_weight_inital: 0.0
  L1_weight_rest: 0.0

  TV_weight_density: 0.0
  TV_weight_app: 0.00

exportation:
  render_only: 0
  render_test: 1
  render_train: 0
  render_path: 0
  export_mesh: 0
  export_mesh_only: 0

renderer:
  shadingMode: MLP_Fea
  num_layers: 3
  hidden_dim: 128

  fea2denseAct: 'relu'
  density_shift: -10
  distance_scale: 25.0

  view_pe: 6
  fea_pe: 2

  lindisp: 0
  perturb: 1          # help='set to 0. for no jitter, 1. for jitter'

  step_ratio: 0.5
  max_samples: 1600

  alphaMask_thres: 0.04
  rayMarch_weight_thres: 1e-3












================================================
FILE: configs/defaults.yaml
================================================

defaults:
  expname:  basis_lego
  logedir: ./logs

  mode: 'reconstruction'

  progress_refresh_rate: 10

  add_timestamp: 0

model:
  basis_dims: [4,4,4,2,2,2]
  basis_resos: [32,51,70,89,108,128]
  coeff_reso: 32
  total_params: 10744166
  T_basis: 0
  T_coeff: 0

  coef_init: 1.0
  coef_mode: bilinear
  basis_mode: bilinear

  freq_bands: [1.0000, 1.3300, 1.7689, 2.3526, 3.1290, 4.1616]


  basis_mapping: 'sawtooth'
  with_dropout: False

  in_dim: 3
  out_dim: 32
  num_layers: 2
  hidden_dim: 128

dataset:
  # loader options
  dataset_name: blender # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
  datadir: ./data/nerf_synthetic/lego

  with_depth: 0
  downsample_train: 1.0
  downsample_test: 1.0

  is_unbound: False

training:
  # training options
  batch_size: 4096
  n_iters: 30000

  # learning rate
  lr_small: 0.001
  lr_large: 0.02

  lr_decay_iters: -1
  lr_decay_target_ratio: 0.1    # help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters'
  lr_upsample_reset: 1          # help='reset lr to inital after upsampling'

  # loss
  L1_weight_inital: 0.0         # help='loss weight'
  L1_weight_rest: 0
  Ortho_weight: 0.0
  TV_weight_density: 0.0
  TV_weight_app: 0.0

  # optimiziable
  coeff: True
  basis: True
  linear_mat: True
  renderModule: True




================================================
FILE: configs/image.yaml
================================================

defaults:
  expname: basis_image
  logdir: ./logs

  mode: 'image'

  ckpt: null                  # help='specific weights npy file to reload for coarse network'

model:
  basis_dims: [32,32,32,16,16,16]
  basis_resos: [32,51,70,89,108,128]
  freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]

  total_params: 1426063 # albert
  # total_params: 61445328 # pluto
  # total_params: 71848800 #Girl_with_a_Pearl_Earring
  # total_params: 37138096 # Weissenbruch_Jan_Hendrik_The_Shipping_Canal_at_Rijswijk.jpeg_base
  
  coeff_type: 'grid'
  basis_type: 'grid'
  
  coef_init: 0.001

  coef_mode: nearest
  basis_mode: nearest
  basis_mapping: 'sawtooth'


  in_dim: 2
  out_dim: 3
  num_layers: 2
  hidden_dim: 64
  with_dropout: False
  
dataset:
  # loader options
  dataset_name: image
  datadir: "../data/image/albert.exr"
  # datadir: "../data/image//pluto.jpeg"
  # datadir: "../data/image//Girl_with_a_Pearl_Earring.jpeg"
  # datadir: "../data/image//Weissenbruch_Jan_Hendrik_The_Shipping_Canal_at_Rijswijk.jpeg"


training:
  n_iters: 10000
  batch_size: 102400

  # learning rate
  lr_small: 0.002
  lr_large: 0.002














 

================================================
FILE: configs/image_intro.yaml
================================================

defaults:
  expname: basis_image
  logdir: ./logs

  mode: 'demo'

  ckpt: null                  # help='specific weights npy file to reload for coarse network'

model:
  in_dim: 2
  out_dim: 1

  basis_dims: [32,32,32,16,16,16]
  basis_resos: [32,51,70,89,108,128]
  freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]


  


  # occNet
  coeff_type: 'none'
  basis_type: 'x'
  basis_mapping: 'x'
  num_layers: 8
  hidden_dim: 256

  
  # coef_init: 0.001

  # coef_mode: nearest
  # basis_mode: nearest
  # basis_mapping: 'sawtooth'

  with_dropout: False
  
dataset:
  # loader options
  dataset_name: image
  datadir: ../data/image/cat_occupancy.png


training:
  n_iters: 10000
  batch_size: 102400

  # learning rate
  lr_small: 0.0002
  lr_large: 0.0002














 

================================================
FILE: configs/image_set.yaml
================================================

defaults:
  expname: basis_image
  logdir: ./logs

  mode: 'images'

  ckpt: null                  # help='specific weights npy file to reload for coarse network'

model:
  basis_dims: [32,32,32,16,16,16]
  basis_resos:  [32,51,70,89,108,128]
  freq_bands: [2. , 3.2, 4.4, 5.6, 6.8, 8.]
  
  
  coeff_reso: 32
  total_params: 1024000 
  
  coef_init: 0.001

  coef_mode: bilinear
  basis_mode: bilinear


  coeff_type: 'grid'
  basis_type: 'grid'

  in_dim: 3
  out_dim: 3
  num_layers: 2
  hidden_dim: 64
  with_dropout: True

dataset:
  # loader options
  dataset_name: images 
  datadir: data/ffhq/ffhq_512.npy

training:
  n_iters: 300000
  batch_size: 40960

  # learning rate
  lr_small: 0.002
  lr_large: 0.002
















================================================
FILE: configs/nerf.yaml
================================================
defaults:
  expname: basis_lego
  logdir: ./logs

  mode: 'reconstruction'

  ckpt: null                  # help='specific weights npy file to reload for coarse network'

model:
  coeff_reso: 32

  basis_dims: [4,4,4,2,2,2]
  basis_resos: [32,51,70,89,108,128]
  freq_bands:  [2. , 3.2, 4.4, 5.6, 6.8, 8.]


  coef_init: 1.0
  phases: [0.0]
  total_params: 5308416

  coef_mode: bilinear
  basis_mode: bilinear

  coeff_type: 'grid'
  basis_type: 'grid'
  basis_mapping: 'sawtooth'

  in_dim: 3
  out_dim: 32
  num_layers: 2
  hidden_dim: 64

dataset:
  # loader options
  dataset_name: blender # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
  datadir: ./data/nerf_synthetic/lego
  ndc_ray: 0

  with_depth: 0
  downsample_train: 1.0
  downsample_test: 1.0

  N_vis: 5
  vis_every: 100000
  scene_reso: 768

training:

  n_iters: 30000
  batch_size: 4096

  volume_resoInit: 128 # 128**3:
  volume_resoFinal: 300 # 300**3

  upsamp_list: [2000,3000,4000,5500,7000]
  update_AlphaMask_list: [2500,4000]
  shrinking_list: [500]

  L1_weight_inital: 0.0
  L1_weight_rest: 0.0

  TV_weight_density: 0.000
  TV_weight_app: 0.00

exportation:
  render_only: 0
  render_test: 1
  render_train: 0
  render_path: 0
  export_mesh: 0
  export_mesh_only: 0

renderer:
  shadingMode: MLP_Fea
  num_layers: 3
  hidden_dim: 128

  fea2denseAct: 'softplus'
  density_shift: -10
  distance_scale: 25.0

  view_pe: 6
  fea_pe: 2

  lindisp: 0
  perturb: 1          # help='set to 0. for no jitter, 1. for jitter'

  step_ratio: 0.5
  max_samples: 1200

  alphaMask_thres: 0.02
  rayMarch_weight_thres: 1e-3


================================================
FILE: configs/nerf_ft.yaml
================================================
defaults:
  expname: basis
  logdir: ./logs

  mode: 'reconstructions'

  ckpt: null                  # help='specific weights npy file to reload for coarse network'

model:
  coeff_reso: 16

  basis_dims: [16,16,16,8,8,8]
  basis_resos: [32,51,70,89,108,128]
  freq_bands:  [2. , 3.2, 4.4, 5.6, 6.8, 8.]

  with_dropout: True

  coef_init: 1.0
  phases: [0.0]
  total_params: 5308416

  coef_mode: bilinear
  basis_mode: bilinear

  coeff_type: 'grid'
  basis_type: 'grid'
  basis_mapping: 'sawtooth'

  in_dim: 3
  out_dim: 32
  num_layers: 2
  hidden_dim: 64

dataset:
  # loader options
  dataset_name: google_objs # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
  datadir: /vlg-nfs/anpei/dataset/google_scanned_objects
  ndc_ray: 0
  train_scene_list: [100]
  test_scene_list: [100]
  train_views: 5

  with_depth: 0
  downsample_train: 1.0
  downsample_test: 1.0

  N_vis: 5
  vis_every: 100000
  scene_reso: 768

training:

  n_iters: 5000
  batch_size: 4096

  volume_resoInit: 128 # 128**3:
  volume_resoFinal: 300 # 300**3

  upsamp_list: [2000,3000,4000]
  update_AlphaMask_list: [1500]
  shrinking_list: [-1]

  L1_weight_inital: 0.0
  L1_weight_rest: 0.0

  TV_weight_density: 0.000
  TV_weight_app: 0.00

  # optimiziable
  coeff: True
  basis: False
  linear_mat: False
  renderModule: False

exportation:
  render_only: 0
  render_test: 1
  render_train: 0
  render_path: 0
  export_mesh: 0
  export_mesh_only: 0

renderer:
  shadingMode: MLP_Fea
  num_layers: 3
  hidden_dim: 128

  fea2denseAct: 'softplus'
  density_shift: -10
  distance_scale: 25.0

  view_pe: 6
  fea_pe: 2

  lindisp: 0
  perturb: 1          # help='set to 0. for no jitter, 1. for jitter'

  step_ratio: 0.5
  max_samples: 1200

  alphaMask_thres: 0.02
  rayMarch_weight_thres: 1e-3

================================================
FILE: configs/nerf_set.yaml
================================================
defaults:
  expname: basis_no_relu_lego
  logdir: ./logs

  mode: 'reconstructions'

  ckpt: null                  # help='specific weights npy file to reload for coarse network'

model:
  coeff_reso: 16

  basis_dims: [16,16,16,8,8,8]
#  basis_resos: [32,51,70,89,108,128]
  basis_resos: [32,51,70,89,108,128]
#  freq_bands:  [1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.]
  freq_bands:  [2. , 3.2, 4.4, 5.6, 6.8, 8.]

  with_dropout: True

  coef_init: 1.0
  phases: [0.0]
  total_params: 5308416

  coef_mode: bilinear
  basis_mode: bilinear

  coeff_type: 'grid'
  basis_type: 'grid'
  basis_mapping: 'sawtooth'

  in_dim: 3
  out_dim: 32
  num_layers: 2
  hidden_dim: 64

dataset:
  # loader options
  dataset_name: google_objs # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
  datadir: /vlg-nfs/anpei/dataset/google_scanned_objects
  ndc_ray: 0
  train_scene_list: [0,100]
  test_scene_list: [0]
  train_views: 100

  with_depth: 0
  downsample_train: 1.0
  downsample_test: 1.0

  N_vis: 5
  vis_every: 100000
  scene_reso: 768

training:

  n_iters: 50000
  batch_size: 4096

  volume_resoInit: 128 # 128**3:
  volume_resoFinal: 256 # 300**3

  upsamp_list: [2000,3000,4000,5500,7000]
  update_AlphaMask_list: [-1]
  shrinking_list: [-1]

  L1_weight_inital: 0.0
  L1_weight_rest: 0.0

  TV_weight_density: 0.000
  TV_weight_app: 0.00

exportation:
  render_only: 0
  render_test: 0
  render_train: 0
  render_path: 0
  export_mesh: 0
  export_mesh_only: 0

renderer:
  shadingMode: MLP_Fea
  num_layers: 3
  hidden_dim: 128

  fea2denseAct: 'softplus'
  density_shift: -10
  distance_scale: 25.0

  view_pe: 6
  fea_pe: 2

  lindisp: 0
  perturb: 1          # help='set to 0. for no jitter, 1. for jitter'

  step_ratio: 0.5
  max_samples: 1200

  alphaMask_thres: 0.005
  rayMarch_weight_thres: 1e-3


================================================
FILE: configs/sdf.yaml
================================================
defaults:
  expname: basis_sdf
  logdir: ./logs

  mode: 'sdf'

  ckpt: null                  # help='specific weights npy file to reload for coarse network'

model:
  basis_dims: [4,4,4,2,2,2]
  basis_resos: [32,51,70,89,108,128]
  freq_bands:  [2. , 3.2, 4.4, 5.6, 6.8, 8.]
  
  total_params: 5313942
  
  coeff_reso: 32
  coef_init: 0.05

  coef_mode: bilinear
  basis_mode: bilinear


  coeff_type: 'grid'
  basis_type: 'grid'
  kernel_mapping_type: 'sawtooth'

  in_dim: 3
  out_dim: 1
  num_layers: 1
  hidden_dim: 64

dataset:
  # loader options
  dataset_name: sdf 
  datadir: "../data/mesh/statuette_close.npy"
  
  scene_reso: 384
  

training:
  n_iters: 10000
  batch_size: 40960

  # learning rate
  lr_small: 0.002
  lr_large: 0.02

================================================
FILE: configs/tnt.yaml
================================================

defaults:
  expname: basis_truck
  logdir: ./logs

  ckpt: null                  # help='specific weights npy file to reload for coarse network'

model:
  coeff_reso: 32

#  basis_dims: [8,4,2]
#  basis_resos: [32,64,128]
#  freq_bands:  [3.0,4.7,6.8]

## 32.88
#  basis_dims: [3, 3, 3, 3, 3, 3, 3]
#  basis_resos: [64, 64, 64, 64, 64, 64, 64]
##  freq_bands: [1.52727273, 2.58181818, 3.63636364, 4.69090909, 5.21818182, 6.27272727, 6.8]
#  freq_bands: [2., 3., 4.,5.,6.,7.,8.]

#  basis_dims: [5,5,5,2,2,2]
#  basis_resos: [32,51,70,89,108,128]
#  coeff_reso: 26
#  freq_bands:  [1.0000, 1.7689, 2.3526, 3.1290, 4.1616, 6.]

  basis_dims: [4,4,4,2,2,2]
  basis_resos: [32,51,70,89,108,128]
  freq_bands:  [2. , 3.2, 4.4, 5.6, 6.8, 8.]
#  freq_bands:  [2. , 2.8, 3.6, 4.4, 5.2, 6.]

  coef_init: 1.0
  phases: [0.0]
  total_params: 5744166

  coef_mode: bilinear
  basis_mode: bilinear

  kernel_mapping_type: 'sawtooth'

  in_dim: 3
  out_dim: 32
  num_layers: 2
  hidden_dim: 64

dataset:
  # loader options
  dataset_name: tankstemple # choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data']
  datadir: ./data/TanksAndTemple/Truck
  ndc_ray: 0

  with_depth: 0
  downsample_train: 1.0
  downsample_test: 1.0

  N_vis: 5
  vis_every: 100000
  scene_reso: 768

training:

  n_iters: 30000
  batch_size: 4096

  volume_resoInit: 128 # 128**3:
  volume_resoFinal: 320 # 300**3

#  upsamp_list: [2000,4000,7000]
#  update_AlphaMask_list: [2000,3000]
  upsamp_list: [2000,3000,4000,5500,7000]
  update_AlphaMask_list: [2500,4000]
  shrinking_list: [500]

  L1_weight_inital: 0.0
  L1_weight_rest: 0.0

  TV_weight_density: 0.0
  TV_weight_app: 0.00

exportation:
  render_only: 0
  render_test: 1
  render_train: 0
  render_path: 0
  export_mesh: 0
  export_mesh_only: 0

renderer:
  shadingMode: MLP_Fea
  num_layers: 3
  hidden_dim: 128

  fea2denseAct: 'softplus'
  density_shift: -10
  distance_scale: 25.0

  view_pe: 2
  fea_pe: 2

  lindisp: 0
  perturb: 1          # help='set to 0. for no jitter, 1. for jitter'

  step_ratio: 0.5
  max_samples: 1200

  alphaMask_thres: 0.005
  rayMarch_weight_thres: 1e-3












================================================
FILE: dataLoader/__init__.py
================================================
from .llff import LLFFDataset
from .blender import BlenderDataset
from .nsvf import NSVF
from .tankstemple import TanksTempleDataset
from .your_own_data import YourOwnDataset
from .image import ImageDataset
from .image_set import ImageSetDataset
from .colmap import ColmapDataset
from .sdf import SDFDataset
from .blender_set import BlenderDatasetSet
from .google_objs import GoogleObjsDataset
from .dtu_objs import DTUDataset


dataset_dict = {'blender': BlenderDataset,
                'blender_set': BlenderDatasetSet,
               'llff':LLFFDataset,
               'tankstemple':TanksTempleDataset,
               'nsvf':NSVF,
                'own_data':YourOwnDataset,
                'image':ImageDataset,
                'images':ImageSetDataset,
                'sdf':SDFDataset,
                'colmap':ColmapDataset,
                'google_objs':GoogleObjsDataset,
                'dtu':DTUDataset}

================================================
FILE: dataLoader/blender.py
================================================
import torch, cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T

from .ray_utils import *


class BlenderDataset(Dataset):
    def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):

        # self.N_vis = N_vis
        self.split = split
        self.batch_size = batch_size
        self.root_dir = cfg.datadir
        self.is_stack = is_stack if is_stack is not None else 'train'!=split
        self.downsample = cfg.get(f'downsample_{self.split}')
        self.img_wh = (int(800 / self.downsample), int(800 / self.downsample))
        self.define_transforms()

        self.rot = torch.tensor([[0.65561799, -0.65561799, 0.37460659],
                                 [0.73729737, 0.44876192, -0.50498052],
                                 [0.16296514, 0.60727077, 0.77760181]])

        self.scene_bbox = (np.array([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]])).tolist()
        # self.scene_bbox = [[-0.8,-0.8,-0.22],[0.8,0.8,0.2]]
        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.define_proj_mat()

        self.white_bg = True
        self.near_far = [2.0, 6.0]

        # self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
        # self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)

    def read_depth(self, filename):
        depth = np.array(read_pfm(filename)[0], dtype=np.float32)  # (800, 800)
        return depth

    def read_meta(self):

        with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
            self.meta = json.load(f)

        w, h = self.img_wh
        self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x'])  # original focal length
        self.focal *= self.img_wh[0] / 800  # modify focal length to match size self.img_wh

        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(h, w, [self.focal, self.focal])  # (h, w, 3)
        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
        self.intrinsics = torch.tensor([[self.focal, 0, w / 2], [0, self.focal, h / 2], [0, 0, 1]]).float()

        self.image_paths = []
        self.poses = []
        self.all_rays = []
        self.all_rgbs = []
        self.all_masks = []
        self.all_depth = []
        self.downsample = 1.0

        img_eval_interval = 1 #if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
        idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
        # idxs = idxs[:10] if self.split=='train' else idxs
        for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):  # img_list:#

            frame = self.meta['frames'][i]
            pose = np.array(frame['transform_matrix']) @ self.blender2opencv
            c2w = torch.FloatTensor(pose)
            c2w[:3,-1] /= 1.5
            self.poses += [c2w]

            image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
            self.image_paths += [image_path]
            img = Image.open(image_path)

            if self.downsample != 1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (4, h, w)
            img = img.view(4, -1).permute(1, 0)  # (h*w, 4) RGBA
            img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
            self.all_rgbs += [img]

            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            # rays_o, rays_d = rays_o@self.rot, rays_d@self.rot
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)

        self.poses = torch.stack(self.poses)
        if not self.is_stack:
            self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)
            self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (len(self.meta['frames])*h*w, 3)

        #             self.all_depth = torch.cat(self.all_depth, 0)  # (len(self.meta['frames])*h*w, 3)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
            self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3)  # (len(self.meta['frames]),h,w,3)
            # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1])  # (len(self.meta['frames]),h,w,3)

        self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]),self.batch_size)

    def define_transforms(self):
        self.transform = T.ToTensor()

    def define_proj_mat(self):
        self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3]

    # def world2ndc(self,points,lindisp=None):
    #     device = points.device
    #     return (points - self.center.to(device)) / self.radius.to(device)

    def __len__(self):
        return len(self.all_rgbs) if self.split=='test' else 300000

    def __getitem__(self, idx):
        idx_rand = self.sampler.nextids() #torch.randint(0,len(self.all_rays),(self.batch_size,))
        sample = {'rays': self.all_rays[idx_rand], 'rgbs': self.all_rgbs[idx_rand]}
        return sample

================================================
FILE: dataLoader/blender_set.py
================================================
import torch, cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T

from .ray_utils import *


class BlenderDatasetSet(Dataset):
    def __init__(self, cfg, split='train'):

        # self.N_vis = N_vis
        self.root_dir = cfg.datadir
        self.split = split
        self.is_stack = False if 'train'==split else True
        self.downsample = cfg.get(f'downsample_{self.split}')
        self.img_wh = (int(800 / self.downsample), int(800 / self.downsample))
        self.define_transforms()

        self.rot = torch.tensor([[0.65561799, -0.65561799, 0.37460659],
                                 [0.73729737, 0.44876192, -0.50498052],
                                 [0.16296514, 0.60727077, 0.77760181]])

        self.scene_bbox = (np.array([[-1.0, -1.0, -1.0, 0], [1.0, 1.0, 1.0, 2]])).tolist()
        # self.scene_bbox = [[-0.8,-0.8,-0.22],[0.8,0.8,0.2]]
        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.define_proj_mat()

        self.white_bg = True
        self.near_far = [2.0, 6.0]

        # self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
        # self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)

    def read_depth(self, filename):
        depth = np.array(read_pfm(filename)[0], dtype=np.float32)  # (800, 800)
        return depth

    def read_meta(self):

        with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
            self.meta = json.load(f)

        w, h = self.img_wh
        self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x'])  # original focal length
        self.focal *= self.img_wh[0] / 800  # modify focal length to match size self.img_wh

        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(h, w, [self.focal, self.focal])  # (h, w, 3)
        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
        self.intrinsics = torch.tensor([[self.focal, 0, w / 2], [0, self.focal, h / 2], [0, 0, 1]]).float()

        self.image_paths = []
        self.poses = []
        self.all_rays = []
        self.all_rgbs = []
        self.all_masks = []
        self.all_depth = []
        self.downsample = 1.0

        img_eval_interval = 1 #if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
        idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
        for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):  # img_list:#

            frame = self.meta['frames'][i]
            pose = np.array(frame['transform_matrix']) @ self.blender2opencv
            c2w = torch.FloatTensor(pose)
            c2w[:3,-1] /= 1.5
            self.poses += [c2w]

            image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
            self.image_paths += [image_path]
            img = Image.open(image_path)

            if self.downsample != 1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (4, h, w)
            img = img.view(4, -1).permute(1, 0)  # (h*w, 4) RGBA
            img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
            self.all_rgbs += [img]

            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            # rays_o, rays_d = rays_o@self.rot, rays_d@self.rot

            scene_id = torch.ones_like(rays_o[...,:1])*0
            self.all_rays += [torch.cat([rays_o, rays_d, scene_id], 1)]  # (h*w, 6)

        self.poses = torch.stack(self.poses)
        if not self.is_stack:
            self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)
            self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (len(self.meta['frames])*h*w, 3)

        #             self.all_depth = torch.cat(self.all_depth, 0)  # (len(self.meta['frames])*h*w, 3)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
            self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3)  # (len(self.meta['frames]),h,w,3)
            # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1])  # (len(self.meta['frames]),h,w,3)

    def define_transforms(self):
        self.transform = T.ToTensor()

    def define_proj_mat(self):
        self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:, :3]

    # def world2ndc(self,points,lindisp=None):
    #     device = points.device
    #     return (points - self.center.to(device)) / self.radius.to(device)

    def __len__(self):
        return len(self.all_rgbs)

    def __getitem__(self, idx):
        rays = torch.cat((self.all_rays[idx],torch.tensor([0+0.5])),dim=-1)
        sample = {'rays': rays, 'rgbs': self.all_rgbs[idx]}
        return sample

================================================
FILE: dataLoader/colmap.py
================================================
import torch, cv2
from torch.utils.data import Dataset

from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T

from .ray_utils import *


class ColmapDataset(Dataset):
    def __init__(self, cfg, split='train'):

        self.cfg = cfg
        self.root_dir = cfg.datadir
        self.split = split
        self.is_stack = False if 'train'==split else True
        self.downsample = cfg.get(f'downsample_{self.split}')
        self.define_transforms()
        self.img_eval_interval = 8

        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])#np.eye(4)#
        self.read_meta()

        self.white_bg = cfg.get('white_bg')
        
        # self.near_far = [0.1,2.0]


    def read_meta(self):

        if os.path.exists(f'{self.root_dir}/transforms.json'):
            self.meta = load_json(f'{self.root_dir}/transforms.json')
            i_test = np.arange(0, len(self.meta['frames']), self.img_eval_interval)  # [np.argmin(dists)]
            idxs = i_test if self.split != 'train' else list(set(np.arange(len(self.meta['frames']))) - set(i_test))
        else:
            self.meta = load_json(f'{self.root_dir}/transforms_{self.split}.json')
            idxs = np.arange(0, len(self.meta['frames']))  # [np.argmin(dists)]
            inv_split = 'train' if self.split!='train' else 'test'
            self.meta['frames'] += load_json(f'{self.root_dir}/transforms_{inv_split}.json')['frames']
            print(len(self.meta['frames']),len(idxs))


        self.scale = self.meta.get('scale', 0.5)
        self.offset = torch.FloatTensor(self.meta.get('offset', [0.0,0.0,0.0]))
        # self.scene_bbox = (torch.tensor([[-6.,-7.,-10.0],[6.,7.,10.]])/5).tolist()
        # self.scene_bbox = [[-1., -1., -1.0], [1., 1., 1.]]

        # center, radius = torch.tensor([-0.082157, 2.415426,-3.703080]), torch.tensor([7.36916, 11.34958, 20.1616])/2
        # self.scene_bbox = torch.stack([center-radius, center+radius]).tolist()

        h, w = int(self.meta.get('h')), int(self.meta.get('w'))
        cx, cy = self.meta.get('cx'), self.meta.get('cy')
        self.focal = [self.meta.get('fl_x'), self.meta.get('fl_y')]

        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(h, w, self.focal, center=[cx, cy])  # (h, w, 3)
        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
        # self.intrinsics = torch.FloatTensor([[self.focal[0], 0, cx], [0, self.focal[1], cy], [0, 0, 1]])
        # self.intrinsics[:2] /= self.downsample

        poses = pose_from_json(self.meta, self.blender2opencv)
        poses, self.scene_bbox = orientation(poses, f'{self.root_dir}/colmap_text/points3D.txt')

        self.image_paths = []
        self.poses = []
        self.all_rays = []
        self.all_rgbs = []
        self.all_masks = []
        self.all_depth = []

        self.img_wh = [w,h]

        for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):  # img_list:#

            frame = self.meta['frames'][i]
            c2w = torch.FloatTensor(poses[i])
            # c2w[:3,3] = (c2w[:3,3]*self.scale + self.offset)*2-1
            self.poses += [c2w]

            image_path = os.path.join(self.root_dir, frame['file_path'])
            self.image_paths += [image_path]
            img = Image.open(image_path)

            if self.downsample != 1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)

            img = self.transform(img)
            if img.shape[0]==4:
                img = img[:3] * img[-1:] + (1 - img[-1:])  # blend A to RGB
            img = img.view(3, -1).permute(1, 0)
            self.all_rgbs += [img]


            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)

        self.poses = torch.stack(self.poses)
        if not self.is_stack:
            self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)
            self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (len(self.meta['frames])*h*w, 3)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
            self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1],3)  # (len(self.meta['frames]),h,w,3)

    def define_transforms(self):
        self.transform = T.ToTensor()


    def __len__(self):
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        if self.split == 'train':  # use data in the buffers
            sample = {'rays': self.all_rays[idx],
                      'rgbs': self.all_rgbs[idx]}

        else:  # create data for each image separately

            img = self.all_rgbs[idx]
            rays = self.all_rays[idx]

            sample = {'rays': rays,
                      'rgbs': img}
        return sample

================================================
FILE: dataLoader/colmap2nerf.py
================================================
#!/usr/bin/env python3

# Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import argparse
import os
from pathlib import Path, PurePosixPath

import numpy as np
import json
import sys
import math
import cv2
import os
import shutil

def parse_args():
	parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place")

	parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also")
	parser.add_argument("--video_fps", default=2)
	parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video")
	parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder")
	parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images")
	parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename")
	parser.add_argument("--colmap_camera_model", default="OPENCV", choices=["SIMPLE_PINHOLE", "PINHOLE", "SIMPLE_RADIAL", "RADIAL","OPENCV"], help="camera model")
	parser.add_argument("--colmap_camera_params", default="", help="intrinsic parameters, depending on the chosen model.  Format: fx,fy,cx,cy,dist")
	parser.add_argument("--images", default="images", help="input path to the images")
	parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)")
	parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16")
	parser.add_argument("--skip_early", default=0, help="skip this many images from the start")
	parser.add_argument("--keep_colmap_coords", action="store_true", help="keep transforms.json in COLMAP's original frame of reference (this will avoid reorienting and repositioning the scene for preview and rendering)")
	parser.add_argument("--out", default="transforms.json", help="output path")
	parser.add_argument("--vocab_path", default="", help="vocabulary tree path")
	args = parser.parse_args()
	return args

def do_system(arg):
	print(f"==== running: {arg}")
	err = os.system(arg)
	if err:
		print("FATAL: command failed")
		sys.exit(err)

def run_ffmpeg(args):
	if not os.path.isabs(args.images):
		args.images = os.path.join(os.path.dirname(args.video_in), args.images)
	images = args.images
	video = args.video_in
	fps = float(args.video_fps) or 1.0
	print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.")
	if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
		sys.exit(1)
	try:
		shutil.rmtree(images)
	except:
		pass
	do_system(f"mkdir {images}")

	time_slice_value = ""
	time_slice = args.time_slice
	if time_slice:
	    start, end = time_slice.split(",")
	    time_slice_value = f",select='between(t\,{start}\,{end})'"
	do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg")

def run_colmap(args):
	db = args.colmap_db
	images = "\"" + args.images + "\""
	db_noext=str(Path(db).with_suffix(""))

	if args.text=="text":
		args.text=db_noext+"_text"
	text=args.text
	sparse=db_noext+"_sparse"
	print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}")
	if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
		sys.exit(1)
	if os.path.exists(db):
		os.remove(db)
	do_system(f"colmap feature_extractor --ImageReader.camera_model {args.colmap_camera_model} --ImageReader.camera_params \"{args.colmap_camera_params}\" --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}")
	match_cmd = f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}"
	if args.vocab_path:
		match_cmd += f" --VocabTreeMatching.vocab_tree_path {args.vocab_path}"
	do_system(match_cmd)
	try:
		shutil.rmtree(sparse)
	except:
		pass
	do_system(f"mkdir {sparse}")
	do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}")
	do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1")
	try:
		shutil.rmtree(text)
	except:
		pass
	do_system(f"mkdir {text}")
	do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT")


def variance_of_laplacian(image):
	return cv2.Laplacian(image, cv2.CV_64F).var()

def sharpness(imagePath):
	image = cv2.imread(imagePath)
	gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
	fm = variance_of_laplacian(gray)
	return fm

def qvec2rotmat(qvec):
	return np.array([
		[
			1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
			2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
			2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
		], [
			2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
			1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
			2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
		], [
			2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
			2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
			1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
		]
	])

def rotmat(a, b):
	a, b = a / np.linalg.norm(a), b / np.linalg.norm(b)
	v = np.cross(a, b)
	c = np.dot(a, b)
	s = np.linalg.norm(v)
	kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
	return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10))

def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel
	da = da / np.linalg.norm(da)
	db = db / np.linalg.norm(db)
	c = np.cross(da, db)
	denom = np.linalg.norm(c)**2
	t = ob - oa
	ta = np.linalg.det([t, db, c]) / (denom + 1e-10)
	tb = np.linalg.det([t, da, c]) / (denom + 1e-10)
	if ta > 0:
		ta = 0
	if tb > 0:
		tb = 0
	return (oa+ta*da+ob+tb*db) * 0.5, denom

############ orientation  ##############
def normalize(x):
    return x / np.linalg.norm(x)

def rotation_matrix_from_vectors(vec1, vec2):
	""" Find the rotation matrix that aligns vec1 to vec2
    :param vec1: A 3d "source" vector
    :param vec2: A 3d "destination" vector
    :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
    """
	a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
	v = np.cross(a, b)
	if any(v):  # if not all zeros then
		c = np.dot(a, b)
		s = np.linalg.norm(v)
		kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
		return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2))

	else:
		return np.eye(3)  # cross of all zeros only occurs on identical directions

def rotation_up(poses):
    up = normalize(np.linalg.lstsq(poses[:, :3, 1],np.ones((poses.shape[0],1)),rcond=None)[0])
    rot = rotation_matrix_from_vectors(up,np.array([0.,1.,0.]))
    return rot

def search_orientation(points):
    from scipy.spatial.transform import Rotation as R
    bbox_sizes,rot_mats,bboxs = [],[],[]
    for y_angle in np.linspace(-45,45,15):
        rotvec = np.array([0,y_angle,0])/180*np.pi
        rot = R.from_rotvec(rotvec).as_matrix()
        point_orientation = rot@points
        bbox = np.max(point_orientation,axis=1) - np.min(point_orientation,axis=1)
        bbox_sizes.append(np.prod(bbox))
        rot_mats.append(rot)
        bboxs.append(bbox)
    rot = rot_mats[np.argmin(bbox_sizes)]
    bbox = bboxs[np.argmin(bbox_sizes)]
    return rot,bbox

def load_point_txt(path):
    points = []
    with open(path, "r") as f:
        for line in f:
            if line[0] == "#":
                continue
            els = line.split(" ")
            points.append([float(els[1]),float(els[2]),float(els[3])])
    return np.stack(points)

# def load_c2ws(frames):
# 	c2ws =
# 	for f in frames:
# 		f["transform_matrix"] -= totp
#
# def oritation(transform_matrix, point_txt):


if __name__ == "__main__":
	args = parse_args()
	if args.video_in != "":
		run_ffmpeg(args)
	if args.run_colmap:
		run_colmap(args)
	AABB_SCALE = int(args.aabb_scale)
	SKIP_EARLY = int(args.skip_early)
	IMAGE_FOLDER = args.images
	TEXT_FOLDER = args.text
	OUT_PATH = args.out
	print(f"outputting to {OUT_PATH}...")
	with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f:
		angle_x = math.pi / 2
		for line in f:
			# 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691
			# 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224
			# 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443
			if line[0] == "#":
				continue
			els = line.split(" ")
			w = float(els[2])
			h = float(els[3])
			fl_x = float(els[4])
			fl_y = float(els[4])
			k1 = 0
			k2 = 0
			p1 = 0
			p2 = 0
			cx = w / 2
			cy = h / 2
			if els[1] == "SIMPLE_PINHOLE":
				cx = float(els[5])
				cy = float(els[6])
			elif els[1] == "PINHOLE":
				fl_y = float(els[5])
				cx = float(els[6])
				cy = float(els[7])
			elif els[1] == "SIMPLE_RADIAL":
				cx = float(els[5])
				cy = float(els[6])
				k1 = float(els[7])
			elif els[1] == "RADIAL":
				cx = float(els[5])
				cy = float(els[6])
				k1 = float(els[7])
				k2 = float(els[8])
			elif els[1] == "OPENCV":
				fl_y = float(els[5])
				cx = float(els[6])
				cy = float(els[7])
				k1 = float(els[8])
				k2 = float(els[9])
				p1 = float(els[10])
				p2 = float(els[11])
			else:
				print("unknown camera model ", els[1])
			# fl = 0.5 * w / tan(0.5 * angle_x);
			angle_x = math.atan(w / (fl_x * 2)) * 2
			angle_y = math.atan(h / (fl_y * 2)) * 2
			fovx = angle_x * 180 / math.pi
			fovy = angle_y * 180 / math.pi

	print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ")

	with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f:
		i = 0
		bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4])
		out = {
			"camera_angle_x": angle_x,
			"camera_angle_y": angle_y,
			"fl_x": fl_x,
			"fl_y": fl_y,
			"k1": k1,
			"k2": k2,
			"p1": p1,
			"p2": p2,
			"cx": cx,
			"cy": cy,
			"w": w,
			"h": h,
			"aabb_scale": AABB_SCALE,
			"frames": [],
		}

		up = np.zeros(3)
		for line in f:
			line = line.strip()
			if line[0] == "#":
				continue
			i = i + 1
			if i < SKIP_EARLY*2:
				continue
			if  i % 2 == 1:
				elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces)
				#name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9])))
				# why is this requireing a relitive path while using ^
				image_rel = os.path.relpath(IMAGE_FOLDER)
				name = str(f"./{image_rel}/{'_'.join(elems[9:])}")
				b=sharpness(name)
				print(name, "sharpness=",b)
				image_id = int(elems[0])
				qvec = np.array(tuple(map(float, elems[1:5])))
				tvec = np.array(tuple(map(float, elems[5:8])))
				R = qvec2rotmat(-qvec)
				t = tvec.reshape([3,1])
				m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
				c2w = np.linalg.inv(m)
				# c2w[0:3,2] *= -1 # flip the y and z axis
				# c2w[0:3,1] *= -1
				# c2w = c2w[[1,0,2,3],:] # swap y and z
				# c2w[2,:] *= -1 # flip whole world upside down

				# up += c2w[0:3,1]

				frame={"file_path":name,"sharpness":b,"transform_matrix": c2w}
				out["frames"].append(frame)

	# up = up / np.linalg.norm(up)
	# print("up vector was", up)
	# R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]
	# R = np.pad(R,[0,1])
	# R[-1, -1] = 1
	# for f in out["frames"]:
	# 	f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis

	nframes = len(out["frames"])

	# find a central point they are all looking at
	print("computing center of attention...")
	totw = 0.0
	totp = np.array([0.0, 0.0, 0.0])
	for f in out["frames"]:
		mf = f["transform_matrix"][0:3,:]
		for g in out["frames"]:
			mg = g["transform_matrix"][0:3,:]
			p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])
			if w > 0.01:
				totp += p*w
				totw += w
	totp /= totw
	print(totp) # the cameras are looking at totp
	for f in out["frames"]:
		f["transform_matrix"][0:3,3] -= totp

	avglen = 0.
	for f in out["frames"]:
		avglen += np.linalg.norm(f["transform_matrix"][0:3,3])
	avglen /= nframes
	print("avg camera distance from origin", avglen)
	for f in out["frames"]:
		f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized"



	for f in out["frames"]:
		f["transform_matrix"] = f["transform_matrix"].tolist()
	print(nframes,"frames")
	print(f"writing {OUT_PATH}")
	with open(OUT_PATH, "w") as outfile:
		json.dump(out, outfile, indent=2)

================================================
FILE: dataLoader/dtu_objs.py
================================================

import torch
import cv2 as cv
import numpy as np
import os
from glob import glob
from .ray_utils import *
from torch.utils.data import Dataset

# This function is borrowed from IDR: https://github.com/lioryariv/idr
def load_K_Rt_from_P(filename, P=None):
    if P is None:
        lines = open(filename).read().splitlines()
        if len(lines) == 4:
            lines = lines[1:]
        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
        P = np.asarray(lines).astype(np.float32).squeeze()

    out = cv.decomposeProjectionMatrix(P)
    K = out[0]
    R = out[1]
    t = out[2]

    K = K / K[2, 2]
    intrinsics = np.eye(4)
    intrinsics[:3, :3] = K

    pose = np.eye(4, dtype=np.float32)
    pose[:3, :3] = R.transpose()
    pose[:3, 3] = (t[:3] / t[3])[:, 0]

    return intrinsics, pose

def fps_downsample(points, n_points_to_sample):
    selected_points = np.zeros((n_points_to_sample, 3))
    selected_idxs = []
    dist = np.ones(points.shape[0]) * 100
    for i in range(n_points_to_sample):
        idx = np.argmax(dist)
        selected_points[i] = points[idx]
        selected_idxs.append(idx)
        dist_ = ((points - selected_points[i]) ** 2).sum(-1)
        dist = np.minimum(dist, dist_)

    return selected_idxs

class DTUDataset(Dataset):
    def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
        """
        img_wh should be set to a tuple ex: (1152, 864) to enable test mode!
        """
        # self.N_vis = N_vis
        self.split = split
        self.batch_size = batch_size
        self.root_dir = cfg.datadir
        self.is_stack = is_stack if is_stack is not None else 'train'!=split
        self.downsample = cfg.get(f'downsample_{self.split}')
        self.img_wh = (int(400 / self.downsample), int(300 / self.downsample))

        train_scene_idxs = sorted(cfg.train_scene_list)
        test_scene_idxs = cfg.test_scene_list
        if len(train_scene_idxs)==2:
            train_scene_idxs = list(range(train_scene_idxs[0],train_scene_idxs[1]))
        self.scene_idxs = train_scene_idxs if self.split=='train' else test_scene_idxs
        print(self.scene_idxs)
        self.train_views = cfg.train_views
        self.scene_num = len(self.scene_idxs)
        self.test_index = test_scene_idxs
        # if 'test' == self.split:
        #     self.test_index = train_scene_idxs.index(test_scene_idxs[0])

        self.scene_path_list = [os.path.join(self.root_dir, f"scan{i}") for i in self.scene_idxs]
        # self.scene_path_list = sorted(glob(os.path.join(self.root_dir, "scan*")))

        self.read_meta()
        self.white_bg = False

    def read_meta(self):
        self.aabbs = []
        self.all_rgb_files,self.all_pose_files,self.all_intrinsics_files = {},{},{}
        for i, scene_idx in enumerate(self.scene_idxs):

            scene_path = self.scene_path_list[i]
            camera_dict = np.load(os.path.join(scene_path, 'cameras.npz'))

            self.all_rgb_files[scene_idx] = [
                os.path.join(scene_path, "image", f)
                for f in sorted(os.listdir(os.path.join(scene_path, "image")))
            ]

            # world_mat is a projection matrix from world to image
            n_images = len(self.all_rgb_files[scene_idx])
            world_mats_np = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
            scale_mats_np = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
            object_scale_mat = camera_dict['scale_mat_0']
            self.aabbs.append(self.get_bbox(scale_mats_np,object_scale_mat))

            # W,H = self.img_wh
            intrinsics_scene, poses_scene = [], []
            for img_idx, (scale_mat, world_mat) in enumerate(zip(scale_mats_np, world_mats_np)):
                P = world_mat @ scale_mat
                P = P[:3, :4]
                intrinsic, c2w = load_K_Rt_from_P(None, P)

                c2w = torch.from_numpy(c2w).float()
                intrinsic = torch.from_numpy(intrinsic).float()
                intrinsic[:2] /= self.downsample

                poses_scene.append(c2w)
                intrinsics_scene.append(intrinsic)

            self.all_pose_files[scene_idx] = np.stack(poses_scene)
            self.all_intrinsics_files[scene_idx] = np.stack(intrinsics_scene)

        self.aabbs[0][0].append(0)
        self.aabbs[0][1].append(self.scene_num)
        self.scene_bbox = self.aabbs[0]
        print(self.scene_bbox)
        if self.split=='test' or self.scene_num==1:
            self.load_data(self.scene_idxs[0],range(49))

    def load_data(self, scene_idx, img_idx=None):
        self.all_rays = []

        n_views = len(self.all_pose_files[scene_idx])
        cam_xyzs = self.all_pose_files[scene_idx][:,:3, -1]
        idxs = fps_downsample(cam_xyzs, min(self.train_views, n_views)) if img_idx is None else img_idx
        # if "test" == self.split:
        #     idxs = [item for item in list(range(n_views)) if item not in idxs]
        #     if len(idxs)==0:
        #         idxs = list(range(n_views))

        images_np = np.stack([cv.resize(cv.imread(self.all_rgb_files[scene_idx][idx]), self.img_wh) for idx in idxs]) / 255.0
        self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[..., [2, 1, 0]])  # [n_images, H, W, 3]

        for c2w,intrinsic in zip(self.all_pose_files[scene_idx][idxs],self.all_intrinsics_files[scene_idx][idxs]):
            rays_o, rays_d = self.gen_rays_at(torch.from_numpy(intrinsic).float(), torch.from_numpy(c2w).float())
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)

        if not self.is_stack:
            self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)
            self.all_rgbs = self.all_rgbs.reshape(-1, 3)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
            self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)

            # self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size)


        # def read_meta(self):
    #
    #     images_lis = sorted(glob(os.path.join(self.root_dir, 'image/*.png')))
    #     images_np = np.stack([cv.resize(cv.imread(im_name), self.img_wh) for im_name in images_lis]) / 255.0
    #     # masks_lis = sorted(glob(os.path.join(self.root_dir, 'mask/*.png')))
    #     # masks_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in masks_lis])>128
    #
    #     self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[..., [2, 1, 0]])  # [n_images, H, W, 3]
    #     # self.all_masks  = torch.from_numpy(masks_np>0)   # [n_images, H, W, 3]
    #     self.img_wh = [self.all_rgbs.shape[2], self.all_rgbs.shape[1]]
    #
    #     # world_mat is a projection matrix from world to image
    #     n_images = len(images_lis)
    #     world_mats_np = [self.camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
    #     self.scale_mats_np = [self.camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
    #
    #     # W,H = self.img_wh
    #     self.all_rays = []
    #     self.intrinsics, self.poses = [], []
    #     for img_idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, world_mats_np)):
    #         P = world_mat @ scale_mat
    #         P = P[:3, :4]
    #         intrinsic, c2w = load_K_Rt_from_P(None, P)
    #
    #         c2w = torch.from_numpy(c2w).float()
    #         intrinsic = torch.from_numpy(intrinsic).float()
    #         intrinsic[:2] /= self.downsample
    #
    #         self.poses.append(c2w)
    #         self.intrinsics.append(intrinsic)
    #
    #         rays_o, rays_d = self.gen_rays_at(intrinsic, c2w)
    #         self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)
    #
    #     self.intrinsics, self.poses = torch.stack(self.intrinsics), torch.stack(self.poses)
    #
    #     # self.all_rgbs[~self.all_masks] = 1.0
    #     if not self.is_stack:
    #         self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)
    #         self.all_rgbs = self.all_rgbs.reshape(-1, 3)
    #     else:
    #         self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
    #         self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)
    #
    #     self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size)

    def get_bbox(self, scale_mats_np, object_scale_mat):
        object_bbox_min = np.array([-1.0, -1.0, -1.0, 1.0])
        object_bbox_max = np.array([ 1.0,  1.0,  1.0, 1.0])
        # Object scale mat: region of interest to **extract mesh**
        # object_scale_mat = np.load(os.path.join(scene_path, 'cameras.npz'))
        object_bbox_min = np.linalg.inv(scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
        object_bbox_max = np.linalg.inv(scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
        return [object_bbox_min[:3, 0].tolist(),object_bbox_max[:3, 0].tolist()]
        # self.near_far = [2.125, 4.525]

    def gen_rays_at(self, intrinsic, c2w, resolution_level=1):
        """
        Generate rays at world space from one camera.
        """
        l = resolution_level
        W,H = self.img_wh
        tx = torch.linspace(0, W - 1, W // l)+0.5
        ty = torch.linspace(0, H - 1, H // l)+0.5
        pixels_x, pixels_y = torch.meshgrid(tx, ty)
        p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
        intrinsic_inv = torch.inverse(intrinsic)
        p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze()  # W, H, 3
        rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True)  # W, H, 3
        rays_v = torch.matmul(c2w[None, None, :3, :3], rays_v[:, :, :, None]).squeeze()  # W, H, 3
        rays_o = c2w[None, None, :3, 3].expand(rays_v.shape)  # W, H, 3
        return rays_o.transpose(0, 1).reshape(-1,3), rays_v.transpose(0, 1).reshape(-1,3)


    def __len__(self):
        return 1000000 #len(self.all_rays)

    def __getitem__(self, idx):
        idx = torch.randint(self.scene_num,(1,)).item()
        if self.scene_num >= 1:
            scene_name = self.scene_idxs[idx]
            img_idx = np.random.choice(len(self.all_rgb_files[scene_name]), size=6)
            self.load_data(scene_name, img_idx)

        idxs = np.random.choice(self.all_rays.shape[0], size=self.batch_size)

        return {'rays': self.all_rays[idxs], 'rgbs': self.all_rgbs[idxs], 'idx': idx}

================================================
FILE: dataLoader/dtu_objs2.py
================================================

import torch
import cv2 as cv
import numpy as np
import os
from glob import glob
from .ray_utils import *
from torch.utils.data import Dataset


# This function is borrowed from IDR: https://github.com/lioryariv/idr
def load_K_Rt_from_P(filename, P=None):
    if P is None:
        lines = open(filename).read().splitlines()
        if len(lines) == 4:
            lines = lines[1:]
        lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
        P = np.asarray(lines).astype(np.float32).squeeze()

    out = cv.decomposeProjectionMatrix(P)
    K = out[0]
    R = out[1]
    t = out[2]

    K = K / K[2, 2]
    intrinsics = np.eye(4)
    intrinsics[:3, :3] = K

    pose = np.eye(4, dtype=np.float32)
    pose[:3, :3] = R.transpose()
    pose[:3, 3] = (t[:3] / t[3])[:, 0]

    return intrinsics, pose

class DTUDataset(Dataset):
    def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
        """
        img_wh should be set to a tuple ex: (1152, 864) to enable test mode!
        """
        # self.N_vis = N_vis
        self.split = split
        self.batch_size = batch_size
        self.root_dir = cfg.datadir
        self.is_stack = is_stack if is_stack is not None else 'train'!=split
        self.downsample = cfg.get(f'downsample_{self.split}')
        self.img_wh = (int(400 / self.downsample), int(300 / self.downsample))

        self.white_bg = False
        self.camera_dict = np.load(os.path.join(self.root_dir, 'cameras.npz'))

        self.read_meta()
        self.get_bbox()

    # def define_transforms(self):
    #     self.transform = T.ToTensor()

    def get_bbox(self):
        object_bbox_min = np.array([-1.0, -1.0, -1.0, 1.0])
        object_bbox_max = np.array([ 1.0,  1.0,  1.0, 1.0])
        # Object scale mat: region of interest to **extract mesh**
        object_scale_mat = np.load(os.path.join(self.root_dir, 'cameras.npz'))['scale_mat_0']
        object_bbox_min = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_min[:, None]
        object_bbox_max = np.linalg.inv(self.scale_mats_np[0]) @ object_scale_mat @ object_bbox_max[:, None]
        self.scene_bbox = [object_bbox_min[:3, 0].tolist(),object_bbox_max[:3, 0].tolist()]
        self.scene_bbox[0].append(0)
        self.scene_bbox[1].append(1)

    def gen_rays_at(self, intrinsic, c2w, resolution_level=1):
        """
        Generate rays at world space from one camera.
        """
        l = resolution_level
        W,H = self.img_wh
        tx = torch.linspace(0, W - 1, W // l)+0.5
        ty = torch.linspace(0, H - 1, H // l)+0.5
        pixels_x, pixels_y = torch.meshgrid(tx, ty)
        p = torch.stack([pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1) # W, H, 3
        intrinsic_inv = torch.inverse(intrinsic)
        p = torch.matmul(intrinsic_inv[None, None, :3, :3], p[:, :, :, None]).squeeze()  # W, H, 3
        rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True)  # W, H, 3
        rays_v = torch.matmul(c2w[None, None, :3, :3], rays_v[:, :, :, None]).squeeze()  # W, H, 3
        rays_o = c2w[None, None, :3, 3].expand(rays_v.shape)  # W, H, 3
        return rays_o.transpose(0, 1).reshape(-1,3), rays_v.transpose(0, 1).reshape(-1,3)

    def read_meta(self):

        images_lis = sorted(glob(os.path.join(self.root_dir, 'image/*.png')))
        images_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in images_lis]) / 255.0
        # masks_lis = sorted(glob(os.path.join(self.root_dir, 'mask/*.png')))
        # masks_np = np.stack([cv.resize(cv.imread(im_name),self.img_wh) for im_name in masks_lis])>128

        self.all_rgbs = torch.from_numpy(images_np.astype(np.float32)[...,[2,1,0]])  # [n_images, H, W, 3]
        # self.all_masks  = torch.from_numpy(masks_np>0)   # [n_images, H, W, 3]
        self.img_wh = [self.all_rgbs.shape[2],self.all_rgbs.shape[1]]

        # world_mat is a projection matrix from world to image
        n_images = len(images_lis)
        world_mats_np = [self.camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]
        self.scale_mats_np = [self.camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in range(n_images)]

        # W,H = self.img_wh
        self.all_rays = []
        self.intrinsics, self.poses = [],[]
        for img_idx, (scale_mat, world_mat) in enumerate(zip(self.scale_mats_np, world_mats_np)):
            P = world_mat @ scale_mat
            P = P[:3, :4]
            intrinsic, c2w = load_K_Rt_from_P(None, P)

            c2w = torch.from_numpy(c2w).float()
            intrinsic = torch.from_numpy(intrinsic).float()
            intrinsic[:2] /= self.downsample

            self.poses.append(c2w)
            self.intrinsics.append(intrinsic)

            rays_o, rays_d = self.gen_rays_at(intrinsic,c2w)
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)

        self.intrinsics, self.poses = torch.stack(self.intrinsics), torch.stack(self.poses)

        # self.all_rgbs[~self.all_masks] = 1.0
        if not self.is_stack:
            self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)
            self.all_rgbs = self.all_rgbs.reshape(-1,3)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
            self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1],3)  # (len(self.meta['frames]),h,w,3)

        self.sampler = SimpleSampler(np.prod(self.all_rgbs.shape[:-1]), self.batch_size)

    def __len__(self):
        return len(self.all_rays)

    def __getitem__(self, idx):
        idx_rand = self.sampler.nextids() #torch.randint(0,len(self.all_rays),(self.batch_size,))
        sample = {'rays': self.all_rays[idx_rand], 'rgbs': self.all_rgbs[idx_rand]}
        return sample

================================================
FILE: dataLoader/google_objs.py
================================================
import torch, cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
import glob
from scipy.spatial.transform import Rotation as R

from .ray_utils import *

def fps_downsample(points, n_points_to_sample):
    selected_points = np.zeros((n_points_to_sample, 3))
    selected_idxs = []
    dist = np.ones(points.shape[0]) * 100
    for i in range(n_points_to_sample):
        idx = np.argmax(dist).tolist()
        selected_points[i] = points[idx]
        print(idx)
        selected_idxs.extend([idx])
        dist_ = ((points - selected_points[i]) ** 2).sum(-1)
        dist = np.minimum(dist, dist_).tolist()

    return selected_idxs

############################# Get Spherical Path #############################

def pose_spherical_nerf(euler, radius=1.8, ep=1):
    c2ws_render = np.eye(4)
    c2ws_render[:3,:3] =  R.from_euler('xyz', euler, degrees=True).as_matrix()
    # 保留旋转矩阵的最后一列再乘个系数就能当作位置?
    c2ws_render[:3,3]  = c2ws_render[:3,:3] @ np.array([0.0,0.0,ep*radius])
    return c2ws_render

def nerf_video_path(c2ws, theta_range=10,phi_range=20,N_views=120,radius=1.3,ep=-1):
    c2ws = torch.tensor(c2ws)
    mean_position = torch.mean(c2ws[:,:3, 3],dim=0).reshape(1,3).cpu().numpy()
    rotvec = []
    for i in range(c2ws.shape[0]):
        r = R.from_matrix(c2ws[i, :3, :3])
        euler_ange = r.as_euler('xyz', degrees=True).reshape(1, 3)
        if i:
            mask = np.abs(euler_ange - rotvec[0])>180
            euler_ange[mask] += 360.0
        rotvec.append(euler_ange)
    # 采用欧拉角做平均的方法求旋转矩阵的平均
    rotvec = np.mean(np.stack(rotvec), axis=0)
    render_poses = [pose_spherical_nerf(rotvec+np.array([angle,0.0,-phi_range]), radius=radius, ep=ep) for angle in np.linspace(-theta_range,theta_range,N_views//4, endpoint=False)]
    render_poses += [pose_spherical_nerf(rotvec+np.array([theta_range,0.0,angle]), radius=radius, ep=ep) for angle in np.linspace(-phi_range,phi_range,N_views//4, endpoint=False)]
    render_poses += [pose_spherical_nerf(rotvec+np.array([angle,0.0,phi_range]), radius=radius, ep=ep) for angle in np.linspace(theta_range,-theta_range,N_views//4, endpoint=False)]
    render_poses += [pose_spherical_nerf(rotvec+np.array([-theta_range,0.0,angle]), radius=radius, ep=ep) for angle in np.linspace(phi_range,-phi_range,N_views//4, endpoint=False)]

    return render_poses

def _interpolate_trajectory(c2ws, num_views: int = 300):
    """calculate interpolate path"""

    from scipy.interpolate import interp1d
    from scipy.spatial.transform import Rotation, Slerp

    key_rots = Rotation.from_matrix(c2ws[:, :3, :3])
    key_times = list(range(len(c2ws)))
    slerp = Slerp(key_times, key_rots)
    interp = interp1d(key_times, c2ws[:, :3, 3], axis=0)
    render_c2ws = []
    for i in range(num_views):
        time = float(i) / num_views * (len(c2ws) - 1)
        cam_location = interp(time)
        cam_rot = slerp(time).as_matrix()
        c2w = np.eye(4)
        c2w[:3, :3] = cam_rot
        c2w[:3, 3] = cam_location
        render_c2ws.append(c2w)
    return np.stack(render_c2ws, axis=0)

def google_objs_path(c2ws, N_views=150):
    positions = c2ws[:, :3, 3]
    selected_idxs = fps_downsample(positions, 3)
    selected_idxs.append(selected_idxs[0])
    return _interpolate_trajectory(c2ws[selected_idxs].numpy(), N_views)



class GoogleObjsDataset(Dataset):
    def __init__(self, cfg, split="train", batch_size=4096):

        # self.N_vis = N_vis
        self.cfg = cfg
        self.root_dir = cfg.datadir
        self.split = split
        self.batch_size = batch_size
        self.is_stack = False if "train" == split else True
        self.downsample = cfg.get(f"downsample_{self.split}")
        self.img_wh = (int(512 / self.downsample), int(512 / self.downsample))
        self.define_transforms()
        train_scene_idxs = sorted(cfg.train_scene_list)
        test_scene_idxs = cfg.test_scene_list
        if len(train_scene_idxs)==2:
            train_scene_idxs = list(range(train_scene_idxs[0],train_scene_idxs[1]))
        self.scene_idxs = train_scene_idxs if self.split=='train' else test_scene_idxs
        self.train_views = cfg.train_views
        self.scene_num = len(self.scene_idxs)

        if 'test' == self.split:
            self.test_index = train_scene_idxs.index(test_scene_idxs[0])


        # self.rot = torch.tensor([[0.65561799, -0.65561799, 0.37460659],
        #                          [0.73729737, 0.44876192, -0.50498052],
        #                          [0.16296514, 0.60727077, 0.77760181]])

        self.scene_bbox = [[-1.0, -1.0, -1.0, 0.0], [1.0, 1.0, 1.0, self.scene_num]]
        # self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])

        self.white_bg = True
        self.near_far = [0.4, 1.6]

        #################################

        # self.folder_path = datadir
        # self.num_source_views = args.num_source_views
        # self.rectify_inplane_rotation = args.rectify_inplane_rotation
        self.scene_path_list = sorted(glob.glob(os.path.join(self.root_dir, "*/")))

        all_rgb_files = []
        # all_depth_files = []
        all_pose_files = []
        all_intrinsics_files = []
        num_files = 250

        for i, scene_idx in enumerate(self.scene_idxs):

            scene_path = self.scene_path_list[scene_idx]
            # print(scene_idx,scene_path)

            rgb_files = [
                os.path.join(scene_path, "rgb", f)
                for f in sorted(os.listdir(os.path.join(scene_path, "rgb")))
            ]
            # depth_files = [os.path.join(scene_path, 'depth', f)
            #              for f in sorted(os.listdir(os.path.join(scene_path, 'depth')))]
            pose_files = [
                f.replace("rgb", "pose").replace("png", "txt") for f in rgb_files
            ]
            intrinsics_files = [
                f.replace("rgb", "intrinsics").replace("png", "txt") for f in rgb_files
            ]

            if (
                np.min([len(rgb_files), len(pose_files), len(intrinsics_files)])
                < num_files
            ):
                print(scene_path)
                continue

            all_rgb_files.append(rgb_files)
            # all_depth_files.append(depth_files)
            all_pose_files.append(pose_files)
            all_intrinsics_files.append(intrinsics_files)

        index = np.arange(len(all_rgb_files))
        self.all_rgb_files = np.array(all_rgb_files)[index]
        # self.all_depth_files = np.array(all_depth_files)[index]
        self.all_pose_files = np.array(all_pose_files)[index]
        self.all_intrinsics_files = np.array(all_intrinsics_files)[index]

        if self.split=='test' or self.scene_num==1:
            self.read_meta()
        else:
            self.train_idxs = range(self.train_views)
        # self.define_proj_mat()

    def read_depth(self, filename):
        depth = np.array(read_pfm(filename)[0], dtype=np.float32)  # (800, 800)
        return depth

    def read_meta(self):

        assert (
            len(self.all_rgb_files)
            == len(self.all_pose_files)
            == len(self.all_intrinsics_files)
        )

        self.all_image_paths = []
        self.all_poses = []
        self.all_rays = []
        self.all_rgbs = []
        for idx in range(len(self.all_rgb_files)):
            rgb_files = self.all_rgb_files[idx]
            pose_files = self.all_pose_files[idx]
            intrinsics_files = self.all_intrinsics_files[idx]

            intrinsics = np.loadtxt(intrinsics_files[0])
            index_3x3 = np.array([0, 1, 2, 4, 5, 6, 8, 9, 10])
            self.intrinsics = intrinsics[index_3x3]

            w, h = self.img_wh
            self.focal = self.intrinsics[0]
            self.focal *= (
                self.img_wh[0] / 512
            )  # modify focal length to match size self.img_wh

            # ray directions for all pixels, same for all images (same H, W, focal)
            self.directions = get_ray_directions(
                h, w, [self.focal, self.focal]
            )  # (h, w, 3)
            self.directions = self.directions / torch.norm(
                self.directions, dim=-1, keepdim=True
            )
            self.intrinsics = torch.tensor(
                [[self.focal, 0, w / 2], [0, self.focal, h / 2], [0, 0, 1]]
            ).float()

            self.scene_image_paths = []
            self.scene_poses = []
            self.scene_rays = []
            self.scene_rgbs = []
            # self.downsample = 1.0

            img_eval_interval = (
                1  # if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
            )
            if "train" == self.split:
                cam_xyzs = []
                # for i in range(len(pose_files)):
                for i in range(100):
                    pose = np.loadtxt(pose_files[i])
                    cam_xyzs.append([pose[3], pose[7], pose[11]])
                cam_xyzs = np.array(cam_xyzs)
                idxs = fps_downsample(cam_xyzs, min(self.train_views, len(rgb_files)))
                self.train_idxs = idxs
                print("train idxs:", idxs)
            else:
                idxs = list(range(100, 200))


            for i in tqdm(
                idxs, desc=f"Loading data {self.split} ({len(idxs)})"
            ):  # img_list:#


                pose = np.loadtxt(pose_files[i])
                pose = torch.FloatTensor(pose).view(4, 4)
                self.scene_poses += [pose]

                image_path = rgb_files[i]
                self.scene_image_paths += [image_path]
                img = Image.open(image_path)

                if self.downsample != 1.0:
                    img = img.resize(self.img_wh, Image.LANCZOS)
                img = self.transform(img)  # (3, h, w)
                img = img.view(3, -1).permute(1, 0)  # (h*w, 3) RGBA
                # img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
                self.scene_rgbs += [img]

                rays_o, rays_d = get_rays(self.directions, pose)  # both (h*w, 3)
                # rays_o, rays_d = rays_o@self.rot, rays_d@self.rot
                self.scene_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)

            self.scene_poses = torch.stack(self.scene_poses)

            views = 180
            radius = {183: 1.3, 199: 1.7, 298: 1.5, 467: 1.1, 957: 1.9, 244: 1.2, 963: 1.2, 527: 1.2, 681:1.9,948:1.2}
            self.render_path = google_objs_path(self.scene_poses, N_views=views)

            name = self.scene_path_list[self.scene_idxs[0]].split('/')[-2]
            np.save(f'{self.root_dir}/{name}_render_path.npy', self.render_path)
            # self.render_path = nerf_video_path(self.scene_poses, N_views=views, theta_range=45,phi_range=90, radius=radius[self.scene_idxs[0]],ep=-1)

            if not self.is_stack:
                self.scene_rays = torch.cat(
                    self.scene_rays, 0
                )  # (len(self.meta['frames])*h*w, 3)
                self.scene_rgbs = torch.cat(
                    self.scene_rgbs, 0
                )  # (len(self.meta['frames])*h*w, 3)

            #             self.all_depth = torch.cat(self.all_depth, 0)  # (len(self.meta['frames])*h*w, 3)
            else:
                self.scene_rays = torch.stack(
                    self.scene_rays, 0
                )  # (len(self.meta['frames]),h*w, 3)
                self.scene_rgbs = torch.stack(self.scene_rgbs, 0).reshape(
                    -1, *self.img_wh[::-1], 3
                )  # (len(self.meta['frames]),h,w,3)
                # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1])  # (len(self.meta['frames]),h,w,3)

            ######################## pre-generate and save in mem ########################

            self.all_image_paths.append(self.scene_image_paths)
            self.all_poses.append(self.scene_poses)
            self.all_rays.append(self.scene_rays)
            self.all_rgbs.append(self.scene_rgbs)

        self.all_rays = torch.cat(self.all_rays)
        self.all_rgbs = torch.cat(self.all_rgbs)

    def get_rays(self, idx):
        ######################## pre-generate and save in mem ########################
        return self.all_rays[idx]


    def get_rgbs(self, idx):

        ######################## pre-generate and save in mem ########################
        return self.all_rgbs[idx]


    def define_transforms(self):
        self.transform = T.ToTensor()

    def define_proj_mat(self):
        self.proj_mat = (
            self.intrinsics.unsqueeze(0) @ torch.inverse(self.scene_poses)[:, :3]
        )

    def world2ndc(self, points, lindisp=None):
        device = points.device
        return (points - self.center.to(device)) / self.radius.to(device)

    def update_index(self):
        self.scene_idx = torch.randint(0, len(self.all_rgb_files), (1,)).item()

    def __len__(self):
        return 10000000 #len(self.all_rgb_files)

    def __getitem__(self, idx):
        #
        # self.update_index()
        # idx =  self.scene_idx #
        idx = idx % len(self.all_rgb_files)
        # idx = torch.randint(len(self.all_rgb_files), (1,)).item()

        ######################## generate rays on the fly ########################
        rgb_files = self.all_rgb_files[idx]
        pose_files = self.all_pose_files[idx]
        intrinsics_files = self.all_intrinsics_files[idx]

        intrinsics = np.loadtxt(intrinsics_files[0])
        index_3x3 = np.array([0, 1, 2, 4, 5, 6, 8, 9, 10])
        intrinsics = intrinsics[index_3x3]

        w, h = self.img_wh
        focal = intrinsics[0]
        focal *= self.img_wh[0] / 512  # modify focal length to match size self.img_wh

        # ray directions for all pixels, same for all images (same H, W, focal)
        directions = get_ray_directions(h, w, [focal, focal])  # (h, w, 3)
        directions = directions / torch.norm(directions, dim=-1, keepdim=True)
        intrinsics = torch.tensor(
            [[focal, 0, w / 2], [0, focal, h / 2], [0, 0, 1]]
        ).float()

        scene_poses = []
        scene_rays = []
        scene_image_paths = []
        scene_rgbs = []
        # downsample = 1.0

        sample_views = 5
        if self.scene_num>1:
            ids = np.random.choice(self.train_idxs, size=sample_views)

            for i in ids:
                image_path = rgb_files[i]
                scene_image_paths += [image_path]
                img = Image.open(image_path)

                idxs = torch.randint(0, w*h, (self.batch_size // sample_views,))

                if self.downsample != 1.0:
                    img = img.resize(self.img_wh, Image.LANCZOS)
                img = self.transform(img)  # (3, h, w)
                img = img.view(3, -1).permute(1, 0)  # (h*w, 3) RGBA
                # img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
                scene_rgbs += [img[idxs]]

                pose = np.loadtxt(pose_files[i])
                pose = torch.FloatTensor(pose).view(4, 4)
                scene_poses += [pose]

                rays_o, rays_d = get_rays(directions, pose)  # both (h*w, 3)
                scene_rays += [torch.cat([rays_o[idxs], rays_d[idxs]], 1)]  # (h*w, 6)

            scene_poses = torch.stack(scene_poses)
            if not self.is_stack:
                scene_rays = torch.cat(scene_rays, 0)  # (len(self.meta['frames])*h*w, 3)
                scene_rgbs = torch.cat(scene_rgbs, 0)  # (len(self.meta['frames])*h*w, 3)
            else:
                scene_rays = torch.stack(scene_rays, 0)  # (len(self.meta['frames]),h*w, 3)
                scene_rgbs = torch.stack(scene_rgbs, 0)  # (len(self.meta['frames]),h*w, 3)


            return {'rays': scene_rays, 'rgbs': scene_rgbs, 'idx': idx}
        else:
            idx_rand = torch.randint(0, len(self.all_rays), (self.batch_size,))
            return {'rays': self.all_rays[idx_rand], 'rgbs': self.all_rgbs[idx_rand]}




================================================
FILE: dataLoader/image.py
================================================
import torch,imageio,cv2
from PIL import Image 
Image.MAX_IMAGE_PIXELS = 1000000000 
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset

_img_suffix = ['png','jpg','jpeg','bmp','tif']

def load(path):
    suffix = path.split('.')[-1]
    if suffix in _img_suffix:
        img =  np.array(Image.open(path))#.convert('L')
        scale = 256.**(1+np.log2(np.max(img))//8)-1
        return img/scale
    elif 'exr' == suffix:
        return imageio.imread(path)
    elif 'npy' == suffix:
        return np.load(path)

    
def srgb_to_linear(img):
	limit = 0.04045
	return np.where(img > limit, np.power((img + 0.055) / 1.055, 2.4), img / 12.92)

class ImageDataset(Dataset):
    def __init__(self, cfg, batchsize, split='train', continue_sampling=False, tolinear=False, HW=-1, perscent=1.0, delete_region=None,mask=None):

        datadir = cfg.datadir
        self.batchsize = batchsize
        self.continue_sampling = continue_sampling
        img = load(datadir).astype(np.float32)
        if HW>0:
            img = cv2.resize(img,[HW,HW])
            
        if tolinear:
            img = srgb_to_linear(img)
        self.img = torch.from_numpy(img)

        H,W = self.img.shape[:2]

        y, x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W), indexing='ij')
        self.coordiante = torch.stack((x,y),-1).float()+0.5

        n_channel = self.img.shape[-1]
        self.image =  self.img
        self.img, self.coordiante = self.img.reshape(H*W,-1), self.coordiante.reshape(H*W,2)
      
        # if continue_sampling:
        #     coordiante_tmp = self.coordiante.view(1,1,-1,2)/torch.tensor([W,H])*2-1.0
        #     self.img = F.grid_sample(self.img.view(1,H,W,-1).permute(0,3,1,2),coordiante_tmp, mode='bilinear', align_corners=True).reshape(self.img.shape[-1],-1).t()
            
            
        if 'train'==split:
            self.mask = torch.ones_like(y)>0
            if mask is not None:
                self.mask = mask>0
                print(torch.sum(mask)/1.0/HW/HW)
            elif delete_region is not None:
                
                if isinstance(delete_region[0], list):
                    for item in delete_region:
                        t_l_x,t_l_y,width,height = item
                        self.mask[t_l_y:t_l_y+height,t_l_x:t_l_x+width] = False
                else:
                    t_l_x,t_l_y,width,height = delete_region
                    self.mask[t_l_y:t_l_y+height,t_l_x:t_l_x+width] = False
            else:
                index = torch.randperm(len(self.img))[:int(len(self.img)*perscent)] 
                self.mask[:] = False
                self.mask.view(-1)[index] = True
            self.mask = self.mask.view(-1)
            self.image, self.coordiante = self.img[self.mask], self.coordiante[self.mask]
        else:
            self.image = self.img
            

        self.HW = [H,W]

        self.scene_bbox = [[0., 0.], [W, H]]
        cfg.aabb = self.scene_bbox
        #

    def __len__(self):
        return 10000

    def __getitem__(self, idx):
        H,W = self.HW 
        device = self.image.device
        idx = torch.randint(0,len(self.image),(self.batchsize,), device=device)
        
        if self.continue_sampling:
            coordinate = self.coordiante[idx] +  torch.rand((self.batchsize,2))-0.5
            coordinate_tmp = (coordinate.view(1,1,self.batchsize,2))/torch.tensor([W,H],device=device)*2-1.0
            rgb = F.grid_sample(self.img.view(1,H,W,-1).permute(0,3,1,2),coordinate_tmp, mode='bilinear', 
                                align_corners=False, padding_mode='border').reshape(self.img.shape[-1],-1).t()
            sample = {'rgb': rgb,
                      'xy': coordinate}
        else:
            sample = {'rgb': self.image[idx],
                      'xy': self.coordiante[idx]}

        return sample

================================================
FILE: dataLoader/image_set.py
================================================
import torch,cv2
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset

def srgb_to_linear(img):
	limit = 0.04045
	return torch.where(img > limit, torch.pow((img + 0.055) / 1.055, 2.4), img / 12.92)

def load(path, HW=512):
    suffix = path.split('.')[-1]

    if 'npy' == suffix:
        img = np.load(path)
        # img = 0.3*img[...,:1] + 0.59*img[...,1:2] + 0.11*img[...,2:]

    if img.shape[-2]!=HW:
        for i in range(img.shape[0]):
            img[i] = cv2.resize(img[i],[HW,HW])
    
    return img


class ImageSetDataset(Dataset):
    def __init__(self, cfg, batchsize, split='train', continue_sampling=False, HW=512, N=10, tolinear=True):

        datadir = cfg.datadir
        self.batchsize = batchsize
        self.continue_sampling = continue_sampling
        imgs = load(datadir,HW=HW)[:N]
        
            
        self.imgs = torch.from_numpy(imgs).float()/255
        if tolinear:
            self.imgs = srgb_to_linear(self.imgs)

        D,H,W = self.imgs.shape[:3]
        
        y, x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W), indexing='ij')
        self.coordinate = torch.stack((x,y),-1).float()+0.5

        self.imgs, self.coordinate = self.imgs.reshape(D,H*W,-1), self.coordinate.reshape(H*W,2)
        self.DHW = [D,H,W]

        self.scene_bbox = [[0., 0., 0.], [W, H, D]]
        cfg.aabb = self.scene_bbox
        # self.down_scale = 512.0/H


    def __len__(self):
        return 1000000

    def __getitem__(self, idx):
        D,H,W = self.DHW
        pix_idx = torch.randint(0,H*W,(self.batchsize,))
        img_idx = torch.randint(0,D,(self.batchsize,))
        

        if self.continue_sampling:
            coordinate = self.coordinate[pix_idx] +  torch.rand((self.batchsize,2)) - 0.5
            coordinate = torch.cat((coordinate,img_idx.unsqueeze(-1)+0.5),dim=-1)
            coordinate_tmp = (coordinate.view(1,1,1,self.batchsize,3))/torch.tensor([W,H,D])*2-1.0
            rgb = F.grid_sample(self.imgs.view(1,D,H,W,-1).permute(0,4,1,2,3),coordinate_tmp, mode='bilinear',
                                align_corners=False, padding_mode='border').reshape(self.imgs.shape[-1],-1).t()
            # coordinate[:,:2] *= self.down_scale
            sample = {'rgb': rgb,
                      'xy': coordinate}
        else:
            sample = {'rgb': self.imgs[img_idx,pix_idx],
                      'xy': torch.cat((self.coordinate[pix_idx],img_idx.unsqueeze(-1)+0.5),dim=-1)}
                      # 'xy': torch.cat((self.coordiante[pix_idx],img_idx.expand_as(pix_idx).unsqueeze(-1)),dim=-1)}



        return sample

================================================
FILE: dataLoader/llff.py
================================================
import torch
from torch.utils.data import Dataset
import glob
import numpy as np
import os
from PIL import Image
from torchvision import transforms as T

from .ray_utils import *


def normalize(v):
    """Normalize a vector."""
    return v / np.linalg.norm(v)


def average_poses(poses):
    """
    Calculate the average pose, which is then used to center all poses
    using @center_poses. Its computation is as follows:
    1. Compute the center: the average of pose centers.
    2. Compute the z axis: the normalized average z axis.
    3. Compute axis y': the average y axis.
    4. Compute x' = y' cross product z, then normalize it as the x axis.
    5. Compute the y axis: z cross product x.

    Note that at step 3, we cannot directly use y' as y axis since it's
    not necessarily orthogonal to z axis. We need to pass from x to y.
    Inputs:
        poses: (N_images, 3, 4)
    Outputs:
        pose_avg: (3, 4) the average pose
    """
    # 1. Compute the center
    center = poses[..., 3].mean(0)  # (3)

    # 2. Compute the z axis
    z = normalize(poses[..., 2].mean(0))  # (3)

    # 3. Compute axis y' (no need to normalize as it's not the final output)
    y_ = poses[..., 1].mean(0)  # (3)

    # 4. Compute the x axis
    x = normalize(np.cross(z, y_))  # (3)

    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
    y = np.cross(x, z)  # (3)

    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)

    return pose_avg


def center_poses(poses, blender2opencv):
    """
    Center the poses so that we can use NDC.
    See https://github.com/bmild/nerf/issues/34
    Inputs:
        poses: (N_images, 3, 4)
    Outputs:
        poses_centered: (N_images, 3, 4) the centered poses
        pose_avg: (3, 4) the average pose
    """
    poses = poses @ blender2opencv
    pose_avg = average_poses(poses)  # (3, 4)
    pose_avg_homo = np.eye(4)
    pose_avg_homo[:3] = pose_avg  # convert to homogeneous coordinate for faster computation
    pose_avg_homo = pose_avg_homo
    # by simply adding 0, 0, 0, 1 as the last row
    last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1))  # (N_images, 1, 4)
    poses_homo = \
        np.concatenate([poses, last_row], 1)  # (N_images, 4, 4) homogeneous coordinate

    poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo  # (N_images, 4, 4)
    #     poses_centered = poses_centered  @ blender2opencv
    poses_centered = poses_centered[:, :3]  # (N_images, 3, 4)

    return poses_centered, pose_avg_homo


def viewmatrix(z, up, pos):
    vec2 = normalize(z)
    vec1_avg = up
    vec0 = normalize(np.cross(vec1_avg, vec2))
    vec1 = normalize(np.cross(vec2, vec0))
    m = np.eye(4)
    m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
    return m


def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
    render_poses = []
    rads = np.array(list(rads) + [1.])

    for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:
        c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
        z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
        render_poses.append(viewmatrix(z, up, c))
    return render_poses


def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
    # center pose
    c2w = average_poses(c2ws_all)

    # Get average pose
    up = normalize(c2ws_all[:, :3, 1].sum(0))

    # Find a reasonable "focus depth" for this dataset
    dt = 0.75
    close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
    focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))

    # Get radii for spiral path
    zdelta = near_fars.min() * .2
    tt = c2ws_all[:, :3, 3]
    rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
    render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
    return np.stack(render_poses)


class LLFFDataset(Dataset):
    def __init__(self, cfg , split='train', hold_every=8):
        """
        spheric_poses: whether the images are taken in a spheric inward-facing manner
                       default: False (forward-facing)
        val_num: number of val images (used for multigpu training, validate same image for all gpus)
        """

        self.root_dir = cfg.datadir
        self.split = split
        self.hold_every = hold_every
        self.is_stack = False if 'train' == split else True
        self.downsample = cfg.get(f'downsample_{self.split}')
        self.define_transforms()

        self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.white_bg = False

        #         self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]
        self.near_far = [0.0, 1.0]
        self.scene_bbox = [[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]]
        # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
        # self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)
        # self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)

    def read_meta(self):

        print(self.root_dir)
        poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy'))  # (N_images, 17)
        self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*')))
        # load full resolution image then resize
        if self.split in ['train', 'test']:
            assert len(poses_bounds) == len(self.image_paths), \
                'Mismatch between number of images and number of poses! Please rerun COLMAP!'

        poses = poses_bounds[:, :15].reshape(-1, 3, 5)  # (N_images, 3, 5)
        self.near_fars = poses_bounds[:, -2:]  # (N_images, 2)
        hwf = poses[:, :, -1]

        # Step 1: rescale focal length according to training resolution
        H, W, self.focal = poses[0, :, -1]  # original intrinsics, same for all images
        self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
        self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]

        # Step 2: correct poses
        # Original poses has rotation in form "down right back", change to "right up back"
        # See https://github.com/bmild/nerf/issues/34
        poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
        # (N_images, 3, 4) exclude H, W, focal
        self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)

        # Step 3: correct scale so that the nearest depth is at a little more than 1.0
        # See https://github.com/bmild/nerf/issues/34
        near_original = self.near_fars.min()
        scale_factor = near_original * 0.75  # 0.75 is the default parameter
        # the nearest depth is at 1/0.75=1.33
        self.near_fars /= scale_factor
        self.poses[..., 3] /= scale_factor

        # build rendering path
        N_views, N_rots = 120, 2
        tt = self.poses[:, :3, 3]  # ptstocam(poses[:3,3,:].T, c2w).T
        up = normalize(self.poses[:, :3, 1].sum(0))
        rads = np.percentile(np.abs(tt), 90, 0)

        self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)

        # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
        # val_idx = np.argmin(distances_from_center)  # choose val image as the closest to
        # center image

        # ray directions for all pixels, same for all images (same H, W, focal)
        W, H = self.img_wh
        self.directions = get_ray_directions_blender(H, W, self.focal)  # (H, W, 3)

        average_pose = average_poses(self.poses)
        dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)
        i_test = np.arange(0, self.poses.shape[0], self.hold_every)  # [np.argmin(dists)]
        img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))

        # use first N_images-1 to train, the LAST is val
        self.all_rays = []
        self.all_rgbs = []
        for i in img_list:
            image_path = self.image_paths[i]
            c2w = torch.FloatTensor(self.poses[i])

            img = Image.open(image_path).convert('RGB')
            if self.downsample != 1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (3, h, w)

            img = img.view(3, -1).permute(1, 0)  # (h*w, 3) RGB
            self.all_rgbs += [img]
            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
            # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)

            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)

        if not self.is_stack:
            self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
            self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)   # (len(self.meta['frames]),h,w, 3)
            self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)


    def define_transforms(self):
        self.transform = T.ToTensor()

    def __len__(self):
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        sample = {'rays': self.all_rays[idx],
                  'rgbs': self.all_rgbs[idx]}

        return sample

================================================
FILE: dataLoader/nsvf.py
================================================
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T

from .ray_utils import *

trans_t = lambda t : torch.Tensor([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1]]).float()

rot_phi = lambda phi : torch.Tensor([
    [1,0,0,0],
    [0,np.cos(phi),-np.sin(phi),0],
    [0,np.sin(phi), np.cos(phi),0],
    [0,0,0,1]]).float()

rot_theta = lambda th : torch.Tensor([
    [np.cos(th),0,-np.sin(th),0],
    [0,1,0,0],
    [np.sin(th),0, np.cos(th),0],
    [0,0,0,1]]).float()


def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi/180.*np.pi) @ c2w
    c2w = rot_theta(theta/180.*np.pi) @ c2w
    c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
    return c2w

class NSVF(Dataset):
    """NSVF Generic Dataset."""
    def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False):
        self.root_dir = datadir
        self.split = split
        self.is_stack = is_stack
        self.downsample = downsample
        self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
        self.define_transforms()

        self.white_bg = True
        self.near_far = [0.5,6.0]
        self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)
        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.define_proj_mat()
        
        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
    
    def bbox2corners(self):
        corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
        for i in range(3):
            corners[i,[0,1],i] = corners[i,[1,0],i] 
        return corners.view(-1,3)
        
        
    def read_meta(self):
        with open(os.path.join(self.root_dir, "intrinsics.txt")) as f:
            focal = float(f.readline().split()[0])
        self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]])
        self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1)

        pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
        img_files  = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))

        if self.split == 'train':
            pose_files = [x for x in pose_files if x.startswith('0_')]
            img_files = [x for x in img_files if x.startswith('0_')]
        elif self.split == 'val':
            pose_files = [x for x in pose_files if x.startswith('1_')]
            img_files = [x for x in img_files if x.startswith('1_')]
        elif self.split == 'test':
            test_pose_files = [x for x in pose_files if x.startswith('2_')]
            test_img_files = [x for x in img_files if x.startswith('2_')]
            if len(test_pose_files) == 0:
                test_pose_files = [x for x in pose_files if x.startswith('1_')]
                test_img_files = [x for x in img_files if x.startswith('1_')]
            pose_files = test_pose_files
            img_files = test_img_files

        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2])  # (h, w, 3)
        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)

        frames = 200
        self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,frames+1)[:-1]], 0)
        
        self.poses = []
        self.all_rays = []
        self.all_rgbs = []

        assert len(img_files) == len(pose_files)
        for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
            image_path = os.path.join(self.root_dir, 'rgb', img_fname)
            img = Image.open(image_path)
            if self.downsample!=1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (4, h, w)
            img = img.view(img.shape[0], -1).permute(1, 0)  # (h*w, 4) RGBA
            if img.shape[-1]==4:
                img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
            self.all_rgbs += [img]

            c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv
            c2w = torch.FloatTensor(c2w)
            self.poses.append(c2w)  # C2W
            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 8)
            
#             w2c = torch.inverse(c2w)
#

        self.poses = torch.stack(self.poses)
        if 'train' == self.split:
            if self.is_stack:
                self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6)  # (len(self.meta['frames])*h*w, 3)
                self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames])*h*w, 3) 
            else:
                self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)
                self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (len(self.meta['frames])*h*w, 3)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
            self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)

 
    def define_transforms(self):
        self.transform = T.ToTensor()
        
    def define_proj_mat(self):
        self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]

    def world2ndc(self, points):
        device = points.device
        return (points - self.center.to(device)) / self.radius.to(device)
        
    def __len__(self):
        if self.split == 'train':
            return len(self.all_rays)
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        if self.split == 'train':  # use data in the buffers
            sample = {'rays': self.all_rays[idx],
                      'rgbs': self.all_rgbs[idx]}

        else:  # create data for each image separately

            img = self.all_rgbs[idx]
            rays = self.all_rays[idx]

            sample = {'rays': rays,
                      'rgbs': img}
        return sample

================================================
FILE: dataLoader/ray_utils.py
================================================
import torch, re, json
import numpy as np
from torch import searchsorted
from kornia import create_meshgrid


def load_json(path):
    with open(path, 'r') as f:
        return json.load(f)

# from utils import index_point_feature
def depth2dist(z_vals, cos_angle):
    # z_vals: [N_ray N_sample]
    device = z_vals.device
    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1)  # [N_rays, N_samples]
    dists = dists * cos_angle.unsqueeze(-1)
    return dists


def ndc2dist(ndc_pts, cos_angle):
    dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1)
    dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1)  # [N_rays, N_samples]
    return dists


def get_ray_directions(H, W, focal, center=None):
    """
    Get ray directions for all pixels in camera coordinate.
    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
               ray-tracing-generating-camera-rays/standard-coordinate-systems
    Inputs:
        H, W, focal: image height, width and focal length
    Outputs:
        directions: (H, W, 3), the direction of the rays in camera coordinate
    """
    grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5

    i, j = grid.unbind(-1)
    # the direction here is without +0.5 pixel centering as calibration is not so accurate
    # see https://github.com/bmild/nerf/issues/24
    cent = center if center is not None else [W / 2, H / 2]
    directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1)  # (H, W, 3)

    return directions


def get_ray_directions_blender(H, W, focal, center=None):
    """
    Get ray directions for all pixels in camera coordinate.
    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
               ray-tracing-generating-camera-rays/standard-coordinate-systems
    Inputs:
        H, W, focal: image height, width and focal length
    Outputs:
        directions: (H, W, 3), the direction of the rays in camera coordinate
    """
    grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5
    i, j = grid.unbind(-1)
    # the direction here is without +0.5 pixel centering as calibration is not so accurate
    # see https://github.com/bmild/nerf/issues/24
    cent = center if center is not None else [W / 2, H / 2]
    directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)],
                             -1)  # (H, W, 3)

    return directions


def get_rays(directions, c2w):
    """
    Get ray origin and normalized directions in world coordinate for all pixels in one image.
    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
               ray-tracing-generating-camera-rays/standard-coordinate-systems
    Inputs:
        directions: (H, W, 3) precomputed ray directions in camera coordinate
        c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
    Outputs:
        rays_o: (H*W, 3), the origin of the rays in world coordinate
        rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
    """
    # Rotate ray directions from camera coordinate to the world coordinate
    rays_d = directions @ c2w[:3, :3].T  # (H, W, 3)
    # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
    # The origin of all rays is the camera origin in world coordinate
    rays_o = c2w[:3, 3].expand(rays_d.shape)  # (H, W, 3)

    rays_d = rays_d.view(-1, 3)
    rays_o = rays_o.view(-1, 3)

    return rays_o, rays_d


def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
    # Shift ray origins to near plane
    t = -(near + rays_o[..., 2]) / rays_d[..., 2]
    rays_o = rays_o + t[..., None] * rays_d

    # Projection
    o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
    o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
    o2 = 1. + 2. * near / rays_o[..., 2]

    d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
    d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
    d2 = -2. * near / rays_o[..., 2]

    rays_o = torch.stack([o0, o1, o2], -1)
    rays_d = torch.stack([d0, d1, d2], -1)

    return rays_o, rays_d

def ndc_rays(H, W, focal, near, rays_o, rays_d):
    # Shift ray origins to near plane
    t = (near - rays_o[..., 2]) / rays_d[..., 2]
    rays_o = rays_o + t[..., None] * rays_d

    # Projection
    o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
    o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
    o2 = 1. - 2. * near / rays_o[..., 2]

    d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
    d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
    d2 = 2. * near / rays_o[..., 2]

    rays_o = torch.stack([o0, o1, o2], -1)
    rays_d = torch.stack([d0, d1, d2], -1)

    return rays_o, rays_d

# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
    device = weights.device
    # Get pdf
    weights = weights + 1e-5  # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)  # (batch, len(bins))

    # Take uniform samples
    if det:
        u = torch.linspace(0., 1., steps=N_samples, device=device)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device)

    # Pytest, overwrite u with numpy's fixed random numbers
    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0., 1., N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u)

    # Invert CDF
    u = u.contiguous()
    inds = searchsorted(cdf.detach(), u, right=True)
    below = torch.max(torch.zeros_like(inds - 1), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples


def dda(rays_o, rays_d, bbox_3D):
    inv_ray_d = 1.0 / (rays_d + 1e-6)
    t_min = (bbox_3D[:1] - rays_o) * inv_ray_d  # N_rays 3
    t_max = (bbox_3D[1:] - rays_o) * inv_ray_d
    t = torch.stack((t_min, t_max))  # 2 N_rays 3
    t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0]
    t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0]
    return t_min, t_max


def ray_marcher(rays,
                N_samples=64,
                lindisp=False,
                perturb=0,
                bbox_3D=None):
    """
    sample points along the rays
    Inputs:
        rays: ()

    Returns:

    """

    # Decompose the inputs
    N_rays = rays.shape[0]
    rays_o, rays_d = rays[:, 0:3], rays[:, 3:6]  # both (N_rays, 3)
    near, far = rays[:, 6:7], rays[:, 7:8]  # both (N_rays, 1)

    if bbox_3D is not None:
        # cal aabb boundles
        near, far = dda(rays_o, rays_d, bbox_3D)

    # Sample depth points
    z_steps = torch.linspace(0, 1, N_samples, device=rays.device)  # (N_samples)
    if not lindisp:  # use linear sampling in depth space
        z_vals = near * (1 - z_steps) + far * z_steps
    else:  # use linear sampling in disparity space
        z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)

    z_vals = z_vals.expand(N_rays, N_samples)

    if perturb > 0:  # perturb sampling depths (z_vals)
        z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:])  # (N_rays, N_samples-1) interval mid points
        # get intervals between samples
        upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1)
        lower = torch.cat([z_vals[:, :1], z_vals_mid], -1)

        perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)
        z_vals = lower + (upper - lower) * perturb_rand

    xyz_coarse_sampled = rays_o.unsqueeze(1) + \
                         rays_d.unsqueeze(1) * z_vals.unsqueeze(2)  # (N_rays, N_samples, 3)

    return xyz_coarse_sampled, rays_o, rays_d, z_vals


def read_pfm(filename):
    file = open(filename, 'rb')
    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().decode('utf-8').rstrip()
    if header == 'PF':
        color = True
    elif header == 'Pf':
        color = False
    else:
        raise Exception('Not a PFM file.')

    dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception('Malformed PFM header.')

    scale = float(file.readline().rstrip())
    if scale < 0:  # little-endian
        endian = '<'
        scale = -scale
    else:
        endian = '>'  # big-endian

    data = np.fromfile(file, endian + 'f')
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    file.close()
    return data, scale
    
class SimpleSampler:
    def __init__(self, total, batch):
        self.total = total
        self.batch = batch
        self.curr = total
        self.ids = None

    def nextids(self):
        self.curr += self.batch
        if self.curr + self.batch > self.total:
            self.ids = torch.LongTensor(np.random.permutation(self.total))
            self.curr = 0
        return self.ids[self.curr:self.curr + self.batch]

def ndc_bbox(all_rays):
    near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0]
    near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0]
    far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0]
    far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0]
    print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}')
    return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max)))

def pose_from_json(meta, transpose):
    c2ws = []
    for frame in meta['frames']:
        c2ws.append(np.array(frame['transform_matrix'])@transpose)
    return np.stack(c2ws)


def rotation_matrix_from_vectors(vec1, vec2):
    """ Find the rotation matrix that aligns vec1 to vec2
    :param vec1: A 3d "source" vector
    :param vec2: A 3d "destination" vector
    :return mat: A transform matrix (3x3) which when applied to vec1, aligns it with vec2.
    """
    a, b = (vec1 / np.linalg.norm(vec1)).reshape(3), (vec2 / np.linalg.norm(vec2)).reshape(3)
    v = np.cross(a, b)
    if any(v):  # if not all zeros then
        c = np.dot(a, b)
        s = np.linalg.norm(v)
        kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
        return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2))

    else:
        return np.eye(3)  # cross of all zeros only occurs on identical directions

def normalize(x):
    return x / np.linalg.norm(x)

def rotation_up(poses):
    up = normalize(np.linalg.lstsq(poses[:, :3, 1],np.ones((poses.shape[0],1)))[0])
    rot = rotation_matrix_from_vectors(up,np.array([0.,1.,0.]))
    return rot

def search_orientation(points):
    from scipy.spatial.transform import Rotation as R
    bbox_sizes,rot_mats,bboxs = [],[],[]
    for y_angle in np.linspace(-45,45,15):
        rotvec = np.array([0,y_angle,0])/180*np.pi
        rot = R.from_rotvec(rotvec).as_matrix()
        point_orientation = rot@points
        bbox = np.max(point_orientation,axis=1) - np.min(point_orientation,axis=1)
        bbox_sizes.append(np.prod(bbox))
        rot_mats.append(rot)
        bboxs.append(bbox)
    rot = rot_mats[np.argmin(bbox_sizes)]
    bbox = bboxs[np.argmin(bbox_sizes)]
    return rot,bbox


def load_point_txt(path):
    points = []
    with open(path, "r") as f:
        for line in f:
            if line[0] == "#":
                continue
            els = line.split(" ")
            err = float(els[7])
            if err > 0.25:
                continue
            points.append([float(els[1]), float(els[2]), float(els[3])])

    points = np.stack(points)

    # mask out outerliner with a sphere cover 95 percet poitns
    center = np.mean(points, axis=0)
    dist = np.sqrt(np.sum((points - center) ** 2, axis=1))
    radius_idx = np.argsort(dist)[:int(points.shape[0] * 0.95)]
    points = points[radius_idx]

    return points

p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1)

def orientation(poses, point_path=None):
    if point_path is not None:
        points = load_point_txt(point_path)
    else:
        points = poses[:,:3,-1]

    # np.savetxt('points_1.txt', points)
    # np.savetxt('poses_center_1.txt', poses[:, :3, -1])

    rot_up = rotation_up(poses)
    poses_new = rot_up[None] @ poses[:, :3]
    points = rot_up[:3, :3] @ points.T

    center = np.mean(poses_new[:,:3,-1], axis=0, keepdims=True)
    poses_new[:, :3, -1] -= center
    points -= center.T

    # scale = 1.5
    rot, bbox = search_orientation(poses_new[:, :3, -1].T)
    poses_new[:, :3, :4] = rot[None] @ poses_new[:, :3, :4]
    points = rot @ points
    aabb = np.stack((np.min(points,axis=1),np.max(points,axis=1)))
    poses_new[:, :3, -1] /= np.min(aabb[1]-aabb[0])
    points /= np.min(aabb[1] - aabb[0])
    aabb = (aabb/np.min(aabb[1]-aabb[0])*1.5).tolist()


    np.savetxt('points.txt', points.T)
    np.savetxt('poses_center.txt', poses_new[:, :3, -1])
    return p34_to_44(poses_new), aabb


def spherify_poses(poses, radus=1):
    p34_to_44 = lambda p: np.concatenate([p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1)

    rays_d = poses[:, :3, 2:3]
    rays_o = poses[:, :3, 3:4]

    def min_line_dist(rays_o, rays_d):
        A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
        b_i = -A_i @ rays_o
        pt_mindist = np.squeeze(-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0))
        return pt_mindist

    pt_mindist = min_line_dist(rays_o, rays_d)

    center = pt_mindist
    up = (poses[:, :3, 3] - center).mean(0)

    vec0 = normalize(up)
    vec1 = normalize(np.cross([.1, .2, .3], vec0))
    vec2 = normalize(np.cross(vec0, vec1))
    pos = center
    c2w = np.stack([vec1, vec2, vec0, pos], 1)

    poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])

    rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))*radus
    # print(poses_reset,center,rad)

    scale = 1. / rad
    poses_reset[:, :3, 3] *= scale
    # bds *= sc
    rad *= scale
    # print ('========>',poses_reset)

    centroid = np.mean(poses_reset[:, :3, 3], 0)
    zh = centroid[2]
    radcircle = np.sqrt(rad ** 2 - zh ** 2)
    render_poses = []

    for th in np.linspace(0., 2. * np.pi, 120):
        camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
        up = np.array([0, 0, -1.])

        vec2 = normalize(camorigin)
        vec0 = normalize(np.cross(vec2, up))
        vec1 = normalize(np.cross(vec2, vec0))
        pos = camorigin
        p = np.stack([vec0, vec1, vec2, pos], 1)

        render_poses.append(p)

    render_poses = np.stack(render_poses, 0)

    # render_poses = np.concatenate([render_poses, np.broadcast_to(poses[0, :3, -1:], render_poses[:, :3, -1:].shape)], -1)
    # poses_reset = np.concatenate(
    #     [poses_reset[:, :3, :4], np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape)], -1)

    return poses_reset, render_poses, scale


================================================
FILE: dataLoader/sdf.py
================================================
import torch
import numpy as np
from torch.utils.data import Dataset

def N_to_reso(avg_reso, bbox):
    xyz_min, xyz_max = bbox
    dim = len(xyz_min)
    n_voxels = avg_reso**dim
    voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)
    return torch.ceil((xyz_max - xyz_min) / voxel_size).long().tolist()

def load(path, split, dtype='points'):

    if 'grid' == dtype:
        sdf = torch.from_numpy(np.load(path).astype(np.float32))
        D, H, W = sdf.shape
        z, y, x = torch.meshgrid(torch.arange(0, D), torch.arange(0, H), torch.arange(0, W), indexing='ij')
        coordiante = torch.stack((x,y,z),-1).reshape(D*H*W,3)#*2-1 # normalize to [-1,1]
        sdf = sdf.reshape(D*H*W,-1)
        DHW = [D,H,W]
    elif 'points' == dtype:
        DHW = [640] * 3
        sdf_dict = np.load(path, allow_pickle=True).item()
        sdf = torch.from_numpy(sdf_dict[f'sdfs_{split}'].astype(np.float32)).reshape(-1,1)
        coordiante = torch.from_numpy(sdf_dict[f'points_{split}'].astype(np.float32))
        aabb = [[-1,-1,-1],[1,1,1]]
        coordiante = ((coordiante + 1) / 2 * (torch.tensor(DHW[::-1]))).reshape(-1,3)
        DHW = DHW[::-1]
    return coordiante, sdf, DHW

class SDFDataset(Dataset):
    def __init__(self, cfg, split='train'):

        datadir = cfg.datadir
        self.coordiante, self.sdf, self.DHW = load(datadir, split)

        [D, H, W] = self.DHW

        self.scene_bbox = [[0., 0., 0.], [W, H, D]]
        cfg.aabb = self.scene_bbox

    def __len__(self):
        return len(self.sdf)

    def __getitem__(self, idx):
        sample = {'rgb': self.sdf[idx],
                  'xy': self.coordiante[idx]}

        return sample

================================================
FILE: dataLoader/tankstemple.py
================================================
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T

from .ray_utils import *


def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
    if axis == 'z':
        return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h]
    elif axis == 'y':
        return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)]
    else:
        return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)]


def cross(x, y, axis=0):
    T = torch if isinstance(x, torch.Tensor) else np
    return T.cross(x, y, axis)


def normalize(x, axis=-1, order=2):
    if isinstance(x, torch.Tensor):
        l2 = x.norm(p=order, dim=axis, keepdim=True)
        return x / (l2 + 1e-8), l2

    else:
        l2 = np.linalg.norm(x, order, axis)
        l2 = np.expand_dims(l2, axis)
        l2[l2 == 0] = 1
        return x / l2,


def cat(x, axis=1):
    if isinstance(x[0], torch.Tensor):
        return torch.cat(x, dim=axis)
    return np.concatenate(x, axis=axis)


def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):
    """
    This function takes a vector 'camera_position' which specifies the location
    of the camera in world coordinates and two vectors `at` and `up` which
    indicate the position of the object and the up directions of the world
    coordinate system respectively. The object is assumed to be centered at
    the origin.
    The output is a rotation matrix representing the transformation
    from world coordinates -> view coordinates.
    Input:
        camera_position: 3
        at: 1 x 3 or N x 3  (0, 0, 0) in default
        up: 1 x 3 or N x 3  (0, 1, 0) in default
    """

    if at is None:
        at = torch.zeros_like(camera_position)
    else:
        at = torch.tensor(at).type_as(camera_position)
    if up is None:
        up = torch.zeros_like(camera_position)
        up[2] = -1
    else:
        up = torch.tensor(up).type_as(camera_position)

    z_axis = normalize(at - camera_position)[0]
    x_axis = normalize(cross(up, z_axis))[0]
    y_axis = normalize(cross(z_axis, x_axis))[0]

    R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)
    return R


def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
    c2ws = []
    for t in range(frames):
        c2w = torch.eye(4)
        cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi))
        cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True)
        c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot
        c2ws.append(c2w)
    return torch.stack(c2ws)


class TanksTempleDataset(Dataset):
    """NSVF Generic Dataset."""

    def __init__(self, cfg, split='train'):
        self.root_dir = cfg.datadir
        self.split = split
        self.is_stack = False if 'train'==split else True
        self.downsample = cfg.get(f'downsample_{self.split}')
        self.img_wh = (int(1920 / self.downsample), int(1080 / self.downsample))
        self.define_transforms()

        self.white_bg = True
        self.near_far = [0.01, 6.0]
        self.scene_bbox = (torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2, 3) * 1.2).tolist()

        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.define_proj_mat()

        # self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
        # self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)

    def bbox2corners(self):
        corners = self.scene_bbox.unsqueeze(0).repeat(4, 1, 1)
        for i in range(3):
            corners[i, [0, 1], i] = corners[i, [1, 0], i]
        return corners.view(-1, 3)

    def read_meta(self):

        self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt"))
        self.intrinsics[:2] *= (np.array(self.img_wh) / np.array([1920, 1080])).reshape(2, 1)
        pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
        img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))

        if self.split == 'train':
            pose_files = [x for x in pose_files if x.startswith('0_')]
            img_files = [x for x in img_files if x.startswith('0_')]
        elif self.split == 'val':
            pose_files = [x for x in pose_files if x.startswith('1_')]
            img_files = [x for x in img_files if x.startswith('1_')]
        elif self.split == 'test':
            test_pose_files = [x for x in pose_files if x.startswith('2_')]
            test_img_files = [x for x in img_files if x.startswith('2_')]
            if len(test_pose_files) == 0:
                test_pose_files = [x for x in pose_files if x.startswith('1_')]
                test_img_files = [x for x in img_files if x.startswith('1_')]
            pose_files = test_pose_files
            img_files = test_img_files

        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0],
                                             [self.intrinsics[0, 0], self.intrinsics[1, 1]],
                                             center=self.intrinsics[:2, 2])  # (h, w, 3)
        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)

        self.poses = []

        ray_per_frame = self.img_wh[0]*self.img_wh[1]
        self.all_rays = torch.empty(len(pose_files),ray_per_frame,6)
        self.all_rgbs = torch.empty(len(pose_files),ray_per_frame,3)
        assert len(img_files) == len(pose_files)
        for i in tqdm(range(len(pose_files)),desc=f'Loading data {self.split} ({len(img_files)})'):

            img_fname, pose_fname = img_files[i], pose_files[i]
            image_path = os.path.join(self.root_dir, 'rgb', img_fname)
            img = Image.open(image_path)
            if self.downsample != 1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (4, h, w)
            img = img.view(img.shape[0], -1).permute(1, 0)  # (h*w, 4) RGBA
            if img.shape[-1] == 4:
                img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
            # self.all_rgbs.append(img)

            c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))  # @ cam_trans
            c2w = torch.FloatTensor(c2w)
            self.poses.append(c2w)  # C2W
            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)

            self.all_rays[i] = torch.cat([rays_o, rays_d], 1)  # (h*w, 6)
            self.all_rgbs[i] = img

        self.poses = torch.stack(self.poses)

        frames = 200
        scene_bbox = torch.tensor(self.scene_bbox).float()
        center = torch.mean(scene_bbox, dim=0)
        radius = torch.norm(scene_bbox[1] - center) * 1.2
        up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
        pos_gen = circle(radius=radius, h=-0.2 * up[1], axis='y')
        self.render_path = gen_path(pos_gen, up=up, frames=frames)
        self.render_path[:, :3, 3] += center

        if 'train' == self.split:
            if not self.is_stack:
            #     self.all_rays = torch.stack(self.all_rays, 0).reshape(-1, *self.img_wh[::-1], 6)  # (len(self.meta['frames])*h*w, 3)
            #     self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1, *self.img_wh[::-1], 3)  # (len(self.meta['frames])*h*w, 3)
            # else:
                self.all_rays = self.all_rays.reshape(-1,6)  # (len(self.meta['frames])*h*w, 3)
                self.all_rgbs = self.all_rgbs.reshape(-1,3)  # (len(self.meta['frames])*h*w, 3)
        else:
            # self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
            self.all_rgbs = self.all_rgbs.reshape(-1, *self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)

    def define_transforms(self):
        self.transform = T.ToTensor()

    def define_proj_mat(self):
        self.proj_mat = torch.from_numpy(self.intrinsics[:3, :3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:, :3]

    def world2ndc(self, points):
        device = points.device
        return (points - self.center.to(device)) / self.radius.to(device)

    def __len__(self):
        if self.split == 'train':
            return len(self.all_rays)
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        if self.split == 'train':  # use data in the buffers
            sample = {'rays': self.all_rays[idx],
                      'rgbs': self.all_rgbs[idx]}

        else:  # create data for each image separately

            img = self.all_rgbs[idx]
            rays = self.all_rays[idx]

            sample = {'rays': rays,
                      'rgbs': img}
        return sample

================================================
FILE: dataLoader/your_own_data.py
================================================
import torch,cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T


from .ray_utils import *


class YourOwnDataset(Dataset):
    def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):

        self.N_vis = N_vis
        self.root_dir = datadir
        self.split = split
        self.is_stack = is_stack
        self.downsample = downsample
        self.define_transforms()

        self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.define_proj_mat()

        self.white_bg = True
        self.near_far = [0.1,100.0]
        
        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
        self.downsample=downsample

    def read_depth(self, filename):
        depth = np.array(read_pfm(filename)[0], dtype=np.float32)  # (800, 800)
        return depth
    
    def read_meta(self):

        with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
            self.meta = json.load(f)

        w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample)
        self.img_wh = [w,h]
        self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x'])  # original focal length
        self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y'])  # original focal length
        self.cx, self.cy = self.meta['cx'],self.meta['cy']


        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy])  # (h, w, 3)
        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
        self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float()

        self.image_paths = []
        self.poses = []
        self.all_rays = []
        self.all_rgbs = []
        self.all_masks = []
        self.all_depth = []


        img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
        idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
        for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#

            frame = self.meta['frames'][i]
            pose = np.array(frame['transform_matrix']) @ self.blender2opencv
            c2w = torch.FloatTensor(pose)
            self.poses += [c2w]

            image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
            self.image_paths += [image_path]
            img = Image.open(image_path)
            
            if self.downsample!=1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (4, h, w)
            img = img.view(-1, w*h).permute(1, 0)  # (h*w, 4) RGBA
            if img.shape[-1]==4:
                img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
            self.all_rgbs += [img]


            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)


        self.poses = torch.stack(self.poses)
        if not self.is_stack:
            self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)
            self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (len(self.meta['frames])*h*w, 3)

#             self.all_depth = torch.cat(self.all_depth, 0)  # (len(self.meta['frames])*h*w, 3)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
            self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)
            # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1])  # (len(self.meta['frames]),h,w,3)


    def define_transforms(self):
        self.transform = T.ToTensor()
        
    def define_proj_mat(self):
        self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]

    def world2ndc(self,points,lindisp=None):
        device = points.device
        return (points - self.center.to(device)) / self.radius.to(device)
        
    def __len__(self):
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        if self.split == 'train':  # use data in the buffers
            sample = {'rays': self.all_rays[idx],
                      'rgbs': self.all_rgbs[idx]}

        else:  # create data for each image separately

            img = self.all_rgbs[idx]
            rays = self.all_rays[idx]
            mask = self.all_masks[idx] # for quantity evaluation

            sample = {'rays': rays,
                      'rgbs': img}
        return sample


================================================
FILE: models/FactorFields.py
================================================
import torch, math
import torch.nn
import torch.nn.functional as F
import numpy as np
import time, skimage
from utils import N_to_reso, N_to_vm_reso


# import BasisCoding

def grid_mapping(positions, freq_bands, aabb, basis_mapping='sawtooth'):
    aabbSize = max(aabb[1] - aabb[0])
    scale = aabbSize[..., None] / freq_bands
    if basis_mapping == 'triangle':
        pts_local = (positions - aabb[0]).unsqueeze(-1) % scale
        pts_local_int = ((positions - aabb[0]).unsqueeze(-1) // scale) % 2
        pts_local = pts_local / (scale / 2) - 1
        pts_local = torch.where(pts_local_int == 1, -pts_local, pts_local)
    elif basis_mapping == 'sawtooth':
        pts_local = (positions - aabb[0])[..., None] % scale
        pts_local = pts_local / (scale / 2) - 1
        pts_local = pts_local.clamp(-1., 1.)
    elif basis_mapping == 'sinc':
        pts_local = torch.sin((positions - aabb[0])[..., None] / (scale / np.pi) - np.pi / 2)
    elif basis_mapping == 'trigonometric':
        pts_local = (positions - aabb[0])[..., None] / scale * 2 * np.pi
        pts_local = torch.cat((torch.sin(pts_local), torch.cos(pts_local)), dim=-1)
    elif basis_mapping == 'x':
        pts_local = (positions - aabb[0]).unsqueeze(-1) / scale
    # elif basis_mapping=='hash':
    #     pts_local = (positions - aabb[0])/max(aabbSize)

    return pts_local


def dct_dict(n_atoms_fre, size, n_selete, dim=2):
    """
    Create a dictionary using the Discrete Cosine Transform (DCT) basis. If n_atoms is
    not a perfect square, the returned dictionary will have ceil(sqrt(n_atoms))**2 atoms
    :param n_atoms:
        Number of atoms in dict
    :param size:
        Size of first patch dim
    :return:
        DCT dictionary, shape (size*size, ceil(sqrt(n_atoms))**2)
    """
    # todo flip arguments to match random_dictionary
    p = n_atoms_fre  # int(math.ceil(math.sqrt(n_atoms)))
    dct = np.zeros((p, size))

    for k in range(p):
        basis = np.cos(np.arange(size) * k * math.pi / p)
        if k > 0:
            basis = basis - np.mean(basis)

        dct[k] = basis

    kron = np.kron(dct, dct)
    if 3 == dim:
        kron = np.kron(kron, dct)

    if n_selete < kron.shape[0]:
        idx = [x[0] for x in np.array_split(np.arange(kron.shape[0]), n_selete)]
        kron = kron[idx]

    for col in range(kron.shape[0]):
        norm = np.linalg.norm(kron[col]) or 1
        kron[col] /= norm

    kron = torch.FloatTensor(kron)
    return kron


def positional_encoding(positions, freqs):
    freq_bands = (2 ** torch.arange(freqs).float()).to(positions.device)  # (F,)
    pts = (positions[..., None] * freq_bands).reshape(
        positions.shape[:-1] + (freqs * positions.shape[-1],))  # (..., DF)
    pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
    return pts


def raw2alpha(sigma, dist):
    # sigma, dist  [N_rays, N_samples]
    alpha = 1. - torch.exp(-sigma * dist)

    T = torch.cumprod(torch.cat([torch.ones_like(alpha[..., :1]), 1. - alpha + 1e-10], -1), -1)
    weights = alpha * T[..., :-1]  # [N_rays, N_samples]
    return alpha, weights, T[..., -1:]


class AlphaGridMask(torch.nn.Module):
    def __init__(self, device, aabb, alpha_volume):
        super(AlphaGridMask, self).__init__()
        self.device = device

        self.aabb = aabb.to(self.device)
        self.aabbSize = self.aabb[1] - self.aabb[0]
        self.invgridSize = 1.0 / self.aabbSize * 2
        self.alpha_volume = alpha_volume.view(1, 1, *alpha_volume.shape[-3:])
        self.gridSize = torch.LongTensor([alpha_volume.shape[-1], alpha_volume.shape[-2], alpha_volume.shape[-3]]).to(
            self.device)

    def sample_alpha(self, xyz_sampled):
        xyz_sampled = self.normalize_coord(xyz_sampled)
        alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1, -1, 1, 1, 3), align_corners=True).view(-1)

        return alpha_vals

    def normalize_coord(self, xyz_sampled):
        return (xyz_sampled - self.aabb[0]) * self.invgridSize - 1


class MLPMixer(torch.nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim=16,
                 num_layers=2,
                 hidden_dim=64, pe=0, with_dropout=False):
        super().__init__()

        self.with_dropout = with_dropout
        self.in_dim = in_dim + 2 * in_dim * pe
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.pe = pe

        backbone = []
        for l in range(num_layers):
            if l == 0:
                layer_in_dim = self.in_dim
            else:
                layer_in_dim = self.hidden_dim

            if l == num_layers - 1:
                layer_out_dim, bias = out_dim, False
            else:
                layer_out_dim, bias = self.hidden_dim, True

            backbone.append(torch.nn.Linear(layer_in_dim, layer_out_dim, bias=bias))

        self.backbone = torch.nn.ModuleList(backbone)
        # torch.nn.init.constant_(backbone[0].weight.data, 1.0/self.in_dim)

    def forward(self, x, is_train=False):
        # x: [B, 3]
        h = x
        if self.pe > 0:
            h = torch.cat([h, positional_encoding(h, self.pe)], dim=-1)

        if self.with_dropout and is_train:
            h = F.dropout(h, p=0.1)

        for l in range(self.num_layers):
            h = self.backbone[l](h)
            if l != self.num_layers - 1:  # l!=0 and
                h = F.relu(h, inplace=True)
                # h = torch.sin(h)
        # sigma, feat = h[...,0], h[...,1:]
        return h


class MLPRender_Fea(torch.nn.Module):
    def __init__(self, inChanel, num_layers=3, hidden_dim=64, viewpe=6, feape=2):
        super(MLPRender_Fea, self).__init__()

        self.in_mlpC = 3 + inChanel + 2 * viewpe * 3 + 2 * feape * inChanel
        self.num_layers = num_layers
        self.viewpe = viewpe
        self.feape = feape

        mlp = []
        for l in range(num_layers):
            if l == 0:
                in_dim = self.in_mlpC
            else:
                in_dim = hidden_dim

            if l == num_layers - 1:
                out_dim, bias = 3, False  # 3 rgb
            else:
                out_dim, bias = hidden_dim, True

            mlp.append(torch.nn.Linear(in_dim, out_dim, bias=bias))

        self.mlp = torch.nn.ModuleList(mlp)
        # torch.nn.init.constant_(self.mlp[-1].bias, 0)

    def forward(self, viewdirs, features):

        indata = [features, viewdirs]
        if self.feape > 0:
            indata += [positional_encoding(features, self.feape)]
        if self.viewpe > 0:
            indata += [positional_encoding(viewdirs, self.viewpe)]

        h = torch.cat(indata, dim=-1)
        for l in range(self.num_layers):
            h = self.mlp[l](h)
            if l != self.num_layers - 1:
                h = F.relu(h, inplace=True)

        rgb = torch.sigmoid(h)
        return rgb


class FactorFields(torch.nn.Module):
    def __init__(self, cfg, device):
        super(FactorFields, self).__init__()

        self.cfg = cfg
        self.device = device

        self.matMode = [[0, 1], [0, 2], [1, 2]]
        self.vecMode = [2, 1, 0]
        self.n_scene, self.scene_idx = 1, 0

        self.alphaMask = None
        self.coeff_type, self.basis_type = cfg.model.coeff_type, cfg.model.basis_type

        self.setup_params(self.cfg.dataset.aabb)
        if self.cfg.model.coeff_type != 'none':
            self.coeffs = self.init_coef()

        if self.cfg.model.basis_type != 'none':
            self.basises = self.init_basis()

        out_dim = cfg.model.out_dim
        if 'vm' in self.coeff_type:
            in_dim = sum(cfg.model.basis_dims) * 3
        elif 'x' in self.cfg.model.basis_type:
            in_dim = len(
                cfg.model.basis_dims) * 2 * self.in_dim if self.cfg.model.basis_mapping == 'trigonometric' else len(
                cfg.model.basis_dims) * self.in_dim
        else:
            in_dim = sum(cfg.model.basis_dims)
        self.linear_mat = MLPMixer(in_dim, out_dim, num_layers=cfg.model.num_layers, hidden_dim=cfg.model.hidden_dim,
                                   with_dropout=cfg.model.with_dropout).to(device)

        if 'reconstruction' in cfg.defaults.mode:
            # self.cur_volumeSize = N_to_reso(cfg.training.volume_resoInit, self.aabb)
            # self.update_renderParams(self.cur_volumeSize)

            view_pe, fea_pe = cfg.renderer.view_pe, cfg.renderer.fea_pe
            num_layers, hidden_dim = cfg.renderer.num_layers, cfg.renderer.hidden_dim
            self.renderModule = MLPRender_Fea(inChanel=out_dim - 1, num_layers=num_layers, hidden_dim=hidden_dim,
                                              viewpe=view_pe, feape=fea_pe).to(device)

            self.is_unbound = self.cfg.dataset.is_unbound
            if self.is_unbound:
                self.bg_len = 0.2
                self.inward_aabb = torch.tensor([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]]).to(device)
                self.aabb = self.inward_aabb * (1 + self.bg_len)
            else:
                self.inward_aabb = self.aabb

            # self.freq_bands = torch.FloatTensor(cfg.model.freq_bands).to(device)
            self.cur_volumeSize = N_to_reso(cfg.training.volume_resoInit ** self.in_dim, self.aabb)
            self.update_renderParams(self.cur_volumeSize)

        print('=====> total parameters: ', self.n_parameters())

    def setup_params(self, aabb):

        self.in_dim = len(aabb[0]) - 1 if (
                    'images' == self.cfg.defaults.mode or 'reconstructions' == self.cfg.defaults.mode) else len(aabb[0])
        self.aabb = torch.FloatTensor(aabb)[:, :self.in_dim].to(self.device)

        self.basis_dims = self.cfg.model.basis_dims
        if 'reconstruction' not in self.cfg.defaults.mode:
            self.basis_reso = self.cfg.model.basis_resos if 'image' in self.cfg.defaults.mode else np.round(
                np.array(self.cfg.model.basis_resos) * (min(aabb[1][:self.in_dim]) + 1) / 1024.0).astype('int').tolist()
            self.T_basis = self.cfg.model.T_basis if self.cfg.model.T_basis>0 else sum(np.power(np.array(self.basis_reso), self.in_dim) * np.array(self.cfg.model.basis_dims))
            self.T_coeff = self.cfg.model.T_coeff if self.cfg.model.T_coeff>0 else self.cfg.model.total_params - self.T_basis
            self.T_coeff = self.T_coeff if self.T_coeff > 0 else 8 ** self.in_dim * sum(self.basis_dims)

            if 'image' == self.cfg.defaults.mode:
                self.freq_bands = max(aabb[1][:self.in_dim]) / torch.FloatTensor(self.basis_reso).to(self.device)
            else:
                self.freq_bands = torch.FloatTensor(self.cfg.model.freq_bands).to(self.device)

            self.coeff_reso = N_to_reso(self.T_coeff // sum(self.basis_dims), self.aabb[:, :self.in_dim])[::-1]  # DHW
            self.n_scene = 1  # int(aabb[1][-1]) if 'images' == self.cfg.defaults.mode else 1

            if 'sdf' == self.cfg.defaults.mode:
                self.freq_bands *= 0.5
            elif 'images' == self.cfg.defaults.mode:
                self.coeff_reso = [aabb[1][-1]] + self.coeff_reso
                self.aabb = torch.FloatTensor(aabb).to(self.device)
            if 'vec' in self.coeff_type or 'cp' in self.coeff_type or 'vm' in self.coeff_type:
                self.coeff_reso = aabb[1]
        else:
            self.coeff_reso = N_to_reso(self.cfg.model.coeff_reso ** self.in_dim, self.aabb[:, :self.in_dim])[::-1]
            self.T_coeff = sum(self.cfg.model.basis_dims) * np.prod(self.coeff_reso)
            self.T_basis = self.cfg.model.total_params - self.T_coeff
            scale = self.T_basis / sum(
                np.power(np.array(self.cfg.model.basis_resos), self.in_dim) * np.array(self.cfg.model.basis_dims))
            scale = np.power(scale, 1.0 / self.in_dim)
            self.basis_reso = self.cfg.model.basis_resos if (
                        'vec' in self.basis_type or 'cp' in self.basis_type) else np.round(
                np.array(self.cfg.model.basis_resos) * scale).astype('int').tolist()
            # self.freq_bands = self.cfg.dataset.scene_reso / torch.FloatTensor(self.basis_resos).to(self.device)
            self.freq_bands = torch.FloatTensor(self.cfg.model.freq_bands).to(self.device) \
                if (
                        'reconstructions' == self.cfg.defaults.mode or 'x' in self.basis_type or 'vec' in self.basis_type or 'cp' in self.basis_type) else torch.FloatTensor(
                self.cfg.model.freq_bands).to(self.device) * (self.cfg.dataset.scene_reso / float(
                max(self.basis_reso)) / max(self.cfg.model.freq_bands))
            self.n_scene = int(aabb[1][-1]) if 'reconstructions' == self.cfg.defaults.mode else 1

        # print(self.coeff_reso,self.basis_reso,self.freq_bands)

    def init_coef(self):
        n_scene = self.n_scene if 'reconstructions' == self.cfg.defaults.mode or 'images' == self.cfg.defaults.mode else 1
        if 'hash' in self.coeff_type or 'grid' in self.coeff_type:
            coeffs = [
                self.cfg.model.coef_init * torch.ones((1, sum(self.basis_dims), *self.coeff_reso), device=self.device)
                for _ in range(n_scene)]
            coeffs = torch.nn.ParameterList(coeffs)
        elif 'cp' in self.coeff_type or 'vm' in self.coeff_type:
            coeffs = []
            for i in range(len(self.coeff_reso)):
                coeffs.append(self.cfg.model.coef_init * torch.ones(
                    (1, sum(self.basis_dims), max(256, self.coeff_reso[i]), n_scene), device=self.device))
            coeffs = torch.nn.ParameterList(coeffs)
        elif 'vec' in self.coeff_type:
            coeffs = self.cfg.model.coef_init * torch.ones(
                (1, sum(self.basis_dims), max(256, max(self.coeff_reso)), n_scene), device=self.device)
            coeffs = torch.nn.ParameterList([coeffs])
        elif 'mlp' in self.coeff_type:
            coeffs = torch.nn.ParameterList(
                [MLPMixer(self.in_dim, sum(self.basis_dims), num_layers=2, hidden_dim=64, pe=4).to(self.device) for _ in
                 range(n_scene)])
        return coeffs

    def init_basis(self):

        if 'hash' in self.basis_type:
            import tinycudann as tcnn
            n_levels = len(self.basis_reso)
            if 'reconstruction' not in self.cfg.defaults.mode:
                base_resolution_low, base_resolution_high = 32, int(max(self.aabb[1]).item()) // 2
                per_level_scale = torch.pow(self.freq_bands[-1] / self.freq_bands[0], 1.0 / n_levels).item()
            else:
                base_resolution_low = torch.round(self.basis_reso[0] * self.freq_bands[0]).long().item()
                per_level_scale = torch.pow(self.freq_bands[0] / self.freq_bands[-1], 1.0 / n_levels).item()
                base_resolution_high = torch.round(
                    self.basis_reso[n_levels // 2] * self.freq_bands[n_levels // 2]).long().item()
            log2_hashmap_size = np.round(np.log2(self.T_basis / 3 / np.mean(self.basis_dims)))
            # per_level_scale = torch.pow((self.basis_reso[-1] * self.freq_bands[-1])/(self.basis_reso[0] * self.freq_bands[0]),1.0/n_levels).item()
            # per_level_scale = torch.pow(self.freq_bands[0] / self.freq_bands[-1], 1.0 / n_levels).item()

            basises = []
            if len(self.basis_dims) == 1 or sum(self.basis_dims) > 32:
                encoding_config_low = {
                    "otype": "HashGrid",
                    "n_levels": n_levels,
                    "n_features_per_level": sum(self.basis_dims) // n_levels,
                    "log2_hashmap_size": log2_hashmap_size,
                    "base_resolution": base_resolution_low,
                    "per_level_scale": per_level_scale  # 1.25992 #1.38191 #
                }

                basises.append(tcnn.Encoding(
                    n_input_dims=self.in_dim,
                    encoding_config=encoding_config_low))
            else:
                encoding_config_low = {
                    "otype": "HashGrid",
                    "n_levels": n_levels // 2,
                    "n_features_per_level": min(16, self.basis_dims[0]),
                    "log2_hashmap_size": log2_hashmap_size,
                    "base_resolution": base_resolution_low,
                    "per_level_scale": per_level_scale  # 1.25992 #1.38191 #
                }

                encoding_config_high = {
                    "otype": "HashGrid",
                    "n_levels": n_levels // 2,
                    "n_features_per_level": self.basis_dims[-1],
                    "log2_hashmap_size": log2_hashmap_size,
                    "base_resolution": base_resolution_high,
                    "per_level_scale": per_level_scale  # 1.25992 #1.38191 #
                }

                basises = []
                basises.append(tcnn.Encoding(
                    n_input_dims=self.in_dim,
                    encoding_config=encoding_config_low))
                if self.basis_dims[0] == 32:
                    basises.append(tcnn.Encoding(
                        n_input_dims=self.in_dim,
                        encoding_config=encoding_config_low))
                basises.append(tcnn.Encoding(
                    n_input_dims=self.in_dim,
                    encoding_config=encoding_config_high))

            return torch.nn.ParameterList(basises)
        else:
            basises, coeffs, n_params_basis = [], [], 0
            # in_dim = self.in_dim if 'images' != self.cfg.defaults.mode else self.in_dim - 1
            for i, (basis_dim, reso) in enumerate(zip(self.basis_dims, self.basis_reso)):
                # reso_cur = N_to_reso(reso, aabb)[::-1]
                if 'mlp' in self.basis_type:
                    basises.append(MLPMixer(self.in_dim, basis_dim, num_layers=2, \
                                            hidden_dim=64, pe=4).to(self.device))
                elif 'grid' in self.basis_type:
                    basises.append(torch.nn.Parameter(dct_dict(int(np.power(basis_dim, 1. / self.in_dim) + 1), reso,
                                                               n_selete=basis_dim, dim=self.in_dim).reshape(
                        [1, basis_dim] + [reso] * self.in_dim).to(self.device)))
                    # basises.append(torch.nn.Parameter(torch.ones([1, basis_dim] + [reso] * self.in_dim).to(self.device)))
                elif 'vm' in self.basis_type:
                    reso_level = N_to_vm_reso(reso ** self.in_dim, self.aabb[:, :self.in_dim])
                    for i in range(len(self.matMode)):
                        mat_id_0, mat_id_1 = self.matMode[i]
                        basises.append(torch.nn.Parameter(
                            0.1 * torch.randn((1, basis_dim, reso_level[mat_id_1], reso_level[mat_id_0]),
                                              device=self.device)))
                elif 'cp' in self.basis_type:
                    for _ in range(self.in_dim - 1):
                        basises.append(torch.nn.Parameter(
                            0.1 * torch.randn((1, basis_dim, max(reso, 128), 1), device=self.device)))
                elif 'x' in self.basis_type:
                    continue
        return torch.nn.ParameterList(basises)

    def get_coeff(self, xyz_sampled):
        N_points, dim = xyz_sampled.shape
        in_dim = self.in_dim
        pts = self.normalize_coord(xyz_sampled).view([1, -1] + [1] * (dim - 1) + [dim])

        if self.coeff_type in 'hash':
            coeffs = self.coeffs(pts * 0.5 + 0.5).float()
        elif 'grid' in self.coeff_type:
            coeffs = F.grid_sample(self.coeffs[self.scene_idx], pts, mode=self.cfg.model.coef_mode, align_corners=False,
                                   padding_mode='border').view(-1, N_points).t()
        elif 'vec' in self.coeff_type:
            pts = pts.view(1, -1, 1, in_dim)
            idx = (self.scene_idx + 0.5) / self.n_scene * 2 - 1
            pts = torch.stack((torch.ones_like(pts[..., 0]) * idx, pts[..., 0]), dim=-1)
            coeffs = F.grid_sample(self.coeffs[0], pts, mode=self.cfg.model.coef_mode,
                                   align_corners=False, padding_mode='border').view(-1, N_points).t()
        elif 'cp' in self.coeff_type:
            pts = pts.squeeze(2)
            idx = (self.scene_idx + 0.5) / self.n_scene * 2 - 1
            pts = torch.stack((torch.ones_like(pts) * idx, pts), dim=-2)

            coeffs = F.grid_sample(self.coeffs[0], pts[..., 0], mode=self.cfg.model.coef_mode,
                                   align_corners=False, padding_mode='border').view(-1, N_points).t()
            for i in range(1, in_dim):
                coeffs = coeffs * F.grid_sample(self.coeffs[i], pts[..., i], mode=self.cfg.model.coef_mode,
                                                align_corners=False, padding_mode='border').view(-1, N_points).t()
        elif 'vm' in self.coeff_type:
            pts = pts.squeeze(2)
            idx = (self.scene_idx + 0.5) / self.n_scene * 2 - 1
            pts = torch.stack((torch.ones_like(pts) * idx, pts), dim=-2)

            coeffs = []
            for i in range(in_dim):
                coeffs.append(F.grid_sample(self.coeffs[i], pts[..., self.vecMode[i]], mode=self.cfg.model.coef_mode,
                                            align_corners=False, padding_mode='border').view(-1, N_points).t())
            coeffs = torch.cat(coeffs, dim=-1)
        elif 'mlp' in self.coeff_type:
            coeffs = self.coeffs[self.scene_idx](pts.view(N_points, in_dim))
        elif 'hash' in self.coeff_type:
            coeffs = self.coeffs[self.scene_idx]((pts.view(N_points, in_dim) + 1) / 2)
        return coeffs

    def get_basis(self, x):
        N_points = x.shape[0]
        if 'images' == self.cfg.defaults.mode:
            x = x[..., :-1]
        if 'hash' in self.basis_type:
            x = (x - self.aabb[0]) / torch.max(self.aabb[1] - self.aabb[0])
            if len(self.basises) == 1:
                basises = self.basises[0](x).float()
            if len(self.basises) == 2:
                basises = torch.cat((self.basises[0](x), self.basises[1](x)), dim=-1).float()
            elif len(self.basises) == 3:
                basises = torch.cat((self.basises[0](x), self.basises[1](x), self.basises[2](x)), dim=-1).float()
        else:
            freq_len = len(self.freq_bands)
            xyz = grid_mapping(x, self.freq_bands, self.aabb[:, :self.in_dim], self.cfg.model.basis_mapping).view(1, *(
                        [1] * (self.in_dim - 1)), -1, self.in_dim, freq_len)
            basises = []
            for i in range(freq_len):
                if 'mlp' in self.basis_type:
                    basises.append(self.basises[i](xyz[..., i].view(-1, self.in_dim)))
                elif 'grid' in self.basis_type:
                    basises.append(
                        F.grid_sample(self.basises[i], xyz[..., i], mode=self.cfg.model.basis_mode,
                                      align_corners=True).view(-1, N_points).T)
                elif 'vm' in self.basis_type:
                    coordinate_mat = torch.stack((xyz[..., self.matMode[0], i], xyz[..., self.matMode[1], i],
                                                  xyz[..., self.matMode[2], i])).view(3, -1, 1, 2)
                    for idx_mat in range(self.in_dim):
                        basises.append(F.grid_sample(self.basises[i * self.in_dim + idx_mat], coordinate_mat[[idx_mat]],
                                                     align_corners=True).view(-1, x.shape[0]).t())
                elif 'cp' in self.basis_type:
                    for idx_axis in range(self.in_dim - 1):
                        coordinate_vec = torch.stack(
                            (torch.zeros_like(xyz[..., idx_axis + 1, i]), xyz[..., idx_axis + 1, i]), dim=-1).squeeze(2)
                        if 0 == idx_axis:
                            basises_level = F.grid_sample(self.basises[i * (self.in_dim - 1) + idx_axis],
                                                          coordinate_vec,
                                                          align_corners=True).view(-1, x.shape[0]).t()
                        else:
                            basises_level = basises_level * F.grid_sample(
                                self.basises[i * (self.in_dim - 1) + idx_axis], coordinate_vec,
                                align_corners=True).view(-1, x.shape[0]).t()
                    basises.append(basises_level)
                elif 'x' in self.basis_type:
                    basises.append(xyz[..., i].view(x.shape[0], -1))
            if isinstance(basises, list):
                basises = torch.cat(basises, dim=-1)
            if 'vm' in self.basis_type:  # switch order
                basises = basises.view(x.shape[0], freq_len, -1).permute(0, 2, 1).reshape(x.shape[0], -1)
        return basises

    @torch.no_grad()
    def normalize_basis(self):
        for basis in self.basises:
            basis.data = basis.data / torch.norm(basis.data, dim=(2, 3), keepdim=True)

    def get_coding(self, x):
        if self.cfg.model.coeff_type != 'none' and self.cfg.model.basis_type != 'none':
            coeff = self.get_coeff(x)
            basises = self.get_basis(x)
            return basises * coeff, coeff
        elif self.cfg.model.coeff_type != 'none':
            coeff = self.get_coeff(x)
            return coeff, coeff
        elif self.cfg.model.basis_type != 'none':
            basises = self.get_basis(x)
            return basises, basises

    def n_parameters(self):
        total = sum(p.numel() for p in self.parameters())
        if 'fix' in self.cfg.model.basis_type:
            total -= self.T_basis
        return total

    def get_optparam_groups(self, lr_small=0.001, lr_large=0.02):
        grad_vars = []
        if self.cfg.training.linear_mat:
            grad_vars += [{'params': self.linear_mat.parameters(), 'lr': lr_small}]

        if 'none' != self.coeff_type and self.cfg.training.coeff:
            grad_vars += [{'params': self.coeffs.parameters(), 'lr': lr_large}]

        if 'fix' not in self.cfg.model.basis_type and 'none' != self.cfg.model.basis_type and self.cfg.training.basis:
            grad_vars += [{'params': self.basises.parameters(), 'lr': lr_large}]

        if 'reconstruction' in self.cfg.defaults.mode and self.cfg.training.renderModule:
            grad_vars += [{'params': self.renderModule.parameters(), 'lr': lr_small}]
        return grad_vars

    def set_optimizable(self, items, statue):
        for item in items:
            if item == 'basis' and self.cfg.model.basis_type != 'none':
                for item in self.basises:
                    item.requires_grad = statue
            elif item == 'coeff' and self.cfg.model.coeff_type != 'none':
                for item in self.basises:
                    item.requires_grad = statue
            elif item == 'proj':
                self.linear_mat.requires_grad = statue
            elif item == 'renderer':
                self.renderModule.requires_grad = statue

    def TV_loss(self, reg):
        total = 0
        for idx in range(len(self.basises)):
            total = total + reg(self.basises[idx]) * 1e-2
        return total

    def sample_point_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
        N_samples = N_samples if N_samples > 0 else self.nSamples
        near, far = self.cfg.dataset.near_far
        interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o)
        if is_train:
            interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples)

        rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
        mask_outbbox = ((self.aabb[0, :self.in_dim] > rays_pts) | (rays_pts > self.aabb[1, :self.in_dim])).any(dim=-1)
        return rays_pts, interpx, ~mask_outbbox

    def sample_point(self, rays_o, rays_d, is_train=True, N_samples=-1):
        N_samples = N_samples if N_samples > 0 else self.nSamples
        vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
        rate_a = (self.aabb[1, :self.in_dim] - rays_o) / vec
        rate_b = (self.aabb[0, :self.in_dim] - rays_o) / vec
        t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=0.05, max=1e3)
        rng = torch.arange(N_samples)[None].float()
        if is_train:
            rng = rng.repeat(rays_d.shape[-2], 1)
            rng += torch.rand_like(rng[:, [0]])
        step = self.stepSize * rng.to(rays_o.device)
        interpx = (t_min[..., None] + step)

        rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
        mask_outbbox = ((self.aabb[0, :self.in_dim] > rays_pts) | (rays_pts > self.aabb[1, :self.in_dim])).any(dim=-1)

        return rays_pts, interpx, ~mask_outbbox

    def sample_point_unbound(self, rays_o, rays_d, is_train=True, N_samples=-1):
        N_samples = N_samples if N_samples > 0 else self.nSamples

        N_inner, N_outer = 3 * N_samples // 4, N_samples // 4
        b_inner = torch.linspace(0, 2, N_inner + 1).to(self.device)
        b_outer = 2 / torch.linspace(1, 1 / 16, N_outer + 1).to(self.device)

        if is_train:
            rng = torch.rand((N_inner + N_outer), device=self.device)
            interpx = torch.cat([
                b_inner[1:] * rng[:N_inner] + b_inner[:-1] * (1 - rng[:N_inner]),
                b_outer[1:] * rng[N_inner:] + b_outer[:-1] * (1 - rng[N_inner:]),
            ])[None]
        else:
            interpx = torch.cat([
                (b_inner[1:] + b_inner[:-1]) * 0.5,
                (b_outer[1:] + b_outer[:-1]) * 0.5,
            ])[None]

        rays_pts = rays_o[:, None, :] + rays_d[:, None, :] * interpx[..., None]

        norm = rays_pts.abs().amax(dim=-1, keepdim=True)
        inner_mask = (norm <= 1)
        rays_pts = torch.where(
            inner_mask,
            rays_pts,
            rays_pts / norm * ((1 + self.bg_len) - self.bg_len / norm)
        )

        return rays_pts, interpx, inner_mask.squeeze(-1)

    def normalize_coord(self, xyz_sampled):
        invaabbSize = 2.0 / (self.aabb[1] - self.aabb[0])
        return (xyz_sampled - self.aabb[0]) * invaabbSize - 1

    def basis2density(self, density_features):
        if self.cfg.renderer.fea2denseAct == "softplus":
            return F.softplus(density_features + self.cfg.renderer.density_shift)
        elif self.cfg.renderer.fea2denseAct == "relu":
            return F.relu(density_features + self.cfg.renderer.density_shift)

    @torch.no_grad()
    def cal_mean_coef(self, state_dict):
        if 'grid' in self.coeff_type or 'mlp' in self.coeff_type:
            key_list = []
            for item in state_dict.keys():
                if 'coeffs.0' in item:
                    key_list.append(item)

            for key in key_list:
                average = torch.zeros_like(state_dict[key])
                for i in range(self.n_scene):
                    item = key.replace('0', f'{i}', 1)
                    average += state_dict[item]
                    state_dict.pop(item, None)
                average /= self.n_scene
                state_dict[key] = average
        elif 'vec' in self.coeff_type:
            state_dict['coeffs.0'] = torch.mean(state_dict['coeffs.0'], dim=-1, keepdim=True)
        elif 'cp' in self.coeff_type or 'vm' in self.coeff_type:
            for i in range(3):
                state_dict[f'coeffs.{i}'] = torch.mean(state_dict[f'coeffs.{i}'], dim=-1, keepdim=True)

        return state_dict

    def save(self, path):
        ckpt = {'state_dict': self.state_dict(), 'cfg': self.cfg}
        if self.alphaMask is not None:
            alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy()
            ckpt.update({'alphaMask.shape': alpha_volume.shape})
            ckpt.update({'alphaMask.mask': np.packbits(alpha_volume.reshape(-1))})
            ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()})

        # average the coeff for saving if batch training
        if 'reconstruction' in self.cfg.defaults.mode:
            ckpt['state_dict'] = self.cal_mean_coef(ckpt['state_dict'])
        torch.save(ckpt, path)

    def load(self, ckpt):
        if 'alphaMask.aabb' in ckpt.keys():
            length = np.prod(ckpt['alphaMask.shape'])
            alpha_volume = torch.from_numpy(
                np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape']))
            self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device),
                                           alpha_volume.float().to(self.device))
        self.load_state_dict(ckpt['state_dict'])
        volumeSize = N_to_reso(self.cfg.training.volume_resoFinal ** self.in_dim, self.aabb)
        self.update_renderParams(volumeSize)

    def update_renderParams(self, gridSize):
        self.aabbSize = self.aabb[1] - self.aabb[0]
        self.gridSize = torch.LongTensor(gridSize).to(self.device)
        units = self.aabbSize / (self.gridSize - 1)
        self.stepSize = torch.mean(units) * self.cfg.renderer.step_ratio
        aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize)))
        self.nSamples = int((aabbDiag / self.stepSize).item()) + 1

    @torch.no_grad()
    def upsample_volume_grid(self, res_target):
        self.update_renderParams(res_target)

        if self.cfg.dataset.dataset_name == 'google_objs' and self.n_scene == 1 and self.cfg.model.coeff_type == 'grid':
            coeffs = [
                F.interpolate(self.coeffs[0].data, size=None, scale_factor=1.3, align_corners=True, mode='trilinear')]
            self.coeffs = torch.nn.ParameterList(coeffs)

    def compute_alpha(self, xyz_locs, length=1):

        if self.alphaMask is not None:
            alphas = self.alphaMask.sample_alpha(xyz_locs)
            alpha_mask = alphas > 0
        else:
            alpha_mask = torch.ones_like(xyz_locs[:, 0], dtype=bool)

        sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)

        if alpha_mask.any():
            feats, _ = self.get_coding(xyz_locs[alpha_mask])
            validsigma = self.linear_mat(feats, is_train=False)[..., 0]
            sigma[alpha_mask] = self.basis2density(validsigma)

        alpha = 1 - torch.exp(-sigma * length).view(xyz_locs.shape[:-1])

        return alpha

    @torch.no_grad()
    def getDenseAlpha(self, gridSize=None, times=16):

        gridSize = self.gridSize.tolist() if gridSize is None else gridSize

        aabbSize = self.inward_aabb[1] - self.inward_aabb[0]
        units = aabbSize / (torch.LongTensor(gridSize).to(self.device) - 1)
        units_half = 1.0 / (torch.LongTensor(gridSize) - 1) * 0.5
        stepSize = torch.mean(units)

        samples = torch.stack(torch.meshgrid(
            [torch.linspace(units_half[0], 1 - units_half[0], gridSize[0]),
             torch.linspace(units_half[1], 1 - units_half[1], gridSize[1]),
             torch.linspace(units_half[2], 1 - units_half[2], gridSize[2])], indexing='ij'
        ), -1).to(self.device)
        dense_xyz = self.inward_aabb[0] * (1 - samples) + self.inward_aabb[1] * samples

        dense_xyz = dense_xyz.transpose(0, 2).contiguous()
        alpha = torch.zeros_like(dense_xyz[..., 0])
        for _ in range(times):
            for i in range(gridSize[2]):
                shiftment = (torch.rand(dense_xyz[i].shape) * 2 - 1).to(self.device) * (
                            units / 2 * 1.2) if times > 1 else 0.0
                alpha[i] += self.compute_alpha((dense_xyz[i] + shiftment).view(-1, 3),
                                               stepSize * self.cfg.renderer.distance_scale).view(
                    (gridSize[1], gridSize[0]))
        return alpha / times, dense_xyz

    @torch.no_grad()
    def updateAlphaMask(self, gridSize=(200, 200, 200), is_update_alphaMask=False):

        alpha, dense_xyz = self.getDenseAlpha(gridSize)
        total_voxels = gridSize[0] * gridSize[1] * gridSize[2]

        ks = 3
        alpha = alpha.clamp(0, 1)[None, None]
        alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1])

        # filter floaters
        min_size = np.mean(alpha.shape[-3:]).item()
        alphaMask_thres = self.cfg.renderer.alphaMask_thres if is_update_alphaMask else 0.08
        if self.is_unbound:
            alphaMask_thres = 0.04
            alpha = (alpha >= alphaMask_thres).float()
        else:
            alpha = skimage.morphology.remove_small_objects(alpha.cpu().numpy() >= alphaMask_thres, min_size=min_size,
                                                            connectivity=1)
            alpha = torch.FloatTensor(alpha).to(self.device)

        if is_update_alphaMask:
            self.alphaMask = AlphaGridMask(self.device, self.inward_aabb, alpha)

        valid_xyz = dense_xyz[alpha > 0.5]

        xyz_min = valid_xyz.amin(0)
        xyz_max = valid_xyz.amax(0)
        if not self.is_unbound:
            pad = (xyz_max - xyz_min) / 20
            xyz_min -= pad
            xyz_max += pad

        new_aabb = torch.stack((xyz_min, xyz_max))

        total = torch.sum(alpha)
        return new_aabb

    @torch.no_grad()
    def shrink(self, new_aabb):

        self.setup_params(new_aabb.tolist())
        if self.cfg.model.coeff_type != 'none':
            del self.coeffs
            self.coeffs = self.init_coef()

        if self.cfg.model.basis_type != 'none':
            del self.basises
            self.basises = self.init_basis()

        self.aabb = self.inward_aabb = new_aabb
        self.cfg.dataset.aabb = self.aabb.tolist()
        self.update_renderParams(self.gridSize.tolist())

    @torch.no_grad()
    def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240 * 5, bbox_only=False):
        tt = time.time()
        N = torch.tensor(all_rays.shape[:-1]).prod()

        mask_filtered = []
        length_current = 0
        idx_chunks = torch.split(torch.arange(N), chunk)
        for idx_chunk in idx_chunks:
            rays_chunk = all_rays[idx_chunk].to(self.device)

            rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6]
            if bbox_only:
                vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
                rate_a = (self.aabb[1] - rays_o) / vec
                rate_b = (self.aabb[0] - rays_o) / vec
                t_min = torch.minimum(rate_a, rate_b).amax(-1)  # .clamp(min=near, max=far)
                t_max = torch.maximum(rate_a, rate_b).amin(-1)  # .clamp(min=near, max=far)
                mask_inbbox = t_max > t_min

            else:
                xyz_sampled, _, _ = self.sample_point(rays_o, rays_d, N_samples=N_samples, is_train=False)
                mask_inbbox = (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1)

            # mask_filtered.append(mask_inbbox.cpu())
            length = torch.sum(mask_inbbox)
            all_rays[length_current:length_current + length], all_rgbs[length_current:length_current + length] = \
            rays_chunk[mask_inbbox].cpu(), all_rgbs[idx_chunk][mask_inbbox.cpu()]
            length_current += length

        return all_rays[:length_current], all_rgbs[:length_current]

    def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):

        # sample points
        viewdirs = rays_chunk[:, 3:6]
        if self.is_unbound:
            xyz_sampled, z_vals, inner_mask = self.sample_point_unbound(rays_chunk[:, :3], viewdirs, is_train=is_train,
                                                                        N_samples=N_samples)
            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], z_vals[:, -1:] - z_vals[:, -2:-1]), dim=-1)
        elif ndc_ray:
            xyz_sampled, z_vals, inner_mask = self.sample_point_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,
                                                                    N_samples=N_samples)
            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
            rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
            dists = dists * rays_norm
            viewdirs = viewdirs / rays_norm
        else:
            xyz_sampled, z_vals, inner_mask = self.sample_point(rays_chunk[:, :3], viewdirs, is_train=is_train,
                                                                N_samples=N_samples)
            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)

        viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)
        ray_valid = torch.ones_like(xyz_sampled[..., 0]).bool() if self.is_unbound else inner_mask
        if self.alphaMask is not None:
            alpha_inner_valid = self.alphaMask.sample_alpha(xyz_sampled[inner_mask]) > 0.5
            ray_valid[inner_mask.clone()] = alpha_inner_valid

        sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
        rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)

        coeffs = torch.zeros((1, sum(self.cfg.model.basis_dims)), device=xyz_sampled.device)
        if ray_valid.any():
            feats, coeffs = self.get_coding(xyz_sampled[ray_valid])
            feat = self.linear_mat(feats, is_train=is_train)
            sigma[ray_valid] = self.basis2density(feat[..., 0])

        alpha, weight, bg_weight = raw2alpha(sigma, dists * self.cfg.renderer.di
Download .txt
gitextract_iica752a/

├── .gitignore
├── 2D_regression.py
├── LICENSE
├── README.md
├── README_FactorField.md
├── configs/
│   ├── 360_v2.yaml
│   ├── defaults.yaml
│   ├── image.yaml
│   ├── image_intro.yaml
│   ├── image_set.yaml
│   ├── nerf.yaml
│   ├── nerf_ft.yaml
│   ├── nerf_set.yaml
│   ├── sdf.yaml
│   └── tnt.yaml
├── dataLoader/
│   ├── __init__.py
│   ├── blender.py
│   ├── blender_set.py
│   ├── colmap.py
│   ├── colmap2nerf.py
│   ├── dtu_objs.py
│   ├── dtu_objs2.py
│   ├── google_objs.py
│   ├── image.py
│   ├── image_set.py
│   ├── llff.py
│   ├── nsvf.py
│   ├── ray_utils.py
│   ├── sdf.py
│   ├── tankstemple.py
│   └── your_own_data.py
├── models/
│   ├── FactorFields.py
│   ├── __init__.py
│   └── sh.py
├── renderer.py
├── requirements.txt
├── run_batch.py
├── scripts/
│   ├── 2D_regression.ipynb
│   ├── 2D_set_regression.ipynb
│   ├── 2D_set_regression.py
│   ├── __init__.py
│   ├── formula_demostration.ipynb
│   ├── mesh2SDF_data_process.ipynb
│   └── sdf_regression.ipynb
├── train_across_scene.py
├── train_across_scene_ft.py
├── train_per_scene.py
└── utils.py
Download .txt
SYMBOL INDEX (257 symbols across 25 files)

FILE: 2D_regression.py
  function PSNR (line 17) | def PSNR(a, b):
  function rgb_ssim (line 26) | def rgb_ssim(img0, img1, max_val,
  function eval_img (line 76) | def eval_img(aabb, reso, shiftment=[0.5, 0.5], chunk=10240):
  function linear_to_srgb (line 95) | def linear_to_srgb(img):
  function write_image_imageio (line 100) | def write_image_imageio(img_file, img, colormap=None, quality=100):

FILE: dataLoader/blender.py
  class BlenderDataset (line 12) | class BlenderDataset(Dataset):
    method __init__ (line 13) | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
    method read_depth (line 40) | def read_depth(self, filename):
    method read_meta (line 44) | def read_meta(self):
    method define_transforms (line 105) | def define_transforms(self):
    method define_proj_mat (line 108) | def define_proj_mat(self):
    method __len__ (line 115) | def __len__(self):
    method __getitem__ (line 118) | def __getitem__(self, idx):

FILE: dataLoader/blender_set.py
  class BlenderDatasetSet (line 12) | class BlenderDatasetSet(Dataset):
    method __init__ (line 13) | def __init__(self, cfg, split='train'):
    method read_depth (line 39) | def read_depth(self, filename):
    method read_meta (line 43) | def read_meta(self):
    method define_transforms (line 103) | def define_transforms(self):
    method define_proj_mat (line 106) | def define_proj_mat(self):
    method __len__ (line 113) | def __len__(self):
    method __getitem__ (line 116) | def __getitem__(self, idx):

FILE: dataLoader/colmap.py
  class ColmapDataset (line 12) | class ColmapDataset(Dataset):
    method __init__ (line 13) | def __init__(self, cfg, split='train'):
    method read_meta (line 31) | def read_meta(self):
    method define_transforms (line 107) | def define_transforms(self):
    method __len__ (line 111) | def __len__(self):
    method __getitem__ (line 114) | def __getitem__(self, idx):

FILE: dataLoader/colmap2nerf.py
  function parse_args (line 23) | def parse_args():
  function do_system (line 44) | def do_system(arg):
  function run_ffmpeg (line 51) | def run_ffmpeg(args):
  function run_colmap (line 73) | def run_colmap(args):
  function variance_of_laplacian (line 107) | def variance_of_laplacian(image):
  function sharpness (line 110) | def sharpness(imagePath):
  function qvec2rotmat (line 116) | def qvec2rotmat(qvec):
  function rotmat (line 133) | def rotmat(a, b):
  function closest_point_2_lines (line 141) | def closest_point_2_lines(oa, da, ob, db): # returns point closest to bo...
  function normalize (line 156) | def normalize(x):
  function rotation_matrix_from_vectors (line 159) | def rotation_matrix_from_vectors(vec1, vec2):
  function rotation_up (line 176) | def rotation_up(poses):
  function search_orientation (line 181) | def search_orientation(points):
  function load_point_txt (line 196) | def load_point_txt(path):

FILE: dataLoader/dtu_objs.py
  function load_K_Rt_from_P (line 11) | def load_K_Rt_from_P(filename, P=None):
  function fps_downsample (line 34) | def fps_downsample(points, n_points_to_sample):
  class DTUDataset (line 47) | class DTUDataset(Dataset):
    method __init__ (line 48) | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
    method read_meta (line 78) | def read_meta(self):
    method load_data (line 122) | def load_data(self, scene_idx, img_idx=None):
    method get_bbox (line 196) | def get_bbox(self, scale_mats_np, object_scale_mat):
    method gen_rays_at (line 206) | def gen_rays_at(self, intrinsic, c2w, resolution_level=1):
    method __len__ (line 224) | def __len__(self):
    method __getitem__ (line 227) | def __getitem__(self, idx):

FILE: dataLoader/dtu_objs2.py
  function load_K_Rt_from_P (line 12) | def load_K_Rt_from_P(filename, P=None):
  class DTUDataset (line 35) | class DTUDataset(Dataset):
    method __init__ (line 36) | def __init__(self, cfg, split='train', batch_size=4096, is_stack=None):
    method get_bbox (line 57) | def get_bbox(self):
    method gen_rays_at (line 68) | def gen_rays_at(self, intrinsic, c2w, resolution_level=1):
    method read_meta (line 85) | def read_meta(self):
    method __len__ (line 131) | def __len__(self):
    method __getitem__ (line 134) | def __getitem__(self, idx):

FILE: dataLoader/google_objs.py
  function fps_downsample (line 13) | def fps_downsample(points, n_points_to_sample):
  function pose_spherical_nerf (line 29) | def pose_spherical_nerf(euler, radius=1.8, ep=1):
  function nerf_video_path (line 36) | def nerf_video_path(c2ws, theta_range=10,phi_range=20,N_views=120,radius...
  function _interpolate_trajectory (line 56) | def _interpolate_trajectory(c2ws, num_views: int = 300):
  function google_objs_path (line 77) | def google_objs_path(c2ws, N_views=150):
  class GoogleObjsDataset (line 85) | class GoogleObjsDataset(Dataset):
    method __init__ (line 86) | def __init__(self, cfg, split="train", batch_size=4096):
    method read_depth (line 174) | def read_depth(self, filename):
    method read_meta (line 178) | def read_meta(self):
    method get_rays (line 301) | def get_rays(self, idx):
    method get_rgbs (line 306) | def get_rgbs(self, idx):
    method define_transforms (line 312) | def define_transforms(self):
    method define_proj_mat (line 315) | def define_proj_mat(self):
    method world2ndc (line 320) | def world2ndc(self, points, lindisp=None):
    method update_index (line 324) | def update_index(self):
    method __len__ (line 327) | def __len__(self):
    method __getitem__ (line 330) | def __getitem__(self, idx):

FILE: dataLoader/image.py
  function load (line 10) | def load(path):
  function srgb_to_linear (line 22) | def srgb_to_linear(img):
  class ImageDataset (line 26) | class ImageDataset(Dataset):
    method __init__ (line 27) | def __init__(self, cfg, batchsize, split='train', continue_sampling=Fa...
    method __len__ (line 84) | def __len__(self):
    method __getitem__ (line 87) | def __getitem__(self, idx):

FILE: dataLoader/image_set.py
  function srgb_to_linear (line 6) | def srgb_to_linear(img):
  function load (line 10) | def load(path, HW=512):
  class ImageSetDataset (line 24) | class ImageSetDataset(Dataset):
    method __init__ (line 25) | def __init__(self, cfg, batchsize, split='train', continue_sampling=Fa...
    method __len__ (line 50) | def __len__(self):
    method __getitem__ (line 53) | def __getitem__(self, idx):

FILE: dataLoader/llff.py
  function normalize (line 12) | def normalize(v):
  function average_poses (line 17) | def average_poses(poses):
  function center_poses (line 54) | def center_poses(poses, blender2opencv):
  function viewmatrix (line 81) | def viewmatrix(z, up, pos):
  function render_path_spiral (line 91) | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=...
  function get_spiral (line 102) | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
  class LLFFDataset (line 122) | class LLFFDataset(Dataset):
    method __init__ (line 123) | def __init__(self, cfg , split='train', hold_every=8):
    method read_meta (line 148) | def read_meta(self):
    method define_transforms (line 231) | def define_transforms(self):
    method __len__ (line 234) | def __len__(self):
    method __getitem__ (line 237) | def __getitem__(self, idx):

FILE: dataLoader/nsvf.py
  function pose_spherical (line 29) | def pose_spherical(theta, phi, radius):
  class NSVF (line 36) | class NSVF(Dataset):
    method __init__ (line 38) | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800...
    method bbox2corners (line 56) | def bbox2corners(self):
    method read_meta (line 63) | def read_meta(self):
    method define_transforms (line 132) | def define_transforms(self):
    method define_proj_mat (line 135) | def define_proj_mat(self):
    method world2ndc (line 138) | def world2ndc(self, points):
    method __len__ (line 142) | def __len__(self):
    method __getitem__ (line 147) | def __getitem__(self, idx):

FILE: dataLoader/ray_utils.py
  function load_json (line 7) | def load_json(path):
  function depth2dist (line 12) | def depth2dist(z_vals, cos_angle):
  function ndc2dist (line 21) | def ndc2dist(ndc_pts, cos_angle):
  function get_ray_directions (line 27) | def get_ray_directions(H, W, focal, center=None):
  function get_ray_directions_blender (line 48) | def get_ray_directions_blender(H, W, focal, center=None):
  function get_rays (line 69) | def get_rays(directions, c2w):
  function ndc_rays_blender (line 93) | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
  function ndc_rays (line 112) | def ndc_rays(H, W, focal, near, rays_o, rays_d):
  function sample_pdf (line 132) | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
  function dda (line 177) | def dda(rays_o, rays_d, bbox_3D):
  function ray_marcher (line 187) | def ray_marcher(rays,
  function read_pfm (line 234) | def read_pfm(filename):
  class SimpleSampler (line 271) | class SimpleSampler:
    method __init__ (line 272) | def __init__(self, total, batch):
    method nextids (line 278) | def nextids(self):
  function ndc_bbox (line 285) | def ndc_bbox(all_rays):
  function pose_from_json (line 293) | def pose_from_json(meta, transpose):
  function rotation_matrix_from_vectors (line 300) | def rotation_matrix_from_vectors(vec1, vec2):
  function normalize (line 317) | def normalize(x):
  function rotation_up (line 320) | def rotation_up(poses):
  function search_orientation (line 325) | def search_orientation(points):
  function load_point_txt (line 341) | def load_point_txt(path):
  function orientation (line 365) | def orientation(poses, point_path=None):
  function spherify_poses (line 397) | def spherify_poses(poses, radus=1):

FILE: dataLoader/sdf.py
  function N_to_reso (line 5) | def N_to_reso(avg_reso, bbox):
  function load (line 12) | def load(path, split, dtype='points'):
  class SDFDataset (line 31) | class SDFDataset(Dataset):
    method __init__ (line 32) | def __init__(self, cfg, split='train'):
    method __len__ (line 42) | def __len__(self):
    method __getitem__ (line 45) | def __getitem__(self, idx):

FILE: dataLoader/tankstemple.py
  function circle (line 11) | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
  function cross (line 20) | def cross(x, y, axis=0):
  function normalize (line 25) | def normalize(x, axis=-1, order=2):
  function cat (line 37) | def cat(x, axis=1):
  function look_at_rotation (line 43) | def look_at_rotation(camera_position, at=None, up=None, inverse=False, c...
  function gen_path (line 76) | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
  class TanksTempleDataset (line 87) | class TanksTempleDataset(Dataset):
    method __init__ (line 90) | def __init__(self, cfg, split='train'):
    method bbox2corners (line 109) | def bbox2corners(self):
    method read_meta (line 115) | def read_meta(self):
    method define_transforms (line 192) | def define_transforms(self):
    method define_proj_mat (line 195) | def define_proj_mat(self):
    method world2ndc (line 198) | def world2ndc(self, points):
    method __len__ (line 202) | def __len__(self):
    method __getitem__ (line 207) | def __getitem__(self, idx):

FILE: dataLoader/your_own_data.py
  class YourOwnDataset (line 13) | class YourOwnDataset(Dataset):
    method __init__ (line 14) | def __init__(self, datadir, split='train', downsample=1.0, is_stack=Fa...
    method read_depth (line 35) | def read_depth(self, filename):
    method read_meta (line 39) | def read_meta(self):
    method define_transforms (line 102) | def define_transforms(self):
    method define_proj_mat (line 105) | def define_proj_mat(self):
    method world2ndc (line 108) | def world2ndc(self,points,lindisp=None):
    method __len__ (line 112) | def __len__(self):
    method __getitem__ (line 115) | def __getitem__(self, idx):

FILE: models/FactorFields.py
  function grid_mapping (line 11) | def grid_mapping(positions, freq_bands, aabb, basis_mapping='sawtooth'):
  function dct_dict (line 36) | def dct_dict(n_atoms_fre, size, n_selete, dim=2):
  function positional_encoding (line 74) | def positional_encoding(positions, freqs):
  function raw2alpha (line 82) | def raw2alpha(sigma, dist):
  class AlphaGridMask (line 91) | class AlphaGridMask(torch.nn.Module):
    method __init__ (line 92) | def __init__(self, device, aabb, alpha_volume):
    method sample_alpha (line 103) | def sample_alpha(self, xyz_sampled):
    method normalize_coord (line 109) | def normalize_coord(self, xyz_sampled):
  class MLPMixer (line 113) | class MLPMixer(torch.nn.Module):
    method __init__ (line 114) | def __init__(self,
    method forward (line 144) | def forward(self, x, is_train=False):
  class MLPRender_Fea (line 162) | class MLPRender_Fea(torch.nn.Module):
    method __init__ (line 163) | def __init__(self, inChanel, num_layers=3, hidden_dim=64, viewpe=6, fe...
    method forward (line 188) | def forward(self, viewdirs, features):
  class FactorFields (line 206) | class FactorFields(torch.nn.Module):
    method __init__ (line 207) | def __init__(self, cfg, device):
    method setup_params (line 262) | def setup_params(self, aabb):
    method init_coef (line 311) | def init_coef(self):
    method init_basis (line 334) | def init_basis(self):
    method get_coeff (line 425) | def get_coeff(self, xyz_sampled):
    method get_basis (line 467) | def get_basis(self, x):
    method normalize_basis (line 519) | def normalize_basis(self):
    method get_coding (line 523) | def get_coding(self, x):
    method n_parameters (line 535) | def n_parameters(self):
    method get_optparam_groups (line 541) | def get_optparam_groups(self, lr_small=0.001, lr_large=0.02):
    method set_optimizable (line 556) | def set_optimizable(self, items, statue):
    method TV_loss (line 569) | def TV_loss(self, reg):
    method sample_point_ndc (line 575) | def sample_point_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
    method sample_point (line 586) | def sample_point(self, rays_o, rays_d, is_train=True, N_samples=-1):
    method sample_point_unbound (line 604) | def sample_point_unbound(self, rays_o, rays_d, is_train=True, N_sample...
    method normalize_coord (line 635) | def normalize_coord(self, xyz_sampled):
    method basis2density (line 639) | def basis2density(self, density_features):
    method cal_mean_coef (line 646) | def cal_mean_coef(self, state_dict):
    method save (line 669) | def save(self, path):
    method load (line 682) | def load(self, ckpt):
    method update_renderParams (line 693) | def update_renderParams(self, gridSize):
    method upsample_volume_grid (line 702) | def upsample_volume_grid(self, res_target):
    method compute_alpha (line 710) | def compute_alpha(self, xyz_locs, length=1):
    method getDenseAlpha (line 730) | def getDenseAlpha(self, gridSize=None, times=16):
    method updateAlphaMask (line 758) | def updateAlphaMask(self, gridSize=(200, 200, 200), is_update_alphaMas...
    method shrink (line 796) | def shrink(self, new_aabb):
    method filtering_rays (line 812) | def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=1024...
    method forward (line 843) | def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=F...

FILE: models/sh.py
  function eval_sh (line 34) | def eval_sh(deg, sh, dirs):
  function eval_sh_bases (line 87) | def eval_sh_bases(deg, dirs):

FILE: renderer.py
  function render_ray (line 8) | def render_ray(rays, factor_fields, chunk=4096, N_samples=-1, ndc_ray=Fa...
  function evaluation (line 30) | def evaluation(test_dataset,factor_fields, renderer, savePath=None, N_vi...
  function evaluation_path (line 101) | def evaluation_path(test_dataset,factor_fields, c2ws, renderer, savePath...

FILE: run_batch.py
  function run_program (line 123) | def run_program(gpu, cmd):

FILE: scripts/2D_set_regression.py
  function PSNR (line 19) | def PSNR(a, b):
  function eval_img (line 29) | def eval_img(aabb, reso, idx, shiftment=[0.5, 0.5, 0.5], chunk=10240):
  function eval_img_single (line 49) | def eval_img_single(aabb, reso, chunk=10240):
  function linear_to_srgb (line 67) | def linear_to_srgb(img):
  function srgb_to_linear (line 72) | def srgb_to_linear(img):
  function write_image_imageio (line 77) | def write_image_imageio(img_file, img, colormap=None, quality=100):
  function interpolate (line 97) | def interpolate(colormap, x):

FILE: train_across_scene.py
  class SimpleSampler (line 17) | class SimpleSampler:
    method __init__ (line 18) | def __init__(self, total, batch):
    method nextids (line 24) | def nextids(self):
  function export_mesh (line 33) | def export_mesh(cfg):
  function render_test (line 55) | def render_test(cfg):
  function reconstruction (line 99) | def reconstruction(cfg):

FILE: train_across_scene_ft.py
  class SimpleSampler (line 17) | class SimpleSampler:
    method __init__ (line 18) | def __init__(self, total, batch):
    method nextids (line 24) | def nextids(self):
  function export_mesh (line 33) | def export_mesh(cfg):
  function render_test (line 42) | def render_test(cfg):
  function reconstruction (line 88) | def reconstruction(cfg):

FILE: train_per_scene.py
  class SimpleSampler (line 16) | class SimpleSampler:
    method __init__ (line 17) | def __init__(self, total, batch):
    method nextids (line 23) | def nextids(self):
  function export_mesh (line 32) | def export_mesh(ckpt_path):
  function render_test (line 44) | def render_test(cfg):
  function reconstruction (line 89) | def reconstruction(cfg):

FILE: utils.py
  function visualize_depth_numpy (line 11) | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
  function init_log (line 28) | def init_log(log, keys):
  function visualize_depth (line 33) | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
  function N_to_reso (line 53) | def N_to_reso(n_voxels, bbox):
  function N_to_vm_reso (line 59) | def N_to_vm_reso(n_voxels, bbox):
  function cal_n_samples (line 69) | def cal_n_samples(reso, step_ratio=0.5):
  class SimpleSampler (line 72) | class SimpleSampler:
    method __init__ (line 73) | def __init__(self, total, batch):
    method nextids (line 79) | def nextids(self):
  function init_lpips (line 87) | def init_lpips(net_name, device):
  function rgb_lpips (line 93) | def rgb_lpips(np_gt, np_im, net_name, device):
  function findItem (line 101) | def findItem(items, target):
  function rgb_ssim (line 110) | def rgb_ssim(img0, img1, max_val,
  class TVLoss (line 160) | class TVLoss(nn.Module):
    method __init__ (line 161) | def __init__(self,TVLoss_weight=1):
    method forward (line 165) | def forward(self,x):
    method _tensor_size (line 175) | def _tensor_size(self,t):
  function marchcude_to_world (line 178) | def marchcude_to_world(vertices, reso_WHD):
  function convert_sdf_samples_to_ply (line 183) | def convert_sdf_samples_to_ply(
Condensed preview — 48 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (1,773K chars).
[
  {
    "path": ".gitignore",
    "chars": 3099,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "2D_regression.py",
    "chars": 9265,
    "preview": "import torch,imageio,sys,time,os,cmapy,scipy\nimport numpy as np\nfrom tqdm import tqdm\nimport matplotlib.pyplot as plt\nfr"
  },
  {
    "path": "LICENSE",
    "chars": 1073,
    "preview": "MIT License\n\nCopyright (c) 2023 autonomousvision\n\nPermission is hereby granted, free of charge, to any person obtaining "
  },
  {
    "path": "README.md",
    "chars": 6031,
    "preview": "# Factor Fields\n## [Project page](https://apchenstu.github.io/FactorFields/) |  [Paper](https://arxiv.org/abs/2302.01226"
  },
  {
    "path": "README_FactorField.md",
    "chars": 4832,
    "preview": "## nerf reconstruction with Dictionary field\n\n```python\n\tfor scene in ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus',"
  },
  {
    "path": "configs/360_v2.yaml",
    "chars": 1568,
    "preview": "\ndefaults:\n  expname: basis_room_real_mask\n  logdir: ./logs\n\n  ckpt: null                  # help='specific weights npy "
  },
  {
    "path": "configs/defaults.yaml",
    "chars": 1342,
    "preview": "\ndefaults:\n  expname:  basis_lego\n  logedir: ./logs\n\n  mode: 'reconstruction'\n\n  progress_refresh_rate: 10\n\n  add_timest"
  },
  {
    "path": "configs/image.yaml",
    "chars": 1129,
    "preview": "\ndefaults:\n  expname: basis_image\n  logdir: ./logs\n\n  mode: 'image'\n\n  ckpt: null                  # help='specific weig"
  },
  {
    "path": "configs/image_intro.yaml",
    "chars": 767,
    "preview": "\ndefaults:\n  expname: basis_image\n  logdir: ./logs\n\n  mode: 'demo'\n\n  ckpt: null                  # help='specific weigh"
  },
  {
    "path": "configs/image_set.yaml",
    "chars": 733,
    "preview": "\ndefaults:\n  expname: basis_image\n  logdir: ./logs\n\n  mode: 'images'\n\n  ckpt: null                  # help='specific wei"
  },
  {
    "path": "configs/nerf.yaml",
    "chars": 1613,
    "preview": "defaults:\n  expname: basis_lego\n  logdir: ./logs\n\n  mode: 'reconstruction'\n\n  ckpt: null                  # help='specif"
  },
  {
    "path": "configs/nerf_ft.yaml",
    "chars": 1795,
    "preview": "defaults:\n  expname: basis\n  logdir: ./logs\n\n  mode: 'reconstructions'\n\n  ckpt: null                  # help='specific w"
  },
  {
    "path": "configs/nerf_set.yaml",
    "chars": 1831,
    "preview": "defaults:\n  expname: basis_no_relu_lego\n  logdir: ./logs\n\n  mode: 'reconstructions'\n\n  ckpt: null                  # hel"
  },
  {
    "path": "configs/sdf.yaml",
    "chars": 745,
    "preview": "defaults:\n  expname: basis_sdf\n  logdir: ./logs\n\n  mode: 'sdf'\n\n  ckpt: null                  # help='specific weights n"
  },
  {
    "path": "configs/tnt.yaml",
    "chars": 2138,
    "preview": "\ndefaults:\n  expname: basis_truck\n  logdir: ./logs\n\n  ckpt: null                  # help='specific weights npy file to r"
  },
  {
    "path": "dataLoader/__init__.py",
    "chars": 913,
    "preview": "from .llff import LLFFDataset\nfrom .blender import BlenderDataset\nfrom .nsvf import NSVF\nfrom .tankstemple import TanksT"
  },
  {
    "path": "dataLoader/blender.py",
    "chars": 5265,
    "preview": "import torch, cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image"
  },
  {
    "path": "dataLoader/blender_set.py",
    "chars": 5025,
    "preview": "import torch, cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image"
  },
  {
    "path": "dataLoader/colmap.py",
    "chars": 4958,
    "preview": "import torch, cv2\nfrom torch.utils.data import Dataset\n\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torch"
  },
  {
    "path": "dataLoader/colmap2nerf.py",
    "chars": 13415,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and "
  },
  {
    "path": "dataLoader/dtu_objs.py",
    "chars": 10686,
    "preview": "\nimport torch\nimport cv2 as cv\nimport numpy as np\nimport os\nfrom glob import glob\nfrom .ray_utils import *\nfrom torch.ut"
  },
  {
    "path": "dataLoader/dtu_objs2.py",
    "chars": 5829,
    "preview": "\nimport torch\nimport cv2 as cv\nimport numpy as np\nimport os\nfrom glob import glob\nfrom .ray_utils import *\nfrom torch.ut"
  },
  {
    "path": "dataLoader/google_objs.py",
    "chars": 16048,
    "preview": "import torch, cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image"
  },
  {
    "path": "dataLoader/image.py",
    "chars": 3881,
    "preview": "import torch,imageio,cv2\nfrom PIL import Image \nImage.MAX_IMAGE_PIXELS = 1000000000 \nimport numpy as np\nimport torch.nn."
  },
  {
    "path": "dataLoader/image_set.py",
    "chars": 2628,
    "preview": "import torch,cv2\nimport numpy as np\nimport torch.nn.functional as F\nfrom torch.utils.data import Dataset\n\ndef srgb_to_li"
  },
  {
    "path": "dataLoader/llff.py",
    "chars": 9522,
    "preview": "import torch\nfrom torch.utils.data import Dataset\nimport glob\nimport numpy as np\nimport os\nfrom PIL import Image\nfrom to"
  },
  {
    "path": "dataLoader/nsvf.py",
    "chars": 6584,
    "preview": "import torch\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision"
  },
  {
    "path": "dataLoader/ray_utils.py",
    "chars": 16205,
    "preview": "import torch, re, json\nimport numpy as np\nfrom torch import searchsorted\nfrom kornia import create_meshgrid\n\n\ndef load_j"
  },
  {
    "path": "dataLoader/sdf.py",
    "chars": 1684,
    "preview": "import torch\nimport numpy as np\nfrom torch.utils.data import Dataset\n\ndef N_to_reso(avg_reso, bbox):\n    xyz_min, xyz_ma"
  },
  {
    "path": "dataLoader/tankstemple.py",
    "chars": 8866,
    "preview": "import torch\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision"
  },
  {
    "path": "dataLoader/your_own_data.py",
    "chars": 5026,
    "preview": "import torch,cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\n"
  },
  {
    "path": "models/FactorFields.py",
    "chars": 42185,
    "preview": "import torch, math\nimport torch.nn\nimport torch.nn.functional as F\nimport numpy as np\nimport time, skimage\nfrom utils im"
  },
  {
    "path": "models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/sh.py",
    "chars": 5231,
    "preview": "import torch\n\n################## sh function ##################\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n"
  },
  {
    "path": "renderer.py",
    "chars": 6556,
    "preview": "import torch,os,imageio,sys\nfrom tqdm.auto import tqdm\nfrom dataLoader.ray_utils import get_rays\nfrom utils import *\nfro"
  },
  {
    "path": "requirements.txt",
    "chars": 170,
    "preview": "imageio==2.26.0\nkornia==0.6.10\nlpips==0.1.4\nmatplotlib==3.7.1\nomegaconf==2.3.0\nopencv_python==4.7.0.72\nplyfile==0.7.4\nsc"
  },
  {
    "path": "run_batch.py",
    "chars": 8462,
    "preview": "\nimport os\nimport threading, queue\nimport numpy as np\nimport time\n\n\nif __name__ == '__main__':\n\n\t################  per s"
  },
  {
    "path": "scripts/2D_regression.ipynb",
    "chars": 690058,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"f969c229-5a8a-44b6-91a3-bba55968b202\",\n   \""
  },
  {
    "path": "scripts/2D_set_regression.ipynb",
    "chars": 744891,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"f969c229-5a8a-44b6-91a3-bba55968b202\",\n   \""
  },
  {
    "path": "scripts/2D_set_regression.py",
    "chars": 5635,
    "preview": "import torch,imageio,sys,cmapy,time,os\nimport numpy as np\nfrom tqdm import tqdm\n# from .autonotebook import tqdm as tqdm"
  },
  {
    "path": "scripts/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "scripts/formula_demostration.ipynb",
    "chars": 33703,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"f969c229-5a8a-44b6-91a3-bba55968b202\",\n   \""
  },
  {
    "path": "scripts/mesh2SDF_data_process.ipynb",
    "chars": 3598,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"47120857-b2aa-4a36-8733-e04776ca7a80\",\n   \""
  },
  {
    "path": "scripts/sdf_regression.ipynb",
    "chars": 15828,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"id\": \"b4ee17ad-e38b-4d49-a8d6-0ec910d9e528\",\n   \""
  },
  {
    "path": "train_across_scene.py",
    "chars": 10562,
    "preview": "from tqdm.auto import tqdm\nfrom omegaconf import OmegaConf\nfrom models.FactorFields import FactorFields\n\nimport json, ra"
  },
  {
    "path": "train_across_scene_ft.py",
    "chars": 11993,
    "preview": "from tqdm.auto import tqdm\nfrom omegaconf import OmegaConf\nfrom models.FactorFields import FactorFields\n\nimport json, ra"
  },
  {
    "path": "train_per_scene.py",
    "chars": 10729,
    "preview": "from tqdm.auto import tqdm\nfrom omegaconf import OmegaConf\nfrom models.FactorFields import FactorFields\n\nimport json, ra"
  },
  {
    "path": "utils.py",
    "chars": 7990,
    "preview": "import cv2,torch,math\nimport numpy as np\nfrom PIL import Image\nimport torchvision.transforms as T\nimport torch.nn.functi"
  }
]

About this extraction

This page contains the full source code of the autonomousvision/factor-fields GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 48 files (1.7 MB), approximately 1.1M tokens, and a symbol index with 257 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!