Full Code of Kunhao-Liu/StyleRF for AI

main 43611c3ec7ba cached
36 files
205.9 KB
60.2k tokens
246 symbols
1 requests
Download .txt
Showing preview only (217K chars total). Download the full file or copy to clipboard to get everything.
Repository: Kunhao-Liu/StyleRF
Branch: main
Commit: 43611c3ec7ba
Files: 36
Total size: 205.9 KB

Directory structure:
gitextract_j81cq_2h/

├── README.md
├── configs/
│   ├── llff.txt
│   ├── llff_feature.txt
│   ├── llff_style.txt
│   ├── nerf_synthetic.txt
│   ├── nerf_synthetic_feature.txt
│   └── nerf_synthetic_style.txt
├── dataLoader/
│   ├── __init__.py
│   ├── blender.py
│   ├── colmap2nerf.py
│   ├── llff.py
│   ├── nsvf.py
│   ├── ray_utils.py
│   ├── styleLoader.py
│   ├── tankstemple.py
│   └── your_own_data.py
├── extra/
│   ├── auto_run_paramsets.py
│   └── compute_metrics.py
├── models/
│   ├── VGG.py
│   ├── __init__.py
│   ├── sh.py
│   ├── styleModules.py
│   ├── tensoRF.py
│   └── tensorBase.py
├── opt.py
├── renderer.py
├── scripts/
│   ├── test.sh
│   ├── test_feature.sh
│   ├── test_style.sh
│   ├── train.sh
│   ├── train_feature.sh
│   └── train_style.sh
├── train.py
├── train_feature.py
├── train_style.py
└── utils.py

================================================
FILE CONTENTS
================================================

================================================
FILE: README.md
================================================
# [*CVPR 2023*] StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields
## [Project page](https://kunhao-liu.github.io/StyleRF/) |  [Paper](https://arxiv.org/abs/2303.10598)

This repository contains a pytorch implementation for the paper: [StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields](https://arxiv.org/abs/2303.10598). StyleRF is an innovative 3D style transfer technique that achieves superior 3D stylization quality with precise geometry reconstruction and it can generalize to various new styles in a zero-shot manner. 

![teaser](https://kunhao-liu.github.io/StyleRF/resources/teaser.png)

---
## Installation
> Tested on Ubuntu 20.04 + Pytorch 1.12.1

Install environment:
```
conda create -n StyleRF python=3.9
conda activate StyleRF
pip install torch torchvision
pip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg kornia lpips tensorboard
```

## Datasets
Please put the datasets in `./data`. You can put the datasets elsewhere if you modify the corresponding paths in the configs.

### 3D scene datasets
* [nerf_synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) 
* [llff](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
### Style image dataset
* [WikiArt](https://www.kaggle.com/datasets/ipythonx/wikiart-gangogh-creating-art-gan)

## Quick Start
We provide some trained checkpoints in: [StyleRF checkpoints](https://drive.google.com/drive/folders/1nF9-6lTIhktG5JjNvnmdYOo1LTvtK7Dw?usp=share_link)

Then modify the following attributes in `scripts/test_style.sh`:
* `--config`: choose `configs/llff_style.txt` or `configs/nerf_synthetic_style.txt` according to which type of dataset is being used
* `--datadir`: dataset's path
* `--ckpt`: checkpoint's path
* `--style_img`: reference style image's path


To generate stylized novel views:
```
bash scripts/test_style.sh [GPU ID]
```
The rendered stylized images can then be found in the directory under the checkpoint's path.

## Training
> Current settings in `configs` are tested on one NVIDIA RTX A5000 Graphics Card with 24G memory. To reduce memory consumption, you can set `batch_size`, `chunk_size` or `patch_size` to a smaller number.

We follow the following 3 steps of training:
### 1. Train original TensoRF
This step is for reconstructing the density field, which contains more precise geometry details compared to mesh-based methods. You can skip this step by directly downloading pre-trained checkpoints provided by [TensoRF checkpoints](https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm).

The configs are stored in `configs/llff.txt` and `configs/nerf_synthetic.txt`. For the details of the settings, please also refer to [TensoRF](https://github.com/apchenstu/TensoRF). The checkpoints are stored in `./log` by default.

You can train the original TensoRF by:
```
bash script/train.sh [GPU ID]
```

### 2. Feature grid training stage
This step is for reconstructing the 3D gird containing the VGG features.

The configs are stored in `configs/llff_feature.txt` and `configs/nerf_synthetic_feature.txt`, in which `ckpt` specifies the checkpoints trained in the **first** step. The checkpoints are stored in `./log_feature` by default.

Then run:
```
bash script/train_feature.sh [GPU ID]
```


### 3. Stylization training stage 
This step is for training the style transfer modules.

The configs are stored in `configs/llff_style.txt` and `configs/nerf_synthetic_style.txt`, in which `ckpt` specifies the checkpoints trained in the **second** step. The checkpoints are stored in `./log_style` by default.

Then run:
```
bash script/train_style.sh [GPU ID]
```

---
## Training on 360 Unbounded Scenes
The code for training StyleRF on the Tanks&Temples dataset is available on the `360` branch. To access it, run `git checkout 360`.


## Acknowledgments
This repo is heavily based on the [TensoRF](https://github.com/apchenstu/TensoRF). Thank them for sharing their amazing work!

## Citation
If you find our code or paper helps, please consider citing:
```
@inproceedings{liu2023stylerf,
  title={StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields},
  author={Liu, Kunhao and Zhan, Fangneng and Chen, Yiwen and Zhang, Jiahui and Yu, Yingchen and El Saddik, Abdulmotaleb and Lu, Shijian and Xing, Eric P},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={8338--8348},
  year={2023}
}
```



================================================
FILE: configs/llff.txt
================================================

dataset_name = llff
datadir = ./data/nerf_llff_data/trex
expname = trex
basedir = ./log

downsample_train = 4.0
ndc_ray = 1

n_iters = 25000
batch_size = 4096

N_voxel_init = 2097156 # 128**3
N_voxel_final = 262144000 # 640**3
upsamp_list = [2000,3000,4000,5500]
update_AlphaMask_list = [2500]

N_vis = -1 # vis all testing images
vis_every = 10000

render_test = 1
render_path = 1

n_lamb_sigma = [16,4,4]
n_lamb_sh = [48,12,12]

shadingMode = MLP_Fea
fea2denseAct = relu

view_pe = 0
fea_pe = 0

TV_weight_density = 1.0
TV_weight_app = 1.0



================================================
FILE: configs/llff_feature.txt
================================================
dataset_name = llff
datadir = ./data/nerf_llff_data/trex
ckpt = ./log/trex/trex.th
expname = trex
basedir = ./log_feature

TV_weight_feature = 80

downsample_train = 4.0
ndc_ray = 1

n_iters = 25000
patch_size = 256
batch_size = 4096
chunk_size = 4096

N_voxel_init = 2097156 # 128**3
N_voxel_final = 262144000 # 640**3
upsamp_list = [2000,3000,4000,5500]
update_AlphaMask_list = [2500]

n_lamb_sigma = [16,4,4]
n_lamb_sh = [48,12,12]

fea2denseAct = relu



================================================
FILE: configs/llff_style.txt
================================================
dataset_name = llff
datadir = ./data/nerf_llff_data/trex
ckpt = ./log_feature/trex/trex.th
expname = trex
basedir = ./log_style

nSamples = 300
patch_size = 256
chunk_size = 2048

content_weight = 1
style_weight = 20
featuremap_tv_weight = 0
image_tv_weight = 0

rm_weight_mask_thre = 0.001

downsample_train = 4.0
ndc_ray = 1

n_iters = 25000

n_lamb_sigma = [16,4,4]
n_lamb_sh = [48,12,12]
N_voxel_init = 2097156 # 128**3
N_voxel_final = 262144000 # 640**3

fea2denseAct = relu


================================================
FILE: configs/nerf_synthetic.txt
================================================

dataset_name = blender
datadir = ./data/nerf_synthetic/lego
expname =  lego
basedir = ./log

n_iters = 30000
batch_size = 4096

N_voxel_init = 2097156 # 128**3
N_voxel_final = 27000000 # 300**3
upsamp_list = [2000,3000,4000,5500,7000]
update_AlphaMask_list = [2000,4000]

N_vis = 5
vis_every = 10000

render_test = 1

n_lamb_sigma = [16,16,16]
n_lamb_sh = [48,48,48]
model_name = TensorVMSplit


shadingMode = MLP_Fea
fea2denseAct = softplus

view_pe = 2
fea_pe = 2

L1_weight_inital = 8e-5
L1_weight_rest = 4e-5
rm_weight_mask_thre = 1e-4


================================================
FILE: configs/nerf_synthetic_feature.txt
================================================
dataset_name = blender
datadir = ./data/nerf_synthetic/lego
ckpt = ./log/lego/lego.th
expname = lego
basedir = ./log_feature

TV_weight_feature = 10

n_iters = 25000
patch_size = 256
batch_size = 4096
chunk_size = 4096

N_voxel_init = 2097156 # 128**3
N_voxel_final = 27000000 # 300**3
upsamp_list = [2000,3000,4000,5500,7000]
update_AlphaMask_list = [2000,4000]

rm_weight_mask_thre = 0.01

n_lamb_sigma = [16,16,16]
n_lamb_sh = [48,48,48]

fea2denseAct = softplus


================================================
FILE: configs/nerf_synthetic_style.txt
================================================
dataset_name = blender
datadir = ./data/nerf_synthetic/lego
ckpt = ./log_feature/lego/lego.th
expname = lego
basedir = ./log_style

patch_size = 256
chunk_size = 2048

content_weight = 1
style_weight = 20

rm_weight_mask_thre = 0.01

n_iters = 25000

n_lamb_sigma = [16,16,16]
n_lamb_sh = [48,48,48]
N_voxel_init = 2097156 # 128**3
N_voxel_final = 27000000 # 300**3

fea2denseAct = softplus


================================================
FILE: dataLoader/__init__.py
================================================
from .llff import LLFFDataset
from .blender import BlenderDataset
from .nsvf import NSVF
from .tankstemple import TanksTempleDataset
from .your_own_data import YourOwnDataset



dataset_dict = {'blender': BlenderDataset,
               'llff':LLFFDataset,
               'tankstemple':TanksTempleDataset,
               'nsvf':NSVF,
                'own_data':YourOwnDataset}

================================================
FILE: dataLoader/blender.py
================================================
import torch,cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T


from .ray_utils import *


class BlenderDataset(Dataset):
    def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):

        self.N_vis = N_vis
        self.root_dir = datadir
        self.split = split
        self.is_stack = is_stack
        self.img_wh = (int(800/downsample),int(800/downsample))
        self.define_transforms()

        self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.define_proj_mat()

        self.white_bg = True
        self.near_far = [2.0,6.0]
        
        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
        self.downsample=downsample

    def read_depth(self, filename):
        depth = np.array(read_pfm(filename)[0], dtype=np.float32)  # (800, 800)
        return depth
    
    def read_meta(self):

        with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
            self.meta = json.load(f)

        w, h = self.img_wh
        self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x'])  # original focal length
        self.focal *= self.img_wh[0] / 800  # modify focal length to match size self.img_wh


        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(h, w, [self.focal,self.focal])  # (h, w, 3)
        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
        self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float()

        self.image_paths = []
        self.poses = []
        self.all_rays = []
        self.all_rgbs = []
        self.all_masks = []
        self.all_depth = []
        self.downsample=1.0

        img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
        idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
        for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#

            frame = self.meta['frames'][i]
            pose = np.array(frame['transform_matrix']) @ self.blender2opencv
            c2w = torch.FloatTensor(pose)
            self.poses += [c2w]

            image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
            self.image_paths += [image_path]
            img = Image.open(image_path)
            
            if self.downsample!=1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (4, h, w)
            img = img.view(4, -1).permute(1, 0)  # (h*w, 4) RGBA
            self.all_masks.append(img[:, -1:].reshape(h,w,1)) # (h, w, 1) A
            img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
            self.all_rgbs += [img]


            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)

        
        self.all_masks = torch.stack(self.all_masks) # (n_frames, h, w, 1)
        self.poses = torch.stack(self.poses)
        all_rays = self.all_rays
        all_rgbs = self.all_rgbs

        self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6)
        self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)

        if self.is_stack:
            self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6)  # (len(self.meta['frames]),h,w,6)
            avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True)
            self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6)
            self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)

    @torch.no_grad()
    def prepare_feature_data(self, encoder, chunk=8):
        '''
        Prepare feature maps as training data.
        '''
        assert self.is_stack, 'Dataset should contain original stacked taining data!'
        print('====> prepare_feature_data ...')

        frames_num, h, w, _ = self.all_rgbs_stack.size()
        features = []

        for chunk_idx in range(frames_num // chunk + int(frames_num % chunk > 0)):
            rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda()
            features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1
            # resize to the size of rgb map so that rays can match
            features_chunk = T.functional.resize(features_chunk, size=(h,w), 
                                                 interpolation=T.InterpolationMode.BILINEAR)
            features.append(features_chunk.detach().cpu().requires_grad_(False))

        self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256)
        self.all_features = self.all_features_stack.reshape(-1, 256)
        print('prepare_feature_data Done!')

    def define_transforms(self):
        self.transform = T.ToTensor()
        
    def define_proj_mat(self):
        self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]

    def world2ndc(self,points,lindisp=None):
        device = points.device
        return (points - self.center.to(device)) / self.radius.to(device)
        
    def __len__(self):
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        if self.split == 'train':  # use data in the buffers
            sample = {'rays': self.all_rays[idx],
                      'rgbs': self.all_rgbs[idx]}

        else:  # create data for each image separately

            img = self.all_rgbs[idx]
            rays = self.all_rays[idx]
            mask = self.all_masks[idx] # for quantity evaluation

            sample = {'rays': rays,
                      'rgbs': img,
                      'mask': mask}
        return sample


================================================
FILE: dataLoader/colmap2nerf.py
================================================
#!/usr/bin/env python3

# Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto.  Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

import argparse
import os
from pathlib import Path, PurePosixPath

import numpy as np
import json
import sys
import math
import cv2
import os
import shutil

def parse_args():
	parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place")

	parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also")
	parser.add_argument("--video_fps", default=2)
	parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video")
	parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder")
	parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images")
	parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename")
	parser.add_argument("--images", default="images", help="input path to the images")
	parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)")
	parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16")
	parser.add_argument("--skip_early", default=0, help="skip this many images from the start")
	parser.add_argument("--out", default="transforms.json", help="output path")
	args = parser.parse_args()
	return args

def do_system(arg):
	print(f"==== running: {arg}")
	err = os.system(arg)
	if err:
		print("FATAL: command failed")
		sys.exit(err)

def run_ffmpeg(args):
	if not os.path.isabs(args.images):
		args.images = os.path.join(os.path.dirname(args.video_in), args.images)
	images = args.images
	video = args.video_in
	fps = float(args.video_fps) or 1.0
	print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.")
	if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
		sys.exit(1)
	try:
		shutil.rmtree(images)
	except:
		pass
	do_system(f"mkdir {images}")

	time_slice_value = ""
	time_slice = args.time_slice
	if time_slice:
	    start, end = time_slice.split(",")
	    time_slice_value = f",select='between(t\,{start}\,{end})'"
	do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg")

def run_colmap(args):
	db=args.colmap_db
	images=args.images
	db_noext=str(Path(db).with_suffix(""))

	if args.text=="text":
		args.text=db_noext+"_text"
	text=args.text
	sparse=db_noext+"_sparse"
	print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}")
	if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
		sys.exit(1)
	if os.path.exists(db):
		os.remove(db)
	do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}")
	do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}")
	try:
		shutil.rmtree(sparse)
	except:
		pass
	do_system(f"mkdir {sparse}")
	do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}")
	do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1")
	try:
		shutil.rmtree(text)
	except:
		pass
	do_system(f"mkdir {text}")
	do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT")

def variance_of_laplacian(image):
	return cv2.Laplacian(image, cv2.CV_64F).var()

def sharpness(imagePath):
	image = cv2.imread(imagePath)
	gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
	fm = variance_of_laplacian(gray)
	return fm

def qvec2rotmat(qvec):
	return np.array([
		[
			1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
			2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
			2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
		], [
			2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
			1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
			2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
		], [
			2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
			2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
			1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
		]
	])

def rotmat(a, b):
	a, b = a / np.linalg.norm(a), b / np.linalg.norm(b)
	v = np.cross(a, b)
	c = np.dot(a, b)
	s = np.linalg.norm(v)
	kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
	return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10))

def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel
	da = da / np.linalg.norm(da)
	db = db / np.linalg.norm(db)
	c = np.cross(da, db)
	denom = np.linalg.norm(c)**2
	t = ob - oa
	ta = np.linalg.det([t, db, c]) / (denom + 1e-10)
	tb = np.linalg.det([t, da, c]) / (denom + 1e-10)
	if ta > 0:
		ta = 0
	if tb > 0:
		tb = 0
	return (oa+ta*da+ob+tb*db) * 0.5, denom

if __name__ == "__main__":
	args = parse_args()
	if args.video_in != "":
		run_ffmpeg(args)
	if args.run_colmap:
		run_colmap(args)
	AABB_SCALE = int(args.aabb_scale)
	SKIP_EARLY = int(args.skip_early)
	IMAGE_FOLDER = args.images
	TEXT_FOLDER = args.text
	OUT_PATH = args.out
	print(f"outputting to {OUT_PATH}...")
	with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f:
		angle_x = math.pi / 2
		for line in f:
			# 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691
			# 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224
			# 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443
			if line[0] == "#":
				continue
			els = line.split(" ")
			w = float(els[2])
			h = float(els[3])
			fl_x = float(els[4])
			fl_y = float(els[4])
			k1 = 0
			k2 = 0
			p1 = 0
			p2 = 0
			cx = w / 2
			cy = h / 2
			if els[1] == "SIMPLE_PINHOLE":
				cx = float(els[5])
				cy = float(els[6])
			elif els[1] == "PINHOLE":
				fl_y = float(els[5])
				cx = float(els[6])
				cy = float(els[7])
			elif els[1] == "SIMPLE_RADIAL":
				cx = float(els[5])
				cy = float(els[6])
				k1 = float(els[7])
			elif els[1] == "RADIAL":
				cx = float(els[5])
				cy = float(els[6])
				k1 = float(els[7])
				k2 = float(els[8])
			elif els[1] == "OPENCV":
				fl_y = float(els[5])
				cx = float(els[6])
				cy = float(els[7])
				k1 = float(els[8])
				k2 = float(els[9])
				p1 = float(els[10])
				p2 = float(els[11])
			else:
				print("unknown camera model ", els[1])
			# fl = 0.5 * w / tan(0.5 * angle_x);
			angle_x = math.atan(w / (fl_x * 2)) * 2
			angle_y = math.atan(h / (fl_y * 2)) * 2
			fovx = angle_x * 180 / math.pi
			fovy = angle_y * 180 / math.pi

	print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ")

	with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f:
		i = 0
		bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4])
		out = {
			"camera_angle_x": angle_x,
			"camera_angle_y": angle_y,
			"fl_x": fl_x,
			"fl_y": fl_y,
			"k1": k1,
			"k2": k2,
			"p1": p1,
			"p2": p2,
			"cx": cx,
			"cy": cy,
			"w": w,
			"h": h,
			"aabb_scale": AABB_SCALE,
			"frames": [],
		}

		up = np.zeros(3)
		for line in f:
			line = line.strip()
			if line[0] == "#":
				continue
			i = i + 1
			if i < SKIP_EARLY*2:
				continue
			if  i % 2 == 1:
				elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces)
				#name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9])))
				# why is this requireing a relitive path while using ^
				image_rel = os.path.relpath(IMAGE_FOLDER)
				name = str(f"./{image_rel}/{'_'.join(elems[9:])}")
				b=sharpness(name)
				print(name, "sharpness=",b)
				image_id = int(elems[0])
				qvec = np.array(tuple(map(float, elems[1:5])))
				tvec = np.array(tuple(map(float, elems[5:8])))
				R = qvec2rotmat(-qvec)
				t = tvec.reshape([3,1])
				m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
				c2w = np.linalg.inv(m)
				c2w[0:3,2] *= -1 # flip the y and z axis
				c2w[0:3,1] *= -1
				c2w = c2w[[1,0,2,3],:] # swap y and z
				c2w[2,:] *= -1 # flip whole world upside down

				up += c2w[0:3,1]

				frame={"file_path":name,"sharpness":b,"transform_matrix": c2w}
				out["frames"].append(frame)
	nframes = len(out["frames"])
	up = up / np.linalg.norm(up)
	print("up vector was", up)
	R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]
	R = np.pad(R,[0,1])
	R[-1, -1] = 1


	for f in out["frames"]:
		f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis

	# find a central point they are all looking at
	print("computing center of attention...")
	totw = 0.0
	totp = np.array([0.0, 0.0, 0.0])
	for f in out["frames"]:
		mf = f["transform_matrix"][0:3,:]
		for g in out["frames"]:
			mg = g["transform_matrix"][0:3,:]
			p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])
			if w > 0.01:
				totp += p*w
				totw += w
	totp /= totw
	print(totp) # the cameras are looking at totp
	for f in out["frames"]:
		f["transform_matrix"][0:3,3] -= totp

	avglen = 0.
	for f in out["frames"]:
		avglen += np.linalg.norm(f["transform_matrix"][0:3,3])
	avglen /= nframes
	print("avg camera distance from origin", avglen)
	for f in out["frames"]:
		f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized"

	for f in out["frames"]:
		f["transform_matrix"] = f["transform_matrix"].tolist()
	print(nframes,"frames")
	print(f"writing {OUT_PATH}")
	with open(OUT_PATH, "w") as outfile:
		json.dump(out, outfile, indent=2)

================================================
FILE: dataLoader/llff.py
================================================
import torch
from torch.utils.data import Dataset
import glob
import numpy as np
import os
from PIL import Image
from torchvision import transforms as T

from .ray_utils import *


def normalize(v):
    """Normalize a vector."""
    return v / np.linalg.norm(v)


def average_poses(poses):
    """
    Calculate the average pose, which is then used to center all poses
    using @center_poses. Its computation is as follows:
    1. Compute the center: the average of pose centers.
    2. Compute the z axis: the normalized average z axis.
    3. Compute axis y': the average y axis.
    4. Compute x' = y' cross product z, then normalize it as the x axis.
    5. Compute the y axis: z cross product x.

    Note that at step 3, we cannot directly use y' as y axis since it's
    not necessarily orthogonal to z axis. We need to pass from x to y.
    Inputs:
        poses: (N_images, 3, 4)
    Outputs:
        pose_avg: (3, 4) the average pose
    """
    # 1. Compute the center
    center = poses[..., 3].mean(0)  # (3)

    # 2. Compute the z axis
    z = normalize(poses[..., 2].mean(0))  # (3)

    # 3. Compute axis y' (no need to normalize as it's not the final output)
    y_ = poses[..., 1].mean(0)  # (3)

    # 4. Compute the x axis
    x = normalize(np.cross(z, y_))  # (3)

    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
    y = np.cross(x, z)  # (3)

    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)

    return pose_avg


def center_poses(poses, blender2opencv):
    """
    Center the poses so that we can use NDC.
    See https://github.com/bmild/nerf/issues/34
    Inputs:
        poses: (N_images, 3, 4)
    Outputs:
        poses_centered: (N_images, 3, 4) the centered poses
        pose_avg: (3, 4) the average pose
    """
    poses = poses @ blender2opencv
    pose_avg = average_poses(poses)  # (3, 4)
    pose_avg_homo = np.eye(4)
    pose_avg_homo[:3] = pose_avg  # convert to homogeneous coordinate for faster computation
    pose_avg_homo = pose_avg_homo
    # by simply adding 0, 0, 0, 1 as the last row
    last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1))  # (N_images, 1, 4)
    poses_homo = \
        np.concatenate([poses, last_row], 1)  # (N_images, 4, 4) homogeneous coordinate

    poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo  # (N_images, 4, 4)
    #     poses_centered = poses_centered  @ blender2opencv
    poses_centered = poses_centered[:, :3]  # (N_images, 3, 4)

    return poses_centered, pose_avg_homo


def viewmatrix(z, up, pos):
    vec2 = normalize(z)
    vec1_avg = up
    vec0 = normalize(np.cross(vec1_avg, vec2))
    vec1 = normalize(np.cross(vec2, vec0))
    m = np.eye(4)
    m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
    return m


def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
    render_poses = []
    rads = np.array(list(rads) + [1.])

    for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:
        c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
        z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
        render_poses.append(viewmatrix(z, up, c))
    return render_poses


def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
    # center pose
    c2w = average_poses(c2ws_all)

    # Get average pose
    up = normalize(c2ws_all[:, :3, 1].sum(0))

    # Find a reasonable "focus depth" for this dataset
    dt = 0.75
    close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
    focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))

    # Get radii for spiral path
    zdelta = near_fars.min() * .2
    tt = c2ws_all[:, :3, 3]
    rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
    render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
    return np.stack(render_poses)

def get_interpolation_path(c2ws_all, steps=30):
    
    # flower
    # idx0 = 1
    # idx1 = 10

    # trex
    # idx0 = 8
    # idx1 = 53

    # horns
    idx0 = 18
    idx1 = 47

    v = np.linspace(0,1,num=steps)

    c2w0 = c2ws_all[idx0]
    c2w1 = c2ws_all[idx1]

    c2w_ = []
    for i in range(steps):
        c2w_.append(c2w0*v[i] + c2w1*(1-v[i]))

    return np.stack(c2w_)


class LLFFDataset(Dataset):
    def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8):

        self.root_dir = datadir
        self.split = split
        self.hold_every = hold_every
        self.is_stack = is_stack
        self.downsample = downsample
        self.define_transforms()

        self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.white_bg = False

        #         self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]
        self.near_far = [0.0, 1.0]
        self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]])
        # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
        self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)
        self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)

    def read_meta(self):


        poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy'))  # (N_images, 17)
        self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*')))
        # load full resolution image then resize
        if self.split in ['train', 'test']:
            assert len(poses_bounds) == len(self.image_paths), \
                'Mismatch between number of images and number of poses! Please rerun COLMAP!'

        poses = poses_bounds[:, :15].reshape(-1, 3, 5)  # (N_images, 3, 5)
        self.near_fars = poses_bounds[:, -2:]  # (N_images, 2)
        hwf = poses[:, :, -1]

        # Step 1: rescale focal length according to training resolution
        H, W, self.focal = poses[0, :, -1]  # original intrinsics, same for all images
        self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
        self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]

        # Step 2: correct poses
        # Original poses has rotation in form "down right back", change to "right up back"
        # See https://github.com/bmild/nerf/issues/34
        poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
        # (N_images, 3, 4) exclude H, W, focal
        self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)

        # Step 3: correct scale so that the nearest depth is at a little more than 1.0
        # See https://github.com/bmild/nerf/issues/34
        near_original = self.near_fars.min()
        scale_factor = near_original * 0.75  # 0.75 is the default parameter
        # the nearest depth is at 1/0.75=1.33
        self.near_fars /= scale_factor
        self.poses[..., 3] /= scale_factor

        # build rendering path
        N_views, N_rots = 120, 2
        tt = self.poses[:, :3, 3]  # ptstocam(poses[:3,3,:].T, c2w).T
        up = normalize(self.poses[:, :3, 1].sum(0))
        rads = np.percentile(np.abs(tt), 90, 0)

        self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
        # self.render_path = get_interpolation_path(self.poses)

        # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
        # val_idx = np.argmin(distances_from_center)  # choose val image as the closest to
        # center image

        # ray directions for all pixels, same for all images (same H, W, focal)
        W, H = self.img_wh
        self.directions = get_ray_directions_blender(H, W, self.focal)  # (H, W, 3)

        average_pose = average_poses(self.poses)
        dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)
        i_test = np.arange(0, self.poses.shape[0], self.hold_every)  # [np.argmin(dists)]
        img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))

        # use first N_images-1 to train, the LAST is val
        self.all_rays = []
        self.all_rgbs = []
        for i in img_list:
            image_path = self.image_paths[i]
            c2w = torch.FloatTensor(self.poses[i])

            img = Image.open(image_path).convert('RGB')
            if self.downsample != 1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (3, h, w)

            img = img.view(3, -1).permute(1, 0)  # (h*w, 3) RGB
            self.all_rgbs += [img]
            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
            # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)

            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)

        all_rays = self.all_rays
        all_rgbs = self.all_rgbs

        self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6)
        self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)

        if self.is_stack:
            self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6)  # (len(self.meta['frames]),h,w,6)
            avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True)
            self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6)
            self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)

    @torch.no_grad()
    def prepare_feature_data(self, encoder, chunk=8):
        '''
        Prepare feature maps as training data.
        '''
        assert self.is_stack, 'Dataset should contain original stacked taining data!'
        print('====> prepare_feature_data ...')

        frames_num, h, w, _ = self.all_rgbs_stack.size()
        features = []

        for chunk_idx in range(frames_num // chunk + int(frames_num % chunk > 0)):
            rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda()
            features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1
            # resize to the size of rgb map so that rays can match
            features_chunk = T.functional.resize(features_chunk, size=(h,w), 
                                                 interpolation=T.InterpolationMode.BILINEAR)
            features.append(features_chunk.detach().cpu().requires_grad_(False))

        self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256)
        self.all_features = self.all_features_stack.reshape(-1, 256)
        print('prepare_feature_data Done!')

    def define_transforms(self):
        self.transform = T.ToTensor()

    def __len__(self):
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        sample = {'rays': self.all_rays[idx],
                  'rgbs': self.all_rgbs[idx]}

        return sample

================================================
FILE: dataLoader/nsvf.py
================================================
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T

from .ray_utils import *

trans_t = lambda t : torch.Tensor([
    [1,0,0,0],
    [0,1,0,0],
    [0,0,1,t],
    [0,0,0,1]]).float()

rot_phi = lambda phi : torch.Tensor([
    [1,0,0,0],
    [0,np.cos(phi),-np.sin(phi),0],
    [0,np.sin(phi), np.cos(phi),0],
    [0,0,0,1]]).float()

rot_theta = lambda th : torch.Tensor([
    [np.cos(th),0,-np.sin(th),0],
    [0,1,0,0],
    [np.sin(th),0, np.cos(th),0],
    [0,0,0,1]]).float()


def pose_spherical(theta, phi, radius):
    c2w = trans_t(radius)
    c2w = rot_phi(phi/180.*np.pi) @ c2w
    c2w = rot_theta(theta/180.*np.pi) @ c2w
    c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
    return c2w

class NSVF(Dataset):
    """NSVF Generic Dataset."""
    def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False):
        self.root_dir = datadir
        self.split = split
        self.is_stack = is_stack
        self.downsample = downsample
        self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
        self.define_transforms()

        self.white_bg = True
        self.near_far = [0.5,6.0]
        self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)
        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.define_proj_mat()
        
        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
    
    def bbox2corners(self):
        corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
        for i in range(3):
            corners[i,[0,1],i] = corners[i,[1,0],i] 
        return corners.view(-1,3)
        
        
    def read_meta(self):
        with open(os.path.join(self.root_dir, "intrinsics.txt")) as f:
            focal = float(f.readline().split()[0])
        self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]])
        self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1)

        pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
        img_files  = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))

        if self.split == 'train':
            pose_files = [x for x in pose_files if x.startswith('0_')]
            img_files = [x for x in img_files if x.startswith('0_')]
        elif self.split == 'val':
            pose_files = [x for x in pose_files if x.startswith('1_')]
            img_files = [x for x in img_files if x.startswith('1_')]
        elif self.split == 'test':
            test_pose_files = [x for x in pose_files if x.startswith('2_')]
            test_img_files = [x for x in img_files if x.startswith('2_')]
            if len(test_pose_files) == 0:
                test_pose_files = [x for x in pose_files if x.startswith('1_')]
                test_img_files = [x for x in img_files if x.startswith('1_')]
            pose_files = test_pose_files
            img_files = test_img_files

        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2])  # (h, w, 3)
        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)


        self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
        
        self.poses = []
        self.all_rays = []
        self.all_rgbs = []

        assert len(img_files) == len(pose_files)
        for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
            image_path = os.path.join(self.root_dir, 'rgb', img_fname)
            img = Image.open(image_path)
            if self.downsample!=1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (4, h, w)
            img = img.view(img.shape[0], -1).permute(1, 0)  # (h*w, 4) RGBA
            if img.shape[-1]==4:
                img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
            self.all_rgbs += [img]

            c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv
            c2w = torch.FloatTensor(c2w)
            self.poses.append(c2w)  # C2W
            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 8)
            
#             w2c = torch.inverse(c2w)
#

        self.poses = torch.stack(self.poses)
        if 'train' == self.split:
            if self.is_stack:
                self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6)  # (len(self.meta['frames])*h*w, 3)
                self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames])*h*w, 3) 
            else:
                self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)
                self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (len(self.meta['frames])*h*w, 3)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
            self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)

 
    def define_transforms(self):
        self.transform = T.ToTensor()
        
    def define_proj_mat(self):
        self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]

    def world2ndc(self, points):
        device = points.device
        return (points - self.center.to(device)) / self.radius.to(device)
        
    def __len__(self):
        if self.split == 'train':
            return len(self.all_rays)
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        if self.split == 'train':  # use data in the buffers
            sample = {'rays': self.all_rays[idx],
                      'rgbs': self.all_rgbs[idx]}

        else:  # create data for each image separately

            img = self.all_rgbs[idx]
            rays = self.all_rays[idx]

            sample = {'rays': rays,
                      'rgbs': img}
        return sample

================================================
FILE: dataLoader/ray_utils.py
================================================
import torch, re
import numpy as np
from torch import searchsorted
from kornia import create_meshgrid


# from utils import index_point_feature

def depth2dist(z_vals, cos_angle):
    # z_vals: [N_ray N_sample]
    device = z_vals.device
    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1)  # [N_rays, N_samples]
    dists = dists * cos_angle.unsqueeze(-1)
    return dists


def ndc2dist(ndc_pts, cos_angle):
    dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1)
    dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1)  # [N_rays, N_samples]
    return dists


def get_ray_directions(H, W, focal, center=None):
    """
    Get ray directions for all pixels in camera coordinate.
    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
               ray-tracing-generating-camera-rays/standard-coordinate-systems
    Inputs:
        H, W, focal: image height, width and focal length
    Outputs:
        directions: (H, W, 3), the direction of the rays in camera coordinate
    """
    grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5

    i, j = grid.unbind(-1)
    # the direction here is without +0.5 pixel centering as calibration is not so accurate
    # see https://github.com/bmild/nerf/issues/24
    cent = center if center is not None else [W / 2, H / 2]
    directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1)  # (H, W, 3)

    return directions


def get_ray_directions_blender(H, W, focal, center=None):
    """
    Get ray directions for all pixels in camera coordinate.
    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
               ray-tracing-generating-camera-rays/standard-coordinate-systems
    Inputs:
        H, W, focal: image height, width and focal length
    Outputs:
        directions: (H, W, 3), the direction of the rays in camera coordinate
    """
    grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5
    i, j = grid.unbind(-1)
    # the direction here is without +0.5 pixel centering as calibration is not so accurate
    # see https://github.com/bmild/nerf/issues/24
    cent = center if center is not None else [W / 2, H / 2]
    directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)],
                             -1)  # (H, W, 3)

    return directions


def get_rays(directions, c2w):
    """
    Get ray origin and normalized directions in world coordinate for all pixels in one image.
    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
               ray-tracing-generating-camera-rays/standard-coordinate-systems
    Inputs:
        directions: (H, W, 3) precomputed ray directions in camera coordinate
        c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
    Outputs:
        rays_o: (H*W, 3), the origin of the rays in world coordinate
        rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
    """
    # Rotate ray directions from camera coordinate to the world coordinate
    rays_d = directions @ c2w[:3, :3].T  # (H, W, 3)
    # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
    # The origin of all rays is the camera origin in world coordinate
    rays_o = c2w[:3, 3].expand(rays_d.shape)  # (H, W, 3)

    rays_d = rays_d.view(-1, 3)
    rays_o = rays_o.view(-1, 3)

    return rays_o, rays_d


def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
    # Shift ray origins to near plane
    t = -(near + rays_o[..., 2]) / rays_d[..., 2]
    rays_o = rays_o + t[..., None] * rays_d

    # Projection
    o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
    o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
    o2 = 1. + 2. * near / rays_o[..., 2]

    d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
    d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
    d2 = -2. * near / rays_o[..., 2]

    rays_o = torch.stack([o0, o1, o2], -1)
    rays_d = torch.stack([d0, d1, d2], -1)

    return rays_o, rays_d

def ndc_rays(H, W, focal, near, rays_o, rays_d):
    # Shift ray origins to near plane
    t = (near - rays_o[..., 2]) / rays_d[..., 2]
    rays_o = rays_o + t[..., None] * rays_d

    # Projection
    o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
    o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
    o2 = 1. - 2. * near / rays_o[..., 2]

    d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
    d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
    d2 = 2. * near / rays_o[..., 2]

    rays_o = torch.stack([o0, o1, o2], -1)
    rays_d = torch.stack([d0, d1, d2], -1)

    return rays_o, rays_d

# Hierarchical sampling (section 5.2)
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
    device = weights.device
    # Get pdf
    weights = weights + 1e-5  # prevent nans
    pdf = weights / torch.sum(weights, -1, keepdim=True)
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)  # (batch, len(bins))

    # Take uniform samples
    if det:
        u = torch.linspace(0., 1., steps=N_samples, device=device)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device)

    # Pytest, overwrite u with numpy's fixed random numbers
    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0., 1., N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u)

    # Invert CDF
    u = u.contiguous()
    inds = searchsorted(cdf.detach(), u, right=True)
    below = torch.max(torch.zeros_like(inds - 1), inds - 1)
    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)

    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)

    denom = (cdf_g[..., 1] - cdf_g[..., 0])
    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
    t = (u - cdf_g[..., 0]) / denom
    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])

    return samples


def dda(rays_o, rays_d, bbox_3D):
    inv_ray_d = 1.0 / (rays_d + 1e-6)
    t_min = (bbox_3D[:1] - rays_o) * inv_ray_d  # N_rays 3
    t_max = (bbox_3D[1:] - rays_o) * inv_ray_d
    t = torch.stack((t_min, t_max))  # 2 N_rays 3
    t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0]
    t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0]
    return t_min, t_max


def ray_marcher(rays,
                N_samples=64,
                lindisp=False,
                perturb=0,
                bbox_3D=None):
    """
    sample points along the rays
    Inputs:
        rays: ()

    Returns:

    """

    # Decompose the inputs
    N_rays = rays.shape[0]
    rays_o, rays_d = rays[:, 0:3], rays[:, 3:6]  # both (N_rays, 3)
    near, far = rays[:, 6:7], rays[:, 7:8]  # both (N_rays, 1)

    if bbox_3D is not None:
        # cal aabb boundles
        near, far = dda(rays_o, rays_d, bbox_3D)

    # Sample depth points
    z_steps = torch.linspace(0, 1, N_samples, device=rays.device)  # (N_samples)
    if not lindisp:  # use linear sampling in depth space
        z_vals = near * (1 - z_steps) + far * z_steps
    else:  # use linear sampling in disparity space
        z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)

    z_vals = z_vals.expand(N_rays, N_samples)

    if perturb > 0:  # perturb sampling depths (z_vals)
        z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:])  # (N_rays, N_samples-1) interval mid points
        # get intervals between samples
        upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1)
        lower = torch.cat([z_vals[:, :1], z_vals_mid], -1)

        perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)
        z_vals = lower + (upper - lower) * perturb_rand

    xyz_coarse_sampled = rays_o.unsqueeze(1) + \
                         rays_d.unsqueeze(1) * z_vals.unsqueeze(2)  # (N_rays, N_samples, 3)

    return xyz_coarse_sampled, rays_o, rays_d, z_vals


def read_pfm(filename):
    file = open(filename, 'rb')
    color = None
    width = None
    height = None
    scale = None
    endian = None

    header = file.readline().decode('utf-8').rstrip()
    if header == 'PF':
        color = True
    elif header == 'Pf':
        color = False
    else:
        raise Exception('Not a PFM file.')

    dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
    if dim_match:
        width, height = map(int, dim_match.groups())
    else:
        raise Exception('Malformed PFM header.')

    scale = float(file.readline().rstrip())
    if scale < 0:  # little-endian
        endian = '<'
        scale = -scale
    else:
        endian = '>'  # big-endian

    data = np.fromfile(file, endian + 'f')
    shape = (height, width, 3) if color else (height, width)

    data = np.reshape(data, shape)
    data = np.flipud(data)
    file.close()
    return data, scale


def ndc_bbox(all_rays):
    near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0]
    near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0]
    far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0]
    far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0]
    print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}')
    return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max)))

import torchvision
normalize_vgg = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                 std=[0.229, 0.224, 0.225])

def denormalize_vgg(img):
    im = img.clone()
    im[:, 0, :, :] *= 0.229
    im[:, 1, :, :] *= 0.224
    im[:, 2, :, :] *= 0.225
    im[:, 0, :, :] += 0.485
    im[:, 1, :, :] += 0.456
    im[:, 2, :, :] += 0.406
    return im

================================================
FILE: dataLoader/styleLoader.py
================================================
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as T


def getDataLoader(dataset_path, batch_size, sampler, image_side_length=256, num_workers=2):
    transform = T.Compose([
                T.Resize(size=(image_side_length*2, image_side_length*2)),
                T.RandomCrop(image_side_length),
                T.ToTensor(),
            ])

    train_dataset = datasets.ImageFolder(dataset_path, transform=transform)
    dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler(len(train_dataset)), num_workers=num_workers)

    return dataloader

================================================
FILE: dataLoader/tankstemple.py
================================================
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T
import random

from .ray_utils import *


def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
    if axis == 'z':
        return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h]
    elif axis == 'y':
        return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)]
    else:
        return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)]


def cross(x, y, axis=0):
    T = torch if isinstance(x, torch.Tensor) else np
    return T.cross(x, y, axis)


def normalize(x, axis=-1, order=2):
    if isinstance(x, torch.Tensor):
        l2 = x.norm(p=order, dim=axis, keepdim=True)
        return x / (l2 + 1e-8), l2

    else:
        l2 = np.linalg.norm(x, order, axis)
        l2 = np.expand_dims(l2, axis)
        l2[l2 == 0] = 1
        return x / l2,


def cat(x, axis=1):
    if isinstance(x[0], torch.Tensor):
        return torch.cat(x, dim=axis)
    return np.concatenate(x, axis=axis)


def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):
    """
    This function takes a vector 'camera_position' which specifies the location
    of the camera in world coordinates and two vectors `at` and `up` which
    indicate the position of the object and the up directions of the world
    coordinate system respectively. The object is assumed to be centered at
    the origin.
    The output is a rotation matrix representing the transformation
    from world coordinates -> view coordinates.
    Input:
        camera_position: 3
        at: 1 x 3 or N x 3  (0, 0, 0) in default
        up: 1 x 3 or N x 3  (0, 1, 0) in default
    """

    if at is None:
        at = torch.zeros_like(camera_position)
    else:
        at = torch.tensor(at).type_as(camera_position)
    if up is None:
        up = torch.zeros_like(camera_position)
        up[2] = -1
    else:
        up = torch.tensor(up).type_as(camera_position)

    z_axis = normalize(at - camera_position)[0]
    x_axis = normalize(cross(up, z_axis))[0]
    y_axis = normalize(cross(z_axis, x_axis))[0]

    R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)
    return R


def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
    c2ws = []
    for t in range(frames):
        c2w = torch.eye(4)
        cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi))
        cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True)
        c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot
        c2ws.append(c2w)
    return torch.stack(c2ws)

class TanksTempleDataset(Dataset):
    """NSVF Generic Dataset."""
    def __init__(self, datadir, split='train', downsample=4.0, wh=[1920,1080], is_stack=False):
        self.root_dir = datadir
        self.split = split
        self.is_stack = is_stack
        self.downsample = downsample
        self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
        self.define_transforms()

        self.white_bg = True
        self.near_far = [0.01,6.0]
        self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2

        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.define_proj_mat()
        
        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
    
    def bbox2corners(self):
        corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
        for i in range(3):
            corners[i,[0,1],i] = corners[i,[1,0],i] 
        return corners.view(-1,3)
        
        
    def read_meta(self):

        self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt"))
        self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1)
        pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
        img_files  = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))

        if self.split == 'train':
            pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('0_') and idx%3==0]
            img_files = [x for idx,x in enumerate(img_files) if x.startswith('0_') and idx%3==0]
        elif self.split == 'test':
            pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('2_') and idx%3==0]
            img_files = [x for idx,x in enumerate(img_files) if x.startswith('2_') and idx%3==0]
            if len(test_pose_files) == 0:
                test_pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('1_') and idx%3==0]
                test_img_files = [x for idx,x in enumerate(img_files) if x.startswith('1_') and idx%3==0]
            pose_files = test_pose_files
            img_files = test_img_files

        

        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2])  # (h, w, 3)
        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)

        w, h = self.img_wh
        
        self.poses = []
        self.all_rays = []
        self.all_rgbs = []
        self.all_masks = []

        assert len(img_files) == len(pose_files)
        for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
            image_path = os.path.join(self.root_dir, 'rgb', img_fname)
            img = Image.open(image_path)
            if self.downsample!=1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (3, h, w)
            img = img.view(img.shape[0], -1).permute(1, 0)  # (h*w, 3) RGBA
            mask =  torch.where(
                img.sum(-1, keepdim=True) == 3.,
                1.,
                0.
            )
            self.all_masks.append(mask.reshape(h,w,1)) # (h, w, 1) A

            if img.shape[-1]==4:
                img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
            self.all_rgbs.append(img)
            

            c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans
            c2w = torch.FloatTensor(c2w)
            self.poses.append(c2w)  # C2W
            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 8)

        self.poses = torch.stack(self.poses)

        center = torch.mean(self.scene_bbox, dim=0)
        radius = torch.norm(self.scene_bbox[1]-center)*1.2
        up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
        pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y')
        self.render_path = gen_path(pos_gen, up=up,frames=100)
        self.render_path[:, :3, 3] += center


        all_rays = self.all_rays
        all_rgbs = self.all_rgbs
        self.all_masks = torch.stack(self.all_masks) # (n_frames, h, w, 1)
        self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6)
        self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)

        if self.is_stack:
            self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6)  # (len(self.meta['frames]),h,w,6)
            avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True)
            self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6)
            self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)

    @torch.no_grad()
    def prepare_feature_data(self, encoder, chunk=4):
        '''
        Prepare feature maps as training data.
        '''
        assert self.is_stack, 'Dataset should contain original stacked taining data!'
        print('====> prepare_feature_data ...')

        frames_num, h, w, _ = self.all_rgbs_stack.size()
        features = []

        for chunk_idx in tqdm(range(frames_num // chunk + int(frames_num % chunk > 0))):
            rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda()
            features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1
            # resize to the size of rgb map so that rays can match
            features_chunk = T.functional.resize(features_chunk, size=(h,w), 
                                                 interpolation=T.InterpolationMode.BILINEAR)
            features.append(features_chunk.detach().cpu().requires_grad_(False))

        self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256)
        self.all_features = self.all_features_stack.reshape(-1, 256)
        print('prepare_feature_data Done!')


    def define_transforms(self):
        self.transform = T.ToTensor()
        
    def define_proj_mat(self):
        self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]

    def world2ndc(self, points):
        device = points.device
        return (points - self.center.to(device)) / self.radius.to(device)
        
    def __len__(self):
        if self.split == 'train':
            return len(self.all_rays)
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        if self.split == 'train':  # use data in the buffers
            sample = {'rays': self.all_rays[idx],
                      'rgbs': self.all_rgbs[idx]}

        else:  # create data for each image separately

            img = self.all_rgbs[idx]
            rays = self.all_rays[idx]

            sample = {'rays': rays,
                      'rgbs': img}
        return sample

================================================
FILE: dataLoader/your_own_data.py
================================================
import torch,cv2
from torch.utils.data import Dataset
import json
from tqdm import tqdm
import os
from PIL import Image
from torchvision import transforms as T


from .ray_utils import *


class YourOwnDataset(Dataset):
    def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):

        self.N_vis = N_vis
        self.root_dir = datadir
        self.split = split
        self.is_stack = is_stack
        self.downsample = downsample
        self.define_transforms()

        self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
        self.read_meta()
        self.define_proj_mat()

        self.white_bg = True
        self.near_far = [0.1,100.0]
        
        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
        self.downsample=downsample

    def read_depth(self, filename):
        depth = np.array(read_pfm(filename)[0], dtype=np.float32)  # (800, 800)
        return depth
    
    def read_meta(self):

        with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
            self.meta = json.load(f)

        w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample)
        self.img_wh = [w,h]
        self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x'])  # original focal length
        self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y'])  # original focal length
        self.cx, self.cy = self.meta['cx'],self.meta['cy']


        # ray directions for all pixels, same for all images (same H, W, focal)
        self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy])  # (h, w, 3)
        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
        self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float()

        self.image_paths = []
        self.poses = []
        self.all_rays = []
        self.all_rgbs = []
        self.all_masks = []
        self.all_depth = []


        img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
        idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
        for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#

            frame = self.meta['frames'][i]
            pose = np.array(frame['transform_matrix']) @ self.blender2opencv
            c2w = torch.FloatTensor(pose)
            self.poses += [c2w]

            image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
            self.image_paths += [image_path]
            img = Image.open(image_path)
            
            if self.downsample!=1.0:
                img = img.resize(self.img_wh, Image.LANCZOS)
            img = self.transform(img)  # (4, h, w)
            img = img.view(-1, w*h).permute(1, 0)  # (h*w, 4) RGBA
            if img.shape[-1]==4:
                img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB
            self.all_rgbs += [img]


            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)
            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)


        self.poses = torch.stack(self.poses)
        if not self.is_stack:
            self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)
            self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (len(self.meta['frames])*h*w, 3)

#             self.all_depth = torch.cat(self.all_depth, 0)  # (len(self.meta['frames])*h*w, 3)
        else:
            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)
            self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)
            # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1])  # (len(self.meta['frames]),h,w,3)


    def define_transforms(self):
        self.transform = T.ToTensor()
        
    def define_proj_mat(self):
        self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]

    def world2ndc(self,points,lindisp=None):
        device = points.device
        return (points - self.center.to(device)) / self.radius.to(device)
        
    def __len__(self):
        return len(self.all_rgbs)

    def __getitem__(self, idx):

        if self.split == 'train':  # use data in the buffers
            sample = {'rays': self.all_rays[idx],
                      'rgbs': self.all_rgbs[idx]}

        else:  # create data for each image separately

            img = self.all_rgbs[idx]
            rays = self.all_rays[idx]
            mask = self.all_masks[idx] # for quantity evaluation

            sample = {'rays': rays,
                      'rgbs': img}
        return sample


================================================
FILE: extra/auto_run_paramsets.py
================================================
import os
import threading, queue
import numpy as np
import time


def getFolderLocker(logFolder):
    while True:
        try:
            os.makedirs(logFolder+"/lockFolder")
            break
        except: 
            time.sleep(0.01)

def releaseFolderLocker(logFolder):
    os.removedirs(logFolder+"/lockFolder")

def getStopFolder(logFolder):
    return os.path.isdir(logFolder+"/stopFolder")


def get_param_str(key, val):
    if key == 'data_name':
        return f'--datadir {datafolder}/{val} '
    else:
        return f'--{key} {val} '

def get_param_list(param_dict):
    param_keys = list(param_dict.keys())
    param_modes = len(param_keys)
    param_nums = [len(param_dict[key]) for key in param_keys]
    
    param_ids = np.zeros(param_nums+[param_modes], dtype=int)
    for i in range(param_modes):
        broad_tuple = np.ones(param_modes, dtype=int).tolist()
        broad_tuple[i] = param_nums[i]
        broad_tuple = tuple(broad_tuple)
        print(broad_tuple)
        param_ids[...,i] = np.arange(param_nums[i]).reshape(broad_tuple)
    param_ids = param_ids.reshape(-1, param_modes)
    # print(param_ids)
    print(len(param_ids))
    
    params = []
    expnames = []
    for i in range(param_ids.shape[0]):
        one = ""
        name = ""
        param_id = param_ids[i]
        for j in range(param_modes):
            key = param_keys[j]
            val = param_dict[key][param_id[j]]
            if type(key) is tuple:
                assert len(key) == len(val)
                for k in range(len(key)):
                    one += get_param_str(key[k], val[k])
                    name += f'{val[k]},'
                name=name[:-1]+'-'
            else:
                one += get_param_str(key, val)
                name += f'{val}-'
        params.append(one)
        name=name.replace(' ','')
        print(name)
        expnames.append(name[:-1])
    # print(params)
    return params, expnames







if __name__ == '__main__':
    


    # nerf
    expFolder = "nerf/"
    # parameters to iterate, use tuple to couple multiple parameters
    datafolder = '/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/'
    param_dict = {
        'data_name': ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials'],
        'data_dim_color': [13, 27, 54]
    }

    # n_iters = 30000
    # for data_name in ['Robot']:#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'
    #     cmd = f'CUDA_VISIBLE_DEVICES={cuda}  python train.py ' \
    #           f'--dataset_name nsvf --datadir /mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/{data_name} '\
    #           f'--expname {data_name} --batch_size {batch_size} ' \
    #           f'--n_iters {n_iters}  ' \
    #           f'--N_voxel_init {128**3} --N_voxel_final {300**3} '\
    #           f'--N_vis {5}  ' \
    #           f'--n_lamb_sigma "[16,16,16]" --n_lamb_sh "[48,48,48]" ' \
    #           f'--upsamp_list "[2000, 3000, 4000, 5500,7000]" --update_AlphaMask_list "[3000,4000]" ' \
    #           f'--shadingMode MLP_Fea --fea2denseAct softplus  --view_pe {2} --fea_pe {2} ' \
    #           f'--L1_weight_inital {8e-5} --L1_weight_rest {4e-5} --rm_weight_mask_thre {1e-4} --add_timestamp 0 ' \
    #           f'--render_test 1 '
    #     print(cmd)
    #     os.system(cmd)

    # nsvf
    # expFolder = "nsvf_0227/"
    # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/'
    # param_dict = {
    #             'data_name': ['Robot','Steamtrain','Bike','Lifestyle','Palace','Spaceship','Toad','Wineholder'],#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'
    #             'shadingMode': ['SH'],
    #             ('n_lamb_sigma', 'n_lamb_sh'): [ ("[8,8,8]", "[8,8,8]")],
    #             ('view_pe', 'fea_pe', 'featureC','fea2denseAct','N_voxel_init') : [(2, 2, 128, 'softplus',128**3)],
    #             ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'):[(4e-5, 4e-5, 1e-4)],
    #             ('n_iters','N_voxel_final'): [(30000,300**3)],
    #             ('dataset_name','N_vis','render_test') : [("nsvf",5,1)],
    #             ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[3000,4000]")]
    #
    #     }

    # tankstemple
    # expFolder = "tankstemple_0304/"
    # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/TanksAndTemple/'
    # param_dict = {
    #             'data_name': ['Truck','Barn','Caterpillar','Family','Ignatius'],
    #             'shadingMode': ['MLP_Fea'],
    #             ('n_lamb_sigma', 'n_lamb_sh'): [("[16,16,16]", "[48,48,48]")],
    #             ('view_pe', 'fea_pe','fea2denseAct','N_voxel_init','render_test') : [(2, 2, 'softplus',128**3,1)],
    #             ('TV_weight_density','TV_weight_app'):[(0.1,0.01)],
    #             # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'): [(4e-5, 4e-5, 1e-4)],
    #             ('n_iters','N_voxel_final'): [(15000,300**3)],
    #             ('dataset_name','N_vis') : [("tankstemple",5)],
    #             ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2000,4000]")]
    #     }

    # llff
    # expFolder = "real_iconic/"
    # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/real_iconic/'
    # List = os.listdir(datafolder)
    # param_dict = {
    #             'data_name': List,
    #             ('shadingMode', 'view_pe', 'fea_pe','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 'relu',512,128**3)],
    #             ('n_lamb_sigma', 'n_lamb_sh') : [("[16,4,4]", "[48,12,12]")],
    #             ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],
    #             ('n_iters','N_voxel_final'): [(25000,640**3)],
    #             ('dataset_name','downsample_train','ndc_ray','N_vis','render_path') : [("llff",4.0, 1,-1,1)],
    #             ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")],
    #     }

    # expFolder = "llff/"
    # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data'
    # param_dict = {
    #             'data_name': ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'],#'fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'
    #             ('n_lamb_sigma', 'n_lamb_sh'): [("[16,4,4]", "[48,12,12]")],
    #             ('shadingMode', 'view_pe', 'fea_pe', 'featureC','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 128, 'relu',512,128**3),('SH', 0, 0, 128, 'relu',512,128**3)],
    #             ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],
    #             ('n_iters','N_voxel_final'): [(25000,640**3)],
    #             ('dataset_name','downsample_train','ndc_ray','N_vis','render_test','render_path') : [("llff",4.0, 1,-1,1,1)],
    #             ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")],
    #     }

    #setting available gpus
    gpus_que = queue.Queue(3)
    for i in [1,2,3]:
        gpus_que.put(i)
    
    os.makedirs(f"log/{expFolder}", exist_ok=True)

    def run_program(gpu, expname, param):
        cmd = f'CUDA_VISIBLE_DEVICES={gpu}  python train.py ' \
            f'--expname {expname} --basedir ./log/{expFolder} --config configs/lego.txt ' \
            f'{param}' \
            f'> "log/{expFolder}{expname}/{expname}.txt"'
        print(cmd)
        os.system(cmd)
        gpus_que.put(gpu)

    params, expnames = get_param_list(param_dict)

    
    logFolder=f"log/{expFolder}"
    os.makedirs(logFolder, exist_ok=True)

    ths = []
    for i in range(len(params)):

        if getStopFolder(logFolder):
            break


        targetFolder = f"log/{expFolder}{expnames[i]}"
        gpu = gpus_que.get()
        getFolderLocker(logFolder)
        if os.path.isdir(targetFolder):
            releaseFolderLocker(logFolder)
            gpus_que.put(gpu)
            continue
        else:
            os.makedirs(targetFolder, exist_ok=True)
            print("making",targetFolder, "running",expnames[i], params[i])
            releaseFolderLocker(logFolder)


        t = threading.Thread(target=run_program, args=(gpu, expnames[i], params[i]), daemon=True)
        t.start()
        ths.append(t)
    
    for th in ths:
        th.join()

================================================
FILE: extra/compute_metrics.py
================================================
import os, math
import numpy as np
import scipy.signal
from typing import List, Optional
from PIL import Image
import os
import torch
import configargparse

__LPIPS__ = {}
def init_lpips(net_name, device):
    assert net_name in ['alex', 'vgg']
    import lpips
    print(f'init_lpips: lpips_{net_name}')
    return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)

def rgb_lpips(np_gt, np_im, net_name, device):
    if net_name not in __LPIPS__:
        __LPIPS__[net_name] = init_lpips(net_name, device)
    gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
    im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
    return __LPIPS__[net_name](gt, im, normalize=True).item()


def findItem(items, target):
    for one in items:
        if one[:len(target)]==target:
            return one
    return None


''' Evaluation metrics (ssim, lpips)
'''
def rgb_ssim(img0, img1, max_val,
             filter_size=11,
             filter_sigma=1.5,
             k1=0.01,
             k2=0.03,
             return_map=False):
    # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
    assert len(img0.shape) == 3
    assert img0.shape[-1] == 3
    assert img0.shape == img1.shape

    # Construct a 1D Gaussian blur filter.
    hw = filter_size // 2
    shift = (2 * hw - filter_size + 1) / 2
    f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
    filt = np.exp(-0.5 * f_i)
    filt /= np.sum(filt)

    # Blur in x and y (faster than the 2D convolution).
    def convolve2d(z, f):
        return scipy.signal.convolve2d(z, f, mode='valid')

    filt_fn = lambda z: np.stack([
        convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
        for i in range(z.shape[-1])], -1)
    mu0 = filt_fn(img0)
    mu1 = filt_fn(img1)
    mu00 = mu0 * mu0
    mu11 = mu1 * mu1
    mu01 = mu0 * mu1
    sigma00 = filt_fn(img0**2) - mu00
    sigma11 = filt_fn(img1**2) - mu11
    sigma01 = filt_fn(img0 * img1) - mu01

    # Clip the variances and covariances to valid values.
    # Variance must be non-negative:
    sigma00 = np.maximum(0., sigma00)
    sigma11 = np.maximum(0., sigma11)
    sigma01 = np.sign(sigma01) * np.minimum(
        np.sqrt(sigma00 * sigma11), np.abs(sigma01))
    c1 = (k1 * max_val)**2
    c2 = (k2 * max_val)**2
    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
    ssim_map = numer / denom
    ssim = np.mean(ssim_map)
    return ssim_map if return_map else ssim


if __name__ == '__main__':

    parser = configargparse.ArgumentParser()
    parser.add_argument("--exp", type=str, help="folder of exps")
    parser.add_argument("--paramStr", type=str, help="str of params")
    args = parser.parse_args()


    # datanames = ['drums','hotdog','materials','ficus','lego','mic','ship','chair'] #['ship']#
    # gtFolder = "/home/code-base/user_space/codes/nerf/data/nerf_synthetic"
    # expFolder = "/home/code-base/user_space/codes/TensoRF/log/"+args.exp

    # datanames = ['room','fortress', 'flower','orchids','leaves','horns','trex','fern'] #['ship']#
    # gtFolder = "/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/"
    # expFolder = "/mnt/new_disk_2/anpei/code/TensoRF/log/"+args.exp
    paramStr = args.paramStr
    fileNum = 200


    expitems = os.listdir(expFolder)
    finalFolder = f'{expFolder}/finals/{paramStr}'
    outFile = f'{finalFolder}/{paramStr}_metrics.txt'
    os.makedirs(finalFolder, exist_ok=True)

    expitems.sort(reverse=True)


    with open(outFile, 'w') as f:
        all_psnr = []
        all_ssim = []
        all_alex = []
        all_vgg = []
        for dataname in datanames:
            

            gtstr = gtFolder+"/"+dataname+"/test/r_%d.png"
            expname = findItem(expitems, f'{paramStr}-{dataname}')
            print("expname: ", expname)
            if expname is None:
                print("no ",dataname, "exists")
                continue
            resultstr = expFolder+"/"+expname+"/imgs_test_all/"+ dataname+"-"+paramStr+ "_%03d.png"
            metric_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_mean.txt'
            video_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_video.mp4'
            
            exist_metric=False
            if os.path.isfile(metric_file):
                metrics = np.loadtxt(metric_file)
                print(metrics, metrics.tolist())
                if metrics.size == 4:
                    psnr, ssim, l_a, l_v = metrics.tolist()
                    exist_metric = True
                    os.system(f"cp {video_file} {finalFolder}/")

            if not exist_metric:
                psnrs = []
                ssims = []
                l_alex = []
                l_vgg = []
                for i in range(fileNum):
                    gt = np.asarray(Image.open(gtstr%i),dtype=np.float32) / 255.0
                    gtmask = gt[...,[3]]
                    gt = gt[...,:3]
                    gt = gt*gtmask + (1-gtmask)
                    img = np.asarray(Image.open(resultstr%i),dtype=np.float32)[...,:3]  / 255.0
                    # print(gt[0,0],img[0,0],gt.shape, img.shape, gt.max(), img.max())


                    psnr = -10. * np.log10(np.mean(np.square(img - gt)))
                    ssim = rgb_ssim(img, gt, 1)
                    lpips_alex = rgb_lpips(gt, img, 'alex','cuda')
                    lpips_vgg = rgb_lpips(gt, img, 'vgg','cuda')

                    print(i, psnr, ssim, lpips_alex, lpips_vgg)
                    psnrs.append(psnr)
                    ssims.append(ssim)
                    l_alex.append(lpips_alex)
                    l_vgg.append(lpips_vgg)
                    psnr = np.mean(np.array(psnrs))
                    ssim = np.mean(np.array(ssims))
                    l_a  = np.mean(np.array(l_alex))
                    l_v  = np.mean(np.array(l_vgg))

            rS=f'{dataname} : psnr {psnr} ssim {ssim}  l_a {l_a} l_v {l_v}'
            print(rS)
            f.write(rS+"\n")

            all_psnr.append(psnr)
            all_ssim.append(ssim)
            all_alex.append(l_a)
            all_vgg.append(l_v)
        
        psnr = np.mean(np.array(all_psnr))
        ssim = np.mean(np.array(all_ssim))
        l_a  = np.mean(np.array(all_alex))
        l_v  = np.mean(np.array(all_vgg))

        rS=f'mean : psnr {psnr} ssim {ssim}  l_a {l_a} l_v {l_v}'
        print(rS)
        f.write(rS+"\n")

================================================
FILE: models/VGG.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple
import torchvision.models as models

# pytorch pretrained vgg
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        #pretrained vgg19
        vgg19 = models.vgg19(weights='DEFAULT').features

        self.relu1_1 = vgg19[:2]
        self.relu2_1 = vgg19[2:7]
        self.relu3_1 = vgg19[7:12]
        self.relu4_1 = vgg19[12:21]

        #fix parameters
        self.requires_grad_(False)

    def forward(self, x):
        _output = namedtuple('output', ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1'])
        relu1_1 = self.relu1_1(x)
        relu2_1 = self.relu2_1(relu1_1)
        relu3_1 = self.relu3_1(relu2_1)
        relu4_1 = self.relu4_1(relu3_1)
        output = _output(relu1_1, relu2_1, relu3_1, relu4_1)

        return output


class Decoder(nn.Module): 
    """
    starting from relu 4_1
    """
    def __init__(self, ckpt_path=None):
        super().__init__()
        
        self.layers = nn.Sequential(
            # nn.Conv2d(512, 256, 3, padding=1, padding_mode='reflect'),
            # nn.ReLU(),
            # nn.Upsample(scale_factor=2, mode='nearest'), # relu4-1
            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), # relu3-4
            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), # relu3-3
            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), # relu3-2
            nn.Conv2d(256, 128, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), 
            nn.Upsample(scale_factor=2, mode='nearest'),# relu3-1
            nn.Conv2d(128, 128, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), # relu2-2
            nn.Conv2d(128, 64, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), 
            nn.Upsample(scale_factor=2, mode='nearest'),# relu2-1
            nn.Conv2d(64, 64, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), # relu1-2
            nn.Conv2d(64, 3, 3, padding=1, padding_mode='reflect'),
        )

        if ckpt_path is not None:
          self.load_state_dict(torch.load(ckpt_path))

    def forward(self, x):
        return self.layers(x)


### high-res unet feature map decoder


class DownBlock(nn.Module):

    def __init__(self, in_dim, out_dim, down='conv'):
        super(DownBlock, self).__init__()

        if down == 'conv':
            self.down_conv = nn.Sequential(
                nn.Conv2d(in_dim, out_dim, 3, 2, 1),
                nn.LeakyReLU(),
                nn.Conv2d(out_dim, out_dim, 3, 1, 1),
                nn.LeakyReLU(),
            )
        elif down == 'mean':
            self.down_conv = nn.AvgPool2d(2)
        else:
            raise NotImplementedError(
                '[ERROR] invalid downsampling operator: {:s}'.format(down)
            )

    def forward(self, x):
        x = self.down_conv(x)
        return x


class UpBlock(nn.Module):

    def __init__(self, in_dim, out_dim, skip_dim=None, up='nearest'):
        super(UpBlock, self).__init__()

        if up == 'conv':
            self.up_conv = nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, 3, 2, 1, 1),
                nn.ReLU(),
            )
        else:
            assert up in ('bilinear', 'nearest'), \
                '[ERROR] invalid upsampling mode: {:s}'.format(up)
            self.up_conv = nn.Sequential(
                nn.Upsample(scale_factor=2, mode=up),
                nn.Conv2d(in_dim, out_dim, 3, 1, 1),
                nn.ReLU(),
            )
        
        in_dim = out_dim
        if skip_dim is not None:
            in_dim += skip_dim
        self.conv = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, 3, 1, 1),
            nn.ReLU(),
        )

    def _pad(self, x, y):
        dh = y.size(-2) - x.size(-2)
        dw = y.size(-1) - x.size(-1)
        if dh == 0 and dw == 0:
            return x
        if dh < 0:
            x = x[..., :dh, :]
        if dw < 0:
            x = x[..., :, :dw]
        if dh > 0 or dw > 0:
            x = F.pad(
                x, 
                pad=(dw // 2, dw - dw // 2, dh // 2, dh - dh // 2), 
                mode='reflect'
            )
        return x

    def forward(self, x, skip=None):
        x = self.up_conv(x)
        if skip is not None:
            x = torch.cat([self._pad(x, skip), skip], 1)
        x = self.conv(x)
        return x


class UNetDecoder(nn.Module):

    def __init__(self, in_dim=256):
        super(UNetDecoder, self).__init__()

        self.down_layers = nn.ModuleList()
        self.skip_convs = nn.ModuleList()
        self.up_layers = nn.ModuleList()

        in_dim = in_dim
        self.n_levels = 2
        self.up = 1

        for i in range(self.n_levels):
            self.down_layers.append(
                DownBlock(
                    in_dim, in_dim,
                )
            )
            out_dim = in_dim // 2 ** (self.n_levels - i)
            self.skip_convs.append(nn.Conv2d(in_dim, out_dim, 1))
            self.up_layers.append(
                UpBlock(
                    out_dim * 2, out_dim, out_dim,
                )
            )

        out_dim = in_dim // 2 ** self.n_levels
        self.out_conv = nn.Sequential(
            nn.Conv2d(out_dim, out_dim, 3, 1, 1),
            nn.ReLU(),
            nn.Conv2d(out_dim, 3, 1, 1),
        )

    def forward(self, feats):
        skips = []
        for i in range(self.n_levels):
            skips.append(self.skip_convs[i](feats))
            feats = self.down_layers[i](feats)
        for i in range(self.n_levels - 1, -1, -1):
            feats = self.up_layers[i](feats, skips[i])
        rgb = self.out_conv(feats)
        return rgb


### high-res feature map decoder

class PlainDecoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), # relu3-4
            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), # relu3-3
            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), # relu3-2
            nn.Conv2d(256, 128, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), 
            nn.Conv2d(128, 128, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), # relu2-2
            nn.Conv2d(128, 64, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), 
            nn.Conv2d(64, 64, 3, padding=1, padding_mode='reflect'),
            nn.ReLU(), # relu1-2
            nn.Conv2d(64, 3, 3, padding=1, padding_mode='reflect'),
        )

    def forward(self, x):
        return self.layers(x)

================================================
FILE: models/__init__.py
================================================


================================================
FILE: models/sh.py
================================================
import torch

################## sh function ##################
C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [
    1.0925484305920792,
    -1.0925484305920792,
    0.31539156525252005,
    -1.0925484305920792,
    0.5462742152960396
]
C3 = [
    -0.5900435899266435,
    2.890611442640554,
    -0.4570457994644658,
    0.3731763325901154,
    -0.4570457994644658,
    1.445305721320277,
    -0.5900435899266435
]
C4 = [
    2.5033429417967046,
    -1.7701307697799304,
    0.9461746957575601,
    -0.6690465435572892,
    0.10578554691520431,
    -0.6690465435572892,
    0.47308734787878004,
    -1.7701307697799304,
    0.6258357354491761,
]

def eval_sh(deg, sh, dirs):
    """
    Evaluate spherical harmonics at unit directions
    using hardcoded SH polynomials.
    Works with torch/np/jnp.
    ... Can be 0 or more batch dimensions.
    :param deg: int SH max degree. Currently, 0-4 supported
    :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2)
    :param dirs: torch.Tensor unit directions (..., 3)
    :return: (..., C)
    """
    assert deg <= 4 and deg >= 0
    assert (deg + 1) ** 2 == sh.shape[-1]
    C = sh.shape[-2]

    result = C0 * sh[..., 0]
    if deg > 0:
        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
        result = (result -
                C1 * y * sh[..., 1] +
                C1 * z * sh[..., 2] -
                C1 * x * sh[..., 3])
        if deg > 1:
            xx, yy, zz = x * x, y * y, z * z
            xy, yz, xz = x * y, y * z, x * z
            result = (result +
                    C2[0] * xy * sh[..., 4] +
                    C2[1] * yz * sh[..., 5] +
                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
                    C2[3] * xz * sh[..., 7] +
                    C2[4] * (xx - yy) * sh[..., 8])

            if deg > 2:
                result = (result +
                        C3[0] * y * (3 * xx - yy) * sh[..., 9] +
                        C3[1] * xy * z * sh[..., 10] +
                        C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
                        C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
                        C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
                        C3[5] * z * (xx - yy) * sh[..., 14] +
                        C3[6] * x * (xx - 3 * yy) * sh[..., 15])
                if deg > 3:
                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
    return result

def eval_sh_bases(deg, dirs):
    """
    Evaluate spherical harmonics bases at unit directions,
    without taking linear combination.
    At each point, the final result may the be
    obtained through simple multiplication.
    :param deg: int SH max degree. Currently, 0-4 supported
    :param dirs: torch.Tensor (..., 3) unit directions
    :return: torch.Tensor (..., (deg+1) ** 2)
    """
    assert deg <= 4 and deg >= 0
    result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device)
    result[..., 0] = C0
    if deg > 0:
        x, y, z = dirs.unbind(-1)
        result[..., 1] = -C1 * y;
        result[..., 2] = C1 * z;
        result[..., 3] = -C1 * x;
        if deg > 1:
            xx, yy, zz = x * x, y * y, z * z
            xy, yz, xz = x * y, y * z, x * z
            result[..., 4] = C2[0] * xy;
            result[..., 5] = C2[1] * yz;
            result[..., 6] = C2[2] * (2.0 * zz - xx - yy);
            result[..., 7] = C2[3] * xz;
            result[..., 8] = C2[4] * (xx - yy);

            if deg > 2:
                result[..., 9] = C3[0] * y * (3 * xx - yy);
                result[..., 10] = C3[1] * xy * z;
                result[..., 11] = C3[2] * y * (4 * zz - xx - yy);
                result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy);
                result[..., 13] = C3[4] * x * (4 * zz - xx - yy);
                result[..., 14] = C3[5] * z * (xx - yy);
                result[..., 15] = C3[6] * x * (xx - 3 * yy);

                if deg > 3:
                    result[..., 16] = C4[0] * xy * (xx - yy);
                    result[..., 17] = C4[1] * yz * (3 * xx - yy);
                    result[..., 18] = C4[2] * xy * (7 * zz - 1);
                    result[..., 19] = C4[3] * yz * (7 * zz - 3);
                    result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3);
                    result[..., 21] = C4[5] * xz * (7 * zz - 3);
                    result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1);
                    result[..., 23] = C4[7] * xz * (xx - 3 * yy);
                    result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy));
    return result


================================================
FILE: models/styleModules.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import random

def calc_mean_std(x, eps=1e-8):
        """
        calculating channel-wise instance mean and standard variance
        x: shape of (N,C,*)
        """
        mean = torch.mean(x.flatten(2), dim=-1, keepdim=True) # size of (N, C, 1)
        std = torch.std(x.flatten(2), dim=-1, keepdim=True) + eps # size of (N, C, 1)
        
        return mean, std

def cal_adain_style_loss(x, y):
    """
    style loss in one layer

    Args:
        x, y: feature maps of size [N, C, H, W]
    """
    x_mean, x_std = calc_mean_std(x)
    y_mean, y_std = calc_mean_std(y)

    return nn.functional.mse_loss(x_mean, y_mean) \
         + nn.functional.mse_loss(x_std, y_std)

def cal_mse_content_loss(x, y):
    return nn.functional.mse_loss(x, y)



class LearnableIN(nn.Module):
    '''
    Input: (N, C, L) or (C, L)
    '''
    def __init__(self, dim=256):
        super().__init__()
        self.IN = torch.nn.InstanceNorm1d(dim, momentum=1e-4, track_running_stats=True)

    def forward(self, x):
        if x.size()[-1] <= 1:
            return x
        return self.IN(x)

class SimpleLinearStylizer(nn.Module):
    def __init__(self, input_dim=256, embed_dim=32, n_layers=3 ) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.embed_dim = embed_dim

        self.IN = LearnableIN(input_dim)

        self.q_embed = nn.Conv1d(input_dim, embed_dim, 1)
        self.k_embed = nn.Conv1d(input_dim, embed_dim, 1)
        self.v_embed = nn.Conv1d(input_dim, embed_dim, 1)

        self.unzipper = nn.Conv1d(embed_dim, input_dim, 1, bias=0)

        s_net = []
        for i in range(n_layers - 1):
            out_dim = max(embed_dim, input_dim // 2)
            s_net.append(
                nn.Sequential(
                    nn.Conv1d(input_dim, out_dim, 1),
                    nn.ReLU(inplace=True),
                )
            )
            input_dim = out_dim
        s_net.append(nn.Conv1d(input_dim, embed_dim, 1))
        self.s_net = nn.Sequential(*s_net)

        self.s_fc = nn.Linear(embed_dim ** 2, embed_dim ** 2)

    def _vectorized_covariance(self, x):
        cov = torch.bmm(x, x.transpose(2, 1)) / x.size(-1)
        cov = cov.flatten(1)
        return cov

    def get_content_matrix(self, c):
        '''
        Args:
            c: content feature [N,input_dim,S]
        Return:
            mat: [N,S,embed_dim,embed_dim]
        '''
        normalized_c = self.IN(c)
        # normalized_c = torch.nn.functional.instance_norm(c)
        q_embed = self.q_embed(normalized_c)
        k_embed = self.k_embed(normalized_c)
        
        c_cov = q_embed.transpose(1,2).unsqueeze(3) * k_embed.transpose(1,2).unsqueeze(2) # [N,S,embed_dim,embed_dim]
        attn = torch.softmax(c_cov, -1) # [N,S,embed_dim,embed_dim]

        return attn, normalized_c

    def get_style_mean_std_matrix(self, s):
        '''
        Args:
            s: style feature [N,input_dim,S]

        Return:
            mat: [N,embed_dim,embed_dim]
        '''
        s_mean = s.mean(-1, keepdim=True)
        s_std = s.std(-1, keepdim=True)
        s = s - s_mean

        s_embed = self.s_net(s)
        s_cov = self._vectorized_covariance(s_embed)
        s_mat = self.s_fc(s_cov)
        s_mat = s_mat.reshape(-1, self.embed_dim, self.embed_dim)

        return s_mean, s_std, s_mat

    def transform_content_3D(self, c):
        '''
        Args:
            c: content feature [N,input_dim,S]
        Return:
            transformed_c: [N,embed_dim,S]
        '''
        attn, normalized_c = self.get_content_matrix(c) # [N,S,embed_dim,embed_dim]
        c = self.v_embed(normalized_c) # [N,embed_dim,S]
        c = c.transpose(1,2).unsqueeze(3) # [N,S,embed_dim,1]
        c = torch.matmul(attn, c).squeeze(3) # [N,S,embed_dim]

        return c.transpose(1,2)

    def transfer_style_2D(self, s_mean_std_mat, c, acc_map):
        '''
        Agrs:
            c: content feature map after volume rendering [N,embed_dim,S]
            s_mat: style matrix [N,embed_dim,embed_dim]
            acc_map: [S]
            
            s_mean = [N,input_dim,1]
            s_std = [N,input_dim,1]
        '''
        s_mean, s_std, s_mat = s_mean_std_mat

        cs = torch.bmm(s_mat, c) # [N,embed_dim,S]
        cs = self.unzipper(cs) # [N,input_dim,S]

        cs = cs * s_std + s_mean * acc_map[None,None,...]

        return cs


class AdaAttN(nn.Module):
    """ Attention-weighted AdaIN (Liu et al., ICCV 21) """

    def __init__(self, qk_dim, v_dim):
        """
        Args:
            qk_dim (int): query and key size.
            v_dim (int): value size.
        """
        super(AdaAttN, self).__init__()

        self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1)
        self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1)
        self.s_embed = nn.Conv1d(v_dim, v_dim, 1)

    def forward(self, q, k):
        """
        Args:
            q (float tensor, (bs, qk, *)): query (content) features.
            k (float tensor, (bs, qk, *)): key (style) features.
            c (float tensor, (bs, v, *)): content value features.
            s (float tensor, (bs, v, *)): style value features.

        Returns:
            cs (float tensor, (bs, v, *)): stylized content features.
        """
        c, s = q, k

        shape = c.shape
        q, k = q.flatten(2), k.flatten(2)
        c, s = c.flatten(2), s.flatten(2)

        # QKV attention with projected content and style features
        q = self.q_embed(F.instance_norm(q)).transpose(2, 1)    # (bs, n, qk)
        k = self.k_embed(F.instance_norm(k))                    # (bs, qk, m)
        s = self.s_embed(s).transpose(2, 1)                     # (bs, m, v)
        attn = F.softmax(torch.bmm(q, k), -1)                   # (bs, n, m)
        
        # attention-weighted channel-wise statistics
        mean = torch.bmm(attn, s)                               # (bs, n, v)
        var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2)       # (bs, n, v)
        mean = mean.transpose(2, 1)                             # (bs, v, n)
        std = torch.sqrt(var).transpose(2, 1)                   # (bs, v, n)
        
        cs = F.instance_norm(c) * std + mean                    # (bs, v, n)
        cs = cs.reshape(shape)
        return cs

class AdaAttN_new_IN(nn.Module):
    """ Attention-weighted AdaIN (Liu et al., ICCV 21) """

    def __init__(self, qk_dim, v_dim):
        """
        Args:
            qk_dim (int): query and key size.
            v_dim (int): value size.
        """
        super(AdaAttN_new_IN, self).__init__()

        self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1)
        self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1)
        self.s_embed = nn.Conv1d(v_dim, v_dim, 1)
        self.IN = LearnableIN(qk_dim)

    def forward(self, q, k):
        """
        Args:
            q (float tensor, (bs, qk, *)): query (content) features.
            k (float tensor, (bs, qk, *)): key (style) features.
            c (float tensor, (bs, v, *)): content value features.
            s (float tensor, (bs, v, *)): style value features.

        Returns:
            cs (float tensor, (bs, v, *)): stylized content features.
        """
        c, s = q, k

        shape = c.shape
        q, k = q.flatten(2), k.flatten(2)
        c, s = c.flatten(2), s.flatten(2)

        # QKV attention with projected content and style features
        q = self.q_embed(self.IN(q)).transpose(2, 1)    # (bs, n, qk)
        k = self.k_embed(F.instance_norm(k))                    # (bs, qk, m)
        s = self.s_embed(s).transpose(2, 1)                     # (bs, m, v)
        attn = F.softmax(torch.bmm(q, k), -1)                   # (bs, n, m)
        
        # attention-weighted channel-wise statistics
        mean = torch.bmm(attn, s)                               # (bs, n, v)
        var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2)       # (bs, n, v)
        mean = mean.transpose(2, 1)                             # (bs, v, n)
        std = torch.sqrt(var).transpose(2, 1)                   # (bs, v, n)
        
        cs = self.IN(c) * std + mean                    # (bs, v, n)
        cs = cs.reshape(shape)
        return cs

class AdaAttN_woin(nn.Module):
    """ Attention-weighted AdaIN (Liu et al., ICCV 21) """

    def __init__(self, qk_dim, v_dim):
        """
        Args:
            qk_dim (int): query and key size.
            v_dim (int): value size.
        """
        super().__init__()

        self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1)
        self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1)
        self.s_embed = nn.Conv1d(v_dim, v_dim, 1)

    def forward(self, q, k):
        """
        Args:
            q (float tensor, (bs, qk, *)): query (content) features.
            k (float tensor, (bs, qk, *)): key (style) features.
            c (float tensor, (bs, v, *)): content value features.
            s (float tensor, (bs, v, *)): style value features.

        Returns:
            cs (float tensor, (bs, v, *)): stylized content features.
        """
        c, s = q, k

        shape = c.shape
        q, k = q.flatten(2), k.flatten(2)
        c, s = c.flatten(2), s.flatten(2)

        # QKV attention with projected content and style features
        q = self.q_embed(q).transpose(2, 1)    # (bs, n, qk)
        k = self.k_embed(k)                    # (bs, qk, m)
        s = self.s_embed(s).transpose(2, 1)                     # (bs, m, v)
        attn = F.softmax(torch.bmm(q, k), -1)                   # (bs, n, m)
        
        # attention-weighted channel-wise statistics
        mean = torch.bmm(attn, s)                               # (bs, n, v)
        var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2)       # (bs, n, v)
        mean = mean.transpose(2, 1)                             # (bs, v, n)
        std = torch.sqrt(var).transpose(2, 1)                   # (bs, v, n)
        
        cs = c * std + mean                    # (bs, v, n)
        cs = cs.reshape(shape)
        return cs

================================================
FILE: models/tensoRF.py
================================================
from .tensorBase import *
from .VGG import Encoder, Decoder, UNetDecoder, PlainDecoder
from .styleModules import LearnableIN, SimpleLinearStylizer

class TensorVMSplit(TensorBase):
    def __init__(self, aabb, gridSize, device, **kargs):
        super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs)



    def change_to_feature_mod(self, feature_n_comp, device):
        self.density_line.requires_grad_(False)
        self.density_plane.requires_grad_(False)
        self.app_line = None
        self.app_plane = None
        self.basis_mat = None
        self.renderModule = None

        # Both encoder and decoder do not require grad when initialized
        self.encoder = Encoder().to(device)
        self.decoder = PlainDecoder().to(device)

        # We need to finetune decoder when training a feature grid
        self.decoder.requires_grad_(True)

        self.feature_n_comp = feature_n_comp
        self.init_feature_svd(device)

    def change_to_style_mod(self, device='cuda'):
        assert self.feature_line is not None, 'Have to be trained in feature mod first!'
        self.feature_line.requires_grad_(False)
        self.feature_plane.requires_grad_(False)
        self.feature_basis_mat.requires_grad_(False)
        self.decoder.requires_grad_(True)

        self.IN = LearnableIN().to(device)

        # self.stylizer = NearestFeatureTransform()
        # self.stylizer = LearnableIN(1,256,device)
        # self.stylizer = AdaAttN(256, 256).to(device)
        # self.stylizer = AdaAttN_woin(256, 256).to(device)
        self.stylizer = SimpleLinearStylizer(256).to(device)


    def init_svd_volume(self, res, device):
        self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device)
        self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device)
        self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device)

    def init_feature_svd(self, device):
        self.feature_plane, self.feature_line = self.init_one_svd(self.feature_n_comp, self.gridSize, 0.1, device)
        self.feature_basis_mat = torch.nn.Linear(sum(self.feature_n_comp), 256, bias=False).to(device)

    def init_one_svd(self, n_component, gridSize, scale, device):
        plane_coef, line_coef = [], []
        for i in range(len(self.vecMode)):
            vec_id = self.vecMode[i]
            mat_id_0, mat_id_1 = self.matMode[i]
            plane_coef.append(torch.nn.Parameter(
                scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0]))))  
            line_coef.append(
                torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1))))

        return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device)
    
    

    def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
        grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, {'params': self.density_plane, 'lr': lr_init_spatialxyz},
                     {'params': self.app_line, 'lr': lr_init_spatialxyz}, {'params': self.app_plane, 'lr': lr_init_spatialxyz},
                         {'params': self.basis_mat.parameters(), 'lr':lr_init_network}]
        if isinstance(self.renderModule, torch.nn.Module):
            grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}]
        return grad_vars

    def get_optparam_groups_feature_mod(self, lr_init_spatialxyz, lr_init_network):
        grad_vars = [{'params': self.feature_line, 'lr': lr_init_spatialxyz}, 
                     {'params': self.feature_plane, 'lr': lr_init_spatialxyz},
                     {'params': self.feature_basis_mat.parameters(), 'lr':lr_init_network},
                     {'params': self.decoder.parameters(), 'lr':lr_init_network}]
        return grad_vars

    def get_optparam_groups_style_mod(self, lr_init_network, lr_finetune):
        grad_vars = [
                        {'params': self.stylizer.parameters(), 'lr': lr_init_network},
                        {'params': self.decoder.parameters(), 'lr': lr_finetune}, 
                    ]
        return grad_vars



    def vectorDiffs(self, vector_comps):
        total = 0
        
        for idx in range(len(vector_comps)):
            n_comp, n_size = vector_comps[idx].shape[1:-1]
            
            dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2))
            non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1]
            total = total + torch.mean(torch.abs(non_diagonal))
        return total

    def vector_comp_diffs(self):
        return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line)
    
    def density_L1(self):
        total = 0
        for idx in range(len(self.density_plane)):
            total = total + torch.mean(torch.abs(self.density_plane[idx])) + torch.mean(torch.abs(self.density_line[idx]))# + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.density_plane[idx]))
        return total
    
    def TV_loss_density(self, reg):
        total = 0
        for idx in range(len(self.density_plane)):
            total = total + reg(self.density_plane[idx]) * 1e-2 #+ reg(self.density_line[idx]) * 1e-3
        return total
        
    def TV_loss_app(self, reg):
        total = 0
        for idx in range(len(self.app_plane)):
            total = total + reg(self.app_plane[idx]) * 1e-2 #+ reg(self.app_line[idx]) * 1e-3
        return total

    def TV_loss_feature(self, reg):
        total = 0
        for idx in range(len(self.feature_plane)):
            total = total + reg(self.feature_plane[idx]) + reg(self.feature_line[idx])
        return total



    def compute_densityfeature(self, xyz_sampled):

        # plane + line basis
        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)

        sigma_feature = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device)
        for idx_plane in range(len(self.density_plane)):
            plane_coef_point = F.grid_sample(self.density_plane[idx_plane], coordinate_plane[[idx_plane]],
                                                align_corners=True).view(-1, *xyz_sampled.shape[:1])
            line_coef_point = F.grid_sample(self.density_line[idx_plane], coordinate_line[[idx_plane]],
                                            align_corners=True).view(-1, *xyz_sampled.shape[:1])
            sigma_feature = sigma_feature + torch.sum(plane_coef_point * line_coef_point, dim=0)

        return sigma_feature

    def compute_appfeature(self, xyz_sampled):

        # plane + line basis
        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)

        plane_coef_point,line_coef_point = [],[]
        for idx_plane in range(len(self.app_plane)):
            plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]],
                                                align_corners=True).view(-1, *xyz_sampled.shape[:1]))
            line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]],
                                            align_corners=True).view(-1, *xyz_sampled.shape[:1]))
        plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)


        return self.basis_mat((plane_coef_point * line_coef_point).T)

    def compute_feature(self, xyz_sampled):

        # plane + line basis
        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)

        plane_coef_point,line_coef_point = [],[]
        for idx_plane in range(len(self.feature_plane)):
            plane_coef_point.append(F.grid_sample(self.feature_plane[idx_plane], coordinate_plane[[idx_plane]],
                                                align_corners=True).view(-1, *xyz_sampled.shape[:1]))
            line_coef_point.append(F.grid_sample(self.feature_line[idx_plane], coordinate_line[[idx_plane]],
                                            align_corners=True).view(-1, *xyz_sampled.shape[:1]))
        plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)


        return self.feature_basis_mat((plane_coef_point * line_coef_point).T)
        


    def render_feature_map(self, rays_chunk, s_mean_std_mat=None, is_train=False, ndc_ray=False, N_samples=-1):

        # sample points
        viewdirs = rays_chunk[:, 3:6]
        if ndc_ray:
            xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
            rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
            dists = dists * rays_norm
        else:
            xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
        
        if self.alphaMask is not None:
            alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
            alpha_mask = alphas > 0
            ray_invalid = ~ray_valid
            ray_invalid[ray_valid] |= (~alpha_mask)
            ray_valid = ~ray_invalid


        sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
        if s_mean_std_mat is not None:
            features = torch.zeros((*xyz_sampled.shape[:2], self.stylizer.embed_dim), device=xyz_sampled.device)
        else:
            features = torch.zeros((*xyz_sampled.shape[:2], 256), device=xyz_sampled.device)


        if ray_valid.any():
            xyz_sampled = self.normalize_coord(xyz_sampled)
            sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])

            validsigma = self.feature2density(sigma_feature)
            sigma[ray_valid] = validsigma


        alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)

        app_mask = weight > self.rayMarch_weight_thres

        if app_mask.any():
            valid_features = self.compute_feature(xyz_sampled[app_mask]) # [n_valid_points~40k if not specify nSamples, C=256]
            
            # transform content on 3d
            if s_mean_std_mat is not None:
                valid_features = self.stylizer.transform_content_3D(valid_features.transpose(0,1)[None,...])
                valid_features = valid_features.squeeze(0).transpose(0,1)

            features[app_mask] = valid_features

        feature_map = torch.sum(weight[..., None] * features, -2)
        acc_map = torch.sum(weight, -1)
        
        # style transfer on 2d
        if s_mean_std_mat is not None:
            feature_map = self.stylizer.transfer_style_2D(s_mean_std_mat, feature_map.transpose(0,1)[None,...], acc_map)
            feature_map = feature_map.squeeze().transpose(0,1)

        return feature_map, acc_map

    def render_depth_map(self, rays_chunk, is_train=False, ndc_ray=False, N_samples=-1):

        # sample points
        viewdirs = rays_chunk[:, 3:6]
        if ndc_ray:
            xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
            rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
            dists = dists * rays_norm
        else:
            xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
        
        if self.alphaMask is not None:
            alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
            alpha_mask = alphas > 0
            ray_invalid = ~ray_valid
            ray_invalid[ray_valid] |= (~alpha_mask)
            ray_valid = ~ray_invalid

        sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)

        if ray_valid.any():
            xyz_sampled = self.normalize_coord(xyz_sampled)
            sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])
            validsigma = self.feature2density(sigma_feature)
            sigma[ray_valid] = validsigma

        alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)

        acc_map = torch.sum(weight, -1)

        depth_map = torch.sum(weight * z_vals, -1)
        depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]

        return depth_map # [n_rays]



    @torch.no_grad()
    def up_sampling_VM(self, plane_coef, line_coef, res_target):

        for i in range(len(self.vecMode)):
            vec_id = self.vecMode[i]
            mat_id_0, mat_id_1 = self.matMode[i]
            plane_coef[i] = torch.nn.Parameter(
                F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear',
                              align_corners=True))
            line_coef[i] = torch.nn.Parameter(
                F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))


        return plane_coef, line_coef

    @torch.no_grad()
    def upsample_volume_grid(self, res_target):
        self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target)
        self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target)

        self.update_stepSize(res_target)
        print(f'upsamping to {res_target}')

    @torch.no_grad()
    def shrink(self, new_aabb):
        print("====> shrinking ...")
        xyz_min, xyz_max = new_aabb
        t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units
        # print(new_aabb, self.aabb)
        # print(t_l, b_r,self.alphaMask.alpha_volume.shape)
        t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
        b_r = torch.stack([b_r, self.gridSize]).amin(0)

        for i in range(len(self.vecMode)):
            mode0 = self.vecMode[i]
            self.density_line[i] = torch.nn.Parameter(
                self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:]
            )
            self.app_line[i] = torch.nn.Parameter(
                self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:]
            )
            mode0, mode1 = self.matMode[i]
            self.density_plane[i] = torch.nn.Parameter(
                self.density_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]
            )
            self.app_plane[i] = torch.nn.Parameter(
                self.app_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]
            )


        if not torch.all(self.alphaMask.gridSize == self.gridSize):
            t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)
            correct_aabb = torch.zeros_like(new_aabb)
            correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1]
            correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1]
            print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
            new_aabb = correct_aabb

        newSize = b_r - t_l
        self.aabb = new_aabb
        self.update_stepSize((newSize[0], newSize[1], newSize[2]))


================================================
FILE: models/tensorBase.py
================================================
import torch
import torch.nn
import torch.nn.functional as F
from .sh import eval_sh_bases
import numpy as np
import time


def positional_encoding(positions, freqs):
    
        freq_bands = (2**torch.arange(freqs).float()).to(positions.device)  # (F,)
        pts = (positions[..., None] * freq_bands).reshape(
            positions.shape[:-1] + (freqs * positions.shape[-1], ))  # (..., DF)
        pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
        return pts

def raw2alpha(sigma, dist):
    # sigma, dist  [N_rays, N_samples]
    alpha = 1. - torch.exp(-sigma*dist)

    T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1)

    weights = alpha * T[:, :-1]  # [N_rays, N_samples]
    return alpha, weights, T[:,-1:]


def SHRender(xyz_sampled, viewdirs, features):
    sh_mult = eval_sh_bases(2, viewdirs)[:, None]
    rgb_sh = features.view(-1, 3, sh_mult.shape[-1])
    rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5)
    return rgb


def RGBRender(xyz_sampled, viewdirs, features):

    rgb = features
    return rgb

class AlphaGridMask(torch.nn.Module):
    def __init__(self, device, aabb, alpha_volume):
        super(AlphaGridMask, self).__init__()
        self.device = device

        self.aabb=aabb.to(self.device)
        self.aabbSize = self.aabb[1] - self.aabb[0]
        self.invgridSize = 1.0/self.aabbSize * 2
        self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:])
        self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device)

    def sample_alpha(self, xyz_sampled):
        xyz_sampled = self.normalize_coord(xyz_sampled)
        alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1)

        return alpha_vals

    def normalize_coord(self, xyz_sampled):
        return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1


class MLPRender_Fea(torch.nn.Module):
    def __init__(self,inChanel, viewpe=6, feape=6, featureC=128):
        super(MLPRender_Fea, self).__init__()

        self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel
        self.viewpe = viewpe
        self.feape = feape
        layer1 = torch.nn.Linear(self.in_mlpC, featureC)
        layer2 = torch.nn.Linear(featureC, featureC)
        layer3 = torch.nn.Linear(featureC,3)

        self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
        torch.nn.init.constant_(self.mlp[-1].bias, 0)

    def forward(self, pts, viewdirs, features):
        indata = [features, viewdirs]
        if self.feape > 0:
            indata += [positional_encoding(features, self.feape)]
        if self.viewpe > 0:
            indata += [positional_encoding(viewdirs, self.viewpe)]
        mlp_in = torch.cat(indata, dim=-1)
        rgb = self.mlp(mlp_in)
        rgb = torch.sigmoid(rgb)

        return rgb

class MLPRender_PE(torch.nn.Module):
    def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128):
        super(MLPRender_PE, self).__init__()

        self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3)  + inChanel #
        self.viewpe = viewpe
        self.pospe = pospe
        layer1 = torch.nn.Linear(self.in_mlpC, featureC)
        layer2 = torch.nn.Linear(featureC, featureC)
        layer3 = torch.nn.Linear(featureC,3)

        self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
        torch.nn.init.constant_(self.mlp[-1].bias, 0)

    def forward(self, pts, viewdirs, features):
        indata = [features, viewdirs]
        if self.pospe > 0:
            indata += [positional_encoding(pts, self.pospe)]
        if self.viewpe > 0:
            indata += [positional_encoding(viewdirs, self.viewpe)]
        mlp_in = torch.cat(indata, dim=-1)
        rgb = self.mlp(mlp_in)
        rgb = torch.sigmoid(rgb)

        return rgb

class MLPRender(torch.nn.Module):
    def __init__(self,inChanel, viewpe=6, featureC=128):
        super(MLPRender, self).__init__()

        self.in_mlpC = (3+2*viewpe*3) + inChanel
        self.viewpe = viewpe
        
        layer1 = torch.nn.Linear(self.in_mlpC, featureC)
        layer2 = torch.nn.Linear(featureC, featureC)
        layer3 = torch.nn.Linear(featureC,3)

        self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
        torch.nn.init.constant_(self.mlp[-1].bias, 0)

    def forward(self, pts, viewdirs, features):
        indata = [features, viewdirs]
        if self.viewpe > 0:
            indata += [positional_encoding(viewdirs, self.viewpe)]
        mlp_in = torch.cat(indata, dim=-1)
        rgb = self.mlp(mlp_in)
        rgb = torch.sigmoid(rgb)

        return rgb



class TensorBase(torch.nn.Module):
    def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27,
                    shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0],
                    density_shift = -10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001,
                    pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0,
                    fea2denseAct = 'softplus'):
        super(TensorBase, self).__init__()

        self.density_n_comp = density_n_comp
        self.app_n_comp = appearance_n_comp
        self.app_dim = app_dim
        self.aabb = aabb
        self.alphaMask = alphaMask
        self.device=device

        self.density_shift = density_shift
        self.alphaMask_thres = alphaMask_thres
        self.distance_scale = distance_scale
        self.rayMarch_weight_thres = rayMarch_weight_thres
        self.fea2denseAct = fea2denseAct

        self.near_far = near_far
        self.step_ratio = step_ratio


        self.update_stepSize(gridSize)

        self.matMode = [[0,1], [0,2], [1,2]]
        self.vecMode =  [2, 1, 0]
        self.comp_w = [1,1,1]


        self.init_svd_volume(gridSize[0], device)

        self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC
        self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC, device)

    def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device):
        if shadingMode == 'MLP_PE':
            self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC).to(device)
        elif shadingMode == 'MLP_Fea':
            self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC).to(device)
        elif shadingMode == 'MLP':
            self.renderModule = MLPRender(self.app_dim, view_pe, featureC).to(device)
        elif shadingMode == 'SH':
            self.renderModule = SHRender
        elif shadingMode == 'RGB':
            assert self.app_dim == 3
            self.renderModule = RGBRender
        else:
            print("Unrecognized shading module")
            exit()
        print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe)

    def update_stepSize(self, gridSize):
        print("aabb", self.aabb.view(-1))
        print("grid size", gridSize)
        self.aabbSize = self.aabb[1] - self.aabb[0]
        self.invaabbSize = 2.0/self.aabbSize
        self.gridSize= torch.LongTensor(gridSize).to(self.device)
        self.units=self.aabbSize / (self.gridSize-1)
        self.stepSize=torch.mean(self.units)*self.step_ratio
        self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize)))
        self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1
        print("sampling step size: ", self.stepSize)
        print("sampling number: ", self.nSamples)

    def init_svd_volume(self, res, device):
        pass

    def compute_features(self, xyz_sampled):
        pass
    
    def compute_densityfeature(self, xyz_sampled):
        pass
    
    def compute_appfeature(self, xyz_sampled):
        pass
    
    def normalize_coord(self, xyz_sampled):
        return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1

    def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001):
        pass

    def get_kwargs(self):
        return {
            'aabb': self.aabb,
            'gridSize':self.gridSize.tolist(),
            'density_n_comp': self.density_n_comp,
            'appearance_n_comp': self.app_n_comp,
            'app_dim': self.app_dim,

            'density_shift': self.density_shift,
            'alphaMask_thres': self.alphaMask_thres,
            'distance_scale': self.distance_scale,
            'rayMarch_weight_thres': self.rayMarch_weight_thres,
            'fea2denseAct': self.fea2denseAct,

            'near_far': self.near_far,
            'step_ratio': self.step_ratio,

            'shadingMode': self.shadingMode,
            'pos_pe': self.pos_pe,
            'view_pe': self.view_pe,
            'fea_pe': self.fea_pe,
            'featureC': self.featureC
        }

    def save(self, path):
        kwargs = self.get_kwargs()
        ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()}
        if self.alphaMask is not None:
            alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy()
            ckpt.update({'alphaMask.shape':alpha_volume.shape})
            ckpt.update({'alphaMask.mask':np.packbits(alpha_volume.reshape(-1))})
            ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()})
        torch.save(ckpt, path)

    def load(self, ckpt):
        if 'alphaMask.aabb' in ckpt.keys():
            length = np.prod(ckpt['alphaMask.shape'])
            alpha_volume = torch.from_numpy(np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape']))
            self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device))
        self.load_state_dict(ckpt['state_dict'])


    def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
        N_samples = N_samples if N_samples > 0 else self.nSamples
        near, far = self.near_far
        interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o)
        if is_train:
            interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples)

        rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]
        mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1)
        return rays_pts, interpx, ~mask_outbbox

    def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1):
        N_samples = N_samples if N_samples>0 else self.nSamples
        stepsize = self.stepSize
        near, far = self.near_far
        vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d)
        rate_a = (self.aabb[1] - rays_o) / vec
        rate_b = (self.aabb[0] - rays_o) / vec
        t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far)

        rng = torch.arange(N_samples)[None].float()
        if is_train:
            rng = rng.repeat(rays_d.shape[-2],1)
            rng += torch.rand_like(rng[:,[0]])
        step = stepsize * rng.to(rays_o.device)
        interpx = (t_min[...,None] + step)

        rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None]
        mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1)

        return rays_pts, interpx, ~mask_outbbox


    def shrink(self, new_aabb, voxel_size):
        pass

    @torch.no_grad()
    def getDenseAlpha(self,gridSize=None):
        gridSize = self.gridSize if gridSize is None else gridSize

        samples = torch.stack(torch.meshgrid(
            torch.linspace(0, 1, gridSize[0]),
            torch.linspace(0, 1, gridSize[1]),
            torch.linspace(0, 1, gridSize[2]),
        ), -1).to(self.device)
        dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples

        # dense_xyz = dense_xyz
        # print(self.stepSize, self.distance_scale*self.aabbDiag)
        alpha = torch.zeros_like(dense_xyz[...,0])
        for i in range(gridSize[0]):
            alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2]))
        return alpha, dense_xyz

    @torch.no_grad()
    def updateAlphaMask(self, gridSize=(200,200,200)):

        alpha, dense_xyz = self.getDenseAlpha(gridSize)
        dense_xyz = dense_xyz.transpose(0,2).contiguous()
        alpha = alpha.clamp(0,1).transpose(0,2).contiguous()[None,None]
        total_voxels = gridSize[0] * gridSize[1] * gridSize[2]

        ks = 3
        alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1])
        alpha[alpha>=self.alphaMask_thres] = 1
        alpha[alpha<self.alphaMask_thres] = 0

        self.alphaMask = AlphaGridMask(self.device, self.aabb, alpha)

        valid_xyz = dense_xyz[alpha>0.5]

        xyz_min = valid_xyz.amin(0)
        xyz_max = valid_xyz.amax(0)

        new_aabb = torch.stack((xyz_min, xyz_max))

        total = torch.sum(alpha)
        print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f"%(total/total_voxels*100))
        return new_aabb

    @torch.no_grad()
    def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240*5, bbox_only=False):
        print('========> filtering rays ...')
        tt = time.time()

        N = torch.tensor(all_rays.shape[:-1]).prod()

        mask_filtered = []
        idx_chunks = torch.split(torch.arange(N), chunk)
        for idx_chunk in idx_chunks:
            rays_chunk = all_rays[idx_chunk].to(self.device)

            rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6]
            if bbox_only:
                vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
                rate_a = (self.aabb[1] - rays_o) / vec
                rate_b = (self.aabb[0] - rays_o) / vec
                t_min = torch.minimum(rate_a, rate_b).amax(-1)#.clamp(min=near, max=far)
                t_max = torch.maximum(rate_a, rate_b).amin(-1)#.clamp(min=near, max=far)
                mask_inbbox = t_max > t_min

            else:
                xyz_sampled, _,_ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False)
                mask_inbbox= (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1)

            mask_filtered.append(mask_inbbox.cpu())

        mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1])

        print(f'Ray filtering done! takes {time.time()-tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}')
        return all_rays[mask_filtered], all_rgbs[mask_filtered]


    def feature2density(self, density_features):
        if self.fea2denseAct == "softplus":
            return F.softplus(density_features+self.density_shift)
        elif self.fea2denseAct == "relu":
            return F.relu(density_features)


    def compute_alpha(self, xyz_locs, length=1):

        if self.alphaMask is not None:
            alphas = self.alphaMask.sample_alpha(xyz_locs)
            alpha_mask = alphas > 0
        else:
            alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool)
            

        sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)

        if alpha_mask.any():
            xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask])
            sigma_feature = self.compute_densityfeature(xyz_sampled)
            validsigma = self.feature2density(sigma_feature)
            sigma[alpha_mask] = validsigma
        

        alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1])

        return alpha


    def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):

        # sample points
        viewdirs = rays_chunk[:, 3:6]
        if ndc_ray:
            xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
            rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
            dists = dists * rays_norm
            viewdirs = viewdirs / rays_norm
        else:
            xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
        viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)
        
        if self.alphaMask is not None:
            alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
            alpha_mask = alphas > 0
            ray_invalid = ~ray_valid
            ray_invalid[ray_valid] |= (~alpha_mask)
            ray_valid = ~ray_invalid


        sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
        rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)

        if ray_valid.any():
            xyz_sampled = self.normalize_coord(xyz_sampled)
            sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])

            validsigma = self.feature2density(sigma_feature)
            sigma[ray_valid] = validsigma


        alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)

        app_mask = weight > self.rayMarch_weight_thres

        if app_mask.any():
            app_features = self.compute_appfeature(xyz_sampled[app_mask])
            valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features)
            rgb[app_mask] = valid_rgbs

        acc_map = torch.sum(weight, -1)
        rgb_map = torch.sum(weight[..., None] * rgb, -2)

        if white_bg or (is_train and torch.rand((1,))<0.5):
            rgb_map = rgb_map + (1. - acc_map[..., None])

        
        rgb_map = rgb_map.clamp(0,1)

        with torch.no_grad():
            depth_map = torch.sum(weight * z_vals, -1)
            depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]

        return rgb_map, depth_map # rgb, sigma, alpha, weight, bg_weight



================================================
FILE: opt.py
================================================
import configargparse

def config_parser(cmd=None):
    parser = configargparse.ArgumentParser()
    parser.add_argument('--config', is_config_file=True,
                        help='config file path')
    parser.add_argument("--expname", type=str,
                        help='experiment name')
    parser.add_argument("--basedir", type=str, default='./log',
                        help='where to store ckpts and logs')
    parser.add_argument("--add_timestamp", type=int, default=0,
                        help='add timestamp to dir')
    parser.add_argument("--datadir", type=str, default='./data/llff/fern',
                        help='input data directory')
    parser.add_argument("--wikiartdir", type=str, default='./data/WikiArt',
                        help='input data directory')
    parser.add_argument("--progress_refresh_rate", type=int, default=10,
                        help='how many iterations to show psnrs or iters')

    parser.add_argument('--with_depth', action='store_true')
    parser.add_argument('--downsample_train', type=float, default=1.0)
    parser.add_argument('--downsample_test', type=float, default=1.0)

    parser.add_argument('--model_name', type=str, default='TensorVMSplit',
                        choices=['TensorVMSplit', 'TensorCP'])

    # loader options
    parser.add_argument("--batch_size", type=int, default=4096)
    parser.add_argument("--n_iters", type=int, default=30000)
    parser.add_argument('--dataset_name', type=str, default='blender',
                        choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'])


    # training options
    parser.add_argument("--patch_size", type=int, default=256,
                        help='patch_size for training') 
    parser.add_argument("--chunk_size", type=int, default=4096,
                        help='chunk_size for training')           
    # learning rate
    parser.add_argument("--lr_init", type=float, default=0.02,
                        help='learning rate')    
    parser.add_argument("--lr_basis", type=float, default=1e-4,
                        help='learning rate')
    parser.add_argument("--lr_finetune", type=float, default=1e-5,
                        help='learning rate')
    parser.add_argument("--lr_decay_iters", type=int, default=-1,
                        help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters')
    parser.add_argument("--lr_decay_target_ratio", type=float, default=0.1,
                        help='the target decay ratio; after decay_iters inital lr decays to lr*ratio')
    parser.add_argument("--lr_upsample_reset", type=int, default=1,
                        help='reset lr to inital after upsampling')

    # loss
    parser.add_argument("--L1_weight_inital", type=float, default=0.0,
                        help='loss weight')
    parser.add_argument("--L1_weight_rest", type=float, default=0,
                        help='loss weight')
    parser.add_argument("--Ortho_weight", type=float, default=0.0,
                        help='loss weight')
    parser.add_argument("--TV_weight_density", type=float, default=0.0,
                        help='loss weight')
    parser.add_argument("--TV_weight_app", type=float, default=0.0,
                        help='loss weight')
    parser.add_argument("--TV_weight_feature", type=float, default=0.0,
                        help='loss weight')
    parser.add_argument("--style_weight", type=float, default=0,
                        help='loss weight')
    parser.add_argument("--content_weight", type=float, default=0,
                        help='loss weight')
    parser.add_argument("--image_tv_weight", type=float, default=0,
                        help='loss weight')
    parser.add_argument("--featuremap_tv_weight", type=float, default=0,
                        help='loss weight')

    
    # model
    # volume options
    parser.add_argument("--n_lamb_sigma", type=int, action="append")
    parser.add_argument("--n_lamb_sh", type=int, action="append")
    parser.add_argument("--data_dim_color", type=int, default=27)

    parser.add_argument("--rm_weight_mask_thre", type=float, default=0.0001,
                        help='mask points in ray marching')
    parser.add_argument("--alpha_mask_thre", type=float, default=0.0001,
                        help='threshold for creating alpha mask volume')
    parser.add_argument("--distance_scale", type=float, default=25,
                        help='scaling sampling distance for computation')
    parser.add_argument("--density_shift", type=float, default=-10,
                        help='shift density in softplus; making density = 0  when feature == 0')
                        
    # network decoder
    parser.add_argument("--shadingMode", type=str, default="MLP_Fea",
                        help='which shading mode to use')
    parser.add_argument("--pos_pe", type=int, default=6,
                        help='number of pe for pos')
    parser.add_argument("--view_pe", type=int, default=6,
                        help='number of pe for view')
    parser.add_argument("--fea_pe", type=int, default=6,
                        help='number of pe for features')
    parser.add_argument("--featureC", type=int, default=128,
                        help='hidden feature channel in MLP')
    

    # test option
    parser.add_argument("--ckpt", type=str, default=None,
                        help='specific weights npy file to reload for coarse network')
    parser.add_argument("--render_only", type=int, default=0)
    parser.add_argument("--render_test", type=int, default=0)
    parser.add_argument("--render_train", type=int, default=0)
    parser.add_argument("--render_path", type=int, default=0)
    parser.add_argument("--export_mesh", type=int, default=0)
    parser.add_argument("--style_img", type=str, required=False)

    # rendering options
    parser.add_argument('--lindisp', default=False, action="store_true",
                        help='use disparity depth sampling')
    parser.add_argument("--perturb", type=float, default=1.,
                        help='set to 0. for no jitter, 1. for jitter')
    parser.add_argument("--accumulate_decay", type=float, default=0.998)
    parser.add_argument("--fea2denseAct", type=str, default='softplus')
    parser.add_argument('--ndc_ray', type=int, default=0)
    parser.add_argument('--nSamples', type=int, default=1e6,
                        help='sample point each ray, pass 1e6 if automatic adjust')
    parser.add_argument('--step_ratio',type=float,default=0.5)


    ## blender flags
    parser.add_argument("--white_bkgd", action='store_true',
                        help='set to render synthetic data on a white bkgd (always use for dvoxels)')



    parser.add_argument('--N_voxel_init',
                        type=int,
                        default=100**3)
    parser.add_argument('--N_voxel_final',
                        type=int,
                        default=300**3)
    parser.add_argument("--upsamp_list", type=int, action="append")
    parser.add_argument("--update_AlphaMask_list", type=int, action="append")

    parser.add_argument('--idx_view',
                        type=int,
                        default=0)
    # logging/saving options
    parser.add_argument("--N_vis", type=int, default=5,
                        help='N images to vis')
    parser.add_argument("--vis_every", type=int, default=10000,
                        help='frequency of visualize the image')
    if cmd is not None:
        return parser.parse_args(cmd)
    else:
        return parser.parse_args()

================================================
FILE: renderer.py
================================================
import torch,os,imageio,sys
from tqdm.auto import tqdm
from dataLoader.ray_utils import get_rays
from models.tensoRF import raw2alpha, TensorVMSplit, AlphaGridMask
from utils import *
from dataLoader.ray_utils import ndc_rays_blender, denormalize_vgg, normalize_vgg


def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, 
                                render_feature=False, style_img=None, device='cuda'):

    rgbs, alphas, depth_maps, weights, uncertainties = [], [], [], [], []
    features, accs = [], []
    s_mean_std_mat = None
    if style_img is not None:
        with torch.no_grad():
            style_feature = tensorf.encoder(normalize_vgg(style_img))
        s_mean_std_mat = tensorf.stylizer.get_style_mean_std_matrix(style_feature.relu3_1.flatten(2))

    N_rays_all = rays.shape[0]
    for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):
        rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)
        
        if render_feature:
            feature_map, acc_map = tensorf.render_feature_map(rays_chunk, s_mean_std_mat=s_mean_std_mat, is_train=is_train, ndc_ray=ndc_ray, N_samples=N_samples)
            features.append(feature_map)
            accs.append(acc_map)
        else:
            rgb_map, depth_map = tensorf(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples)
            rgbs.append(rgb_map)
            depth_maps.append(depth_map)
    
    if render_feature:
        if style_img is not None:
            return torch.cat(features), torch.cat(accs), style_feature
        return torch.cat(features), torch.cat(accs)

    return torch.cat(rgbs), None, torch.cat(depth_maps), None, None 

def OctreeRender_trilinear_fast_depth(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, is_train=False, device='cuda'):

    depth_maps = []
    N_rays_all = rays.shape[0]
    for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):
        rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)
        depth_map = tensorf.render_depth_map(rays_chunk, is_train=is_train, ndc_ray=ndc_ray, N_samples=N_samples)
        depth_maps.append(depth_map)
    
    return torch.cat(depth_maps)


@torch.no_grad()
def evaluation_feature(test_dataset, tensorf, args, renderer, chunk_size=2048, savePath=None, N_vis=10, prtx='', N_samples=-1,
               white_bg=False, ndc_ray=False, compute_extra_metrics=False, style_img=None, device='cuda'):
    '''
    To see if the decoded feature map is similar to gt rgb map
    '''
    PSNRs, rgb_maps, vis_feature_maps = [], [], []
    ssims,l_alex,l_vgg=[],[],[]
    os.makedirs(savePath, exist_ok=True)
    os.makedirs(savePath + '/feature', exist_ok=True)
    W, H = test_dataset.img_wh

    try:
        tqdm._instances.clear()
    except Exception:
        pass

    near_far = test_dataset.near_far
    img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays_stack.shape[0] // N_vis,1)
    idxs = list(range(0, test_dataset.all_rays_stack.shape[0], img_eval_interval))
    for idx, samples in tqdm(enumerate(test_dataset.all_rays_stack[0::img_eval_interval]), file=sys.stdout):        

        rays = samples.view(-1,samples.shape[-1])

        if style_img is None:
            feature_map, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, 
                                        white_bg = white_bg, render_feature=True, device=device)
        else:
            feature_map, _, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, 
                                white_bg = white_bg, render_feature=True, style_img=style_img, device=device)
                            
        feature_map = feature_map.reshape(H, W, 256)[None,...].permute(0,3,1,2)

        recon_rgb = denormalize_vgg(tensorf.decoder(feature_map))
        recon_rgb = recon_rgb.permute(0,2,3,1).clamp(0,1)

        vis_feature_map = torch.sigmoid(feature_map[:, [1,2,3], :, :].permute(0,2,3,1))
        
        if test_dataset.white_bg:
            mask = test_dataset.all_masks[idx:idx+1].to(device)
            recon_rgb = mask*recon_rgb + (1.-mask)
            vis_feature_map = mask*vis_feature_map + (1.-mask)

        recon_rgb = recon_rgb.reshape(H, W, 3).cpu()
        vis_feature_map = vis_feature_map.squeeze().cpu()

        if len(test_dataset.all_rgbs_stack):
            gt_rgb = test_dataset.all_rgbs_stack[idxs[idx]].view(H, W, 3)
            loss = torch.mean((recon_rgb - gt_rgb) ** 2)
            PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))

            if compute_extra_metrics:
                ssim = rgb_ssim(recon_rgb, gt_rgb, 1)
                l_a = rgb_lpips(gt_rgb.numpy(), recon_rgb.numpy(), 'alex', tensorf.device)
                l_v = rgb_lpips(gt_rgb.numpy(), recon_rgb.numpy(), 'vgg', tensorf.device)
                ssims.append(ssim)
                l_alex.append(l_a)
                l_vgg.append(l_v)

        recon_rgb = (recon_rgb.numpy() * 255).astype('uint8')
        vis_feature_map = (vis_feature_map.numpy() * 255).astype('uint8')
        gt_rgb = (gt_rgb.numpy() * 255).astype('uint8')

        if savePath is not None:
            # rgb_map = np.concatenate((recon_rgb, gt_rgb), axis=1)
            rgb_maps.append(recon_rgb)
            vis_feature_maps.append(vis_feature_map)
            imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', recon_rgb)
            imageio.imwrite(f'{savePath}/feature/feature_{idx:03d}.png', vis_feature_map)

    imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)
    imageio.mimwrite(f'{savePath}/feature/feature_video.mp4', np.stack(vis_feature_maps), fps=30, quality=8)

    if PSNRs:
        psnr = np.mean(np.asarray(PSNRs))
        if compute_extra_metrics:
            ssim = np.mean(np.asarray(ssims))
            l_a = np.mean(np.asarray(l_alex))
            l_v = np.mean(np.asarray(l_vgg))
            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
        else:
            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))


    return PSNRs

@torch.no_grad()
def evaluation_feature_path(test_dataset, tensorf, c2ws, renderer, chunk_size=2048, savePath=None, N_vis=5, prtx='', N_samples=-1,
                    white_bg=False, ndc_ray=False, compute_extra_metrics=False, style_img=None, device='cuda'):
    PSNRs, rgb_maps, vis_feature_maps, depth_maps = [], [], [], []
    ssims,l_alex,l_vgg=[],[],[]
    os.makedirs(savePath, exist_ok=True)
    os.makedirs(savePath + '/feature', exist_ok=True)
    W, H = test_dataset.img_wh

    try:
        tqdm._instances.clear()
    except Exception:
        pass

    near_far = test_dataset.near_far
    for idx, c2w in tqdm(enumerate(c2ws)):

        c2w = torch.FloatTensor(c2w)
        rays_o, rays_d = get_rays(test_dataset.directions, c2w)  # both (h*w, 3)
        if ndc_ray:
            rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
        rays = torch.cat([rays_o, rays_d], 1).reshape(H, W, 6).permute(2,0,1)  # (6,H,W)
        rays = rays.permute(1,2,0).reshape(-1,6) # (H * W, 6)

        if style_img is None:
            feature_map, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, 
                                        white_bg = white_bg, render_feature=True, device=device)
        else:
            feature_map, _, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, 
                                white_bg = white_bg, render_feature=True, style_img=style_img, device=device)
        
        feature_map = feature_map.reshape(H, W, 256)[None,...].permute(0,3,1,2)

        recon_rgb = denormalize_vgg(tensorf.decoder(feature_map))
        recon_rgb = recon_rgb.permute(0,2,3,1).clamp(0,1)
        recon_rgb = recon_rgb.reshape(H, W, 3).cpu()
        recon_rgb = (recon_rgb.numpy() * 255).astype('uint8')
        rgb_maps.append(recon_rgb)

        vis_feature_map = torch.sigmoid(feature_map[:, [1,2,3], :, :].permute(0,2,3,1))
        vis_feature_map = vis_feature_map.squeeze().cpu()
        vis_feature_map = (vis_feature_map.numpy() * 255).astype('uint8')
        vis_feature_maps.append(vis_feature_map)
        
        if savePath is not None:
            imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', recon_rgb)
            imageio.imwrite(f'{savePath}/feature/feature_{idx:03d}.png', vis_feature_map)

    imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)
    imageio.mimwrite(f'{savePath}/feature/feature_video.mp4', np.stack(vis_feature_maps), fps=30, quality=8)

    if PSNRs:
        psnr = np.mean(np.asarray(PSNRs))
        if compute_extra_metrics:
            ssim = np.mean(np.asarray(ssims))
            l_a = np.mean(np.asarray(l_alex))
            l_v = np.mean(np.asarray(l_vgg))
            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
        else:
            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))


    return PSNRs

@torch.no_grad()
def evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
               white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
    PSNRs, rgb_maps, depth_maps = [], [], []
    ssims,l_alex,l_vgg=[],[],[]
    os.makedirs(savePath, exist_ok=True)
    os.makedirs(savePath+"/rgbd", exist_ok=True)

    try:
        tqdm._instances.clear()
    except Exception:
        pass

    near_far = test_dataset.near_far
    img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays_stack.shape[0] // N_vis,1)
    idxs = list(range(0, test_dataset.all_rays_stack.shape[0], img_eval_interval))
    for idx, samples in tqdm(enumerate(test_dataset.all_rays_stack[0::img_eval_interval]), file=sys.stdout):

        W, H = test_dataset.img_wh
        rays = samples.view(-1,samples.shape[-1])

        rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=4096, N_samples=N_samples,
                                        ndc_ray=ndc_ray, white_bg = white_bg, device=device)
        rgb_map = rgb_map.clamp(0.0, 1.0)

        rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()

        depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)
        if len(test_dataset.all_rgbs_stack):
            gt_rgb = test_dataset.all_rgbs_stack[idxs[idx]].view(H, W, 3)
            loss = torch.mean((rgb_map - gt_rgb) ** 2)
            PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))

            if compute_extra_metrics:
                ssim = rgb_ssim(rgb_map, gt_rgb, 1)
                l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', tensorf.device)
                l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', tensorf.device)
                ssims.append(ssim)
                l_alex.append(l_a)
                l_vgg.append(l_v)

        rgb_map = (rgb_map.numpy() * 255).astype('uint8')
        # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
        rgb_maps.append(rgb_map)
        depth_maps.append(depth_map)
        if savePath is not None:
            imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
            # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
            imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', depth_map)

    imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10)
    imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10)

    if PSNRs:
        psnr = np.mean(np.asarray(PSNRs))
        if compute_extra_metrics:
            ssim = np.mean(np.asarray(ssims))
            l_a = np.mean(np.asarray(l_alex))
            l_v = np.mean(np.asarray(l_vgg))
            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
        else:
            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))


    return PSNRs

@torch.no_grad()
def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
                    white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
    PSNRs, rgb_maps, depth_maps = [], [], []
    ssims,l_alex,l_vgg=[],[],[]
    os.makedirs(savePath, exist_ok=True)
    os.makedirs(savePath+"/rgbd", exist_ok=True)

    try:
        tqdm._instances.clear()
    except Exception:
        pass

    near_far = test_dataset.near_far
    for idx, c2w in tqdm(enumerate(c2ws)):

        W, H = test_dataset.img_wh

        c2w = torch.FloatTensor(c2w)
        rays_o, rays_d = get_rays(test_dataset.directions, c2w)  # both (h*w, 3)
        if ndc_ray:
            rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
        rays = torch.cat([rays_o, rays_d], 1)  # (h*w, 6)

        rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=8192, N_samples=N_samples,
                                        ndc_ray=ndc_ray, white_bg = white_bg, device=device)
        rgb_map = rgb_map.clamp(0.0, 1.0)

        rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()

        depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)

        rgb_map = (rgb_map.numpy() * 255).astype('uint8')
        # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
        rgb_maps.append(rgb_map)
        depth_maps.append(depth_map)
        if savePath is not None:
            imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
            rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
            imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)

    imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)
    imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8)

    if PSNRs:
        psnr = np.mean(np.asarray(PSNRs))
        if compute_extra_metrics:
            ssim = np.mean(np.asarray(ssims))
            l_a = np.mean(np.asarray(l_alex))
            l_v = np.mean(np.asarray(l_vgg))
            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
        else:
            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))


    return PSNRs



================================================
FILE: scripts/test.sh
================================================
CUDA_VISIBLE_DEVICES=$1 python train.py \
--config configs/llff.txt \
--ckpt log/trex/trex.th \ 
--render_only 1 
--render_test 1 

================================================
FILE: scripts/test_feature.sh
================================================
expname=trex
CUDA_VISIBLE_DEVICES=$1 python train_feature.py \
--config configs/llff_feature.txt \
--datadir ./data/nerf_llff_data/trex \
--expname $expname \
--ckpt ./log_feature/$expname/$expname.th \
--render_only 1 \
--render_test 0 \
--render_path 1 \
--chunk_size 1024

================================================
FILE: scripts/test_style.sh
================================================
expname=trex
CUDA_VISIBLE_DEVICES=$1 python train_style.py \
--config configs/llff_style.txt \
--datadir ./data/nerf_llff_data/trex \
--expname $expname \
--ckpt log_style/$expname/$expname.th \
--style_img path/to/reference/style/image \
--render_only 1 \
--render_train 0 \
--render_test 0 \
--render_path 1 \
--chunk_size 1024 \
--rm_weight_mask_thre 0.0001 \

================================================
FILE: scripts/train.sh
================================================
CUDA_VISIBLE_DEVICES=$1 python train.py --config=configs/llff.txt

================================================
FILE: scripts/train_feature.sh
================================================
CUDA_VISIBLE_DEVICES=$1 python train_feature.py --config=configs/llff_feature.txt

================================================
FILE: scripts/train_style.sh
================================================
CUDA_VISIBLE_DEVICES=$1 python train_style.py --config=configs/llff_style.txt

================================================
FILE: train.py
================================================

import os
from tqdm.auto import tqdm
from opt import config_parser



import json, random
from renderer import *
from utils import *
from torch.utils.tensorboard import SummaryWriter
import datetime

from dataLoader import dataset_dict
import sys



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

renderer = OctreeRender_trilinear_fast


class SimpleSampler:
    def __init__(self, total, batch):
        self.total = total
        self.batch = batch
        self.curr = total
        self.ids = None

    def nextids(self):
        self.curr+=self.batch
        if self.curr + self.batch > self.total:
            self.ids = torch.LongTensor(np.random.permutation(self.total))
            self.curr = 0
        return self.ids[self.curr:self.curr+self.batch]


@torch.no_grad()
def export_mesh(args):

    ckpt = torch.load(args.ckpt, map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device': device})
    tensorf = eval(args.model_name)(**kwargs)
    tensorf.load(ckpt)

    alpha,_ = tensorf.getDenseAlpha()
    convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005)


@torch.no_grad()
def render_test(args):
    # init dataset
    dataset = dataset_dict[args.dataset_name]
    test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
    white_bg = test_dataset.white_bg
    ndc_ray = args.ndc_ray

    if not os.path.exists(args.ckpt):
        print('the ckpt path does not exists!!')
        return

    ckpt = torch.load(args.ckpt, map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device': device})
    tensorf = eval(args.model_name)(**kwargs)
    tensorf.load(ckpt)

    logfolder = os.path.dirname(args.ckpt)
    if args.render_train:
        os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
        PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
        print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')

    if args.render_test:
        os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
        evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)

    if args.render_path:
        c2ws = test_dataset.render_path
        os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
        evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)

def reconstruction(args):

    # init dataset
    dataset = dataset_dict[args.dataset_name]
    train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)
    test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
    white_bg = train_dataset.white_bg
    near_far = train_dataset.near_far
    ndc_ray = args.ndc_ray

    # init resolution
    upsamp_list = args.upsamp_list
    update_AlphaMask_list = args.update_AlphaMask_list
    n_lamb_sigma = args.n_lamb_sigma
    n_lamb_sh = args.n_lamb_sh

    
    if args.add_timestamp:
        logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
    else:
        logfolder = f'{args.basedir}/{args.expname}'
    

    # init log file
    os.makedirs(logfolder, exist_ok=True)
    os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
    os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
    os.makedirs(f'{logfolder}/rgba', exist_ok=True)
    summary_writer = SummaryWriter(logfolder)



    # init parameters
    # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
    aabb = train_dataset.scene_bbox.to(device)
    reso_cur = N_to_reso(args.N_voxel_init, aabb)
    nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))


    if args.ckpt is not None:
        ckpt = torch.load(args.ckpt, map_location=device)
        kwargs = ckpt['kwargs']
        kwargs.update({'device':device})
        tensorf = eval(args.model_name)(**kwargs)
        tensorf.load(ckpt)
    else:
        tensorf = eval(args.model_name)(aabb, reso_cur, device,
                    density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far,
                    shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale,
                    pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct)


    grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
    if args.lr_decay_iters > 0:
        lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
    else:
        args.lr_decay_iters = args.n_iters
        lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)

    print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
    
    optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))


    #linear in logrithmic space
    if upsamp_list is not None:
        N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]


    torch.cuda.empty_cache()
    PSNRs,PSNRs_test = [],[0]

    allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs
    if not args.ndc_ray:
        allrays, allrgbs = tensorf.filtering_rays(allrays, allrgbs, bbox_only=True)
    trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)

    Ortho_reg_weight = args.Ortho_weight
    print("initial Ortho_reg_weight", Ortho_reg_weight)

    L1_reg_weight = args.L1_weight_inital
    print("initial L1_reg_weight", L1_reg_weight)
    TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app
    tvreg = TVLoss()
    print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app}")


    pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
    for iteration in pbar:


        ray_idx = trainingSampler.nextids()
        rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device)

        #rgb_map, alphas_map, depth_map, weights, uncertainty
        rgb_map, alphas_map, depth_map, weights, uncertainty = renderer(rays_train, tensorf, chunk=args.batch_size,
                                N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True)

        loss = torch.mean((rgb_map - rgb_train) ** 2)


        # loss
        total_loss = loss
        if Ortho_reg_weight > 0:
            loss_reg = tensorf.vector_comp_diffs()
            total_loss += Ortho_reg_weight*loss_reg
            summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
        if L1_reg_weight > 0:
            loss_reg_L1 = tensorf.density_L1()
            total_loss += L1_reg_weight*loss_reg_L1
            summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)

        if TV_weight_density>0:
            TV_weight_density *= lr_factor
            loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density
            total_loss = total_loss + loss_tv
            summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
        if TV_weight_app>0:
            TV_weight_app *= lr_factor
            loss_tv = tensorf.TV_loss_app(tvreg)*TV_weight_app
            total_loss = total_loss + loss_tv
            summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        loss = loss.detach().item()
        
        PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
        summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
        summary_writer.add_scalar('train/mse', loss, global_step=iteration)


        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * lr_factor

        # Print the current values of the losses.
        if iteration % args.progress_refresh_rate == 0:
            pbar.set_description(
                f'Iteration {iteration:05d}:'
                + f' train_psnr = {float(np.mean(PSNRs)):.2f}'
                + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
                + f' mse = {loss:.6f}'
            )
            PSNRs = []


        if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0:
            PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis,
                                    prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, compute_extra_metrics=False)
            summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)



        if update_AlphaMask_list is not None and iteration in update_AlphaMask_list:

            if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution
                reso_mask = reso_cur
            new_aabb = tensorf.updateAlphaMask(tuple(reso_mask))
            if iteration == update_AlphaMask_list[0]:
                tensorf.shrink(new_aabb)
                # tensorVM.alphaMask = None
                L1_reg_weight = args.L1_weight_rest
                print("continuing L1_reg_weight", L1_reg_weight)


            if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
                # filter rays outside the bbox
                allrays,allrgbs = tensorf.filtering_rays(allrays,allrgbs)
                trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)


        if upsamp_list is not None and iteration in upsamp_list:
            n_voxels = N_voxel_list.pop(0)
            reso_cur = N_to_reso(n_voxels, tensorf.aabb)
            nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
            tensorf.upsample_volume_grid(reso_cur)

            if args.lr_upsample_reset:
                print("reset lr to initial")
                lr_scale = 1 #0.1 ** (iteration / args.n_iters)
            else:
                lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
            grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale)
            optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
        

    tensorf.save(f'{logfolder}/{args.expname}.th')


    if args.render_train:
        os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
        PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
        print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')

    if args.render_test:
        os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
        PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
        summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
        print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')

    if args.render_path:
        c2ws = test_dataset.render_path
        # c2ws = test_dataset.poses
        print('========>',c2ws.shape)
        os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
        evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)


if __name__ == '__main__':

    torch.set_default_dtype(torch.float32)
    torch.manual_seed(20211202)
    np.random.seed(20211202)

    args = config_parser()
    print(args)

    if  args.export_mesh:
        export_mesh(args)

    if args.render_only and (args.render_test or args.render_path):
        render_test(args)
    else:
        reconstruction(args)



================================================
FILE: train_feature.py
================================================

import os
from unittest.mock import patch
from tqdm.auto import tqdm
from opt import config_parser



import json, random
from renderer import *
from utils import *
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
import torchvision.transforms.functional as TF
import datetime

from dataLoader import dataset_dict
import sys



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

renderer = OctreeRender_trilinear_fast


class SimpleSampler:
    def __init__(self, total, batch):
        self.total = total
        self.batch = batch
        self.curr = total
        self.ids = None

    def nextids(self):
        self.curr+=self.batch
        if self.curr + self.batch > self.total:
            self.ids = torch.LongTensor(np.random.permutation(self.total))
            self.curr = 0
        return self.ids[self.curr:self.curr+self.batch]

def InfiniteSampler(n):
    # i = 0
    i = n - 1
    order = np.random.permutation(n)
    while True:
        yield order[i]
        i += 1
        if i >= n:
            np.random.seed()
            order = np.random.permutation(n)
            i = 0

class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2 ** 31

@torch.no_grad()
def render_test(args):
    # init dataset
    dataset = dataset_dict[args.dataset_name]
    test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
    white_bg = test_dataset.white_bg
    ndc_ray = args.ndc_ray

    if not os.path.exists(args.ckpt):
        print('the ckpt path does not exists!!')
        return

    ckpt = torch.load(args.ckpt, map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device': device})
    tensorf = eval(args.model_name)(**kwargs)
    tensorf.change_to_feature_mod(args.n_lamb_sh ,device)
    tensorf.load(ckpt)
    tensorf.eval()
    tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre

    logfolder = os.path.dirname(args.ckpt)


    if args.render_train:
        os.makedirs(f'{logfolder}/{args.expname}/imgs_train_all', exist_ok=True)
        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
        evaluation_feature(train_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_train_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)

    if args.render_test:
        os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
        evaluation_feature(test_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_test_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)

    if args.render_path:
        c2ws = test_dataset.render_path
        os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
        evaluation_feature_path(test_dataset,tensorf, c2ws, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_path_all/',
                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)

def reconstruction(args):

    # init dataset
    dataset = dataset_dict[args.dataset_name]
    train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
    white_bg = train_dataset.white_bg
    near_far = train_dataset.near_far
    h_rays, w_rays = train_dataset.img_wh[1], train_dataset.img_wh[0]
    ndc_ray = args.ndc_ray

    patch_size = args.patch_size

    
    if args.add_timestamp:
        logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
    else:
        logfolder = f'{args.basedir}/{args.expname}'
    

    # init log file
    os.makedirs(logfolder, exist_ok=True)
    summary_writer = SummaryWriter(logfolder)



    # init parameters
    # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
    aabb = train_dataset.scene_bbox.to(device)
    # TODO: need to update reso_cur
    reso_cur = N_to_reso(args.N_voxel_init, aabb)
    nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))


    assert args.ckpt is not None, 'Have to be pre-trained to get density fielded!'

    ckpt = torch.load(args.ckpt, map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device':device})
    tensorf = eval(args.model_name)(**kwargs)
    tensorf.load(ckpt)
   
    tensorf.change_to_feature_mod(args.n_lamb_sh ,device)
    tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre

    train_dataset.prepare_feature_data(tensorf.encoder)

    grad_vars = tensorf.get_optparam_groups_feature_mod(args.lr_init, args.lr_basis)
    if args.lr_decay_iters > 0:
        lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
    else:
        args.lr_decay_iters = args.n_iters
        lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)

    print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
    
    optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))


    torch.cuda.empty_cache()
    PSNRs = []

    allrays, allfeatures = train_dataset.all_rays, train_dataset.all_features
    allrays_stack, allrgbs_stack = train_dataset.all_rays_stack, train_dataset.all_rgbs_stack
    if not args.ndc_ray:
        allrays, allfeatures = tensorf.filtering_rays(allrays, allfeatures, bbox_only=True)
    trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
    frameSampler = iter(InfiniteSamplerWrapper(allrays_stack.size(0))) # every next(sampler) returns a frame index


    TV_weight_feature = args.TV_weight_feature
    tvreg = TVLoss()
    print(f"initial TV_weight_feature: {TV_weight_feature}")


    pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
    for iteration in pbar:

        feature_loss, pixel_loss = 0., 0.
        if iteration%2==0:
            ray_idx = trainingSampler.nextids()
            rays_train, features_train = allrays[ray_idx], allfeatures[ray_idx].to(device)

            feature_map, _ = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg = white_bg, 
                                ndc_ray=ndc_ray, render_feature=True, device=device, is_train=True)

            feature_loss = torch.mean((feature_map - features_train) ** 2)
        else:
            frame_idx = next(frameSampler)
            start_h = np.random.randint(0, h_rays-patch_size+1)
            start_w = np.random.randint(0, w_rays-patch_size+1)
            if white_bg:
                # move random sampled patches into center
                mid_h, mid_w = (h_rays-patch_size+1)/2, (w_rays-patch_size+1)/2
                if mid_h-start_h>=1:
                    start_h += np.random.randint(0, mid_h-start_h)
                elif mid_h-start_h<=-1:
                    start_h += np.random.randint(mid_h-start_h, 0)
                if mid_w-start_w>=1:
                    start_w += np.random.randint(0, mid_w-start_w)
                elif mid_w-start_w<=-1:
                    start_w += np.random.randint(mid_w-start_w, 0)

            rays_train = allrays_stack[frame_idx, start_h:start_h+patch_size, 
                                                    start_w:start_w+patch_size, :].reshape(-1, 6).to(device)
            # [patch*patch, 6]
            
            rgbs_train = allrgbs_stack[frame_idx, start_h:(start_h+patch_size), 
                                                  start_w:(start_w+patch_size), :].to(device)
            # [patch, patch, 3]

            feature_map, _ = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg=white_bg, 
                                ndc_ray=ndc_ray, render_feature=True, device=device, is_train=True)

            feature_map = feature_map.reshape(patch_size, patch_size, 256)[None,...].permute(0,3,1,2)
            recon_rgb = tensorf.decoder(feature_map)

            rgbs_train = rgbs_train[None,...].permute(0,3,1,2)
            img_enc = tensorf.encoder(normalize_vgg(rgbs_train))
            recon_rgb_enc = tensorf.encoder(recon_rgb)
            
            feature_loss =(F.mse_loss(recon_rgb_enc.relu4_1, img_enc.relu4_1) +
                           F.mse_loss(recon_rgb_enc.relu3_1, img_enc.relu3_1)) / 10

            recon_rgb = denormalize_vgg(recon_rgb)

            pixel_loss = torch.mean((recon_rgb - rgbs_train) ** 2)

        total_loss = pixel_loss + feature_loss

        # loss
        # NOTE: Calculate feature TV loss rather than appearence TV loss
        if TV_weight_feature>0:
            TV_weight_feature *= lr_factor
            loss_tv = tensorf.TV_loss_feature(tvreg)*TV_weight_feature
            total_loss = total_loss + loss_tv
            summary_writer.add_scalar('train/reg_tv_feature', loss_tv.detach().item(), global_step=iteration)

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if pixel_loss == 0:
            feature_loss = feature_loss.detach().item()
            PSNRs.append(-10.0 * np.log(feature_loss) / np.log(10.0))
            summary_writer.add_scalar('train/PSNR_feature', PSNRs[-1], global_step=iteration)
            summary_writer.add_scalar('train/mse_feature', feature_loss, global_step=iteration)
        else:
            pixel_loss = pixel_loss.detach().item()
            PSNRs.append(-10.0 * np.log(pixel_loss) / np.log(10.0))
            summary_writer.add_scalar('train/PSNR_pixel', PSNRs[-1], global_step=iteration)
            summary_writer.add_scalar('train/mse_pixel', pixel_loss, global_step=iteration)
            summary_writer.add_scalar('train/mse_recon_feature', feature_loss.detach().item(), global_step=iteration)


        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * lr_factor

        # Print the current values of the losses.
        if iteration % args.progress_refresh_rate == 0:
            pbar.set_description(
                f'Iteration {iteration:05d}:'
                + f' psnr = {float(np.mean(PSNRs)):.2f}'
            )
            PSNRs = []

        if iteration % (args.progress_refresh_rate*20) == 1:
            summary_writer.add_image('output', make_grid([rgbs_train.squeeze(), 
                                                        recon_rgb.clamp(0, 1).squeeze()],  
                                                        nrow=2, padding=0, normalize=False),
                                                        global_step=iteration)
        

    tensorf.save(f'{logfolder}/{args.expname}.th')

if __name__ == '__main__':

    torch.set_default_dtype(torch.float32)
    torch.manual_seed(20211202)
    np.random.seed(20211202)

    args = config_parser()
    print(args)

    if args.render_only:
        render_test(args)
    else:
        reconstruction(args)



================================================
FILE: train_style.py
================================================

import os
from tqdm.auto import tqdm
from opt import config_parser
from PIL import Image, ImageFile
from pathlib import Path
from torchvision.utils import make_grid
import torchvision.transforms.functional as TF

from renderer import *
from utils import *
from torch.utils.tensorboard import SummaryWriter
import datetime

from dataLoader import dataset_dict
from dataLoader.styleLoader import getDataLoader

from models.styleModules import cal_mse_content_loss, cal_adain_style_loss

import sys



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

renderer = OctreeRender_trilinear_fast
depth_renderer = OctreeRender_trilinear_fast_depth


class SimpleSampler:
    def __init__(self, total, batch):
        self.total = total
        self.batch = batch
        self.curr = total
        self.ids = None

    def nextids(self):
        self.curr+=self.batch
        if self.curr + self.batch > self.total:
            self.ids = torch.LongTensor(np.random.permutation(self.total))
            self.curr = 0
        return self.ids[self.curr:self.curr+self.batch]

def InfiniteSampler(n):
    # i = 0
    i = n - 1
    order = np.random.permutation(n)
    while True:
        yield order[i]
        i += 1
        if i >= n:
            np.random.seed()
            order = np.random.permutation(n)
            i = 0

class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):
    def __init__(self, num_samples):
        self.num_samples = num_samples

    def __iter__(self):
        return iter(InfiniteSampler(self.num_samples))

    def __len__(self):
        return 2 ** 31

@torch.no_grad()
def render_test(args):
    # init dataset
    dataset = dataset_dict[args.dataset_name]
    ndc_ray = args.ndc_ray

    if not os.path.exists(args.ckpt):
        print('the ckpt path does not exists!!')
        return

    assert args.style_img is not None, 'Must specify a style image!'

    ckpt = torch.load(args.ckpt, map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device': device})
    tensorf = eval(args.model_name)(**kwargs)
    tensorf.change_to_feature_mod(args.n_lamb_sh, device)
    tensorf.change_to_style_mod(device)
    tensorf.load(ckpt)
    tensorf.eval()
    tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre

    logfolder = os.path.dirname(args.ckpt)

    trans = T.Compose([T.Resize(size=(256,256)), T.ToTensor()])
    style_img = trans(Image.open(args.style_img)).cuda()[None, ...]
    style_name = Path(args.style_img).stem

    if args.render_train:
        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
        os.makedirs(f'{logfolder}/{args.expname}/imgs_train_all/{style_name}', exist_ok=True)
        evaluation_feature(train_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_train_all/{style_name}',
                                N_vis=-1, N_samples=-1, white_bg = train_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device)
    
    if args.render_test:
        test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
        os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all/{style_name}', exist_ok=True)
        evaluation_feature(test_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_test_all/{style_name}',
                                N_vis=-1, N_samples=-1, white_bg = test_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device)

    if args.render_path:
        test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
        c2ws = test_dataset.render_path
        os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all/{style_name}', exist_ok=True)
        evaluation_feature_path(test_dataset, tensorf, c2ws, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_path_all/{style_name}',
                N_vis=-1, N_samples=-1, white_bg = test_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device)

def reconstruction(args):

    # init dataset
    dataset = dataset_dict[args.dataset_name]
    train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
    white_bg = train_dataset.white_bg
    near_far = train_dataset.near_far
    h_rays, w_rays = train_dataset.img_wh[1], train_dataset.img_wh[0]
    ndc_ray = args.ndc_ray

    patch_size = args.patch_size # ground truth image patch size when training

    Image.MAX_IMAGE_PIXELS = None  # Disable DecompressionBombError
    ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated
    style_loader = getDataLoader(args.wikiartdir, batch_size=1, sampler=InfiniteSamplerWrapper, 
                    image_side_length=256, num_workers=2)
    style_iter = iter(style_loader)
    
    if args.add_timestamp:
        logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
    else:
        logfolder = f'{args.basedir}/{args.expname}'
    

    # init log file
    os.makedirs(logfolder, exist_ok=True)
    summary_writer = SummaryWriter(logfolder)

    # init parameters
    # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])
    aabb = train_dataset.scene_bbox.to(device)
    # TODO: need to update reso_cur
    reso_cur = N_to_reso(args.N_voxel_init, aabb)
    nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))


    assert args.ckpt is not None, 'Have to be pre-trained to get density fielded!'

    ckpt = torch.load(args.ckpt, map_location=device)
    kwargs = ckpt['kwargs']
    kwargs.update({'device':device})
    tensorf = eval(args.model_name)(**kwargs)
    tensorf.change_to_feature_mod(args.n_lamb_sh, device)
    tensorf.load(ckpt)
    tensorf.change_to_style_mod(device)
    tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre

    tvreg = TVLoss()

    grad_vars = tensorf.get_optparam_groups_style_mod(args.lr_basis, args.lr_finetune)
    if args.lr_decay_iters > 0:
        lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
    else:
        args.lr_decay_iters = args.n_iters
        lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)

    print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
    
    tensorf.train()
    optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))

    torch.cuda.empty_cache()

    allrays_stack, allrgbs_stack = train_dataset.all_rays_stack, train_dataset.all_rgbs_stack
    frameSampler = iter(InfiniteSamplerWrapper(allrays_stack.size(0))) # every next(sampler) returns a frame index

    pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
    for iteration in pbar:

        # get style_img, this style_img has NOT been normalized according to the pretrained VGGmodel
        style_img = next(style_iter)[0].to(device)

        # randomly sample patch_size*patch_size patch from given frame
        frame_idx = next(frameSampler)
        start_h = np.random.randint(0, h_rays-patch_size+1)
        start_w = np.random.randint(0, w_rays-patch_size+1)
        if white_bg:
            # move random sampled patches into center
            mid_h, mid_w = (h_rays-patch_size+1)/2, (w_rays-patch_size+1)/2
            if mid_h-start_h>=1:
                start_h += np.random.randint(0, mid_h-start_h)
            elif mid_h-start_h<=-1:
                start_h += np.random.randint(mid_h-start_h, 0)
            if mid_w-start_w>=1:
                start_w += np.random.randint(0, mid_w-start_w)
            elif mid_w-start_w<=-1:
                start_w += np.random.randint(mid_w-start_w, 0)

        rays_train = allrays_stack[frame_idx, start_h:start_h+patch_size, start_w:start_w+patch_size, :]\
                            .reshape(-1, 6).to(device)
        # [patch*patch, 6]
        
        rgbs_train = allrgbs_stack[frame_idx, start_h:(start_h+patch_size), 
                                            start_w:(start_w+patch_size), :].to(device)
        # [patch, patch, 3]

        feature_map, acc_map, style_feature = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg = white_bg, 
                                ndc_ray=ndc_ray, render_feature=True, style_img=style_img, device=device, is_train=True)

        feature_map = feature_map.reshape(patch_size, patch_size, 256)[None,...].permute(0,
Download .txt
gitextract_j81cq_2h/

├── README.md
├── configs/
│   ├── llff.txt
│   ├── llff_feature.txt
│   ├── llff_style.txt
│   ├── nerf_synthetic.txt
│   ├── nerf_synthetic_feature.txt
│   └── nerf_synthetic_style.txt
├── dataLoader/
│   ├── __init__.py
│   ├── blender.py
│   ├── colmap2nerf.py
│   ├── llff.py
│   ├── nsvf.py
│   ├── ray_utils.py
│   ├── styleLoader.py
│   ├── tankstemple.py
│   └── your_own_data.py
├── extra/
│   ├── auto_run_paramsets.py
│   └── compute_metrics.py
├── models/
│   ├── VGG.py
│   ├── __init__.py
│   ├── sh.py
│   ├── styleModules.py
│   ├── tensoRF.py
│   └── tensorBase.py
├── opt.py
├── renderer.py
├── scripts/
│   ├── test.sh
│   ├── test_feature.sh
│   ├── test_style.sh
│   ├── train.sh
│   ├── train_feature.sh
│   └── train_style.sh
├── train.py
├── train_feature.py
├── train_style.py
└── utils.py
Download .txt
SYMBOL INDEX (246 symbols across 21 files)

FILE: dataLoader/blender.py
  class BlenderDataset (line 13) | class BlenderDataset(Dataset):
    method __init__ (line 14) | def __init__(self, datadir, split='train', downsample=1.0, is_stack=Fa...
    method read_depth (line 35) | def read_depth(self, filename):
    method read_meta (line 39) | def read_meta(self):
    method prepare_feature_data (line 103) | def prepare_feature_data(self, encoder, chunk=8):
    method define_transforms (line 125) | def define_transforms(self):
    method define_proj_mat (line 128) | def define_proj_mat(self):
    method world2ndc (line 131) | def world2ndc(self,points,lindisp=None):
    method __len__ (line 135) | def __len__(self):
    method __getitem__ (line 138) | def __getitem__(self, idx):

FILE: dataLoader/colmap2nerf.py
  function parse_args (line 23) | def parse_args():
  function do_system (line 40) | def do_system(arg):
  function run_ffmpeg (line 47) | def run_ffmpeg(args):
  function run_colmap (line 69) | def run_colmap(args):
  function variance_of_laplacian (line 99) | def variance_of_laplacian(image):
  function sharpness (line 102) | def sharpness(imagePath):
  function qvec2rotmat (line 108) | def qvec2rotmat(qvec):
  function rotmat (line 125) | def rotmat(a, b):
  function closest_point_2_lines (line 133) | def closest_point_2_lines(oa, da, ob, db): # returns point closest to bo...

FILE: dataLoader/llff.py
  function normalize (line 12) | def normalize(v):
  function average_poses (line 17) | def average_poses(poses):
  function center_poses (line 54) | def center_poses(poses, blender2opencv):
  function viewmatrix (line 81) | def viewmatrix(z, up, pos):
  function render_path_spiral (line 91) | def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=...
  function get_spiral (line 102) | def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
  function get_interpolation_path (line 121) | def get_interpolation_path(c2ws_all, steps=30):
  class LLFFDataset (line 147) | class LLFFDataset(Dataset):
    method __init__ (line 148) | def __init__(self, datadir, split='train', downsample=4, is_stack=Fals...
    method read_meta (line 168) | def read_meta(self):
    method prepare_feature_data (line 257) | def prepare_feature_data(self, encoder, chunk=8):
    method define_transforms (line 279) | def define_transforms(self):
    method __len__ (line 282) | def __len__(self):
    method __getitem__ (line 285) | def __getitem__(self, idx):

FILE: dataLoader/nsvf.py
  function pose_spherical (line 29) | def pose_spherical(theta, phi, radius):
  class NSVF (line 36) | class NSVF(Dataset):
    method __init__ (line 38) | def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800...
    method bbox2corners (line 56) | def bbox2corners(self):
    method read_meta (line 63) | def read_meta(self):
    method define_transforms (line 132) | def define_transforms(self):
    method define_proj_mat (line 135) | def define_proj_mat(self):
    method world2ndc (line 138) | def world2ndc(self, points):
    method __len__ (line 142) | def __len__(self):
    method __getitem__ (line 147) | def __getitem__(self, idx):

FILE: dataLoader/ray_utils.py
  function depth2dist (line 9) | def depth2dist(z_vals, cos_angle):
  function ndc2dist (line 18) | def ndc2dist(ndc_pts, cos_angle):
  function get_ray_directions (line 24) | def get_ray_directions(H, W, focal, center=None):
  function get_ray_directions_blender (line 45) | def get_ray_directions_blender(H, W, focal, center=None):
  function get_rays (line 66) | def get_rays(directions, c2w):
  function ndc_rays_blender (line 90) | def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
  function ndc_rays (line 109) | def ndc_rays(H, W, focal, near, rays_o, rays_d):
  function sample_pdf (line 129) | def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
  function dda (line 174) | def dda(rays_o, rays_d, bbox_3D):
  function ray_marcher (line 184) | def ray_marcher(rays,
  function read_pfm (line 231) | def read_pfm(filename):
  function ndc_bbox (line 269) | def ndc_bbox(all_rays):
  function denormalize_vgg (line 281) | def denormalize_vgg(img):

FILE: dataLoader/styleLoader.py
  function getDataLoader (line 6) | def getDataLoader(dataset_path, batch_size, sampler, image_side_length=2...

FILE: dataLoader/tankstemple.py
  function circle (line 12) | def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
  function cross (line 21) | def cross(x, y, axis=0):
  function normalize (line 26) | def normalize(x, axis=-1, order=2):
  function cat (line 38) | def cat(x, axis=1):
  function look_at_rotation (line 44) | def look_at_rotation(camera_position, at=None, up=None, inverse=False, c...
  function gen_path (line 77) | def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
  class TanksTempleDataset (line 87) | class TanksTempleDataset(Dataset):
    method __init__ (line 89) | def __init__(self, datadir, split='train', downsample=4.0, wh=[1920,10...
    method bbox2corners (line 108) | def bbox2corners(self):
    method read_meta (line 115) | def read_meta(self):
    method prepare_feature_data (line 196) | def prepare_feature_data(self, encoder, chunk=4):
    method define_transforms (line 219) | def define_transforms(self):
    method define_proj_mat (line 222) | def define_proj_mat(self):
    method world2ndc (line 225) | def world2ndc(self, points):
    method __len__ (line 229) | def __len__(self):
    method __getitem__ (line 234) | def __getitem__(self, idx):

FILE: dataLoader/your_own_data.py
  class YourOwnDataset (line 13) | class YourOwnDataset(Dataset):
    method __init__ (line 14) | def __init__(self, datadir, split='train', downsample=1.0, is_stack=Fa...
    method read_depth (line 35) | def read_depth(self, filename):
    method read_meta (line 39) | def read_meta(self):
    method define_transforms (line 102) | def define_transforms(self):
    method define_proj_mat (line 105) | def define_proj_mat(self):
    method world2ndc (line 108) | def world2ndc(self,points,lindisp=None):
    method __len__ (line 112) | def __len__(self):
    method __getitem__ (line 115) | def __getitem__(self, idx):

FILE: extra/auto_run_paramsets.py
  function getFolderLocker (line 7) | def getFolderLocker(logFolder):
  function releaseFolderLocker (line 15) | def releaseFolderLocker(logFolder):
  function getStopFolder (line 18) | def getStopFolder(logFolder):
  function get_param_str (line 22) | def get_param_str(key, val):
  function get_param_list (line 28) | def get_param_list(param_dict):
  function run_program (line 167) | def run_program(gpu, expname, param):

FILE: extra/compute_metrics.py
  function init_lpips (line 11) | def init_lpips(net_name, device):
  function rgb_lpips (line 17) | def rgb_lpips(np_gt, np_im, net_name, device):
  function findItem (line 25) | def findItem(items, target):
  function rgb_ssim (line 34) | def rgb_ssim(img0, img1, max_val,

FILE: models/VGG.py
  class Encoder (line 8) | class Encoder(nn.Module):
    method __init__ (line 9) | def __init__(self):
    method forward (line 23) | def forward(self, x):
  class Decoder (line 34) | class Decoder(nn.Module):
    method __init__ (line 38) | def __init__(self, ckpt_path=None):
    method forward (line 67) | def forward(self, x):
  class DownBlock (line 74) | class DownBlock(nn.Module):
    method __init__ (line 76) | def __init__(self, in_dim, out_dim, down='conv'):
    method forward (line 93) | def forward(self, x):
  class UpBlock (line 98) | class UpBlock(nn.Module):
    method __init__ (line 100) | def __init__(self, in_dim, out_dim, skip_dim=None, up='nearest'):
    method _pad (line 125) | def _pad(self, x, y):
    method forward (line 142) | def forward(self, x, skip=None):
  class UNetDecoder (line 150) | class UNetDecoder(nn.Module):
    method __init__ (line 152) | def __init__(self, in_dim=256):
    method forward (line 184) | def forward(self, feats):
  class PlainDecoder (line 197) | class PlainDecoder(nn.Module):
    method __init__ (line 198) | def __init__(self) -> None:
    method forward (line 219) | def forward(self, x):

FILE: models/sh.py
  function eval_sh (line 34) | def eval_sh(deg, sh, dirs):
  function eval_sh_bases (line 87) | def eval_sh_bases(deg, dirs):

FILE: models/styleModules.py
  function calc_mean_std (line 6) | def calc_mean_std(x, eps=1e-8):
  function cal_adain_style_loss (line 16) | def cal_adain_style_loss(x, y):
  function cal_mse_content_loss (line 29) | def cal_mse_content_loss(x, y):
  class LearnableIN (line 34) | class LearnableIN(nn.Module):
    method __init__ (line 38) | def __init__(self, dim=256):
    method forward (line 42) | def forward(self, x):
  class SimpleLinearStylizer (line 47) | class SimpleLinearStylizer(nn.Module):
    method __init__ (line 48) | def __init__(self, input_dim=256, embed_dim=32, n_layers=3 ) -> None:
    method _vectorized_covariance (line 76) | def _vectorized_covariance(self, x):
    method get_content_matrix (line 81) | def get_content_matrix(self, c):
    method get_style_mean_std_matrix (line 98) | def get_style_mean_std_matrix(self, s):
    method transform_content_3D (line 117) | def transform_content_3D(self, c):
    method transfer_style_2D (line 131) | def transfer_style_2D(self, s_mean_std_mat, c, acc_map):
  class AdaAttN (line 151) | class AdaAttN(nn.Module):
    method __init__ (line 154) | def __init__(self, qk_dim, v_dim):
    method forward (line 166) | def forward(self, q, k):
  class AdaAttN_new_IN (line 199) | class AdaAttN_new_IN(nn.Module):
    method __init__ (line 202) | def __init__(self, qk_dim, v_dim):
    method forward (line 215) | def forward(self, q, k):
  class AdaAttN_woin (line 248) | class AdaAttN_woin(nn.Module):
    method __init__ (line 251) | def __init__(self, qk_dim, v_dim):
    method forward (line 263) | def forward(self, q, k):

FILE: models/tensoRF.py
  class TensorVMSplit (line 5) | class TensorVMSplit(TensorBase):
    method __init__ (line 6) | def __init__(self, aabb, gridSize, device, **kargs):
    method change_to_feature_mod (line 11) | def change_to_feature_mod(self, feature_n_comp, device):
    method change_to_style_mod (line 29) | def change_to_style_mod(self, device='cuda'):
    method init_svd_volume (line 45) | def init_svd_volume(self, res, device):
    method init_feature_svd (line 50) | def init_feature_svd(self, device):
    method init_one_svd (line 54) | def init_one_svd(self, n_component, gridSize, scale, device):
    method get_optparam_groups (line 68) | def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_netwo...
    method get_optparam_groups_feature_mod (line 76) | def get_optparam_groups_feature_mod(self, lr_init_spatialxyz, lr_init_...
    method get_optparam_groups_style_mod (line 83) | def get_optparam_groups_style_mod(self, lr_init_network, lr_finetune):
    method vectorDiffs (line 92) | def vectorDiffs(self, vector_comps):
    method vector_comp_diffs (line 103) | def vector_comp_diffs(self):
    method density_L1 (line 106) | def density_L1(self):
    method TV_loss_density (line 112) | def TV_loss_density(self, reg):
    method TV_loss_app (line 118) | def TV_loss_app(self, reg):
    method TV_loss_feature (line 124) | def TV_loss_feature(self, reg):
    method compute_densityfeature (line 132) | def compute_densityfeature(self, xyz_sampled):
    method compute_appfeature (line 149) | def compute_appfeature(self, xyz_sampled):
    method compute_feature (line 167) | def compute_feature(self, xyz_sampled):
    method render_feature_map (line 187) | def render_feature_map(self, rays_chunk, s_mean_std_mat=None, is_train...
    method render_depth_map (line 247) | def render_depth_map(self, rays_chunk, is_train=False, ndc_ray=False, ...
    method up_sampling_VM (line 287) | def up_sampling_VM(self, plane_coef, line_coef, res_target):
    method upsample_volume_grid (line 302) | def upsample_volume_grid(self, res_target):
    method shrink (line 310) | def shrink(self, new_aabb):

FILE: models/tensorBase.py
  function positional_encoding (line 9) | def positional_encoding(positions, freqs):
  function raw2alpha (line 17) | def raw2alpha(sigma, dist):
  function SHRender (line 27) | def SHRender(xyz_sampled, viewdirs, features):
  function RGBRender (line 34) | def RGBRender(xyz_sampled, viewdirs, features):
  class AlphaGridMask (line 39) | class AlphaGridMask(torch.nn.Module):
    method __init__ (line 40) | def __init__(self, device, aabb, alpha_volume):
    method sample_alpha (line 50) | def sample_alpha(self, xyz_sampled):
    method normalize_coord (line 56) | def normalize_coord(self, xyz_sampled):
  class MLPRender_Fea (line 60) | class MLPRender_Fea(torch.nn.Module):
    method __init__ (line 61) | def __init__(self,inChanel, viewpe=6, feape=6, featureC=128):
    method forward (line 74) | def forward(self, pts, viewdirs, features):
  class MLPRender_PE (line 86) | class MLPRender_PE(torch.nn.Module):
    method __init__ (line 87) | def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128):
    method forward (line 100) | def forward(self, pts, viewdirs, features):
  class MLPRender (line 112) | class MLPRender(torch.nn.Module):
    method __init__ (line 113) | def __init__(self,inChanel, viewpe=6, featureC=128):
    method forward (line 126) | def forward(self, pts, viewdirs, features):
  class TensorBase (line 138) | class TensorBase(torch.nn.Module):
    method __init__ (line 139) | def __init__(self, aabb, gridSize, device, density_n_comp = 8, appeara...
    method init_render_func (line 175) | def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featu...
    method update_stepSize (line 192) | def update_stepSize(self, gridSize):
    method init_svd_volume (line 205) | def init_svd_volume(self, res, device):
    method compute_features (line 208) | def compute_features(self, xyz_sampled):
    method compute_densityfeature (line 211) | def compute_densityfeature(self, xyz_sampled):
    method compute_appfeature (line 214) | def compute_appfeature(self, xyz_sampled):
    method normalize_coord (line 217) | def normalize_coord(self, xyz_sampled):
    method get_optparam_groups (line 220) | def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network ...
    method get_kwargs (line 223) | def get_kwargs(self):
    method save (line 247) | def save(self, path):
    method load (line 257) | def load(self, ckpt):
    method sample_ray_ndc (line 265) | def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
    method sample_ray (line 276) | def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1):
    method shrink (line 298) | def shrink(self, new_aabb, voxel_size):
    method getDenseAlpha (line 302) | def getDenseAlpha(self,gridSize=None):
    method updateAlphaMask (line 320) | def updateAlphaMask(self, gridSize=(200,200,200)):
    method filtering_rays (line 346) | def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=1024...
    method feature2density (line 378) | def feature2density(self, density_features):
    method compute_alpha (line 385) | def compute_alpha(self, xyz_locs, length=1):
    method forward (line 408) | def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=F...

FILE: opt.py
  function config_parser (line 3) | def config_parser(cmd=None):

FILE: renderer.py
  function OctreeRender_trilinear_fast (line 9) | def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1,...
  function OctreeRender_trilinear_fast_depth (line 40) | def OctreeRender_trilinear_fast_depth(rays, tensorf, chunk=4096, N_sampl...
  function evaluation_feature (line 53) | def evaluation_feature(test_dataset, tensorf, args, renderer, chunk_size...
  function evaluation_feature_path (line 139) | def evaluation_feature_path(test_dataset, tensorf, c2ws, renderer, chunk...
  function evaluation (line 203) | def evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vi...
  function evaluation_path (line 269) | def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None,...

FILE: train.py
  class SimpleSampler (line 24) | class SimpleSampler:
    method __init__ (line 25) | def __init__(self, total, batch):
    method nextids (line 31) | def nextids(self):
  function export_mesh (line 40) | def export_mesh(args):
  function render_test (line 53) | def render_test(args):
  function reconstruction (line 89) | def reconstruction(args):

FILE: train_feature.py
  class SimpleSampler (line 27) | class SimpleSampler:
    method __init__ (line 28) | def __init__(self, total, batch):
    method nextids (line 34) | def nextids(self):
  function InfiniteSampler (line 41) | def InfiniteSampler(n):
  class InfiniteSamplerWrapper (line 53) | class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):
    method __init__ (line 54) | def __init__(self, num_samples):
    method __iter__ (line 57) | def __iter__(self):
    method __len__ (line 60) | def __len__(self):
  function render_test (line 64) | def render_test(args):
  function reconstruction (line 104) | def reconstruction(args):

FILE: train_style.py
  class SimpleSampler (line 30) | class SimpleSampler:
    method __init__ (line 31) | def __init__(self, total, batch):
    method nextids (line 37) | def nextids(self):
  function InfiniteSampler (line 44) | def InfiniteSampler(n):
  class InfiniteSamplerWrapper (line 56) | class InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):
    method __init__ (line 57) | def __init__(self, num_samples):
    method __iter__ (line 60) | def __iter__(self):
    method __len__ (line 63) | def __len__(self):
  function render_test (line 67) | def render_test(args):
  function reconstruction (line 113) | def reconstruction(args):

FILE: utils.py
  function visualize_depth_numpy (line 11) | def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
  function init_log (line 28) | def init_log(log, keys):
  function visualize_depth (line 33) | def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
  function N_to_reso (line 53) | def N_to_reso(n_voxels, bbox):
  function cal_n_samples (line 59) | def cal_n_samples(reso, step_ratio=0.5):
  function init_lpips (line 66) | def init_lpips(net_name, device):
  function rgb_lpips (line 72) | def rgb_lpips(np_gt, np_im, net_name, device):
  function findItem (line 80) | def findItem(items, target):
  function rgb_ssim (line 89) | def rgb_ssim(img0, img1, max_val,
  class TVLoss (line 139) | class TVLoss(nn.Module):
    method __init__ (line 140) | def __init__(self):
    method forward (line 143) | def forward(self,x):
    method _tensor_size (line 164) | def _tensor_size(self,t):
  function convert_sdf_samples_to_ply (line 171) | def convert_sdf_samples_to_ply(
  function get_checkerboard (line 355) | def get_checkerboard(fg, n=8):
Condensed preview — 36 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (219K chars).
[
  {
    "path": "README.md",
    "chars": 4457,
    "preview": "# [*CVPR 2023*] StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields\n## [Project page](https://kunhao-liu.gith"
  },
  {
    "path": "configs/llff.txt",
    "chars": 544,
    "preview": "\ndataset_name = llff\ndatadir = ./data/nerf_llff_data/trex\nexpname = trex\nbasedir = ./log\n\ndownsample_train = 4.0\nndc_ray"
  },
  {
    "path": "configs/llff_feature.txt",
    "chars": 457,
    "preview": "dataset_name = llff\ndatadir = ./data/nerf_llff_data/trex\nckpt = ./log/trex/trex.th\nexpname = trex\nbasedir = ./log_featur"
  },
  {
    "path": "configs/llff_style.txt",
    "chars": 480,
    "preview": "dataset_name = llff\ndatadir = ./data/nerf_llff_data/trex\nckpt = ./log_feature/trex/trex.th\nexpname = trex\nbasedir = ./lo"
  },
  {
    "path": "configs/nerf_synthetic.txt",
    "chars": 541,
    "preview": "\ndataset_name = blender\ndatadir = ./data/nerf_synthetic/lego\nexpname =  lego\nbasedir = ./log\n\nn_iters = 30000\nbatch_size"
  },
  {
    "path": "configs/nerf_synthetic_feature.txt",
    "chars": 466,
    "preview": "dataset_name = blender\ndatadir = ./data/nerf_synthetic/lego\nckpt = ./log/lego/lego.th\nexpname = lego\nbasedir = ./log_fea"
  },
  {
    "path": "configs/nerf_synthetic_style.txt",
    "chars": 391,
    "preview": "dataset_name = blender\ndatadir = ./data/nerf_synthetic/lego\nckpt = ./log_feature/lego/lego.th\nexpname = lego\nbasedir = ."
  },
  {
    "path": "dataLoader/__init__.py",
    "chars": 375,
    "preview": "from .llff import LLFFDataset\nfrom .blender import BlenderDataset\nfrom .nsvf import NSVF\nfrom .tankstemple import TanksT"
  },
  {
    "path": "dataLoader/blender.py",
    "chars": 6247,
    "preview": "import torch,cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\n"
  },
  {
    "path": "dataLoader/colmap2nerf.py",
    "chars": 10564,
    "preview": "#!/usr/bin/env python3\n\n# Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and "
  },
  {
    "path": "dataLoader/llff.py",
    "chars": 11105,
    "preview": "import torch\nfrom torch.utils.data import Dataset\nimport glob\nimport numpy as np\nimport os\nfrom PIL import Image\nfrom to"
  },
  {
    "path": "dataLoader/nsvf.py",
    "chars": 6560,
    "preview": "import torch\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision"
  },
  {
    "path": "dataLoader/ray_utils.py",
    "chars": 10543,
    "preview": "import torch, re\nimport numpy as np\nfrom torch import searchsorted\nfrom kornia import create_meshgrid\n\n\n# from utils imp"
  },
  {
    "path": "dataLoader/styleLoader.py",
    "chars": 626,
    "preview": "from torch.utils.data import DataLoader\nfrom torchvision import datasets\nimport torchvision.transforms as T\n\n\ndef getDat"
  },
  {
    "path": "dataLoader/tankstemple.py",
    "chars": 9953,
    "preview": "import torch\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision"
  },
  {
    "path": "dataLoader/your_own_data.py",
    "chars": 5026,
    "preview": "import torch,cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\n"
  },
  {
    "path": "extra/auto_run_paramsets.py",
    "chars": 8274,
    "preview": "import os\nimport threading, queue\nimport numpy as np\nimport time\n\n\ndef getFolderLocker(logFolder):\n    while True:\n     "
  },
  {
    "path": "extra/compute_metrics.py",
    "chars": 6532,
    "preview": "import os, math\nimport numpy as np\nimport scipy.signal\nfrom typing import List, Optional\nfrom PIL import Image\nimport os"
  },
  {
    "path": "models/VGG.py",
    "chars": 6818,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom collections import namedtuple\nimport torchvision"
  },
  {
    "path": "models/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "models/sh.py",
    "chars": 5231,
    "preview": "import torch\n\n################## sh function ##################\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n"
  },
  {
    "path": "models/styleModules.py",
    "chars": 10046,
    "preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport random\n\ndef calc_mean_std(x, eps=1e-8):\n      "
  },
  {
    "path": "models/tensoRF.py",
    "chars": 16669,
    "preview": "from .tensorBase import *\nfrom .VGG import Encoder, Decoder, UNetDecoder, PlainDecoder\nfrom .styleModules import Learnab"
  },
  {
    "path": "models/tensorBase.py",
    "chars": 18150,
    "preview": "import torch\nimport torch.nn\nimport torch.nn.functional as F\nfrom .sh import eval_sh_bases\nimport numpy as np\nimport tim"
  },
  {
    "path": "opt.py",
    "chars": 7633,
    "preview": "import configargparse\n\ndef config_parser(cmd=None):\n    parser = configargparse.ArgumentParser()\n    parser.add_argument"
  },
  {
    "path": "renderer.py",
    "chars": 14465,
    "preview": "import torch,os,imageio,sys\nfrom tqdm.auto import tqdm\nfrom dataLoader.ray_utils import get_rays\nfrom models.tensoRF imp"
  },
  {
    "path": "scripts/test.sh",
    "chars": 130,
    "preview": "CUDA_VISIBLE_DEVICES=$1 python train.py \\\n--config configs/llff.txt \\\n--ckpt log/trex/trex.th \\ \n--render_only 1 \n--rend"
  },
  {
    "path": "scripts/test_feature.sh",
    "chars": 274,
    "preview": "expname=trex\nCUDA_VISIBLE_DEVICES=$1 python train_feature.py \\\n--config configs/llff_feature.txt \\\n--datadir ./data/nerf"
  },
  {
    "path": "scripts/test_style.sh",
    "chars": 362,
    "preview": "expname=trex\nCUDA_VISIBLE_DEVICES=$1 python train_style.py \\\n--config configs/llff_style.txt \\\n--datadir ./data/nerf_llf"
  },
  {
    "path": "scripts/train.sh",
    "chars": 65,
    "preview": "CUDA_VISIBLE_DEVICES=$1 python train.py --config=configs/llff.txt"
  },
  {
    "path": "scripts/train_feature.sh",
    "chars": 81,
    "preview": "CUDA_VISIBLE_DEVICES=$1 python train_feature.py --config=configs/llff_feature.txt"
  },
  {
    "path": "scripts/train_style.sh",
    "chars": 77,
    "preview": "CUDA_VISIBLE_DEVICES=$1 python train_style.py --config=configs/llff_style.txt"
  },
  {
    "path": "train.py",
    "chars": 12844,
    "preview": "\nimport os\nfrom tqdm.auto import tqdm\nfrom opt import config_parser\n\n\n\nimport json, random\nfrom renderer import *\nfrom u"
  },
  {
    "path": "train_feature.py",
    "chars": 11107,
    "preview": "\nimport os\nfrom unittest.mock import patch\nfrom tqdm.auto import tqdm\nfrom opt import config_parser\n\n\n\nimport json, rand"
  },
  {
    "path": "train_style.py",
    "chars": 12355,
    "preview": "\nimport os\nfrom tqdm.auto import tqdm\nfrom opt import config_parser\nfrom PIL import Image, ImageFile\nfrom pathlib import"
  },
  {
    "path": "utils.py",
    "chars": 11395,
    "preview": "import cv2,torch\nimport numpy as np\nfrom PIL import Image\nimport torchvision.transforms as T\nimport torch.nn.functional "
  }
]

About this extraction

This page contains the full source code of the Kunhao-Liu/StyleRF GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 36 files (205.9 KB), approximately 60.2k tokens, and a symbol index with 246 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!