[
  {
    "path": ".gitignore",
    "content": ".idea\n__pycache__\ndata.zip\ndata\n"
  },
  {
    "path": "LICENSE",
    "content": "Copyright 2022 Michael A. Alcorn\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the \"Software\"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.\n"
  },
  {
    "path": "README.md",
    "content": "# PyTorch NeRF and pixelNeRF\n\n**NeRF**: [![Open NeRF in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oRnnlF-2YqCDIzoc-uShQm8_yymLKiqr)\n\n**Tiny NeRF**: [![Open Tiny NeRF in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ntlbzQ121-E1BSa5EKvAyai6SMG4cylj)\n\n**pixelNeRF**: [![Open pixelNeRF in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1VEEy4VOVoQTQKo4oG3nWcfKAXjC_0fFt)\n\nThis repository contains minimal PyTorch implementations of the NeRF model described in \"[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](https://arxiv.org/abs/2003.08934)\" and the pixelNeRF model described in [\"pixelNeRF: Neural Radiance Fields from One or Few Images\"](https://arxiv.org/abs/2012.02190).\nWhile there are other PyTorch implementations out there (e.g., [this one](https://github.com/krrish94/nerf-pytorch) and [this one](https://github.com/yenchenlin/nerf-pytorch) for NeRF, and [the authors' official implementation](https://github.com/sxyu/pixel-nerf) for pixelNeRF), I personally found them somewhat difficult to follow, so I decided to do a complete rewrite of NeRF myself.\nI tried to stay as close to the authors' text as possible, and I added comments in the code referring back to the relevant sections/equations in the paper.\nThe final result is a tight 355 lines of heavily commented code (301 sloc—\"source lines of code\"—on GitHub) all contained in [a single file](run_nerf.py). For comparison, [this PyTorch implementation](https://github.com/krrish94/nerf-pytorch) has approximately 970 sloc spread across several files, while [this PyTorch implementation](https://github.com/yenchenlin/nerf-pytorch) has approximately 905 sloc.\n\n[`run_tiny_nerf.py`](run_tiny_nerf.py) trains a simplified NeRF model inspired by the \"[Tiny NeRF](https://colab.research.google.com/github/bmild/nerf/blob/master/tiny_nerf.ipynb)\" example provided by the NeRF authors.\nThis NeRF model does not use fine sampling and the MLP is smaller, but the code is otherwise identical to the full model code.\nAt only 153 sloc, it might be a good place to start for people who are completely new to NeRF.\nIf you prefer your code more object-oriented, check out [`run_nerf_alt.py`](run_nerf_alt.py) and [`run_tiny_nerf_alt.py`](run_tiny_nerf_alt.py).\n\nA Colab notebook for the full model can be found [here](https://colab.research.google.com/drive/1oRnnlF-2YqCDIzoc-uShQm8_yymLKiqr?usp=sharing), while a notebook for the tiny model can be found [here](https://colab.research.google.com/drive/1ntlbzQ121-E1BSa5EKvAyai6SMG4cylj?usp=sharing).\nThe [`generate_nerf_dataset.py`](generate_nerf_dataset.py) script was used to generate the training data of the ShapeNet car (see \"[Generating the ShapeNet datasets](#generating-the-shapenet-datasets)\" for additional details).\n\nFor the following test view:\n\n![](test_view.png)\n\n[`run_nerf.py`](run_nerf.py) generated the following after 20,100 iterations (a few hours on a P100 GPU):\n\n**Loss**: 0.00022201683896128088\n\n![](nerf.png)\n\nwhile [`run_tiny_nerf.py`](run_tiny_nerf.py) generated the following after 19,600 iterations (~35 minutes on a P100 GPU):\n\n**Loss**: 0.0004151524917688221\n\n![](tiny_nerf.png)\n\nThe advantages of streamlining NeRF's code become readily apparent when trying to extend NeRF.\nFor example, [training a pixelNeRF model](run_pixelnerf.py) only required making a few changes to [`run_nerf.py`](run_nerf.py) bringing it to 368 sloc (notebook [here](https://colab.research.google.com/drive/1VEEy4VOVoQTQKo4oG3nWcfKAXjC_0fFt?usp=sharing)).\nFor comparison, [the official pixelNeRF implementation](https://github.com/sxyu/pixel-nerf) has approximately 1,300 pixelNeRF-specific (i.e., not related to the image encoder or dataset) sloc spread across several files.\nThe [`generate_pixelnerf_dataset.py`](generate_pixelnerf_dataset.py) script was used to generate the training data of ShapeNet cars (see \"[Generating the ShapeNet datasets](#generating-the-shapenet-datasets)\" for additional details).\n\nFor the following source object and view:\n\n![](pixelnerf_src.png)\n\nand target view:\n\n![](pixelnerf_tgt.png)\n\n[`run_pixelnerf.py`](run_pixelnerf.py) generated the following after 73,243 iterations (~12 hours on a P100 GPU; the full pixelNeRF model was trained for 400,000 iterations, which took six days):\n\n**Loss**: 0.004468636587262154\n\n![](pixelnerf.png)\n\nThe \"smearing\" is an artifact caused by the bounding box sampling method.\n\n## Generating the ShapeNet datasets\n\n1) Download the data (the ShapeNet server is pretty slow, so this will take a while):\n\n```bash\nSHAPENET_BASE_DIR=<path/to/your/shapenet/root>\nnohup wget --quiet -P ${SHAPENET_BASE_DIR} http://shapenet.cs.stanford.edu/shapenet/obj-zip/ShapeNetCore.v2.zip > shapenet.log &\n```\n\n2) Unzip the data:\n\n```bash\ncd ${SHAPENET_BASE_DIR}\nnohup unzip -q ShapeNetCore.v2.zip > shapenet.log &\n```\n\n3) ***After*** the file is done unzipping, remove the ZIP:\n\n```bash\nrm ShapeNetCore.v2.zip\n```\n\n4) Change the `SHAPENET_DIR` variable in [`generate_nerf_dataset.py`](generate_nerf_dataset.py) and [`generate_pixelnerf_dataset.py`](generate_pixelnerf_dataset.py) to `<path/to/your/shapenet/root>/ShapeNetCore.v2`.\n"
  },
  {
    "path": "generate_nerf_dataset.py",
    "content": "import numpy as np\n\nfrom pyrr import Matrix44\nfrom renderer import gen_rotation_matrix_from_cam_pos, Renderer\nfrom renderer_settings import *\n\nSHAPENET_DIR = \"/run/media/airalcorn2/MiQ BIG/ShapeNetCore.v2\"\n\n\ndef main():\n    # Set up the renderer.\n    renderer = Renderer(\n        camera_distance=CAMERA_DISTANCE,\n        angle_of_view=ANGLE_OF_VIEW,\n        dir_light=DIR_LIGHT,\n        dif_int=DIF_INT,\n        amb_int=AMB_INT,\n        default_width=WINDOW_SIZE,\n        default_height=WINDOW_SIZE,\n        cull_faces=CULL_FACES,\n    )\n    img_size = 100\n    # Calculate focal length in pixel units. This is just geometry. See:\n    # https://en.wikipedia.org/wiki/Angle_of_view#Derivation_of_the_angle-of-view_formula.\n    focal = (img_size / 2) / np.tan(np.radians(ANGLE_OF_VIEW) / 2)\n\n    # Load the ShapeNet car object.\n    obj = \"66bdbc812bd0a196e194052f3f12cb2e\"\n    cat = \"02958343\"\n    obj_mtl_path = f\"{SHAPENET_DIR}/{cat}/{obj}/models/model_normalized\"\n    renderer.set_up_obj(f\"{obj_mtl_path}.obj\", f\"{obj_mtl_path}.mtl\")\n\n    # Generate car renders using random camera locations.\n    init_cam_pos = np.array([0, 0, CAMERA_DISTANCE])\n    target = np.zeros(3)\n    up = np.array([0.0, 1.0, 0.0])\n    samps = 800\n    imgs = []\n    poses = []\n    for idx in range(samps):\n        # See: https://stats.stackexchange.com/a/7984/81836.\n        xyz = np.random.normal(size=3)\n        xyz /= np.linalg.norm(xyz)\n        R = gen_rotation_matrix_from_cam_pos(xyz)\n        eye = tuple((R @ init_cam_pos).flatten())\n        look_at = Matrix44.look_at(eye, target, up)\n        renderer.prog[\"VP\"].write(\n            (look_at @ renderer.perspective).astype(\"f4\").tobytes()\n        )\n        renderer.prog[\"cam_pos\"].value = eye\n\n        image = renderer.render(0.5, 0.5, 0.5).resize((img_size, img_size))\n        imgs.append(np.array(image))\n\n        pose = np.eye(4)\n        pose[:3, :3] = np.array(look_at[:3, :3])\n        pose[:3, 3] = -look_at[:3, :3] @ look_at[3, :3]\n        poses.append(pose)\n\n    imgs = np.stack(imgs)\n    poses = np.stack(poses)\n    np.savez(\n        f\"{obj}.npz\",\n        images=imgs,\n        poses=poses,\n        focal=focal,\n        camera_distance=CAMERA_DISTANCE,\n    )\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "generate_pixelnerf_dataset.py",
    "content": "import numpy as np\nimport os\nimport sys\n\nfrom pyrr import Matrix44\nfrom renderer import gen_rotation_matrix_from_cam_pos, Renderer\nfrom renderer_settings import *\n\nSHAPENET_DIR = \"/run/media/airalcorn2/MiQ BIG/ShapeNetCore.v2\"\n\n\ndef main():\n    # Set up the renderer.\n    renderer = Renderer(\n        camera_distance=CAMERA_DISTANCE,\n        angle_of_view=ANGLE_OF_VIEW,\n        dir_light=DIR_LIGHT,\n        dif_int=DIF_INT,\n        amb_int=AMB_INT,\n        default_width=WINDOW_SIZE,\n        default_height=WINDOW_SIZE,\n        cull_faces=CULL_FACES,\n    )\n    # See Section 5.1.1.\n    img_size = 128\n    # Calculate focal length in pixel units. This is just geometry. See:\n    # https://en.wikipedia.org/wiki/Angle_of_view#Derivation_of_the_angle-of-view_formula.\n    focal = (img_size / 2) / np.tan(np.radians(ANGLE_OF_VIEW) / 2)\n\n    # Generate car renders using random camera locations.\n    init_cam_pos = np.array([0, 0, CAMERA_DISTANCE])\n    target = np.zeros(3)\n    up = np.array([0.0, 1.0, 0.0])\n    # See Section 5.1.1.\n    samps = 50\n    z_len = len(str(samps - 1))\n    data_dir = \"data\"\n    poses = []\n    os.mkdir(data_dir)\n\n    # Car category.\n    cat = \"02958343\"\n    objs = os.listdir(f\"{SHAPENET_DIR}/{cat}\")\n    used_objs = []\n    for obj in objs:\n        # Load the ShapeNet object.\n        obj_mtl_path = f\"{SHAPENET_DIR}/{cat}/{obj}/models/model_normalized\"\n        try:\n            renderer.set_up_obj(f\"{obj_mtl_path}.obj\", f\"{obj_mtl_path}.mtl\")\n            sys.stderr.flush()\n        except OSError:\n            print(f\"{SHAPENET_DIR}/{cat}/{obj} is empty.\", flush=True)\n            continue\n\n        except FloatingPointError:\n            print(f\"{SHAPENET_DIR}/{cat}/{obj} divides by zero.\", flush=True)\n\n        obj_dir = f\"{data_dir}/{obj}\"\n        os.mkdir(obj_dir)\n        obj_poses = []\n        for samp_idx in range(samps):\n            # See: https://stats.stackexchange.com/a/7984/81836.\n            xyz = np.random.normal(size=3)\n            xyz /= np.linalg.norm(xyz)\n            R = gen_rotation_matrix_from_cam_pos(xyz)\n            eye = tuple((R @ init_cam_pos).flatten())\n            look_at = Matrix44.look_at(eye, target, up)\n            renderer.prog[\"VP\"].write(\n                (look_at @ renderer.perspective).astype(\"f4\").tobytes()\n            )\n            renderer.prog[\"cam_pos\"].value = eye\n\n            image = renderer.render(0.5, 0.5, 0.5).resize((img_size, img_size))\n            np.save(f\"{obj_dir}/{str(samp_idx).zfill(z_len)}.npy\", np.array(image))\n\n            pose = np.eye(4)\n            pose[:3, :3] = np.array(look_at[:3, :3])\n            pose[:3, 3] = -look_at[:3, :3] @ look_at[3, :3]\n            obj_poses.append(pose)\n\n        obj_poses = np.stack(obj_poses)\n        poses.append(obj_poses)\n        renderer.release_obj()\n        used_objs.append(obj)\n\n    poses = np.stack(poses)\n    np.savez(\n        f\"{data_dir}/poses.npz\",\n        poses=poses,\n        focal=focal,\n        camera_distance=CAMERA_DISTANCE,\n    )\n    with open(f\"{data_dir}/objs.txt\", \"w\") as f:\n        print(\"\\n\".join(used_objs), file=f)\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "image_encoder.py",
    "content": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision.models import resnet34\n\n\nclass ImageEncoder(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.resnet = resnet34(True)\n\n    def forward(self, x):\n        # Extract feature pyramid from image. See Section 4.1., Section B.1 in the\n        # Supplementary Materials, and: https://github.com/sxyu/pixel-nerf/blob/master/src/model/encoder.py.\n        x = self.resnet.conv1(x)\n        x = self.resnet.bn1(x)\n        feats1 = self.resnet.relu(x)\n\n        feats2 = self.resnet.layer1(self.resnet.maxpool(feats1))\n        feats3 = self.resnet.layer2(feats2)\n        feats4 = self.resnet.layer3(feats3)\n\n        latents = [feats1, feats2, feats3, feats4]\n        latent_sz = latents[0].shape[-2:]\n        for i in range(len(latents)):\n            latents[i] = F.interpolate(\n                latents[i], latent_sz, mode=\"bilinear\", align_corners=True\n            )\n\n        latents = torch.cat(latents, dim=1)\n        return latents\n"
  },
  {
    "path": "pixelnerf_dataset.py",
    "content": "import numpy as np\nimport torch\n\nfrom torch.utils.data import Dataset\n\n\nclass PixelNeRFDataset(Dataset):\n    def __init__(\n        self,\n        data_dir,\n        num_iters,\n        test_obj_idx,\n        test_source_pose_idx,\n        test_target_pose_idx,\n    ):\n        self.data_dir = data_dir\n        self.N = num_iters\n        with open(f\"{data_dir}/objs.txt\") as f:\n            self.objs = f.read().split(\"\\n\")[:-1]\n\n        self.test_obj_idx = test_obj_idx\n        self.test_source_pose_idx = test_source_pose_idx\n        self.test_target_pose_idx = test_target_pose_idx\n        data = np.load(f\"{data_dir}/poses.npz\")\n        self.poses = poses = data[\"poses\"]\n        (n_objs, n_poses) = poses.shape[:2]\n        self.z_len = len(str(n_poses - 1))\n        self.poses = torch.Tensor(poses)\n\n        self.channel_means = torch.Tensor([0.485, 0.456, 0.406])\n        self.channel_stds = torch.Tensor([0.229, 0.224, 0.225])\n\n        samp_img = np.load(f\"{data_dir}/{self.objs[0]}/{str(0).zfill(self.z_len)}.npy\")\n        img_size = samp_img.shape[0]\n        self.pix_idxs = np.arange(img_size ** 2)\n        xs = torch.arange(img_size) - (img_size / 2 - 0.5)\n        ys = torch.arange(img_size) - (img_size / 2 - 0.5)\n        (xs, ys) = torch.meshgrid(xs, -ys, indexing=\"xy\")\n        focal = float(data[\"focal\"])\n        pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)\n        camera_coords = pixel_coords / focal\n        self.init_ds = camera_coords\n        self.camera_distance = camera_distance = float(data[\"camera_distance\"])\n        self.init_o = torch.Tensor(np.array([0, 0, camera_distance]))\n        # tan(theta) = opposite / adjacent.\n        self.scale = (img_size / 2) / focal\n\n    def __len__(self):\n        return self.N\n\n    def __getitem__(self, idx):\n        obj_idx = np.random.randint(self.poses.shape[0])\n        obj = self.objs[obj_idx]\n        obj_dir = f\"{self.data_dir}/{obj}\"\n\n        source_pose_idx = np.random.randint(self.poses.shape[1])\n        if obj_idx == self.test_obj_idx:\n            while source_pose_idx == self.test_source_pose_idx:\n                source_pose_idx = np.random.randint(self.poses.shape[1])\n\n        source_img_f = f\"{obj_dir}/{str(source_pose_idx).zfill(self.z_len)}.npy\"\n        source_image = torch.Tensor(np.load(source_img_f) / 255)\n        source_image = (source_image - self.channel_means) / self.channel_stds\n        source_pose = self.poses[obj_idx, source_pose_idx]\n        source_R = source_pose[:3, :3]\n\n        target_pose_idx = np.random.randint(self.poses.shape[1])\n        if obj_idx == self.test_obj_idx:\n            while (target_pose_idx == self.test_source_pose_idx) or (\n                target_pose_idx == self.test_target_pose_idx\n            ):\n                target_pose_idx = np.random.randint(self.poses.shape[1])\n\n        target_img_f = f\"{obj_dir}/{str(target_pose_idx).zfill(self.z_len)}.npy\"\n        target_image = np.load(target_img_f)\n        not_gray_pix = np.argwhere((target_image == 128).sum(-1) != 3)\n        top_row = not_gray_pix[:, 0].min()\n        bottom_row = not_gray_pix[:, 0].max()\n        left_col = not_gray_pix[:, 1].min()\n        right_col = not_gray_pix[:, 1].max()\n        bbox = (top_row, left_col, bottom_row, right_col)\n\n        target_image = np.load(target_img_f) / 255\n        target_pose = self.poses[obj_idx, target_pose_idx]\n        target_R = target_pose[:3, :3]\n\n        R = source_R.T @ target_R\n\n        return (source_image, torch.Tensor(R), torch.Tensor(target_image), bbox)\n"
  },
  {
    "path": "renderer.py",
    "content": "import logging\nimport moderngl\nimport numpy as np\n\nfrom PIL import Image, ImageOps\nfrom pyrr import Matrix44\nfrom scipy.spatial.transform import Rotation\n\nYAW_PITCH_ROLL = {\"yaw\", \"pitch\", \"roll\"}\nAZIM_ELEV_IN_PLANE = {\"azimuth\", \"elevation\", \"in_plane\"}\nTOL = 1e-6\n\n\ndef gen_rotation_matrix_from_cam_pos(xyz, in_plane=0.0):\n    assert 1 - np.linalg.norm(xyz) < TOL\n\n    cam_from = xyz\n    cam_to = np.zeros(3)\n    tmp = np.array([0.0, 1.0, 0.0])\n\n    diff = cam_from - cam_to\n    forward = diff / np.linalg.norm(diff)\n    crossed = np.cross(tmp, forward)\n    right = crossed / np.linalg.norm(crossed)\n    up = np.cross(forward, right)\n\n    R = np.stack([right, up, forward])\n    R_in_plane = Rotation.from_euler(\"Z\", in_plane).as_matrix()\n    return R_in_plane @ R\n\n\n\ndef gen_rotation_matrix_from_azim_elev_in_plane(\n    azimuth=0.0, elevation=0.0, in_plane=0.0\n):\n    # See: https://www.scratchapixel.com/lessons/mathematics-physics-for-computer-graphics/lookat-function.\n    y = np.sin(elevation)\n    radius = np.cos(elevation)\n    x = radius * np.sin(azimuth)\n    z = radius * np.cos(azimuth)\n\n    cam_from = np.array([x, y, z])\n    cam_to = np.zeros(3)\n    tmp = np.array([0.0, 1.0, 0.0])\n\n    diff = cam_from - cam_to\n    forward = diff / np.linalg.norm(diff)\n    crossed = np.cross(tmp, forward)\n    right = crossed / np.linalg.norm(crossed)\n    up = np.cross(forward, right)\n\n    R = np.stack([right, up, forward])\n    R_in_plane = Rotation.from_euler(\"Z\", in_plane).as_matrix()\n    return R_in_plane @ R\n\n\ndef parse_obj_file(input_obj):\n    \"\"\"Parse wavefront .obj file.\n\n    :param input_obj:\n    :return: Dictionary of NumPy arrays with shape (3 * num_faces, 8). Each row contains\n    (1) the coordinates of a vertex of a face, (2) the vertex's normal vector, and (3)\n    the texture coordinates for the vertex.\n    \"\"\"\n    data = {\"v\": [], \"vn\": [], \"vt\": []}\n    packed_arrays = {}\n    obj_f = open(input_obj)\n    current_mtl = None\n    min_vec = np.full(3, np.inf)\n    max_vec = np.full(3, -np.inf)\n    empty_vt = np.array([0.0, 0.0, 0.0])\n    for line in obj_f:\n        line = line.strip()\n        if line == \"\":\n            continue\n\n        parts = line.split()\n        elem_type = parts[0]\n        if elem_type in data:\n            vals = np.array(parts[1:4], dtype=np.float32)\n            if elem_type == \"v\":\n                min_vec = np.minimum(min_vec, vals)\n                max_vec = np.maximum(max_vec, vals)\n            elif elem_type == \"vn\":\n                vals /= np.linalg.norm(vals)\n            elif elem_type == \"vt\":\n                if len(vals) < 3:\n                    vals = np.array(list(vals) + [0.0], dtype=np.float32)\n\n            data[elem_type].append(vals)\n        elif elem_type == \"f\":\n            f = parts[1:4]\n            for fv in f:\n                (v, vt, vn) = fv.split(\"/\")\n\n                # Convert to zero-based indexing.\n                v = int(v) - 1\n                vn = int(vn) - 1\n                vt = int(vt) - 1 if vt else -1\n\n                if vt == -1:\n                    row = np.concatenate((data[\"v\"][v], data[\"vn\"][vn], empty_vt))\n                else:\n                    row = np.concatenate((data[\"v\"][v], data[\"vn\"][vn], data[\"vt\"][vt]))\n\n                packed_arrays[current_mtl].append(row)\n        elif elem_type == \"usemtl\":\n            current_mtl = parts[1]\n            if current_mtl not in packed_arrays:\n                packed_arrays[current_mtl] = []\n        elif elem_type == \"l\":\n            if current_mtl in packed_arrays:\n                packed_arrays.pop(current_mtl)\n\n    max_pos_vec = max_vec - min_vec\n    max_pos_val = max(max_pos_vec)\n    max_pos_vec_norm = max_pos_vec / max_pos_val\n    for (sub_obj, packed_array) in packed_arrays.items():\n        # z-coordinate of texture is always zero (if present).\n        packed_array = np.stack(packed_array)[:, :8]\n        original_vertices = packed_array[:, :3].copy()\n\n        # All coordinates greater than or equal to zero.\n        original_vertices -= min_vec\n        # All coordinates between zero and one.\n        original_vertices /= max_pos_val\n        # All coordinates between zero and two.\n        original_vertices *= 2\n        # All coordinates between negative one and positive one with the center of object\n        # at (0, 0, 0).\n        original_vertices -= max_pos_vec_norm\n\n        packed_array[:, :3] = original_vertices\n        packed_arrays[sub_obj] = packed_array\n\n    all_vertices = np.stack(data[\"v\"])\n    all_vertices -= min_vec\n    all_vertices /= max_pos_val\n    all_vertices *= 2\n    all_vertices -= max_pos_vec_norm\n    return (packed_arrays, all_vertices)\n\n\ndef parse_mtl_file(input_mtl):\n    vector_elems = {\"Ka\", \"Kd\", \"Ks\"}\n    float_elems = {\"Ns\", \"Ni\", \"d\"}\n    int_elems = {\"illum\"}\n    current_mtl = None\n    mtl_infos = {}\n    mtl_f = open(input_mtl)\n    sub_objs = []\n    for line in mtl_f:\n        line = line.strip()\n        if line == \"\":\n            continue\n\n        parts = line.split()\n        elem_type = parts[0]\n        if elem_type in vector_elems:\n            vals = np.array(parts[1:4], dtype=np.float32)\n            mtl_infos[current_mtl][elem_type] = tuple(vals)\n        elif elem_type in float_elems:\n            mtl_infos[current_mtl][elem_type] = float(parts[1])\n        elif elem_type in int_elems:\n            mtl_infos[current_mtl][elem_type] = int(parts[1])\n        elif elem_type == \"newmtl\":\n            current_mtl = parts[1]\n            sub_objs.append(current_mtl)\n            mtl_infos[current_mtl] = {\"d\": 1.0}\n        elif elem_type == \"map_Kd\":\n            mtl_infos[current_mtl][\"map_Kd\"] = parts[1]\n\n    sub_objs.sort()\n    sub_objs.reverse()\n    non_trans = [sub_obj for sub_obj in sub_objs if mtl_infos[sub_obj][\"d\"] == 1.0]\n    trans = [\n        (sub_obj, mtl_infos[sub_obj][\"d\"])\n        for sub_obj in sub_objs\n        if mtl_infos[sub_obj][\"d\"] < 1.0\n    ]\n    trans.sort(key=lambda x: x[1], reverse=True)\n    sub_objs = non_trans + [sub_obj for (sub_obj, d) in trans]\n    return (mtl_infos, sub_objs)\n\n\ndef get_texture_data(sub_objs, packed_arrays, mtl_infos, obj_f):\n    texture_data = {}\n    texture_path_list = obj_f.split(\"/\")\n    img_str_len = len(\"images/\")\n    for sub_obj in sub_objs:\n        if sub_obj not in packed_arrays:\n            continue\n\n        if \"map_Kd\" in mtl_infos[sub_obj]:\n            texture_f = mtl_infos[sub_obj][\"map_Kd\"]\n            img_str_idx = texture_f.find(\"images/\")\n            if img_str_idx != -1:\n                texture_path = \"/\".join(texture_path_list[:-2] + [\"images\"])\n                texture_f = texture_f[img_str_idx + img_str_len :]\n            else:\n                texture_path = \"/\".join(texture_path_list[:-1])\n\n            try:\n                texture_img = (\n                    Image.open(texture_path + \"/\" + texture_f)\n                    .transpose(Image.FLIP_TOP_BOTTOM)\n                    .convert(\"RGBA\")\n                )\n            except FileNotFoundError:\n                texture_f_parts = texture_f.split(\".\")\n                ext = texture_f_parts[-1]\n                if ext.isupper():\n                    texture_f_parts[-1] = ext.lower()\n                elif ext.islower():\n                    texture_f_parts[-1] = ext.upper()\n\n                texture_f = \".\".join(texture_f_parts)\n                texture_img = (\n                    Image.open(texture_path + \"/\" + texture_f)\n                    .transpose(Image.FLIP_TOP_BOTTOM)\n                    .convert(\"RGBA\")\n                )\n\n            texture_data[sub_obj] = {\n                \"size\": texture_img.size,\n                \"bytes\": texture_img.tobytes(),\n            }\n\n    return texture_data\n\n\nclass Renderer:\n    def __init__(\n        self,\n        background_f=None,\n        camera_distance=2.0,\n        angle_of_view=16.426,\n        dir_light=(0, 1 / np.sqrt(2), np.sqrt(2)),\n        dif_int=0.7,\n        amb_int=0.7,\n        default_width=128,\n        default_height=128,\n        cull_faces=True,\n    ):\n        # Initialize OpenGL context.\n        self.ctx = moderngl.create_standalone_context()\n        # Render depth appropriately.\n        self.ctx.enable(moderngl.DEPTH_TEST)\n        # Setting for rendering transparent objects.\n        # See: https://learnopengl.com/Advanced-OpenGL/Blending\n        # and: https://github.com/cprogrammer1994/ModernGL/blob/master/moderngl/context.py#L129.\n        self.ctx.enable(moderngl.BLEND)\n\n        # Define OpenGL program.\n        prog = self.ctx.program(\n            vertex_shader=\"\"\"\n                #version 330\n\n                uniform float x;\n                uniform float y;\n                uniform float z;\n\n                uniform mat3 R_obj;\n                uniform mat3 R_light;\n                uniform vec3 DirLight;\n                uniform mat4 VP;\n                uniform int mode;\n\n                in vec3 in_vert;\n                in vec3 in_norm;\n                in vec2 in_text;\n\n                out vec3 v_pos;\n                out vec3 v_norm;\n                out vec2 v_text;\n                out vec3 v_light;\n\n                void main() {\n                    if (mode == 0) {\n                        v_pos = R_obj * in_vert + vec3(x, y, z);\n                        gl_Position = VP * vec4(v_pos, 1.0);\n                        v_norm = R_obj * in_norm;\n                        v_text = in_text;\n                        v_light = R_light * DirLight;\n                    } else {\n                        gl_Position = vec4(in_vert, 1.0);\n                        v_text = in_text;\n                    }\n                }\n            \"\"\",\n            fragment_shader=\"\"\"\n                #version 330\n\n                uniform float amb_int;\n                uniform float dif_int;\n                uniform vec3 cam_pos;\n\n                uniform sampler2D Texture;\n                uniform int mode;\n                uniform bool use_texture;\n                uniform bool has_image;\n\n                uniform vec3 box_rgb;\n\n                uniform vec3 amb_rgb;\n                uniform vec3 dif_rgb;\n                uniform vec3 spc_rgb;\n                uniform float spec_exp;\n                uniform float trans;\n\n                in vec3 v_pos;\n                in vec3 v_norm;\n                in vec2 v_text;\n                in vec3 v_light;\n\n                out vec4 f_color;\n\n                void main() {\n                    if (mode == 0) {\n                        float dif = clamp(dot(v_light, v_norm), 0.0, 1.0) * dif_int;\n                        if (use_texture) {\n                            vec3 surface_rgb = dif_rgb;\n                            vec3 diffuse = dif * surface_rgb;\n                            if (has_image) {\n                                surface_rgb = texture(Texture, v_text).rgb;\n                                diffuse = dif * dif_rgb * surface_rgb;\n                            }\n                            vec3 ambient = amb_int * amb_rgb * surface_rgb;\n                            float spec = 0.0;\n                            if (dif > 0.0) {\n                                vec3 reflected = reflect(-v_light, v_norm);\n                                vec3 surface_to_camera = normalize(cam_pos - v_pos);\n                                spec = pow(clamp(dot(surface_to_camera, reflected), 0.0, 1.0), spec_exp);\n                            }\n                            vec3 specular = spec * spc_rgb * surface_rgb;\n                            vec3 linear = ambient + diffuse + specular;\n                            f_color = vec4(linear, trans);\n                        } else {\n                            f_color = vec4(vec3(1.0, 1.0, 1.0) * dif + amb_int, 1.0);\n                        }\n                    } else if (mode == 1) {\n                        f_color = vec4(texture(Texture, v_text).rgba);\n                    } else {\n                        f_color = vec4(box_rgb, 1.0);\n                    }\n                }\n            \"\"\",\n        )\n\n        # Lighting uniform variables.\n        prog[\"R_light\"].write(np.eye(3).astype(\"f4\").tobytes())\n        dir_light = np.array(dir_light)\n        prog[\"DirLight\"].value = tuple(dir_light / np.linalg.norm(dir_light))\n        prog[\"dif_int\"].value = dif_int\n        prog[\"amb_int\"].value = amb_int\n        prog[\"amb_rgb\"].value = (1.0, 1.0, 1.0)\n        prog[\"dif_rgb\"].value = (1.0, 1.0, 1.0)\n        prog[\"spc_rgb\"].value = (1.0, 1.0, 1.0)\n        prog[\"spec_exp\"].value = 0.0\n        self.use_spec = True\n\n        # Mode uniform variables.\n        prog[\"mode\"].value = 0\n        prog[\"use_texture\"].value = True\n        prog[\"has_image\"].value = False\n\n        # Model transformation uniform variables.\n        prog[\"R_obj\"].write(np.eye(3).astype(\"f4\").tobytes())\n        prog[\"x\"].value = 0\n        prog[\"y\"].value = 0\n        prog[\"z\"].value = 0\n\n        # Set up background.\n        self.prog = prog\n        (self.default_width, self.default_height) = (default_width, default_height)\n        self.background = None\n        (window_width, window_height) = self.set_up_background(background_f)\n\n        # Look at origin matrix.\n        eye = np.array([0.0, 0.0, camera_distance])\n        prog[\"cam_pos\"].value = tuple(eye)\n        target = np.zeros(3)\n        up = np.array([0.0, 1.0, 0.0])\n        self.look_at = Matrix44.look_at(eye, target, up)\n\n        # Perspective projection matrix.\n        self.ratio = window_width / window_height\n        self.angle_of_view = angle_of_view\n        self.perspective = Matrix44.perspective_projection(\n            angle_of_view, self.ratio, 0.1, 1000.0\n        )\n\n        # View-Projection uniform variable.\n        self.prog[\"VP\"].write((self.look_at @ self.perspective).astype(\"f4\").tobytes())\n\n        # Set up object.\n        self.mtl_infos = None\n        self.cull_faces = cull_faces\n        self.render_objs = []\n        self.vbos = {}\n        self.vaos = {}\n        self.textures = {}\n\n        # Initialize frame buffer.\n        size = (window_width, window_height)\n        self.window_size = size\n\n        # Set up multisample anti-aliasing.\n        self.ctx.multisample = True\n        color_rbo = self.ctx.renderbuffer(size, samples=self.ctx.max_samples)\n        depth_rbo = self.ctx.depth_renderbuffer(size, samples=self.ctx.max_samples)\n        self.fbo = self.ctx.framebuffer(color_rbo, depth_rbo)\n\n        color_rbo2 = self.ctx.renderbuffer(size)\n        depth_rbo2 = self.ctx.depth_renderbuffer(size)\n        self.fbo2 = self.ctx.framebuffer(color_rbo2, depth_rbo2)\n\n        self.fbo.use()\n\n    def set_up_obj(self, obj_f, mtl_f):\n        (packed_arrays, vertices) = parse_obj_file(obj_f)\n        packed_arrays = {\n            sub_obj: packed_array.flatten().astype(\"f4\").tobytes()\n            for (sub_obj, packed_array) in packed_arrays.items()\n        }\n        (mtl_infos, sub_objs) = parse_mtl_file(mtl_f)\n        texture_data = get_texture_data(sub_objs, packed_arrays, mtl_infos, obj_f)\n        self.load_obj(packed_arrays, vertices, mtl_infos, sub_objs, texture_data)\n\n    def load_obj(self, packed_arrays, vertices, mtl_infos, sub_objs, texture_data):\n        self.hom_vertices = np.hstack([vertices, np.ones(len(vertices))[:, None]])\n        render_objs = []\n        vbos = {}\n        vaos = {}\n        textures = {}\n        for sub_obj in sub_objs:\n            if sub_obj not in packed_arrays:\n                logging.info(f\"Skipping {sub_obj}.\")\n                continue\n\n            render_objs.append(sub_obj)\n            packed_array = packed_arrays[sub_obj]\n            vbo = self.ctx.buffer(packed_array)\n            vbos[sub_obj] = vbo\n            # Recall that \"in_vert\", \"in_norm\", and \"in_text\" are the inputs to the\n            # vertex shader.\n            vao = self.ctx.simple_vertex_array(\n                self.prog, vbo, \"in_vert\", \"in_norm\", \"in_text\"\n            )\n            vaos[sub_obj] = vao\n\n            if \"map_Kd\" in mtl_infos[sub_obj]:\n                # Initialize texture from image.\n                texture = self.ctx.texture(\n                    texture_data[sub_obj][\"size\"], 4, texture_data[sub_obj][\"bytes\"]\n                )\n                texture.build_mipmaps()\n                textures[sub_obj] = texture\n\n        self.mtl_infos = mtl_infos\n        self.render_objs = render_objs\n        self.vbos = vbos\n        self.vaos = vaos\n        self.textures = textures\n\n    def set_up_background(self, background_f=None):\n        if background_f:\n            background_img = (\n                Image.open(background_f)\n                .transpose(Image.FLIP_TOP_BOTTOM)\n                .convert(\"RGBA\")\n            )\n\n            # Initialize background from image.\n            background = self.ctx.texture(\n                background_img.size, 4, background_img.tobytes()\n            )\n            background.build_mipmaps()\n            self.background = background\n\n            # Create a square plane from two triangles (two sets of three points).\n            vertices = np.array(\n                [\n                    [-1.0, -1.0, 0.0],\n                    [-1.0, 1.0, 0.0],\n                    [1.0, 1.0, 0.0],\n                    [-1.0, -1.0, 0.0],\n                    [1.0, -1.0, 0.0],\n                    [1.0, 1.0, 0.0],\n                ]\n            )\n            # These arrays are not used by the renderer, but the vertex shader expects\n            # them as input.\n            normals = np.repeat([[0.0, 0.0, 1.0]], len(vertices), axis=0)\n            # The texture (UV) coordinates corresponding to the above triangle points.\n            texture_coords = np.array(\n                [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]\n            )\n\n            background_array = np.hstack((vertices, normals, texture_coords))\n            self.background_vbo = self.ctx.buffer(\n                background_array.flatten().astype(\"f4\").tobytes()\n            )\n            self.background_vao = self.ctx.simple_vertex_array(\n                self.prog, self.background_vbo, \"in_vert\", \"in_norm\", \"in_text\"\n            )\n\n            return (background_img.width, background_img.height)\n        else:\n            return (self.default_width, self.default_height)\n\n    def render(self, r=0.485, g=0.456, b=0.406, with_alpha=False):\n        if self.background is not None:\n            # See: https://computergraphics.stackexchange.com/a/4007.\n            self.ctx.disable(moderngl.DEPTH_TEST)\n            self.prog[\"mode\"].value = 1\n            self.background.use()\n            self.fbo.clear()\n            self.background_vao.render()\n\n            self.ctx.enable(moderngl.DEPTH_TEST)\n            self.prog[\"mode\"].value = 0\n        else:\n            self.fbo.clear(r, g, b)\n\n        if self.cull_faces:\n            self.ctx.enable(moderngl.CULL_FACE)\n\n        for render_obj in self.render_objs:\n            if self.prog[\"use_texture\"].value:\n                self.prog[\"amb_rgb\"].value = self.mtl_infos[render_obj][\"Ka\"]\n                self.prog[\"dif_rgb\"].value = self.mtl_infos[render_obj][\"Kd\"]\n                if self.use_spec:\n                    self.prog[\"spc_rgb\"].value = self.mtl_infos[render_obj][\"Ks\"]\n                    self.prog[\"spec_exp\"].value = self.mtl_infos[render_obj][\"Ns\"]\n                else:\n                    self.prog[\"spc_rgb\"].value = (0.0, 0.0, 0.0)\n\n                self.prog[\"trans\"].value = self.mtl_infos[render_obj][\"d\"]\n                if render_obj in self.textures:\n                    self.prog[\"has_image\"].value = True\n                    self.textures[render_obj].use()\n\n            self.vaos[render_obj].render()\n            self.prog[\"has_image\"].value = False\n\n        self.ctx.disable(moderngl.CULL_FACE)\n        self.ctx.copy_framebuffer(self.fbo2, self.fbo)\n        if with_alpha:\n            return Image.frombytes(\n                \"RGBA\",\n                self.fbo.size,\n                self.fbo2.read(components=4),\n                \"raw\",\n                \"RGBA\",\n                0,\n                -1,\n            )\n        else:\n            return Image.frombytes(\n                \"RGB\", self.fbo.size, self.fbo2.read(), \"raw\", \"RGB\", 0, -1\n            )\n\n    def get_vertex_screen_coordinates(self):\n        world = np.eye(4)\n        world[:3, :3] = np.array(self.prog[\"R_obj\"].value).reshape((3, 3)).T\n        world[:3, 3] = (\n            self.prog[\"x\"].value,\n            self.prog[\"y\"].value,\n            self.prog[\"z\"].value,\n        )\n        PV = np.array(self.prog[\"VP\"].value).reshape((4, 4)).T\n        pre_screen_coords = PV @ world @ self.hom_vertices.T\n\n        (window_width, window_height) = self.window_size\n        screen_xs = (\n            window_width\n            * (np.array(pre_screen_coords[0]) / np.array(pre_screen_coords[3]) + 1)\n            / 2\n        )\n        screen_ys = (\n            window_height\n            * (np.array(pre_screen_coords[1]) / np.array(pre_screen_coords[3]) + 1)\n            / 2\n        )\n        screen_coords = np.hstack((screen_xs, screen_ys))\n\n        screen = np.zeros((window_height, window_width))\n        for i in range(len(screen_xs)):\n            col = x = int(screen_xs[i])\n            row = y = int(screen_ys[i])\n            if x < window_width and y < window_height:\n                screen[window_height - row - 1, col] = 1\n\n        screen_mat = np.uint8(255 * screen)\n        screen_img = Image.fromarray(screen_mat, mode=\"L\")\n        return (screen_coords, screen_img)\n\n    def __del__(self):\n        self.release()\n\n    def release_obj(self):\n        for sub_obj in self.vbos:\n            self.vbos[sub_obj].release()\n            self.vaos[sub_obj].release()\n            if sub_obj in self.textures:\n                self.textures[sub_obj].release()\n\n        self.vbos = {}\n        self.vaos = {}\n        self.textures = {}\n\n    def release_background(self):\n        if self.background is not None:\n            self.background.release()\n            self.background_vbo.release()\n            self.background_vao.release()\n            self.background = None\n\n    def release(self):\n        self.release_obj()\n        self.release_background()\n\n        self.fbo.release()\n        self.fbo2.release()\n        self.ctx.release()\n\n    def adjust_angle_of_view(self, angle_of_view):\n        self.angle_of_view = angle_of_view\n        perspective = Matrix44.perspective_projection(\n            self.angle_of_view, self.ratio, 0.1, 1000.0\n        )\n        self.prog[\"VP\"].write((perspective * self.look_at).astype(\"f4\").tobytes())\n\n    def set_params(self, params):\n        ypr_params = {}\n        ae_params = {}\n        for (param, value) in params.items():\n            if param in self.prog:\n                self.prog[param].value = value\n            elif param == \"aov\":\n                self.adjust_angle_of_view(value)\n            elif param in YAW_PITCH_ROLL:\n                ypr_params[param] = value\n            elif param in AZIM_ELEV_IN_PLANE:\n                ae_params[param] = value\n\n        if len(ypr_params) > 0:\n            yaw = ypr_params.get(\"yaw\", 0)\n            pitch = ypr_params.get(\"pitch\", 0)\n            roll = ypr_params.get(\"roll\", 0)\n            R_obj = Rotation.from_euler(\"YXZ\", [yaw, pitch, roll]).as_matrix()\n            self.prog[\"R_obj\"].write(R_obj.T.astype(\"f4\").tobytes())\n        elif len(ae_params) > 0:\n            R_obj = gen_rotation_matrix_from_azim_elev_in_plane(**ae_params)\n            self.prog[\"R_obj\"].write(R_obj.T.astype(\"f4\").tobytes())\n\n    def get_depth_arrays(self):\n        depth = np.frombuffer(\n            self.fbo2.read(attachment=-1, dtype=\"f4\"), dtype=np.dtype(\"f4\")\n        )\n        depth = 1 - depth.reshape(self.window_size)\n        min_pos = depth[depth > 0].min()\n        depth[depth > 0] = depth[depth > 0] - min_pos\n        depth_normed = depth / depth.max()\n        return (depth, depth_normed)\n\n    def get_depth_map(self):\n        (depth, depth_normed) = self.get_depth_arrays()\n        depth_map = np.uint8(255 * depth_normed)\n        return ImageOps.flip(Image.fromarray(depth_map, \"L\"))\n\n    def get_normal_map(self):\n        # See: https://stackoverflow.com/questions/5281261/generating-a-normal-map-from-a-height-map\n        # and: https://stackoverflow.com/questions/34644101/calculate-surface-normals-from-depth-image-using-neighboring-pixels-cross-produc\n        # and: https://en.wikipedia.org/wiki/Normal_mapping#How_it_works.\n        (depth, depth_normed) = self.get_depth_arrays()\n        depth_pad = np.pad(depth_normed, 1, \"constant\")\n        (dx, dy) = (1 / depth.shape[1], 1 / depth.shape[0])\n        dz_dx = (depth_pad[1:-1, 2:] - depth_pad[1:-1, :-2]) / (2 * dx)\n        dz_dy = (depth_pad[2:, 1:-1] - depth_pad[:-2, 1:-1]) / (2 * dy)\n        norms = np.stack([-dz_dx.flatten(), -dz_dy.flatten(), np.ones(dz_dx.size)])\n        magnitudes = np.linalg.norm(norms, axis=0)\n        norms /= magnitudes\n        norms = norms.T\n        norms[:, :2] = 255 * (norms[:, :2] + 1) / 2\n        norms[:, 2] = 127 * norms[:, 2] + 128\n        norms = np.uint8(norms).reshape((*depth.shape, 3))\n        return ImageOps.flip(Image.fromarray(norms))\n"
  },
  {
    "path": "renderer_settings.py",
    "content": "WINDOW_SIZE = 256\nIMG_SIZE = 128\nCULL_FACES = True\n\nCAMERA_DISTANCE = 2.25\n# See: https://en.wikipedia.org/wiki/Angle_of_view#Common_lens_angles_of_view.\nANGLE_OF_VIEW = 53.962828459664856\n\n# Lighting.\nDIR_LIGHT = (0, 1 / (2 ** 0.5), 2 ** 0.5)\nDIF_INT = 0.7\nAMB_INT = 0.7\n"
  },
  {
    "path": "run_nerf.py",
    "content": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom torch import nn, optim\n\n\ndef get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os):\n    # Sample depths (t_is_c). See Equation (2) in Section 4.\n    u_is_c = torch.rand(*list(ds.shape[:2]) + [N_c]).to(ds)\n    t_is_c = t_i_c_bin_edges + u_is_c * t_i_c_gap\n    # Calculate the points along the rays (r_ts_c) using the ray origins (os), sampled\n    # depths (t_is_c), and ray directions (ds). See Section 4: r(t) = o + t * d.\n    r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]\n    return (r_ts_c, t_is_c)\n\n\ndef get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds):\n    # See text surrounding Equation (5) in Section 5.2 and:\n    # https://stephens999.github.io/fiveMinuteStats/inverse_transform_sampling.html#discrete_distributions.\n\n    # Define PDFs (pdfs) and CDFs (cdfs) from weights (w_is_c).\n    w_is_c = w_is_c + 1e-5\n    pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True)\n    cdfs = torch.cumsum(pdfs, dim=-1)\n    cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1)\n\n    # Get uniform samples (us).\n    us = torch.rand(list(cdfs.shape[:-1]) + [N_f]).to(w_is_c)\n\n    # Use inverse transform sampling to sample the depths (t_is_f).\n    idxs = torch.searchsorted(cdfs, us, right=True)\n    t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1)\n    idxs_capped = idxs.clone()\n    max_ind = cdfs.shape[-1]\n    idxs_capped[idxs_capped == max_ind] = max_ind - 1\n    t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped)\n    t_i_f_top_edges[idxs == max_ind] = t_f\n    t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges\n    u_is_f = torch.rand_like(t_i_f_gaps).to(os)\n    t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps\n\n    # Combine the coarse (t_is_c) and fine (t_is_f) depths and sort them.\n    (t_is_f, _) = torch.sort(torch.cat([t_is_c, t_is_f.detach()], dim=-1), dim=-1)\n    # Calculate the points along the rays (r_ts_f) using the ray origins (os), depths\n    # (t_is_f), and ray directions (ds). See Section 4: r(t) = o + t * d.\n    r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :]\n    return (r_ts_f, t_is_f)\n\n\ndef render_radiance_volume(r_ts, ds, chunk_size, F, t_is):\n    # Use the network (F) to predict colors (c_is) and volume densities (sigma_is) for\n    # 3D points along rays (r_ts) given the viewing directions (ds) of the rays. See\n    # Section 3 and Figure 7 in the Supplementary Materials.\n    r_ts_flat = r_ts.reshape((-1, 3))\n    ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)\n    ds_flat = ds_rep.reshape((-1, 3))\n    c_is = []\n    sigma_is = []\n    # The network processes batches of inputs to avoid running out of memory.\n    for chunk_start in range(0, r_ts_flat.shape[0], chunk_size):\n        r_ts_batch = r_ts_flat[chunk_start : chunk_start + chunk_size]\n        ds_batch = ds_flat[chunk_start : chunk_start + chunk_size]\n        preds = F(r_ts_batch, ds_batch)\n        c_is.append(preds[\"c_is\"])\n        sigma_is.append(preds[\"sigma_is\"])\n\n    c_is = torch.cat(c_is).reshape(r_ts.shape)\n    sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])\n\n    # Calculate the distances (delta_is) between points along the rays. The differences\n    # in depths are scaled by the norms of the ray directions to get the final\n    # distances. See text following Equation (3) in Section 4.\n    delta_is = t_is[..., 1:] - t_is[..., :-1]\n    # \"Infinity\". Guarantees last alpha is always one.\n    one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)\n    delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)\n    delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)\n\n    # Calculate the alphas (alpha_is) of the 3D points using the volume densities\n    # (sigma_is) and distances between points (delta_is). See text following Equation\n    # (3) in Section 4 and https://en.wikipedia.org/wiki/Alpha_compositing.\n    alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)\n\n    # Calculate the accumulated transmittances (T_is) along the rays from the alphas\n    # (alpha_is). See Equation (3) in Section 4. T_i is \"the probability that the ray\n    # travels from t_n to t_i without hitting any other particle\".\n    T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)\n    # Guarantees the ray makes it at least to the first step. See:\n    # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L142,\n    # which uses tf.math.cumprod(1.-alpha + 1e-10, axis=-1, exclusive=True).\n    T_is = torch.roll(T_is, 1, -1)\n    T_is[..., 0] = 1.0\n\n    # Calculate the weights (w_is) for the colors (c_is) along the rays using the\n    # transmittances (T_is) and alphas (alpha_is). See Equation (5) in Section 5.2:\n    # w_i = T_i * (1 - exp(-sigma_i * delta_i)).\n    w_is = T_is * alpha_is\n\n    # Calculate the pixel colors (C_rs) for the rays as weighted (w_is) sums of colors\n    # (c_is). See Equation (5) in Section 5.2: C_c_hat(r) = Σ w_i * c_i.\n    C_rs = (w_is[..., None] * c_is).sum(dim=-2)\n\n    return (C_rs, w_is)\n\n\ndef run_one_iter_of_nerf(\n    ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_c, N_f, t_f, F_f\n):\n    (r_ts_c, t_is_c) = get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os)\n    (C_rs_c, w_is_c) = render_radiance_volume(r_ts_c, ds, chunk_size, F_c, t_is_c)\n\n    (r_ts_f, t_is_f) = get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds)\n    (C_rs_f, _) = render_radiance_volume(r_ts_f, ds, chunk_size, F_f, t_is_f)\n\n    return (C_rs_c, C_rs_f)\n\n\nclass NeRFMLP(nn.Module):\n    def __init__(self):\n        super().__init__()\n        # Number of encoding functions for positions. See Section 5.1.\n        self.L_pos = 10\n        # Number of encoding functions for viewing directions. See Section 5.1.\n        self.L_dir = 4\n        pos_enc_feats = 3 + 3 * 2 * self.L_pos\n        dir_enc_feats = 3 + 3 * 2 * self.L_dir\n\n        in_feats = pos_enc_feats\n        net_width = 256\n        early_mlp_layers = 5\n        early_mlp = []\n        for layer_idx in range(early_mlp_layers):\n            early_mlp.append(nn.Linear(in_feats, net_width))\n            early_mlp.append(nn.ReLU())\n            in_feats = net_width\n\n        self.early_mlp = nn.Sequential(*early_mlp)\n\n        in_feats = pos_enc_feats + net_width\n        late_mlp_layers = 3\n        late_mlp = []\n        for layer_idx in range(late_mlp_layers):\n            late_mlp.append(nn.Linear(in_feats, net_width))\n            late_mlp.append(nn.ReLU())\n            in_feats = net_width\n\n        self.late_mlp = nn.Sequential(*late_mlp)\n        self.sigma_layer = nn.Linear(net_width, net_width + 1)\n        self.pre_final_layer = nn.Sequential(\n            nn.Linear(dir_enc_feats + net_width, net_width // 2), nn.ReLU()\n        )\n        self.final_layer = nn.Sequential(nn.Linear(net_width // 2, 3), nn.Sigmoid())\n\n    def forward(self, xs, ds):\n        # Encode the inputs. See Equation (4) in Section 5.1.\n        xs_encoded = [xs]\n        for l_pos in range(self.L_pos):\n            xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs))\n            xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs))\n\n        xs_encoded = torch.cat(xs_encoded, dim=-1)\n\n        ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)\n        ds_encoded = [ds]\n        for l_dir in range(self.L_dir):\n            ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds))\n            ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds))\n\n        ds_encoded = torch.cat(ds_encoded, dim=-1)\n\n        # Use the network to predict colors (c_is) and volume densities (sigma_is) for\n        # 3D points (xs) along rays given the viewing directions (ds) of the rays. See\n        # Section 3 and Figure 7 in the Supplementary Materials.\n        outputs = self.early_mlp(xs_encoded)\n        outputs = self.late_mlp(torch.cat([xs_encoded, outputs], dim=-1))\n        outputs = self.sigma_layer(outputs)\n        sigma_is = torch.relu(outputs[:, 0])\n        outputs = self.pre_final_layer(torch.cat([ds_encoded, outputs[:, 1:]], dim=-1))\n        c_is = self.final_layer(outputs)\n        return {\"c_is\": c_is, \"sigma_is\": sigma_is}\n\n\ndef main():\n    # Set seed.\n    seed = 9458\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n\n    # Initialize coarse and fine MLPs.\n    device = \"cuda:0\"\n    F_c = NeRFMLP().to(device)\n    F_f = NeRFMLP().to(device)\n    # Number of query points passed through the MLP at a time. See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L488.\n    chunk_size = 1024 * 32\n    # Number of training rays per iteration. See Section 5.3.\n    batch_img_size = 64\n    n_batch_pix = batch_img_size**2\n\n    # Initialize optimizer. See Section 5.3.\n    lr = 5e-4\n    optimizer = optim.Adam(list(F_c.parameters()) + list(F_f.parameters()), lr=lr)\n    criterion = nn.MSELoss()\n    # The learning rate decays exponentially. See Section 5.3\n    # See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L486.\n    lrate_decay = 250\n    decay_steps = lrate_decay * 1000\n    # See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L707.\n    decay_rate = 0.1\n\n    # Load dataset.\n    data_f = \"66bdbc812bd0a196e194052f3f12cb2e.npz\"\n    data = np.load(data_f)\n\n    # Set up initial ray origin (init_o) and ray directions (init_ds). These are the\n    # same across samples, we just rotate them based on the orientation of the camera.\n    # See Section 4.\n    images = data[\"images\"] / 255\n    img_size = images.shape[1]\n    xs = torch.arange(img_size) - (img_size / 2 - 0.5)\n    ys = torch.arange(img_size) - (img_size / 2 - 0.5)\n    (xs, ys) = torch.meshgrid(xs, -ys, indexing=\"xy\")\n    focal = float(data[\"focal\"])\n    pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)\n    # We want the zs to be negative ones, so we divide everything by the focal length\n    # (which is in pixel units).\n    camera_coords = pixel_coords / focal\n    init_ds = camera_coords.to(device)\n    init_o = torch.Tensor(np.array([0, 0, float(data[\"camera_distance\"])])).to(device)\n\n    # Set up test view.\n    test_idx = 150\n    plt.imshow(images[test_idx])\n    plt.show()\n    test_img = torch.Tensor(images[test_idx]).to(device)\n    poses = data[\"poses\"]\n    test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device)\n    test_ds = torch.einsum(\"ij,hwj->hwi\", test_R, init_ds)\n    test_os = (test_R @ init_o).expand(test_ds.shape)\n\n    # Initialize volume rendering hyperparameters.\n    # Near bound. See Section 4.\n    t_n = 1.0\n    # Far bound. See Section 4.\n    t_f = 4.0\n    # Number of coarse samples along a ray. See Section 5.3.\n    N_c = 64\n    # Number of fine samples along a ray. See Section 5.3.\n    N_f = 128\n    # Bins used to sample depths along a ray. See Equation (2) in Section 4.\n    t_i_c_gap = (t_f - t_n) / N_c\n    t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)\n\n    # Start training model.\n    train_idxs = np.arange(len(images)) != test_idx\n    images = torch.Tensor(images[train_idxs])\n    poses = torch.Tensor(poses[train_idxs])\n    n_pix = img_size**2\n    pixel_ps = torch.full((n_pix,), 1 / n_pix).to(device)\n    psnrs = []\n    iternums = []\n    # See Section 5.3.\n    num_iters = 300000\n    display_every = 100\n    F_c.train()\n    F_f.train()\n    for i in range(num_iters):\n        # Sample image and associated pose.\n        target_img_idx = np.random.randint(images.shape[0])\n        target_pose = poses[target_img_idx].to(device)\n        R = target_pose[:3, :3]\n\n        # Get rotated ray origins (os) and ray directions (ds). See Section 4.\n        ds = torch.einsum(\"ij,hwj->hwi\", R, init_ds)\n        os = (R @ init_o).expand(ds.shape)\n\n        # Sample a batch of rays.\n        pix_idxs = pixel_ps.multinomial(n_batch_pix, False)\n        pix_idx_rows = pix_idxs // img_size\n        pix_idx_cols = pix_idxs % img_size\n        ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape(\n            batch_img_size, batch_img_size, -1\n        )\n        os_batch = os[pix_idx_rows, pix_idx_cols].reshape(\n            batch_img_size, batch_img_size, -1\n        )\n\n        # Run NeRF.\n        (C_rs_c, C_rs_f) = run_one_iter_of_nerf(\n            ds_batch,\n            N_c,\n            t_i_c_bin_edges,\n            t_i_c_gap,\n            os_batch,\n            chunk_size,\n            F_c,\n            N_f,\n            t_f,\n            F_f,\n        )\n        target_img = images[target_img_idx].to(device)\n        target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape(C_rs_f.shape)\n        # Calculate the mean squared error for both the coarse and fine MLP models and\n        # update the weights. See Equation (6) in Section 5.3.\n        loss = criterion(C_rs_c, target_img_batch) + criterion(C_rs_f, target_img_batch)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Exponentially decay learning rate. See Section 5.3 and:\n        # https://keras.io/api/optimizers/learning_rate_schedules/exponential_decay/.\n        for g in optimizer.param_groups:\n            g[\"lr\"] = lr * decay_rate ** (i / decay_steps)\n\n        if i % display_every == 0:\n            F_c.eval()\n            F_f.eval()\n            with torch.no_grad():\n                (_, C_rs_f) = run_one_iter_of_nerf(\n                    test_ds,\n                    N_c,\n                    t_i_c_bin_edges,\n                    t_i_c_gap,\n                    test_os,\n                    chunk_size,\n                    F_c,\n                    N_f,\n                    t_f,\n                    F_f,\n                )\n\n            loss = criterion(C_rs_f, test_img)\n            print(f\"Loss: {loss.item()}\")\n            psnr = -10.0 * torch.log10(loss)\n\n            psnrs.append(psnr.item())\n            iternums.append(i)\n\n            plt.figure(figsize=(10, 4))\n            plt.subplot(121)\n            plt.imshow(C_rs_f.detach().cpu().numpy())\n            plt.title(f\"Iteration {i}\")\n            plt.subplot(122)\n            plt.plot(iternums, psnrs)\n            plt.title(\"PSNR\")\n            plt.show()\n\n            F_c.train()\n            F_f.train()\n\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "run_nerf_alt.py",
    "content": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom torch import nn, optim\n\n\nclass NeRFMLP(nn.Module):\n    def __init__(self):\n        super().__init__()\n        # Number of encoding functions for positions. See Section 5.1.\n        self.L_pos = 10\n        # Number of encoding functions for viewing directions. See Section 5.1.\n        self.L_dir = 4\n        pos_enc_feats = 3 + 3 * 2 * self.L_pos\n        dir_enc_feats = 3 + 3 * 2 * self.L_dir\n\n        in_feats = pos_enc_feats\n        net_width = 256\n        early_mlp_layers = 5\n        early_mlp = []\n        for layer_idx in range(early_mlp_layers):\n            early_mlp.append(nn.Linear(in_feats, net_width))\n            early_mlp.append(nn.ReLU())\n            in_feats = net_width\n\n        self.early_mlp = nn.Sequential(*early_mlp)\n\n        in_feats = pos_enc_feats + net_width\n        late_mlp_layers = 3\n        late_mlp = []\n        for layer_idx in range(late_mlp_layers):\n            late_mlp.append(nn.Linear(in_feats, net_width))\n            late_mlp.append(nn.ReLU())\n            in_feats = net_width\n\n        self.late_mlp = nn.Sequential(*late_mlp)\n        self.sigma_layer = nn.Linear(net_width, net_width + 1)\n        self.pre_final_layer = nn.Sequential(\n            nn.Linear(dir_enc_feats + net_width, net_width // 2), nn.ReLU()\n        )\n        self.final_layer = nn.Sequential(nn.Linear(net_width // 2, 3), nn.Sigmoid())\n\n    def forward(self, xs, ds):\n        # Encode the inputs. See Equation (4) in Section 5.1.\n        xs_encoded = [xs]\n        for l_pos in range(self.L_pos):\n            xs_encoded.append(torch.sin(2 ** l_pos * torch.pi * xs))\n            xs_encoded.append(torch.cos(2 ** l_pos * torch.pi * xs))\n\n        xs_encoded = torch.cat(xs_encoded, dim=-1)\n\n        ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)\n        ds_encoded = [ds]\n        for l_dir in range(self.L_dir):\n            ds_encoded.append(torch.sin(2 ** l_dir * torch.pi * ds))\n            ds_encoded.append(torch.cos(2 ** l_dir * torch.pi * ds))\n\n        ds_encoded = torch.cat(ds_encoded, dim=-1)\n\n        # Use the network to predict colors (c_is) and volume densities (sigma_is) for\n        # 3D points (xs) along rays given the viewing directions (ds) of the rays. See\n        # Section 3 and Figure 7 in the Supplementary Materials.\n        outputs = self.early_mlp(xs_encoded)\n        outputs = self.late_mlp(torch.cat([xs_encoded, outputs], dim=-1))\n        outputs = self.sigma_layer(outputs)\n        sigma_is = torch.relu(outputs[:, 0])\n        outputs = self.pre_final_layer(torch.cat([ds_encoded, outputs[:, 1:]], dim=-1))\n        c_is = self.final_layer(outputs)\n        return {\"c_is\": c_is, \"sigma_is\": sigma_is}\n\n\nclass NeRF:\n    def __init__(self, device):\n        # Initialize coarse and fine MLPs.\n        self.F_c = NeRFMLP().to(device)\n        self.F_f = NeRFMLP().to(device)\n\n        # Number of query points passed through the MLPs at a time. See:\n        # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L488.\n        self.chunk_size = 1024 * 32\n\n        # Initialize volume rendering hyperparameters.\n        # Near bound. See Section 4.\n        self.t_n = t_n = 1.0\n        # Far bound. See Section 4.\n        self.t_f = t_f = 4.0\n        # Number of coarse samples along a ray. See Section 5.3.\n        self.N_c = N_c = 64\n        # Number of fine samples along a ray. See Section 5.3.\n        self.N_f = 128\n        # Bins used to sample depths along a ray. See Equation (2) in Section 4.\n        self.t_i_c_gap = t_i_c_gap = (t_f - t_n) / N_c\n        self.t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)\n\n    def get_coarse_query_points(self, ds, os):\n        # Sample depths (t_is_c). See Equation (2) in Section 4.\n        u_is_c = torch.rand(*list(ds.shape[:2]) + [self.N_c]).to(ds)\n        t_is_c = self.t_i_c_bin_edges + u_is_c * self.t_i_c_gap\n        # Calculate the points along the rays (r_ts_c) using the ray origins (os),\n        # sampled depths (t_is_c), and ray directions (ds). See Section 4:\n        # r(t) = o + t * d.\n        r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]\n        return (r_ts_c, t_is_c)\n\n    def get_fine_query_points(self, w_is_c, t_is_c, os, ds):\n        # See text surrounding Equation (5) in Section 5.2 and:\n        # https://stephens999.github.io/fiveMinuteStats/inverse_transform_sampling.html#discrete_distributions.\n\n        # Define PDFs (pdfs) and CDFs (cdfs) from weights (w_is_c).\n        w_is_c = w_is_c + 1e-5\n        pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True)\n        cdfs = torch.cumsum(pdfs, dim=-1)\n        cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1)\n\n        # Get uniform samples (us).\n        us = torch.rand(list(cdfs.shape[:-1]) + [self.N_f]).to(w_is_c)\n\n        # Use inverse transform sampling to sample the depths (t_is_f).\n        idxs = torch.searchsorted(cdfs, us, right=True)\n        t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1)\n        idxs_capped = idxs.clone()\n        max_ind = cdfs.shape[-1]\n        idxs_capped[idxs_capped == max_ind] = max_ind - 1\n        t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped)\n        t_i_f_top_edges[idxs == max_ind] = self.t_f\n        t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges\n        u_is_f = torch.rand_like(t_i_f_gaps).to(os)\n        t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps\n\n        # Combine the coarse (t_is_c) and fine (t_is_f) depths and sort them.\n        (t_is_f, _) = torch.sort(torch.cat([t_is_c, t_is_f.detach()], dim=-1), dim=-1)\n        # Calculate the points along the rays (r_ts_f) using the ray origins (os),\n        # depths (t_is_f), and ray directions (ds). See Section 4: r(t) = o + t * d.\n        r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :]\n        return (r_ts_f, t_is_f)\n\n    def render_radiance_volume(self, r_ts, ds, F, t_is):\n        # Use the network (F) to predict colors (c_is) and volume densities (sigma_is)\n        # for 3D points along rays (r_ts) given the viewing directions (ds) of the rays.\n        # See Section 3 and Figure 7 in the Supplementary Materials.\n        r_ts_flat = r_ts.reshape((-1, 3))\n        ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)\n        ds_flat = ds_rep.reshape((-1, 3))\n        c_is = []\n        sigma_is = []\n        # The network processes batches of inputs to avoid running out of memory.\n        for chunk_start in range(0, r_ts_flat.shape[0], self.chunk_size):\n            r_ts_batch = r_ts_flat[chunk_start : chunk_start + self.chunk_size]\n            ds_batch = ds_flat[chunk_start : chunk_start + self.chunk_size]\n            preds = F(r_ts_batch, ds_batch)\n            c_is.append(preds[\"c_is\"])\n            sigma_is.append(preds[\"sigma_is\"])\n\n        c_is = torch.cat(c_is).reshape(r_ts.shape)\n        sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])\n\n        # Calculate the distances (delta_is) between points along the rays. The\n        # differences in depths are scaled by the norms of the ray directions to get the\n        # final distances. See text following Equation (3) in Section 4.\n        delta_is = t_is[..., 1:] - t_is[..., :-1]\n        # \"Infinity\". Guarantees last alpha is always one.\n        one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)\n        delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)\n        delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)\n\n        # Calculate the alphas (alpha_is) of the 3D points using the volume densities\n        # (sigma_is) and distances between points (delta_is). See text following\n        # Equation (3) in Section 4 and https://en.wikipedia.org/wiki/Alpha_compositing.\n        alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)\n\n        # Calculate the accumulated transmittances (T_is) along the rays from the alphas\n        # (alpha_is). See Equation (3) in Section 4. T_i is \"the probability that the\n        # ray travels from t_n to t_i without hitting any other particle\".\n        T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)\n        # Guarantees the ray makes it at least to the first step. See:\n        # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L142,\n        # which uses tf.math.cumprod(1.-alpha + 1e-10, axis=-1, exclusive=True).\n        T_is = torch.roll(T_is, 1, -1)\n        T_is[..., 0] = 1.0\n\n        # Calculate the weights (w_is) for the colors (c_is) along the rays using the\n        # transmittances (T_is) and alphas (alpha_is). See Equation (5) in Section 5.2:\n        # w_i = T_i * (1 - exp(-sigma_i * delta_i)).\n        w_is = T_is * alpha_is\n\n        # Calculate the pixel colors (C_rs) for the rays as weighted (w_is) sums of\n        # colors (c_is). See Equation (5) in Section 5.2: C_c_hat(r) = Σ w_i * c_i.\n        C_rs = (w_is[..., None] * c_is).sum(dim=-2)\n\n        return (C_rs, w_is)\n\n    def __call__(self, ds, os):\n        (r_ts_c, t_is_c) = self.get_coarse_query_points(ds, os)\n        (C_rs_c, w_is_c) = self.render_radiance_volume(r_ts_c, ds, self.F_c, t_is_c)\n\n        (r_ts_f, t_is_f) = self.get_fine_query_points(w_is_c, t_is_c, os, ds)\n        (C_rs_f, _) = self.render_radiance_volume(r_ts_f, ds, self.F_f, t_is_f)\n\n        return (C_rs_c, C_rs_f)\n\n\ndef load_data(device):\n    data_f = \"66bdbc812bd0a196e194052f3f12cb2e.npz\"\n    data = np.load(data_f)\n\n    # Set up initial ray origin (init_o) and ray directions (init_ds). These are the\n    # same across samples, we just rotate them based on the orientation of the camera.\n    # See Section 4.\n    images = data[\"images\"] / 255\n    img_size = images.shape[1]\n    xs = torch.arange(img_size) - (img_size / 2 - 0.5)\n    ys = torch.arange(img_size) - (img_size / 2 - 0.5)\n    (xs, ys) = torch.meshgrid(xs, -ys, indexing=\"xy\")\n    focal = float(data[\"focal\"])\n    pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)\n    # We want the zs to be negative ones, so we divide everything by the focal length\n    # (which is in pixel units).\n    camera_coords = pixel_coords / focal\n    init_ds = camera_coords.to(device)\n    init_o = torch.Tensor(np.array([0, 0, float(data[\"camera_distance\"])])).to(device)\n\n    return (images, data[\"poses\"], init_ds, init_o, img_size)\n\n\ndef set_up_test_data(images, device, poses, init_ds, init_o):\n    # Set up test view.\n    test_idx = 150\n    plt.imshow(images[test_idx])\n    plt.show()\n    test_img = torch.Tensor(images[test_idx]).to(device)\n    test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device)\n    test_ds = torch.einsum(\"ij,hwj->hwi\", test_R, init_ds)\n    test_os = (test_R @ init_o).expand(test_ds.shape)\n\n    train_idxs = np.arange(len(images)) != test_idx\n\n    return (test_ds, test_os, test_img, train_idxs)\n\n\ndef main():\n    # Set seed.\n    seed = 9458\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n\n    # Initialize NeRF.\n    device = \"cuda:0\"\n    nerf = NeRF(device)\n    # Number of training rays per iteration. See Section 5.3.\n    batch_img_size = 64\n    n_batch_pix = batch_img_size ** 2\n\n    # Initialize optimizer. See Section 5.3.\n    lr = 5e-4\n    train_params = list(nerf.F_c.parameters()) + list(nerf.F_f.parameters())\n    optimizer = optim.Adam(train_params, lr=lr)\n    criterion = nn.MSELoss()\n    # The learning rate decays exponentially. See Section 5.3\n    # See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L486.\n    lrate_decay = 250\n    decay_steps = lrate_decay * 1000\n    # See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L707.\n    decay_rate = 0.1\n\n    # Load dataset.\n    (images, poses, init_ds, init_o, img_size) = load_data(device)\n    (test_ds, test_os, test_img, train_idxs) = set_up_test_data(\n        images, device, poses, init_ds, init_o\n    )\n    images = torch.Tensor(images[train_idxs])\n    poses = torch.Tensor(poses[train_idxs])\n    n_pix = img_size ** 2\n    pixel_ps = torch.full((n_pix,), 1 / n_pix).to(device)\n\n    # Start training model.\n    psnrs = []\n    iternums = []\n    # See Section 5.3.\n    num_iters = 300000\n    display_every = 100\n    nerf.F_c.train()\n    nerf.F_f.train()\n    for i in range(num_iters):\n        # Sample image and associated pose.\n        target_img_idx = np.random.randint(images.shape[0])\n        target_pose = poses[target_img_idx].to(device)\n        R = target_pose[:3, :3]\n\n        # Get rotated ray origins (os) and ray directions (ds). See Section 4.\n        ds = torch.einsum(\"ij,hwj->hwi\", R, init_ds)\n        os = (R @ init_o).expand(ds.shape)\n\n        # Sample a batch of rays.\n        pix_idxs = pixel_ps.multinomial(n_batch_pix, False)\n        pix_idx_rows = pix_idxs // img_size\n        pix_idx_cols = pix_idxs % img_size\n        ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape(\n            batch_img_size, batch_img_size, -1\n        )\n        os_batch = os[pix_idx_rows, pix_idx_cols].reshape(\n            batch_img_size, batch_img_size, -1\n        )\n\n        # Run NeRF.\n        (C_rs_c, C_rs_f) = nerf(ds_batch, os_batch)\n        target_img = images[target_img_idx].to(device)\n        target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape(C_rs_f.shape)\n        # Calculate the mean squared error for both the coarse and fine MLP models and\n        # update the weights. See Equation (6) in Section 5.3.\n        loss = criterion(C_rs_c, target_img_batch) + criterion(C_rs_f, target_img_batch)\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        # Exponentially decay learning rate. See Section 5.3 and:\n        # https://keras.io/api/optimizers/learning_rate_schedules/exponential_decay/.\n        for g in optimizer.param_groups:\n            g[\"lr\"] = lr * decay_rate ** (i / decay_steps)\n\n        if i % display_every == 0:\n            nerf.F_c.eval()\n            nerf.F_f.eval()\n            with torch.no_grad():\n                (_, C_rs_f) = nerf(test_ds, test_os)\n\n            loss = criterion(C_rs_f, test_img)\n            print(f\"Loss: {loss.item()}\")\n            psnr = -10.0 * torch.log10(loss)\n\n            psnrs.append(psnr.item())\n            iternums.append(i)\n\n            plt.figure(figsize=(10, 4))\n            plt.subplot(121)\n            plt.imshow(C_rs_f.detach().cpu().numpy())\n            plt.title(f\"Iteration {i}\")\n            plt.subplot(122)\n            plt.plot(iternums, psnrs)\n            plt.title(\"PSNR\")\n            plt.show()\n\n            nerf.F_c.train()\n            nerf.F_f.train()\n\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "run_pixelnerf.py",
    "content": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom image_encoder import ImageEncoder\nfrom pixelnerf_dataset import PixelNeRFDataset\nfrom torch import nn, optim\n\n\ndef get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os):\n    u_is_c = torch.rand(*list(ds.shape[:2]) + [N_c]).to(ds)\n    t_is_c = t_i_c_bin_edges + u_is_c * t_i_c_gap\n    r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]\n    return (r_ts_c, t_is_c)\n\n\ndef get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds, r_ts_c, N_d, d_std, t_n):\n    w_is_c = w_is_c + 1e-5\n    pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True)\n    cdfs = torch.cumsum(pdfs, dim=-1)\n    cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1)\n\n    us = torch.rand(list(cdfs.shape[:-1]) + [N_f]).to(w_is_c)\n\n    idxs = torch.searchsorted(cdfs, us, right=True)\n    t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1)\n    idxs_capped = idxs.clone()\n    max_ind = cdfs.shape[-1]\n    idxs_capped[idxs_capped == max_ind] = max_ind - 1\n    t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped)\n    t_i_f_top_edges[idxs == max_ind] = t_f\n    t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges\n    u_is_f = torch.rand_like(t_i_f_gaps).to(os)\n    t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps\n\n    # See Section B.1 in the Supplementary Materials and:\n    # https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150.\n    t_is_d = (w_is_c * r_ts_c[..., 2]).sum(dim=-1)\n    t_is_d = t_is_d.unsqueeze(2).repeat((1, 1, N_d))\n    t_is_d = t_is_d + torch.normal(0, d_std, size=t_is_d.shape).to(t_is_d)\n    t_is_d = torch.clamp(t_is_d, t_n, t_f)\n\n    t_is_f = torch.cat([t_is_c, t_is_f.detach(), t_is_d], dim=-1)\n    (t_is_f, _) = torch.sort(t_is_f, dim=-1)\n    r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :]\n\n    return (r_ts_f, t_is_f)\n\n\ndef get_image_features_for_query_points(r_ts, camera_distance, scale, W_i):\n    # Get the projected image coordinates (pi_x_is) for each point along the rays\n    # (r_ts). This is just geometry. See: http://www.songho.ca/opengl/gl_projectionmatrix.html.\n    pi_x_is = r_ts[..., :2] / (camera_distance - r_ts[..., 2].unsqueeze(-1))\n    pi_x_is = pi_x_is / scale\n    # PyTorch's grid_sample function assumes (-1, -1) is the left-top pixel, but we want\n    # (-1, -1) to be the left-bottom pixel, so we negate the y-coordinates.\n    pi_x_is[..., 1] = -1 * pi_x_is[..., 1]\n    # PyTorch's grid_sample function expects the grid to have shape\n    # (N, H_out, W_out, 2).\n    pi_x_is = pi_x_is.permute(2, 0, 1, 3)\n    # PyTorch's grid_sample function expects the input to have shape (N, C, H_in, W_in).\n    W_i = W_i.repeat(pi_x_is.shape[0], 1, 1, 1)\n    # Get the image features (z_is) associated with the projected image coordinates\n    # (pi_x_is) from the encoded image features (W_i). See Section 4.2.\n    z_is = nn.functional.grid_sample(\n        W_i, pi_x_is, align_corners=True, padding_mode=\"border\"\n    )\n    # Convert shape back to match rays.\n    z_is = z_is.permute(2, 3, 0, 1)\n    return z_is\n\n\ndef render_radiance_volume(r_ts, ds, z_is, chunk_size, F, t_is):\n    r_ts_flat = r_ts.reshape((-1, 3))\n    ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)\n    ds_flat = ds_rep.reshape((-1, 3))\n    z_is_flat = z_is.reshape((ds_flat.shape[0], -1))\n    c_is = []\n    sigma_is = []\n    for chunk_start in range(0, r_ts_flat.shape[0], chunk_size):\n        r_ts_batch = r_ts_flat[chunk_start : chunk_start + chunk_size]\n        ds_batch = ds_flat[chunk_start : chunk_start + chunk_size]\n        w_is_batch = z_is_flat[chunk_start : chunk_start + chunk_size]\n        preds = F(r_ts_batch, ds_batch, w_is_batch)\n        c_is.append(preds[\"c_is\"])\n        sigma_is.append(preds[\"sigma_is\"])\n\n    c_is = torch.cat(c_is).reshape(r_ts.shape)\n    sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])\n\n    delta_is = t_is[..., 1:] - t_is[..., :-1]\n    one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)\n    delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)\n    delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)\n\n    alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)\n\n    T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)\n    T_is = torch.roll(T_is, 1, -1)\n    T_is[..., 0] = 1.0\n\n    w_is = T_is * alpha_is\n\n    C_rs = (w_is[..., None] * c_is).sum(dim=-2)\n\n    return (C_rs, w_is)\n\n\ndef run_one_iter_of_pixelnerf(\n    ds,\n    N_c,\n    t_i_c_bin_edges,\n    t_i_c_gap,\n    os,\n    camera_distance,\n    scale,\n    W_i,\n    chunk_size,\n    F_c,\n    N_f,\n    t_f,\n    N_d,\n    d_std,\n    t_n,\n    F_f,\n):\n    (r_ts_c, t_is_c) = get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os)\n    z_is_c = get_image_features_for_query_points(r_ts_c, camera_distance, scale, W_i)\n    (C_rs_c, w_is_c) = render_radiance_volume(\n        r_ts_c, ds, z_is_c, chunk_size, F_c, t_is_c\n    )\n\n    (r_ts_f, t_is_f) = get_fine_query_points(\n        w_is_c, N_f, t_is_c, t_f, os, ds, r_ts_c, N_d, d_std, t_n\n    )\n    z_is_f = get_image_features_for_query_points(r_ts_f, camera_distance, scale, W_i)\n    (C_rs_f, _) = render_radiance_volume(r_ts_f, ds, z_is_f, chunk_size, F_f, t_is_f)\n    return (C_rs_c, C_rs_f)\n\n\nclass PixelNeRFFCResNet(nn.Module):\n    def __init__(self):\n        super().__init__()\n        # Number of encoding functions for positions. See Section B.1 in the\n        # Supplementary Materials.\n        self.L_pos = 6\n        # Number of encoding functions for viewing directions.\n        self.L_dir = 0\n        pos_enc_feats = 3 + 3 * 2 * self.L_pos\n        dir_enc_feats = 3 + 3 * 2 * self.L_dir\n\n        # Set up ResNet MLP. See Section B.1 and Figure 18 in the Supplementary\n        # Materials.\n        net_width = 512\n        self.first_layer = nn.Sequential(\n            nn.Linear(pos_enc_feats + dir_enc_feats, net_width)\n        )\n        self.n_resnet_blocks = 5\n        z_linears = []\n        mlps = []\n        for resnet_block in range(self.n_resnet_blocks):\n            z_linears.append(nn.Linear(net_width, net_width))\n            mlps.append(\n                nn.Sequential(\n                    nn.Linear(net_width, net_width),\n                    nn.ReLU(),\n                    nn.Linear(net_width, net_width),\n                    nn.ReLU(),\n                )\n            )\n\n        self.z_linears = nn.ModuleList(z_linears)\n        self.mlps = nn.ModuleList(mlps)\n        self.final_layer = nn.Linear(net_width, 4)\n\n    def forward(self, xs, ds, zs):\n        xs_encoded = [xs]\n        for l_pos in range(self.L_pos):\n            xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs))\n            xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs))\n\n        xs_encoded = torch.cat(xs_encoded, dim=-1)\n\n        ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)\n        ds_encoded = [ds]\n        for l_dir in range(self.L_dir):\n            ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds))\n            ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds))\n\n        ds_encoded = torch.cat(ds_encoded, dim=-1)\n\n        # Use the network to predict colors (c_is) and volume densities (sigma_is) for\n        # 3D points (xs) along rays given the viewing directions (ds) of the rays\n        # and the associated input image features (zs). See Section B.1 and Figure 18 in\n        # the Supplementary Materials and:\n        # https://github.com/sxyu/pixel-nerf/blob/master/src/model/resnetfc.py.\n        outputs = self.first_layer(torch.cat([xs_encoded, ds_encoded], dim=-1))\n        for block_idx in range(self.n_resnet_blocks):\n            resnet_zs = self.z_linears[block_idx](zs)\n            outputs = outputs + resnet_zs\n            outputs = self.mlps[block_idx](outputs) + outputs\n\n        outputs = self.final_layer(outputs)\n        sigma_is = torch.relu(outputs[:, 0])\n        c_is = torch.sigmoid(outputs[:, 1:])\n        return {\"c_is\": c_is, \"sigma_is\": sigma_is}\n\n\ndef load_data():\n    # Initialize dataset and test object/poses.\n    data_dir = \"data\"\n    # See Section B.2.1 in the Supplementary Materials.\n    num_iters = 400000\n    test_obj_idx = 5\n    test_source_pose_idx = 11\n    test_target_pose_idx = 33\n    train_dataset = PixelNeRFDataset(\n        data_dir, num_iters, test_obj_idx, test_source_pose_idx, test_target_pose_idx\n    )\n    return train_dataset\n\n\ndef set_up_test_data(train_dataset, device):\n    obj_idx = train_dataset.test_obj_idx\n    obj = train_dataset.objs[obj_idx]\n    data_dir = train_dataset.data_dir\n    obj_dir = f\"{data_dir}/{obj}\"\n\n    z_len = train_dataset.z_len\n    source_pose_idx = train_dataset.test_source_pose_idx\n    source_img_f = f\"{obj_dir}/{str(source_pose_idx).zfill(z_len)}.npy\"\n    source_image = np.load(source_img_f) / 255\n    source_pose = train_dataset.poses[obj_idx, source_pose_idx]\n    source_R = source_pose[:3, :3]\n\n    target_pose_idx = train_dataset.test_target_pose_idx\n    target_img_f = f\"{obj_dir}/{str(target_pose_idx).zfill(z_len)}.npy\"\n    target_image = np.load(target_img_f) / 255\n    target_pose = train_dataset.poses[obj_idx, target_pose_idx]\n    target_R = target_pose[:3, :3]\n\n    R = torch.Tensor(source_R.T @ target_R).to(device)\n\n    plt.imshow(source_image)\n    plt.show()\n    source_image = torch.Tensor(source_image)\n    source_image = (\n        source_image - train_dataset.channel_means\n    ) / train_dataset.channel_stds\n    source_image = source_image.to(device).unsqueeze(0).permute(0, 3, 1, 2)\n    plt.imshow(target_image)\n    plt.show()\n    target_image = torch.Tensor(target_image).to(device)\n\n    return (source_image, R, target_image)\n\n\ndef main():\n    seed = 9458\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n\n    device = \"cuda:0\"\n    F_c = PixelNeRFFCResNet().to(device)\n    F_f = PixelNeRFFCResNet().to(device)\n\n    E = ImageEncoder().to(device)\n    chunk_size = 1024 * 32\n    # See Section B.2 in the Supplementary Materials.\n    batch_img_size = 12\n    n_batch_pix = batch_img_size**2\n    n_objs = 4\n\n    # See Section B.2 in the Supplementary Materials.\n    lr = 1e-4\n    optimizer = optim.Adam(list(F_c.parameters()) + list(F_f.parameters()), lr=lr)\n    criterion = nn.MSELoss()\n\n    train_dataset = load_data()\n\n    camera_distance = train_dataset.camera_distance\n    scale = train_dataset.scale\n    t_n = 1.0\n    t_f = 4.0\n    img_size = train_dataset[0][2].shape[0]\n    # See Section B.1 in the Supplementary Materials,\n    # and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/conf/default.conf#L50,\n    # and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150.\n    N_c = 64\n    N_f = 16\n    N_d = 16\n    d_std = 0.01\n\n    t_i_c_gap = (t_f - t_n) / N_c\n    t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)\n\n    init_o = train_dataset.init_o.to(device)\n    init_ds = train_dataset.init_ds.to(device)\n\n    (test_source_image, test_R, test_target_image) = set_up_test_data(\n        train_dataset, device\n    )\n    test_ds = torch.einsum(\"ij,hwj->hwi\", test_R, init_ds)\n    test_os = (test_R @ init_o).expand(test_ds.shape)\n\n    psnrs = []\n    iternums = []\n    num_iters = train_dataset.N\n    use_bbox = True\n    num_bbox_iters = 300000\n    display_every = 100\n    F_c.train()\n    F_f.train()\n    E.eval()\n    for i in range(num_iters):\n        if i == num_bbox_iters:\n            use_bbox = False\n\n        loss = 0\n        for obj in range(n_objs):\n            try:\n                (source_image, R, target_image, bbox) = train_dataset[0]\n            except ValueError:\n                continue\n\n            R = R.to(device)\n            ds = torch.einsum(\"ij,hwj->hwi\", R, init_ds)\n            os = (R @ init_o).expand(ds.shape)\n\n            if use_bbox:\n                pix_rows = np.arange(bbox[0], bbox[2])\n                pix_cols = np.arange(bbox[1], bbox[3])\n            else:\n                pix_rows = np.arange(0, img_size)\n                pix_cols = np.arange(0, img_size)\n\n            pix_row_cols = np.meshgrid(pix_rows, pix_cols, indexing=\"ij\")\n            pix_row_cols = np.stack(pix_row_cols).transpose(1, 2, 0).reshape(-1, 2)\n            choices = np.arange(len(pix_row_cols))\n            try:\n                selected_pix = np.random.choice(choices, n_batch_pix, False)\n            except ValueError:\n                continue\n\n            pix_idx_rows = pix_row_cols[selected_pix, 0]\n            pix_idx_cols = pix_row_cols[selected_pix, 1]\n            ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape(\n                batch_img_size, batch_img_size, -1\n            )\n            os_batch = os[pix_idx_rows, pix_idx_cols].reshape(\n                batch_img_size, batch_img_size, -1\n            )\n\n            # Extract feature pyramid from image. See Section 4.1, Section B.1 in the\n            # Supplementary Materials, and: https://github.com/sxyu/pixel-nerf/blob/master/src/model/encoder.py.\n            with torch.no_grad():\n                W_i = E(source_image.unsqueeze(0).permute(0, 3, 1, 2).to(device))\n\n            (C_rs_c, C_rs_f) = run_one_iter_of_pixelnerf(\n                ds_batch,\n                N_c,\n                t_i_c_bin_edges,\n                t_i_c_gap,\n                os_batch,\n                camera_distance,\n                scale,\n                W_i,\n                chunk_size,\n                F_c,\n                N_f,\n                t_f,\n                N_d,\n                d_std,\n                t_n,\n                F_f,\n            )\n            target_img = target_image.to(device)\n            target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape(\n                C_rs_c.shape\n            )\n            loss += criterion(C_rs_c, target_img_batch)\n            loss += criterion(C_rs_f, target_img_batch)\n\n        try:\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n        except AttributeError:\n            continue\n\n        if i % display_every == 0:\n            F_c.eval()\n            F_f.eval()\n\n            with torch.no_grad():\n                test_W_i = E(test_source_image)\n\n                (_, C_rs_f) = run_one_iter_of_pixelnerf(\n                    test_ds,\n                    N_c,\n                    t_i_c_bin_edges,\n                    t_i_c_gap,\n                    test_os,\n                    camera_distance,\n                    scale,\n                    test_W_i,\n                    chunk_size,\n                    F_c,\n                    N_f,\n                    t_f,\n                    N_d,\n                    d_std,\n                    t_n,\n                    F_f,\n                )\n\n            loss = criterion(C_rs_f, test_target_image)\n            print(f\"Loss: {loss.item()}\")\n            psnr = -10.0 * torch.log10(loss)\n\n            psnrs.append(psnr.item())\n            iternums.append(i)\n\n            plt.figure(figsize=(10, 4))\n            plt.subplot(121)\n            plt.imshow(C_rs_f.detach().cpu().numpy())\n            plt.title(f\"Iteration {i}\")\n            plt.subplot(122)\n            plt.plot(iternums, psnrs)\n            plt.title(\"PSNR\")\n            plt.show()\n\n            F_c.train()\n            F_f.train()\n\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "run_pixelnerf_alt.py",
    "content": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom image_encoder import ImageEncoder\nfrom pixelnerf_dataset import PixelNeRFDataset\nfrom torch import nn, optim\n\n\nclass PixelNeRFFCResNet(nn.Module):\n    def __init__(self):\n        super().__init__()\n        # Number of encoding functions for positions. See Section B.1 in the\n        # Supplementary Materials.\n        self.L_pos = 6\n        # Number of encoding functions for viewing directions.\n        self.L_dir = 0\n        pos_enc_feats = 3 + 3 * 2 * self.L_pos\n        dir_enc_feats = 3 + 3 * 2 * self.L_dir\n\n        # Set up ResNet MLP. See Section B.1 and Figure 18 in the Supplementary\n        # Materials.\n        net_width = 512\n        self.first_layer = nn.Sequential(\n            nn.Linear(pos_enc_feats + dir_enc_feats, net_width)\n        )\n        self.n_resnet_blocks = 5\n        z_linears = []\n        mlps = []\n        for resnet_block in range(self.n_resnet_blocks):\n            z_linears.append(nn.Linear(net_width, net_width))\n            mlps.append(\n                nn.Sequential(\n                    nn.Linear(net_width, net_width),\n                    nn.ReLU(),\n                    nn.Linear(net_width, net_width),\n                    nn.ReLU(),\n                )\n            )\n\n        self.z_linears = nn.ModuleList(z_linears)\n        self.mlps = nn.ModuleList(mlps)\n        self.final_layer = nn.Linear(net_width, 4)\n\n    def forward(self, xs, ds, zs):\n        xs_encoded = [xs]\n        for l_pos in range(self.L_pos):\n            xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs))\n            xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs))\n\n        xs_encoded = torch.cat(xs_encoded, dim=-1)\n\n        ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)\n        ds_encoded = [ds]\n        for l_dir in range(self.L_dir):\n            ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds))\n            ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds))\n\n        ds_encoded = torch.cat(ds_encoded, dim=-1)\n\n        # Use the network to predict colors (c_is) and volume densities (sigma_is) for\n        # 3D points (xs) along rays given the viewing directions (ds) of the rays\n        # and the associated input image features (zs). See Section B.1 and Figure 18 in\n        # the Supplementary Materials and:\n        # https://github.com/sxyu/pixel-nerf/blob/master/src/model/resnetfc.py.\n        outputs = self.first_layer(torch.cat([xs_encoded, ds_encoded], dim=-1))\n        for block_idx in range(self.n_resnet_blocks):\n            resnet_zs = self.z_linears[block_idx](zs)\n            outputs = outputs + resnet_zs\n            outputs = self.mlps[block_idx](outputs) + outputs\n\n        outputs = self.final_layer(outputs)\n        sigma_is = torch.relu(outputs[:, 0])\n        c_is = torch.sigmoid(outputs[:, 1:])\n        return {\"c_is\": c_is, \"sigma_is\": sigma_is}\n\n\nclass PixelNeRF:\n    def __init__(self, device, camera_distance, scale):\n        self.device = device\n\n        # See Section B.1 in the Supplementary Materials,\n        # and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/conf/default.conf#L50,\n        # and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150.\n        self.N_c = N_c = 64\n        self.N_f = 16\n        self.N_d = 16\n        self.d_std = 0.01\n\n        self.t_n = t_n = 1.0\n        self.t_f = t_f = 4.0\n        self.t_i_c_gap = t_i_c_gap = (t_f - t_n) / N_c\n        self.t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)\n\n        self.F_c = PixelNeRFFCResNet().to(device)\n        self.F_f = PixelNeRFFCResNet().to(device)\n        self.E = ImageEncoder().to(device)\n\n        self.camera_distance = camera_distance\n        self.scale = scale\n\n        self.chunk_size = 1024 * 32\n\n    def get_coarse_query_points(self, ds, os):\n        u_is_c = torch.rand(*list(ds.shape[:2]) + [self.N_c]).to(ds)\n        t_is_c = self.t_i_c_bin_edges + u_is_c * self.t_i_c_gap\n        r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]\n        return (r_ts_c, t_is_c)\n\n    def get_fine_query_points(self, w_is_c, t_is_c, os, ds, r_ts_c):\n        w_is_c = w_is_c + 1e-5\n        pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True)\n        cdfs = torch.cumsum(pdfs, dim=-1)\n        cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1)\n\n        us = torch.rand(list(cdfs.shape[:-1]) + [self.N_f]).to(w_is_c)\n\n        idxs = torch.searchsorted(cdfs, us, right=True)\n        t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1)\n        idxs_capped = idxs.clone()\n        max_ind = cdfs.shape[-1]\n        idxs_capped[idxs_capped == max_ind] = max_ind - 1\n        t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped)\n        t_i_f_top_edges[idxs == max_ind] = self.t_f\n        t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges\n        u_is_f = torch.rand_like(t_i_f_gaps).to(os)\n        t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps\n\n        # See Section B.1 in the Supplementary Materials and:\n        # https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150.\n        t_is_d = (w_is_c * r_ts_c[..., 2]).sum(dim=-1)\n        t_is_d = t_is_d.unsqueeze(2).repeat((1, 1, self.N_d))\n        t_is_d = t_is_d + torch.normal(0, self.d_std, size=t_is_d.shape).to(t_is_d)\n        t_is_d = torch.clamp(t_is_d, self.t_n, self.t_f)\n\n        t_is_f = torch.cat([t_is_c, t_is_f.detach(), t_is_d], dim=-1)\n        (t_is_f, _) = torch.sort(t_is_f, dim=-1)\n        r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :]\n\n        return (r_ts_f, t_is_f)\n\n    def get_image_features_for_query_points(self, r_ts, W_i):\n        # Get the projected image coordinates (pi_x_is) for each point along the rays\n        # (r_ts). This is just geometry. See: http://www.songho.ca/opengl/gl_projectionmatrix.html.\n        pi_x_is = r_ts[..., :2] / (self.camera_distance - r_ts[..., 2].unsqueeze(-1))\n        pi_x_is = pi_x_is / self.scale\n        # PyTorch's grid_sample function assumes (-1, -1) is the left-top pixel, but we want\n        # (-1, -1) to be the left-bottom pixel, so we negate the y-coordinates.\n        pi_x_is[..., 1] = -1 * pi_x_is[..., 1]\n        # PyTorch's grid_sample function expects the grid to have shape\n        # (N, H_out, W_out, 2).\n        pi_x_is = pi_x_is.permute(2, 0, 1, 3)\n        # PyTorch's grid_sample function expects the input to have shape (N, C, H_in, W_in).\n        W_i = W_i.repeat(pi_x_is.shape[0], 1, 1, 1)\n        # Get the image features (z_is) associated with the projected image coordinates\n        # (pi_x_is) from the encoded image features (W_i). See Section 4.2.\n        z_is = nn.functional.grid_sample(\n            W_i, pi_x_is, align_corners=True, padding_mode=\"border\"\n        )\n        # Convert shape back to match rays.\n        z_is = z_is.permute(2, 3, 0, 1)\n        return z_is\n\n    def render_radiance_volume(self, r_ts, ds, z_is, F, t_is):\n        r_ts_flat = r_ts.reshape((-1, 3))\n        ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)\n        ds_flat = ds_rep.reshape((-1, 3))\n        z_is_flat = z_is.reshape((ds_flat.shape[0], -1))\n        c_is = []\n        sigma_is = []\n        for chunk_start in range(0, r_ts_flat.shape[0], self.chunk_size):\n            r_ts_batch = r_ts_flat[chunk_start : chunk_start + self.chunk_size]\n            ds_batch = ds_flat[chunk_start : chunk_start + self.chunk_size]\n            w_is_batch = z_is_flat[chunk_start : chunk_start + self.chunk_size]\n            preds = F(r_ts_batch, ds_batch, w_is_batch)\n            c_is.append(preds[\"c_is\"])\n            sigma_is.append(preds[\"sigma_is\"])\n\n        c_is = torch.cat(c_is).reshape(r_ts.shape)\n        sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])\n\n        delta_is = t_is[..., 1:] - t_is[..., :-1]\n        one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)\n        delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)\n        delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)\n\n        alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)\n\n        T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)\n        T_is = torch.roll(T_is, 1, -1)\n        T_is[..., 0] = 1.0\n\n        w_is = T_is * alpha_is\n\n        C_rs = (w_is[..., None] * c_is).sum(dim=-2)\n\n        return (C_rs, w_is)\n\n    def __call__(self, ds, os, source_image):\n        (r_ts_c, t_is_c) = self.get_coarse_query_points(ds, os)\n        # Extract feature pyramid from image. See Section 4.1, Section B.1 in the\n        # Supplementary Materials, and: https://github.com/sxyu/pixel-nerf/blob/master/src/model/encoder.py.\n        with torch.no_grad():\n            W_i = self.E(source_image.unsqueeze(0).permute(0, 3, 1, 2).to(self.device))\n\n        z_is_c = self.get_image_features_for_query_points(r_ts_c, W_i)\n        (C_rs_c, w_is_c) = self.render_radiance_volume(\n            r_ts_c, ds, z_is_c, self.F_c, t_is_c\n        )\n\n        (r_ts_f, t_is_f) = self.get_fine_query_points(w_is_c, t_is_c, os, ds, r_ts_c)\n        z_is_f = self.get_image_features_for_query_points(r_ts_f, W_i)\n        (C_rs_f, _) = self.render_radiance_volume(r_ts_f, ds, z_is_f, self.F_f, t_is_f)\n        return (C_rs_c, C_rs_f)\n\n\ndef load_data():\n    # Initialize dataset and test object/poses.\n    data_dir = \"data\"\n    # See Section B.2.1 in the Supplementary Materials.\n    num_iters = 400000\n    test_obj_idx = 5\n    test_source_pose_idx = 11\n    test_target_pose_idx = 33\n    train_dataset = PixelNeRFDataset(\n        data_dir, num_iters, test_obj_idx, test_source_pose_idx, test_target_pose_idx\n    )\n\n    return (num_iters, train_dataset)\n\n\ndef set_up_test_data(train_dataset, device):\n    obj_idx = train_dataset.test_obj_idx\n    obj = train_dataset.objs[obj_idx]\n    data_dir = train_dataset.data_dir\n    obj_dir = f\"{data_dir}/{obj}\"\n\n    z_len = train_dataset.z_len\n    source_pose_idx = train_dataset.test_source_pose_idx\n    source_img_f = f\"{obj_dir}/{str(source_pose_idx).zfill(z_len)}.npy\"\n    source_image = np.load(source_img_f) / 255\n    source_pose = train_dataset.poses[obj_idx, source_pose_idx]\n    source_R = source_pose[:3, :3]\n\n    target_pose_idx = train_dataset.test_target_pose_idx\n    target_img_f = f\"{obj_dir}/{str(target_pose_idx).zfill(z_len)}.npy\"\n    target_image = np.load(target_img_f) / 255\n    target_pose = train_dataset.poses[obj_idx, target_pose_idx]\n    target_R = target_pose[:3, :3]\n\n    R = torch.Tensor(source_R.T @ target_R).to(device)\n\n    plt.imshow(source_image)\n    plt.show()\n    source_image = torch.Tensor(source_image)\n    source_image = (\n        source_image - train_dataset.channel_means\n    ) / train_dataset.channel_stds\n    plt.imshow(target_image)\n    plt.show()\n    target_image = torch.Tensor(target_image).to(device)\n\n    return (source_image, R, target_image)\n\n\ndef main():\n    seed = 9458\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    device = \"cuda:0\"\n\n    (num_iters, train_dataset) = load_data()\n    img_size = train_dataset[0][2].shape[0]\n\n    pixelnerf = PixelNeRF(device, train_dataset.camera_distance, train_dataset.scale)\n    # See Section B.2 in the Supplementary Materials.\n    batch_img_size = 12\n    n_batch_pix = batch_img_size**2\n    n_objs = 4\n\n    # See Section B.2 in the Supplementary Materials.\n    lr = 1e-4\n    train_params = list(pixelnerf.F_c.parameters()) + list(pixelnerf.F_f.parameters())\n    optimizer = optim.Adam(train_params, lr=lr)\n    criterion = nn.MSELoss()\n\n    (test_source_image, test_R, test_target_image) = set_up_test_data(\n        train_dataset, device\n    )\n    init_o = train_dataset.init_o.to(device)\n    init_ds = train_dataset.init_ds.to(device)\n    test_ds = torch.einsum(\"ij,hwj->hwi\", test_R, init_ds)\n    test_os = (test_R @ init_o).expand(test_ds.shape)\n\n    psnrs = []\n    iternums = []\n    use_bbox = True\n    num_bbox_iters = 300000\n    display_every = 100\n    pixelnerf.F_c.train()\n    pixelnerf.F_f.train()\n    pixelnerf.E.eval()\n    for i in range(num_iters):\n        if i == num_bbox_iters:\n            use_bbox = False\n\n        loss = 0\n        for obj in range(n_objs):\n            try:\n                (source_image, R, target_image, bbox) = train_dataset[0]\n            except ValueError:\n                continue\n\n            R = R.to(device)\n            ds = torch.einsum(\"ij,hwj->hwi\", R, init_ds)\n            os = (R @ init_o).expand(ds.shape)\n\n            if use_bbox:\n                pix_rows = np.arange(bbox[0], bbox[2])\n                pix_cols = np.arange(bbox[1], bbox[3])\n            else:\n                pix_rows = np.arange(0, img_size)\n                pix_cols = np.arange(0, img_size)\n\n            pix_row_cols = np.meshgrid(pix_rows, pix_cols, indexing=\"ij\")\n            pix_row_cols = np.stack(pix_row_cols).transpose(1, 2, 0).reshape(-1, 2)\n            choices = np.arange(len(pix_row_cols))\n            try:\n                selected_pix = np.random.choice(choices, n_batch_pix, False)\n            except ValueError:\n                continue\n\n            pix_idx_rows = pix_row_cols[selected_pix, 0]\n            pix_idx_cols = pix_row_cols[selected_pix, 1]\n            ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape(\n                batch_img_size, batch_img_size, -1\n            )\n            os_batch = os[pix_idx_rows, pix_idx_cols].reshape(\n                batch_img_size, batch_img_size, -1\n            )\n\n            (C_rs_c, C_rs_f) = pixelnerf(ds_batch, os_batch, source_image)\n            target_img = target_image.to(device)\n            target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape(\n                C_rs_c.shape\n            )\n            loss += criterion(C_rs_c, target_img_batch)\n            loss += criterion(C_rs_f, target_img_batch)\n\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if i % display_every == 0:\n            pixelnerf.F_c.eval()\n            pixelnerf.F_f.eval()\n\n            with torch.no_grad():\n                (_, C_rs_f) = pixelnerf(test_ds, test_os, test_source_image)\n\n            loss = criterion(C_rs_f, test_target_image)\n            print(f\"Loss: {loss.item()}\")\n            psnr = -10.0 * torch.log10(loss)\n\n            psnrs.append(psnr.item())\n            iternums.append(i)\n\n            plt.figure(figsize=(10, 4))\n            plt.subplot(121)\n            plt.imshow(C_rs_f.detach().cpu().numpy())\n            plt.title(f\"Iteration {i}\")\n            plt.subplot(122)\n            plt.plot(iternums, psnrs)\n            plt.title(\"PSNR\")\n            plt.show()\n\n            pixelnerf.F_c.train()\n            pixelnerf.F_f.train()\n\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "run_tiny_nerf.py",
    "content": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom torch import nn, optim\n\n\ndef get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os):\n    u_is_c = torch.rand(*list(ds.shape[:2]) + [N_c]).to(ds)\n    t_is_c = t_i_c_bin_edges + u_is_c * t_i_c_gap\n    r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]\n    return (r_ts_c, t_is_c)\n\n\ndef render_radiance_volume(r_ts, ds, chunk_size, F, t_is):\n    r_ts_flat = r_ts.reshape((-1, 3))\n    ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)\n    ds_flat = ds_rep.reshape((-1, 3))\n    c_is = []\n    sigma_is = []\n    for chunk_start in range(0, r_ts_flat.shape[0], chunk_size):\n        r_ts_batch = r_ts_flat[chunk_start : chunk_start + chunk_size]\n        ds_batch = ds_flat[chunk_start : chunk_start + chunk_size]\n        preds = F(r_ts_batch, ds_batch)\n        c_is.append(preds[\"c_is\"])\n        sigma_is.append(preds[\"sigma_is\"])\n\n    c_is = torch.cat(c_is).reshape(r_ts.shape)\n    sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])\n\n    delta_is = t_is[..., 1:] - t_is[..., :-1]\n    one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)\n    delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)\n    delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)\n\n    alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)\n\n    T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)\n    T_is = torch.roll(T_is, 1, -1)\n    T_is[..., 0] = 1.0\n\n    w_is = T_is * alpha_is\n\n    C_rs = (w_is[..., None] * c_is).sum(dim=-2)\n\n    return C_rs\n\n\ndef run_one_iter_of_tiny_nerf(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_c):\n    (r_ts_c, t_is_c) = get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os)\n    C_rs_c = render_radiance_volume(r_ts_c, ds, chunk_size, F_c, t_is_c)\n    return C_rs_c\n\n\nclass VeryTinyNeRFMLP(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.L_pos = 6\n        self.L_dir = 4\n        pos_enc_feats = 3 + 3 * 2 * self.L_pos\n        dir_enc_feats = 3 + 3 * 2 * self.L_dir\n\n        net_width = 256\n        self.early_mlp = nn.Sequential(\n            nn.Linear(pos_enc_feats, net_width),\n            nn.ReLU(),\n            nn.Linear(net_width, net_width + 1),\n            nn.ReLU(),\n        )\n        self.late_mlp = nn.Sequential(\n            nn.Linear(net_width + dir_enc_feats, net_width),\n            nn.ReLU(),\n            nn.Linear(net_width, 3),\n            nn.Sigmoid(),\n        )\n\n    def forward(self, xs, ds):\n        xs_encoded = [xs]\n        for l_pos in range(self.L_pos):\n            xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs))\n            xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs))\n\n        xs_encoded = torch.cat(xs_encoded, dim=-1)\n\n        ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)\n        ds_encoded = [ds]\n        for l_dir in range(self.L_dir):\n            ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds))\n            ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds))\n\n        ds_encoded = torch.cat(ds_encoded, dim=-1)\n\n        outputs = self.early_mlp(xs_encoded)\n        sigma_is = outputs[:, 0]\n        c_is = self.late_mlp(torch.cat([outputs[:, 1:], ds_encoded], dim=-1))\n        return {\"c_is\": c_is, \"sigma_is\": sigma_is}\n\n\ndef main():\n    seed = 9458\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n\n    device = \"cuda:0\"\n    F_c = VeryTinyNeRFMLP().to(device)\n    chunk_size = 16384\n\n    lr = 5e-3\n    optimizer = optim.Adam(F_c.parameters(), lr=lr)\n    criterion = nn.MSELoss()\n\n    data_f = \"66bdbc812bd0a196e194052f3f12cb2e.npz\"\n    data = np.load(data_f)\n\n    images = data[\"images\"] / 255\n    img_size = images.shape[1]\n    xs = torch.arange(img_size) - (img_size / 2 - 0.5)\n    ys = torch.arange(img_size) - (img_size / 2 - 0.5)\n    (xs, ys) = torch.meshgrid(xs, -ys, indexing=\"xy\")\n    focal = float(data[\"focal\"])\n    pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)\n    camera_coords = pixel_coords / focal\n    init_ds = camera_coords.to(device)\n    init_o = torch.Tensor(np.array([0, 0, float(data[\"camera_distance\"])])).to(device)\n\n    test_idx = 150\n    plt.imshow(images[test_idx])\n    plt.show()\n    test_img = torch.Tensor(images[test_idx]).to(device)\n    poses = data[\"poses\"]\n    test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device)\n    test_ds = torch.einsum(\"ij,hwj->hwi\", test_R, init_ds)\n    test_os = (test_R @ init_o).expand(test_ds.shape)\n\n    t_n = 1.0\n    t_f = 4.0\n    N_c = 32\n    t_i_c_gap = (t_f - t_n) / N_c\n    t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)\n\n    train_idxs = np.arange(len(images)) != test_idx\n    images = torch.Tensor(images[train_idxs])\n    poses = torch.Tensor(poses[train_idxs])\n    psnrs = []\n    iternums = []\n    num_iters = 20000\n    display_every = 100\n    F_c.train()\n    for i in range(num_iters):\n        target_img_idx = np.random.randint(images.shape[0])\n        target_pose = poses[target_img_idx].to(device)\n        R = target_pose[:3, :3]\n\n        ds = torch.einsum(\"ij,hwj->hwi\", R, init_ds)\n        os = (R @ init_o).expand(ds.shape)\n\n        C_rs_c = run_one_iter_of_tiny_nerf(\n            ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_c\n        )\n        loss = criterion(C_rs_c, images[target_img_idx].to(device))\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if i % display_every == 0:\n            F_c.eval()\n            with torch.no_grad():\n                C_rs_c = run_one_iter_of_tiny_nerf(\n                    test_ds, N_c, t_i_c_bin_edges, t_i_c_gap, test_os, chunk_size, F_c\n                )\n\n            loss = criterion(C_rs_c, test_img)\n            print(f\"Loss: {loss.item()}\")\n            psnr = -10.0 * torch.log10(loss)\n\n            psnrs.append(psnr.item())\n            iternums.append(i)\n\n            plt.figure(figsize=(10, 4))\n            plt.subplot(121)\n            plt.imshow(C_rs_c.detach().cpu().numpy())\n            plt.title(f\"Iteration {i}\")\n            plt.subplot(122)\n            plt.plot(iternums, psnrs)\n            plt.title(\"PSNR\")\n            plt.show()\n\n            F_c.train()\n\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "run_tiny_nerf_alt.py",
    "content": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom torch import nn, optim\n\n\nclass VeryTinyNeRFMLP(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.L_pos = 6\n        self.L_dir = 4\n        pos_enc_feats = 3 + 3 * 2 * self.L_pos\n        dir_enc_feats = 3 + 3 * 2 * self.L_dir\n\n        net_width = 256\n        self.early_mlp = nn.Sequential(\n            nn.Linear(pos_enc_feats, net_width),\n            nn.ReLU(),\n            nn.Linear(net_width, net_width + 1),\n            nn.ReLU(),\n        )\n        self.late_mlp = nn.Sequential(\n            nn.Linear(net_width + dir_enc_feats, net_width),\n            nn.ReLU(),\n            nn.Linear(net_width, 3),\n            nn.Sigmoid(),\n        )\n\n    def forward(self, xs, ds):\n        xs_encoded = [xs]\n        for l_pos in range(self.L_pos):\n            xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs))\n            xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs))\n\n        xs_encoded = torch.cat(xs_encoded, dim=-1)\n\n        ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)\n        ds_encoded = [ds]\n        for l_dir in range(self.L_dir):\n            ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds))\n            ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds))\n\n        ds_encoded = torch.cat(ds_encoded, dim=-1)\n\n        outputs = self.early_mlp(xs_encoded)\n        sigma_is = outputs[:, 0]\n        c_is = self.late_mlp(torch.cat([outputs[:, 1:], ds_encoded], dim=-1))\n        return {\"c_is\": c_is, \"sigma_is\": sigma_is}\n\n\nclass VeryTinyNeRF:\n    def __init__(self, device):\n        self.F_c = VeryTinyNeRFMLP().to(device)\n        self.chunk_size = 16384\n        self.t_n = t_n = 1.0\n        self.t_f = t_f = 4.0\n        self.N_c = N_c = 32\n        self.t_i_c_gap = t_i_c_gap = (t_f - t_n) / N_c\n        self.t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)\n\n    def get_coarse_query_points(self, ds, os):\n        u_is_c = torch.rand(*list(ds.shape[:2]) + [self.N_c]).to(ds)\n        t_is_c = self.t_i_c_bin_edges + u_is_c * self.t_i_c_gap\n        r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]\n        return (r_ts_c, t_is_c)\n\n    def render_radiance_volume(self, r_ts, ds, F, t_is):\n        r_ts_flat = r_ts.reshape((-1, 3))\n        ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)\n        ds_flat = ds_rep.reshape((-1, 3))\n        c_is = []\n        sigma_is = []\n        for chunk_start in range(0, r_ts_flat.shape[0], self.chunk_size):\n            r_ts_batch = r_ts_flat[chunk_start : chunk_start + self.chunk_size]\n            ds_batch = ds_flat[chunk_start : chunk_start + self.chunk_size]\n            preds = F(r_ts_batch, ds_batch)\n            c_is.append(preds[\"c_is\"])\n            sigma_is.append(preds[\"sigma_is\"])\n\n        c_is = torch.cat(c_is).reshape(r_ts.shape)\n        sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])\n\n        delta_is = t_is[..., 1:] - t_is[..., :-1]\n        one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)\n        delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)\n        delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)\n\n        alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)\n\n        T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)\n        T_is = torch.roll(T_is, 1, -1)\n        T_is[..., 0] = 1.0\n\n        w_is = T_is * alpha_is\n\n        C_rs = (w_is[..., None] * c_is).sum(dim=-2)\n\n        return C_rs\n\n    def __call__(self, ds, os):\n        (r_ts_c, t_is_c) = self.get_coarse_query_points(ds, os)\n        C_rs_c = self.render_radiance_volume(r_ts_c, ds, self.F_c, t_is_c)\n        return C_rs_c\n\n\ndef load_data(device):\n    data_f = \"66bdbc812bd0a196e194052f3f12cb2e.npz\"\n    data = np.load(data_f)\n\n    images = data[\"images\"] / 255\n    img_size = images.shape[1]\n    xs = torch.arange(img_size) - (img_size / 2 - 0.5)\n    ys = torch.arange(img_size) - (img_size / 2 - 0.5)\n    (xs, ys) = torch.meshgrid(xs, -ys, indexing=\"xy\")\n    focal = float(data[\"focal\"])\n    pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)\n    camera_coords = pixel_coords / focal\n    init_ds = camera_coords.to(device)\n    init_o = torch.Tensor(np.array([0, 0, float(data[\"camera_distance\"])])).to(device)\n\n    return (images, data[\"poses\"], init_ds, init_o, img_size)\n\n\ndef set_up_test_data(images, device, poses, init_ds, init_o):\n    test_idx = 150\n    plt.imshow(images[test_idx])\n    plt.show()\n    test_img = torch.Tensor(images[test_idx]).to(device)\n    test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device)\n    test_ds = torch.einsum(\"ij,hwj->hwi\", test_R, init_ds)\n    test_os = (test_R @ init_o).expand(test_ds.shape)\n\n    train_idxs = np.arange(len(images)) != test_idx\n\n    return (test_ds, test_os, test_img, train_idxs)\n\n\ndef main():\n    seed = 9458\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n\n    device = \"cuda:0\"\n    nerf = VeryTinyNeRF(device)\n\n    lr = 5e-3\n    optimizer = optim.Adam(nerf.F_c.parameters(), lr=lr)\n    criterion = nn.MSELoss()\n\n    (images, poses, init_ds, init_o, test_img) = load_data(device)\n    (test_ds, test_os, test_img, train_idxs) = set_up_test_data(\n        images, device, poses, init_ds, init_o\n    )\n    images = torch.Tensor(images[train_idxs])\n    poses = torch.Tensor(poses[train_idxs])\n\n    psnrs = []\n    iternums = []\n    num_iters = 20000\n    display_every = 100\n    nerf.F_c.train()\n    for i in range(num_iters):\n        target_img_idx = np.random.randint(images.shape[0])\n        target_pose = poses[target_img_idx].to(device)\n        R = target_pose[:3, :3]\n\n        ds = torch.einsum(\"ij,hwj->hwi\", R, init_ds)\n        os = (R @ init_o).expand(ds.shape)\n\n        C_rs_c = nerf(ds, os)\n        loss = criterion(C_rs_c, images[target_img_idx].to(device))\n        optimizer.zero_grad()\n        loss.backward()\n        optimizer.step()\n\n        if i % display_every == 0:\n            nerf.F_c.eval()\n            with torch.no_grad():\n                C_rs_c = nerf(test_ds, test_os)\n\n            loss = criterion(C_rs_c, test_img)\n            print(f\"Loss: {loss.item()}\")\n            psnr = -10.0 * torch.log10(loss)\n\n            psnrs.append(psnr.item())\n            iternums.append(i)\n\n            plt.figure(figsize=(10, 4))\n            plt.subplot(121)\n            plt.imshow(C_rs_c.detach().cpu().numpy())\n            plt.title(f\"Iteration {i}\")\n            plt.subplot(122)\n            plt.plot(iternums, psnrs)\n            plt.title(\"PSNR\")\n            plt.show()\n\n            nerf.F_c.train()\n\n    print(\"Done!\")\n\n\nif __name__ == \"__main__\":\n    main()\n"
  }
]