Repository: airalcorn2/pytorch-nerf Branch: master Commit: c9a055558e57 Files: 16 Total size: 23.1 MB Directory structure: gitextract_esqdzl8y/ ├── .gitignore ├── 66bdbc812bd0a196e194052f3f12cb2e.npz ├── LICENSE ├── README.md ├── generate_nerf_dataset.py ├── generate_pixelnerf_dataset.py ├── image_encoder.py ├── pixelnerf_dataset.py ├── renderer.py ├── renderer_settings.py ├── run_nerf.py ├── run_nerf_alt.py ├── run_pixelnerf.py ├── run_pixelnerf_alt.py ├── run_tiny_nerf.py └── run_tiny_nerf_alt.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ .idea __pycache__ data.zip data ================================================ FILE: 66bdbc812bd0a196e194052f3f12cb2e.npz ================================================ [File too large to display: 23.0 MB] ================================================ FILE: LICENSE ================================================ Copyright 2022 Michael A. Alcorn Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # PyTorch NeRF and pixelNeRF **NeRF**: [![Open NeRF in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oRnnlF-2YqCDIzoc-uShQm8_yymLKiqr) **Tiny NeRF**: [![Open Tiny NeRF in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ntlbzQ121-E1BSa5EKvAyai6SMG4cylj) **pixelNeRF**: [![Open pixelNeRF in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1VEEy4VOVoQTQKo4oG3nWcfKAXjC_0fFt) This repository contains minimal PyTorch implementations of the NeRF model described in "[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](https://arxiv.org/abs/2003.08934)" and the pixelNeRF model described in ["pixelNeRF: Neural Radiance Fields from One or Few Images"](https://arxiv.org/abs/2012.02190). While there are other PyTorch implementations out there (e.g., [this one](https://github.com/krrish94/nerf-pytorch) and [this one](https://github.com/yenchenlin/nerf-pytorch) for NeRF, and [the authors' official implementation](https://github.com/sxyu/pixel-nerf) for pixelNeRF), I personally found them somewhat difficult to follow, so I decided to do a complete rewrite of NeRF myself. I tried to stay as close to the authors' text as possible, and I added comments in the code referring back to the relevant sections/equations in the paper. The final result is a tight 355 lines of heavily commented code (301 sloc—"source lines of code"—on GitHub) all contained in [a single file](run_nerf.py). For comparison, [this PyTorch implementation](https://github.com/krrish94/nerf-pytorch) has approximately 970 sloc spread across several files, while [this PyTorch implementation](https://github.com/yenchenlin/nerf-pytorch) has approximately 905 sloc. [`run_tiny_nerf.py`](run_tiny_nerf.py) trains a simplified NeRF model inspired by the "[Tiny NeRF](https://colab.research.google.com/github/bmild/nerf/blob/master/tiny_nerf.ipynb)" example provided by the NeRF authors. This NeRF model does not use fine sampling and the MLP is smaller, but the code is otherwise identical to the full model code. At only 153 sloc, it might be a good place to start for people who are completely new to NeRF. If you prefer your code more object-oriented, check out [`run_nerf_alt.py`](run_nerf_alt.py) and [`run_tiny_nerf_alt.py`](run_tiny_nerf_alt.py). A Colab notebook for the full model can be found [here](https://colab.research.google.com/drive/1oRnnlF-2YqCDIzoc-uShQm8_yymLKiqr?usp=sharing), while a notebook for the tiny model can be found [here](https://colab.research.google.com/drive/1ntlbzQ121-E1BSa5EKvAyai6SMG4cylj?usp=sharing). The [`generate_nerf_dataset.py`](generate_nerf_dataset.py) script was used to generate the training data of the ShapeNet car (see "[Generating the ShapeNet datasets](#generating-the-shapenet-datasets)" for additional details). For the following test view: ![](test_view.png) [`run_nerf.py`](run_nerf.py) generated the following after 20,100 iterations (a few hours on a P100 GPU): **Loss**: 0.00022201683896128088 ![](nerf.png) while [`run_tiny_nerf.py`](run_tiny_nerf.py) generated the following after 19,600 iterations (~35 minutes on a P100 GPU): **Loss**: 0.0004151524917688221 ![](tiny_nerf.png) The advantages of streamlining NeRF's code become readily apparent when trying to extend NeRF. For example, [training a pixelNeRF model](run_pixelnerf.py) only required making a few changes to [`run_nerf.py`](run_nerf.py) bringing it to 368 sloc (notebook [here](https://colab.research.google.com/drive/1VEEy4VOVoQTQKo4oG3nWcfKAXjC_0fFt?usp=sharing)). For comparison, [the official pixelNeRF implementation](https://github.com/sxyu/pixel-nerf) has approximately 1,300 pixelNeRF-specific (i.e., not related to the image encoder or dataset) sloc spread across several files. The [`generate_pixelnerf_dataset.py`](generate_pixelnerf_dataset.py) script was used to generate the training data of ShapeNet cars (see "[Generating the ShapeNet datasets](#generating-the-shapenet-datasets)" for additional details). For the following source object and view: ![](pixelnerf_src.png) and target view: ![](pixelnerf_tgt.png) [`run_pixelnerf.py`](run_pixelnerf.py) generated the following after 73,243 iterations (~12 hours on a P100 GPU; the full pixelNeRF model was trained for 400,000 iterations, which took six days): **Loss**: 0.004468636587262154 ![](pixelnerf.png) The "smearing" is an artifact caused by the bounding box sampling method. ## Generating the ShapeNet datasets 1) Download the data (the ShapeNet server is pretty slow, so this will take a while): ```bash SHAPENET_BASE_DIR= nohup wget --quiet -P ${SHAPENET_BASE_DIR} http://shapenet.cs.stanford.edu/shapenet/obj-zip/ShapeNetCore.v2.zip > shapenet.log & ``` 2) Unzip the data: ```bash cd ${SHAPENET_BASE_DIR} nohup unzip -q ShapeNetCore.v2.zip > shapenet.log & ``` 3) ***After*** the file is done unzipping, remove the ZIP: ```bash rm ShapeNetCore.v2.zip ``` 4) Change the `SHAPENET_DIR` variable in [`generate_nerf_dataset.py`](generate_nerf_dataset.py) and [`generate_pixelnerf_dataset.py`](generate_pixelnerf_dataset.py) to `/ShapeNetCore.v2`. ================================================ FILE: generate_nerf_dataset.py ================================================ import numpy as np from pyrr import Matrix44 from renderer import gen_rotation_matrix_from_cam_pos, Renderer from renderer_settings import * SHAPENET_DIR = "/run/media/airalcorn2/MiQ BIG/ShapeNetCore.v2" def main(): # Set up the renderer. renderer = Renderer( camera_distance=CAMERA_DISTANCE, angle_of_view=ANGLE_OF_VIEW, dir_light=DIR_LIGHT, dif_int=DIF_INT, amb_int=AMB_INT, default_width=WINDOW_SIZE, default_height=WINDOW_SIZE, cull_faces=CULL_FACES, ) img_size = 100 # Calculate focal length in pixel units. This is just geometry. See: # https://en.wikipedia.org/wiki/Angle_of_view#Derivation_of_the_angle-of-view_formula. focal = (img_size / 2) / np.tan(np.radians(ANGLE_OF_VIEW) / 2) # Load the ShapeNet car object. obj = "66bdbc812bd0a196e194052f3f12cb2e" cat = "02958343" obj_mtl_path = f"{SHAPENET_DIR}/{cat}/{obj}/models/model_normalized" renderer.set_up_obj(f"{obj_mtl_path}.obj", f"{obj_mtl_path}.mtl") # Generate car renders using random camera locations. init_cam_pos = np.array([0, 0, CAMERA_DISTANCE]) target = np.zeros(3) up = np.array([0.0, 1.0, 0.0]) samps = 800 imgs = [] poses = [] for idx in range(samps): # See: https://stats.stackexchange.com/a/7984/81836. xyz = np.random.normal(size=3) xyz /= np.linalg.norm(xyz) R = gen_rotation_matrix_from_cam_pos(xyz) eye = tuple((R @ init_cam_pos).flatten()) look_at = Matrix44.look_at(eye, target, up) renderer.prog["VP"].write( (look_at @ renderer.perspective).astype("f4").tobytes() ) renderer.prog["cam_pos"].value = eye image = renderer.render(0.5, 0.5, 0.5).resize((img_size, img_size)) imgs.append(np.array(image)) pose = np.eye(4) pose[:3, :3] = np.array(look_at[:3, :3]) pose[:3, 3] = -look_at[:3, :3] @ look_at[3, :3] poses.append(pose) imgs = np.stack(imgs) poses = np.stack(poses) np.savez( f"{obj}.npz", images=imgs, poses=poses, focal=focal, camera_distance=CAMERA_DISTANCE, ) if __name__ == "__main__": main() ================================================ FILE: generate_pixelnerf_dataset.py ================================================ import numpy as np import os import sys from pyrr import Matrix44 from renderer import gen_rotation_matrix_from_cam_pos, Renderer from renderer_settings import * SHAPENET_DIR = "/run/media/airalcorn2/MiQ BIG/ShapeNetCore.v2" def main(): # Set up the renderer. renderer = Renderer( camera_distance=CAMERA_DISTANCE, angle_of_view=ANGLE_OF_VIEW, dir_light=DIR_LIGHT, dif_int=DIF_INT, amb_int=AMB_INT, default_width=WINDOW_SIZE, default_height=WINDOW_SIZE, cull_faces=CULL_FACES, ) # See Section 5.1.1. img_size = 128 # Calculate focal length in pixel units. This is just geometry. See: # https://en.wikipedia.org/wiki/Angle_of_view#Derivation_of_the_angle-of-view_formula. focal = (img_size / 2) / np.tan(np.radians(ANGLE_OF_VIEW) / 2) # Generate car renders using random camera locations. init_cam_pos = np.array([0, 0, CAMERA_DISTANCE]) target = np.zeros(3) up = np.array([0.0, 1.0, 0.0]) # See Section 5.1.1. samps = 50 z_len = len(str(samps - 1)) data_dir = "data" poses = [] os.mkdir(data_dir) # Car category. cat = "02958343" objs = os.listdir(f"{SHAPENET_DIR}/{cat}") used_objs = [] for obj in objs: # Load the ShapeNet object. obj_mtl_path = f"{SHAPENET_DIR}/{cat}/{obj}/models/model_normalized" try: renderer.set_up_obj(f"{obj_mtl_path}.obj", f"{obj_mtl_path}.mtl") sys.stderr.flush() except OSError: print(f"{SHAPENET_DIR}/{cat}/{obj} is empty.", flush=True) continue except FloatingPointError: print(f"{SHAPENET_DIR}/{cat}/{obj} divides by zero.", flush=True) obj_dir = f"{data_dir}/{obj}" os.mkdir(obj_dir) obj_poses = [] for samp_idx in range(samps): # See: https://stats.stackexchange.com/a/7984/81836. xyz = np.random.normal(size=3) xyz /= np.linalg.norm(xyz) R = gen_rotation_matrix_from_cam_pos(xyz) eye = tuple((R @ init_cam_pos).flatten()) look_at = Matrix44.look_at(eye, target, up) renderer.prog["VP"].write( (look_at @ renderer.perspective).astype("f4").tobytes() ) renderer.prog["cam_pos"].value = eye image = renderer.render(0.5, 0.5, 0.5).resize((img_size, img_size)) np.save(f"{obj_dir}/{str(samp_idx).zfill(z_len)}.npy", np.array(image)) pose = np.eye(4) pose[:3, :3] = np.array(look_at[:3, :3]) pose[:3, 3] = -look_at[:3, :3] @ look_at[3, :3] obj_poses.append(pose) obj_poses = np.stack(obj_poses) poses.append(obj_poses) renderer.release_obj() used_objs.append(obj) poses = np.stack(poses) np.savez( f"{data_dir}/poses.npz", poses=poses, focal=focal, camera_distance=CAMERA_DISTANCE, ) with open(f"{data_dir}/objs.txt", "w") as f: print("\n".join(used_objs), file=f) if __name__ == "__main__": main() ================================================ FILE: image_encoder.py ================================================ import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import resnet34 class ImageEncoder(nn.Module): def __init__(self): super().__init__() self.resnet = resnet34(True) def forward(self, x): # Extract feature pyramid from image. See Section 4.1., Section B.1 in the # Supplementary Materials, and: https://github.com/sxyu/pixel-nerf/blob/master/src/model/encoder.py. x = self.resnet.conv1(x) x = self.resnet.bn1(x) feats1 = self.resnet.relu(x) feats2 = self.resnet.layer1(self.resnet.maxpool(feats1)) feats3 = self.resnet.layer2(feats2) feats4 = self.resnet.layer3(feats3) latents = [feats1, feats2, feats3, feats4] latent_sz = latents[0].shape[-2:] for i in range(len(latents)): latents[i] = F.interpolate( latents[i], latent_sz, mode="bilinear", align_corners=True ) latents = torch.cat(latents, dim=1) return latents ================================================ FILE: pixelnerf_dataset.py ================================================ import numpy as np import torch from torch.utils.data import Dataset class PixelNeRFDataset(Dataset): def __init__( self, data_dir, num_iters, test_obj_idx, test_source_pose_idx, test_target_pose_idx, ): self.data_dir = data_dir self.N = num_iters with open(f"{data_dir}/objs.txt") as f: self.objs = f.read().split("\n")[:-1] self.test_obj_idx = test_obj_idx self.test_source_pose_idx = test_source_pose_idx self.test_target_pose_idx = test_target_pose_idx data = np.load(f"{data_dir}/poses.npz") self.poses = poses = data["poses"] (n_objs, n_poses) = poses.shape[:2] self.z_len = len(str(n_poses - 1)) self.poses = torch.Tensor(poses) self.channel_means = torch.Tensor([0.485, 0.456, 0.406]) self.channel_stds = torch.Tensor([0.229, 0.224, 0.225]) samp_img = np.load(f"{data_dir}/{self.objs[0]}/{str(0).zfill(self.z_len)}.npy") img_size = samp_img.shape[0] self.pix_idxs = np.arange(img_size ** 2) xs = torch.arange(img_size) - (img_size / 2 - 0.5) ys = torch.arange(img_size) - (img_size / 2 - 0.5) (xs, ys) = torch.meshgrid(xs, -ys, indexing="xy") focal = float(data["focal"]) pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1) camera_coords = pixel_coords / focal self.init_ds = camera_coords self.camera_distance = camera_distance = float(data["camera_distance"]) self.init_o = torch.Tensor(np.array([0, 0, camera_distance])) # tan(theta) = opposite / adjacent. self.scale = (img_size / 2) / focal def __len__(self): return self.N def __getitem__(self, idx): obj_idx = np.random.randint(self.poses.shape[0]) obj = self.objs[obj_idx] obj_dir = f"{self.data_dir}/{obj}" source_pose_idx = np.random.randint(self.poses.shape[1]) if obj_idx == self.test_obj_idx: while source_pose_idx == self.test_source_pose_idx: source_pose_idx = np.random.randint(self.poses.shape[1]) source_img_f = f"{obj_dir}/{str(source_pose_idx).zfill(self.z_len)}.npy" source_image = torch.Tensor(np.load(source_img_f) / 255) source_image = (source_image - self.channel_means) / self.channel_stds source_pose = self.poses[obj_idx, source_pose_idx] source_R = source_pose[:3, :3] target_pose_idx = np.random.randint(self.poses.shape[1]) if obj_idx == self.test_obj_idx: while (target_pose_idx == self.test_source_pose_idx) or ( target_pose_idx == self.test_target_pose_idx ): target_pose_idx = np.random.randint(self.poses.shape[1]) target_img_f = f"{obj_dir}/{str(target_pose_idx).zfill(self.z_len)}.npy" target_image = np.load(target_img_f) not_gray_pix = np.argwhere((target_image == 128).sum(-1) != 3) top_row = not_gray_pix[:, 0].min() bottom_row = not_gray_pix[:, 0].max() left_col = not_gray_pix[:, 1].min() right_col = not_gray_pix[:, 1].max() bbox = (top_row, left_col, bottom_row, right_col) target_image = np.load(target_img_f) / 255 target_pose = self.poses[obj_idx, target_pose_idx] target_R = target_pose[:3, :3] R = source_R.T @ target_R return (source_image, torch.Tensor(R), torch.Tensor(target_image), bbox) ================================================ FILE: renderer.py ================================================ import logging import moderngl import numpy as np from PIL import Image, ImageOps from pyrr import Matrix44 from scipy.spatial.transform import Rotation YAW_PITCH_ROLL = {"yaw", "pitch", "roll"} AZIM_ELEV_IN_PLANE = {"azimuth", "elevation", "in_plane"} TOL = 1e-6 def gen_rotation_matrix_from_cam_pos(xyz, in_plane=0.0): assert 1 - np.linalg.norm(xyz) < TOL cam_from = xyz cam_to = np.zeros(3) tmp = np.array([0.0, 1.0, 0.0]) diff = cam_from - cam_to forward = diff / np.linalg.norm(diff) crossed = np.cross(tmp, forward) right = crossed / np.linalg.norm(crossed) up = np.cross(forward, right) R = np.stack([right, up, forward]) R_in_plane = Rotation.from_euler("Z", in_plane).as_matrix() return R_in_plane @ R def gen_rotation_matrix_from_azim_elev_in_plane( azimuth=0.0, elevation=0.0, in_plane=0.0 ): # See: https://www.scratchapixel.com/lessons/mathematics-physics-for-computer-graphics/lookat-function. y = np.sin(elevation) radius = np.cos(elevation) x = radius * np.sin(azimuth) z = radius * np.cos(azimuth) cam_from = np.array([x, y, z]) cam_to = np.zeros(3) tmp = np.array([0.0, 1.0, 0.0]) diff = cam_from - cam_to forward = diff / np.linalg.norm(diff) crossed = np.cross(tmp, forward) right = crossed / np.linalg.norm(crossed) up = np.cross(forward, right) R = np.stack([right, up, forward]) R_in_plane = Rotation.from_euler("Z", in_plane).as_matrix() return R_in_plane @ R def parse_obj_file(input_obj): """Parse wavefront .obj file. :param input_obj: :return: Dictionary of NumPy arrays with shape (3 * num_faces, 8). Each row contains (1) the coordinates of a vertex of a face, (2) the vertex's normal vector, and (3) the texture coordinates for the vertex. """ data = {"v": [], "vn": [], "vt": []} packed_arrays = {} obj_f = open(input_obj) current_mtl = None min_vec = np.full(3, np.inf) max_vec = np.full(3, -np.inf) empty_vt = np.array([0.0, 0.0, 0.0]) for line in obj_f: line = line.strip() if line == "": continue parts = line.split() elem_type = parts[0] if elem_type in data: vals = np.array(parts[1:4], dtype=np.float32) if elem_type == "v": min_vec = np.minimum(min_vec, vals) max_vec = np.maximum(max_vec, vals) elif elem_type == "vn": vals /= np.linalg.norm(vals) elif elem_type == "vt": if len(vals) < 3: vals = np.array(list(vals) + [0.0], dtype=np.float32) data[elem_type].append(vals) elif elem_type == "f": f = parts[1:4] for fv in f: (v, vt, vn) = fv.split("/") # Convert to zero-based indexing. v = int(v) - 1 vn = int(vn) - 1 vt = int(vt) - 1 if vt else -1 if vt == -1: row = np.concatenate((data["v"][v], data["vn"][vn], empty_vt)) else: row = np.concatenate((data["v"][v], data["vn"][vn], data["vt"][vt])) packed_arrays[current_mtl].append(row) elif elem_type == "usemtl": current_mtl = parts[1] if current_mtl not in packed_arrays: packed_arrays[current_mtl] = [] elif elem_type == "l": if current_mtl in packed_arrays: packed_arrays.pop(current_mtl) max_pos_vec = max_vec - min_vec max_pos_val = max(max_pos_vec) max_pos_vec_norm = max_pos_vec / max_pos_val for (sub_obj, packed_array) in packed_arrays.items(): # z-coordinate of texture is always zero (if present). packed_array = np.stack(packed_array)[:, :8] original_vertices = packed_array[:, :3].copy() # All coordinates greater than or equal to zero. original_vertices -= min_vec # All coordinates between zero and one. original_vertices /= max_pos_val # All coordinates between zero and two. original_vertices *= 2 # All coordinates between negative one and positive one with the center of object # at (0, 0, 0). original_vertices -= max_pos_vec_norm packed_array[:, :3] = original_vertices packed_arrays[sub_obj] = packed_array all_vertices = np.stack(data["v"]) all_vertices -= min_vec all_vertices /= max_pos_val all_vertices *= 2 all_vertices -= max_pos_vec_norm return (packed_arrays, all_vertices) def parse_mtl_file(input_mtl): vector_elems = {"Ka", "Kd", "Ks"} float_elems = {"Ns", "Ni", "d"} int_elems = {"illum"} current_mtl = None mtl_infos = {} mtl_f = open(input_mtl) sub_objs = [] for line in mtl_f: line = line.strip() if line == "": continue parts = line.split() elem_type = parts[0] if elem_type in vector_elems: vals = np.array(parts[1:4], dtype=np.float32) mtl_infos[current_mtl][elem_type] = tuple(vals) elif elem_type in float_elems: mtl_infos[current_mtl][elem_type] = float(parts[1]) elif elem_type in int_elems: mtl_infos[current_mtl][elem_type] = int(parts[1]) elif elem_type == "newmtl": current_mtl = parts[1] sub_objs.append(current_mtl) mtl_infos[current_mtl] = {"d": 1.0} elif elem_type == "map_Kd": mtl_infos[current_mtl]["map_Kd"] = parts[1] sub_objs.sort() sub_objs.reverse() non_trans = [sub_obj for sub_obj in sub_objs if mtl_infos[sub_obj]["d"] == 1.0] trans = [ (sub_obj, mtl_infos[sub_obj]["d"]) for sub_obj in sub_objs if mtl_infos[sub_obj]["d"] < 1.0 ] trans.sort(key=lambda x: x[1], reverse=True) sub_objs = non_trans + [sub_obj for (sub_obj, d) in trans] return (mtl_infos, sub_objs) def get_texture_data(sub_objs, packed_arrays, mtl_infos, obj_f): texture_data = {} texture_path_list = obj_f.split("/") img_str_len = len("images/") for sub_obj in sub_objs: if sub_obj not in packed_arrays: continue if "map_Kd" in mtl_infos[sub_obj]: texture_f = mtl_infos[sub_obj]["map_Kd"] img_str_idx = texture_f.find("images/") if img_str_idx != -1: texture_path = "/".join(texture_path_list[:-2] + ["images"]) texture_f = texture_f[img_str_idx + img_str_len :] else: texture_path = "/".join(texture_path_list[:-1]) try: texture_img = ( Image.open(texture_path + "/" + texture_f) .transpose(Image.FLIP_TOP_BOTTOM) .convert("RGBA") ) except FileNotFoundError: texture_f_parts = texture_f.split(".") ext = texture_f_parts[-1] if ext.isupper(): texture_f_parts[-1] = ext.lower() elif ext.islower(): texture_f_parts[-1] = ext.upper() texture_f = ".".join(texture_f_parts) texture_img = ( Image.open(texture_path + "/" + texture_f) .transpose(Image.FLIP_TOP_BOTTOM) .convert("RGBA") ) texture_data[sub_obj] = { "size": texture_img.size, "bytes": texture_img.tobytes(), } return texture_data class Renderer: def __init__( self, background_f=None, camera_distance=2.0, angle_of_view=16.426, dir_light=(0, 1 / np.sqrt(2), np.sqrt(2)), dif_int=0.7, amb_int=0.7, default_width=128, default_height=128, cull_faces=True, ): # Initialize OpenGL context. self.ctx = moderngl.create_standalone_context() # Render depth appropriately. self.ctx.enable(moderngl.DEPTH_TEST) # Setting for rendering transparent objects. # See: https://learnopengl.com/Advanced-OpenGL/Blending # and: https://github.com/cprogrammer1994/ModernGL/blob/master/moderngl/context.py#L129. self.ctx.enable(moderngl.BLEND) # Define OpenGL program. prog = self.ctx.program( vertex_shader=""" #version 330 uniform float x; uniform float y; uniform float z; uniform mat3 R_obj; uniform mat3 R_light; uniform vec3 DirLight; uniform mat4 VP; uniform int mode; in vec3 in_vert; in vec3 in_norm; in vec2 in_text; out vec3 v_pos; out vec3 v_norm; out vec2 v_text; out vec3 v_light; void main() { if (mode == 0) { v_pos = R_obj * in_vert + vec3(x, y, z); gl_Position = VP * vec4(v_pos, 1.0); v_norm = R_obj * in_norm; v_text = in_text; v_light = R_light * DirLight; } else { gl_Position = vec4(in_vert, 1.0); v_text = in_text; } } """, fragment_shader=""" #version 330 uniform float amb_int; uniform float dif_int; uniform vec3 cam_pos; uniform sampler2D Texture; uniform int mode; uniform bool use_texture; uniform bool has_image; uniform vec3 box_rgb; uniform vec3 amb_rgb; uniform vec3 dif_rgb; uniform vec3 spc_rgb; uniform float spec_exp; uniform float trans; in vec3 v_pos; in vec3 v_norm; in vec2 v_text; in vec3 v_light; out vec4 f_color; void main() { if (mode == 0) { float dif = clamp(dot(v_light, v_norm), 0.0, 1.0) * dif_int; if (use_texture) { vec3 surface_rgb = dif_rgb; vec3 diffuse = dif * surface_rgb; if (has_image) { surface_rgb = texture(Texture, v_text).rgb; diffuse = dif * dif_rgb * surface_rgb; } vec3 ambient = amb_int * amb_rgb * surface_rgb; float spec = 0.0; if (dif > 0.0) { vec3 reflected = reflect(-v_light, v_norm); vec3 surface_to_camera = normalize(cam_pos - v_pos); spec = pow(clamp(dot(surface_to_camera, reflected), 0.0, 1.0), spec_exp); } vec3 specular = spec * spc_rgb * surface_rgb; vec3 linear = ambient + diffuse + specular; f_color = vec4(linear, trans); } else { f_color = vec4(vec3(1.0, 1.0, 1.0) * dif + amb_int, 1.0); } } else if (mode == 1) { f_color = vec4(texture(Texture, v_text).rgba); } else { f_color = vec4(box_rgb, 1.0); } } """, ) # Lighting uniform variables. prog["R_light"].write(np.eye(3).astype("f4").tobytes()) dir_light = np.array(dir_light) prog["DirLight"].value = tuple(dir_light / np.linalg.norm(dir_light)) prog["dif_int"].value = dif_int prog["amb_int"].value = amb_int prog["amb_rgb"].value = (1.0, 1.0, 1.0) prog["dif_rgb"].value = (1.0, 1.0, 1.0) prog["spc_rgb"].value = (1.0, 1.0, 1.0) prog["spec_exp"].value = 0.0 self.use_spec = True # Mode uniform variables. prog["mode"].value = 0 prog["use_texture"].value = True prog["has_image"].value = False # Model transformation uniform variables. prog["R_obj"].write(np.eye(3).astype("f4").tobytes()) prog["x"].value = 0 prog["y"].value = 0 prog["z"].value = 0 # Set up background. self.prog = prog (self.default_width, self.default_height) = (default_width, default_height) self.background = None (window_width, window_height) = self.set_up_background(background_f) # Look at origin matrix. eye = np.array([0.0, 0.0, camera_distance]) prog["cam_pos"].value = tuple(eye) target = np.zeros(3) up = np.array([0.0, 1.0, 0.0]) self.look_at = Matrix44.look_at(eye, target, up) # Perspective projection matrix. self.ratio = window_width / window_height self.angle_of_view = angle_of_view self.perspective = Matrix44.perspective_projection( angle_of_view, self.ratio, 0.1, 1000.0 ) # View-Projection uniform variable. self.prog["VP"].write((self.look_at @ self.perspective).astype("f4").tobytes()) # Set up object. self.mtl_infos = None self.cull_faces = cull_faces self.render_objs = [] self.vbos = {} self.vaos = {} self.textures = {} # Initialize frame buffer. size = (window_width, window_height) self.window_size = size # Set up multisample anti-aliasing. self.ctx.multisample = True color_rbo = self.ctx.renderbuffer(size, samples=self.ctx.max_samples) depth_rbo = self.ctx.depth_renderbuffer(size, samples=self.ctx.max_samples) self.fbo = self.ctx.framebuffer(color_rbo, depth_rbo) color_rbo2 = self.ctx.renderbuffer(size) depth_rbo2 = self.ctx.depth_renderbuffer(size) self.fbo2 = self.ctx.framebuffer(color_rbo2, depth_rbo2) self.fbo.use() def set_up_obj(self, obj_f, mtl_f): (packed_arrays, vertices) = parse_obj_file(obj_f) packed_arrays = { sub_obj: packed_array.flatten().astype("f4").tobytes() for (sub_obj, packed_array) in packed_arrays.items() } (mtl_infos, sub_objs) = parse_mtl_file(mtl_f) texture_data = get_texture_data(sub_objs, packed_arrays, mtl_infos, obj_f) self.load_obj(packed_arrays, vertices, mtl_infos, sub_objs, texture_data) def load_obj(self, packed_arrays, vertices, mtl_infos, sub_objs, texture_data): self.hom_vertices = np.hstack([vertices, np.ones(len(vertices))[:, None]]) render_objs = [] vbos = {} vaos = {} textures = {} for sub_obj in sub_objs: if sub_obj not in packed_arrays: logging.info(f"Skipping {sub_obj}.") continue render_objs.append(sub_obj) packed_array = packed_arrays[sub_obj] vbo = self.ctx.buffer(packed_array) vbos[sub_obj] = vbo # Recall that "in_vert", "in_norm", and "in_text" are the inputs to the # vertex shader. vao = self.ctx.simple_vertex_array( self.prog, vbo, "in_vert", "in_norm", "in_text" ) vaos[sub_obj] = vao if "map_Kd" in mtl_infos[sub_obj]: # Initialize texture from image. texture = self.ctx.texture( texture_data[sub_obj]["size"], 4, texture_data[sub_obj]["bytes"] ) texture.build_mipmaps() textures[sub_obj] = texture self.mtl_infos = mtl_infos self.render_objs = render_objs self.vbos = vbos self.vaos = vaos self.textures = textures def set_up_background(self, background_f=None): if background_f: background_img = ( Image.open(background_f) .transpose(Image.FLIP_TOP_BOTTOM) .convert("RGBA") ) # Initialize background from image. background = self.ctx.texture( background_img.size, 4, background_img.tobytes() ) background.build_mipmaps() self.background = background # Create a square plane from two triangles (two sets of three points). vertices = np.array( [ [-1.0, -1.0, 0.0], [-1.0, 1.0, 0.0], [1.0, 1.0, 0.0], [-1.0, -1.0, 0.0], [1.0, -1.0, 0.0], [1.0, 1.0, 0.0], ] ) # These arrays are not used by the renderer, but the vertex shader expects # them as input. normals = np.repeat([[0.0, 0.0, 1.0]], len(vertices), axis=0) # The texture (UV) coordinates corresponding to the above triangle points. texture_coords = np.array( [[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [1.0, 0.0], [1.0, 1.0]] ) background_array = np.hstack((vertices, normals, texture_coords)) self.background_vbo = self.ctx.buffer( background_array.flatten().astype("f4").tobytes() ) self.background_vao = self.ctx.simple_vertex_array( self.prog, self.background_vbo, "in_vert", "in_norm", "in_text" ) return (background_img.width, background_img.height) else: return (self.default_width, self.default_height) def render(self, r=0.485, g=0.456, b=0.406, with_alpha=False): if self.background is not None: # See: https://computergraphics.stackexchange.com/a/4007. self.ctx.disable(moderngl.DEPTH_TEST) self.prog["mode"].value = 1 self.background.use() self.fbo.clear() self.background_vao.render() self.ctx.enable(moderngl.DEPTH_TEST) self.prog["mode"].value = 0 else: self.fbo.clear(r, g, b) if self.cull_faces: self.ctx.enable(moderngl.CULL_FACE) for render_obj in self.render_objs: if self.prog["use_texture"].value: self.prog["amb_rgb"].value = self.mtl_infos[render_obj]["Ka"] self.prog["dif_rgb"].value = self.mtl_infos[render_obj]["Kd"] if self.use_spec: self.prog["spc_rgb"].value = self.mtl_infos[render_obj]["Ks"] self.prog["spec_exp"].value = self.mtl_infos[render_obj]["Ns"] else: self.prog["spc_rgb"].value = (0.0, 0.0, 0.0) self.prog["trans"].value = self.mtl_infos[render_obj]["d"] if render_obj in self.textures: self.prog["has_image"].value = True self.textures[render_obj].use() self.vaos[render_obj].render() self.prog["has_image"].value = False self.ctx.disable(moderngl.CULL_FACE) self.ctx.copy_framebuffer(self.fbo2, self.fbo) if with_alpha: return Image.frombytes( "RGBA", self.fbo.size, self.fbo2.read(components=4), "raw", "RGBA", 0, -1, ) else: return Image.frombytes( "RGB", self.fbo.size, self.fbo2.read(), "raw", "RGB", 0, -1 ) def get_vertex_screen_coordinates(self): world = np.eye(4) world[:3, :3] = np.array(self.prog["R_obj"].value).reshape((3, 3)).T world[:3, 3] = ( self.prog["x"].value, self.prog["y"].value, self.prog["z"].value, ) PV = np.array(self.prog["VP"].value).reshape((4, 4)).T pre_screen_coords = PV @ world @ self.hom_vertices.T (window_width, window_height) = self.window_size screen_xs = ( window_width * (np.array(pre_screen_coords[0]) / np.array(pre_screen_coords[3]) + 1) / 2 ) screen_ys = ( window_height * (np.array(pre_screen_coords[1]) / np.array(pre_screen_coords[3]) + 1) / 2 ) screen_coords = np.hstack((screen_xs, screen_ys)) screen = np.zeros((window_height, window_width)) for i in range(len(screen_xs)): col = x = int(screen_xs[i]) row = y = int(screen_ys[i]) if x < window_width and y < window_height: screen[window_height - row - 1, col] = 1 screen_mat = np.uint8(255 * screen) screen_img = Image.fromarray(screen_mat, mode="L") return (screen_coords, screen_img) def __del__(self): self.release() def release_obj(self): for sub_obj in self.vbos: self.vbos[sub_obj].release() self.vaos[sub_obj].release() if sub_obj in self.textures: self.textures[sub_obj].release() self.vbos = {} self.vaos = {} self.textures = {} def release_background(self): if self.background is not None: self.background.release() self.background_vbo.release() self.background_vao.release() self.background = None def release(self): self.release_obj() self.release_background() self.fbo.release() self.fbo2.release() self.ctx.release() def adjust_angle_of_view(self, angle_of_view): self.angle_of_view = angle_of_view perspective = Matrix44.perspective_projection( self.angle_of_view, self.ratio, 0.1, 1000.0 ) self.prog["VP"].write((perspective * self.look_at).astype("f4").tobytes()) def set_params(self, params): ypr_params = {} ae_params = {} for (param, value) in params.items(): if param in self.prog: self.prog[param].value = value elif param == "aov": self.adjust_angle_of_view(value) elif param in YAW_PITCH_ROLL: ypr_params[param] = value elif param in AZIM_ELEV_IN_PLANE: ae_params[param] = value if len(ypr_params) > 0: yaw = ypr_params.get("yaw", 0) pitch = ypr_params.get("pitch", 0) roll = ypr_params.get("roll", 0) R_obj = Rotation.from_euler("YXZ", [yaw, pitch, roll]).as_matrix() self.prog["R_obj"].write(R_obj.T.astype("f4").tobytes()) elif len(ae_params) > 0: R_obj = gen_rotation_matrix_from_azim_elev_in_plane(**ae_params) self.prog["R_obj"].write(R_obj.T.astype("f4").tobytes()) def get_depth_arrays(self): depth = np.frombuffer( self.fbo2.read(attachment=-1, dtype="f4"), dtype=np.dtype("f4") ) depth = 1 - depth.reshape(self.window_size) min_pos = depth[depth > 0].min() depth[depth > 0] = depth[depth > 0] - min_pos depth_normed = depth / depth.max() return (depth, depth_normed) def get_depth_map(self): (depth, depth_normed) = self.get_depth_arrays() depth_map = np.uint8(255 * depth_normed) return ImageOps.flip(Image.fromarray(depth_map, "L")) def get_normal_map(self): # See: https://stackoverflow.com/questions/5281261/generating-a-normal-map-from-a-height-map # and: https://stackoverflow.com/questions/34644101/calculate-surface-normals-from-depth-image-using-neighboring-pixels-cross-produc # and: https://en.wikipedia.org/wiki/Normal_mapping#How_it_works. (depth, depth_normed) = self.get_depth_arrays() depth_pad = np.pad(depth_normed, 1, "constant") (dx, dy) = (1 / depth.shape[1], 1 / depth.shape[0]) dz_dx = (depth_pad[1:-1, 2:] - depth_pad[1:-1, :-2]) / (2 * dx) dz_dy = (depth_pad[2:, 1:-1] - depth_pad[:-2, 1:-1]) / (2 * dy) norms = np.stack([-dz_dx.flatten(), -dz_dy.flatten(), np.ones(dz_dx.size)]) magnitudes = np.linalg.norm(norms, axis=0) norms /= magnitudes norms = norms.T norms[:, :2] = 255 * (norms[:, :2] + 1) / 2 norms[:, 2] = 127 * norms[:, 2] + 128 norms = np.uint8(norms).reshape((*depth.shape, 3)) return ImageOps.flip(Image.fromarray(norms)) ================================================ FILE: renderer_settings.py ================================================ WINDOW_SIZE = 256 IMG_SIZE = 128 CULL_FACES = True CAMERA_DISTANCE = 2.25 # See: https://en.wikipedia.org/wiki/Angle_of_view#Common_lens_angles_of_view. ANGLE_OF_VIEW = 53.962828459664856 # Lighting. DIR_LIGHT = (0, 1 / (2 ** 0.5), 2 ** 0.5) DIF_INT = 0.7 AMB_INT = 0.7 ================================================ FILE: run_nerf.py ================================================ import matplotlib.pyplot as plt import numpy as np import torch from torch import nn, optim def get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os): # Sample depths (t_is_c). See Equation (2) in Section 4. u_is_c = torch.rand(*list(ds.shape[:2]) + [N_c]).to(ds) t_is_c = t_i_c_bin_edges + u_is_c * t_i_c_gap # Calculate the points along the rays (r_ts_c) using the ray origins (os), sampled # depths (t_is_c), and ray directions (ds). See Section 4: r(t) = o + t * d. r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :] return (r_ts_c, t_is_c) def get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds): # See text surrounding Equation (5) in Section 5.2 and: # https://stephens999.github.io/fiveMinuteStats/inverse_transform_sampling.html#discrete_distributions. # Define PDFs (pdfs) and CDFs (cdfs) from weights (w_is_c). w_is_c = w_is_c + 1e-5 pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True) cdfs = torch.cumsum(pdfs, dim=-1) cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1) # Get uniform samples (us). us = torch.rand(list(cdfs.shape[:-1]) + [N_f]).to(w_is_c) # Use inverse transform sampling to sample the depths (t_is_f). idxs = torch.searchsorted(cdfs, us, right=True) t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1) idxs_capped = idxs.clone() max_ind = cdfs.shape[-1] idxs_capped[idxs_capped == max_ind] = max_ind - 1 t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped) t_i_f_top_edges[idxs == max_ind] = t_f t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges u_is_f = torch.rand_like(t_i_f_gaps).to(os) t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps # Combine the coarse (t_is_c) and fine (t_is_f) depths and sort them. (t_is_f, _) = torch.sort(torch.cat([t_is_c, t_is_f.detach()], dim=-1), dim=-1) # Calculate the points along the rays (r_ts_f) using the ray origins (os), depths # (t_is_f), and ray directions (ds). See Section 4: r(t) = o + t * d. r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :] return (r_ts_f, t_is_f) def render_radiance_volume(r_ts, ds, chunk_size, F, t_is): # Use the network (F) to predict colors (c_is) and volume densities (sigma_is) for # 3D points along rays (r_ts) given the viewing directions (ds) of the rays. See # Section 3 and Figure 7 in the Supplementary Materials. r_ts_flat = r_ts.reshape((-1, 3)) ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1) ds_flat = ds_rep.reshape((-1, 3)) c_is = [] sigma_is = [] # The network processes batches of inputs to avoid running out of memory. for chunk_start in range(0, r_ts_flat.shape[0], chunk_size): r_ts_batch = r_ts_flat[chunk_start : chunk_start + chunk_size] ds_batch = ds_flat[chunk_start : chunk_start + chunk_size] preds = F(r_ts_batch, ds_batch) c_is.append(preds["c_is"]) sigma_is.append(preds["sigma_is"]) c_is = torch.cat(c_is).reshape(r_ts.shape) sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1]) # Calculate the distances (delta_is) between points along the rays. The differences # in depths are scaled by the norms of the ray directions to get the final # distances. See text following Equation (3) in Section 4. delta_is = t_is[..., 1:] - t_is[..., :-1] # "Infinity". Guarantees last alpha is always one. one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape) delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1) delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1) # Calculate the alphas (alpha_is) of the 3D points using the volume densities # (sigma_is) and distances between points (delta_is). See text following Equation # (3) in Section 4 and https://en.wikipedia.org/wiki/Alpha_compositing. alpha_is = 1.0 - torch.exp(-sigma_is * delta_is) # Calculate the accumulated transmittances (T_is) along the rays from the alphas # (alpha_is). See Equation (3) in Section 4. T_i is "the probability that the ray # travels from t_n to t_i without hitting any other particle". T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1) # Guarantees the ray makes it at least to the first step. See: # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L142, # which uses tf.math.cumprod(1.-alpha + 1e-10, axis=-1, exclusive=True). T_is = torch.roll(T_is, 1, -1) T_is[..., 0] = 1.0 # Calculate the weights (w_is) for the colors (c_is) along the rays using the # transmittances (T_is) and alphas (alpha_is). See Equation (5) in Section 5.2: # w_i = T_i * (1 - exp(-sigma_i * delta_i)). w_is = T_is * alpha_is # Calculate the pixel colors (C_rs) for the rays as weighted (w_is) sums of colors # (c_is). See Equation (5) in Section 5.2: C_c_hat(r) = Σ w_i * c_i. C_rs = (w_is[..., None] * c_is).sum(dim=-2) return (C_rs, w_is) def run_one_iter_of_nerf( ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_c, N_f, t_f, F_f ): (r_ts_c, t_is_c) = get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os) (C_rs_c, w_is_c) = render_radiance_volume(r_ts_c, ds, chunk_size, F_c, t_is_c) (r_ts_f, t_is_f) = get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds) (C_rs_f, _) = render_radiance_volume(r_ts_f, ds, chunk_size, F_f, t_is_f) return (C_rs_c, C_rs_f) class NeRFMLP(nn.Module): def __init__(self): super().__init__() # Number of encoding functions for positions. See Section 5.1. self.L_pos = 10 # Number of encoding functions for viewing directions. See Section 5.1. self.L_dir = 4 pos_enc_feats = 3 + 3 * 2 * self.L_pos dir_enc_feats = 3 + 3 * 2 * self.L_dir in_feats = pos_enc_feats net_width = 256 early_mlp_layers = 5 early_mlp = [] for layer_idx in range(early_mlp_layers): early_mlp.append(nn.Linear(in_feats, net_width)) early_mlp.append(nn.ReLU()) in_feats = net_width self.early_mlp = nn.Sequential(*early_mlp) in_feats = pos_enc_feats + net_width late_mlp_layers = 3 late_mlp = [] for layer_idx in range(late_mlp_layers): late_mlp.append(nn.Linear(in_feats, net_width)) late_mlp.append(nn.ReLU()) in_feats = net_width self.late_mlp = nn.Sequential(*late_mlp) self.sigma_layer = nn.Linear(net_width, net_width + 1) self.pre_final_layer = nn.Sequential( nn.Linear(dir_enc_feats + net_width, net_width // 2), nn.ReLU() ) self.final_layer = nn.Sequential(nn.Linear(net_width // 2, 3), nn.Sigmoid()) def forward(self, xs, ds): # Encode the inputs. See Equation (4) in Section 5.1. xs_encoded = [xs] for l_pos in range(self.L_pos): xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs)) xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs)) xs_encoded = torch.cat(xs_encoded, dim=-1) ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1) ds_encoded = [ds] for l_dir in range(self.L_dir): ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds)) ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds)) ds_encoded = torch.cat(ds_encoded, dim=-1) # Use the network to predict colors (c_is) and volume densities (sigma_is) for # 3D points (xs) along rays given the viewing directions (ds) of the rays. See # Section 3 and Figure 7 in the Supplementary Materials. outputs = self.early_mlp(xs_encoded) outputs = self.late_mlp(torch.cat([xs_encoded, outputs], dim=-1)) outputs = self.sigma_layer(outputs) sigma_is = torch.relu(outputs[:, 0]) outputs = self.pre_final_layer(torch.cat([ds_encoded, outputs[:, 1:]], dim=-1)) c_is = self.final_layer(outputs) return {"c_is": c_is, "sigma_is": sigma_is} def main(): # Set seed. seed = 9458 torch.manual_seed(seed) np.random.seed(seed) # Initialize coarse and fine MLPs. device = "cuda:0" F_c = NeRFMLP().to(device) F_f = NeRFMLP().to(device) # Number of query points passed through the MLP at a time. See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L488. chunk_size = 1024 * 32 # Number of training rays per iteration. See Section 5.3. batch_img_size = 64 n_batch_pix = batch_img_size**2 # Initialize optimizer. See Section 5.3. lr = 5e-4 optimizer = optim.Adam(list(F_c.parameters()) + list(F_f.parameters()), lr=lr) criterion = nn.MSELoss() # The learning rate decays exponentially. See Section 5.3 # See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L486. lrate_decay = 250 decay_steps = lrate_decay * 1000 # See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L707. decay_rate = 0.1 # Load dataset. data_f = "66bdbc812bd0a196e194052f3f12cb2e.npz" data = np.load(data_f) # Set up initial ray origin (init_o) and ray directions (init_ds). These are the # same across samples, we just rotate them based on the orientation of the camera. # See Section 4. images = data["images"] / 255 img_size = images.shape[1] xs = torch.arange(img_size) - (img_size / 2 - 0.5) ys = torch.arange(img_size) - (img_size / 2 - 0.5) (xs, ys) = torch.meshgrid(xs, -ys, indexing="xy") focal = float(data["focal"]) pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1) # We want the zs to be negative ones, so we divide everything by the focal length # (which is in pixel units). camera_coords = pixel_coords / focal init_ds = camera_coords.to(device) init_o = torch.Tensor(np.array([0, 0, float(data["camera_distance"])])).to(device) # Set up test view. test_idx = 150 plt.imshow(images[test_idx]) plt.show() test_img = torch.Tensor(images[test_idx]).to(device) poses = data["poses"] test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device) test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds) test_os = (test_R @ init_o).expand(test_ds.shape) # Initialize volume rendering hyperparameters. # Near bound. See Section 4. t_n = 1.0 # Far bound. See Section 4. t_f = 4.0 # Number of coarse samples along a ray. See Section 5.3. N_c = 64 # Number of fine samples along a ray. See Section 5.3. N_f = 128 # Bins used to sample depths along a ray. See Equation (2) in Section 4. t_i_c_gap = (t_f - t_n) / N_c t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device) # Start training model. train_idxs = np.arange(len(images)) != test_idx images = torch.Tensor(images[train_idxs]) poses = torch.Tensor(poses[train_idxs]) n_pix = img_size**2 pixel_ps = torch.full((n_pix,), 1 / n_pix).to(device) psnrs = [] iternums = [] # See Section 5.3. num_iters = 300000 display_every = 100 F_c.train() F_f.train() for i in range(num_iters): # Sample image and associated pose. target_img_idx = np.random.randint(images.shape[0]) target_pose = poses[target_img_idx].to(device) R = target_pose[:3, :3] # Get rotated ray origins (os) and ray directions (ds). See Section 4. ds = torch.einsum("ij,hwj->hwi", R, init_ds) os = (R @ init_o).expand(ds.shape) # Sample a batch of rays. pix_idxs = pixel_ps.multinomial(n_batch_pix, False) pix_idx_rows = pix_idxs // img_size pix_idx_cols = pix_idxs % img_size ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape( batch_img_size, batch_img_size, -1 ) os_batch = os[pix_idx_rows, pix_idx_cols].reshape( batch_img_size, batch_img_size, -1 ) # Run NeRF. (C_rs_c, C_rs_f) = run_one_iter_of_nerf( ds_batch, N_c, t_i_c_bin_edges, t_i_c_gap, os_batch, chunk_size, F_c, N_f, t_f, F_f, ) target_img = images[target_img_idx].to(device) target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape(C_rs_f.shape) # Calculate the mean squared error for both the coarse and fine MLP models and # update the weights. See Equation (6) in Section 5.3. loss = criterion(C_rs_c, target_img_batch) + criterion(C_rs_f, target_img_batch) optimizer.zero_grad() loss.backward() optimizer.step() # Exponentially decay learning rate. See Section 5.3 and: # https://keras.io/api/optimizers/learning_rate_schedules/exponential_decay/. for g in optimizer.param_groups: g["lr"] = lr * decay_rate ** (i / decay_steps) if i % display_every == 0: F_c.eval() F_f.eval() with torch.no_grad(): (_, C_rs_f) = run_one_iter_of_nerf( test_ds, N_c, t_i_c_bin_edges, t_i_c_gap, test_os, chunk_size, F_c, N_f, t_f, F_f, ) loss = criterion(C_rs_f, test_img) print(f"Loss: {loss.item()}") psnr = -10.0 * torch.log10(loss) psnrs.append(psnr.item()) iternums.append(i) plt.figure(figsize=(10, 4)) plt.subplot(121) plt.imshow(C_rs_f.detach().cpu().numpy()) plt.title(f"Iteration {i}") plt.subplot(122) plt.plot(iternums, psnrs) plt.title("PSNR") plt.show() F_c.train() F_f.train() print("Done!") if __name__ == "__main__": main() ================================================ FILE: run_nerf_alt.py ================================================ import matplotlib.pyplot as plt import numpy as np import torch from torch import nn, optim class NeRFMLP(nn.Module): def __init__(self): super().__init__() # Number of encoding functions for positions. See Section 5.1. self.L_pos = 10 # Number of encoding functions for viewing directions. See Section 5.1. self.L_dir = 4 pos_enc_feats = 3 + 3 * 2 * self.L_pos dir_enc_feats = 3 + 3 * 2 * self.L_dir in_feats = pos_enc_feats net_width = 256 early_mlp_layers = 5 early_mlp = [] for layer_idx in range(early_mlp_layers): early_mlp.append(nn.Linear(in_feats, net_width)) early_mlp.append(nn.ReLU()) in_feats = net_width self.early_mlp = nn.Sequential(*early_mlp) in_feats = pos_enc_feats + net_width late_mlp_layers = 3 late_mlp = [] for layer_idx in range(late_mlp_layers): late_mlp.append(nn.Linear(in_feats, net_width)) late_mlp.append(nn.ReLU()) in_feats = net_width self.late_mlp = nn.Sequential(*late_mlp) self.sigma_layer = nn.Linear(net_width, net_width + 1) self.pre_final_layer = nn.Sequential( nn.Linear(dir_enc_feats + net_width, net_width // 2), nn.ReLU() ) self.final_layer = nn.Sequential(nn.Linear(net_width // 2, 3), nn.Sigmoid()) def forward(self, xs, ds): # Encode the inputs. See Equation (4) in Section 5.1. xs_encoded = [xs] for l_pos in range(self.L_pos): xs_encoded.append(torch.sin(2 ** l_pos * torch.pi * xs)) xs_encoded.append(torch.cos(2 ** l_pos * torch.pi * xs)) xs_encoded = torch.cat(xs_encoded, dim=-1) ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1) ds_encoded = [ds] for l_dir in range(self.L_dir): ds_encoded.append(torch.sin(2 ** l_dir * torch.pi * ds)) ds_encoded.append(torch.cos(2 ** l_dir * torch.pi * ds)) ds_encoded = torch.cat(ds_encoded, dim=-1) # Use the network to predict colors (c_is) and volume densities (sigma_is) for # 3D points (xs) along rays given the viewing directions (ds) of the rays. See # Section 3 and Figure 7 in the Supplementary Materials. outputs = self.early_mlp(xs_encoded) outputs = self.late_mlp(torch.cat([xs_encoded, outputs], dim=-1)) outputs = self.sigma_layer(outputs) sigma_is = torch.relu(outputs[:, 0]) outputs = self.pre_final_layer(torch.cat([ds_encoded, outputs[:, 1:]], dim=-1)) c_is = self.final_layer(outputs) return {"c_is": c_is, "sigma_is": sigma_is} class NeRF: def __init__(self, device): # Initialize coarse and fine MLPs. self.F_c = NeRFMLP().to(device) self.F_f = NeRFMLP().to(device) # Number of query points passed through the MLPs at a time. See: # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L488. self.chunk_size = 1024 * 32 # Initialize volume rendering hyperparameters. # Near bound. See Section 4. self.t_n = t_n = 1.0 # Far bound. See Section 4. self.t_f = t_f = 4.0 # Number of coarse samples along a ray. See Section 5.3. self.N_c = N_c = 64 # Number of fine samples along a ray. See Section 5.3. self.N_f = 128 # Bins used to sample depths along a ray. See Equation (2) in Section 4. self.t_i_c_gap = t_i_c_gap = (t_f - t_n) / N_c self.t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device) def get_coarse_query_points(self, ds, os): # Sample depths (t_is_c). See Equation (2) in Section 4. u_is_c = torch.rand(*list(ds.shape[:2]) + [self.N_c]).to(ds) t_is_c = self.t_i_c_bin_edges + u_is_c * self.t_i_c_gap # Calculate the points along the rays (r_ts_c) using the ray origins (os), # sampled depths (t_is_c), and ray directions (ds). See Section 4: # r(t) = o + t * d. r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :] return (r_ts_c, t_is_c) def get_fine_query_points(self, w_is_c, t_is_c, os, ds): # See text surrounding Equation (5) in Section 5.2 and: # https://stephens999.github.io/fiveMinuteStats/inverse_transform_sampling.html#discrete_distributions. # Define PDFs (pdfs) and CDFs (cdfs) from weights (w_is_c). w_is_c = w_is_c + 1e-5 pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True) cdfs = torch.cumsum(pdfs, dim=-1) cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1) # Get uniform samples (us). us = torch.rand(list(cdfs.shape[:-1]) + [self.N_f]).to(w_is_c) # Use inverse transform sampling to sample the depths (t_is_f). idxs = torch.searchsorted(cdfs, us, right=True) t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1) idxs_capped = idxs.clone() max_ind = cdfs.shape[-1] idxs_capped[idxs_capped == max_ind] = max_ind - 1 t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped) t_i_f_top_edges[idxs == max_ind] = self.t_f t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges u_is_f = torch.rand_like(t_i_f_gaps).to(os) t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps # Combine the coarse (t_is_c) and fine (t_is_f) depths and sort them. (t_is_f, _) = torch.sort(torch.cat([t_is_c, t_is_f.detach()], dim=-1), dim=-1) # Calculate the points along the rays (r_ts_f) using the ray origins (os), # depths (t_is_f), and ray directions (ds). See Section 4: r(t) = o + t * d. r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :] return (r_ts_f, t_is_f) def render_radiance_volume(self, r_ts, ds, F, t_is): # Use the network (F) to predict colors (c_is) and volume densities (sigma_is) # for 3D points along rays (r_ts) given the viewing directions (ds) of the rays. # See Section 3 and Figure 7 in the Supplementary Materials. r_ts_flat = r_ts.reshape((-1, 3)) ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1) ds_flat = ds_rep.reshape((-1, 3)) c_is = [] sigma_is = [] # The network processes batches of inputs to avoid running out of memory. for chunk_start in range(0, r_ts_flat.shape[0], self.chunk_size): r_ts_batch = r_ts_flat[chunk_start : chunk_start + self.chunk_size] ds_batch = ds_flat[chunk_start : chunk_start + self.chunk_size] preds = F(r_ts_batch, ds_batch) c_is.append(preds["c_is"]) sigma_is.append(preds["sigma_is"]) c_is = torch.cat(c_is).reshape(r_ts.shape) sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1]) # Calculate the distances (delta_is) between points along the rays. The # differences in depths are scaled by the norms of the ray directions to get the # final distances. See text following Equation (3) in Section 4. delta_is = t_is[..., 1:] - t_is[..., :-1] # "Infinity". Guarantees last alpha is always one. one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape) delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1) delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1) # Calculate the alphas (alpha_is) of the 3D points using the volume densities # (sigma_is) and distances between points (delta_is). See text following # Equation (3) in Section 4 and https://en.wikipedia.org/wiki/Alpha_compositing. alpha_is = 1.0 - torch.exp(-sigma_is * delta_is) # Calculate the accumulated transmittances (T_is) along the rays from the alphas # (alpha_is). See Equation (3) in Section 4. T_i is "the probability that the # ray travels from t_n to t_i without hitting any other particle". T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1) # Guarantees the ray makes it at least to the first step. See: # https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L142, # which uses tf.math.cumprod(1.-alpha + 1e-10, axis=-1, exclusive=True). T_is = torch.roll(T_is, 1, -1) T_is[..., 0] = 1.0 # Calculate the weights (w_is) for the colors (c_is) along the rays using the # transmittances (T_is) and alphas (alpha_is). See Equation (5) in Section 5.2: # w_i = T_i * (1 - exp(-sigma_i * delta_i)). w_is = T_is * alpha_is # Calculate the pixel colors (C_rs) for the rays as weighted (w_is) sums of # colors (c_is). See Equation (5) in Section 5.2: C_c_hat(r) = Σ w_i * c_i. C_rs = (w_is[..., None] * c_is).sum(dim=-2) return (C_rs, w_is) def __call__(self, ds, os): (r_ts_c, t_is_c) = self.get_coarse_query_points(ds, os) (C_rs_c, w_is_c) = self.render_radiance_volume(r_ts_c, ds, self.F_c, t_is_c) (r_ts_f, t_is_f) = self.get_fine_query_points(w_is_c, t_is_c, os, ds) (C_rs_f, _) = self.render_radiance_volume(r_ts_f, ds, self.F_f, t_is_f) return (C_rs_c, C_rs_f) def load_data(device): data_f = "66bdbc812bd0a196e194052f3f12cb2e.npz" data = np.load(data_f) # Set up initial ray origin (init_o) and ray directions (init_ds). These are the # same across samples, we just rotate them based on the orientation of the camera. # See Section 4. images = data["images"] / 255 img_size = images.shape[1] xs = torch.arange(img_size) - (img_size / 2 - 0.5) ys = torch.arange(img_size) - (img_size / 2 - 0.5) (xs, ys) = torch.meshgrid(xs, -ys, indexing="xy") focal = float(data["focal"]) pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1) # We want the zs to be negative ones, so we divide everything by the focal length # (which is in pixel units). camera_coords = pixel_coords / focal init_ds = camera_coords.to(device) init_o = torch.Tensor(np.array([0, 0, float(data["camera_distance"])])).to(device) return (images, data["poses"], init_ds, init_o, img_size) def set_up_test_data(images, device, poses, init_ds, init_o): # Set up test view. test_idx = 150 plt.imshow(images[test_idx]) plt.show() test_img = torch.Tensor(images[test_idx]).to(device) test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device) test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds) test_os = (test_R @ init_o).expand(test_ds.shape) train_idxs = np.arange(len(images)) != test_idx return (test_ds, test_os, test_img, train_idxs) def main(): # Set seed. seed = 9458 torch.manual_seed(seed) np.random.seed(seed) # Initialize NeRF. device = "cuda:0" nerf = NeRF(device) # Number of training rays per iteration. See Section 5.3. batch_img_size = 64 n_batch_pix = batch_img_size ** 2 # Initialize optimizer. See Section 5.3. lr = 5e-4 train_params = list(nerf.F_c.parameters()) + list(nerf.F_f.parameters()) optimizer = optim.Adam(train_params, lr=lr) criterion = nn.MSELoss() # The learning rate decays exponentially. See Section 5.3 # See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L486. lrate_decay = 250 decay_steps = lrate_decay * 1000 # See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L707. decay_rate = 0.1 # Load dataset. (images, poses, init_ds, init_o, img_size) = load_data(device) (test_ds, test_os, test_img, train_idxs) = set_up_test_data( images, device, poses, init_ds, init_o ) images = torch.Tensor(images[train_idxs]) poses = torch.Tensor(poses[train_idxs]) n_pix = img_size ** 2 pixel_ps = torch.full((n_pix,), 1 / n_pix).to(device) # Start training model. psnrs = [] iternums = [] # See Section 5.3. num_iters = 300000 display_every = 100 nerf.F_c.train() nerf.F_f.train() for i in range(num_iters): # Sample image and associated pose. target_img_idx = np.random.randint(images.shape[0]) target_pose = poses[target_img_idx].to(device) R = target_pose[:3, :3] # Get rotated ray origins (os) and ray directions (ds). See Section 4. ds = torch.einsum("ij,hwj->hwi", R, init_ds) os = (R @ init_o).expand(ds.shape) # Sample a batch of rays. pix_idxs = pixel_ps.multinomial(n_batch_pix, False) pix_idx_rows = pix_idxs // img_size pix_idx_cols = pix_idxs % img_size ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape( batch_img_size, batch_img_size, -1 ) os_batch = os[pix_idx_rows, pix_idx_cols].reshape( batch_img_size, batch_img_size, -1 ) # Run NeRF. (C_rs_c, C_rs_f) = nerf(ds_batch, os_batch) target_img = images[target_img_idx].to(device) target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape(C_rs_f.shape) # Calculate the mean squared error for both the coarse and fine MLP models and # update the weights. See Equation (6) in Section 5.3. loss = criterion(C_rs_c, target_img_batch) + criterion(C_rs_f, target_img_batch) optimizer.zero_grad() loss.backward() optimizer.step() # Exponentially decay learning rate. See Section 5.3 and: # https://keras.io/api/optimizers/learning_rate_schedules/exponential_decay/. for g in optimizer.param_groups: g["lr"] = lr * decay_rate ** (i / decay_steps) if i % display_every == 0: nerf.F_c.eval() nerf.F_f.eval() with torch.no_grad(): (_, C_rs_f) = nerf(test_ds, test_os) loss = criterion(C_rs_f, test_img) print(f"Loss: {loss.item()}") psnr = -10.0 * torch.log10(loss) psnrs.append(psnr.item()) iternums.append(i) plt.figure(figsize=(10, 4)) plt.subplot(121) plt.imshow(C_rs_f.detach().cpu().numpy()) plt.title(f"Iteration {i}") plt.subplot(122) plt.plot(iternums, psnrs) plt.title("PSNR") plt.show() nerf.F_c.train() nerf.F_f.train() print("Done!") if __name__ == "__main__": main() ================================================ FILE: run_pixelnerf.py ================================================ import matplotlib.pyplot as plt import numpy as np import torch from image_encoder import ImageEncoder from pixelnerf_dataset import PixelNeRFDataset from torch import nn, optim def get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os): u_is_c = torch.rand(*list(ds.shape[:2]) + [N_c]).to(ds) t_is_c = t_i_c_bin_edges + u_is_c * t_i_c_gap r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :] return (r_ts_c, t_is_c) def get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds, r_ts_c, N_d, d_std, t_n): w_is_c = w_is_c + 1e-5 pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True) cdfs = torch.cumsum(pdfs, dim=-1) cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1) us = torch.rand(list(cdfs.shape[:-1]) + [N_f]).to(w_is_c) idxs = torch.searchsorted(cdfs, us, right=True) t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1) idxs_capped = idxs.clone() max_ind = cdfs.shape[-1] idxs_capped[idxs_capped == max_ind] = max_ind - 1 t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped) t_i_f_top_edges[idxs == max_ind] = t_f t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges u_is_f = torch.rand_like(t_i_f_gaps).to(os) t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps # See Section B.1 in the Supplementary Materials and: # https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150. t_is_d = (w_is_c * r_ts_c[..., 2]).sum(dim=-1) t_is_d = t_is_d.unsqueeze(2).repeat((1, 1, N_d)) t_is_d = t_is_d + torch.normal(0, d_std, size=t_is_d.shape).to(t_is_d) t_is_d = torch.clamp(t_is_d, t_n, t_f) t_is_f = torch.cat([t_is_c, t_is_f.detach(), t_is_d], dim=-1) (t_is_f, _) = torch.sort(t_is_f, dim=-1) r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :] return (r_ts_f, t_is_f) def get_image_features_for_query_points(r_ts, camera_distance, scale, W_i): # Get the projected image coordinates (pi_x_is) for each point along the rays # (r_ts). This is just geometry. See: http://www.songho.ca/opengl/gl_projectionmatrix.html. pi_x_is = r_ts[..., :2] / (camera_distance - r_ts[..., 2].unsqueeze(-1)) pi_x_is = pi_x_is / scale # PyTorch's grid_sample function assumes (-1, -1) is the left-top pixel, but we want # (-1, -1) to be the left-bottom pixel, so we negate the y-coordinates. pi_x_is[..., 1] = -1 * pi_x_is[..., 1] # PyTorch's grid_sample function expects the grid to have shape # (N, H_out, W_out, 2). pi_x_is = pi_x_is.permute(2, 0, 1, 3) # PyTorch's grid_sample function expects the input to have shape (N, C, H_in, W_in). W_i = W_i.repeat(pi_x_is.shape[0], 1, 1, 1) # Get the image features (z_is) associated with the projected image coordinates # (pi_x_is) from the encoded image features (W_i). See Section 4.2. z_is = nn.functional.grid_sample( W_i, pi_x_is, align_corners=True, padding_mode="border" ) # Convert shape back to match rays. z_is = z_is.permute(2, 3, 0, 1) return z_is def render_radiance_volume(r_ts, ds, z_is, chunk_size, F, t_is): r_ts_flat = r_ts.reshape((-1, 3)) ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1) ds_flat = ds_rep.reshape((-1, 3)) z_is_flat = z_is.reshape((ds_flat.shape[0], -1)) c_is = [] sigma_is = [] for chunk_start in range(0, r_ts_flat.shape[0], chunk_size): r_ts_batch = r_ts_flat[chunk_start : chunk_start + chunk_size] ds_batch = ds_flat[chunk_start : chunk_start + chunk_size] w_is_batch = z_is_flat[chunk_start : chunk_start + chunk_size] preds = F(r_ts_batch, ds_batch, w_is_batch) c_is.append(preds["c_is"]) sigma_is.append(preds["sigma_is"]) c_is = torch.cat(c_is).reshape(r_ts.shape) sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1]) delta_is = t_is[..., 1:] - t_is[..., :-1] one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape) delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1) delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1) alpha_is = 1.0 - torch.exp(-sigma_is * delta_is) T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1) T_is = torch.roll(T_is, 1, -1) T_is[..., 0] = 1.0 w_is = T_is * alpha_is C_rs = (w_is[..., None] * c_is).sum(dim=-2) return (C_rs, w_is) def run_one_iter_of_pixelnerf( ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, camera_distance, scale, W_i, chunk_size, F_c, N_f, t_f, N_d, d_std, t_n, F_f, ): (r_ts_c, t_is_c) = get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os) z_is_c = get_image_features_for_query_points(r_ts_c, camera_distance, scale, W_i) (C_rs_c, w_is_c) = render_radiance_volume( r_ts_c, ds, z_is_c, chunk_size, F_c, t_is_c ) (r_ts_f, t_is_f) = get_fine_query_points( w_is_c, N_f, t_is_c, t_f, os, ds, r_ts_c, N_d, d_std, t_n ) z_is_f = get_image_features_for_query_points(r_ts_f, camera_distance, scale, W_i) (C_rs_f, _) = render_radiance_volume(r_ts_f, ds, z_is_f, chunk_size, F_f, t_is_f) return (C_rs_c, C_rs_f) class PixelNeRFFCResNet(nn.Module): def __init__(self): super().__init__() # Number of encoding functions for positions. See Section B.1 in the # Supplementary Materials. self.L_pos = 6 # Number of encoding functions for viewing directions. self.L_dir = 0 pos_enc_feats = 3 + 3 * 2 * self.L_pos dir_enc_feats = 3 + 3 * 2 * self.L_dir # Set up ResNet MLP. See Section B.1 and Figure 18 in the Supplementary # Materials. net_width = 512 self.first_layer = nn.Sequential( nn.Linear(pos_enc_feats + dir_enc_feats, net_width) ) self.n_resnet_blocks = 5 z_linears = [] mlps = [] for resnet_block in range(self.n_resnet_blocks): z_linears.append(nn.Linear(net_width, net_width)) mlps.append( nn.Sequential( nn.Linear(net_width, net_width), nn.ReLU(), nn.Linear(net_width, net_width), nn.ReLU(), ) ) self.z_linears = nn.ModuleList(z_linears) self.mlps = nn.ModuleList(mlps) self.final_layer = nn.Linear(net_width, 4) def forward(self, xs, ds, zs): xs_encoded = [xs] for l_pos in range(self.L_pos): xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs)) xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs)) xs_encoded = torch.cat(xs_encoded, dim=-1) ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1) ds_encoded = [ds] for l_dir in range(self.L_dir): ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds)) ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds)) ds_encoded = torch.cat(ds_encoded, dim=-1) # Use the network to predict colors (c_is) and volume densities (sigma_is) for # 3D points (xs) along rays given the viewing directions (ds) of the rays # and the associated input image features (zs). See Section B.1 and Figure 18 in # the Supplementary Materials and: # https://github.com/sxyu/pixel-nerf/blob/master/src/model/resnetfc.py. outputs = self.first_layer(torch.cat([xs_encoded, ds_encoded], dim=-1)) for block_idx in range(self.n_resnet_blocks): resnet_zs = self.z_linears[block_idx](zs) outputs = outputs + resnet_zs outputs = self.mlps[block_idx](outputs) + outputs outputs = self.final_layer(outputs) sigma_is = torch.relu(outputs[:, 0]) c_is = torch.sigmoid(outputs[:, 1:]) return {"c_is": c_is, "sigma_is": sigma_is} def load_data(): # Initialize dataset and test object/poses. data_dir = "data" # See Section B.2.1 in the Supplementary Materials. num_iters = 400000 test_obj_idx = 5 test_source_pose_idx = 11 test_target_pose_idx = 33 train_dataset = PixelNeRFDataset( data_dir, num_iters, test_obj_idx, test_source_pose_idx, test_target_pose_idx ) return train_dataset def set_up_test_data(train_dataset, device): obj_idx = train_dataset.test_obj_idx obj = train_dataset.objs[obj_idx] data_dir = train_dataset.data_dir obj_dir = f"{data_dir}/{obj}" z_len = train_dataset.z_len source_pose_idx = train_dataset.test_source_pose_idx source_img_f = f"{obj_dir}/{str(source_pose_idx).zfill(z_len)}.npy" source_image = np.load(source_img_f) / 255 source_pose = train_dataset.poses[obj_idx, source_pose_idx] source_R = source_pose[:3, :3] target_pose_idx = train_dataset.test_target_pose_idx target_img_f = f"{obj_dir}/{str(target_pose_idx).zfill(z_len)}.npy" target_image = np.load(target_img_f) / 255 target_pose = train_dataset.poses[obj_idx, target_pose_idx] target_R = target_pose[:3, :3] R = torch.Tensor(source_R.T @ target_R).to(device) plt.imshow(source_image) plt.show() source_image = torch.Tensor(source_image) source_image = ( source_image - train_dataset.channel_means ) / train_dataset.channel_stds source_image = source_image.to(device).unsqueeze(0).permute(0, 3, 1, 2) plt.imshow(target_image) plt.show() target_image = torch.Tensor(target_image).to(device) return (source_image, R, target_image) def main(): seed = 9458 torch.manual_seed(seed) np.random.seed(seed) device = "cuda:0" F_c = PixelNeRFFCResNet().to(device) F_f = PixelNeRFFCResNet().to(device) E = ImageEncoder().to(device) chunk_size = 1024 * 32 # See Section B.2 in the Supplementary Materials. batch_img_size = 12 n_batch_pix = batch_img_size**2 n_objs = 4 # See Section B.2 in the Supplementary Materials. lr = 1e-4 optimizer = optim.Adam(list(F_c.parameters()) + list(F_f.parameters()), lr=lr) criterion = nn.MSELoss() train_dataset = load_data() camera_distance = train_dataset.camera_distance scale = train_dataset.scale t_n = 1.0 t_f = 4.0 img_size = train_dataset[0][2].shape[0] # See Section B.1 in the Supplementary Materials, # and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/conf/default.conf#L50, # and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150. N_c = 64 N_f = 16 N_d = 16 d_std = 0.01 t_i_c_gap = (t_f - t_n) / N_c t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device) init_o = train_dataset.init_o.to(device) init_ds = train_dataset.init_ds.to(device) (test_source_image, test_R, test_target_image) = set_up_test_data( train_dataset, device ) test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds) test_os = (test_R @ init_o).expand(test_ds.shape) psnrs = [] iternums = [] num_iters = train_dataset.N use_bbox = True num_bbox_iters = 300000 display_every = 100 F_c.train() F_f.train() E.eval() for i in range(num_iters): if i == num_bbox_iters: use_bbox = False loss = 0 for obj in range(n_objs): try: (source_image, R, target_image, bbox) = train_dataset[0] except ValueError: continue R = R.to(device) ds = torch.einsum("ij,hwj->hwi", R, init_ds) os = (R @ init_o).expand(ds.shape) if use_bbox: pix_rows = np.arange(bbox[0], bbox[2]) pix_cols = np.arange(bbox[1], bbox[3]) else: pix_rows = np.arange(0, img_size) pix_cols = np.arange(0, img_size) pix_row_cols = np.meshgrid(pix_rows, pix_cols, indexing="ij") pix_row_cols = np.stack(pix_row_cols).transpose(1, 2, 0).reshape(-1, 2) choices = np.arange(len(pix_row_cols)) try: selected_pix = np.random.choice(choices, n_batch_pix, False) except ValueError: continue pix_idx_rows = pix_row_cols[selected_pix, 0] pix_idx_cols = pix_row_cols[selected_pix, 1] ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape( batch_img_size, batch_img_size, -1 ) os_batch = os[pix_idx_rows, pix_idx_cols].reshape( batch_img_size, batch_img_size, -1 ) # Extract feature pyramid from image. See Section 4.1, Section B.1 in the # Supplementary Materials, and: https://github.com/sxyu/pixel-nerf/blob/master/src/model/encoder.py. with torch.no_grad(): W_i = E(source_image.unsqueeze(0).permute(0, 3, 1, 2).to(device)) (C_rs_c, C_rs_f) = run_one_iter_of_pixelnerf( ds_batch, N_c, t_i_c_bin_edges, t_i_c_gap, os_batch, camera_distance, scale, W_i, chunk_size, F_c, N_f, t_f, N_d, d_std, t_n, F_f, ) target_img = target_image.to(device) target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape( C_rs_c.shape ) loss += criterion(C_rs_c, target_img_batch) loss += criterion(C_rs_f, target_img_batch) try: optimizer.zero_grad() loss.backward() optimizer.step() except AttributeError: continue if i % display_every == 0: F_c.eval() F_f.eval() with torch.no_grad(): test_W_i = E(test_source_image) (_, C_rs_f) = run_one_iter_of_pixelnerf( test_ds, N_c, t_i_c_bin_edges, t_i_c_gap, test_os, camera_distance, scale, test_W_i, chunk_size, F_c, N_f, t_f, N_d, d_std, t_n, F_f, ) loss = criterion(C_rs_f, test_target_image) print(f"Loss: {loss.item()}") psnr = -10.0 * torch.log10(loss) psnrs.append(psnr.item()) iternums.append(i) plt.figure(figsize=(10, 4)) plt.subplot(121) plt.imshow(C_rs_f.detach().cpu().numpy()) plt.title(f"Iteration {i}") plt.subplot(122) plt.plot(iternums, psnrs) plt.title("PSNR") plt.show() F_c.train() F_f.train() print("Done!") if __name__ == "__main__": main() ================================================ FILE: run_pixelnerf_alt.py ================================================ import matplotlib.pyplot as plt import numpy as np import torch from image_encoder import ImageEncoder from pixelnerf_dataset import PixelNeRFDataset from torch import nn, optim class PixelNeRFFCResNet(nn.Module): def __init__(self): super().__init__() # Number of encoding functions for positions. See Section B.1 in the # Supplementary Materials. self.L_pos = 6 # Number of encoding functions for viewing directions. self.L_dir = 0 pos_enc_feats = 3 + 3 * 2 * self.L_pos dir_enc_feats = 3 + 3 * 2 * self.L_dir # Set up ResNet MLP. See Section B.1 and Figure 18 in the Supplementary # Materials. net_width = 512 self.first_layer = nn.Sequential( nn.Linear(pos_enc_feats + dir_enc_feats, net_width) ) self.n_resnet_blocks = 5 z_linears = [] mlps = [] for resnet_block in range(self.n_resnet_blocks): z_linears.append(nn.Linear(net_width, net_width)) mlps.append( nn.Sequential( nn.Linear(net_width, net_width), nn.ReLU(), nn.Linear(net_width, net_width), nn.ReLU(), ) ) self.z_linears = nn.ModuleList(z_linears) self.mlps = nn.ModuleList(mlps) self.final_layer = nn.Linear(net_width, 4) def forward(self, xs, ds, zs): xs_encoded = [xs] for l_pos in range(self.L_pos): xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs)) xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs)) xs_encoded = torch.cat(xs_encoded, dim=-1) ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1) ds_encoded = [ds] for l_dir in range(self.L_dir): ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds)) ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds)) ds_encoded = torch.cat(ds_encoded, dim=-1) # Use the network to predict colors (c_is) and volume densities (sigma_is) for # 3D points (xs) along rays given the viewing directions (ds) of the rays # and the associated input image features (zs). See Section B.1 and Figure 18 in # the Supplementary Materials and: # https://github.com/sxyu/pixel-nerf/blob/master/src/model/resnetfc.py. outputs = self.first_layer(torch.cat([xs_encoded, ds_encoded], dim=-1)) for block_idx in range(self.n_resnet_blocks): resnet_zs = self.z_linears[block_idx](zs) outputs = outputs + resnet_zs outputs = self.mlps[block_idx](outputs) + outputs outputs = self.final_layer(outputs) sigma_is = torch.relu(outputs[:, 0]) c_is = torch.sigmoid(outputs[:, 1:]) return {"c_is": c_is, "sigma_is": sigma_is} class PixelNeRF: def __init__(self, device, camera_distance, scale): self.device = device # See Section B.1 in the Supplementary Materials, # and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/conf/default.conf#L50, # and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150. self.N_c = N_c = 64 self.N_f = 16 self.N_d = 16 self.d_std = 0.01 self.t_n = t_n = 1.0 self.t_f = t_f = 4.0 self.t_i_c_gap = t_i_c_gap = (t_f - t_n) / N_c self.t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device) self.F_c = PixelNeRFFCResNet().to(device) self.F_f = PixelNeRFFCResNet().to(device) self.E = ImageEncoder().to(device) self.camera_distance = camera_distance self.scale = scale self.chunk_size = 1024 * 32 def get_coarse_query_points(self, ds, os): u_is_c = torch.rand(*list(ds.shape[:2]) + [self.N_c]).to(ds) t_is_c = self.t_i_c_bin_edges + u_is_c * self.t_i_c_gap r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :] return (r_ts_c, t_is_c) def get_fine_query_points(self, w_is_c, t_is_c, os, ds, r_ts_c): w_is_c = w_is_c + 1e-5 pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True) cdfs = torch.cumsum(pdfs, dim=-1) cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1) us = torch.rand(list(cdfs.shape[:-1]) + [self.N_f]).to(w_is_c) idxs = torch.searchsorted(cdfs, us, right=True) t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1) idxs_capped = idxs.clone() max_ind = cdfs.shape[-1] idxs_capped[idxs_capped == max_ind] = max_ind - 1 t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped) t_i_f_top_edges[idxs == max_ind] = self.t_f t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges u_is_f = torch.rand_like(t_i_f_gaps).to(os) t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps # See Section B.1 in the Supplementary Materials and: # https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150. t_is_d = (w_is_c * r_ts_c[..., 2]).sum(dim=-1) t_is_d = t_is_d.unsqueeze(2).repeat((1, 1, self.N_d)) t_is_d = t_is_d + torch.normal(0, self.d_std, size=t_is_d.shape).to(t_is_d) t_is_d = torch.clamp(t_is_d, self.t_n, self.t_f) t_is_f = torch.cat([t_is_c, t_is_f.detach(), t_is_d], dim=-1) (t_is_f, _) = torch.sort(t_is_f, dim=-1) r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :] return (r_ts_f, t_is_f) def get_image_features_for_query_points(self, r_ts, W_i): # Get the projected image coordinates (pi_x_is) for each point along the rays # (r_ts). This is just geometry. See: http://www.songho.ca/opengl/gl_projectionmatrix.html. pi_x_is = r_ts[..., :2] / (self.camera_distance - r_ts[..., 2].unsqueeze(-1)) pi_x_is = pi_x_is / self.scale # PyTorch's grid_sample function assumes (-1, -1) is the left-top pixel, but we want # (-1, -1) to be the left-bottom pixel, so we negate the y-coordinates. pi_x_is[..., 1] = -1 * pi_x_is[..., 1] # PyTorch's grid_sample function expects the grid to have shape # (N, H_out, W_out, 2). pi_x_is = pi_x_is.permute(2, 0, 1, 3) # PyTorch's grid_sample function expects the input to have shape (N, C, H_in, W_in). W_i = W_i.repeat(pi_x_is.shape[0], 1, 1, 1) # Get the image features (z_is) associated with the projected image coordinates # (pi_x_is) from the encoded image features (W_i). See Section 4.2. z_is = nn.functional.grid_sample( W_i, pi_x_is, align_corners=True, padding_mode="border" ) # Convert shape back to match rays. z_is = z_is.permute(2, 3, 0, 1) return z_is def render_radiance_volume(self, r_ts, ds, z_is, F, t_is): r_ts_flat = r_ts.reshape((-1, 3)) ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1) ds_flat = ds_rep.reshape((-1, 3)) z_is_flat = z_is.reshape((ds_flat.shape[0], -1)) c_is = [] sigma_is = [] for chunk_start in range(0, r_ts_flat.shape[0], self.chunk_size): r_ts_batch = r_ts_flat[chunk_start : chunk_start + self.chunk_size] ds_batch = ds_flat[chunk_start : chunk_start + self.chunk_size] w_is_batch = z_is_flat[chunk_start : chunk_start + self.chunk_size] preds = F(r_ts_batch, ds_batch, w_is_batch) c_is.append(preds["c_is"]) sigma_is.append(preds["sigma_is"]) c_is = torch.cat(c_is).reshape(r_ts.shape) sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1]) delta_is = t_is[..., 1:] - t_is[..., :-1] one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape) delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1) delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1) alpha_is = 1.0 - torch.exp(-sigma_is * delta_is) T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1) T_is = torch.roll(T_is, 1, -1) T_is[..., 0] = 1.0 w_is = T_is * alpha_is C_rs = (w_is[..., None] * c_is).sum(dim=-2) return (C_rs, w_is) def __call__(self, ds, os, source_image): (r_ts_c, t_is_c) = self.get_coarse_query_points(ds, os) # Extract feature pyramid from image. See Section 4.1, Section B.1 in the # Supplementary Materials, and: https://github.com/sxyu/pixel-nerf/blob/master/src/model/encoder.py. with torch.no_grad(): W_i = self.E(source_image.unsqueeze(0).permute(0, 3, 1, 2).to(self.device)) z_is_c = self.get_image_features_for_query_points(r_ts_c, W_i) (C_rs_c, w_is_c) = self.render_radiance_volume( r_ts_c, ds, z_is_c, self.F_c, t_is_c ) (r_ts_f, t_is_f) = self.get_fine_query_points(w_is_c, t_is_c, os, ds, r_ts_c) z_is_f = self.get_image_features_for_query_points(r_ts_f, W_i) (C_rs_f, _) = self.render_radiance_volume(r_ts_f, ds, z_is_f, self.F_f, t_is_f) return (C_rs_c, C_rs_f) def load_data(): # Initialize dataset and test object/poses. data_dir = "data" # See Section B.2.1 in the Supplementary Materials. num_iters = 400000 test_obj_idx = 5 test_source_pose_idx = 11 test_target_pose_idx = 33 train_dataset = PixelNeRFDataset( data_dir, num_iters, test_obj_idx, test_source_pose_idx, test_target_pose_idx ) return (num_iters, train_dataset) def set_up_test_data(train_dataset, device): obj_idx = train_dataset.test_obj_idx obj = train_dataset.objs[obj_idx] data_dir = train_dataset.data_dir obj_dir = f"{data_dir}/{obj}" z_len = train_dataset.z_len source_pose_idx = train_dataset.test_source_pose_idx source_img_f = f"{obj_dir}/{str(source_pose_idx).zfill(z_len)}.npy" source_image = np.load(source_img_f) / 255 source_pose = train_dataset.poses[obj_idx, source_pose_idx] source_R = source_pose[:3, :3] target_pose_idx = train_dataset.test_target_pose_idx target_img_f = f"{obj_dir}/{str(target_pose_idx).zfill(z_len)}.npy" target_image = np.load(target_img_f) / 255 target_pose = train_dataset.poses[obj_idx, target_pose_idx] target_R = target_pose[:3, :3] R = torch.Tensor(source_R.T @ target_R).to(device) plt.imshow(source_image) plt.show() source_image = torch.Tensor(source_image) source_image = ( source_image - train_dataset.channel_means ) / train_dataset.channel_stds plt.imshow(target_image) plt.show() target_image = torch.Tensor(target_image).to(device) return (source_image, R, target_image) def main(): seed = 9458 torch.manual_seed(seed) np.random.seed(seed) device = "cuda:0" (num_iters, train_dataset) = load_data() img_size = train_dataset[0][2].shape[0] pixelnerf = PixelNeRF(device, train_dataset.camera_distance, train_dataset.scale) # See Section B.2 in the Supplementary Materials. batch_img_size = 12 n_batch_pix = batch_img_size**2 n_objs = 4 # See Section B.2 in the Supplementary Materials. lr = 1e-4 train_params = list(pixelnerf.F_c.parameters()) + list(pixelnerf.F_f.parameters()) optimizer = optim.Adam(train_params, lr=lr) criterion = nn.MSELoss() (test_source_image, test_R, test_target_image) = set_up_test_data( train_dataset, device ) init_o = train_dataset.init_o.to(device) init_ds = train_dataset.init_ds.to(device) test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds) test_os = (test_R @ init_o).expand(test_ds.shape) psnrs = [] iternums = [] use_bbox = True num_bbox_iters = 300000 display_every = 100 pixelnerf.F_c.train() pixelnerf.F_f.train() pixelnerf.E.eval() for i in range(num_iters): if i == num_bbox_iters: use_bbox = False loss = 0 for obj in range(n_objs): try: (source_image, R, target_image, bbox) = train_dataset[0] except ValueError: continue R = R.to(device) ds = torch.einsum("ij,hwj->hwi", R, init_ds) os = (R @ init_o).expand(ds.shape) if use_bbox: pix_rows = np.arange(bbox[0], bbox[2]) pix_cols = np.arange(bbox[1], bbox[3]) else: pix_rows = np.arange(0, img_size) pix_cols = np.arange(0, img_size) pix_row_cols = np.meshgrid(pix_rows, pix_cols, indexing="ij") pix_row_cols = np.stack(pix_row_cols).transpose(1, 2, 0).reshape(-1, 2) choices = np.arange(len(pix_row_cols)) try: selected_pix = np.random.choice(choices, n_batch_pix, False) except ValueError: continue pix_idx_rows = pix_row_cols[selected_pix, 0] pix_idx_cols = pix_row_cols[selected_pix, 1] ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape( batch_img_size, batch_img_size, -1 ) os_batch = os[pix_idx_rows, pix_idx_cols].reshape( batch_img_size, batch_img_size, -1 ) (C_rs_c, C_rs_f) = pixelnerf(ds_batch, os_batch, source_image) target_img = target_image.to(device) target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape( C_rs_c.shape ) loss += criterion(C_rs_c, target_img_batch) loss += criterion(C_rs_f, target_img_batch) optimizer.zero_grad() loss.backward() optimizer.step() if i % display_every == 0: pixelnerf.F_c.eval() pixelnerf.F_f.eval() with torch.no_grad(): (_, C_rs_f) = pixelnerf(test_ds, test_os, test_source_image) loss = criterion(C_rs_f, test_target_image) print(f"Loss: {loss.item()}") psnr = -10.0 * torch.log10(loss) psnrs.append(psnr.item()) iternums.append(i) plt.figure(figsize=(10, 4)) plt.subplot(121) plt.imshow(C_rs_f.detach().cpu().numpy()) plt.title(f"Iteration {i}") plt.subplot(122) plt.plot(iternums, psnrs) plt.title("PSNR") plt.show() pixelnerf.F_c.train() pixelnerf.F_f.train() print("Done!") if __name__ == "__main__": main() ================================================ FILE: run_tiny_nerf.py ================================================ import matplotlib.pyplot as plt import numpy as np import torch from torch import nn, optim def get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os): u_is_c = torch.rand(*list(ds.shape[:2]) + [N_c]).to(ds) t_is_c = t_i_c_bin_edges + u_is_c * t_i_c_gap r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :] return (r_ts_c, t_is_c) def render_radiance_volume(r_ts, ds, chunk_size, F, t_is): r_ts_flat = r_ts.reshape((-1, 3)) ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1) ds_flat = ds_rep.reshape((-1, 3)) c_is = [] sigma_is = [] for chunk_start in range(0, r_ts_flat.shape[0], chunk_size): r_ts_batch = r_ts_flat[chunk_start : chunk_start + chunk_size] ds_batch = ds_flat[chunk_start : chunk_start + chunk_size] preds = F(r_ts_batch, ds_batch) c_is.append(preds["c_is"]) sigma_is.append(preds["sigma_is"]) c_is = torch.cat(c_is).reshape(r_ts.shape) sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1]) delta_is = t_is[..., 1:] - t_is[..., :-1] one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape) delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1) delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1) alpha_is = 1.0 - torch.exp(-sigma_is * delta_is) T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1) T_is = torch.roll(T_is, 1, -1) T_is[..., 0] = 1.0 w_is = T_is * alpha_is C_rs = (w_is[..., None] * c_is).sum(dim=-2) return C_rs def run_one_iter_of_tiny_nerf(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_c): (r_ts_c, t_is_c) = get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os) C_rs_c = render_radiance_volume(r_ts_c, ds, chunk_size, F_c, t_is_c) return C_rs_c class VeryTinyNeRFMLP(nn.Module): def __init__(self): super().__init__() self.L_pos = 6 self.L_dir = 4 pos_enc_feats = 3 + 3 * 2 * self.L_pos dir_enc_feats = 3 + 3 * 2 * self.L_dir net_width = 256 self.early_mlp = nn.Sequential( nn.Linear(pos_enc_feats, net_width), nn.ReLU(), nn.Linear(net_width, net_width + 1), nn.ReLU(), ) self.late_mlp = nn.Sequential( nn.Linear(net_width + dir_enc_feats, net_width), nn.ReLU(), nn.Linear(net_width, 3), nn.Sigmoid(), ) def forward(self, xs, ds): xs_encoded = [xs] for l_pos in range(self.L_pos): xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs)) xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs)) xs_encoded = torch.cat(xs_encoded, dim=-1) ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1) ds_encoded = [ds] for l_dir in range(self.L_dir): ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds)) ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds)) ds_encoded = torch.cat(ds_encoded, dim=-1) outputs = self.early_mlp(xs_encoded) sigma_is = outputs[:, 0] c_is = self.late_mlp(torch.cat([outputs[:, 1:], ds_encoded], dim=-1)) return {"c_is": c_is, "sigma_is": sigma_is} def main(): seed = 9458 torch.manual_seed(seed) np.random.seed(seed) device = "cuda:0" F_c = VeryTinyNeRFMLP().to(device) chunk_size = 16384 lr = 5e-3 optimizer = optim.Adam(F_c.parameters(), lr=lr) criterion = nn.MSELoss() data_f = "66bdbc812bd0a196e194052f3f12cb2e.npz" data = np.load(data_f) images = data["images"] / 255 img_size = images.shape[1] xs = torch.arange(img_size) - (img_size / 2 - 0.5) ys = torch.arange(img_size) - (img_size / 2 - 0.5) (xs, ys) = torch.meshgrid(xs, -ys, indexing="xy") focal = float(data["focal"]) pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1) camera_coords = pixel_coords / focal init_ds = camera_coords.to(device) init_o = torch.Tensor(np.array([0, 0, float(data["camera_distance"])])).to(device) test_idx = 150 plt.imshow(images[test_idx]) plt.show() test_img = torch.Tensor(images[test_idx]).to(device) poses = data["poses"] test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device) test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds) test_os = (test_R @ init_o).expand(test_ds.shape) t_n = 1.0 t_f = 4.0 N_c = 32 t_i_c_gap = (t_f - t_n) / N_c t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device) train_idxs = np.arange(len(images)) != test_idx images = torch.Tensor(images[train_idxs]) poses = torch.Tensor(poses[train_idxs]) psnrs = [] iternums = [] num_iters = 20000 display_every = 100 F_c.train() for i in range(num_iters): target_img_idx = np.random.randint(images.shape[0]) target_pose = poses[target_img_idx].to(device) R = target_pose[:3, :3] ds = torch.einsum("ij,hwj->hwi", R, init_ds) os = (R @ init_o).expand(ds.shape) C_rs_c = run_one_iter_of_tiny_nerf( ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_c ) loss = criterion(C_rs_c, images[target_img_idx].to(device)) optimizer.zero_grad() loss.backward() optimizer.step() if i % display_every == 0: F_c.eval() with torch.no_grad(): C_rs_c = run_one_iter_of_tiny_nerf( test_ds, N_c, t_i_c_bin_edges, t_i_c_gap, test_os, chunk_size, F_c ) loss = criterion(C_rs_c, test_img) print(f"Loss: {loss.item()}") psnr = -10.0 * torch.log10(loss) psnrs.append(psnr.item()) iternums.append(i) plt.figure(figsize=(10, 4)) plt.subplot(121) plt.imshow(C_rs_c.detach().cpu().numpy()) plt.title(f"Iteration {i}") plt.subplot(122) plt.plot(iternums, psnrs) plt.title("PSNR") plt.show() F_c.train() print("Done!") if __name__ == "__main__": main() ================================================ FILE: run_tiny_nerf_alt.py ================================================ import matplotlib.pyplot as plt import numpy as np import torch from torch import nn, optim class VeryTinyNeRFMLP(nn.Module): def __init__(self): super().__init__() self.L_pos = 6 self.L_dir = 4 pos_enc_feats = 3 + 3 * 2 * self.L_pos dir_enc_feats = 3 + 3 * 2 * self.L_dir net_width = 256 self.early_mlp = nn.Sequential( nn.Linear(pos_enc_feats, net_width), nn.ReLU(), nn.Linear(net_width, net_width + 1), nn.ReLU(), ) self.late_mlp = nn.Sequential( nn.Linear(net_width + dir_enc_feats, net_width), nn.ReLU(), nn.Linear(net_width, 3), nn.Sigmoid(), ) def forward(self, xs, ds): xs_encoded = [xs] for l_pos in range(self.L_pos): xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs)) xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs)) xs_encoded = torch.cat(xs_encoded, dim=-1) ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1) ds_encoded = [ds] for l_dir in range(self.L_dir): ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds)) ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds)) ds_encoded = torch.cat(ds_encoded, dim=-1) outputs = self.early_mlp(xs_encoded) sigma_is = outputs[:, 0] c_is = self.late_mlp(torch.cat([outputs[:, 1:], ds_encoded], dim=-1)) return {"c_is": c_is, "sigma_is": sigma_is} class VeryTinyNeRF: def __init__(self, device): self.F_c = VeryTinyNeRFMLP().to(device) self.chunk_size = 16384 self.t_n = t_n = 1.0 self.t_f = t_f = 4.0 self.N_c = N_c = 32 self.t_i_c_gap = t_i_c_gap = (t_f - t_n) / N_c self.t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device) def get_coarse_query_points(self, ds, os): u_is_c = torch.rand(*list(ds.shape[:2]) + [self.N_c]).to(ds) t_is_c = self.t_i_c_bin_edges + u_is_c * self.t_i_c_gap r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :] return (r_ts_c, t_is_c) def render_radiance_volume(self, r_ts, ds, F, t_is): r_ts_flat = r_ts.reshape((-1, 3)) ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1) ds_flat = ds_rep.reshape((-1, 3)) c_is = [] sigma_is = [] for chunk_start in range(0, r_ts_flat.shape[0], self.chunk_size): r_ts_batch = r_ts_flat[chunk_start : chunk_start + self.chunk_size] ds_batch = ds_flat[chunk_start : chunk_start + self.chunk_size] preds = F(r_ts_batch, ds_batch) c_is.append(preds["c_is"]) sigma_is.append(preds["sigma_is"]) c_is = torch.cat(c_is).reshape(r_ts.shape) sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1]) delta_is = t_is[..., 1:] - t_is[..., :-1] one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape) delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1) delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1) alpha_is = 1.0 - torch.exp(-sigma_is * delta_is) T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1) T_is = torch.roll(T_is, 1, -1) T_is[..., 0] = 1.0 w_is = T_is * alpha_is C_rs = (w_is[..., None] * c_is).sum(dim=-2) return C_rs def __call__(self, ds, os): (r_ts_c, t_is_c) = self.get_coarse_query_points(ds, os) C_rs_c = self.render_radiance_volume(r_ts_c, ds, self.F_c, t_is_c) return C_rs_c def load_data(device): data_f = "66bdbc812bd0a196e194052f3f12cb2e.npz" data = np.load(data_f) images = data["images"] / 255 img_size = images.shape[1] xs = torch.arange(img_size) - (img_size / 2 - 0.5) ys = torch.arange(img_size) - (img_size / 2 - 0.5) (xs, ys) = torch.meshgrid(xs, -ys, indexing="xy") focal = float(data["focal"]) pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1) camera_coords = pixel_coords / focal init_ds = camera_coords.to(device) init_o = torch.Tensor(np.array([0, 0, float(data["camera_distance"])])).to(device) return (images, data["poses"], init_ds, init_o, img_size) def set_up_test_data(images, device, poses, init_ds, init_o): test_idx = 150 plt.imshow(images[test_idx]) plt.show() test_img = torch.Tensor(images[test_idx]).to(device) test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device) test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds) test_os = (test_R @ init_o).expand(test_ds.shape) train_idxs = np.arange(len(images)) != test_idx return (test_ds, test_os, test_img, train_idxs) def main(): seed = 9458 torch.manual_seed(seed) np.random.seed(seed) device = "cuda:0" nerf = VeryTinyNeRF(device) lr = 5e-3 optimizer = optim.Adam(nerf.F_c.parameters(), lr=lr) criterion = nn.MSELoss() (images, poses, init_ds, init_o, test_img) = load_data(device) (test_ds, test_os, test_img, train_idxs) = set_up_test_data( images, device, poses, init_ds, init_o ) images = torch.Tensor(images[train_idxs]) poses = torch.Tensor(poses[train_idxs]) psnrs = [] iternums = [] num_iters = 20000 display_every = 100 nerf.F_c.train() for i in range(num_iters): target_img_idx = np.random.randint(images.shape[0]) target_pose = poses[target_img_idx].to(device) R = target_pose[:3, :3] ds = torch.einsum("ij,hwj->hwi", R, init_ds) os = (R @ init_o).expand(ds.shape) C_rs_c = nerf(ds, os) loss = criterion(C_rs_c, images[target_img_idx].to(device)) optimizer.zero_grad() loss.backward() optimizer.step() if i % display_every == 0: nerf.F_c.eval() with torch.no_grad(): C_rs_c = nerf(test_ds, test_os) loss = criterion(C_rs_c, test_img) print(f"Loss: {loss.item()}") psnr = -10.0 * torch.log10(loss) psnrs.append(psnr.item()) iternums.append(i) plt.figure(figsize=(10, 4)) plt.subplot(121) plt.imshow(C_rs_c.detach().cpu().numpy()) plt.title(f"Iteration {i}") plt.subplot(122) plt.plot(iternums, psnrs) plt.title("PSNR") plt.show() nerf.F_c.train() print("Done!") if __name__ == "__main__": main()