[
  {
    "path": "README.md",
    "content": "# [*CVPR 2023*] StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields\n## [Project page](https://kunhao-liu.github.io/StyleRF/) |  [Paper](https://arxiv.org/abs/2303.10598)\n\nThis repository contains a pytorch implementation for the paper: [StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields](https://arxiv.org/abs/2303.10598). StyleRF is an innovative 3D style transfer technique that achieves superior 3D stylization quality with precise geometry reconstruction and it can generalize to various new styles in a zero-shot manner. \n\n![teaser](https://kunhao-liu.github.io/StyleRF/resources/teaser.png)\n\n---\n## Installation\n> Tested on Ubuntu 20.04 + Pytorch 1.12.1\n\nInstall environment:\n```\nconda create -n StyleRF python=3.9\nconda activate StyleRF\npip install torch torchvision\npip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg kornia lpips tensorboard\n```\n\n## Datasets\nPlease put the datasets in `./data`. You can put the datasets elsewhere if you modify the corresponding paths in the configs.\n\n### 3D scene datasets\n* [nerf_synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1) \n* [llff](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)\n### Style image dataset\n* [WikiArt](https://www.kaggle.com/datasets/ipythonx/wikiart-gangogh-creating-art-gan)\n\n## Quick Start\nWe provide some trained checkpoints in: [StyleRF checkpoints](https://drive.google.com/drive/folders/1nF9-6lTIhktG5JjNvnmdYOo1LTvtK7Dw?usp=share_link)\n\nThen modify the following attributes in `scripts/test_style.sh`:\n* `--config`: choose `configs/llff_style.txt` or `configs/nerf_synthetic_style.txt` according to which type of dataset is being used\n* `--datadir`: dataset's path\n* `--ckpt`: checkpoint's path\n* `--style_img`: reference style image's path\n\n\nTo generate stylized novel views:\n```\nbash scripts/test_style.sh [GPU ID]\n```\nThe rendered stylized images can then be found in the directory under the checkpoint's path.\n\n## Training\n> Current settings in `configs` are tested on one NVIDIA RTX A5000 Graphics Card with 24G memory. To reduce memory consumption, you can set `batch_size`, `chunk_size` or `patch_size` to a smaller number.\n\nWe follow the following 3 steps of training:\n### 1. Train original TensoRF\nThis step is for reconstructing the density field, which contains more precise geometry details compared to mesh-based methods. You can skip this step by directly downloading pre-trained checkpoints provided by [TensoRF checkpoints](https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm).\n\nThe configs are stored in `configs/llff.txt` and `configs/nerf_synthetic.txt`. For the details of the settings, please also refer to [TensoRF](https://github.com/apchenstu/TensoRF). The checkpoints are stored in `./log` by default.\n\nYou can train the original TensoRF by:\n```\nbash script/train.sh [GPU ID]\n```\n\n### 2. Feature grid training stage\nThis step is for reconstructing the 3D gird containing the VGG features.\n\nThe configs are stored in `configs/llff_feature.txt` and `configs/nerf_synthetic_feature.txt`, in which `ckpt` specifies the checkpoints trained in the **first** step. The checkpoints are stored in `./log_feature` by default.\n\nThen run:\n```\nbash script/train_feature.sh [GPU ID]\n```\n\n\n### 3. Stylization training stage \nThis step is for training the style transfer modules.\n\nThe configs are stored in `configs/llff_style.txt` and `configs/nerf_synthetic_style.txt`, in which `ckpt` specifies the checkpoints trained in the **second** step. The checkpoints are stored in `./log_style` by default.\n\nThen run:\n```\nbash script/train_style.sh [GPU ID]\n```\n\n---\n## Training on 360 Unbounded Scenes\nThe code for training StyleRF on the Tanks&Temples dataset is available on the `360` branch. To access it, run `git checkout 360`.\n\n\n## Acknowledgments\nThis repo is heavily based on the [TensoRF](https://github.com/apchenstu/TensoRF). Thank them for sharing their amazing work!\n\n## Citation\nIf you find our code or paper helps, please consider citing:\n```\n@inproceedings{liu2023stylerf,\n  title={StyleRF: Zero-shot 3D Style Transfer of Neural Radiance Fields},\n  author={Liu, Kunhao and Zhan, Fangneng and Chen, Yiwen and Zhang, Jiahui and Yu, Yingchen and El Saddik, Abdulmotaleb and Lu, Shijian and Xing, Eric P},\n  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},\n  pages={8338--8348},\n  year={2023}\n}\n```\n\n"
  },
  {
    "path": "configs/llff.txt",
    "content": "\ndataset_name = llff\ndatadir = ./data/nerf_llff_data/trex\nexpname = trex\nbasedir = ./log\n\ndownsample_train = 4.0\nndc_ray = 1\n\nn_iters = 25000\nbatch_size = 4096\n\nN_voxel_init = 2097156 # 128**3\nN_voxel_final = 262144000 # 640**3\nupsamp_list = [2000,3000,4000,5500]\nupdate_AlphaMask_list = [2500]\n\nN_vis = -1 # vis all testing images\nvis_every = 10000\n\nrender_test = 1\nrender_path = 1\n\nn_lamb_sigma = [16,4,4]\nn_lamb_sh = [48,12,12]\n\nshadingMode = MLP_Fea\nfea2denseAct = relu\n\nview_pe = 0\nfea_pe = 0\n\nTV_weight_density = 1.0\nTV_weight_app = 1.0\n\n"
  },
  {
    "path": "configs/llff_feature.txt",
    "content": "dataset_name = llff\ndatadir = ./data/nerf_llff_data/trex\nckpt = ./log/trex/trex.th\nexpname = trex\nbasedir = ./log_feature\n\nTV_weight_feature = 80\n\ndownsample_train = 4.0\nndc_ray = 1\n\nn_iters = 25000\npatch_size = 256\nbatch_size = 4096\nchunk_size = 4096\n\nN_voxel_init = 2097156 # 128**3\nN_voxel_final = 262144000 # 640**3\nupsamp_list = [2000,3000,4000,5500]\nupdate_AlphaMask_list = [2500]\n\nn_lamb_sigma = [16,4,4]\nn_lamb_sh = [48,12,12]\n\nfea2denseAct = relu\n\n"
  },
  {
    "path": "configs/llff_style.txt",
    "content": "dataset_name = llff\ndatadir = ./data/nerf_llff_data/trex\nckpt = ./log_feature/trex/trex.th\nexpname = trex\nbasedir = ./log_style\n\nnSamples = 300\npatch_size = 256\nchunk_size = 2048\n\ncontent_weight = 1\nstyle_weight = 20\nfeaturemap_tv_weight = 0\nimage_tv_weight = 0\n\nrm_weight_mask_thre = 0.001\n\ndownsample_train = 4.0\nndc_ray = 1\n\nn_iters = 25000\n\nn_lamb_sigma = [16,4,4]\nn_lamb_sh = [48,12,12]\nN_voxel_init = 2097156 # 128**3\nN_voxel_final = 262144000 # 640**3\n\nfea2denseAct = relu\n"
  },
  {
    "path": "configs/nerf_synthetic.txt",
    "content": "\ndataset_name = blender\ndatadir = ./data/nerf_synthetic/lego\nexpname =  lego\nbasedir = ./log\n\nn_iters = 30000\nbatch_size = 4096\n\nN_voxel_init = 2097156 # 128**3\nN_voxel_final = 27000000 # 300**3\nupsamp_list = [2000,3000,4000,5500,7000]\nupdate_AlphaMask_list = [2000,4000]\n\nN_vis = 5\nvis_every = 10000\n\nrender_test = 1\n\nn_lamb_sigma = [16,16,16]\nn_lamb_sh = [48,48,48]\nmodel_name = TensorVMSplit\n\n\nshadingMode = MLP_Fea\nfea2denseAct = softplus\n\nview_pe = 2\nfea_pe = 2\n\nL1_weight_inital = 8e-5\nL1_weight_rest = 4e-5\nrm_weight_mask_thre = 1e-4\n"
  },
  {
    "path": "configs/nerf_synthetic_feature.txt",
    "content": "dataset_name = blender\ndatadir = ./data/nerf_synthetic/lego\nckpt = ./log/lego/lego.th\nexpname = lego\nbasedir = ./log_feature\n\nTV_weight_feature = 10\n\nn_iters = 25000\npatch_size = 256\nbatch_size = 4096\nchunk_size = 4096\n\nN_voxel_init = 2097156 # 128**3\nN_voxel_final = 27000000 # 300**3\nupsamp_list = [2000,3000,4000,5500,7000]\nupdate_AlphaMask_list = [2000,4000]\n\nrm_weight_mask_thre = 0.01\n\nn_lamb_sigma = [16,16,16]\nn_lamb_sh = [48,48,48]\n\nfea2denseAct = softplus\n"
  },
  {
    "path": "configs/nerf_synthetic_style.txt",
    "content": "dataset_name = blender\ndatadir = ./data/nerf_synthetic/lego\nckpt = ./log_feature/lego/lego.th\nexpname = lego\nbasedir = ./log_style\n\npatch_size = 256\nchunk_size = 2048\n\ncontent_weight = 1\nstyle_weight = 20\n\nrm_weight_mask_thre = 0.01\n\nn_iters = 25000\n\nn_lamb_sigma = [16,16,16]\nn_lamb_sh = [48,48,48]\nN_voxel_init = 2097156 # 128**3\nN_voxel_final = 27000000 # 300**3\n\nfea2denseAct = softplus\n"
  },
  {
    "path": "dataLoader/__init__.py",
    "content": "from .llff import LLFFDataset\nfrom .blender import BlenderDataset\nfrom .nsvf import NSVF\nfrom .tankstemple import TanksTempleDataset\nfrom .your_own_data import YourOwnDataset\n\n\n\ndataset_dict = {'blender': BlenderDataset,\n               'llff':LLFFDataset,\n               'tankstemple':TanksTempleDataset,\n               'nsvf':NSVF,\n                'own_data':YourOwnDataset}"
  },
  {
    "path": "dataLoader/blender.py",
    "content": "import torch,cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision import transforms as T\n\n\nfrom .ray_utils import *\n\n\nclass BlenderDataset(Dataset):\n    def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):\n\n        self.N_vis = N_vis\n        self.root_dir = datadir\n        self.split = split\n        self.is_stack = is_stack\n        self.img_wh = (int(800/downsample),int(800/downsample))\n        self.define_transforms()\n\n        self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])\n        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])\n        self.read_meta()\n        self.define_proj_mat()\n\n        self.white_bg = True\n        self.near_far = [2.0,6.0]\n        \n        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)\n        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)\n        self.downsample=downsample\n\n    def read_depth(self, filename):\n        depth = np.array(read_pfm(filename)[0], dtype=np.float32)  # (800, 800)\n        return depth\n    \n    def read_meta(self):\n\n        with open(os.path.join(self.root_dir, f\"transforms_{self.split}.json\"), 'r') as f:\n            self.meta = json.load(f)\n\n        w, h = self.img_wh\n        self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x'])  # original focal length\n        self.focal *= self.img_wh[0] / 800  # modify focal length to match size self.img_wh\n\n\n        # ray directions for all pixels, same for all images (same H, W, focal)\n        self.directions = get_ray_directions(h, w, [self.focal,self.focal])  # (h, w, 3)\n        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)\n        self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float()\n\n        self.image_paths = []\n        self.poses = []\n        self.all_rays = []\n        self.all_rgbs = []\n        self.all_masks = []\n        self.all_depth = []\n        self.downsample=1.0\n\n        img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis\n        idxs = list(range(0, len(self.meta['frames']), img_eval_interval))\n        for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#\n\n            frame = self.meta['frames'][i]\n            pose = np.array(frame['transform_matrix']) @ self.blender2opencv\n            c2w = torch.FloatTensor(pose)\n            self.poses += [c2w]\n\n            image_path = os.path.join(self.root_dir, f\"{frame['file_path']}.png\")\n            self.image_paths += [image_path]\n            img = Image.open(image_path)\n            \n            if self.downsample!=1.0:\n                img = img.resize(self.img_wh, Image.LANCZOS)\n            img = self.transform(img)  # (4, h, w)\n            img = img.view(4, -1).permute(1, 0)  # (h*w, 4) RGBA\n            self.all_masks.append(img[:, -1:].reshape(h,w,1)) # (h, w, 1) A\n            img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB\n            self.all_rgbs += [img]\n\n\n            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)\n            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)\n\n        \n        self.all_masks = torch.stack(self.all_masks) # (n_frames, h, w, 1)\n        self.poses = torch.stack(self.poses)\n        all_rays = self.all_rays\n        all_rgbs = self.all_rgbs\n\n        self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6)\n        self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)\n\n        if self.is_stack:\n            self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6)  # (len(self.meta['frames]),h,w,6)\n            avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True)\n            self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6)\n            self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)\n\n    @torch.no_grad()\n    def prepare_feature_data(self, encoder, chunk=8):\n        '''\n        Prepare feature maps as training data.\n        '''\n        assert self.is_stack, 'Dataset should contain original stacked taining data!'\n        print('====> prepare_feature_data ...')\n\n        frames_num, h, w, _ = self.all_rgbs_stack.size()\n        features = []\n\n        for chunk_idx in range(frames_num // chunk + int(frames_num % chunk > 0)):\n            rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda()\n            features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1\n            # resize to the size of rgb map so that rays can match\n            features_chunk = T.functional.resize(features_chunk, size=(h,w), \n                                                 interpolation=T.InterpolationMode.BILINEAR)\n            features.append(features_chunk.detach().cpu().requires_grad_(False))\n\n        self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256)\n        self.all_features = self.all_features_stack.reshape(-1, 256)\n        print('prepare_feature_data Done!')\n\n    def define_transforms(self):\n        self.transform = T.ToTensor()\n        \n    def define_proj_mat(self):\n        self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]\n\n    def world2ndc(self,points,lindisp=None):\n        device = points.device\n        return (points - self.center.to(device)) / self.radius.to(device)\n        \n    def __len__(self):\n        return len(self.all_rgbs)\n\n    def __getitem__(self, idx):\n\n        if self.split == 'train':  # use data in the buffers\n            sample = {'rays': self.all_rays[idx],\n                      'rgbs': self.all_rgbs[idx]}\n\n        else:  # create data for each image separately\n\n            img = self.all_rgbs[idx]\n            rays = self.all_rays[idx]\n            mask = self.all_masks[idx] # for quantity evaluation\n\n            sample = {'rays': rays,\n                      'rgbs': img,\n                      'mask': mask}\n        return sample\n"
  },
  {
    "path": "dataLoader/colmap2nerf.py",
    "content": "#!/usr/bin/env python3\n\n# Copyright (c) 2020-2022, NVIDIA CORPORATION.  All rights reserved.\n#\n# NVIDIA CORPORATION and its licensors retain all intellectual property\n# and proprietary rights in and to this software, related documentation\n# and any modifications thereto.  Any use, reproduction, disclosure or\n# distribution of this software and related documentation without an express\n# license agreement from NVIDIA CORPORATION is strictly prohibited.\n\nimport argparse\nimport os\nfrom pathlib import Path, PurePosixPath\n\nimport numpy as np\nimport json\nimport sys\nimport math\nimport cv2\nimport os\nimport shutil\n\ndef parse_args():\n\tparser = argparse.ArgumentParser(description=\"convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place\")\n\n\tparser.add_argument(\"--video_in\", default=\"\", help=\"run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also\")\n\tparser.add_argument(\"--video_fps\", default=2)\n\tparser.add_argument(\"--time_slice\", default=\"\", help=\"time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \\\"--time_slice '10,300'\\\" will generate images only from 10th second to 300th second of the video\")\n\tparser.add_argument(\"--run_colmap\", action=\"store_true\", help=\"run colmap first on the image folder\")\n\tparser.add_argument(\"--colmap_matcher\", default=\"sequential\", choices=[\"exhaustive\",\"sequential\",\"spatial\",\"transitive\",\"vocab_tree\"], help=\"select which matcher colmap should use. sequential for videos, exhaustive for adhoc images\")\n\tparser.add_argument(\"--colmap_db\", default=\"colmap.db\", help=\"colmap database filename\")\n\tparser.add_argument(\"--images\", default=\"images\", help=\"input path to the images\")\n\tparser.add_argument(\"--text\", default=\"colmap_text\", help=\"input path to the colmap text files (set automatically if run_colmap is used)\")\n\tparser.add_argument(\"--aabb_scale\", default=16, choices=[\"1\",\"2\",\"4\",\"8\",\"16\"], help=\"large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16\")\n\tparser.add_argument(\"--skip_early\", default=0, help=\"skip this many images from the start\")\n\tparser.add_argument(\"--out\", default=\"transforms.json\", help=\"output path\")\n\targs = parser.parse_args()\n\treturn args\n\ndef do_system(arg):\n\tprint(f\"==== running: {arg}\")\n\terr = os.system(arg)\n\tif err:\n\t\tprint(\"FATAL: command failed\")\n\t\tsys.exit(err)\n\ndef run_ffmpeg(args):\n\tif not os.path.isabs(args.images):\n\t\targs.images = os.path.join(os.path.dirname(args.video_in), args.images)\n\timages = args.images\n\tvideo = args.video_in\n\tfps = float(args.video_fps) or 1.0\n\tprint(f\"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.\")\n\tif (input(f\"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)\").lower().strip()+\"y\")[:1] != \"y\":\n\t\tsys.exit(1)\n\ttry:\n\t\tshutil.rmtree(images)\n\texcept:\n\t\tpass\n\tdo_system(f\"mkdir {images}\")\n\n\ttime_slice_value = \"\"\n\ttime_slice = args.time_slice\n\tif time_slice:\n\t    start, end = time_slice.split(\",\")\n\t    time_slice_value = f\",select='between(t\\,{start}\\,{end})'\"\n\tdo_system(f\"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \\\"fps={fps}{time_slice_value}\\\" {images}/%04d.jpg\")\n\ndef run_colmap(args):\n\tdb=args.colmap_db\n\timages=args.images\n\tdb_noext=str(Path(db).with_suffix(\"\"))\n\n\tif args.text==\"text\":\n\t\targs.text=db_noext+\"_text\"\n\ttext=args.text\n\tsparse=db_noext+\"_sparse\"\n\tprint(f\"running colmap with:\\n\\tdb={db}\\n\\timages={images}\\n\\tsparse={sparse}\\n\\ttext={text}\")\n\tif (input(f\"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)\").lower().strip()+\"y\")[:1] != \"y\":\n\t\tsys.exit(1)\n\tif os.path.exists(db):\n\t\tos.remove(db)\n\tdo_system(f\"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}\")\n\tdo_system(f\"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}\")\n\ttry:\n\t\tshutil.rmtree(sparse)\n\texcept:\n\t\tpass\n\tdo_system(f\"mkdir {sparse}\")\n\tdo_system(f\"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}\")\n\tdo_system(f\"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1\")\n\ttry:\n\t\tshutil.rmtree(text)\n\texcept:\n\t\tpass\n\tdo_system(f\"mkdir {text}\")\n\tdo_system(f\"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT\")\n\ndef variance_of_laplacian(image):\n\treturn cv2.Laplacian(image, cv2.CV_64F).var()\n\ndef sharpness(imagePath):\n\timage = cv2.imread(imagePath)\n\tgray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)\n\tfm = variance_of_laplacian(gray)\n\treturn fm\n\ndef qvec2rotmat(qvec):\n\treturn np.array([\n\t\t[\n\t\t\t1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,\n\t\t\t2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],\n\t\t\t2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]\n\t\t], [\n\t\t\t2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],\n\t\t\t1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,\n\t\t\t2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]\n\t\t], [\n\t\t\t2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],\n\t\t\t2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],\n\t\t\t1 - 2 * qvec[1]**2 - 2 * qvec[2]**2\n\t\t]\n\t])\n\ndef rotmat(a, b):\n\ta, b = a / np.linalg.norm(a), b / np.linalg.norm(b)\n\tv = np.cross(a, b)\n\tc = np.dot(a, b)\n\ts = np.linalg.norm(v)\n\tkmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])\n\treturn np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10))\n\ndef closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel\n\tda = da / np.linalg.norm(da)\n\tdb = db / np.linalg.norm(db)\n\tc = np.cross(da, db)\n\tdenom = np.linalg.norm(c)**2\n\tt = ob - oa\n\tta = np.linalg.det([t, db, c]) / (denom + 1e-10)\n\ttb = np.linalg.det([t, da, c]) / (denom + 1e-10)\n\tif ta > 0:\n\t\tta = 0\n\tif tb > 0:\n\t\ttb = 0\n\treturn (oa+ta*da+ob+tb*db) * 0.5, denom\n\nif __name__ == \"__main__\":\n\targs = parse_args()\n\tif args.video_in != \"\":\n\t\trun_ffmpeg(args)\n\tif args.run_colmap:\n\t\trun_colmap(args)\n\tAABB_SCALE = int(args.aabb_scale)\n\tSKIP_EARLY = int(args.skip_early)\n\tIMAGE_FOLDER = args.images\n\tTEXT_FOLDER = args.text\n\tOUT_PATH = args.out\n\tprint(f\"outputting to {OUT_PATH}...\")\n\twith open(os.path.join(TEXT_FOLDER,\"cameras.txt\"), \"r\") as f:\n\t\tangle_x = math.pi / 2\n\t\tfor line in f:\n\t\t\t# 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691\n\t\t\t# 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224\n\t\t\t# 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443\n\t\t\tif line[0] == \"#\":\n\t\t\t\tcontinue\n\t\t\tels = line.split(\" \")\n\t\t\tw = float(els[2])\n\t\t\th = float(els[3])\n\t\t\tfl_x = float(els[4])\n\t\t\tfl_y = float(els[4])\n\t\t\tk1 = 0\n\t\t\tk2 = 0\n\t\t\tp1 = 0\n\t\t\tp2 = 0\n\t\t\tcx = w / 2\n\t\t\tcy = h / 2\n\t\t\tif els[1] == \"SIMPLE_PINHOLE\":\n\t\t\t\tcx = float(els[5])\n\t\t\t\tcy = float(els[6])\n\t\t\telif els[1] == \"PINHOLE\":\n\t\t\t\tfl_y = float(els[5])\n\t\t\t\tcx = float(els[6])\n\t\t\t\tcy = float(els[7])\n\t\t\telif els[1] == \"SIMPLE_RADIAL\":\n\t\t\t\tcx = float(els[5])\n\t\t\t\tcy = float(els[6])\n\t\t\t\tk1 = float(els[7])\n\t\t\telif els[1] == \"RADIAL\":\n\t\t\t\tcx = float(els[5])\n\t\t\t\tcy = float(els[6])\n\t\t\t\tk1 = float(els[7])\n\t\t\t\tk2 = float(els[8])\n\t\t\telif els[1] == \"OPENCV\":\n\t\t\t\tfl_y = float(els[5])\n\t\t\t\tcx = float(els[6])\n\t\t\t\tcy = float(els[7])\n\t\t\t\tk1 = float(els[8])\n\t\t\t\tk2 = float(els[9])\n\t\t\t\tp1 = float(els[10])\n\t\t\t\tp2 = float(els[11])\n\t\t\telse:\n\t\t\t\tprint(\"unknown camera model \", els[1])\n\t\t\t# fl = 0.5 * w / tan(0.5 * angle_x);\n\t\t\tangle_x = math.atan(w / (fl_x * 2)) * 2\n\t\t\tangle_y = math.atan(h / (fl_y * 2)) * 2\n\t\t\tfovx = angle_x * 180 / math.pi\n\t\t\tfovy = angle_y * 180 / math.pi\n\n\tprint(f\"camera:\\n\\tres={w,h}\\n\\tcenter={cx,cy}\\n\\tfocal={fl_x,fl_y}\\n\\tfov={fovx,fovy}\\n\\tk={k1,k2} p={p1,p2} \")\n\n\twith open(os.path.join(TEXT_FOLDER,\"images.txt\"), \"r\") as f:\n\t\ti = 0\n\t\tbottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4])\n\t\tout = {\n\t\t\t\"camera_angle_x\": angle_x,\n\t\t\t\"camera_angle_y\": angle_y,\n\t\t\t\"fl_x\": fl_x,\n\t\t\t\"fl_y\": fl_y,\n\t\t\t\"k1\": k1,\n\t\t\t\"k2\": k2,\n\t\t\t\"p1\": p1,\n\t\t\t\"p2\": p2,\n\t\t\t\"cx\": cx,\n\t\t\t\"cy\": cy,\n\t\t\t\"w\": w,\n\t\t\t\"h\": h,\n\t\t\t\"aabb_scale\": AABB_SCALE,\n\t\t\t\"frames\": [],\n\t\t}\n\n\t\tup = np.zeros(3)\n\t\tfor line in f:\n\t\t\tline = line.strip()\n\t\t\tif line[0] == \"#\":\n\t\t\t\tcontinue\n\t\t\ti = i + 1\n\t\t\tif i < SKIP_EARLY*2:\n\t\t\t\tcontinue\n\t\t\tif  i % 2 == 1:\n\t\t\t\telems=line.split(\" \") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces)\n\t\t\t\t#name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9])))\n\t\t\t\t# why is this requireing a relitive path while using ^\n\t\t\t\timage_rel = os.path.relpath(IMAGE_FOLDER)\n\t\t\t\tname = str(f\"./{image_rel}/{'_'.join(elems[9:])}\")\n\t\t\t\tb=sharpness(name)\n\t\t\t\tprint(name, \"sharpness=\",b)\n\t\t\t\timage_id = int(elems[0])\n\t\t\t\tqvec = np.array(tuple(map(float, elems[1:5])))\n\t\t\t\ttvec = np.array(tuple(map(float, elems[5:8])))\n\t\t\t\tR = qvec2rotmat(-qvec)\n\t\t\t\tt = tvec.reshape([3,1])\n\t\t\t\tm = np.concatenate([np.concatenate([R, t], 1), bottom], 0)\n\t\t\t\tc2w = np.linalg.inv(m)\n\t\t\t\tc2w[0:3,2] *= -1 # flip the y and z axis\n\t\t\t\tc2w[0:3,1] *= -1\n\t\t\t\tc2w = c2w[[1,0,2,3],:] # swap y and z\n\t\t\t\tc2w[2,:] *= -1 # flip whole world upside down\n\n\t\t\t\tup += c2w[0:3,1]\n\n\t\t\t\tframe={\"file_path\":name,\"sharpness\":b,\"transform_matrix\": c2w}\n\t\t\t\tout[\"frames\"].append(frame)\n\tnframes = len(out[\"frames\"])\n\tup = up / np.linalg.norm(up)\n\tprint(\"up vector was\", up)\n\tR = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]\n\tR = np.pad(R,[0,1])\n\tR[-1, -1] = 1\n\n\n\tfor f in out[\"frames\"]:\n\t\tf[\"transform_matrix\"] = np.matmul(R, f[\"transform_matrix\"]) # rotate up to be the z axis\n\n\t# find a central point they are all looking at\n\tprint(\"computing center of attention...\")\n\ttotw = 0.0\n\ttotp = np.array([0.0, 0.0, 0.0])\n\tfor f in out[\"frames\"]:\n\t\tmf = f[\"transform_matrix\"][0:3,:]\n\t\tfor g in out[\"frames\"]:\n\t\t\tmg = g[\"transform_matrix\"][0:3,:]\n\t\t\tp, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])\n\t\t\tif w > 0.01:\n\t\t\t\ttotp += p*w\n\t\t\t\ttotw += w\n\ttotp /= totw\n\tprint(totp) # the cameras are looking at totp\n\tfor f in out[\"frames\"]:\n\t\tf[\"transform_matrix\"][0:3,3] -= totp\n\n\tavglen = 0.\n\tfor f in out[\"frames\"]:\n\t\tavglen += np.linalg.norm(f[\"transform_matrix\"][0:3,3])\n\tavglen /= nframes\n\tprint(\"avg camera distance from origin\", avglen)\n\tfor f in out[\"frames\"]:\n\t\tf[\"transform_matrix\"][0:3,3] *= 4.0 / avglen # scale to \"nerf sized\"\n\n\tfor f in out[\"frames\"]:\n\t\tf[\"transform_matrix\"] = f[\"transform_matrix\"].tolist()\n\tprint(nframes,\"frames\")\n\tprint(f\"writing {OUT_PATH}\")\n\twith open(OUT_PATH, \"w\") as outfile:\n\t\tjson.dump(out, outfile, indent=2)"
  },
  {
    "path": "dataLoader/llff.py",
    "content": "import torch\nfrom torch.utils.data import Dataset\nimport glob\nimport numpy as np\nimport os\nfrom PIL import Image\nfrom torchvision import transforms as T\n\nfrom .ray_utils import *\n\n\ndef normalize(v):\n    \"\"\"Normalize a vector.\"\"\"\n    return v / np.linalg.norm(v)\n\n\ndef average_poses(poses):\n    \"\"\"\n    Calculate the average pose, which is then used to center all poses\n    using @center_poses. Its computation is as follows:\n    1. Compute the center: the average of pose centers.\n    2. Compute the z axis: the normalized average z axis.\n    3. Compute axis y': the average y axis.\n    4. Compute x' = y' cross product z, then normalize it as the x axis.\n    5. Compute the y axis: z cross product x.\n\n    Note that at step 3, we cannot directly use y' as y axis since it's\n    not necessarily orthogonal to z axis. We need to pass from x to y.\n    Inputs:\n        poses: (N_images, 3, 4)\n    Outputs:\n        pose_avg: (3, 4) the average pose\n    \"\"\"\n    # 1. Compute the center\n    center = poses[..., 3].mean(0)  # (3)\n\n    # 2. Compute the z axis\n    z = normalize(poses[..., 2].mean(0))  # (3)\n\n    # 3. Compute axis y' (no need to normalize as it's not the final output)\n    y_ = poses[..., 1].mean(0)  # (3)\n\n    # 4. Compute the x axis\n    x = normalize(np.cross(z, y_))  # (3)\n\n    # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)\n    y = np.cross(x, z)  # (3)\n\n    pose_avg = np.stack([x, y, z, center], 1)  # (3, 4)\n\n    return pose_avg\n\n\ndef center_poses(poses, blender2opencv):\n    \"\"\"\n    Center the poses so that we can use NDC.\n    See https://github.com/bmild/nerf/issues/34\n    Inputs:\n        poses: (N_images, 3, 4)\n    Outputs:\n        poses_centered: (N_images, 3, 4) the centered poses\n        pose_avg: (3, 4) the average pose\n    \"\"\"\n    poses = poses @ blender2opencv\n    pose_avg = average_poses(poses)  # (3, 4)\n    pose_avg_homo = np.eye(4)\n    pose_avg_homo[:3] = pose_avg  # convert to homogeneous coordinate for faster computation\n    pose_avg_homo = pose_avg_homo\n    # by simply adding 0, 0, 0, 1 as the last row\n    last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1))  # (N_images, 1, 4)\n    poses_homo = \\\n        np.concatenate([poses, last_row], 1)  # (N_images, 4, 4) homogeneous coordinate\n\n    poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo  # (N_images, 4, 4)\n    #     poses_centered = poses_centered  @ blender2opencv\n    poses_centered = poses_centered[:, :3]  # (N_images, 3, 4)\n\n    return poses_centered, pose_avg_homo\n\n\ndef viewmatrix(z, up, pos):\n    vec2 = normalize(z)\n    vec1_avg = up\n    vec0 = normalize(np.cross(vec1_avg, vec2))\n    vec1 = normalize(np.cross(vec2, vec0))\n    m = np.eye(4)\n    m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)\n    return m\n\n\ndef render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):\n    render_poses = []\n    rads = np.array(list(rads) + [1.])\n\n    for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:\n        c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)\n        z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))\n        render_poses.append(viewmatrix(z, up, c))\n    return render_poses\n\n\ndef get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):\n    # center pose\n    c2w = average_poses(c2ws_all)\n\n    # Get average pose\n    up = normalize(c2ws_all[:, :3, 1].sum(0))\n\n    # Find a reasonable \"focus depth\" for this dataset\n    dt = 0.75\n    close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0\n    focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))\n\n    # Get radii for spiral path\n    zdelta = near_fars.min() * .2\n    tt = c2ws_all[:, :3, 3]\n    rads = np.percentile(np.abs(tt), 90, 0) * rads_scale\n    render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)\n    return np.stack(render_poses)\n\ndef get_interpolation_path(c2ws_all, steps=30):\n    \n    # flower\n    # idx0 = 1\n    # idx1 = 10\n\n    # trex\n    # idx0 = 8\n    # idx1 = 53\n\n    # horns\n    idx0 = 18\n    idx1 = 47\n\n    v = np.linspace(0,1,num=steps)\n\n    c2w0 = c2ws_all[idx0]\n    c2w1 = c2ws_all[idx1]\n\n    c2w_ = []\n    for i in range(steps):\n        c2w_.append(c2w0*v[i] + c2w1*(1-v[i]))\n\n    return np.stack(c2w_)\n\n\nclass LLFFDataset(Dataset):\n    def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8):\n\n        self.root_dir = datadir\n        self.split = split\n        self.hold_every = hold_every\n        self.is_stack = is_stack\n        self.downsample = downsample\n        self.define_transforms()\n\n        self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])\n        self.read_meta()\n        self.white_bg = False\n\n        #         self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]\n        self.near_far = [0.0, 1.0]\n        self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]])\n        # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])\n        self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)\n        self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)\n\n    def read_meta(self):\n\n\n        poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy'))  # (N_images, 17)\n        self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*')))\n        # load full resolution image then resize\n        if self.split in ['train', 'test']:\n            assert len(poses_bounds) == len(self.image_paths), \\\n                'Mismatch between number of images and number of poses! Please rerun COLMAP!'\n\n        poses = poses_bounds[:, :15].reshape(-1, 3, 5)  # (N_images, 3, 5)\n        self.near_fars = poses_bounds[:, -2:]  # (N_images, 2)\n        hwf = poses[:, :, -1]\n\n        # Step 1: rescale focal length according to training resolution\n        H, W, self.focal = poses[0, :, -1]  # original intrinsics, same for all images\n        self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])\n        self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]\n\n        # Step 2: correct poses\n        # Original poses has rotation in form \"down right back\", change to \"right up back\"\n        # See https://github.com/bmild/nerf/issues/34\n        poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)\n        # (N_images, 3, 4) exclude H, W, focal\n        self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)\n\n        # Step 3: correct scale so that the nearest depth is at a little more than 1.0\n        # See https://github.com/bmild/nerf/issues/34\n        near_original = self.near_fars.min()\n        scale_factor = near_original * 0.75  # 0.75 is the default parameter\n        # the nearest depth is at 1/0.75=1.33\n        self.near_fars /= scale_factor\n        self.poses[..., 3] /= scale_factor\n\n        # build rendering path\n        N_views, N_rots = 120, 2\n        tt = self.poses[:, :3, 3]  # ptstocam(poses[:3,3,:].T, c2w).T\n        up = normalize(self.poses[:, :3, 1].sum(0))\n        rads = np.percentile(np.abs(tt), 90, 0)\n\n        self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)\n        # self.render_path = get_interpolation_path(self.poses)\n\n        # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)\n        # val_idx = np.argmin(distances_from_center)  # choose val image as the closest to\n        # center image\n\n        # ray directions for all pixels, same for all images (same H, W, focal)\n        W, H = self.img_wh\n        self.directions = get_ray_directions_blender(H, W, self.focal)  # (H, W, 3)\n\n        average_pose = average_poses(self.poses)\n        dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)\n        i_test = np.arange(0, self.poses.shape[0], self.hold_every)  # [np.argmin(dists)]\n        img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))\n\n        # use first N_images-1 to train, the LAST is val\n        self.all_rays = []\n        self.all_rgbs = []\n        for i in img_list:\n            image_path = self.image_paths[i]\n            c2w = torch.FloatTensor(self.poses[i])\n\n            img = Image.open(image_path).convert('RGB')\n            if self.downsample != 1.0:\n                img = img.resize(self.img_wh, Image.LANCZOS)\n            img = self.transform(img)  # (3, h, w)\n\n            img = img.view(3, -1).permute(1, 0)  # (h*w, 3) RGB\n            self.all_rgbs += [img]\n            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)\n            rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)\n            # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)\n\n            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)\n\n        all_rays = self.all_rays\n        all_rgbs = self.all_rgbs\n\n        self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6)\n        self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)\n\n        if self.is_stack:\n            self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6)  # (len(self.meta['frames]),h,w,6)\n            avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True)\n            self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6)\n            self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)\n\n    @torch.no_grad()\n    def prepare_feature_data(self, encoder, chunk=8):\n        '''\n        Prepare feature maps as training data.\n        '''\n        assert self.is_stack, 'Dataset should contain original stacked taining data!'\n        print('====> prepare_feature_data ...')\n\n        frames_num, h, w, _ = self.all_rgbs_stack.size()\n        features = []\n\n        for chunk_idx in range(frames_num // chunk + int(frames_num % chunk > 0)):\n            rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda()\n            features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1\n            # resize to the size of rgb map so that rays can match\n            features_chunk = T.functional.resize(features_chunk, size=(h,w), \n                                                 interpolation=T.InterpolationMode.BILINEAR)\n            features.append(features_chunk.detach().cpu().requires_grad_(False))\n\n        self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256)\n        self.all_features = self.all_features_stack.reshape(-1, 256)\n        print('prepare_feature_data Done!')\n\n    def define_transforms(self):\n        self.transform = T.ToTensor()\n\n    def __len__(self):\n        return len(self.all_rgbs)\n\n    def __getitem__(self, idx):\n\n        sample = {'rays': self.all_rays[idx],\n                  'rgbs': self.all_rgbs[idx]}\n\n        return sample"
  },
  {
    "path": "dataLoader/nsvf.py",
    "content": "import torch\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision import transforms as T\n\nfrom .ray_utils import *\n\ntrans_t = lambda t : torch.Tensor([\n    [1,0,0,0],\n    [0,1,0,0],\n    [0,0,1,t],\n    [0,0,0,1]]).float()\n\nrot_phi = lambda phi : torch.Tensor([\n    [1,0,0,0],\n    [0,np.cos(phi),-np.sin(phi),0],\n    [0,np.sin(phi), np.cos(phi),0],\n    [0,0,0,1]]).float()\n\nrot_theta = lambda th : torch.Tensor([\n    [np.cos(th),0,-np.sin(th),0],\n    [0,1,0,0],\n    [np.sin(th),0, np.cos(th),0],\n    [0,0,0,1]]).float()\n\n\ndef pose_spherical(theta, phi, radius):\n    c2w = trans_t(radius)\n    c2w = rot_phi(phi/180.*np.pi) @ c2w\n    c2w = rot_theta(theta/180.*np.pi) @ c2w\n    c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w\n    return c2w\n\nclass NSVF(Dataset):\n    \"\"\"NSVF Generic Dataset.\"\"\"\n    def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False):\n        self.root_dir = datadir\n        self.split = split\n        self.is_stack = is_stack\n        self.downsample = downsample\n        self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))\n        self.define_transforms()\n\n        self.white_bg = True\n        self.near_far = [0.5,6.0]\n        self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)\n        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])\n        self.read_meta()\n        self.define_proj_mat()\n        \n        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)\n        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)\n    \n    def bbox2corners(self):\n        corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)\n        for i in range(3):\n            corners[i,[0,1],i] = corners[i,[1,0],i] \n        return corners.view(-1,3)\n        \n        \n    def read_meta(self):\n        with open(os.path.join(self.root_dir, \"intrinsics.txt\")) as f:\n            focal = float(f.readline().split()[0])\n        self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]])\n        self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1)\n\n        pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))\n        img_files  = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))\n\n        if self.split == 'train':\n            pose_files = [x for x in pose_files if x.startswith('0_')]\n            img_files = [x for x in img_files if x.startswith('0_')]\n        elif self.split == 'val':\n            pose_files = [x for x in pose_files if x.startswith('1_')]\n            img_files = [x for x in img_files if x.startswith('1_')]\n        elif self.split == 'test':\n            test_pose_files = [x for x in pose_files if x.startswith('2_')]\n            test_img_files = [x for x in img_files if x.startswith('2_')]\n            if len(test_pose_files) == 0:\n                test_pose_files = [x for x in pose_files if x.startswith('1_')]\n                test_img_files = [x for x in img_files if x.startswith('1_')]\n            pose_files = test_pose_files\n            img_files = test_img_files\n\n        # ray directions for all pixels, same for all images (same H, W, focal)\n        self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2])  # (h, w, 3)\n        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)\n\n\n        self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)\n        \n        self.poses = []\n        self.all_rays = []\n        self.all_rgbs = []\n\n        assert len(img_files) == len(pose_files)\n        for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):\n            image_path = os.path.join(self.root_dir, 'rgb', img_fname)\n            img = Image.open(image_path)\n            if self.downsample!=1.0:\n                img = img.resize(self.img_wh, Image.LANCZOS)\n            img = self.transform(img)  # (4, h, w)\n            img = img.view(img.shape[0], -1).permute(1, 0)  # (h*w, 4) RGBA\n            if img.shape[-1]==4:\n                img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB\n            self.all_rgbs += [img]\n\n            c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv\n            c2w = torch.FloatTensor(c2w)\n            self.poses.append(c2w)  # C2W\n            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)\n            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 8)\n            \n#             w2c = torch.inverse(c2w)\n#\n\n        self.poses = torch.stack(self.poses)\n        if 'train' == self.split:\n            if self.is_stack:\n                self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6)  # (len(self.meta['frames])*h*w, 3)\n                self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames])*h*w, 3) \n            else:\n                self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)\n                self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (len(self.meta['frames])*h*w, 3)\n        else:\n            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)\n            self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)\n\n \n    def define_transforms(self):\n        self.transform = T.ToTensor()\n        \n    def define_proj_mat(self):\n        self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]\n\n    def world2ndc(self, points):\n        device = points.device\n        return (points - self.center.to(device)) / self.radius.to(device)\n        \n    def __len__(self):\n        if self.split == 'train':\n            return len(self.all_rays)\n        return len(self.all_rgbs)\n\n    def __getitem__(self, idx):\n\n        if self.split == 'train':  # use data in the buffers\n            sample = {'rays': self.all_rays[idx],\n                      'rgbs': self.all_rgbs[idx]}\n\n        else:  # create data for each image separately\n\n            img = self.all_rgbs[idx]\n            rays = self.all_rays[idx]\n\n            sample = {'rays': rays,\n                      'rgbs': img}\n        return sample"
  },
  {
    "path": "dataLoader/ray_utils.py",
    "content": "import torch, re\nimport numpy as np\nfrom torch import searchsorted\nfrom kornia import create_meshgrid\n\n\n# from utils import index_point_feature\n\ndef depth2dist(z_vals, cos_angle):\n    # z_vals: [N_ray N_sample]\n    device = z_vals.device\n    dists = z_vals[..., 1:] - z_vals[..., :-1]\n    dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1)  # [N_rays, N_samples]\n    dists = dists * cos_angle.unsqueeze(-1)\n    return dists\n\n\ndef ndc2dist(ndc_pts, cos_angle):\n    dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1)\n    dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1)  # [N_rays, N_samples]\n    return dists\n\n\ndef get_ray_directions(H, W, focal, center=None):\n    \"\"\"\n    Get ray directions for all pixels in camera coordinate.\n    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/\n               ray-tracing-generating-camera-rays/standard-coordinate-systems\n    Inputs:\n        H, W, focal: image height, width and focal length\n    Outputs:\n        directions: (H, W, 3), the direction of the rays in camera coordinate\n    \"\"\"\n    grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5\n\n    i, j = grid.unbind(-1)\n    # the direction here is without +0.5 pixel centering as calibration is not so accurate\n    # see https://github.com/bmild/nerf/issues/24\n    cent = center if center is not None else [W / 2, H / 2]\n    directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1)  # (H, W, 3)\n\n    return directions\n\n\ndef get_ray_directions_blender(H, W, focal, center=None):\n    \"\"\"\n    Get ray directions for all pixels in camera coordinate.\n    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/\n               ray-tracing-generating-camera-rays/standard-coordinate-systems\n    Inputs:\n        H, W, focal: image height, width and focal length\n    Outputs:\n        directions: (H, W, 3), the direction of the rays in camera coordinate\n    \"\"\"\n    grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5\n    i, j = grid.unbind(-1)\n    # the direction here is without +0.5 pixel centering as calibration is not so accurate\n    # see https://github.com/bmild/nerf/issues/24\n    cent = center if center is not None else [W / 2, H / 2]\n    directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)],\n                             -1)  # (H, W, 3)\n\n    return directions\n\n\ndef get_rays(directions, c2w):\n    \"\"\"\n    Get ray origin and normalized directions in world coordinate for all pixels in one image.\n    Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/\n               ray-tracing-generating-camera-rays/standard-coordinate-systems\n    Inputs:\n        directions: (H, W, 3) precomputed ray directions in camera coordinate\n        c2w: (3, 4) transformation matrix from camera coordinate to world coordinate\n    Outputs:\n        rays_o: (H*W, 3), the origin of the rays in world coordinate\n        rays_d: (H*W, 3), the normalized direction of the rays in world coordinate\n    \"\"\"\n    # Rotate ray directions from camera coordinate to the world coordinate\n    rays_d = directions @ c2w[:3, :3].T  # (H, W, 3)\n    # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)\n    # The origin of all rays is the camera origin in world coordinate\n    rays_o = c2w[:3, 3].expand(rays_d.shape)  # (H, W, 3)\n\n    rays_d = rays_d.view(-1, 3)\n    rays_o = rays_o.view(-1, 3)\n\n    return rays_o, rays_d\n\n\ndef ndc_rays_blender(H, W, focal, near, rays_o, rays_d):\n    # Shift ray origins to near plane\n    t = -(near + rays_o[..., 2]) / rays_d[..., 2]\n    rays_o = rays_o + t[..., None] * rays_d\n\n    # Projection\n    o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]\n    o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]\n    o2 = 1. + 2. * near / rays_o[..., 2]\n\n    d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])\n    d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])\n    d2 = -2. * near / rays_o[..., 2]\n\n    rays_o = torch.stack([o0, o1, o2], -1)\n    rays_d = torch.stack([d0, d1, d2], -1)\n\n    return rays_o, rays_d\n\ndef ndc_rays(H, W, focal, near, rays_o, rays_d):\n    # Shift ray origins to near plane\n    t = (near - rays_o[..., 2]) / rays_d[..., 2]\n    rays_o = rays_o + t[..., None] * rays_d\n\n    # Projection\n    o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]\n    o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]\n    o2 = 1. - 2. * near / rays_o[..., 2]\n\n    d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])\n    d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])\n    d2 = 2. * near / rays_o[..., 2]\n\n    rays_o = torch.stack([o0, o1, o2], -1)\n    rays_d = torch.stack([d0, d1, d2], -1)\n\n    return rays_o, rays_d\n\n# Hierarchical sampling (section 5.2)\ndef sample_pdf(bins, weights, N_samples, det=False, pytest=False):\n    device = weights.device\n    # Get pdf\n    weights = weights + 1e-5  # prevent nans\n    pdf = weights / torch.sum(weights, -1, keepdim=True)\n    cdf = torch.cumsum(pdf, -1)\n    cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)  # (batch, len(bins))\n\n    # Take uniform samples\n    if det:\n        u = torch.linspace(0., 1., steps=N_samples, device=device)\n        u = u.expand(list(cdf.shape[:-1]) + [N_samples])\n    else:\n        u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device)\n\n    # Pytest, overwrite u with numpy's fixed random numbers\n    if pytest:\n        np.random.seed(0)\n        new_shape = list(cdf.shape[:-1]) + [N_samples]\n        if det:\n            u = np.linspace(0., 1., N_samples)\n            u = np.broadcast_to(u, new_shape)\n        else:\n            u = np.random.rand(*new_shape)\n        u = torch.Tensor(u)\n\n    # Invert CDF\n    u = u.contiguous()\n    inds = searchsorted(cdf.detach(), u, right=True)\n    below = torch.max(torch.zeros_like(inds - 1), inds - 1)\n    above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)\n    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)\n\n    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]\n    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)\n    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)\n\n    denom = (cdf_g[..., 1] - cdf_g[..., 0])\n    denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)\n    t = (u - cdf_g[..., 0]) / denom\n    samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])\n\n    return samples\n\n\ndef dda(rays_o, rays_d, bbox_3D):\n    inv_ray_d = 1.0 / (rays_d + 1e-6)\n    t_min = (bbox_3D[:1] - rays_o) * inv_ray_d  # N_rays 3\n    t_max = (bbox_3D[1:] - rays_o) * inv_ray_d\n    t = torch.stack((t_min, t_max))  # 2 N_rays 3\n    t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0]\n    t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0]\n    return t_min, t_max\n\n\ndef ray_marcher(rays,\n                N_samples=64,\n                lindisp=False,\n                perturb=0,\n                bbox_3D=None):\n    \"\"\"\n    sample points along the rays\n    Inputs:\n        rays: ()\n\n    Returns:\n\n    \"\"\"\n\n    # Decompose the inputs\n    N_rays = rays.shape[0]\n    rays_o, rays_d = rays[:, 0:3], rays[:, 3:6]  # both (N_rays, 3)\n    near, far = rays[:, 6:7], rays[:, 7:8]  # both (N_rays, 1)\n\n    if bbox_3D is not None:\n        # cal aabb boundles\n        near, far = dda(rays_o, rays_d, bbox_3D)\n\n    # Sample depth points\n    z_steps = torch.linspace(0, 1, N_samples, device=rays.device)  # (N_samples)\n    if not lindisp:  # use linear sampling in depth space\n        z_vals = near * (1 - z_steps) + far * z_steps\n    else:  # use linear sampling in disparity space\n        z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)\n\n    z_vals = z_vals.expand(N_rays, N_samples)\n\n    if perturb > 0:  # perturb sampling depths (z_vals)\n        z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:])  # (N_rays, N_samples-1) interval mid points\n        # get intervals between samples\n        upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1)\n        lower = torch.cat([z_vals[:, :1], z_vals_mid], -1)\n\n        perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)\n        z_vals = lower + (upper - lower) * perturb_rand\n\n    xyz_coarse_sampled = rays_o.unsqueeze(1) + \\\n                         rays_d.unsqueeze(1) * z_vals.unsqueeze(2)  # (N_rays, N_samples, 3)\n\n    return xyz_coarse_sampled, rays_o, rays_d, z_vals\n\n\ndef read_pfm(filename):\n    file = open(filename, 'rb')\n    color = None\n    width = None\n    height = None\n    scale = None\n    endian = None\n\n    header = file.readline().decode('utf-8').rstrip()\n    if header == 'PF':\n        color = True\n    elif header == 'Pf':\n        color = False\n    else:\n        raise Exception('Not a PFM file.')\n\n    dim_match = re.match(r'^(\\d+)\\s(\\d+)\\s$', file.readline().decode('utf-8'))\n    if dim_match:\n        width, height = map(int, dim_match.groups())\n    else:\n        raise Exception('Malformed PFM header.')\n\n    scale = float(file.readline().rstrip())\n    if scale < 0:  # little-endian\n        endian = '<'\n        scale = -scale\n    else:\n        endian = '>'  # big-endian\n\n    data = np.fromfile(file, endian + 'f')\n    shape = (height, width, 3) if color else (height, width)\n\n    data = np.reshape(data, shape)\n    data = np.flipud(data)\n    file.close()\n    return data, scale\n\n\ndef ndc_bbox(all_rays):\n    near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0]\n    near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0]\n    far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0]\n    far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0]\n    print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}')\n    return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max)))\n\nimport torchvision\nnormalize_vgg = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], \n                                                 std=[0.229, 0.224, 0.225])\n\ndef denormalize_vgg(img):\n    im = img.clone()\n    im[:, 0, :, :] *= 0.229\n    im[:, 1, :, :] *= 0.224\n    im[:, 2, :, :] *= 0.225\n    im[:, 0, :, :] += 0.485\n    im[:, 1, :, :] += 0.456\n    im[:, 2, :, :] += 0.406\n    return im"
  },
  {
    "path": "dataLoader/styleLoader.py",
    "content": "from torch.utils.data import DataLoader\nfrom torchvision import datasets\nimport torchvision.transforms as T\n\n\ndef getDataLoader(dataset_path, batch_size, sampler, image_side_length=256, num_workers=2):\n    transform = T.Compose([\n                T.Resize(size=(image_side_length*2, image_side_length*2)),\n                T.RandomCrop(image_side_length),\n                T.ToTensor(),\n            ])\n\n    train_dataset = datasets.ImageFolder(dataset_path, transform=transform)\n    dataloader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler(len(train_dataset)), num_workers=num_workers)\n\n    return dataloader"
  },
  {
    "path": "dataLoader/tankstemple.py",
    "content": "import torch\nfrom torch.utils.data import Dataset\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision import transforms as T\nimport random\n\nfrom .ray_utils import *\n\n\ndef circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):\n    if axis == 'z':\n        return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h]\n    elif axis == 'y':\n        return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)]\n    else:\n        return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)]\n\n\ndef cross(x, y, axis=0):\n    T = torch if isinstance(x, torch.Tensor) else np\n    return T.cross(x, y, axis)\n\n\ndef normalize(x, axis=-1, order=2):\n    if isinstance(x, torch.Tensor):\n        l2 = x.norm(p=order, dim=axis, keepdim=True)\n        return x / (l2 + 1e-8), l2\n\n    else:\n        l2 = np.linalg.norm(x, order, axis)\n        l2 = np.expand_dims(l2, axis)\n        l2[l2 == 0] = 1\n        return x / l2,\n\n\ndef cat(x, axis=1):\n    if isinstance(x[0], torch.Tensor):\n        return torch.cat(x, dim=axis)\n    return np.concatenate(x, axis=axis)\n\n\ndef look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):\n    \"\"\"\n    This function takes a vector 'camera_position' which specifies the location\n    of the camera in world coordinates and two vectors `at` and `up` which\n    indicate the position of the object and the up directions of the world\n    coordinate system respectively. The object is assumed to be centered at\n    the origin.\n    The output is a rotation matrix representing the transformation\n    from world coordinates -> view coordinates.\n    Input:\n        camera_position: 3\n        at: 1 x 3 or N x 3  (0, 0, 0) in default\n        up: 1 x 3 or N x 3  (0, 1, 0) in default\n    \"\"\"\n\n    if at is None:\n        at = torch.zeros_like(camera_position)\n    else:\n        at = torch.tensor(at).type_as(camera_position)\n    if up is None:\n        up = torch.zeros_like(camera_position)\n        up[2] = -1\n    else:\n        up = torch.tensor(up).type_as(camera_position)\n\n    z_axis = normalize(at - camera_position)[0]\n    x_axis = normalize(cross(up, z_axis))[0]\n    y_axis = normalize(cross(z_axis, x_axis))[0]\n\n    R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)\n    return R\n\n\ndef gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):\n    c2ws = []\n    for t in range(frames):\n        c2w = torch.eye(4)\n        cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi))\n        cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True)\n        c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot\n        c2ws.append(c2w)\n    return torch.stack(c2ws)\n\nclass TanksTempleDataset(Dataset):\n    \"\"\"NSVF Generic Dataset.\"\"\"\n    def __init__(self, datadir, split='train', downsample=4.0, wh=[1920,1080], is_stack=False):\n        self.root_dir = datadir\n        self.split = split\n        self.is_stack = is_stack\n        self.downsample = downsample\n        self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))\n        self.define_transforms()\n\n        self.white_bg = True\n        self.near_far = [0.01,6.0]\n        self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2\n\n        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])\n        self.read_meta()\n        self.define_proj_mat()\n        \n        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)\n        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)\n    \n    def bbox2corners(self):\n        corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)\n        for i in range(3):\n            corners[i,[0,1],i] = corners[i,[1,0],i] \n        return corners.view(-1,3)\n        \n        \n    def read_meta(self):\n\n        self.intrinsics = np.loadtxt(os.path.join(self.root_dir, \"intrinsics.txt\"))\n        self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1)\n        pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))\n        img_files  = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))\n\n        if self.split == 'train':\n            pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('0_') and idx%3==0]\n            img_files = [x for idx,x in enumerate(img_files) if x.startswith('0_') and idx%3==0]\n        elif self.split == 'test':\n            pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('2_') and idx%3==0]\n            img_files = [x for idx,x in enumerate(img_files) if x.startswith('2_') and idx%3==0]\n            if len(test_pose_files) == 0:\n                test_pose_files = [x for idx,x in enumerate(pose_files) if x.startswith('1_') and idx%3==0]\n                test_img_files = [x for idx,x in enumerate(img_files) if x.startswith('1_') and idx%3==0]\n            pose_files = test_pose_files\n            img_files = test_img_files\n\n        \n\n        # ray directions for all pixels, same for all images (same H, W, focal)\n        self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2])  # (h, w, 3)\n        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)\n\n        w, h = self.img_wh\n        \n        self.poses = []\n        self.all_rays = []\n        self.all_rgbs = []\n        self.all_masks = []\n\n        assert len(img_files) == len(pose_files)\n        for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):\n            image_path = os.path.join(self.root_dir, 'rgb', img_fname)\n            img = Image.open(image_path)\n            if self.downsample!=1.0:\n                img = img.resize(self.img_wh, Image.LANCZOS)\n            img = self.transform(img)  # (3, h, w)\n            img = img.view(img.shape[0], -1).permute(1, 0)  # (h*w, 3) RGBA\n            mask =  torch.where(\n                img.sum(-1, keepdim=True) == 3.,\n                1.,\n                0.\n            )\n            self.all_masks.append(mask.reshape(h,w,1)) # (h, w, 1) A\n\n            if img.shape[-1]==4:\n                img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB\n            self.all_rgbs.append(img)\n            \n\n            c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans\n            c2w = torch.FloatTensor(c2w)\n            self.poses.append(c2w)  # C2W\n            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)\n            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 8)\n\n        self.poses = torch.stack(self.poses)\n\n        center = torch.mean(self.scene_bbox, dim=0)\n        radius = torch.norm(self.scene_bbox[1]-center)*1.2\n        up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()\n        pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y')\n        self.render_path = gen_path(pos_gen, up=up,frames=100)\n        self.render_path[:, :3, 3] += center\n\n\n        all_rays = self.all_rays\n        all_rgbs = self.all_rgbs\n        self.all_masks = torch.stack(self.all_masks) # (n_frames, h, w, 1)\n        self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w,6)\n        self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)\n\n        if self.is_stack:\n            self.all_rays_stack = torch.stack(all_rays, 0).reshape(-1,*self.img_wh[::-1], 6)  # (len(self.meta['frames]),h,w,6)\n            avg_pool = torch.nn.AvgPool2d(4, ceil_mode=True)\n            self.ds_all_rays_stack = avg_pool(self.all_rays_stack.permute(0,3,1,2)).permute(0,2,3,1) # (len(self.meta['frames]),h/4,w/4,6)\n            self.all_rgbs_stack = torch.stack(all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)\n\n    @torch.no_grad()\n    def prepare_feature_data(self, encoder, chunk=4):\n        '''\n        Prepare feature maps as training data.\n        '''\n        assert self.is_stack, 'Dataset should contain original stacked taining data!'\n        print('====> prepare_feature_data ...')\n\n        frames_num, h, w, _ = self.all_rgbs_stack.size()\n        features = []\n\n        for chunk_idx in tqdm(range(frames_num // chunk + int(frames_num % chunk > 0))):\n            rgbs_chunk = self.all_rgbs_stack[chunk_idx*chunk : (chunk_idx+1)*chunk].cuda()\n            features_chunk = encoder(normalize_vgg(rgbs_chunk.permute(0,3,1,2))).relu3_1\n            # resize to the size of rgb map so that rays can match\n            features_chunk = T.functional.resize(features_chunk, size=(h,w), \n                                                 interpolation=T.InterpolationMode.BILINEAR)\n            features.append(features_chunk.detach().cpu().requires_grad_(False))\n\n        self.all_features_stack = torch.cat(features).permute(0,2,3,1) # (len(self.meta['frames]),h,w,256)\n        self.all_features = self.all_features_stack.reshape(-1, 256)\n        print('prepare_feature_data Done!')\n\n\n    def define_transforms(self):\n        self.transform = T.ToTensor()\n        \n    def define_proj_mat(self):\n        self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]\n\n    def world2ndc(self, points):\n        device = points.device\n        return (points - self.center.to(device)) / self.radius.to(device)\n        \n    def __len__(self):\n        if self.split == 'train':\n            return len(self.all_rays)\n        return len(self.all_rgbs)\n\n    def __getitem__(self, idx):\n\n        if self.split == 'train':  # use data in the buffers\n            sample = {'rays': self.all_rays[idx],\n                      'rgbs': self.all_rgbs[idx]}\n\n        else:  # create data for each image separately\n\n            img = self.all_rgbs[idx]\n            rays = self.all_rays[idx]\n\n            sample = {'rays': rays,\n                      'rgbs': img}\n        return sample"
  },
  {
    "path": "dataLoader/your_own_data.py",
    "content": "import torch,cv2\nfrom torch.utils.data import Dataset\nimport json\nfrom tqdm import tqdm\nimport os\nfrom PIL import Image\nfrom torchvision import transforms as T\n\n\nfrom .ray_utils import *\n\n\nclass YourOwnDataset(Dataset):\n    def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):\n\n        self.N_vis = N_vis\n        self.root_dir = datadir\n        self.split = split\n        self.is_stack = is_stack\n        self.downsample = downsample\n        self.define_transforms()\n\n        self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])\n        self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])\n        self.read_meta()\n        self.define_proj_mat()\n\n        self.white_bg = True\n        self.near_far = [0.1,100.0]\n        \n        self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)\n        self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)\n        self.downsample=downsample\n\n    def read_depth(self, filename):\n        depth = np.array(read_pfm(filename)[0], dtype=np.float32)  # (800, 800)\n        return depth\n    \n    def read_meta(self):\n\n        with open(os.path.join(self.root_dir, f\"transforms_{self.split}.json\"), 'r') as f:\n            self.meta = json.load(f)\n\n        w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample)\n        self.img_wh = [w,h]\n        self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x'])  # original focal length\n        self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y'])  # original focal length\n        self.cx, self.cy = self.meta['cx'],self.meta['cy']\n\n\n        # ray directions for all pixels, same for all images (same H, W, focal)\n        self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy])  # (h, w, 3)\n        self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)\n        self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float()\n\n        self.image_paths = []\n        self.poses = []\n        self.all_rays = []\n        self.all_rgbs = []\n        self.all_masks = []\n        self.all_depth = []\n\n\n        img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis\n        idxs = list(range(0, len(self.meta['frames']), img_eval_interval))\n        for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#\n\n            frame = self.meta['frames'][i]\n            pose = np.array(frame['transform_matrix']) @ self.blender2opencv\n            c2w = torch.FloatTensor(pose)\n            self.poses += [c2w]\n\n            image_path = os.path.join(self.root_dir, f\"{frame['file_path']}.png\")\n            self.image_paths += [image_path]\n            img = Image.open(image_path)\n            \n            if self.downsample!=1.0:\n                img = img.resize(self.img_wh, Image.LANCZOS)\n            img = self.transform(img)  # (4, h, w)\n            img = img.view(-1, w*h).permute(1, 0)  # (h*w, 4) RGBA\n            if img.shape[-1]==4:\n                img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:])  # blend A to RGB\n            self.all_rgbs += [img]\n\n\n            rays_o, rays_d = get_rays(self.directions, c2w)  # both (h*w, 3)\n            self.all_rays += [torch.cat([rays_o, rays_d], 1)]  # (h*w, 6)\n\n\n        self.poses = torch.stack(self.poses)\n        if not self.is_stack:\n            self.all_rays = torch.cat(self.all_rays, 0)  # (len(self.meta['frames])*h*w, 3)\n            self.all_rgbs = torch.cat(self.all_rgbs, 0)  # (len(self.meta['frames])*h*w, 3)\n\n#             self.all_depth = torch.cat(self.all_depth, 0)  # (len(self.meta['frames])*h*w, 3)\n        else:\n            self.all_rays = torch.stack(self.all_rays, 0)  # (len(self.meta['frames]),h*w, 3)\n            self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3)  # (len(self.meta['frames]),h,w,3)\n            # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1])  # (len(self.meta['frames]),h,w,3)\n\n\n    def define_transforms(self):\n        self.transform = T.ToTensor()\n        \n    def define_proj_mat(self):\n        self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]\n\n    def world2ndc(self,points,lindisp=None):\n        device = points.device\n        return (points - self.center.to(device)) / self.radius.to(device)\n        \n    def __len__(self):\n        return len(self.all_rgbs)\n\n    def __getitem__(self, idx):\n\n        if self.split == 'train':  # use data in the buffers\n            sample = {'rays': self.all_rays[idx],\n                      'rgbs': self.all_rgbs[idx]}\n\n        else:  # create data for each image separately\n\n            img = self.all_rgbs[idx]\n            rays = self.all_rays[idx]\n            mask = self.all_masks[idx] # for quantity evaluation\n\n            sample = {'rays': rays,\n                      'rgbs': img}\n        return sample\n"
  },
  {
    "path": "extra/auto_run_paramsets.py",
    "content": "import os\nimport threading, queue\nimport numpy as np\nimport time\n\n\ndef getFolderLocker(logFolder):\n    while True:\n        try:\n            os.makedirs(logFolder+\"/lockFolder\")\n            break\n        except: \n            time.sleep(0.01)\n\ndef releaseFolderLocker(logFolder):\n    os.removedirs(logFolder+\"/lockFolder\")\n\ndef getStopFolder(logFolder):\n    return os.path.isdir(logFolder+\"/stopFolder\")\n\n\ndef get_param_str(key, val):\n    if key == 'data_name':\n        return f'--datadir {datafolder}/{val} '\n    else:\n        return f'--{key} {val} '\n\ndef get_param_list(param_dict):\n    param_keys = list(param_dict.keys())\n    param_modes = len(param_keys)\n    param_nums = [len(param_dict[key]) for key in param_keys]\n    \n    param_ids = np.zeros(param_nums+[param_modes], dtype=int)\n    for i in range(param_modes):\n        broad_tuple = np.ones(param_modes, dtype=int).tolist()\n        broad_tuple[i] = param_nums[i]\n        broad_tuple = tuple(broad_tuple)\n        print(broad_tuple)\n        param_ids[...,i] = np.arange(param_nums[i]).reshape(broad_tuple)\n    param_ids = param_ids.reshape(-1, param_modes)\n    # print(param_ids)\n    print(len(param_ids))\n    \n    params = []\n    expnames = []\n    for i in range(param_ids.shape[0]):\n        one = \"\"\n        name = \"\"\n        param_id = param_ids[i]\n        for j in range(param_modes):\n            key = param_keys[j]\n            val = param_dict[key][param_id[j]]\n            if type(key) is tuple:\n                assert len(key) == len(val)\n                for k in range(len(key)):\n                    one += get_param_str(key[k], val[k])\n                    name += f'{val[k]},'\n                name=name[:-1]+'-'\n            else:\n                one += get_param_str(key, val)\n                name += f'{val}-'\n        params.append(one)\n        name=name.replace(' ','')\n        print(name)\n        expnames.append(name[:-1])\n    # print(params)\n    return params, expnames\n\n\n\n\n\n\n\nif __name__ == '__main__':\n    \n\n\n    # nerf\n    expFolder = \"nerf/\"\n    # parameters to iterate, use tuple to couple multiple parameters\n    datafolder = '/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/'\n    param_dict = {\n        'data_name': ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials'],\n        'data_dim_color': [13, 27, 54]\n    }\n\n    # n_iters = 30000\n    # for data_name in ['Robot']:#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'\n    #     cmd = f'CUDA_VISIBLE_DEVICES={cuda}  python train.py ' \\\n    #           f'--dataset_name nsvf --datadir /mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/{data_name} '\\\n    #           f'--expname {data_name} --batch_size {batch_size} ' \\\n    #           f'--n_iters {n_iters}  ' \\\n    #           f'--N_voxel_init {128**3} --N_voxel_final {300**3} '\\\n    #           f'--N_vis {5}  ' \\\n    #           f'--n_lamb_sigma \"[16,16,16]\" --n_lamb_sh \"[48,48,48]\" ' \\\n    #           f'--upsamp_list \"[2000, 3000, 4000, 5500,7000]\" --update_AlphaMask_list \"[3000,4000]\" ' \\\n    #           f'--shadingMode MLP_Fea --fea2denseAct softplus  --view_pe {2} --fea_pe {2} ' \\\n    #           f'--L1_weight_inital {8e-5} --L1_weight_rest {4e-5} --rm_weight_mask_thre {1e-4} --add_timestamp 0 ' \\\n    #           f'--render_test 1 '\n    #     print(cmd)\n    #     os.system(cmd)\n\n    # nsvf\n    # expFolder = \"nsvf_0227/\"\n    # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/'\n    # param_dict = {\n    #             'data_name': ['Robot','Steamtrain','Bike','Lifestyle','Palace','Spaceship','Toad','Wineholder'],#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'\n    #             'shadingMode': ['SH'],\n    #             ('n_lamb_sigma', 'n_lamb_sh'): [ (\"[8,8,8]\", \"[8,8,8]\")],\n    #             ('view_pe', 'fea_pe', 'featureC','fea2denseAct','N_voxel_init') : [(2, 2, 128, 'softplus',128**3)],\n    #             ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'):[(4e-5, 4e-5, 1e-4)],\n    #             ('n_iters','N_voxel_final'): [(30000,300**3)],\n    #             ('dataset_name','N_vis','render_test') : [(\"nsvf\",5,1)],\n    #             ('upsamp_list','update_AlphaMask_list'): [(\"[2000,3000,4000,5500,7000]\",\"[3000,4000]\")]\n    #\n    #     }\n\n    # tankstemple\n    # expFolder = \"tankstemple_0304/\"\n    # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/TanksAndTemple/'\n    # param_dict = {\n    #             'data_name': ['Truck','Barn','Caterpillar','Family','Ignatius'],\n    #             'shadingMode': ['MLP_Fea'],\n    #             ('n_lamb_sigma', 'n_lamb_sh'): [(\"[16,16,16]\", \"[48,48,48]\")],\n    #             ('view_pe', 'fea_pe','fea2denseAct','N_voxel_init','render_test') : [(2, 2, 'softplus',128**3,1)],\n    #             ('TV_weight_density','TV_weight_app'):[(0.1,0.01)],\n    #             # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'): [(4e-5, 4e-5, 1e-4)],\n    #             ('n_iters','N_voxel_final'): [(15000,300**3)],\n    #             ('dataset_name','N_vis') : [(\"tankstemple\",5)],\n    #             ('upsamp_list','update_AlphaMask_list'): [(\"[2000,3000,4000,5500,7000]\",\"[2000,4000]\")]\n    #     }\n\n    # llff\n    # expFolder = \"real_iconic/\"\n    # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/real_iconic/'\n    # List = os.listdir(datafolder)\n    # param_dict = {\n    #             'data_name': List,\n    #             ('shadingMode', 'view_pe', 'fea_pe','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 'relu',512,128**3)],\n    #             ('n_lamb_sigma', 'n_lamb_sh') : [(\"[16,4,4]\", \"[48,12,12]\")],\n    #             ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],\n    #             ('n_iters','N_voxel_final'): [(25000,640**3)],\n    #             ('dataset_name','downsample_train','ndc_ray','N_vis','render_path') : [(\"llff\",4.0, 1,-1,1)],\n    #             ('upsamp_list','update_AlphaMask_list'): [(\"[2000,3000,4000,5500,7000]\",\"[2500]\")],\n    #     }\n\n    # expFolder = \"llff/\"\n    # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data'\n    # param_dict = {\n    #             'data_name': ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'],#'fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'\n    #             ('n_lamb_sigma', 'n_lamb_sh'): [(\"[16,4,4]\", \"[48,12,12]\")],\n    #             ('shadingMode', 'view_pe', 'fea_pe', 'featureC','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 128, 'relu',512,128**3),('SH', 0, 0, 128, 'relu',512,128**3)],\n    #             ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],\n    #             ('n_iters','N_voxel_final'): [(25000,640**3)],\n    #             ('dataset_name','downsample_train','ndc_ray','N_vis','render_test','render_path') : [(\"llff\",4.0, 1,-1,1,1)],\n    #             ('upsamp_list','update_AlphaMask_list'): [(\"[2000,3000,4000,5500,7000]\",\"[2500]\")],\n    #     }\n\n    #setting available gpus\n    gpus_que = queue.Queue(3)\n    for i in [1,2,3]:\n        gpus_que.put(i)\n    \n    os.makedirs(f\"log/{expFolder}\", exist_ok=True)\n\n    def run_program(gpu, expname, param):\n        cmd = f'CUDA_VISIBLE_DEVICES={gpu}  python train.py ' \\\n            f'--expname {expname} --basedir ./log/{expFolder} --config configs/lego.txt ' \\\n            f'{param}' \\\n            f'> \"log/{expFolder}{expname}/{expname}.txt\"'\n        print(cmd)\n        os.system(cmd)\n        gpus_que.put(gpu)\n\n    params, expnames = get_param_list(param_dict)\n\n    \n    logFolder=f\"log/{expFolder}\"\n    os.makedirs(logFolder, exist_ok=True)\n\n    ths = []\n    for i in range(len(params)):\n\n        if getStopFolder(logFolder):\n            break\n\n\n        targetFolder = f\"log/{expFolder}{expnames[i]}\"\n        gpu = gpus_que.get()\n        getFolderLocker(logFolder)\n        if os.path.isdir(targetFolder):\n            releaseFolderLocker(logFolder)\n            gpus_que.put(gpu)\n            continue\n        else:\n            os.makedirs(targetFolder, exist_ok=True)\n            print(\"making\",targetFolder, \"running\",expnames[i], params[i])\n            releaseFolderLocker(logFolder)\n\n\n        t = threading.Thread(target=run_program, args=(gpu, expnames[i], params[i]), daemon=True)\n        t.start()\n        ths.append(t)\n    \n    for th in ths:\n        th.join()"
  },
  {
    "path": "extra/compute_metrics.py",
    "content": "import os, math\nimport numpy as np\nimport scipy.signal\nfrom typing import List, Optional\nfrom PIL import Image\nimport os\nimport torch\nimport configargparse\n\n__LPIPS__ = {}\ndef init_lpips(net_name, device):\n    assert net_name in ['alex', 'vgg']\n    import lpips\n    print(f'init_lpips: lpips_{net_name}')\n    return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)\n\ndef rgb_lpips(np_gt, np_im, net_name, device):\n    if net_name not in __LPIPS__:\n        __LPIPS__[net_name] = init_lpips(net_name, device)\n    gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)\n    im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)\n    return __LPIPS__[net_name](gt, im, normalize=True).item()\n\n\ndef findItem(items, target):\n    for one in items:\n        if one[:len(target)]==target:\n            return one\n    return None\n\n\n''' Evaluation metrics (ssim, lpips)\n'''\ndef rgb_ssim(img0, img1, max_val,\n             filter_size=11,\n             filter_sigma=1.5,\n             k1=0.01,\n             k2=0.03,\n             return_map=False):\n    # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58\n    assert len(img0.shape) == 3\n    assert img0.shape[-1] == 3\n    assert img0.shape == img1.shape\n\n    # Construct a 1D Gaussian blur filter.\n    hw = filter_size // 2\n    shift = (2 * hw - filter_size + 1) / 2\n    f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2\n    filt = np.exp(-0.5 * f_i)\n    filt /= np.sum(filt)\n\n    # Blur in x and y (faster than the 2D convolution).\n    def convolve2d(z, f):\n        return scipy.signal.convolve2d(z, f, mode='valid')\n\n    filt_fn = lambda z: np.stack([\n        convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])\n        for i in range(z.shape[-1])], -1)\n    mu0 = filt_fn(img0)\n    mu1 = filt_fn(img1)\n    mu00 = mu0 * mu0\n    mu11 = mu1 * mu1\n    mu01 = mu0 * mu1\n    sigma00 = filt_fn(img0**2) - mu00\n    sigma11 = filt_fn(img1**2) - mu11\n    sigma01 = filt_fn(img0 * img1) - mu01\n\n    # Clip the variances and covariances to valid values.\n    # Variance must be non-negative:\n    sigma00 = np.maximum(0., sigma00)\n    sigma11 = np.maximum(0., sigma11)\n    sigma01 = np.sign(sigma01) * np.minimum(\n        np.sqrt(sigma00 * sigma11), np.abs(sigma01))\n    c1 = (k1 * max_val)**2\n    c2 = (k2 * max_val)**2\n    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)\n    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)\n    ssim_map = numer / denom\n    ssim = np.mean(ssim_map)\n    return ssim_map if return_map else ssim\n\n\nif __name__ == '__main__':\n\n    parser = configargparse.ArgumentParser()\n    parser.add_argument(\"--exp\", type=str, help=\"folder of exps\")\n    parser.add_argument(\"--paramStr\", type=str, help=\"str of params\")\n    args = parser.parse_args()\n\n\n    # datanames = ['drums','hotdog','materials','ficus','lego','mic','ship','chair'] #['ship']#\n    # gtFolder = \"/home/code-base/user_space/codes/nerf/data/nerf_synthetic\"\n    # expFolder = \"/home/code-base/user_space/codes/TensoRF/log/\"+args.exp\n\n    # datanames = ['room','fortress', 'flower','orchids','leaves','horns','trex','fern'] #['ship']#\n    # gtFolder = \"/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/\"\n    # expFolder = \"/mnt/new_disk_2/anpei/code/TensoRF/log/\"+args.exp\n    paramStr = args.paramStr\n    fileNum = 200\n\n\n    expitems = os.listdir(expFolder)\n    finalFolder = f'{expFolder}/finals/{paramStr}'\n    outFile = f'{finalFolder}/{paramStr}_metrics.txt'\n    os.makedirs(finalFolder, exist_ok=True)\n\n    expitems.sort(reverse=True)\n\n\n    with open(outFile, 'w') as f:\n        all_psnr = []\n        all_ssim = []\n        all_alex = []\n        all_vgg = []\n        for dataname in datanames:\n            \n\n            gtstr = gtFolder+\"/\"+dataname+\"/test/r_%d.png\"\n            expname = findItem(expitems, f'{paramStr}-{dataname}')\n            print(\"expname: \", expname)\n            if expname is None:\n                print(\"no \",dataname, \"exists\")\n                continue\n            resultstr = expFolder+\"/\"+expname+\"/imgs_test_all/\"+ dataname+\"-\"+paramStr+ \"_%03d.png\"\n            metric_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_mean.txt'\n            video_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_video.mp4'\n            \n            exist_metric=False\n            if os.path.isfile(metric_file):\n                metrics = np.loadtxt(metric_file)\n                print(metrics, metrics.tolist())\n                if metrics.size == 4:\n                    psnr, ssim, l_a, l_v = metrics.tolist()\n                    exist_metric = True\n                    os.system(f\"cp {video_file} {finalFolder}/\")\n\n            if not exist_metric:\n                psnrs = []\n                ssims = []\n                l_alex = []\n                l_vgg = []\n                for i in range(fileNum):\n                    gt = np.asarray(Image.open(gtstr%i),dtype=np.float32) / 255.0\n                    gtmask = gt[...,[3]]\n                    gt = gt[...,:3]\n                    gt = gt*gtmask + (1-gtmask)\n                    img = np.asarray(Image.open(resultstr%i),dtype=np.float32)[...,:3]  / 255.0\n                    # print(gt[0,0],img[0,0],gt.shape, img.shape, gt.max(), img.max())\n\n\n                    psnr = -10. * np.log10(np.mean(np.square(img - gt)))\n                    ssim = rgb_ssim(img, gt, 1)\n                    lpips_alex = rgb_lpips(gt, img, 'alex','cuda')\n                    lpips_vgg = rgb_lpips(gt, img, 'vgg','cuda')\n\n                    print(i, psnr, ssim, lpips_alex, lpips_vgg)\n                    psnrs.append(psnr)\n                    ssims.append(ssim)\n                    l_alex.append(lpips_alex)\n                    l_vgg.append(lpips_vgg)\n                    psnr = np.mean(np.array(psnrs))\n                    ssim = np.mean(np.array(ssims))\n                    l_a  = np.mean(np.array(l_alex))\n                    l_v  = np.mean(np.array(l_vgg))\n\n            rS=f'{dataname} : psnr {psnr} ssim {ssim}  l_a {l_a} l_v {l_v}'\n            print(rS)\n            f.write(rS+\"\\n\")\n\n            all_psnr.append(psnr)\n            all_ssim.append(ssim)\n            all_alex.append(l_a)\n            all_vgg.append(l_v)\n        \n        psnr = np.mean(np.array(all_psnr))\n        ssim = np.mean(np.array(all_ssim))\n        l_a  = np.mean(np.array(all_alex))\n        l_v  = np.mean(np.array(all_vgg))\n\n        rS=f'mean : psnr {psnr} ssim {ssim}  l_a {l_a} l_v {l_v}'\n        print(rS)\n        f.write(rS+\"\\n\")"
  },
  {
    "path": "models/VGG.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom collections import namedtuple\nimport torchvision.models as models\n\n# pytorch pretrained vgg\nclass Encoder(nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        #pretrained vgg19\n        vgg19 = models.vgg19(weights='DEFAULT').features\n\n        self.relu1_1 = vgg19[:2]\n        self.relu2_1 = vgg19[2:7]\n        self.relu3_1 = vgg19[7:12]\n        self.relu4_1 = vgg19[12:21]\n\n        #fix parameters\n        self.requires_grad_(False)\n\n    def forward(self, x):\n        _output = namedtuple('output', ['relu1_1', 'relu2_1', 'relu3_1', 'relu4_1'])\n        relu1_1 = self.relu1_1(x)\n        relu2_1 = self.relu2_1(relu1_1)\n        relu3_1 = self.relu3_1(relu2_1)\n        relu4_1 = self.relu4_1(relu3_1)\n        output = _output(relu1_1, relu2_1, relu3_1, relu4_1)\n\n        return output\n\n\nclass Decoder(nn.Module): \n    \"\"\"\n    starting from relu 4_1\n    \"\"\"\n    def __init__(self, ckpt_path=None):\n        super().__init__()\n        \n        self.layers = nn.Sequential(\n            # nn.Conv2d(512, 256, 3, padding=1, padding_mode='reflect'),\n            # nn.ReLU(),\n            # nn.Upsample(scale_factor=2, mode='nearest'), # relu4-1\n            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), # relu3-4\n            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), # relu3-3\n            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), # relu3-2\n            nn.Conv2d(256, 128, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), \n            nn.Upsample(scale_factor=2, mode='nearest'),# relu3-1\n            nn.Conv2d(128, 128, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), # relu2-2\n            nn.Conv2d(128, 64, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), \n            nn.Upsample(scale_factor=2, mode='nearest'),# relu2-1\n            nn.Conv2d(64, 64, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), # relu1-2\n            nn.Conv2d(64, 3, 3, padding=1, padding_mode='reflect'),\n        )\n\n        if ckpt_path is not None:\n          self.load_state_dict(torch.load(ckpt_path))\n\n    def forward(self, x):\n        return self.layers(x)\n\n\n### high-res unet feature map decoder\n\n\nclass DownBlock(nn.Module):\n\n    def __init__(self, in_dim, out_dim, down='conv'):\n        super(DownBlock, self).__init__()\n\n        if down == 'conv':\n            self.down_conv = nn.Sequential(\n                nn.Conv2d(in_dim, out_dim, 3, 2, 1),\n                nn.LeakyReLU(),\n                nn.Conv2d(out_dim, out_dim, 3, 1, 1),\n                nn.LeakyReLU(),\n            )\n        elif down == 'mean':\n            self.down_conv = nn.AvgPool2d(2)\n        else:\n            raise NotImplementedError(\n                '[ERROR] invalid downsampling operator: {:s}'.format(down)\n            )\n\n    def forward(self, x):\n        x = self.down_conv(x)\n        return x\n\n\nclass UpBlock(nn.Module):\n\n    def __init__(self, in_dim, out_dim, skip_dim=None, up='nearest'):\n        super(UpBlock, self).__init__()\n\n        if up == 'conv':\n            self.up_conv = nn.Sequential(\n                nn.ConvTranspose2d(in_dim, out_dim, 3, 2, 1, 1),\n                nn.ReLU(),\n            )\n        else:\n            assert up in ('bilinear', 'nearest'), \\\n                '[ERROR] invalid upsampling mode: {:s}'.format(up)\n            self.up_conv = nn.Sequential(\n                nn.Upsample(scale_factor=2, mode=up),\n                nn.Conv2d(in_dim, out_dim, 3, 1, 1),\n                nn.ReLU(),\n            )\n        \n        in_dim = out_dim\n        if skip_dim is not None:\n            in_dim += skip_dim\n        self.conv = nn.Sequential(\n            nn.Conv2d(in_dim, out_dim, 3, 1, 1),\n            nn.ReLU(),\n        )\n\n    def _pad(self, x, y):\n        dh = y.size(-2) - x.size(-2)\n        dw = y.size(-1) - x.size(-1)\n        if dh == 0 and dw == 0:\n            return x\n        if dh < 0:\n            x = x[..., :dh, :]\n        if dw < 0:\n            x = x[..., :, :dw]\n        if dh > 0 or dw > 0:\n            x = F.pad(\n                x, \n                pad=(dw // 2, dw - dw // 2, dh // 2, dh - dh // 2), \n                mode='reflect'\n            )\n        return x\n\n    def forward(self, x, skip=None):\n        x = self.up_conv(x)\n        if skip is not None:\n            x = torch.cat([self._pad(x, skip), skip], 1)\n        x = self.conv(x)\n        return x\n\n\nclass UNetDecoder(nn.Module):\n\n    def __init__(self, in_dim=256):\n        super(UNetDecoder, self).__init__()\n\n        self.down_layers = nn.ModuleList()\n        self.skip_convs = nn.ModuleList()\n        self.up_layers = nn.ModuleList()\n\n        in_dim = in_dim\n        self.n_levels = 2\n        self.up = 1\n\n        for i in range(self.n_levels):\n            self.down_layers.append(\n                DownBlock(\n                    in_dim, in_dim,\n                )\n            )\n            out_dim = in_dim // 2 ** (self.n_levels - i)\n            self.skip_convs.append(nn.Conv2d(in_dim, out_dim, 1))\n            self.up_layers.append(\n                UpBlock(\n                    out_dim * 2, out_dim, out_dim,\n                )\n            )\n\n        out_dim = in_dim // 2 ** self.n_levels\n        self.out_conv = nn.Sequential(\n            nn.Conv2d(out_dim, out_dim, 3, 1, 1),\n            nn.ReLU(),\n            nn.Conv2d(out_dim, 3, 1, 1),\n        )\n\n    def forward(self, feats):\n        skips = []\n        for i in range(self.n_levels):\n            skips.append(self.skip_convs[i](feats))\n            feats = self.down_layers[i](feats)\n        for i in range(self.n_levels - 1, -1, -1):\n            feats = self.up_layers[i](feats, skips[i])\n        rgb = self.out_conv(feats)\n        return rgb\n\n\n### high-res feature map decoder\n\nclass PlainDecoder(nn.Module):\n    def __init__(self) -> None:\n        super().__init__()\n\n        self.layers = nn.Sequential(\n            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), # relu3-4\n            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), # relu3-3\n            nn.Conv2d(256, 256, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), # relu3-2\n            nn.Conv2d(256, 128, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), \n            nn.Conv2d(128, 128, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), # relu2-2\n            nn.Conv2d(128, 64, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), \n            nn.Conv2d(64, 64, 3, padding=1, padding_mode='reflect'),\n            nn.ReLU(), # relu1-2\n            nn.Conv2d(64, 3, 3, padding=1, padding_mode='reflect'),\n        )\n\n    def forward(self, x):\n        return self.layers(x)"
  },
  {
    "path": "models/__init__.py",
    "content": ""
  },
  {
    "path": "models/sh.py",
    "content": "import torch\n\n################## sh function ##################\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n    1.0925484305920792,\n    -1.0925484305920792,\n    0.31539156525252005,\n    -1.0925484305920792,\n    0.5462742152960396\n]\nC3 = [\n    -0.5900435899266435,\n    2.890611442640554,\n    -0.4570457994644658,\n    0.3731763325901154,\n    -0.4570457994644658,\n    1.445305721320277,\n    -0.5900435899266435\n]\nC4 = [\n    2.5033429417967046,\n    -1.7701307697799304,\n    0.9461746957575601,\n    -0.6690465435572892,\n    0.10578554691520431,\n    -0.6690465435572892,\n    0.47308734787878004,\n    -1.7701307697799304,\n    0.6258357354491761,\n]\n\ndef eval_sh(deg, sh, dirs):\n    \"\"\"\n    Evaluate spherical harmonics at unit directions\n    using hardcoded SH polynomials.\n    Works with torch/np/jnp.\n    ... Can be 0 or more batch dimensions.\n    :param deg: int SH max degree. Currently, 0-4 supported\n    :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2)\n    :param dirs: torch.Tensor unit directions (..., 3)\n    :return: (..., C)\n    \"\"\"\n    assert deg <= 4 and deg >= 0\n    assert (deg + 1) ** 2 == sh.shape[-1]\n    C = sh.shape[-2]\n\n    result = C0 * sh[..., 0]\n    if deg > 0:\n        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]\n        result = (result -\n                C1 * y * sh[..., 1] +\n                C1 * z * sh[..., 2] -\n                C1 * x * sh[..., 3])\n        if deg > 1:\n            xx, yy, zz = x * x, y * y, z * z\n            xy, yz, xz = x * y, y * z, x * z\n            result = (result +\n                    C2[0] * xy * sh[..., 4] +\n                    C2[1] * yz * sh[..., 5] +\n                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +\n                    C2[3] * xz * sh[..., 7] +\n                    C2[4] * (xx - yy) * sh[..., 8])\n\n            if deg > 2:\n                result = (result +\n                        C3[0] * y * (3 * xx - yy) * sh[..., 9] +\n                        C3[1] * xy * z * sh[..., 10] +\n                        C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +\n                        C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +\n                        C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +\n                        C3[5] * z * (xx - yy) * sh[..., 14] +\n                        C3[6] * x * (xx - 3 * yy) * sh[..., 15])\n                if deg > 3:\n                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +\n                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +\n                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +\n                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +\n                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +\n                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +\n                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +\n                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +\n                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])\n    return result\n\ndef eval_sh_bases(deg, dirs):\n    \"\"\"\n    Evaluate spherical harmonics bases at unit directions,\n    without taking linear combination.\n    At each point, the final result may the be\n    obtained through simple multiplication.\n    :param deg: int SH max degree. Currently, 0-4 supported\n    :param dirs: torch.Tensor (..., 3) unit directions\n    :return: torch.Tensor (..., (deg+1) ** 2)\n    \"\"\"\n    assert deg <= 4 and deg >= 0\n    result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device)\n    result[..., 0] = C0\n    if deg > 0:\n        x, y, z = dirs.unbind(-1)\n        result[..., 1] = -C1 * y;\n        result[..., 2] = C1 * z;\n        result[..., 3] = -C1 * x;\n        if deg > 1:\n            xx, yy, zz = x * x, y * y, z * z\n            xy, yz, xz = x * y, y * z, x * z\n            result[..., 4] = C2[0] * xy;\n            result[..., 5] = C2[1] * yz;\n            result[..., 6] = C2[2] * (2.0 * zz - xx - yy);\n            result[..., 7] = C2[3] * xz;\n            result[..., 8] = C2[4] * (xx - yy);\n\n            if deg > 2:\n                result[..., 9] = C3[0] * y * (3 * xx - yy);\n                result[..., 10] = C3[1] * xy * z;\n                result[..., 11] = C3[2] * y * (4 * zz - xx - yy);\n                result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy);\n                result[..., 13] = C3[4] * x * (4 * zz - xx - yy);\n                result[..., 14] = C3[5] * z * (xx - yy);\n                result[..., 15] = C3[6] * x * (xx - 3 * yy);\n\n                if deg > 3:\n                    result[..., 16] = C4[0] * xy * (xx - yy);\n                    result[..., 17] = C4[1] * yz * (3 * xx - yy);\n                    result[..., 18] = C4[2] * xy * (7 * zz - 1);\n                    result[..., 19] = C4[3] * yz * (7 * zz - 3);\n                    result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3);\n                    result[..., 21] = C4[5] * xz * (7 * zz - 3);\n                    result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1);\n                    result[..., 23] = C4[7] * xz * (xx - 3 * yy);\n                    result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy));\n    return result\n"
  },
  {
    "path": "models/styleModules.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport random\n\ndef calc_mean_std(x, eps=1e-8):\n        \"\"\"\n        calculating channel-wise instance mean and standard variance\n        x: shape of (N,C,*)\n        \"\"\"\n        mean = torch.mean(x.flatten(2), dim=-1, keepdim=True) # size of (N, C, 1)\n        std = torch.std(x.flatten(2), dim=-1, keepdim=True) + eps # size of (N, C, 1)\n        \n        return mean, std\n\ndef cal_adain_style_loss(x, y):\n    \"\"\"\n    style loss in one layer\n\n    Args:\n        x, y: feature maps of size [N, C, H, W]\n    \"\"\"\n    x_mean, x_std = calc_mean_std(x)\n    y_mean, y_std = calc_mean_std(y)\n\n    return nn.functional.mse_loss(x_mean, y_mean) \\\n         + nn.functional.mse_loss(x_std, y_std)\n\ndef cal_mse_content_loss(x, y):\n    return nn.functional.mse_loss(x, y)\n\n\n\nclass LearnableIN(nn.Module):\n    '''\n    Input: (N, C, L) or (C, L)\n    '''\n    def __init__(self, dim=256):\n        super().__init__()\n        self.IN = torch.nn.InstanceNorm1d(dim, momentum=1e-4, track_running_stats=True)\n\n    def forward(self, x):\n        if x.size()[-1] <= 1:\n            return x\n        return self.IN(x)\n\nclass SimpleLinearStylizer(nn.Module):\n    def __init__(self, input_dim=256, embed_dim=32, n_layers=3 ) -> None:\n        super().__init__()\n        self.input_dim = input_dim\n        self.embed_dim = embed_dim\n\n        self.IN = LearnableIN(input_dim)\n\n        self.q_embed = nn.Conv1d(input_dim, embed_dim, 1)\n        self.k_embed = nn.Conv1d(input_dim, embed_dim, 1)\n        self.v_embed = nn.Conv1d(input_dim, embed_dim, 1)\n\n        self.unzipper = nn.Conv1d(embed_dim, input_dim, 1, bias=0)\n\n        s_net = []\n        for i in range(n_layers - 1):\n            out_dim = max(embed_dim, input_dim // 2)\n            s_net.append(\n                nn.Sequential(\n                    nn.Conv1d(input_dim, out_dim, 1),\n                    nn.ReLU(inplace=True),\n                )\n            )\n            input_dim = out_dim\n        s_net.append(nn.Conv1d(input_dim, embed_dim, 1))\n        self.s_net = nn.Sequential(*s_net)\n\n        self.s_fc = nn.Linear(embed_dim ** 2, embed_dim ** 2)\n\n    def _vectorized_covariance(self, x):\n        cov = torch.bmm(x, x.transpose(2, 1)) / x.size(-1)\n        cov = cov.flatten(1)\n        return cov\n\n    def get_content_matrix(self, c):\n        '''\n        Args:\n            c: content feature [N,input_dim,S]\n        Return:\n            mat: [N,S,embed_dim,embed_dim]\n        '''\n        normalized_c = self.IN(c)\n        # normalized_c = torch.nn.functional.instance_norm(c)\n        q_embed = self.q_embed(normalized_c)\n        k_embed = self.k_embed(normalized_c)\n        \n        c_cov = q_embed.transpose(1,2).unsqueeze(3) * k_embed.transpose(1,2).unsqueeze(2) # [N,S,embed_dim,embed_dim]\n        attn = torch.softmax(c_cov, -1) # [N,S,embed_dim,embed_dim]\n\n        return attn, normalized_c\n\n    def get_style_mean_std_matrix(self, s):\n        '''\n        Args:\n            s: style feature [N,input_dim,S]\n\n        Return:\n            mat: [N,embed_dim,embed_dim]\n        '''\n        s_mean = s.mean(-1, keepdim=True)\n        s_std = s.std(-1, keepdim=True)\n        s = s - s_mean\n\n        s_embed = self.s_net(s)\n        s_cov = self._vectorized_covariance(s_embed)\n        s_mat = self.s_fc(s_cov)\n        s_mat = s_mat.reshape(-1, self.embed_dim, self.embed_dim)\n\n        return s_mean, s_std, s_mat\n\n    def transform_content_3D(self, c):\n        '''\n        Args:\n            c: content feature [N,input_dim,S]\n        Return:\n            transformed_c: [N,embed_dim,S]\n        '''\n        attn, normalized_c = self.get_content_matrix(c) # [N,S,embed_dim,embed_dim]\n        c = self.v_embed(normalized_c) # [N,embed_dim,S]\n        c = c.transpose(1,2).unsqueeze(3) # [N,S,embed_dim,1]\n        c = torch.matmul(attn, c).squeeze(3) # [N,S,embed_dim]\n\n        return c.transpose(1,2)\n\n    def transfer_style_2D(self, s_mean_std_mat, c, acc_map):\n        '''\n        Agrs:\n            c: content feature map after volume rendering [N,embed_dim,S]\n            s_mat: style matrix [N,embed_dim,embed_dim]\n            acc_map: [S]\n            \n            s_mean = [N,input_dim,1]\n            s_std = [N,input_dim,1]\n        '''\n        s_mean, s_std, s_mat = s_mean_std_mat\n\n        cs = torch.bmm(s_mat, c) # [N,embed_dim,S]\n        cs = self.unzipper(cs) # [N,input_dim,S]\n\n        cs = cs * s_std + s_mean * acc_map[None,None,...]\n\n        return cs\n\n\nclass AdaAttN(nn.Module):\n    \"\"\" Attention-weighted AdaIN (Liu et al., ICCV 21) \"\"\"\n\n    def __init__(self, qk_dim, v_dim):\n        \"\"\"\n        Args:\n            qk_dim (int): query and key size.\n            v_dim (int): value size.\n        \"\"\"\n        super(AdaAttN, self).__init__()\n\n        self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1)\n        self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1)\n        self.s_embed = nn.Conv1d(v_dim, v_dim, 1)\n\n    def forward(self, q, k):\n        \"\"\"\n        Args:\n            q (float tensor, (bs, qk, *)): query (content) features.\n            k (float tensor, (bs, qk, *)): key (style) features.\n            c (float tensor, (bs, v, *)): content value features.\n            s (float tensor, (bs, v, *)): style value features.\n\n        Returns:\n            cs (float tensor, (bs, v, *)): stylized content features.\n        \"\"\"\n        c, s = q, k\n\n        shape = c.shape\n        q, k = q.flatten(2), k.flatten(2)\n        c, s = c.flatten(2), s.flatten(2)\n\n        # QKV attention with projected content and style features\n        q = self.q_embed(F.instance_norm(q)).transpose(2, 1)    # (bs, n, qk)\n        k = self.k_embed(F.instance_norm(k))                    # (bs, qk, m)\n        s = self.s_embed(s).transpose(2, 1)                     # (bs, m, v)\n        attn = F.softmax(torch.bmm(q, k), -1)                   # (bs, n, m)\n        \n        # attention-weighted channel-wise statistics\n        mean = torch.bmm(attn, s)                               # (bs, n, v)\n        var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2)       # (bs, n, v)\n        mean = mean.transpose(2, 1)                             # (bs, v, n)\n        std = torch.sqrt(var).transpose(2, 1)                   # (bs, v, n)\n        \n        cs = F.instance_norm(c) * std + mean                    # (bs, v, n)\n        cs = cs.reshape(shape)\n        return cs\n\nclass AdaAttN_new_IN(nn.Module):\n    \"\"\" Attention-weighted AdaIN (Liu et al., ICCV 21) \"\"\"\n\n    def __init__(self, qk_dim, v_dim):\n        \"\"\"\n        Args:\n            qk_dim (int): query and key size.\n            v_dim (int): value size.\n        \"\"\"\n        super(AdaAttN_new_IN, self).__init__()\n\n        self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1)\n        self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1)\n        self.s_embed = nn.Conv1d(v_dim, v_dim, 1)\n        self.IN = LearnableIN(qk_dim)\n\n    def forward(self, q, k):\n        \"\"\"\n        Args:\n            q (float tensor, (bs, qk, *)): query (content) features.\n            k (float tensor, (bs, qk, *)): key (style) features.\n            c (float tensor, (bs, v, *)): content value features.\n            s (float tensor, (bs, v, *)): style value features.\n\n        Returns:\n            cs (float tensor, (bs, v, *)): stylized content features.\n        \"\"\"\n        c, s = q, k\n\n        shape = c.shape\n        q, k = q.flatten(2), k.flatten(2)\n        c, s = c.flatten(2), s.flatten(2)\n\n        # QKV attention with projected content and style features\n        q = self.q_embed(self.IN(q)).transpose(2, 1)    # (bs, n, qk)\n        k = self.k_embed(F.instance_norm(k))                    # (bs, qk, m)\n        s = self.s_embed(s).transpose(2, 1)                     # (bs, m, v)\n        attn = F.softmax(torch.bmm(q, k), -1)                   # (bs, n, m)\n        \n        # attention-weighted channel-wise statistics\n        mean = torch.bmm(attn, s)                               # (bs, n, v)\n        var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2)       # (bs, n, v)\n        mean = mean.transpose(2, 1)                             # (bs, v, n)\n        std = torch.sqrt(var).transpose(2, 1)                   # (bs, v, n)\n        \n        cs = self.IN(c) * std + mean                    # (bs, v, n)\n        cs = cs.reshape(shape)\n        return cs\n\nclass AdaAttN_woin(nn.Module):\n    \"\"\" Attention-weighted AdaIN (Liu et al., ICCV 21) \"\"\"\n\n    def __init__(self, qk_dim, v_dim):\n        \"\"\"\n        Args:\n            qk_dim (int): query and key size.\n            v_dim (int): value size.\n        \"\"\"\n        super().__init__()\n\n        self.q_embed = nn.Conv1d(qk_dim, qk_dim, 1)\n        self.k_embed = nn.Conv1d(qk_dim, qk_dim, 1)\n        self.s_embed = nn.Conv1d(v_dim, v_dim, 1)\n\n    def forward(self, q, k):\n        \"\"\"\n        Args:\n            q (float tensor, (bs, qk, *)): query (content) features.\n            k (float tensor, (bs, qk, *)): key (style) features.\n            c (float tensor, (bs, v, *)): content value features.\n            s (float tensor, (bs, v, *)): style value features.\n\n        Returns:\n            cs (float tensor, (bs, v, *)): stylized content features.\n        \"\"\"\n        c, s = q, k\n\n        shape = c.shape\n        q, k = q.flatten(2), k.flatten(2)\n        c, s = c.flatten(2), s.flatten(2)\n\n        # QKV attention with projected content and style features\n        q = self.q_embed(q).transpose(2, 1)    # (bs, n, qk)\n        k = self.k_embed(k)                    # (bs, qk, m)\n        s = self.s_embed(s).transpose(2, 1)                     # (bs, m, v)\n        attn = F.softmax(torch.bmm(q, k), -1)                   # (bs, n, m)\n        \n        # attention-weighted channel-wise statistics\n        mean = torch.bmm(attn, s)                               # (bs, n, v)\n        var = F.relu(torch.bmm(attn, s ** 2) - mean ** 2)       # (bs, n, v)\n        mean = mean.transpose(2, 1)                             # (bs, v, n)\n        std = torch.sqrt(var).transpose(2, 1)                   # (bs, v, n)\n        \n        cs = c * std + mean                    # (bs, v, n)\n        cs = cs.reshape(shape)\n        return cs"
  },
  {
    "path": "models/tensoRF.py",
    "content": "from .tensorBase import *\nfrom .VGG import Encoder, Decoder, UNetDecoder, PlainDecoder\nfrom .styleModules import LearnableIN, SimpleLinearStylizer\n\nclass TensorVMSplit(TensorBase):\n    def __init__(self, aabb, gridSize, device, **kargs):\n        super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs)\n\n\n\n    def change_to_feature_mod(self, feature_n_comp, device):\n        self.density_line.requires_grad_(False)\n        self.density_plane.requires_grad_(False)\n        self.app_line = None\n        self.app_plane = None\n        self.basis_mat = None\n        self.renderModule = None\n\n        # Both encoder and decoder do not require grad when initialized\n        self.encoder = Encoder().to(device)\n        self.decoder = PlainDecoder().to(device)\n\n        # We need to finetune decoder when training a feature grid\n        self.decoder.requires_grad_(True)\n\n        self.feature_n_comp = feature_n_comp\n        self.init_feature_svd(device)\n\n    def change_to_style_mod(self, device='cuda'):\n        assert self.feature_line is not None, 'Have to be trained in feature mod first!'\n        self.feature_line.requires_grad_(False)\n        self.feature_plane.requires_grad_(False)\n        self.feature_basis_mat.requires_grad_(False)\n        self.decoder.requires_grad_(True)\n\n        self.IN = LearnableIN().to(device)\n\n        # self.stylizer = NearestFeatureTransform()\n        # self.stylizer = LearnableIN(1,256,device)\n        # self.stylizer = AdaAttN(256, 256).to(device)\n        # self.stylizer = AdaAttN_woin(256, 256).to(device)\n        self.stylizer = SimpleLinearStylizer(256).to(device)\n\n\n    def init_svd_volume(self, res, device):\n        self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device)\n        self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device)\n        self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device)\n\n    def init_feature_svd(self, device):\n        self.feature_plane, self.feature_line = self.init_one_svd(self.feature_n_comp, self.gridSize, 0.1, device)\n        self.feature_basis_mat = torch.nn.Linear(sum(self.feature_n_comp), 256, bias=False).to(device)\n\n    def init_one_svd(self, n_component, gridSize, scale, device):\n        plane_coef, line_coef = [], []\n        for i in range(len(self.vecMode)):\n            vec_id = self.vecMode[i]\n            mat_id_0, mat_id_1 = self.matMode[i]\n            plane_coef.append(torch.nn.Parameter(\n                scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0]))))  \n            line_coef.append(\n                torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1))))\n\n        return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device)\n    \n    \n\n    def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):\n        grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, {'params': self.density_plane, 'lr': lr_init_spatialxyz},\n                     {'params': self.app_line, 'lr': lr_init_spatialxyz}, {'params': self.app_plane, 'lr': lr_init_spatialxyz},\n                         {'params': self.basis_mat.parameters(), 'lr':lr_init_network}]\n        if isinstance(self.renderModule, torch.nn.Module):\n            grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}]\n        return grad_vars\n\n    def get_optparam_groups_feature_mod(self, lr_init_spatialxyz, lr_init_network):\n        grad_vars = [{'params': self.feature_line, 'lr': lr_init_spatialxyz}, \n                     {'params': self.feature_plane, 'lr': lr_init_spatialxyz},\n                     {'params': self.feature_basis_mat.parameters(), 'lr':lr_init_network},\n                     {'params': self.decoder.parameters(), 'lr':lr_init_network}]\n        return grad_vars\n\n    def get_optparam_groups_style_mod(self, lr_init_network, lr_finetune):\n        grad_vars = [\n                        {'params': self.stylizer.parameters(), 'lr': lr_init_network},\n                        {'params': self.decoder.parameters(), 'lr': lr_finetune}, \n                    ]\n        return grad_vars\n\n\n\n    def vectorDiffs(self, vector_comps):\n        total = 0\n        \n        for idx in range(len(vector_comps)):\n            n_comp, n_size = vector_comps[idx].shape[1:-1]\n            \n            dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2))\n            non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1]\n            total = total + torch.mean(torch.abs(non_diagonal))\n        return total\n\n    def vector_comp_diffs(self):\n        return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line)\n    \n    def density_L1(self):\n        total = 0\n        for idx in range(len(self.density_plane)):\n            total = total + torch.mean(torch.abs(self.density_plane[idx])) + torch.mean(torch.abs(self.density_line[idx]))# + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.density_plane[idx]))\n        return total\n    \n    def TV_loss_density(self, reg):\n        total = 0\n        for idx in range(len(self.density_plane)):\n            total = total + reg(self.density_plane[idx]) * 1e-2 #+ reg(self.density_line[idx]) * 1e-3\n        return total\n        \n    def TV_loss_app(self, reg):\n        total = 0\n        for idx in range(len(self.app_plane)):\n            total = total + reg(self.app_plane[idx]) * 1e-2 #+ reg(self.app_line[idx]) * 1e-3\n        return total\n\n    def TV_loss_feature(self, reg):\n        total = 0\n        for idx in range(len(self.feature_plane)):\n            total = total + reg(self.feature_plane[idx]) + reg(self.feature_line[idx])\n        return total\n\n\n\n    def compute_densityfeature(self, xyz_sampled):\n\n        # plane + line basis\n        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)\n        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))\n        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)\n\n        sigma_feature = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device)\n        for idx_plane in range(len(self.density_plane)):\n            plane_coef_point = F.grid_sample(self.density_plane[idx_plane], coordinate_plane[[idx_plane]],\n                                                align_corners=True).view(-1, *xyz_sampled.shape[:1])\n            line_coef_point = F.grid_sample(self.density_line[idx_plane], coordinate_line[[idx_plane]],\n                                            align_corners=True).view(-1, *xyz_sampled.shape[:1])\n            sigma_feature = sigma_feature + torch.sum(plane_coef_point * line_coef_point, dim=0)\n\n        return sigma_feature\n\n    def compute_appfeature(self, xyz_sampled):\n\n        # plane + line basis\n        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)\n        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))\n        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)\n\n        plane_coef_point,line_coef_point = [],[]\n        for idx_plane in range(len(self.app_plane)):\n            plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]],\n                                                align_corners=True).view(-1, *xyz_sampled.shape[:1]))\n            line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]],\n                                            align_corners=True).view(-1, *xyz_sampled.shape[:1]))\n        plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)\n\n\n        return self.basis_mat((plane_coef_point * line_coef_point).T)\n\n    def compute_feature(self, xyz_sampled):\n\n        # plane + line basis\n        coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)\n        coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))\n        coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)\n\n        plane_coef_point,line_coef_point = [],[]\n        for idx_plane in range(len(self.feature_plane)):\n            plane_coef_point.append(F.grid_sample(self.feature_plane[idx_plane], coordinate_plane[[idx_plane]],\n                                                align_corners=True).view(-1, *xyz_sampled.shape[:1]))\n            line_coef_point.append(F.grid_sample(self.feature_line[idx_plane], coordinate_line[[idx_plane]],\n                                            align_corners=True).view(-1, *xyz_sampled.shape[:1]))\n        plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point)\n\n\n        return self.feature_basis_mat((plane_coef_point * line_coef_point).T)\n        \n\n\n    def render_feature_map(self, rays_chunk, s_mean_std_mat=None, is_train=False, ndc_ray=False, N_samples=-1):\n\n        # sample points\n        viewdirs = rays_chunk[:, 3:6]\n        if ndc_ray:\n            xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)\n            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)\n            rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)\n            dists = dists * rays_norm\n        else:\n            xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)\n            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)\n        \n        if self.alphaMask is not None:\n            alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])\n            alpha_mask = alphas > 0\n            ray_invalid = ~ray_valid\n            ray_invalid[ray_valid] |= (~alpha_mask)\n            ray_valid = ~ray_invalid\n\n\n        sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)\n        if s_mean_std_mat is not None:\n            features = torch.zeros((*xyz_sampled.shape[:2], self.stylizer.embed_dim), device=xyz_sampled.device)\n        else:\n            features = torch.zeros((*xyz_sampled.shape[:2], 256), device=xyz_sampled.device)\n\n\n        if ray_valid.any():\n            xyz_sampled = self.normalize_coord(xyz_sampled)\n            sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])\n\n            validsigma = self.feature2density(sigma_feature)\n            sigma[ray_valid] = validsigma\n\n\n        alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)\n\n        app_mask = weight > self.rayMarch_weight_thres\n\n        if app_mask.any():\n            valid_features = self.compute_feature(xyz_sampled[app_mask]) # [n_valid_points~40k if not specify nSamples, C=256]\n            \n            # transform content on 3d\n            if s_mean_std_mat is not None:\n                valid_features = self.stylizer.transform_content_3D(valid_features.transpose(0,1)[None,...])\n                valid_features = valid_features.squeeze(0).transpose(0,1)\n\n            features[app_mask] = valid_features\n\n        feature_map = torch.sum(weight[..., None] * features, -2)\n        acc_map = torch.sum(weight, -1)\n        \n        # style transfer on 2d\n        if s_mean_std_mat is not None:\n            feature_map = self.stylizer.transfer_style_2D(s_mean_std_mat, feature_map.transpose(0,1)[None,...], acc_map)\n            feature_map = feature_map.squeeze().transpose(0,1)\n\n        return feature_map, acc_map\n\n    def render_depth_map(self, rays_chunk, is_train=False, ndc_ray=False, N_samples=-1):\n\n        # sample points\n        viewdirs = rays_chunk[:, 3:6]\n        if ndc_ray:\n            xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)\n            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)\n            rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)\n            dists = dists * rays_norm\n        else:\n            xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)\n            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)\n        \n        if self.alphaMask is not None:\n            alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])\n            alpha_mask = alphas > 0\n            ray_invalid = ~ray_valid\n            ray_invalid[ray_valid] |= (~alpha_mask)\n            ray_valid = ~ray_invalid\n\n        sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)\n\n        if ray_valid.any():\n            xyz_sampled = self.normalize_coord(xyz_sampled)\n            sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])\n            validsigma = self.feature2density(sigma_feature)\n            sigma[ray_valid] = validsigma\n\n        alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)\n\n        acc_map = torch.sum(weight, -1)\n\n        depth_map = torch.sum(weight * z_vals, -1)\n        depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]\n\n        return depth_map # [n_rays]\n\n\n\n    @torch.no_grad()\n    def up_sampling_VM(self, plane_coef, line_coef, res_target):\n\n        for i in range(len(self.vecMode)):\n            vec_id = self.vecMode[i]\n            mat_id_0, mat_id_1 = self.matMode[i]\n            plane_coef[i] = torch.nn.Parameter(\n                F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear',\n                              align_corners=True))\n            line_coef[i] = torch.nn.Parameter(\n                F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))\n\n\n        return plane_coef, line_coef\n\n    @torch.no_grad()\n    def upsample_volume_grid(self, res_target):\n        self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target)\n        self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target)\n\n        self.update_stepSize(res_target)\n        print(f'upsamping to {res_target}')\n\n    @torch.no_grad()\n    def shrink(self, new_aabb):\n        print(\"====> shrinking ...\")\n        xyz_min, xyz_max = new_aabb\n        t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units\n        # print(new_aabb, self.aabb)\n        # print(t_l, b_r,self.alphaMask.alpha_volume.shape)\n        t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1\n        b_r = torch.stack([b_r, self.gridSize]).amin(0)\n\n        for i in range(len(self.vecMode)):\n            mode0 = self.vecMode[i]\n            self.density_line[i] = torch.nn.Parameter(\n                self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:]\n            )\n            self.app_line[i] = torch.nn.Parameter(\n                self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:]\n            )\n            mode0, mode1 = self.matMode[i]\n            self.density_plane[i] = torch.nn.Parameter(\n                self.density_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]\n            )\n            self.app_plane[i] = torch.nn.Parameter(\n                self.app_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]]\n            )\n\n\n        if not torch.all(self.alphaMask.gridSize == self.gridSize):\n            t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)\n            correct_aabb = torch.zeros_like(new_aabb)\n            correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1]\n            correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1]\n            print(\"aabb\", new_aabb, \"\\ncorrect aabb\", correct_aabb)\n            new_aabb = correct_aabb\n\n        newSize = b_r - t_l\n        self.aabb = new_aabb\n        self.update_stepSize((newSize[0], newSize[1], newSize[2]))\n"
  },
  {
    "path": "models/tensorBase.py",
    "content": "import torch\nimport torch.nn\nimport torch.nn.functional as F\nfrom .sh import eval_sh_bases\nimport numpy as np\nimport time\n\n\ndef positional_encoding(positions, freqs):\n    \n        freq_bands = (2**torch.arange(freqs).float()).to(positions.device)  # (F,)\n        pts = (positions[..., None] * freq_bands).reshape(\n            positions.shape[:-1] + (freqs * positions.shape[-1], ))  # (..., DF)\n        pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)\n        return pts\n\ndef raw2alpha(sigma, dist):\n    # sigma, dist  [N_rays, N_samples]\n    alpha = 1. - torch.exp(-sigma*dist)\n\n    T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1)\n\n    weights = alpha * T[:, :-1]  # [N_rays, N_samples]\n    return alpha, weights, T[:,-1:]\n\n\ndef SHRender(xyz_sampled, viewdirs, features):\n    sh_mult = eval_sh_bases(2, viewdirs)[:, None]\n    rgb_sh = features.view(-1, 3, sh_mult.shape[-1])\n    rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5)\n    return rgb\n\n\ndef RGBRender(xyz_sampled, viewdirs, features):\n\n    rgb = features\n    return rgb\n\nclass AlphaGridMask(torch.nn.Module):\n    def __init__(self, device, aabb, alpha_volume):\n        super(AlphaGridMask, self).__init__()\n        self.device = device\n\n        self.aabb=aabb.to(self.device)\n        self.aabbSize = self.aabb[1] - self.aabb[0]\n        self.invgridSize = 1.0/self.aabbSize * 2\n        self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:])\n        self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device)\n\n    def sample_alpha(self, xyz_sampled):\n        xyz_sampled = self.normalize_coord(xyz_sampled)\n        alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1)\n\n        return alpha_vals\n\n    def normalize_coord(self, xyz_sampled):\n        return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1\n\n\nclass MLPRender_Fea(torch.nn.Module):\n    def __init__(self,inChanel, viewpe=6, feape=6, featureC=128):\n        super(MLPRender_Fea, self).__init__()\n\n        self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel\n        self.viewpe = viewpe\n        self.feape = feape\n        layer1 = torch.nn.Linear(self.in_mlpC, featureC)\n        layer2 = torch.nn.Linear(featureC, featureC)\n        layer3 = torch.nn.Linear(featureC,3)\n\n        self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)\n        torch.nn.init.constant_(self.mlp[-1].bias, 0)\n\n    def forward(self, pts, viewdirs, features):\n        indata = [features, viewdirs]\n        if self.feape > 0:\n            indata += [positional_encoding(features, self.feape)]\n        if self.viewpe > 0:\n            indata += [positional_encoding(viewdirs, self.viewpe)]\n        mlp_in = torch.cat(indata, dim=-1)\n        rgb = self.mlp(mlp_in)\n        rgb = torch.sigmoid(rgb)\n\n        return rgb\n\nclass MLPRender_PE(torch.nn.Module):\n    def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128):\n        super(MLPRender_PE, self).__init__()\n\n        self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3)  + inChanel #\n        self.viewpe = viewpe\n        self.pospe = pospe\n        layer1 = torch.nn.Linear(self.in_mlpC, featureC)\n        layer2 = torch.nn.Linear(featureC, featureC)\n        layer3 = torch.nn.Linear(featureC,3)\n\n        self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)\n        torch.nn.init.constant_(self.mlp[-1].bias, 0)\n\n    def forward(self, pts, viewdirs, features):\n        indata = [features, viewdirs]\n        if self.pospe > 0:\n            indata += [positional_encoding(pts, self.pospe)]\n        if self.viewpe > 0:\n            indata += [positional_encoding(viewdirs, self.viewpe)]\n        mlp_in = torch.cat(indata, dim=-1)\n        rgb = self.mlp(mlp_in)\n        rgb = torch.sigmoid(rgb)\n\n        return rgb\n\nclass MLPRender(torch.nn.Module):\n    def __init__(self,inChanel, viewpe=6, featureC=128):\n        super(MLPRender, self).__init__()\n\n        self.in_mlpC = (3+2*viewpe*3) + inChanel\n        self.viewpe = viewpe\n        \n        layer1 = torch.nn.Linear(self.in_mlpC, featureC)\n        layer2 = torch.nn.Linear(featureC, featureC)\n        layer3 = torch.nn.Linear(featureC,3)\n\n        self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)\n        torch.nn.init.constant_(self.mlp[-1].bias, 0)\n\n    def forward(self, pts, viewdirs, features):\n        indata = [features, viewdirs]\n        if self.viewpe > 0:\n            indata += [positional_encoding(viewdirs, self.viewpe)]\n        mlp_in = torch.cat(indata, dim=-1)\n        rgb = self.mlp(mlp_in)\n        rgb = torch.sigmoid(rgb)\n\n        return rgb\n\n\n\nclass TensorBase(torch.nn.Module):\n    def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27,\n                    shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0],\n                    density_shift = -10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001,\n                    pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0,\n                    fea2denseAct = 'softplus'):\n        super(TensorBase, self).__init__()\n\n        self.density_n_comp = density_n_comp\n        self.app_n_comp = appearance_n_comp\n        self.app_dim = app_dim\n        self.aabb = aabb\n        self.alphaMask = alphaMask\n        self.device=device\n\n        self.density_shift = density_shift\n        self.alphaMask_thres = alphaMask_thres\n        self.distance_scale = distance_scale\n        self.rayMarch_weight_thres = rayMarch_weight_thres\n        self.fea2denseAct = fea2denseAct\n\n        self.near_far = near_far\n        self.step_ratio = step_ratio\n\n\n        self.update_stepSize(gridSize)\n\n        self.matMode = [[0,1], [0,2], [1,2]]\n        self.vecMode =  [2, 1, 0]\n        self.comp_w = [1,1,1]\n\n\n        self.init_svd_volume(gridSize[0], device)\n\n        self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC\n        self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC, device)\n\n    def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device):\n        if shadingMode == 'MLP_PE':\n            self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC).to(device)\n        elif shadingMode == 'MLP_Fea':\n            self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC).to(device)\n        elif shadingMode == 'MLP':\n            self.renderModule = MLPRender(self.app_dim, view_pe, featureC).to(device)\n        elif shadingMode == 'SH':\n            self.renderModule = SHRender\n        elif shadingMode == 'RGB':\n            assert self.app_dim == 3\n            self.renderModule = RGBRender\n        else:\n            print(\"Unrecognized shading module\")\n            exit()\n        print(\"pos_pe\", pos_pe, \"view_pe\", view_pe, \"fea_pe\", fea_pe)\n\n    def update_stepSize(self, gridSize):\n        print(\"aabb\", self.aabb.view(-1))\n        print(\"grid size\", gridSize)\n        self.aabbSize = self.aabb[1] - self.aabb[0]\n        self.invaabbSize = 2.0/self.aabbSize\n        self.gridSize= torch.LongTensor(gridSize).to(self.device)\n        self.units=self.aabbSize / (self.gridSize-1)\n        self.stepSize=torch.mean(self.units)*self.step_ratio\n        self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize)))\n        self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1\n        print(\"sampling step size: \", self.stepSize)\n        print(\"sampling number: \", self.nSamples)\n\n    def init_svd_volume(self, res, device):\n        pass\n\n    def compute_features(self, xyz_sampled):\n        pass\n    \n    def compute_densityfeature(self, xyz_sampled):\n        pass\n    \n    def compute_appfeature(self, xyz_sampled):\n        pass\n    \n    def normalize_coord(self, xyz_sampled):\n        return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1\n\n    def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001):\n        pass\n\n    def get_kwargs(self):\n        return {\n            'aabb': self.aabb,\n            'gridSize':self.gridSize.tolist(),\n            'density_n_comp': self.density_n_comp,\n            'appearance_n_comp': self.app_n_comp,\n            'app_dim': self.app_dim,\n\n            'density_shift': self.density_shift,\n            'alphaMask_thres': self.alphaMask_thres,\n            'distance_scale': self.distance_scale,\n            'rayMarch_weight_thres': self.rayMarch_weight_thres,\n            'fea2denseAct': self.fea2denseAct,\n\n            'near_far': self.near_far,\n            'step_ratio': self.step_ratio,\n\n            'shadingMode': self.shadingMode,\n            'pos_pe': self.pos_pe,\n            'view_pe': self.view_pe,\n            'fea_pe': self.fea_pe,\n            'featureC': self.featureC\n        }\n\n    def save(self, path):\n        kwargs = self.get_kwargs()\n        ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()}\n        if self.alphaMask is not None:\n            alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy()\n            ckpt.update({'alphaMask.shape':alpha_volume.shape})\n            ckpt.update({'alphaMask.mask':np.packbits(alpha_volume.reshape(-1))})\n            ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()})\n        torch.save(ckpt, path)\n\n    def load(self, ckpt):\n        if 'alphaMask.aabb' in ckpt.keys():\n            length = np.prod(ckpt['alphaMask.shape'])\n            alpha_volume = torch.from_numpy(np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape']))\n            self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device))\n        self.load_state_dict(ckpt['state_dict'])\n\n\n    def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):\n        N_samples = N_samples if N_samples > 0 else self.nSamples\n        near, far = self.near_far\n        interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o)\n        if is_train:\n            interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples)\n\n        rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None]\n        mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1)\n        return rays_pts, interpx, ~mask_outbbox\n\n    def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1):\n        N_samples = N_samples if N_samples>0 else self.nSamples\n        stepsize = self.stepSize\n        near, far = self.near_far\n        vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d)\n        rate_a = (self.aabb[1] - rays_o) / vec\n        rate_b = (self.aabb[0] - rays_o) / vec\n        t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far)\n\n        rng = torch.arange(N_samples)[None].float()\n        if is_train:\n            rng = rng.repeat(rays_d.shape[-2],1)\n            rng += torch.rand_like(rng[:,[0]])\n        step = stepsize * rng.to(rays_o.device)\n        interpx = (t_min[...,None] + step)\n\n        rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None]\n        mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1)\n\n        return rays_pts, interpx, ~mask_outbbox\n\n\n    def shrink(self, new_aabb, voxel_size):\n        pass\n\n    @torch.no_grad()\n    def getDenseAlpha(self,gridSize=None):\n        gridSize = self.gridSize if gridSize is None else gridSize\n\n        samples = torch.stack(torch.meshgrid(\n            torch.linspace(0, 1, gridSize[0]),\n            torch.linspace(0, 1, gridSize[1]),\n            torch.linspace(0, 1, gridSize[2]),\n        ), -1).to(self.device)\n        dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples\n\n        # dense_xyz = dense_xyz\n        # print(self.stepSize, self.distance_scale*self.aabbDiag)\n        alpha = torch.zeros_like(dense_xyz[...,0])\n        for i in range(gridSize[0]):\n            alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2]))\n        return alpha, dense_xyz\n\n    @torch.no_grad()\n    def updateAlphaMask(self, gridSize=(200,200,200)):\n\n        alpha, dense_xyz = self.getDenseAlpha(gridSize)\n        dense_xyz = dense_xyz.transpose(0,2).contiguous()\n        alpha = alpha.clamp(0,1).transpose(0,2).contiguous()[None,None]\n        total_voxels = gridSize[0] * gridSize[1] * gridSize[2]\n\n        ks = 3\n        alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1])\n        alpha[alpha>=self.alphaMask_thres] = 1\n        alpha[alpha<self.alphaMask_thres] = 0\n\n        self.alphaMask = AlphaGridMask(self.device, self.aabb, alpha)\n\n        valid_xyz = dense_xyz[alpha>0.5]\n\n        xyz_min = valid_xyz.amin(0)\n        xyz_max = valid_xyz.amax(0)\n\n        new_aabb = torch.stack((xyz_min, xyz_max))\n\n        total = torch.sum(alpha)\n        print(f\"bbox: {xyz_min, xyz_max} alpha rest %%%f\"%(total/total_voxels*100))\n        return new_aabb\n\n    @torch.no_grad()\n    def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240*5, bbox_only=False):\n        print('========> filtering rays ...')\n        tt = time.time()\n\n        N = torch.tensor(all_rays.shape[:-1]).prod()\n\n        mask_filtered = []\n        idx_chunks = torch.split(torch.arange(N), chunk)\n        for idx_chunk in idx_chunks:\n            rays_chunk = all_rays[idx_chunk].to(self.device)\n\n            rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6]\n            if bbox_only:\n                vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)\n                rate_a = (self.aabb[1] - rays_o) / vec\n                rate_b = (self.aabb[0] - rays_o) / vec\n                t_min = torch.minimum(rate_a, rate_b).amax(-1)#.clamp(min=near, max=far)\n                t_max = torch.maximum(rate_a, rate_b).amin(-1)#.clamp(min=near, max=far)\n                mask_inbbox = t_max > t_min\n\n            else:\n                xyz_sampled, _,_ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False)\n                mask_inbbox= (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1)\n\n            mask_filtered.append(mask_inbbox.cpu())\n\n        mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1])\n\n        print(f'Ray filtering done! takes {time.time()-tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}')\n        return all_rays[mask_filtered], all_rgbs[mask_filtered]\n\n\n    def feature2density(self, density_features):\n        if self.fea2denseAct == \"softplus\":\n            return F.softplus(density_features+self.density_shift)\n        elif self.fea2denseAct == \"relu\":\n            return F.relu(density_features)\n\n\n    def compute_alpha(self, xyz_locs, length=1):\n\n        if self.alphaMask is not None:\n            alphas = self.alphaMask.sample_alpha(xyz_locs)\n            alpha_mask = alphas > 0\n        else:\n            alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool)\n            \n\n        sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)\n\n        if alpha_mask.any():\n            xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask])\n            sigma_feature = self.compute_densityfeature(xyz_sampled)\n            validsigma = self.feature2density(sigma_feature)\n            sigma[alpha_mask] = validsigma\n        \n\n        alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1])\n\n        return alpha\n\n\n    def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):\n\n        # sample points\n        viewdirs = rays_chunk[:, 3:6]\n        if ndc_ray:\n            xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)\n            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)\n            rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)\n            dists = dists * rays_norm\n            viewdirs = viewdirs / rays_norm\n        else:\n            xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)\n            dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)\n        viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)\n        \n        if self.alphaMask is not None:\n            alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])\n            alpha_mask = alphas > 0\n            ray_invalid = ~ray_valid\n            ray_invalid[ray_valid] |= (~alpha_mask)\n            ray_valid = ~ray_invalid\n\n\n        sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)\n        rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)\n\n        if ray_valid.any():\n            xyz_sampled = self.normalize_coord(xyz_sampled)\n            sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])\n\n            validsigma = self.feature2density(sigma_feature)\n            sigma[ray_valid] = validsigma\n\n\n        alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)\n\n        app_mask = weight > self.rayMarch_weight_thres\n\n        if app_mask.any():\n            app_features = self.compute_appfeature(xyz_sampled[app_mask])\n            valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features)\n            rgb[app_mask] = valid_rgbs\n\n        acc_map = torch.sum(weight, -1)\n        rgb_map = torch.sum(weight[..., None] * rgb, -2)\n\n        if white_bg or (is_train and torch.rand((1,))<0.5):\n            rgb_map = rgb_map + (1. - acc_map[..., None])\n\n        \n        rgb_map = rgb_map.clamp(0,1)\n\n        with torch.no_grad():\n            depth_map = torch.sum(weight * z_vals, -1)\n            depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]\n\n        return rgb_map, depth_map # rgb, sigma, alpha, weight, bg_weight\n\n"
  },
  {
    "path": "opt.py",
    "content": "import configargparse\n\ndef config_parser(cmd=None):\n    parser = configargparse.ArgumentParser()\n    parser.add_argument('--config', is_config_file=True,\n                        help='config file path')\n    parser.add_argument(\"--expname\", type=str,\n                        help='experiment name')\n    parser.add_argument(\"--basedir\", type=str, default='./log',\n                        help='where to store ckpts and logs')\n    parser.add_argument(\"--add_timestamp\", type=int, default=0,\n                        help='add timestamp to dir')\n    parser.add_argument(\"--datadir\", type=str, default='./data/llff/fern',\n                        help='input data directory')\n    parser.add_argument(\"--wikiartdir\", type=str, default='./data/WikiArt',\n                        help='input data directory')\n    parser.add_argument(\"--progress_refresh_rate\", type=int, default=10,\n                        help='how many iterations to show psnrs or iters')\n\n    parser.add_argument('--with_depth', action='store_true')\n    parser.add_argument('--downsample_train', type=float, default=1.0)\n    parser.add_argument('--downsample_test', type=float, default=1.0)\n\n    parser.add_argument('--model_name', type=str, default='TensorVMSplit',\n                        choices=['TensorVMSplit', 'TensorCP'])\n\n    # loader options\n    parser.add_argument(\"--batch_size\", type=int, default=4096)\n    parser.add_argument(\"--n_iters\", type=int, default=30000)\n    parser.add_argument('--dataset_name', type=str, default='blender',\n                        choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'])\n\n\n    # training options\n    parser.add_argument(\"--patch_size\", type=int, default=256,\n                        help='patch_size for training') \n    parser.add_argument(\"--chunk_size\", type=int, default=4096,\n                        help='chunk_size for training')           \n    # learning rate\n    parser.add_argument(\"--lr_init\", type=float, default=0.02,\n                        help='learning rate')    \n    parser.add_argument(\"--lr_basis\", type=float, default=1e-4,\n                        help='learning rate')\n    parser.add_argument(\"--lr_finetune\", type=float, default=1e-5,\n                        help='learning rate')\n    parser.add_argument(\"--lr_decay_iters\", type=int, default=-1,\n                        help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters')\n    parser.add_argument(\"--lr_decay_target_ratio\", type=float, default=0.1,\n                        help='the target decay ratio; after decay_iters inital lr decays to lr*ratio')\n    parser.add_argument(\"--lr_upsample_reset\", type=int, default=1,\n                        help='reset lr to inital after upsampling')\n\n    # loss\n    parser.add_argument(\"--L1_weight_inital\", type=float, default=0.0,\n                        help='loss weight')\n    parser.add_argument(\"--L1_weight_rest\", type=float, default=0,\n                        help='loss weight')\n    parser.add_argument(\"--Ortho_weight\", type=float, default=0.0,\n                        help='loss weight')\n    parser.add_argument(\"--TV_weight_density\", type=float, default=0.0,\n                        help='loss weight')\n    parser.add_argument(\"--TV_weight_app\", type=float, default=0.0,\n                        help='loss weight')\n    parser.add_argument(\"--TV_weight_feature\", type=float, default=0.0,\n                        help='loss weight')\n    parser.add_argument(\"--style_weight\", type=float, default=0,\n                        help='loss weight')\n    parser.add_argument(\"--content_weight\", type=float, default=0,\n                        help='loss weight')\n    parser.add_argument(\"--image_tv_weight\", type=float, default=0,\n                        help='loss weight')\n    parser.add_argument(\"--featuremap_tv_weight\", type=float, default=0,\n                        help='loss weight')\n\n    \n    # model\n    # volume options\n    parser.add_argument(\"--n_lamb_sigma\", type=int, action=\"append\")\n    parser.add_argument(\"--n_lamb_sh\", type=int, action=\"append\")\n    parser.add_argument(\"--data_dim_color\", type=int, default=27)\n\n    parser.add_argument(\"--rm_weight_mask_thre\", type=float, default=0.0001,\n                        help='mask points in ray marching')\n    parser.add_argument(\"--alpha_mask_thre\", type=float, default=0.0001,\n                        help='threshold for creating alpha mask volume')\n    parser.add_argument(\"--distance_scale\", type=float, default=25,\n                        help='scaling sampling distance for computation')\n    parser.add_argument(\"--density_shift\", type=float, default=-10,\n                        help='shift density in softplus; making density = 0  when feature == 0')\n                        \n    # network decoder\n    parser.add_argument(\"--shadingMode\", type=str, default=\"MLP_Fea\",\n                        help='which shading mode to use')\n    parser.add_argument(\"--pos_pe\", type=int, default=6,\n                        help='number of pe for pos')\n    parser.add_argument(\"--view_pe\", type=int, default=6,\n                        help='number of pe for view')\n    parser.add_argument(\"--fea_pe\", type=int, default=6,\n                        help='number of pe for features')\n    parser.add_argument(\"--featureC\", type=int, default=128,\n                        help='hidden feature channel in MLP')\n    \n\n    # test option\n    parser.add_argument(\"--ckpt\", type=str, default=None,\n                        help='specific weights npy file to reload for coarse network')\n    parser.add_argument(\"--render_only\", type=int, default=0)\n    parser.add_argument(\"--render_test\", type=int, default=0)\n    parser.add_argument(\"--render_train\", type=int, default=0)\n    parser.add_argument(\"--render_path\", type=int, default=0)\n    parser.add_argument(\"--export_mesh\", type=int, default=0)\n    parser.add_argument(\"--style_img\", type=str, required=False)\n\n    # rendering options\n    parser.add_argument('--lindisp', default=False, action=\"store_true\",\n                        help='use disparity depth sampling')\n    parser.add_argument(\"--perturb\", type=float, default=1.,\n                        help='set to 0. for no jitter, 1. for jitter')\n    parser.add_argument(\"--accumulate_decay\", type=float, default=0.998)\n    parser.add_argument(\"--fea2denseAct\", type=str, default='softplus')\n    parser.add_argument('--ndc_ray', type=int, default=0)\n    parser.add_argument('--nSamples', type=int, default=1e6,\n                        help='sample point each ray, pass 1e6 if automatic adjust')\n    parser.add_argument('--step_ratio',type=float,default=0.5)\n\n\n    ## blender flags\n    parser.add_argument(\"--white_bkgd\", action='store_true',\n                        help='set to render synthetic data on a white bkgd (always use for dvoxels)')\n\n\n\n    parser.add_argument('--N_voxel_init',\n                        type=int,\n                        default=100**3)\n    parser.add_argument('--N_voxel_final',\n                        type=int,\n                        default=300**3)\n    parser.add_argument(\"--upsamp_list\", type=int, action=\"append\")\n    parser.add_argument(\"--update_AlphaMask_list\", type=int, action=\"append\")\n\n    parser.add_argument('--idx_view',\n                        type=int,\n                        default=0)\n    # logging/saving options\n    parser.add_argument(\"--N_vis\", type=int, default=5,\n                        help='N images to vis')\n    parser.add_argument(\"--vis_every\", type=int, default=10000,\n                        help='frequency of visualize the image')\n    if cmd is not None:\n        return parser.parse_args(cmd)\n    else:\n        return parser.parse_args()"
  },
  {
    "path": "renderer.py",
    "content": "import torch,os,imageio,sys\nfrom tqdm.auto import tqdm\nfrom dataLoader.ray_utils import get_rays\nfrom models.tensoRF import raw2alpha, TensorVMSplit, AlphaGridMask\nfrom utils import *\nfrom dataLoader.ray_utils import ndc_rays_blender, denormalize_vgg, normalize_vgg\n\n\ndef OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, \n                                render_feature=False, style_img=None, device='cuda'):\n\n    rgbs, alphas, depth_maps, weights, uncertainties = [], [], [], [], []\n    features, accs = [], []\n    s_mean_std_mat = None\n    if style_img is not None:\n        with torch.no_grad():\n            style_feature = tensorf.encoder(normalize_vgg(style_img))\n        s_mean_std_mat = tensorf.stylizer.get_style_mean_std_matrix(style_feature.relu3_1.flatten(2))\n\n    N_rays_all = rays.shape[0]\n    for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):\n        rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)\n        \n        if render_feature:\n            feature_map, acc_map = tensorf.render_feature_map(rays_chunk, s_mean_std_mat=s_mean_std_mat, is_train=is_train, ndc_ray=ndc_ray, N_samples=N_samples)\n            features.append(feature_map)\n            accs.append(acc_map)\n        else:\n            rgb_map, depth_map = tensorf(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples)\n            rgbs.append(rgb_map)\n            depth_maps.append(depth_map)\n    \n    if render_feature:\n        if style_img is not None:\n            return torch.cat(features), torch.cat(accs), style_feature\n        return torch.cat(features), torch.cat(accs)\n\n    return torch.cat(rgbs), None, torch.cat(depth_maps), None, None \n\ndef OctreeRender_trilinear_fast_depth(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, is_train=False, device='cuda'):\n\n    depth_maps = []\n    N_rays_all = rays.shape[0]\n    for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):\n        rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)\n        depth_map = tensorf.render_depth_map(rays_chunk, is_train=is_train, ndc_ray=ndc_ray, N_samples=N_samples)\n        depth_maps.append(depth_map)\n    \n    return torch.cat(depth_maps)\n\n\n@torch.no_grad()\ndef evaluation_feature(test_dataset, tensorf, args, renderer, chunk_size=2048, savePath=None, N_vis=10, prtx='', N_samples=-1,\n               white_bg=False, ndc_ray=False, compute_extra_metrics=False, style_img=None, device='cuda'):\n    '''\n    To see if the decoded feature map is similar to gt rgb map\n    '''\n    PSNRs, rgb_maps, vis_feature_maps = [], [], []\n    ssims,l_alex,l_vgg=[],[],[]\n    os.makedirs(savePath, exist_ok=True)\n    os.makedirs(savePath + '/feature', exist_ok=True)\n    W, H = test_dataset.img_wh\n\n    try:\n        tqdm._instances.clear()\n    except Exception:\n        pass\n\n    near_far = test_dataset.near_far\n    img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays_stack.shape[0] // N_vis,1)\n    idxs = list(range(0, test_dataset.all_rays_stack.shape[0], img_eval_interval))\n    for idx, samples in tqdm(enumerate(test_dataset.all_rays_stack[0::img_eval_interval]), file=sys.stdout):        \n\n        rays = samples.view(-1,samples.shape[-1])\n\n        if style_img is None:\n            feature_map, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, \n                                        white_bg = white_bg, render_feature=True, device=device)\n        else:\n            feature_map, _, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, \n                                white_bg = white_bg, render_feature=True, style_img=style_img, device=device)\n                            \n        feature_map = feature_map.reshape(H, W, 256)[None,...].permute(0,3,1,2)\n\n        recon_rgb = denormalize_vgg(tensorf.decoder(feature_map))\n        recon_rgb = recon_rgb.permute(0,2,3,1).clamp(0,1)\n\n        vis_feature_map = torch.sigmoid(feature_map[:, [1,2,3], :, :].permute(0,2,3,1))\n        \n        if test_dataset.white_bg:\n            mask = test_dataset.all_masks[idx:idx+1].to(device)\n            recon_rgb = mask*recon_rgb + (1.-mask)\n            vis_feature_map = mask*vis_feature_map + (1.-mask)\n\n        recon_rgb = recon_rgb.reshape(H, W, 3).cpu()\n        vis_feature_map = vis_feature_map.squeeze().cpu()\n\n        if len(test_dataset.all_rgbs_stack):\n            gt_rgb = test_dataset.all_rgbs_stack[idxs[idx]].view(H, W, 3)\n            loss = torch.mean((recon_rgb - gt_rgb) ** 2)\n            PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))\n\n            if compute_extra_metrics:\n                ssim = rgb_ssim(recon_rgb, gt_rgb, 1)\n                l_a = rgb_lpips(gt_rgb.numpy(), recon_rgb.numpy(), 'alex', tensorf.device)\n                l_v = rgb_lpips(gt_rgb.numpy(), recon_rgb.numpy(), 'vgg', tensorf.device)\n                ssims.append(ssim)\n                l_alex.append(l_a)\n                l_vgg.append(l_v)\n\n        recon_rgb = (recon_rgb.numpy() * 255).astype('uint8')\n        vis_feature_map = (vis_feature_map.numpy() * 255).astype('uint8')\n        gt_rgb = (gt_rgb.numpy() * 255).astype('uint8')\n\n        if savePath is not None:\n            # rgb_map = np.concatenate((recon_rgb, gt_rgb), axis=1)\n            rgb_maps.append(recon_rgb)\n            vis_feature_maps.append(vis_feature_map)\n            imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', recon_rgb)\n            imageio.imwrite(f'{savePath}/feature/feature_{idx:03d}.png', vis_feature_map)\n\n    imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)\n    imageio.mimwrite(f'{savePath}/feature/feature_video.mp4', np.stack(vis_feature_maps), fps=30, quality=8)\n\n    if PSNRs:\n        psnr = np.mean(np.asarray(PSNRs))\n        if compute_extra_metrics:\n            ssim = np.mean(np.asarray(ssims))\n            l_a = np.mean(np.asarray(l_alex))\n            l_v = np.mean(np.asarray(l_vgg))\n            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))\n        else:\n            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))\n\n\n    return PSNRs\n\n@torch.no_grad()\ndef evaluation_feature_path(test_dataset, tensorf, c2ws, renderer, chunk_size=2048, savePath=None, N_vis=5, prtx='', N_samples=-1,\n                    white_bg=False, ndc_ray=False, compute_extra_metrics=False, style_img=None, device='cuda'):\n    PSNRs, rgb_maps, vis_feature_maps, depth_maps = [], [], [], []\n    ssims,l_alex,l_vgg=[],[],[]\n    os.makedirs(savePath, exist_ok=True)\n    os.makedirs(savePath + '/feature', exist_ok=True)\n    W, H = test_dataset.img_wh\n\n    try:\n        tqdm._instances.clear()\n    except Exception:\n        pass\n\n    near_far = test_dataset.near_far\n    for idx, c2w in tqdm(enumerate(c2ws)):\n\n        c2w = torch.FloatTensor(c2w)\n        rays_o, rays_d = get_rays(test_dataset.directions, c2w)  # both (h*w, 3)\n        if ndc_ray:\n            rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)\n        rays = torch.cat([rays_o, rays_d], 1).reshape(H, W, 6).permute(2,0,1)  # (6,H,W)\n        rays = rays.permute(1,2,0).reshape(-1,6) # (H * W, 6)\n\n        if style_img is None:\n            feature_map, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, \n                                        white_bg = white_bg, render_feature=True, device=device)\n        else:\n            feature_map, _, _ = renderer(rays, tensorf, chunk=chunk_size, N_samples=N_samples, ndc_ray=ndc_ray, \n                                white_bg = white_bg, render_feature=True, style_img=style_img, device=device)\n        \n        feature_map = feature_map.reshape(H, W, 256)[None,...].permute(0,3,1,2)\n\n        recon_rgb = denormalize_vgg(tensorf.decoder(feature_map))\n        recon_rgb = recon_rgb.permute(0,2,3,1).clamp(0,1)\n        recon_rgb = recon_rgb.reshape(H, W, 3).cpu()\n        recon_rgb = (recon_rgb.numpy() * 255).astype('uint8')\n        rgb_maps.append(recon_rgb)\n\n        vis_feature_map = torch.sigmoid(feature_map[:, [1,2,3], :, :].permute(0,2,3,1))\n        vis_feature_map = vis_feature_map.squeeze().cpu()\n        vis_feature_map = (vis_feature_map.numpy() * 255).astype('uint8')\n        vis_feature_maps.append(vis_feature_map)\n        \n        if savePath is not None:\n            imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', recon_rgb)\n            imageio.imwrite(f'{savePath}/feature/feature_{idx:03d}.png', vis_feature_map)\n\n    imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)\n    imageio.mimwrite(f'{savePath}/feature/feature_video.mp4', np.stack(vis_feature_maps), fps=30, quality=8)\n\n    if PSNRs:\n        psnr = np.mean(np.asarray(PSNRs))\n        if compute_extra_metrics:\n            ssim = np.mean(np.asarray(ssims))\n            l_a = np.mean(np.asarray(l_alex))\n            l_v = np.mean(np.asarray(l_vgg))\n            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))\n        else:\n            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))\n\n\n    return PSNRs\n\n@torch.no_grad()\ndef evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,\n               white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):\n    PSNRs, rgb_maps, depth_maps = [], [], []\n    ssims,l_alex,l_vgg=[],[],[]\n    os.makedirs(savePath, exist_ok=True)\n    os.makedirs(savePath+\"/rgbd\", exist_ok=True)\n\n    try:\n        tqdm._instances.clear()\n    except Exception:\n        pass\n\n    near_far = test_dataset.near_far\n    img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays_stack.shape[0] // N_vis,1)\n    idxs = list(range(0, test_dataset.all_rays_stack.shape[0], img_eval_interval))\n    for idx, samples in tqdm(enumerate(test_dataset.all_rays_stack[0::img_eval_interval]), file=sys.stdout):\n\n        W, H = test_dataset.img_wh\n        rays = samples.view(-1,samples.shape[-1])\n\n        rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=4096, N_samples=N_samples,\n                                        ndc_ray=ndc_ray, white_bg = white_bg, device=device)\n        rgb_map = rgb_map.clamp(0.0, 1.0)\n\n        rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()\n\n        depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)\n        if len(test_dataset.all_rgbs_stack):\n            gt_rgb = test_dataset.all_rgbs_stack[idxs[idx]].view(H, W, 3)\n            loss = torch.mean((rgb_map - gt_rgb) ** 2)\n            PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))\n\n            if compute_extra_metrics:\n                ssim = rgb_ssim(rgb_map, gt_rgb, 1)\n                l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', tensorf.device)\n                l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', tensorf.device)\n                ssims.append(ssim)\n                l_alex.append(l_a)\n                l_vgg.append(l_v)\n\n        rgb_map = (rgb_map.numpy() * 255).astype('uint8')\n        # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)\n        rgb_maps.append(rgb_map)\n        depth_maps.append(depth_map)\n        if savePath is not None:\n            imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)\n            # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)\n            imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', depth_map)\n\n    imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10)\n    imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10)\n\n    if PSNRs:\n        psnr = np.mean(np.asarray(PSNRs))\n        if compute_extra_metrics:\n            ssim = np.mean(np.asarray(ssims))\n            l_a = np.mean(np.asarray(l_alex))\n            l_v = np.mean(np.asarray(l_vgg))\n            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))\n        else:\n            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))\n\n\n    return PSNRs\n\n@torch.no_grad()\ndef evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,\n                    white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):\n    PSNRs, rgb_maps, depth_maps = [], [], []\n    ssims,l_alex,l_vgg=[],[],[]\n    os.makedirs(savePath, exist_ok=True)\n    os.makedirs(savePath+\"/rgbd\", exist_ok=True)\n\n    try:\n        tqdm._instances.clear()\n    except Exception:\n        pass\n\n    near_far = test_dataset.near_far\n    for idx, c2w in tqdm(enumerate(c2ws)):\n\n        W, H = test_dataset.img_wh\n\n        c2w = torch.FloatTensor(c2w)\n        rays_o, rays_d = get_rays(test_dataset.directions, c2w)  # both (h*w, 3)\n        if ndc_ray:\n            rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)\n        rays = torch.cat([rays_o, rays_d], 1)  # (h*w, 6)\n\n        rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=8192, N_samples=N_samples,\n                                        ndc_ray=ndc_ray, white_bg = white_bg, device=device)\n        rgb_map = rgb_map.clamp(0.0, 1.0)\n\n        rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()\n\n        depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)\n\n        rgb_map = (rgb_map.numpy() * 255).astype('uint8')\n        # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)\n        rgb_maps.append(rgb_map)\n        depth_maps.append(depth_map)\n        if savePath is not None:\n            imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)\n            rgb_map = np.concatenate((rgb_map, depth_map), axis=1)\n            imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)\n\n    imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)\n    imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8)\n\n    if PSNRs:\n        psnr = np.mean(np.asarray(PSNRs))\n        if compute_extra_metrics:\n            ssim = np.mean(np.asarray(ssims))\n            l_a = np.mean(np.asarray(l_alex))\n            l_v = np.mean(np.asarray(l_vgg))\n            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))\n        else:\n            np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))\n\n\n    return PSNRs\n\n"
  },
  {
    "path": "scripts/test.sh",
    "content": "CUDA_VISIBLE_DEVICES=$1 python train.py \\\n--config configs/llff.txt \\\n--ckpt log/trex/trex.th \\ \n--render_only 1 \n--render_test 1 "
  },
  {
    "path": "scripts/test_feature.sh",
    "content": "expname=trex\nCUDA_VISIBLE_DEVICES=$1 python train_feature.py \\\n--config configs/llff_feature.txt \\\n--datadir ./data/nerf_llff_data/trex \\\n--expname $expname \\\n--ckpt ./log_feature/$expname/$expname.th \\\n--render_only 1 \\\n--render_test 0 \\\n--render_path 1 \\\n--chunk_size 1024"
  },
  {
    "path": "scripts/test_style.sh",
    "content": "expname=trex\nCUDA_VISIBLE_DEVICES=$1 python train_style.py \\\n--config configs/llff_style.txt \\\n--datadir ./data/nerf_llff_data/trex \\\n--expname $expname \\\n--ckpt log_style/$expname/$expname.th \\\n--style_img path/to/reference/style/image \\\n--render_only 1 \\\n--render_train 0 \\\n--render_test 0 \\\n--render_path 1 \\\n--chunk_size 1024 \\\n--rm_weight_mask_thre 0.0001 \\"
  },
  {
    "path": "scripts/train.sh",
    "content": "CUDA_VISIBLE_DEVICES=$1 python train.py --config=configs/llff.txt"
  },
  {
    "path": "scripts/train_feature.sh",
    "content": "CUDA_VISIBLE_DEVICES=$1 python train_feature.py --config=configs/llff_feature.txt"
  },
  {
    "path": "scripts/train_style.sh",
    "content": "CUDA_VISIBLE_DEVICES=$1 python train_style.py --config=configs/llff_style.txt"
  },
  {
    "path": "train.py",
    "content": "\nimport os\nfrom tqdm.auto import tqdm\nfrom opt import config_parser\n\n\n\nimport json, random\nfrom renderer import *\nfrom utils import *\nfrom torch.utils.tensorboard import SummaryWriter\nimport datetime\n\nfrom dataLoader import dataset_dict\nimport sys\n\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nrenderer = OctreeRender_trilinear_fast\n\n\nclass SimpleSampler:\n    def __init__(self, total, batch):\n        self.total = total\n        self.batch = batch\n        self.curr = total\n        self.ids = None\n\n    def nextids(self):\n        self.curr+=self.batch\n        if self.curr + self.batch > self.total:\n            self.ids = torch.LongTensor(np.random.permutation(self.total))\n            self.curr = 0\n        return self.ids[self.curr:self.curr+self.batch]\n\n\n@torch.no_grad()\ndef export_mesh(args):\n\n    ckpt = torch.load(args.ckpt, map_location=device)\n    kwargs = ckpt['kwargs']\n    kwargs.update({'device': device})\n    tensorf = eval(args.model_name)(**kwargs)\n    tensorf.load(ckpt)\n\n    alpha,_ = tensorf.getDenseAlpha()\n    convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005)\n\n\n@torch.no_grad()\ndef render_test(args):\n    # init dataset\n    dataset = dataset_dict[args.dataset_name]\n    test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)\n    white_bg = test_dataset.white_bg\n    ndc_ray = args.ndc_ray\n\n    if not os.path.exists(args.ckpt):\n        print('the ckpt path does not exists!!')\n        return\n\n    ckpt = torch.load(args.ckpt, map_location=device)\n    kwargs = ckpt['kwargs']\n    kwargs.update({'device': device})\n    tensorf = eval(args.model_name)(**kwargs)\n    tensorf.load(ckpt)\n\n    logfolder = os.path.dirname(args.ckpt)\n    if args.render_train:\n        os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)\n        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)\n        PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',\n                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)\n        print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')\n\n    if args.render_test:\n        os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)\n        evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',\n                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)\n\n    if args.render_path:\n        c2ws = test_dataset.render_path\n        os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)\n        evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',\n                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)\n\ndef reconstruction(args):\n\n    # init dataset\n    dataset = dataset_dict[args.dataset_name]\n    train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)\n    test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)\n    white_bg = train_dataset.white_bg\n    near_far = train_dataset.near_far\n    ndc_ray = args.ndc_ray\n\n    # init resolution\n    upsamp_list = args.upsamp_list\n    update_AlphaMask_list = args.update_AlphaMask_list\n    n_lamb_sigma = args.n_lamb_sigma\n    n_lamb_sh = args.n_lamb_sh\n\n    \n    if args.add_timestamp:\n        logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime(\"-%Y%m%d-%H%M%S\")}'\n    else:\n        logfolder = f'{args.basedir}/{args.expname}'\n    \n\n    # init log file\n    os.makedirs(logfolder, exist_ok=True)\n    os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)\n    os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)\n    os.makedirs(f'{logfolder}/rgba', exist_ok=True)\n    summary_writer = SummaryWriter(logfolder)\n\n\n\n    # init parameters\n    # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])\n    aabb = train_dataset.scene_bbox.to(device)\n    reso_cur = N_to_reso(args.N_voxel_init, aabb)\n    nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))\n\n\n    if args.ckpt is not None:\n        ckpt = torch.load(args.ckpt, map_location=device)\n        kwargs = ckpt['kwargs']\n        kwargs.update({'device':device})\n        tensorf = eval(args.model_name)(**kwargs)\n        tensorf.load(ckpt)\n    else:\n        tensorf = eval(args.model_name)(aabb, reso_cur, device,\n                    density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far,\n                    shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale,\n                    pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct)\n\n\n    grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis)\n    if args.lr_decay_iters > 0:\n        lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)\n    else:\n        args.lr_decay_iters = args.n_iters\n        lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)\n\n    print(\"lr decay\", args.lr_decay_target_ratio, args.lr_decay_iters)\n    \n    optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))\n\n\n    #linear in logrithmic space\n    if upsamp_list is not None:\n        N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]\n\n\n    torch.cuda.empty_cache()\n    PSNRs,PSNRs_test = [],[0]\n\n    allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs\n    if not args.ndc_ray:\n        allrays, allrgbs = tensorf.filtering_rays(allrays, allrgbs, bbox_only=True)\n    trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)\n\n    Ortho_reg_weight = args.Ortho_weight\n    print(\"initial Ortho_reg_weight\", Ortho_reg_weight)\n\n    L1_reg_weight = args.L1_weight_inital\n    print(\"initial L1_reg_weight\", L1_reg_weight)\n    TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app\n    tvreg = TVLoss()\n    print(f\"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app}\")\n\n\n    pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)\n    for iteration in pbar:\n\n\n        ray_idx = trainingSampler.nextids()\n        rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device)\n\n        #rgb_map, alphas_map, depth_map, weights, uncertainty\n        rgb_map, alphas_map, depth_map, weights, uncertainty = renderer(rays_train, tensorf, chunk=args.batch_size,\n                                N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True)\n\n        loss = torch.mean((rgb_map - rgb_train) ** 2)\n\n\n        # loss\n        total_loss = loss\n        if Ortho_reg_weight > 0:\n            loss_reg = tensorf.vector_comp_diffs()\n            total_loss += Ortho_reg_weight*loss_reg\n            summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)\n        if L1_reg_weight > 0:\n            loss_reg_L1 = tensorf.density_L1()\n            total_loss += L1_reg_weight*loss_reg_L1\n            summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)\n\n        if TV_weight_density>0:\n            TV_weight_density *= lr_factor\n            loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density\n            total_loss = total_loss + loss_tv\n            summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)\n        if TV_weight_app>0:\n            TV_weight_app *= lr_factor\n            loss_tv = tensorf.TV_loss_app(tvreg)*TV_weight_app\n            total_loss = total_loss + loss_tv\n            summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)\n\n        optimizer.zero_grad()\n        total_loss.backward()\n        optimizer.step()\n\n        loss = loss.detach().item()\n        \n        PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))\n        summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)\n        summary_writer.add_scalar('train/mse', loss, global_step=iteration)\n\n\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = param_group['lr'] * lr_factor\n\n        # Print the current values of the losses.\n        if iteration % args.progress_refresh_rate == 0:\n            pbar.set_description(\n                f'Iteration {iteration:05d}:'\n                + f' train_psnr = {float(np.mean(PSNRs)):.2f}'\n                + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'\n                + f' mse = {loss:.6f}'\n            )\n            PSNRs = []\n\n\n        if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0:\n            PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis,\n                                    prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, compute_extra_metrics=False)\n            summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)\n\n\n\n        if update_AlphaMask_list is not None and iteration in update_AlphaMask_list:\n\n            if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution\n                reso_mask = reso_cur\n            new_aabb = tensorf.updateAlphaMask(tuple(reso_mask))\n            if iteration == update_AlphaMask_list[0]:\n                tensorf.shrink(new_aabb)\n                # tensorVM.alphaMask = None\n                L1_reg_weight = args.L1_weight_rest\n                print(\"continuing L1_reg_weight\", L1_reg_weight)\n\n\n            if not args.ndc_ray and iteration == update_AlphaMask_list[1]:\n                # filter rays outside the bbox\n                allrays,allrgbs = tensorf.filtering_rays(allrays,allrgbs)\n                trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)\n\n\n        if upsamp_list is not None and iteration in upsamp_list:\n            n_voxels = N_voxel_list.pop(0)\n            reso_cur = N_to_reso(n_voxels, tensorf.aabb)\n            nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))\n            tensorf.upsample_volume_grid(reso_cur)\n\n            if args.lr_upsample_reset:\n                print(\"reset lr to initial\")\n                lr_scale = 1 #0.1 ** (iteration / args.n_iters)\n            else:\n                lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)\n            grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale)\n            optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))\n        \n\n    tensorf.save(f'{logfolder}/{args.expname}.th')\n\n\n    if args.render_train:\n        os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)\n        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)\n        PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',\n                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)\n        print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')\n\n    if args.render_test:\n        os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)\n        PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_test_all/',\n                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)\n        summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)\n        print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')\n\n    if args.render_path:\n        c2ws = test_dataset.render_path\n        # c2ws = test_dataset.poses\n        print('========>',c2ws.shape)\n        os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)\n        evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/imgs_path_all/',\n                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)\n\n\nif __name__ == '__main__':\n\n    torch.set_default_dtype(torch.float32)\n    torch.manual_seed(20211202)\n    np.random.seed(20211202)\n\n    args = config_parser()\n    print(args)\n\n    if  args.export_mesh:\n        export_mesh(args)\n\n    if args.render_only and (args.render_test or args.render_path):\n        render_test(args)\n    else:\n        reconstruction(args)\n\n"
  },
  {
    "path": "train_feature.py",
    "content": "\nimport os\nfrom unittest.mock import patch\nfrom tqdm.auto import tqdm\nfrom opt import config_parser\n\n\n\nimport json, random\nfrom renderer import *\nfrom utils import *\nfrom torch.utils.tensorboard import SummaryWriter\nfrom torchvision.utils import make_grid\nimport torchvision.transforms.functional as TF\nimport datetime\n\nfrom dataLoader import dataset_dict\nimport sys\n\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nrenderer = OctreeRender_trilinear_fast\n\n\nclass SimpleSampler:\n    def __init__(self, total, batch):\n        self.total = total\n        self.batch = batch\n        self.curr = total\n        self.ids = None\n\n    def nextids(self):\n        self.curr+=self.batch\n        if self.curr + self.batch > self.total:\n            self.ids = torch.LongTensor(np.random.permutation(self.total))\n            self.curr = 0\n        return self.ids[self.curr:self.curr+self.batch]\n\ndef InfiniteSampler(n):\n    # i = 0\n    i = n - 1\n    order = np.random.permutation(n)\n    while True:\n        yield order[i]\n        i += 1\n        if i >= n:\n            np.random.seed()\n            order = np.random.permutation(n)\n            i = 0\n\nclass InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):\n    def __init__(self, num_samples):\n        self.num_samples = num_samples\n\n    def __iter__(self):\n        return iter(InfiniteSampler(self.num_samples))\n\n    def __len__(self):\n        return 2 ** 31\n\n@torch.no_grad()\ndef render_test(args):\n    # init dataset\n    dataset = dataset_dict[args.dataset_name]\n    test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)\n    white_bg = test_dataset.white_bg\n    ndc_ray = args.ndc_ray\n\n    if not os.path.exists(args.ckpt):\n        print('the ckpt path does not exists!!')\n        return\n\n    ckpt = torch.load(args.ckpt, map_location=device)\n    kwargs = ckpt['kwargs']\n    kwargs.update({'device': device})\n    tensorf = eval(args.model_name)(**kwargs)\n    tensorf.change_to_feature_mod(args.n_lamb_sh ,device)\n    tensorf.load(ckpt)\n    tensorf.eval()\n    tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre\n\n    logfolder = os.path.dirname(args.ckpt)\n\n\n    if args.render_train:\n        os.makedirs(f'{logfolder}/{args.expname}/imgs_train_all', exist_ok=True)\n        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)\n        evaluation_feature(train_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_train_all/',\n                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)\n\n    if args.render_test:\n        os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)\n        evaluation_feature(test_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_test_all/',\n                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)\n\n    if args.render_path:\n        c2ws = test_dataset.render_path\n        os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)\n        evaluation_feature_path(test_dataset,tensorf, c2ws, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_path_all/',\n                                N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)\n\ndef reconstruction(args):\n\n    # init dataset\n    dataset = dataset_dict[args.dataset_name]\n    train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)\n    white_bg = train_dataset.white_bg\n    near_far = train_dataset.near_far\n    h_rays, w_rays = train_dataset.img_wh[1], train_dataset.img_wh[0]\n    ndc_ray = args.ndc_ray\n\n    patch_size = args.patch_size\n\n    \n    if args.add_timestamp:\n        logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime(\"-%Y%m%d-%H%M%S\")}'\n    else:\n        logfolder = f'{args.basedir}/{args.expname}'\n    \n\n    # init log file\n    os.makedirs(logfolder, exist_ok=True)\n    summary_writer = SummaryWriter(logfolder)\n\n\n\n    # init parameters\n    # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])\n    aabb = train_dataset.scene_bbox.to(device)\n    # TODO: need to update reso_cur\n    reso_cur = N_to_reso(args.N_voxel_init, aabb)\n    nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))\n\n\n    assert args.ckpt is not None, 'Have to be pre-trained to get density fielded!'\n\n    ckpt = torch.load(args.ckpt, map_location=device)\n    kwargs = ckpt['kwargs']\n    kwargs.update({'device':device})\n    tensorf = eval(args.model_name)(**kwargs)\n    tensorf.load(ckpt)\n   \n    tensorf.change_to_feature_mod(args.n_lamb_sh ,device)\n    tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre\n\n    train_dataset.prepare_feature_data(tensorf.encoder)\n\n    grad_vars = tensorf.get_optparam_groups_feature_mod(args.lr_init, args.lr_basis)\n    if args.lr_decay_iters > 0:\n        lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)\n    else:\n        args.lr_decay_iters = args.n_iters\n        lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)\n\n    print(\"lr decay\", args.lr_decay_target_ratio, args.lr_decay_iters)\n    \n    optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))\n\n\n    torch.cuda.empty_cache()\n    PSNRs = []\n\n    allrays, allfeatures = train_dataset.all_rays, train_dataset.all_features\n    allrays_stack, allrgbs_stack = train_dataset.all_rays_stack, train_dataset.all_rgbs_stack\n    if not args.ndc_ray:\n        allrays, allfeatures = tensorf.filtering_rays(allrays, allfeatures, bbox_only=True)\n    trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)\n    frameSampler = iter(InfiniteSamplerWrapper(allrays_stack.size(0))) # every next(sampler) returns a frame index\n\n\n    TV_weight_feature = args.TV_weight_feature\n    tvreg = TVLoss()\n    print(f\"initial TV_weight_feature: {TV_weight_feature}\")\n\n\n    pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)\n    for iteration in pbar:\n\n        feature_loss, pixel_loss = 0., 0.\n        if iteration%2==0:\n            ray_idx = trainingSampler.nextids()\n            rays_train, features_train = allrays[ray_idx], allfeatures[ray_idx].to(device)\n\n            feature_map, _ = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg = white_bg, \n                                ndc_ray=ndc_ray, render_feature=True, device=device, is_train=True)\n\n            feature_loss = torch.mean((feature_map - features_train) ** 2)\n        else:\n            frame_idx = next(frameSampler)\n            start_h = np.random.randint(0, h_rays-patch_size+1)\n            start_w = np.random.randint(0, w_rays-patch_size+1)\n            if white_bg:\n                # move random sampled patches into center\n                mid_h, mid_w = (h_rays-patch_size+1)/2, (w_rays-patch_size+1)/2\n                if mid_h-start_h>=1:\n                    start_h += np.random.randint(0, mid_h-start_h)\n                elif mid_h-start_h<=-1:\n                    start_h += np.random.randint(mid_h-start_h, 0)\n                if mid_w-start_w>=1:\n                    start_w += np.random.randint(0, mid_w-start_w)\n                elif mid_w-start_w<=-1:\n                    start_w += np.random.randint(mid_w-start_w, 0)\n\n            rays_train = allrays_stack[frame_idx, start_h:start_h+patch_size, \n                                                    start_w:start_w+patch_size, :].reshape(-1, 6).to(device)\n            # [patch*patch, 6]\n            \n            rgbs_train = allrgbs_stack[frame_idx, start_h:(start_h+patch_size), \n                                                  start_w:(start_w+patch_size), :].to(device)\n            # [patch, patch, 3]\n\n            feature_map, _ = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg=white_bg, \n                                ndc_ray=ndc_ray, render_feature=True, device=device, is_train=True)\n\n            feature_map = feature_map.reshape(patch_size, patch_size, 256)[None,...].permute(0,3,1,2)\n            recon_rgb = tensorf.decoder(feature_map)\n\n            rgbs_train = rgbs_train[None,...].permute(0,3,1,2)\n            img_enc = tensorf.encoder(normalize_vgg(rgbs_train))\n            recon_rgb_enc = tensorf.encoder(recon_rgb)\n            \n            feature_loss =(F.mse_loss(recon_rgb_enc.relu4_1, img_enc.relu4_1) +\n                           F.mse_loss(recon_rgb_enc.relu3_1, img_enc.relu3_1)) / 10\n\n            recon_rgb = denormalize_vgg(recon_rgb)\n\n            pixel_loss = torch.mean((recon_rgb - rgbs_train) ** 2)\n\n        total_loss = pixel_loss + feature_loss\n\n        # loss\n        # NOTE: Calculate feature TV loss rather than appearence TV loss\n        if TV_weight_feature>0:\n            TV_weight_feature *= lr_factor\n            loss_tv = tensorf.TV_loss_feature(tvreg)*TV_weight_feature\n            total_loss = total_loss + loss_tv\n            summary_writer.add_scalar('train/reg_tv_feature', loss_tv.detach().item(), global_step=iteration)\n\n        optimizer.zero_grad()\n        total_loss.backward()\n        optimizer.step()\n\n        if pixel_loss == 0:\n            feature_loss = feature_loss.detach().item()\n            PSNRs.append(-10.0 * np.log(feature_loss) / np.log(10.0))\n            summary_writer.add_scalar('train/PSNR_feature', PSNRs[-1], global_step=iteration)\n            summary_writer.add_scalar('train/mse_feature', feature_loss, global_step=iteration)\n        else:\n            pixel_loss = pixel_loss.detach().item()\n            PSNRs.append(-10.0 * np.log(pixel_loss) / np.log(10.0))\n            summary_writer.add_scalar('train/PSNR_pixel', PSNRs[-1], global_step=iteration)\n            summary_writer.add_scalar('train/mse_pixel', pixel_loss, global_step=iteration)\n            summary_writer.add_scalar('train/mse_recon_feature', feature_loss.detach().item(), global_step=iteration)\n\n\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = param_group['lr'] * lr_factor\n\n        # Print the current values of the losses.\n        if iteration % args.progress_refresh_rate == 0:\n            pbar.set_description(\n                f'Iteration {iteration:05d}:'\n                + f' psnr = {float(np.mean(PSNRs)):.2f}'\n            )\n            PSNRs = []\n\n        if iteration % (args.progress_refresh_rate*20) == 1:\n            summary_writer.add_image('output', make_grid([rgbs_train.squeeze(), \n                                                        recon_rgb.clamp(0, 1).squeeze()],  \n                                                        nrow=2, padding=0, normalize=False),\n                                                        global_step=iteration)\n        \n\n    tensorf.save(f'{logfolder}/{args.expname}.th')\n\nif __name__ == '__main__':\n\n    torch.set_default_dtype(torch.float32)\n    torch.manual_seed(20211202)\n    np.random.seed(20211202)\n\n    args = config_parser()\n    print(args)\n\n    if args.render_only:\n        render_test(args)\n    else:\n        reconstruction(args)\n\n"
  },
  {
    "path": "train_style.py",
    "content": "\nimport os\nfrom tqdm.auto import tqdm\nfrom opt import config_parser\nfrom PIL import Image, ImageFile\nfrom pathlib import Path\nfrom torchvision.utils import make_grid\nimport torchvision.transforms.functional as TF\n\nfrom renderer import *\nfrom utils import *\nfrom torch.utils.tensorboard import SummaryWriter\nimport datetime\n\nfrom dataLoader import dataset_dict\nfrom dataLoader.styleLoader import getDataLoader\n\nfrom models.styleModules import cal_mse_content_loss, cal_adain_style_loss\n\nimport sys\n\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nrenderer = OctreeRender_trilinear_fast\ndepth_renderer = OctreeRender_trilinear_fast_depth\n\n\nclass SimpleSampler:\n    def __init__(self, total, batch):\n        self.total = total\n        self.batch = batch\n        self.curr = total\n        self.ids = None\n\n    def nextids(self):\n        self.curr+=self.batch\n        if self.curr + self.batch > self.total:\n            self.ids = torch.LongTensor(np.random.permutation(self.total))\n            self.curr = 0\n        return self.ids[self.curr:self.curr+self.batch]\n\ndef InfiniteSampler(n):\n    # i = 0\n    i = n - 1\n    order = np.random.permutation(n)\n    while True:\n        yield order[i]\n        i += 1\n        if i >= n:\n            np.random.seed()\n            order = np.random.permutation(n)\n            i = 0\n\nclass InfiniteSamplerWrapper(torch.utils.data.sampler.Sampler):\n    def __init__(self, num_samples):\n        self.num_samples = num_samples\n\n    def __iter__(self):\n        return iter(InfiniteSampler(self.num_samples))\n\n    def __len__(self):\n        return 2 ** 31\n\n@torch.no_grad()\ndef render_test(args):\n    # init dataset\n    dataset = dataset_dict[args.dataset_name]\n    ndc_ray = args.ndc_ray\n\n    if not os.path.exists(args.ckpt):\n        print('the ckpt path does not exists!!')\n        return\n\n    assert args.style_img is not None, 'Must specify a style image!'\n\n    ckpt = torch.load(args.ckpt, map_location=device)\n    kwargs = ckpt['kwargs']\n    kwargs.update({'device': device})\n    tensorf = eval(args.model_name)(**kwargs)\n    tensorf.change_to_feature_mod(args.n_lamb_sh, device)\n    tensorf.change_to_style_mod(device)\n    tensorf.load(ckpt)\n    tensorf.eval()\n    tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre\n\n    logfolder = os.path.dirname(args.ckpt)\n\n    trans = T.Compose([T.Resize(size=(256,256)), T.ToTensor()])\n    style_img = trans(Image.open(args.style_img)).cuda()[None, ...]\n    style_name = Path(args.style_img).stem\n\n    if args.render_train:\n        train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)\n        os.makedirs(f'{logfolder}/{args.expname}/imgs_train_all/{style_name}', exist_ok=True)\n        evaluation_feature(train_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_train_all/{style_name}',\n                                N_vis=-1, N_samples=-1, white_bg = train_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device)\n    \n    if args.render_test:\n        test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)\n        os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all/{style_name}', exist_ok=True)\n        evaluation_feature(test_dataset,tensorf, args, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_test_all/{style_name}',\n                                N_vis=-1, N_samples=-1, white_bg = test_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device)\n\n    if args.render_path:\n        test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)\n        c2ws = test_dataset.render_path\n        os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all/{style_name}', exist_ok=True)\n        evaluation_feature_path(test_dataset, tensorf, c2ws, renderer, args.chunk_size, f'{logfolder}/{args.expname}/imgs_path_all/{style_name}',\n                N_vis=-1, N_samples=-1, white_bg = test_dataset.white_bg, ndc_ray=ndc_ray, style_img=style_img, device=device)\n\ndef reconstruction(args):\n\n    # init dataset\n    dataset = dataset_dict[args.dataset_name]\n    train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)\n    white_bg = train_dataset.white_bg\n    near_far = train_dataset.near_far\n    h_rays, w_rays = train_dataset.img_wh[1], train_dataset.img_wh[0]\n    ndc_ray = args.ndc_ray\n\n    patch_size = args.patch_size # ground truth image patch size when training\n\n    Image.MAX_IMAGE_PIXELS = None  # Disable DecompressionBombError\n    ImageFile.LOAD_TRUNCATED_IMAGES = True # Disable OSError: image file is truncated\n    style_loader = getDataLoader(args.wikiartdir, batch_size=1, sampler=InfiniteSamplerWrapper, \n                    image_side_length=256, num_workers=2)\n    style_iter = iter(style_loader)\n    \n    if args.add_timestamp:\n        logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime(\"-%Y%m%d-%H%M%S\")}'\n    else:\n        logfolder = f'{args.basedir}/{args.expname}'\n    \n\n    # init log file\n    os.makedirs(logfolder, exist_ok=True)\n    summary_writer = SummaryWriter(logfolder)\n\n    # init parameters\n    # tensorVM, renderer = init_parameters(args, train_dataset.scene_bbox.to(device), reso_list[0])\n    aabb = train_dataset.scene_bbox.to(device)\n    # TODO: need to update reso_cur\n    reso_cur = N_to_reso(args.N_voxel_init, aabb)\n    nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))\n\n\n    assert args.ckpt is not None, 'Have to be pre-trained to get density fielded!'\n\n    ckpt = torch.load(args.ckpt, map_location=device)\n    kwargs = ckpt['kwargs']\n    kwargs.update({'device':device})\n    tensorf = eval(args.model_name)(**kwargs)\n    tensorf.change_to_feature_mod(args.n_lamb_sh, device)\n    tensorf.load(ckpt)\n    tensorf.change_to_style_mod(device)\n    tensorf.rayMarch_weight_thres = args.rm_weight_mask_thre\n\n    tvreg = TVLoss()\n\n    grad_vars = tensorf.get_optparam_groups_style_mod(args.lr_basis, args.lr_finetune)\n    if args.lr_decay_iters > 0:\n        lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)\n    else:\n        args.lr_decay_iters = args.n_iters\n        lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)\n\n    print(\"lr decay\", args.lr_decay_target_ratio, args.lr_decay_iters)\n    \n    tensorf.train()\n    optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))\n\n    torch.cuda.empty_cache()\n\n    allrays_stack, allrgbs_stack = train_dataset.all_rays_stack, train_dataset.all_rgbs_stack\n    frameSampler = iter(InfiniteSamplerWrapper(allrays_stack.size(0))) # every next(sampler) returns a frame index\n\n    pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)\n    for iteration in pbar:\n\n        # get style_img, this style_img has NOT been normalized according to the pretrained VGGmodel\n        style_img = next(style_iter)[0].to(device)\n\n        # randomly sample patch_size*patch_size patch from given frame\n        frame_idx = next(frameSampler)\n        start_h = np.random.randint(0, h_rays-patch_size+1)\n        start_w = np.random.randint(0, w_rays-patch_size+1)\n        if white_bg:\n            # move random sampled patches into center\n            mid_h, mid_w = (h_rays-patch_size+1)/2, (w_rays-patch_size+1)/2\n            if mid_h-start_h>=1:\n                start_h += np.random.randint(0, mid_h-start_h)\n            elif mid_h-start_h<=-1:\n                start_h += np.random.randint(mid_h-start_h, 0)\n            if mid_w-start_w>=1:\n                start_w += np.random.randint(0, mid_w-start_w)\n            elif mid_w-start_w<=-1:\n                start_w += np.random.randint(mid_w-start_w, 0)\n\n        rays_train = allrays_stack[frame_idx, start_h:start_h+patch_size, start_w:start_w+patch_size, :]\\\n                            .reshape(-1, 6).to(device)\n        # [patch*patch, 6]\n        \n        rgbs_train = allrgbs_stack[frame_idx, start_h:(start_h+patch_size), \n                                            start_w:(start_w+patch_size), :].to(device)\n        # [patch, patch, 3]\n\n        feature_map, acc_map, style_feature = renderer(rays_train, tensorf, chunk=args.chunk_size, N_samples=nSamples, white_bg = white_bg, \n                                ndc_ray=ndc_ray, render_feature=True, style_img=style_img, device=device, is_train=True)\n\n        feature_map = feature_map.reshape(patch_size, patch_size, 256)[None,...].permute(0,3,1,2)\n        rgb_map = tensorf.decoder(feature_map)\n\n        # feature_map is trained with normalized rgb maps, so here we don't normalize the rgb map again.\n        rgbs_train = normalize_vgg(rgbs_train[None,...].permute(0,3,1,2))\n        \n        out_image_feature = tensorf.encoder(rgb_map)\n        content_feature = tensorf.encoder(rgbs_train)\n\n        if white_bg:\n            mask = acc_map.reshape(patch_size, patch_size, 1)[None,...].permute(0,3,1,2)\n            if not (mask>0.5).any(): continue\n            \n            # content loss\n            _mask = F.interpolate(mask, size=content_feature.relu4_1.size()[-2:], mode='bilinear').ge(1e-5)\n            content_loss = cal_mse_content_loss(torch.masked_select(content_feature.relu4_1, _mask), \n                                                torch.masked_select(out_image_feature.relu4_1, _mask))\n            # style loss\n            style_loss = 0.\n            for style_feature, image_feature in zip(style_feature, out_image_feature):\n                _mask = F.interpolate(mask, size=image_feature.size()[-2:], mode='bilinear').ge(1e-5)\n                C = image_feature.size()[1]\n                masked_img_feature = torch.masked_select(image_feature, _mask).reshape(1,C,-1)\n                style_loss += cal_adain_style_loss(style_feature, masked_img_feature)\n\n            content_loss *= args.content_weight\n            style_loss *= args.style_weight\n        else:\n            # content loss\n            content_loss = cal_mse_content_loss(content_feature.relu4_1, out_image_feature.relu4_1)\n            # style loss\n            style_loss = 0.\n            for style_feature, image_feature in zip(style_feature, out_image_feature):\n                style_loss += cal_adain_style_loss(style_feature, image_feature)\n\n            content_loss *= args.content_weight\n            style_loss *= args.style_weight\n\n        feature_tv_loss = tvreg(feature_map) * args.featuremap_tv_weight\n        image_tv_loss = tvreg(denormalize_vgg(rgb_map)) * args.image_tv_weight\n\n        total_loss = content_loss + style_loss + feature_tv_loss + image_tv_loss\n\n        optimizer.zero_grad()\n        total_loss.backward()\n        optimizer.step()\n\n        for param_group in optimizer.param_groups:\n            param_group['lr'] = param_group['lr'] * lr_factor\n      \n        # Print the current values of the losses.\n        if iteration%args.progress_refresh_rate==0:\n            summary_writer.add_scalar('train/content_loss', content_loss, global_step=iteration)\n            summary_writer.add_scalar('train/style_loss', style_loss, global_step=iteration)\n            summary_writer.add_scalar('train/feature_tv_loss', feature_tv_loss, global_step=iteration)\n            summary_writer.add_scalar('train/image_tv_loss', image_tv_loss, global_step=iteration)\n            pbar.set_description(\n                f'Iteration {iteration:05d}:'\n                + f' content_loss = {content_loss.item():.2f}'\n                + f' style_loss = {style_loss.item():.2f}'\n            )\n       \n        if iteration % (args.progress_refresh_rate*20) == 0:\n            summary_writer.add_image('output', make_grid([denormalize_vgg(rgbs_train).squeeze(), \\\n                                                denormalize_vgg(rgb_map).clamp(0, 1).squeeze(), \\\n                                                TF.resize(style_img, (patch_size,patch_size)).squeeze()],  \n                                                nrow=3, padding=0, normalize=False),\n                                                global_step=iteration)\n        \n    tensorf.save(f'{logfolder}/{args.expname}.th')\n\n\nif __name__ == '__main__':\n\n    torch.set_default_dtype(torch.float32)\n    torch.manual_seed(20211202)\n    np.random.seed(20211202)\n\n    args = config_parser()\n    print(args)\n\n    if args.render_only:\n        render_test(args)\n    else:\n        reconstruction(args)\n\n"
  },
  {
    "path": "utils.py",
    "content": "import cv2,torch\nimport numpy as np\nfrom PIL import Image\nimport torchvision.transforms as T\nimport torch.nn.functional as F\nimport scipy.signal\n\nmse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))\n\n\ndef visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):\n    \"\"\"\n    depth: (H, W)\n    \"\"\"\n\n    x = np.nan_to_num(depth) # change nan to 0\n    if minmax is None:\n        mi = np.min(x[x>0]) # get minimum positive depth (ignore background)\n        ma = np.max(x)\n    else:\n        mi,ma = minmax\n\n    x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1\n    x = (255*x).astype(np.uint8)\n    x_ = cv2.applyColorMap(x, cmap)\n    return x_, [mi,ma]\n\ndef init_log(log, keys):\n    for key in keys:\n        log[key] = torch.tensor([0.0], dtype=float)\n    return log\n\ndef visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):\n    \"\"\"\n    depth: (H, W)\n    \"\"\"\n    if type(depth) is not np.ndarray:\n        depth = depth.cpu().numpy()\n\n    x = np.nan_to_num(depth) # change nan to 0\n    if minmax is None:\n        mi = np.min(x[x>0]) # get minimum positive depth (ignore background)\n        ma = np.max(x)\n    else:\n        mi,ma = minmax\n\n    x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1\n    x = (255*x).astype(np.uint8)\n    x_ = Image.fromarray(cv2.applyColorMap(x, cmap))\n    x_ = T.ToTensor()(x_)  # (3, H, W)\n    return x_, [mi,ma]\n\ndef N_to_reso(n_voxels, bbox):\n    xyz_min, xyz_max = bbox\n    dim = len(xyz_min)\n    voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)\n    return ((xyz_max - xyz_min) / voxel_size).long().tolist()\n\ndef cal_n_samples(reso, step_ratio=0.5):\n    return int(np.linalg.norm(reso)/step_ratio)\n\n\n\n\n__LPIPS__ = {}\ndef init_lpips(net_name, device):\n    assert net_name in ['alex', 'vgg']\n    import lpips\n    print(f'init_lpips: lpips_{net_name}')\n    return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)\n\ndef rgb_lpips(np_gt, np_im, net_name, device):\n    if net_name not in __LPIPS__:\n        __LPIPS__[net_name] = init_lpips(net_name, device)\n    gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)\n    im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)\n    return __LPIPS__[net_name](gt, im, normalize=True).item()\n\n\ndef findItem(items, target):\n    for one in items:\n        if one[:len(target)]==target:\n            return one\n    return None\n\n\n''' Evaluation metrics (ssim, lpips)\n'''\ndef rgb_ssim(img0, img1, max_val,\n             filter_size=11,\n             filter_sigma=1.5,\n             k1=0.01,\n             k2=0.03,\n             return_map=False):\n    # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58\n    assert len(img0.shape) == 3\n    assert img0.shape[-1] == 3\n    assert img0.shape == img1.shape\n\n    # Construct a 1D Gaussian blur filter.\n    hw = filter_size // 2\n    shift = (2 * hw - filter_size + 1) / 2\n    f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2\n    filt = np.exp(-0.5 * f_i)\n    filt /= np.sum(filt)\n\n    # Blur in x and y (faster than the 2D convolution).\n    def convolve2d(z, f):\n        return scipy.signal.convolve2d(z, f, mode='valid')\n\n    filt_fn = lambda z: np.stack([\n        convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])\n        for i in range(z.shape[-1])], -1)\n    mu0 = filt_fn(img0)\n    mu1 = filt_fn(img1)\n    mu00 = mu0 * mu0\n    mu11 = mu1 * mu1\n    mu01 = mu0 * mu1\n    sigma00 = filt_fn(img0**2) - mu00\n    sigma11 = filt_fn(img1**2) - mu11\n    sigma01 = filt_fn(img0 * img1) - mu01\n\n    # Clip the variances and covariances to valid values.\n    # Variance must be non-negative:\n    sigma00 = np.maximum(0., sigma00)\n    sigma11 = np.maximum(0., sigma11)\n    sigma01 = np.sign(sigma01) * np.minimum(\n        np.sqrt(sigma00 * sigma11), np.abs(sigma01))\n    c1 = (k1 * max_val)**2\n    c2 = (k2 * max_val)**2\n    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)\n    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)\n    ssim_map = numer / denom\n    ssim = np.mean(ssim_map)\n    return ssim_map if return_map else ssim\n\n\nimport torch.nn as nn\nclass TVLoss(nn.Module):\n    def __init__(self):\n        super(TVLoss,self).__init__()\n\n    def forward(self,x):\n        batch_size = x.size()[0]\n        h_x = x.size()[2]\n        w_x = x.size()[3]\n\n        if w_x==1:\n            count_h = self._tensor_size(x[:,:,1:,:])\n            h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()\n            return 2*(h_tv/count_h)/batch_size\n\n        if h_x==1:\n            count_w = self._tensor_size(x[:,:,:,1:])\n            w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()\n            return 2*(w_tv/count_w)/batch_size\n\n        count_h = self._tensor_size(x[:,:,1:,:])\n        count_w = self._tensor_size(x[:,:,:,1:])\n        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()\n        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()\n        return 2*(h_tv/count_h+w_tv/count_w)/batch_size\n\n    def _tensor_size(self,t):\n        return t.size()[1]*t.size()[2]*t.size()[3]\n\n\n\nimport plyfile\nimport skimage.measure\ndef convert_sdf_samples_to_ply(\n    pytorch_3d_sdf_tensor,\n    ply_filename_out,\n    bbox,\n    level=0.5,\n    offset=None,\n    scale=None,\n):\n    \"\"\"\n    Convert sdf samples to .ply\n\n    :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)\n    :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid\n    :voxel_size: float, the size of the voxels\n    :ply_filename_out: string, path of the filename to save to\n\n    This function adapted from: https://github.com/RobotLocomotion/spartan\n    \"\"\"\n\n    numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()\n    voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape))\n\n    verts, faces, normals, values = skimage.measure.marching_cubes(\n        numpy_3d_sdf_tensor, level=level, spacing=voxel_size\n    )\n    faces = faces[...,::-1] # inverse face orientation\n\n    # transform from voxel coordinates to camera coordinates\n    # note x and y are flipped in the output of marching_cubes\n    mesh_points = np.zeros_like(verts)\n    mesh_points[:, 0] = bbox[0,0] + verts[:, 0]\n    mesh_points[:, 1] = bbox[0,1] + verts[:, 1]\n    mesh_points[:, 2] = bbox[0,2] + verts[:, 2]\n\n    # apply additional offset and scale\n    if scale is not None:\n        mesh_points = mesh_points / scale\n    if offset is not None:\n        mesh_points = mesh_points - offset\n\n    # try writing to the ply file\n\n    num_verts = verts.shape[0]\n    num_faces = faces.shape[0]\n\n    verts_tuple = np.zeros((num_verts,), dtype=[(\"x\", \"f4\"), (\"y\", \"f4\"), (\"z\", \"f4\")])\n\n    for i in range(0, num_verts):\n        verts_tuple[i] = tuple(mesh_points[i, :])\n\n    faces_building = []\n    for i in range(0, num_faces):\n        faces_building.append(((faces[i, :].tolist(),)))\n    faces_tuple = np.array(faces_building, dtype=[(\"vertex_indices\", \"i4\", (3,))])\n\n    el_verts = plyfile.PlyElement.describe(verts_tuple, \"vertex\")\n    el_faces = plyfile.PlyElement.describe(faces_tuple, \"face\")\n\n    ply_data = plyfile.PlyData([el_verts, el_faces])\n    print(\"saving mesh to %s\" % (ply_filename_out))\n    ply_data.write(ply_filename_out)\n\n\n# Point cloud operations\n# import pytorch3d\n# import pytorch3d.transforms as transforms\n# from pytorch3d.structures import Pointclouds\n# from pytorch3d.renderer import (\n#     look_at_view_transform,\n#     FoVOrthographicCameras, \n#     PointsRasterizationSettings,\n#     PointsRenderer,\n#     PointsRasterizer,\n#     AlphaCompositor,\n# )\n# import imageio\n\n# def construct_points_coordinates(rays, depth):\n#     '''\n#     Construct points' coordinates of a point cloud, every point corresponds to \n#         a point on one ray with specified depth.\n\n#     Args:\n#         rays: [n_rays, 6]\n#         depth: [n_rays] \n    \n#     Return:\n#         point_cloud: [n_rays, 3]\n#     '''\n#     rays_o, rays_d = rays[:, :3], rays[:, 3:6]\n\n#     points_coordinates = rays_o + rays_d * depth[...,None]\n\n#     return points_coordinates\n\n# def plot_point_cloud(points, rgbs, prtx=''):\n#     '''\n#     Args:\n#         points: coodinates of points [n, 3]\n#         rgbs: color of points [n, 3]\n#     '''\n\n#     point_cloud = Pointclouds(points=[points], features=[rgbs])\n\n#     images = []\n#     for degree in range(0,60,2):\n\n#         R, T = look_at_view_transform(20, 10, degree)\n#         cameras = FoVOrthographicCameras(device='cuda', R=R, T=T, znear=0.01)\n\n#         raster_settings = PointsRasterizationSettings(\n#             image_size=512, \n#             radius = 0.003,\n#             points_per_pixel = 10\n#         )\n\n#         rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)\n#         renderer = PointsRenderer(\n#             rasterizer=rasterizer,\n#             compositor=AlphaCompositor()\n#         )\n        \n#         image = renderer(point_cloud).squeeze()\n#         image = (image.detach().cpu().numpy() * 255).astype('uint8')\n#         images.append(image)\n\n#         # imageio.imwrite(f'./visualization/{prtx}pointcloud_{degree:02d}.png', image)\n\n#     imageio.mimwrite(f'./visualization/{prtx}pointcloud_video.mp4', np.stack(images), fps=10, quality=8)\n\n# def knn_loss(p1, f1, p2, f2, thres=0.0005):\n#     '''\n#     Args:\n#         p1, p2: points [p1,3] [p2,3]\n#         f1, f2: features [p1,c] [p2,c]\n#         thres: if dist > thres then the two points are not adjacent\n#     '''\n#     dists, idx, _ = pytorch3d.ops.knn_points(p1[None,...], p2[None,...], K=1)\n#     # [N=1, P1, K=1]\n\n#     match_feature = f2[idx.squeeze()] # [p1, C]\n\n#     square_err = torch.sum((f1 - match_feature)**2, dim=-1) # [p1]\n\n#     return torch.masked_select(square_err, dists.squeeze()<thres).mean()\n\n\n# def random_rotate_rays(rays, ds_rays, depth, max_degree=0.2, device='cuda'):\n#     '''\n#     random rotate a batch of rays around their main view point.\n\n#     Args:\n#         rays: [n_rays, 6]\n#         ds_rays: [n_rays/16, 6]\n#         depth: [n_rays] \n    \n#     Return:\n#         rotate_rays: [n_rays, 6]\n#     '''\n#     mean_ray = torch.mean(rays, dim=0) #[6]\n#     mean_depth = torch.mean(depth) #[1]\n#     main_view_point = mean_ray[:3] + mean_depth * mean_ray[3:] #[3]\n\n#     # construct transformation\n#     t_forward = transforms.Translate(*(-main_view_point), device=device)\n#     t_back = transforms.Translate(*main_view_point, device=device)\n\n#     axis = torch.nn.functional.normalize(torch.rand((3,), device=device), dim=0)\n#     angle = torch.rand((1,), device=device) * max_degree\n\n#     R = transforms.axis_angle_to_matrix(axis*angle)\n#     rotation = transforms.Rotate(R, device=device)\n\n#     transform = t_forward.compose(rotation).compose(t_back)\n\n#     rays_o = transform.transform_points(rays[:,:3])\n#     rays_d = transform.transform_points(rays[:,3:])\n\n#     ds_rays_o = transform.transform_points(ds_rays[:,:3])\n#     ds_rays_d = transform.transform_points(ds_rays[:,3:])\n\n#     return torch.cat([rays_o, rays_d], dim=-1), torch.cat([ds_rays_o, ds_rays_d], dim=-1)\n\n\ndef get_checkerboard(fg, n=8):\n    B, _, H, W = fg.shape\n    colors = torch.rand(2, B, 3, 1, 1, 1, 1, dtype=fg.dtype, device=fg.device)\n    h = H // n\n    w = W // n\n    bg = torch.ones(B, 3, n, h, n, w, dtype=fg.dtype, device=fg.device) * colors[0]\n    bg[:, :, ::2, :, 1::2] = colors[1]\n    bg[:, :, 1::2, :, ::2] = colors[1]\n    bg = bg.view(B, 3, H, W)\n    return bg\n"
  }
]