Repository: airalcorn2/pytorch-nerf
Branch: master
Commit: c9a055558e57
Files: 16
Total size: 23.1 MB
Directory structure:
gitextract_esqdzl8y/
├── .gitignore
├── 66bdbc812bd0a196e194052f3f12cb2e.npz
├── LICENSE
├── README.md
├── generate_nerf_dataset.py
├── generate_pixelnerf_dataset.py
├── image_encoder.py
├── pixelnerf_dataset.py
├── renderer.py
├── renderer_settings.py
├── run_nerf.py
├── run_nerf_alt.py
├── run_pixelnerf.py
├── run_pixelnerf_alt.py
├── run_tiny_nerf.py
└── run_tiny_nerf_alt.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.idea
__pycache__
data.zip
data
================================================
FILE: 66bdbc812bd0a196e194052f3f12cb2e.npz
================================================
[File too large to display: 23.0 MB]
================================================
FILE: LICENSE
================================================
Copyright 2022 Michael A. Alcorn
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
================================================
FILE: README.md
================================================
# PyTorch NeRF and pixelNeRF
**NeRF**: [](https://colab.research.google.com/drive/1oRnnlF-2YqCDIzoc-uShQm8_yymLKiqr)
**Tiny NeRF**: [](https://colab.research.google.com/drive/1ntlbzQ121-E1BSa5EKvAyai6SMG4cylj)
**pixelNeRF**: [](https://colab.research.google.com/drive/1VEEy4VOVoQTQKo4oG3nWcfKAXjC_0fFt)
This repository contains minimal PyTorch implementations of the NeRF model described in "[NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis](https://arxiv.org/abs/2003.08934)" and the pixelNeRF model described in ["pixelNeRF: Neural Radiance Fields from One or Few Images"](https://arxiv.org/abs/2012.02190).
While there are other PyTorch implementations out there (e.g., [this one](https://github.com/krrish94/nerf-pytorch) and [this one](https://github.com/yenchenlin/nerf-pytorch) for NeRF, and [the authors' official implementation](https://github.com/sxyu/pixel-nerf) for pixelNeRF), I personally found them somewhat difficult to follow, so I decided to do a complete rewrite of NeRF myself.
I tried to stay as close to the authors' text as possible, and I added comments in the code referring back to the relevant sections/equations in the paper.
The final result is a tight 355 lines of heavily commented code (301 sloc—"source lines of code"—on GitHub) all contained in [a single file](run_nerf.py). For comparison, [this PyTorch implementation](https://github.com/krrish94/nerf-pytorch) has approximately 970 sloc spread across several files, while [this PyTorch implementation](https://github.com/yenchenlin/nerf-pytorch) has approximately 905 sloc.
[`run_tiny_nerf.py`](run_tiny_nerf.py) trains a simplified NeRF model inspired by the "[Tiny NeRF](https://colab.research.google.com/github/bmild/nerf/blob/master/tiny_nerf.ipynb)" example provided by the NeRF authors.
This NeRF model does not use fine sampling and the MLP is smaller, but the code is otherwise identical to the full model code.
At only 153 sloc, it might be a good place to start for people who are completely new to NeRF.
If you prefer your code more object-oriented, check out [`run_nerf_alt.py`](run_nerf_alt.py) and [`run_tiny_nerf_alt.py`](run_tiny_nerf_alt.py).
A Colab notebook for the full model can be found [here](https://colab.research.google.com/drive/1oRnnlF-2YqCDIzoc-uShQm8_yymLKiqr?usp=sharing), while a notebook for the tiny model can be found [here](https://colab.research.google.com/drive/1ntlbzQ121-E1BSa5EKvAyai6SMG4cylj?usp=sharing).
The [`generate_nerf_dataset.py`](generate_nerf_dataset.py) script was used to generate the training data of the ShapeNet car (see "[Generating the ShapeNet datasets](#generating-the-shapenet-datasets)" for additional details).
For the following test view:

[`run_nerf.py`](run_nerf.py) generated the following after 20,100 iterations (a few hours on a P100 GPU):
**Loss**: 0.00022201683896128088

while [`run_tiny_nerf.py`](run_tiny_nerf.py) generated the following after 19,600 iterations (~35 minutes on a P100 GPU):
**Loss**: 0.0004151524917688221

The advantages of streamlining NeRF's code become readily apparent when trying to extend NeRF.
For example, [training a pixelNeRF model](run_pixelnerf.py) only required making a few changes to [`run_nerf.py`](run_nerf.py) bringing it to 368 sloc (notebook [here](https://colab.research.google.com/drive/1VEEy4VOVoQTQKo4oG3nWcfKAXjC_0fFt?usp=sharing)).
For comparison, [the official pixelNeRF implementation](https://github.com/sxyu/pixel-nerf) has approximately 1,300 pixelNeRF-specific (i.e., not related to the image encoder or dataset) sloc spread across several files.
The [`generate_pixelnerf_dataset.py`](generate_pixelnerf_dataset.py) script was used to generate the training data of ShapeNet cars (see "[Generating the ShapeNet datasets](#generating-the-shapenet-datasets)" for additional details).
For the following source object and view:

and target view:

[`run_pixelnerf.py`](run_pixelnerf.py) generated the following after 73,243 iterations (~12 hours on a P100 GPU; the full pixelNeRF model was trained for 400,000 iterations, which took six days):
**Loss**: 0.004468636587262154

The "smearing" is an artifact caused by the bounding box sampling method.
## Generating the ShapeNet datasets
1) Download the data (the ShapeNet server is pretty slow, so this will take a while):
```bash
SHAPENET_BASE_DIR=<path/to/your/shapenet/root>
nohup wget --quiet -P ${SHAPENET_BASE_DIR} http://shapenet.cs.stanford.edu/shapenet/obj-zip/ShapeNetCore.v2.zip > shapenet.log &
```
2) Unzip the data:
```bash
cd ${SHAPENET_BASE_DIR}
nohup unzip -q ShapeNetCore.v2.zip > shapenet.log &
```
3) ***After*** the file is done unzipping, remove the ZIP:
```bash
rm ShapeNetCore.v2.zip
```
4) Change the `SHAPENET_DIR` variable in [`generate_nerf_dataset.py`](generate_nerf_dataset.py) and [`generate_pixelnerf_dataset.py`](generate_pixelnerf_dataset.py) to `<path/to/your/shapenet/root>/ShapeNetCore.v2`.
================================================
FILE: generate_nerf_dataset.py
================================================
import numpy as np
from pyrr import Matrix44
from renderer import gen_rotation_matrix_from_cam_pos, Renderer
from renderer_settings import *
SHAPENET_DIR = "/run/media/airalcorn2/MiQ BIG/ShapeNetCore.v2"
def main():
# Set up the renderer.
renderer = Renderer(
camera_distance=CAMERA_DISTANCE,
angle_of_view=ANGLE_OF_VIEW,
dir_light=DIR_LIGHT,
dif_int=DIF_INT,
amb_int=AMB_INT,
default_width=WINDOW_SIZE,
default_height=WINDOW_SIZE,
cull_faces=CULL_FACES,
)
img_size = 100
# Calculate focal length in pixel units. This is just geometry. See:
# https://en.wikipedia.org/wiki/Angle_of_view#Derivation_of_the_angle-of-view_formula.
focal = (img_size / 2) / np.tan(np.radians(ANGLE_OF_VIEW) / 2)
# Load the ShapeNet car object.
obj = "66bdbc812bd0a196e194052f3f12cb2e"
cat = "02958343"
obj_mtl_path = f"{SHAPENET_DIR}/{cat}/{obj}/models/model_normalized"
renderer.set_up_obj(f"{obj_mtl_path}.obj", f"{obj_mtl_path}.mtl")
# Generate car renders using random camera locations.
init_cam_pos = np.array([0, 0, CAMERA_DISTANCE])
target = np.zeros(3)
up = np.array([0.0, 1.0, 0.0])
samps = 800
imgs = []
poses = []
for idx in range(samps):
# See: https://stats.stackexchange.com/a/7984/81836.
xyz = np.random.normal(size=3)
xyz /= np.linalg.norm(xyz)
R = gen_rotation_matrix_from_cam_pos(xyz)
eye = tuple((R @ init_cam_pos).flatten())
look_at = Matrix44.look_at(eye, target, up)
renderer.prog["VP"].write(
(look_at @ renderer.perspective).astype("f4").tobytes()
)
renderer.prog["cam_pos"].value = eye
image = renderer.render(0.5, 0.5, 0.5).resize((img_size, img_size))
imgs.append(np.array(image))
pose = np.eye(4)
pose[:3, :3] = np.array(look_at[:3, :3])
pose[:3, 3] = -look_at[:3, :3] @ look_at[3, :3]
poses.append(pose)
imgs = np.stack(imgs)
poses = np.stack(poses)
np.savez(
f"{obj}.npz",
images=imgs,
poses=poses,
focal=focal,
camera_distance=CAMERA_DISTANCE,
)
if __name__ == "__main__":
main()
================================================
FILE: generate_pixelnerf_dataset.py
================================================
import numpy as np
import os
import sys
from pyrr import Matrix44
from renderer import gen_rotation_matrix_from_cam_pos, Renderer
from renderer_settings import *
SHAPENET_DIR = "/run/media/airalcorn2/MiQ BIG/ShapeNetCore.v2"
def main():
# Set up the renderer.
renderer = Renderer(
camera_distance=CAMERA_DISTANCE,
angle_of_view=ANGLE_OF_VIEW,
dir_light=DIR_LIGHT,
dif_int=DIF_INT,
amb_int=AMB_INT,
default_width=WINDOW_SIZE,
default_height=WINDOW_SIZE,
cull_faces=CULL_FACES,
)
# See Section 5.1.1.
img_size = 128
# Calculate focal length in pixel units. This is just geometry. See:
# https://en.wikipedia.org/wiki/Angle_of_view#Derivation_of_the_angle-of-view_formula.
focal = (img_size / 2) / np.tan(np.radians(ANGLE_OF_VIEW) / 2)
# Generate car renders using random camera locations.
init_cam_pos = np.array([0, 0, CAMERA_DISTANCE])
target = np.zeros(3)
up = np.array([0.0, 1.0, 0.0])
# See Section 5.1.1.
samps = 50
z_len = len(str(samps - 1))
data_dir = "data"
poses = []
os.mkdir(data_dir)
# Car category.
cat = "02958343"
objs = os.listdir(f"{SHAPENET_DIR}/{cat}")
used_objs = []
for obj in objs:
# Load the ShapeNet object.
obj_mtl_path = f"{SHAPENET_DIR}/{cat}/{obj}/models/model_normalized"
try:
renderer.set_up_obj(f"{obj_mtl_path}.obj", f"{obj_mtl_path}.mtl")
sys.stderr.flush()
except OSError:
print(f"{SHAPENET_DIR}/{cat}/{obj} is empty.", flush=True)
continue
except FloatingPointError:
print(f"{SHAPENET_DIR}/{cat}/{obj} divides by zero.", flush=True)
obj_dir = f"{data_dir}/{obj}"
os.mkdir(obj_dir)
obj_poses = []
for samp_idx in range(samps):
# See: https://stats.stackexchange.com/a/7984/81836.
xyz = np.random.normal(size=3)
xyz /= np.linalg.norm(xyz)
R = gen_rotation_matrix_from_cam_pos(xyz)
eye = tuple((R @ init_cam_pos).flatten())
look_at = Matrix44.look_at(eye, target, up)
renderer.prog["VP"].write(
(look_at @ renderer.perspective).astype("f4").tobytes()
)
renderer.prog["cam_pos"].value = eye
image = renderer.render(0.5, 0.5, 0.5).resize((img_size, img_size))
np.save(f"{obj_dir}/{str(samp_idx).zfill(z_len)}.npy", np.array(image))
pose = np.eye(4)
pose[:3, :3] = np.array(look_at[:3, :3])
pose[:3, 3] = -look_at[:3, :3] @ look_at[3, :3]
obj_poses.append(pose)
obj_poses = np.stack(obj_poses)
poses.append(obj_poses)
renderer.release_obj()
used_objs.append(obj)
poses = np.stack(poses)
np.savez(
f"{data_dir}/poses.npz",
poses=poses,
focal=focal,
camera_distance=CAMERA_DISTANCE,
)
with open(f"{data_dir}/objs.txt", "w") as f:
print("\n".join(used_objs), file=f)
if __name__ == "__main__":
main()
================================================
FILE: image_encoder.py
================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet34
class ImageEncoder(nn.Module):
def __init__(self):
super().__init__()
self.resnet = resnet34(True)
def forward(self, x):
# Extract feature pyramid from image. See Section 4.1., Section B.1 in the
# Supplementary Materials, and: https://github.com/sxyu/pixel-nerf/blob/master/src/model/encoder.py.
x = self.resnet.conv1(x)
x = self.resnet.bn1(x)
feats1 = self.resnet.relu(x)
feats2 = self.resnet.layer1(self.resnet.maxpool(feats1))
feats3 = self.resnet.layer2(feats2)
feats4 = self.resnet.layer3(feats3)
latents = [feats1, feats2, feats3, feats4]
latent_sz = latents[0].shape[-2:]
for i in range(len(latents)):
latents[i] = F.interpolate(
latents[i], latent_sz, mode="bilinear", align_corners=True
)
latents = torch.cat(latents, dim=1)
return latents
================================================
FILE: pixelnerf_dataset.py
================================================
import numpy as np
import torch
from torch.utils.data import Dataset
class PixelNeRFDataset(Dataset):
def __init__(
self,
data_dir,
num_iters,
test_obj_idx,
test_source_pose_idx,
test_target_pose_idx,
):
self.data_dir = data_dir
self.N = num_iters
with open(f"{data_dir}/objs.txt") as f:
self.objs = f.read().split("\n")[:-1]
self.test_obj_idx = test_obj_idx
self.test_source_pose_idx = test_source_pose_idx
self.test_target_pose_idx = test_target_pose_idx
data = np.load(f"{data_dir}/poses.npz")
self.poses = poses = data["poses"]
(n_objs, n_poses) = poses.shape[:2]
self.z_len = len(str(n_poses - 1))
self.poses = torch.Tensor(poses)
self.channel_means = torch.Tensor([0.485, 0.456, 0.406])
self.channel_stds = torch.Tensor([0.229, 0.224, 0.225])
samp_img = np.load(f"{data_dir}/{self.objs[0]}/{str(0).zfill(self.z_len)}.npy")
img_size = samp_img.shape[0]
self.pix_idxs = np.arange(img_size ** 2)
xs = torch.arange(img_size) - (img_size / 2 - 0.5)
ys = torch.arange(img_size) - (img_size / 2 - 0.5)
(xs, ys) = torch.meshgrid(xs, -ys, indexing="xy")
focal = float(data["focal"])
pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)
camera_coords = pixel_coords / focal
self.init_ds = camera_coords
self.camera_distance = camera_distance = float(data["camera_distance"])
self.init_o = torch.Tensor(np.array([0, 0, camera_distance]))
# tan(theta) = opposite / adjacent.
self.scale = (img_size / 2) / focal
def __len__(self):
return self.N
def __getitem__(self, idx):
obj_idx = np.random.randint(self.poses.shape[0])
obj = self.objs[obj_idx]
obj_dir = f"{self.data_dir}/{obj}"
source_pose_idx = np.random.randint(self.poses.shape[1])
if obj_idx == self.test_obj_idx:
while source_pose_idx == self.test_source_pose_idx:
source_pose_idx = np.random.randint(self.poses.shape[1])
source_img_f = f"{obj_dir}/{str(source_pose_idx).zfill(self.z_len)}.npy"
source_image = torch.Tensor(np.load(source_img_f) / 255)
source_image = (source_image - self.channel_means) / self.channel_stds
source_pose = self.poses[obj_idx, source_pose_idx]
source_R = source_pose[:3, :3]
target_pose_idx = np.random.randint(self.poses.shape[1])
if obj_idx == self.test_obj_idx:
while (target_pose_idx == self.test_source_pose_idx) or (
target_pose_idx == self.test_target_pose_idx
):
target_pose_idx = np.random.randint(self.poses.shape[1])
target_img_f = f"{obj_dir}/{str(target_pose_idx).zfill(self.z_len)}.npy"
target_image = np.load(target_img_f)
not_gray_pix = np.argwhere((target_image == 128).sum(-1) != 3)
top_row = not_gray_pix[:, 0].min()
bottom_row = not_gray_pix[:, 0].max()
left_col = not_gray_pix[:, 1].min()
right_col = not_gray_pix[:, 1].max()
bbox = (top_row, left_col, bottom_row, right_col)
target_image = np.load(target_img_f) / 255
target_pose = self.poses[obj_idx, target_pose_idx]
target_R = target_pose[:3, :3]
R = source_R.T @ target_R
return (source_image, torch.Tensor(R), torch.Tensor(target_image), bbox)
================================================
FILE: renderer.py
================================================
import logging
import moderngl
import numpy as np
from PIL import Image, ImageOps
from pyrr import Matrix44
from scipy.spatial.transform import Rotation
YAW_PITCH_ROLL = {"yaw", "pitch", "roll"}
AZIM_ELEV_IN_PLANE = {"azimuth", "elevation", "in_plane"}
TOL = 1e-6
def gen_rotation_matrix_from_cam_pos(xyz, in_plane=0.0):
assert 1 - np.linalg.norm(xyz) < TOL
cam_from = xyz
cam_to = np.zeros(3)
tmp = np.array([0.0, 1.0, 0.0])
diff = cam_from - cam_to
forward = diff / np.linalg.norm(diff)
crossed = np.cross(tmp, forward)
right = crossed / np.linalg.norm(crossed)
up = np.cross(forward, right)
R = np.stack([right, up, forward])
R_in_plane = Rotation.from_euler("Z", in_plane).as_matrix()
return R_in_plane @ R
def gen_rotation_matrix_from_azim_elev_in_plane(
azimuth=0.0, elevation=0.0, in_plane=0.0
):
# See: https://www.scratchapixel.com/lessons/mathematics-physics-for-computer-graphics/lookat-function.
y = np.sin(elevation)
radius = np.cos(elevation)
x = radius * np.sin(azimuth)
z = radius * np.cos(azimuth)
cam_from = np.array([x, y, z])
cam_to = np.zeros(3)
tmp = np.array([0.0, 1.0, 0.0])
diff = cam_from - cam_to
forward = diff / np.linalg.norm(diff)
crossed = np.cross(tmp, forward)
right = crossed / np.linalg.norm(crossed)
up = np.cross(forward, right)
R = np.stack([right, up, forward])
R_in_plane = Rotation.from_euler("Z", in_plane).as_matrix()
return R_in_plane @ R
def parse_obj_file(input_obj):
"""Parse wavefront .obj file.
:param input_obj:
:return: Dictionary of NumPy arrays with shape (3 * num_faces, 8). Each row contains
(1) the coordinates of a vertex of a face, (2) the vertex's normal vector, and (3)
the texture coordinates for the vertex.
"""
data = {"v": [], "vn": [], "vt": []}
packed_arrays = {}
obj_f = open(input_obj)
current_mtl = None
min_vec = np.full(3, np.inf)
max_vec = np.full(3, -np.inf)
empty_vt = np.array([0.0, 0.0, 0.0])
for line in obj_f:
line = line.strip()
if line == "":
continue
parts = line.split()
elem_type = parts[0]
if elem_type in data:
vals = np.array(parts[1:4], dtype=np.float32)
if elem_type == "v":
min_vec = np.minimum(min_vec, vals)
max_vec = np.maximum(max_vec, vals)
elif elem_type == "vn":
vals /= np.linalg.norm(vals)
elif elem_type == "vt":
if len(vals) < 3:
vals = np.array(list(vals) + [0.0], dtype=np.float32)
data[elem_type].append(vals)
elif elem_type == "f":
f = parts[1:4]
for fv in f:
(v, vt, vn) = fv.split("/")
# Convert to zero-based indexing.
v = int(v) - 1
vn = int(vn) - 1
vt = int(vt) - 1 if vt else -1
if vt == -1:
row = np.concatenate((data["v"][v], data["vn"][vn], empty_vt))
else:
row = np.concatenate((data["v"][v], data["vn"][vn], data["vt"][vt]))
packed_arrays[current_mtl].append(row)
elif elem_type == "usemtl":
current_mtl = parts[1]
if current_mtl not in packed_arrays:
packed_arrays[current_mtl] = []
elif elem_type == "l":
if current_mtl in packed_arrays:
packed_arrays.pop(current_mtl)
max_pos_vec = max_vec - min_vec
max_pos_val = max(max_pos_vec)
max_pos_vec_norm = max_pos_vec / max_pos_val
for (sub_obj, packed_array) in packed_arrays.items():
# z-coordinate of texture is always zero (if present).
packed_array = np.stack(packed_array)[:, :8]
original_vertices = packed_array[:, :3].copy()
# All coordinates greater than or equal to zero.
original_vertices -= min_vec
# All coordinates between zero and one.
original_vertices /= max_pos_val
# All coordinates between zero and two.
original_vertices *= 2
# All coordinates between negative one and positive one with the center of object
# at (0, 0, 0).
original_vertices -= max_pos_vec_norm
packed_array[:, :3] = original_vertices
packed_arrays[sub_obj] = packed_array
all_vertices = np.stack(data["v"])
all_vertices -= min_vec
all_vertices /= max_pos_val
all_vertices *= 2
all_vertices -= max_pos_vec_norm
return (packed_arrays, all_vertices)
def parse_mtl_file(input_mtl):
vector_elems = {"Ka", "Kd", "Ks"}
float_elems = {"Ns", "Ni", "d"}
int_elems = {"illum"}
current_mtl = None
mtl_infos = {}
mtl_f = open(input_mtl)
sub_objs = []
for line in mtl_f:
line = line.strip()
if line == "":
continue
parts = line.split()
elem_type = parts[0]
if elem_type in vector_elems:
vals = np.array(parts[1:4], dtype=np.float32)
mtl_infos[current_mtl][elem_type] = tuple(vals)
elif elem_type in float_elems:
mtl_infos[current_mtl][elem_type] = float(parts[1])
elif elem_type in int_elems:
mtl_infos[current_mtl][elem_type] = int(parts[1])
elif elem_type == "newmtl":
current_mtl = parts[1]
sub_objs.append(current_mtl)
mtl_infos[current_mtl] = {"d": 1.0}
elif elem_type == "map_Kd":
mtl_infos[current_mtl]["map_Kd"] = parts[1]
sub_objs.sort()
sub_objs.reverse()
non_trans = [sub_obj for sub_obj in sub_objs if mtl_infos[sub_obj]["d"] == 1.0]
trans = [
(sub_obj, mtl_infos[sub_obj]["d"])
for sub_obj in sub_objs
if mtl_infos[sub_obj]["d"] < 1.0
]
trans.sort(key=lambda x: x[1], reverse=True)
sub_objs = non_trans + [sub_obj for (sub_obj, d) in trans]
return (mtl_infos, sub_objs)
def get_texture_data(sub_objs, packed_arrays, mtl_infos, obj_f):
texture_data = {}
texture_path_list = obj_f.split("/")
img_str_len = len("images/")
for sub_obj in sub_objs:
if sub_obj not in packed_arrays:
continue
if "map_Kd" in mtl_infos[sub_obj]:
texture_f = mtl_infos[sub_obj]["map_Kd"]
img_str_idx = texture_f.find("images/")
if img_str_idx != -1:
texture_path = "/".join(texture_path_list[:-2] + ["images"])
texture_f = texture_f[img_str_idx + img_str_len :]
else:
texture_path = "/".join(texture_path_list[:-1])
try:
texture_img = (
Image.open(texture_path + "/" + texture_f)
.transpose(Image.FLIP_TOP_BOTTOM)
.convert("RGBA")
)
except FileNotFoundError:
texture_f_parts = texture_f.split(".")
ext = texture_f_parts[-1]
if ext.isupper():
texture_f_parts[-1] = ext.lower()
elif ext.islower():
texture_f_parts[-1] = ext.upper()
texture_f = ".".join(texture_f_parts)
texture_img = (
Image.open(texture_path + "/" + texture_f)
.transpose(Image.FLIP_TOP_BOTTOM)
.convert("RGBA")
)
texture_data[sub_obj] = {
"size": texture_img.size,
"bytes": texture_img.tobytes(),
}
return texture_data
class Renderer:
def __init__(
self,
background_f=None,
camera_distance=2.0,
angle_of_view=16.426,
dir_light=(0, 1 / np.sqrt(2), np.sqrt(2)),
dif_int=0.7,
amb_int=0.7,
default_width=128,
default_height=128,
cull_faces=True,
):
# Initialize OpenGL context.
self.ctx = moderngl.create_standalone_context()
# Render depth appropriately.
self.ctx.enable(moderngl.DEPTH_TEST)
# Setting for rendering transparent objects.
# See: https://learnopengl.com/Advanced-OpenGL/Blending
# and: https://github.com/cprogrammer1994/ModernGL/blob/master/moderngl/context.py#L129.
self.ctx.enable(moderngl.BLEND)
# Define OpenGL program.
prog = self.ctx.program(
vertex_shader="""
#version 330
uniform float x;
uniform float y;
uniform float z;
uniform mat3 R_obj;
uniform mat3 R_light;
uniform vec3 DirLight;
uniform mat4 VP;
uniform int mode;
in vec3 in_vert;
in vec3 in_norm;
in vec2 in_text;
out vec3 v_pos;
out vec3 v_norm;
out vec2 v_text;
out vec3 v_light;
void main() {
if (mode == 0) {
v_pos = R_obj * in_vert + vec3(x, y, z);
gl_Position = VP * vec4(v_pos, 1.0);
v_norm = R_obj * in_norm;
v_text = in_text;
v_light = R_light * DirLight;
} else {
gl_Position = vec4(in_vert, 1.0);
v_text = in_text;
}
}
""",
fragment_shader="""
#version 330
uniform float amb_int;
uniform float dif_int;
uniform vec3 cam_pos;
uniform sampler2D Texture;
uniform int mode;
uniform bool use_texture;
uniform bool has_image;
uniform vec3 box_rgb;
uniform vec3 amb_rgb;
uniform vec3 dif_rgb;
uniform vec3 spc_rgb;
uniform float spec_exp;
uniform float trans;
in vec3 v_pos;
in vec3 v_norm;
in vec2 v_text;
in vec3 v_light;
out vec4 f_color;
void main() {
if (mode == 0) {
float dif = clamp(dot(v_light, v_norm), 0.0, 1.0) * dif_int;
if (use_texture) {
vec3 surface_rgb = dif_rgb;
vec3 diffuse = dif * surface_rgb;
if (has_image) {
surface_rgb = texture(Texture, v_text).rgb;
diffuse = dif * dif_rgb * surface_rgb;
}
vec3 ambient = amb_int * amb_rgb * surface_rgb;
float spec = 0.0;
if (dif > 0.0) {
vec3 reflected = reflect(-v_light, v_norm);
vec3 surface_to_camera = normalize(cam_pos - v_pos);
spec = pow(clamp(dot(surface_to_camera, reflected), 0.0, 1.0), spec_exp);
}
vec3 specular = spec * spc_rgb * surface_rgb;
vec3 linear = ambient + diffuse + specular;
f_color = vec4(linear, trans);
} else {
f_color = vec4(vec3(1.0, 1.0, 1.0) * dif + amb_int, 1.0);
}
} else if (mode == 1) {
f_color = vec4(texture(Texture, v_text).rgba);
} else {
f_color = vec4(box_rgb, 1.0);
}
}
""",
)
# Lighting uniform variables.
prog["R_light"].write(np.eye(3).astype("f4").tobytes())
dir_light = np.array(dir_light)
prog["DirLight"].value = tuple(dir_light / np.linalg.norm(dir_light))
prog["dif_int"].value = dif_int
prog["amb_int"].value = amb_int
prog["amb_rgb"].value = (1.0, 1.0, 1.0)
prog["dif_rgb"].value = (1.0, 1.0, 1.0)
prog["spc_rgb"].value = (1.0, 1.0, 1.0)
prog["spec_exp"].value = 0.0
self.use_spec = True
# Mode uniform variables.
prog["mode"].value = 0
prog["use_texture"].value = True
prog["has_image"].value = False
# Model transformation uniform variables.
prog["R_obj"].write(np.eye(3).astype("f4").tobytes())
prog["x"].value = 0
prog["y"].value = 0
prog["z"].value = 0
# Set up background.
self.prog = prog
(self.default_width, self.default_height) = (default_width, default_height)
self.background = None
(window_width, window_height) = self.set_up_background(background_f)
# Look at origin matrix.
eye = np.array([0.0, 0.0, camera_distance])
prog["cam_pos"].value = tuple(eye)
target = np.zeros(3)
up = np.array([0.0, 1.0, 0.0])
self.look_at = Matrix44.look_at(eye, target, up)
# Perspective projection matrix.
self.ratio = window_width / window_height
self.angle_of_view = angle_of_view
self.perspective = Matrix44.perspective_projection(
angle_of_view, self.ratio, 0.1, 1000.0
)
# View-Projection uniform variable.
self.prog["VP"].write((self.look_at @ self.perspective).astype("f4").tobytes())
# Set up object.
self.mtl_infos = None
self.cull_faces = cull_faces
self.render_objs = []
self.vbos = {}
self.vaos = {}
self.textures = {}
# Initialize frame buffer.
size = (window_width, window_height)
self.window_size = size
# Set up multisample anti-aliasing.
self.ctx.multisample = True
color_rbo = self.ctx.renderbuffer(size, samples=self.ctx.max_samples)
depth_rbo = self.ctx.depth_renderbuffer(size, samples=self.ctx.max_samples)
self.fbo = self.ctx.framebuffer(color_rbo, depth_rbo)
color_rbo2 = self.ctx.renderbuffer(size)
depth_rbo2 = self.ctx.depth_renderbuffer(size)
self.fbo2 = self.ctx.framebuffer(color_rbo2, depth_rbo2)
self.fbo.use()
def set_up_obj(self, obj_f, mtl_f):
(packed_arrays, vertices) = parse_obj_file(obj_f)
packed_arrays = {
sub_obj: packed_array.flatten().astype("f4").tobytes()
for (sub_obj, packed_array) in packed_arrays.items()
}
(mtl_infos, sub_objs) = parse_mtl_file(mtl_f)
texture_data = get_texture_data(sub_objs, packed_arrays, mtl_infos, obj_f)
self.load_obj(packed_arrays, vertices, mtl_infos, sub_objs, texture_data)
def load_obj(self, packed_arrays, vertices, mtl_infos, sub_objs, texture_data):
self.hom_vertices = np.hstack([vertices, np.ones(len(vertices))[:, None]])
render_objs = []
vbos = {}
vaos = {}
textures = {}
for sub_obj in sub_objs:
if sub_obj not in packed_arrays:
logging.info(f"Skipping {sub_obj}.")
continue
render_objs.append(sub_obj)
packed_array = packed_arrays[sub_obj]
vbo = self.ctx.buffer(packed_array)
vbos[sub_obj] = vbo
# Recall that "in_vert", "in_norm", and "in_text" are the inputs to the
# vertex shader.
vao = self.ctx.simple_vertex_array(
self.prog, vbo, "in_vert", "in_norm", "in_text"
)
vaos[sub_obj] = vao
if "map_Kd" in mtl_infos[sub_obj]:
# Initialize texture from image.
texture = self.ctx.texture(
texture_data[sub_obj]["size"], 4, texture_data[sub_obj]["bytes"]
)
texture.build_mipmaps()
textures[sub_obj] = texture
self.mtl_infos = mtl_infos
self.render_objs = render_objs
self.vbos = vbos
self.vaos = vaos
self.textures = textures
def set_up_background(self, background_f=None):
if background_f:
background_img = (
Image.open(background_f)
.transpose(Image.FLIP_TOP_BOTTOM)
.convert("RGBA")
)
# Initialize background from image.
background = self.ctx.texture(
background_img.size, 4, background_img.tobytes()
)
background.build_mipmaps()
self.background = background
# Create a square plane from two triangles (two sets of three points).
vertices = np.array(
[
[-1.0, -1.0, 0.0],
[-1.0, 1.0, 0.0],
[1.0, 1.0, 0.0],
[-1.0, -1.0, 0.0],
[1.0, -1.0, 0.0],
[1.0, 1.0, 0.0],
]
)
# These arrays are not used by the renderer, but the vertex shader expects
# them as input.
normals = np.repeat([[0.0, 0.0, 1.0]], len(vertices), axis=0)
# The texture (UV) coordinates corresponding to the above triangle points.
texture_coords = np.array(
[[0.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0], [1.0, 0.0], [1.0, 1.0]]
)
background_array = np.hstack((vertices, normals, texture_coords))
self.background_vbo = self.ctx.buffer(
background_array.flatten().astype("f4").tobytes()
)
self.background_vao = self.ctx.simple_vertex_array(
self.prog, self.background_vbo, "in_vert", "in_norm", "in_text"
)
return (background_img.width, background_img.height)
else:
return (self.default_width, self.default_height)
def render(self, r=0.485, g=0.456, b=0.406, with_alpha=False):
if self.background is not None:
# See: https://computergraphics.stackexchange.com/a/4007.
self.ctx.disable(moderngl.DEPTH_TEST)
self.prog["mode"].value = 1
self.background.use()
self.fbo.clear()
self.background_vao.render()
self.ctx.enable(moderngl.DEPTH_TEST)
self.prog["mode"].value = 0
else:
self.fbo.clear(r, g, b)
if self.cull_faces:
self.ctx.enable(moderngl.CULL_FACE)
for render_obj in self.render_objs:
if self.prog["use_texture"].value:
self.prog["amb_rgb"].value = self.mtl_infos[render_obj]["Ka"]
self.prog["dif_rgb"].value = self.mtl_infos[render_obj]["Kd"]
if self.use_spec:
self.prog["spc_rgb"].value = self.mtl_infos[render_obj]["Ks"]
self.prog["spec_exp"].value = self.mtl_infos[render_obj]["Ns"]
else:
self.prog["spc_rgb"].value = (0.0, 0.0, 0.0)
self.prog["trans"].value = self.mtl_infos[render_obj]["d"]
if render_obj in self.textures:
self.prog["has_image"].value = True
self.textures[render_obj].use()
self.vaos[render_obj].render()
self.prog["has_image"].value = False
self.ctx.disable(moderngl.CULL_FACE)
self.ctx.copy_framebuffer(self.fbo2, self.fbo)
if with_alpha:
return Image.frombytes(
"RGBA",
self.fbo.size,
self.fbo2.read(components=4),
"raw",
"RGBA",
0,
-1,
)
else:
return Image.frombytes(
"RGB", self.fbo.size, self.fbo2.read(), "raw", "RGB", 0, -1
)
def get_vertex_screen_coordinates(self):
world = np.eye(4)
world[:3, :3] = np.array(self.prog["R_obj"].value).reshape((3, 3)).T
world[:3, 3] = (
self.prog["x"].value,
self.prog["y"].value,
self.prog["z"].value,
)
PV = np.array(self.prog["VP"].value).reshape((4, 4)).T
pre_screen_coords = PV @ world @ self.hom_vertices.T
(window_width, window_height) = self.window_size
screen_xs = (
window_width
* (np.array(pre_screen_coords[0]) / np.array(pre_screen_coords[3]) + 1)
/ 2
)
screen_ys = (
window_height
* (np.array(pre_screen_coords[1]) / np.array(pre_screen_coords[3]) + 1)
/ 2
)
screen_coords = np.hstack((screen_xs, screen_ys))
screen = np.zeros((window_height, window_width))
for i in range(len(screen_xs)):
col = x = int(screen_xs[i])
row = y = int(screen_ys[i])
if x < window_width and y < window_height:
screen[window_height - row - 1, col] = 1
screen_mat = np.uint8(255 * screen)
screen_img = Image.fromarray(screen_mat, mode="L")
return (screen_coords, screen_img)
def __del__(self):
self.release()
def release_obj(self):
for sub_obj in self.vbos:
self.vbos[sub_obj].release()
self.vaos[sub_obj].release()
if sub_obj in self.textures:
self.textures[sub_obj].release()
self.vbos = {}
self.vaos = {}
self.textures = {}
def release_background(self):
if self.background is not None:
self.background.release()
self.background_vbo.release()
self.background_vao.release()
self.background = None
def release(self):
self.release_obj()
self.release_background()
self.fbo.release()
self.fbo2.release()
self.ctx.release()
def adjust_angle_of_view(self, angle_of_view):
self.angle_of_view = angle_of_view
perspective = Matrix44.perspective_projection(
self.angle_of_view, self.ratio, 0.1, 1000.0
)
self.prog["VP"].write((perspective * self.look_at).astype("f4").tobytes())
def set_params(self, params):
ypr_params = {}
ae_params = {}
for (param, value) in params.items():
if param in self.prog:
self.prog[param].value = value
elif param == "aov":
self.adjust_angle_of_view(value)
elif param in YAW_PITCH_ROLL:
ypr_params[param] = value
elif param in AZIM_ELEV_IN_PLANE:
ae_params[param] = value
if len(ypr_params) > 0:
yaw = ypr_params.get("yaw", 0)
pitch = ypr_params.get("pitch", 0)
roll = ypr_params.get("roll", 0)
R_obj = Rotation.from_euler("YXZ", [yaw, pitch, roll]).as_matrix()
self.prog["R_obj"].write(R_obj.T.astype("f4").tobytes())
elif len(ae_params) > 0:
R_obj = gen_rotation_matrix_from_azim_elev_in_plane(**ae_params)
self.prog["R_obj"].write(R_obj.T.astype("f4").tobytes())
def get_depth_arrays(self):
depth = np.frombuffer(
self.fbo2.read(attachment=-1, dtype="f4"), dtype=np.dtype("f4")
)
depth = 1 - depth.reshape(self.window_size)
min_pos = depth[depth > 0].min()
depth[depth > 0] = depth[depth > 0] - min_pos
depth_normed = depth / depth.max()
return (depth, depth_normed)
def get_depth_map(self):
(depth, depth_normed) = self.get_depth_arrays()
depth_map = np.uint8(255 * depth_normed)
return ImageOps.flip(Image.fromarray(depth_map, "L"))
def get_normal_map(self):
# See: https://stackoverflow.com/questions/5281261/generating-a-normal-map-from-a-height-map
# and: https://stackoverflow.com/questions/34644101/calculate-surface-normals-from-depth-image-using-neighboring-pixels-cross-produc
# and: https://en.wikipedia.org/wiki/Normal_mapping#How_it_works.
(depth, depth_normed) = self.get_depth_arrays()
depth_pad = np.pad(depth_normed, 1, "constant")
(dx, dy) = (1 / depth.shape[1], 1 / depth.shape[0])
dz_dx = (depth_pad[1:-1, 2:] - depth_pad[1:-1, :-2]) / (2 * dx)
dz_dy = (depth_pad[2:, 1:-1] - depth_pad[:-2, 1:-1]) / (2 * dy)
norms = np.stack([-dz_dx.flatten(), -dz_dy.flatten(), np.ones(dz_dx.size)])
magnitudes = np.linalg.norm(norms, axis=0)
norms /= magnitudes
norms = norms.T
norms[:, :2] = 255 * (norms[:, :2] + 1) / 2
norms[:, 2] = 127 * norms[:, 2] + 128
norms = np.uint8(norms).reshape((*depth.shape, 3))
return ImageOps.flip(Image.fromarray(norms))
================================================
FILE: renderer_settings.py
================================================
WINDOW_SIZE = 256
IMG_SIZE = 128
CULL_FACES = True
CAMERA_DISTANCE = 2.25
# See: https://en.wikipedia.org/wiki/Angle_of_view#Common_lens_angles_of_view.
ANGLE_OF_VIEW = 53.962828459664856
# Lighting.
DIR_LIGHT = (0, 1 / (2 ** 0.5), 2 ** 0.5)
DIF_INT = 0.7
AMB_INT = 0.7
================================================
FILE: run_nerf.py
================================================
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
def get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os):
# Sample depths (t_is_c). See Equation (2) in Section 4.
u_is_c = torch.rand(*list(ds.shape[:2]) + [N_c]).to(ds)
t_is_c = t_i_c_bin_edges + u_is_c * t_i_c_gap
# Calculate the points along the rays (r_ts_c) using the ray origins (os), sampled
# depths (t_is_c), and ray directions (ds). See Section 4: r(t) = o + t * d.
r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]
return (r_ts_c, t_is_c)
def get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds):
# See text surrounding Equation (5) in Section 5.2 and:
# https://stephens999.github.io/fiveMinuteStats/inverse_transform_sampling.html#discrete_distributions.
# Define PDFs (pdfs) and CDFs (cdfs) from weights (w_is_c).
w_is_c = w_is_c + 1e-5
pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True)
cdfs = torch.cumsum(pdfs, dim=-1)
cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1)
# Get uniform samples (us).
us = torch.rand(list(cdfs.shape[:-1]) + [N_f]).to(w_is_c)
# Use inverse transform sampling to sample the depths (t_is_f).
idxs = torch.searchsorted(cdfs, us, right=True)
t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1)
idxs_capped = idxs.clone()
max_ind = cdfs.shape[-1]
idxs_capped[idxs_capped == max_ind] = max_ind - 1
t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped)
t_i_f_top_edges[idxs == max_ind] = t_f
t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges
u_is_f = torch.rand_like(t_i_f_gaps).to(os)
t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps
# Combine the coarse (t_is_c) and fine (t_is_f) depths and sort them.
(t_is_f, _) = torch.sort(torch.cat([t_is_c, t_is_f.detach()], dim=-1), dim=-1)
# Calculate the points along the rays (r_ts_f) using the ray origins (os), depths
# (t_is_f), and ray directions (ds). See Section 4: r(t) = o + t * d.
r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :]
return (r_ts_f, t_is_f)
def render_radiance_volume(r_ts, ds, chunk_size, F, t_is):
# Use the network (F) to predict colors (c_is) and volume densities (sigma_is) for
# 3D points along rays (r_ts) given the viewing directions (ds) of the rays. See
# Section 3 and Figure 7 in the Supplementary Materials.
r_ts_flat = r_ts.reshape((-1, 3))
ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)
ds_flat = ds_rep.reshape((-1, 3))
c_is = []
sigma_is = []
# The network processes batches of inputs to avoid running out of memory.
for chunk_start in range(0, r_ts_flat.shape[0], chunk_size):
r_ts_batch = r_ts_flat[chunk_start : chunk_start + chunk_size]
ds_batch = ds_flat[chunk_start : chunk_start + chunk_size]
preds = F(r_ts_batch, ds_batch)
c_is.append(preds["c_is"])
sigma_is.append(preds["sigma_is"])
c_is = torch.cat(c_is).reshape(r_ts.shape)
sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])
# Calculate the distances (delta_is) between points along the rays. The differences
# in depths are scaled by the norms of the ray directions to get the final
# distances. See text following Equation (3) in Section 4.
delta_is = t_is[..., 1:] - t_is[..., :-1]
# "Infinity". Guarantees last alpha is always one.
one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)
delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)
delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)
# Calculate the alphas (alpha_is) of the 3D points using the volume densities
# (sigma_is) and distances between points (delta_is). See text following Equation
# (3) in Section 4 and https://en.wikipedia.org/wiki/Alpha_compositing.
alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)
# Calculate the accumulated transmittances (T_is) along the rays from the alphas
# (alpha_is). See Equation (3) in Section 4. T_i is "the probability that the ray
# travels from t_n to t_i without hitting any other particle".
T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)
# Guarantees the ray makes it at least to the first step. See:
# https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L142,
# which uses tf.math.cumprod(1.-alpha + 1e-10, axis=-1, exclusive=True).
T_is = torch.roll(T_is, 1, -1)
T_is[..., 0] = 1.0
# Calculate the weights (w_is) for the colors (c_is) along the rays using the
# transmittances (T_is) and alphas (alpha_is). See Equation (5) in Section 5.2:
# w_i = T_i * (1 - exp(-sigma_i * delta_i)).
w_is = T_is * alpha_is
# Calculate the pixel colors (C_rs) for the rays as weighted (w_is) sums of colors
# (c_is). See Equation (5) in Section 5.2: C_c_hat(r) = Σ w_i * c_i.
C_rs = (w_is[..., None] * c_is).sum(dim=-2)
return (C_rs, w_is)
def run_one_iter_of_nerf(
ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_c, N_f, t_f, F_f
):
(r_ts_c, t_is_c) = get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os)
(C_rs_c, w_is_c) = render_radiance_volume(r_ts_c, ds, chunk_size, F_c, t_is_c)
(r_ts_f, t_is_f) = get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds)
(C_rs_f, _) = render_radiance_volume(r_ts_f, ds, chunk_size, F_f, t_is_f)
return (C_rs_c, C_rs_f)
class NeRFMLP(nn.Module):
def __init__(self):
super().__init__()
# Number of encoding functions for positions. See Section 5.1.
self.L_pos = 10
# Number of encoding functions for viewing directions. See Section 5.1.
self.L_dir = 4
pos_enc_feats = 3 + 3 * 2 * self.L_pos
dir_enc_feats = 3 + 3 * 2 * self.L_dir
in_feats = pos_enc_feats
net_width = 256
early_mlp_layers = 5
early_mlp = []
for layer_idx in range(early_mlp_layers):
early_mlp.append(nn.Linear(in_feats, net_width))
early_mlp.append(nn.ReLU())
in_feats = net_width
self.early_mlp = nn.Sequential(*early_mlp)
in_feats = pos_enc_feats + net_width
late_mlp_layers = 3
late_mlp = []
for layer_idx in range(late_mlp_layers):
late_mlp.append(nn.Linear(in_feats, net_width))
late_mlp.append(nn.ReLU())
in_feats = net_width
self.late_mlp = nn.Sequential(*late_mlp)
self.sigma_layer = nn.Linear(net_width, net_width + 1)
self.pre_final_layer = nn.Sequential(
nn.Linear(dir_enc_feats + net_width, net_width // 2), nn.ReLU()
)
self.final_layer = nn.Sequential(nn.Linear(net_width // 2, 3), nn.Sigmoid())
def forward(self, xs, ds):
# Encode the inputs. See Equation (4) in Section 5.1.
xs_encoded = [xs]
for l_pos in range(self.L_pos):
xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs))
xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs))
xs_encoded = torch.cat(xs_encoded, dim=-1)
ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)
ds_encoded = [ds]
for l_dir in range(self.L_dir):
ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds))
ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds))
ds_encoded = torch.cat(ds_encoded, dim=-1)
# Use the network to predict colors (c_is) and volume densities (sigma_is) for
# 3D points (xs) along rays given the viewing directions (ds) of the rays. See
# Section 3 and Figure 7 in the Supplementary Materials.
outputs = self.early_mlp(xs_encoded)
outputs = self.late_mlp(torch.cat([xs_encoded, outputs], dim=-1))
outputs = self.sigma_layer(outputs)
sigma_is = torch.relu(outputs[:, 0])
outputs = self.pre_final_layer(torch.cat([ds_encoded, outputs[:, 1:]], dim=-1))
c_is = self.final_layer(outputs)
return {"c_is": c_is, "sigma_is": sigma_is}
def main():
# Set seed.
seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)
# Initialize coarse and fine MLPs.
device = "cuda:0"
F_c = NeRFMLP().to(device)
F_f = NeRFMLP().to(device)
# Number of query points passed through the MLP at a time. See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L488.
chunk_size = 1024 * 32
# Number of training rays per iteration. See Section 5.3.
batch_img_size = 64
n_batch_pix = batch_img_size**2
# Initialize optimizer. See Section 5.3.
lr = 5e-4
optimizer = optim.Adam(list(F_c.parameters()) + list(F_f.parameters()), lr=lr)
criterion = nn.MSELoss()
# The learning rate decays exponentially. See Section 5.3
# See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L486.
lrate_decay = 250
decay_steps = lrate_decay * 1000
# See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L707.
decay_rate = 0.1
# Load dataset.
data_f = "66bdbc812bd0a196e194052f3f12cb2e.npz"
data = np.load(data_f)
# Set up initial ray origin (init_o) and ray directions (init_ds). These are the
# same across samples, we just rotate them based on the orientation of the camera.
# See Section 4.
images = data["images"] / 255
img_size = images.shape[1]
xs = torch.arange(img_size) - (img_size / 2 - 0.5)
ys = torch.arange(img_size) - (img_size / 2 - 0.5)
(xs, ys) = torch.meshgrid(xs, -ys, indexing="xy")
focal = float(data["focal"])
pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)
# We want the zs to be negative ones, so we divide everything by the focal length
# (which is in pixel units).
camera_coords = pixel_coords / focal
init_ds = camera_coords.to(device)
init_o = torch.Tensor(np.array([0, 0, float(data["camera_distance"])])).to(device)
# Set up test view.
test_idx = 150
plt.imshow(images[test_idx])
plt.show()
test_img = torch.Tensor(images[test_idx]).to(device)
poses = data["poses"]
test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device)
test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds)
test_os = (test_R @ init_o).expand(test_ds.shape)
# Initialize volume rendering hyperparameters.
# Near bound. See Section 4.
t_n = 1.0
# Far bound. See Section 4.
t_f = 4.0
# Number of coarse samples along a ray. See Section 5.3.
N_c = 64
# Number of fine samples along a ray. See Section 5.3.
N_f = 128
# Bins used to sample depths along a ray. See Equation (2) in Section 4.
t_i_c_gap = (t_f - t_n) / N_c
t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)
# Start training model.
train_idxs = np.arange(len(images)) != test_idx
images = torch.Tensor(images[train_idxs])
poses = torch.Tensor(poses[train_idxs])
n_pix = img_size**2
pixel_ps = torch.full((n_pix,), 1 / n_pix).to(device)
psnrs = []
iternums = []
# See Section 5.3.
num_iters = 300000
display_every = 100
F_c.train()
F_f.train()
for i in range(num_iters):
# Sample image and associated pose.
target_img_idx = np.random.randint(images.shape[0])
target_pose = poses[target_img_idx].to(device)
R = target_pose[:3, :3]
# Get rotated ray origins (os) and ray directions (ds). See Section 4.
ds = torch.einsum("ij,hwj->hwi", R, init_ds)
os = (R @ init_o).expand(ds.shape)
# Sample a batch of rays.
pix_idxs = pixel_ps.multinomial(n_batch_pix, False)
pix_idx_rows = pix_idxs // img_size
pix_idx_cols = pix_idxs % img_size
ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape(
batch_img_size, batch_img_size, -1
)
os_batch = os[pix_idx_rows, pix_idx_cols].reshape(
batch_img_size, batch_img_size, -1
)
# Run NeRF.
(C_rs_c, C_rs_f) = run_one_iter_of_nerf(
ds_batch,
N_c,
t_i_c_bin_edges,
t_i_c_gap,
os_batch,
chunk_size,
F_c,
N_f,
t_f,
F_f,
)
target_img = images[target_img_idx].to(device)
target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape(C_rs_f.shape)
# Calculate the mean squared error for both the coarse and fine MLP models and
# update the weights. See Equation (6) in Section 5.3.
loss = criterion(C_rs_c, target_img_batch) + criterion(C_rs_f, target_img_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Exponentially decay learning rate. See Section 5.3 and:
# https://keras.io/api/optimizers/learning_rate_schedules/exponential_decay/.
for g in optimizer.param_groups:
g["lr"] = lr * decay_rate ** (i / decay_steps)
if i % display_every == 0:
F_c.eval()
F_f.eval()
with torch.no_grad():
(_, C_rs_f) = run_one_iter_of_nerf(
test_ds,
N_c,
t_i_c_bin_edges,
t_i_c_gap,
test_os,
chunk_size,
F_c,
N_f,
t_f,
F_f,
)
loss = criterion(C_rs_f, test_img)
print(f"Loss: {loss.item()}")
psnr = -10.0 * torch.log10(loss)
psnrs.append(psnr.item())
iternums.append(i)
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.imshow(C_rs_f.detach().cpu().numpy())
plt.title(f"Iteration {i}")
plt.subplot(122)
plt.plot(iternums, psnrs)
plt.title("PSNR")
plt.show()
F_c.train()
F_f.train()
print("Done!")
if __name__ == "__main__":
main()
================================================
FILE: run_nerf_alt.py
================================================
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
class NeRFMLP(nn.Module):
def __init__(self):
super().__init__()
# Number of encoding functions for positions. See Section 5.1.
self.L_pos = 10
# Number of encoding functions for viewing directions. See Section 5.1.
self.L_dir = 4
pos_enc_feats = 3 + 3 * 2 * self.L_pos
dir_enc_feats = 3 + 3 * 2 * self.L_dir
in_feats = pos_enc_feats
net_width = 256
early_mlp_layers = 5
early_mlp = []
for layer_idx in range(early_mlp_layers):
early_mlp.append(nn.Linear(in_feats, net_width))
early_mlp.append(nn.ReLU())
in_feats = net_width
self.early_mlp = nn.Sequential(*early_mlp)
in_feats = pos_enc_feats + net_width
late_mlp_layers = 3
late_mlp = []
for layer_idx in range(late_mlp_layers):
late_mlp.append(nn.Linear(in_feats, net_width))
late_mlp.append(nn.ReLU())
in_feats = net_width
self.late_mlp = nn.Sequential(*late_mlp)
self.sigma_layer = nn.Linear(net_width, net_width + 1)
self.pre_final_layer = nn.Sequential(
nn.Linear(dir_enc_feats + net_width, net_width // 2), nn.ReLU()
)
self.final_layer = nn.Sequential(nn.Linear(net_width // 2, 3), nn.Sigmoid())
def forward(self, xs, ds):
# Encode the inputs. See Equation (4) in Section 5.1.
xs_encoded = [xs]
for l_pos in range(self.L_pos):
xs_encoded.append(torch.sin(2 ** l_pos * torch.pi * xs))
xs_encoded.append(torch.cos(2 ** l_pos * torch.pi * xs))
xs_encoded = torch.cat(xs_encoded, dim=-1)
ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)
ds_encoded = [ds]
for l_dir in range(self.L_dir):
ds_encoded.append(torch.sin(2 ** l_dir * torch.pi * ds))
ds_encoded.append(torch.cos(2 ** l_dir * torch.pi * ds))
ds_encoded = torch.cat(ds_encoded, dim=-1)
# Use the network to predict colors (c_is) and volume densities (sigma_is) for
# 3D points (xs) along rays given the viewing directions (ds) of the rays. See
# Section 3 and Figure 7 in the Supplementary Materials.
outputs = self.early_mlp(xs_encoded)
outputs = self.late_mlp(torch.cat([xs_encoded, outputs], dim=-1))
outputs = self.sigma_layer(outputs)
sigma_is = torch.relu(outputs[:, 0])
outputs = self.pre_final_layer(torch.cat([ds_encoded, outputs[:, 1:]], dim=-1))
c_is = self.final_layer(outputs)
return {"c_is": c_is, "sigma_is": sigma_is}
class NeRF:
def __init__(self, device):
# Initialize coarse and fine MLPs.
self.F_c = NeRFMLP().to(device)
self.F_f = NeRFMLP().to(device)
# Number of query points passed through the MLPs at a time. See:
# https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L488.
self.chunk_size = 1024 * 32
# Initialize volume rendering hyperparameters.
# Near bound. See Section 4.
self.t_n = t_n = 1.0
# Far bound. See Section 4.
self.t_f = t_f = 4.0
# Number of coarse samples along a ray. See Section 5.3.
self.N_c = N_c = 64
# Number of fine samples along a ray. See Section 5.3.
self.N_f = 128
# Bins used to sample depths along a ray. See Equation (2) in Section 4.
self.t_i_c_gap = t_i_c_gap = (t_f - t_n) / N_c
self.t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)
def get_coarse_query_points(self, ds, os):
# Sample depths (t_is_c). See Equation (2) in Section 4.
u_is_c = torch.rand(*list(ds.shape[:2]) + [self.N_c]).to(ds)
t_is_c = self.t_i_c_bin_edges + u_is_c * self.t_i_c_gap
# Calculate the points along the rays (r_ts_c) using the ray origins (os),
# sampled depths (t_is_c), and ray directions (ds). See Section 4:
# r(t) = o + t * d.
r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]
return (r_ts_c, t_is_c)
def get_fine_query_points(self, w_is_c, t_is_c, os, ds):
# See text surrounding Equation (5) in Section 5.2 and:
# https://stephens999.github.io/fiveMinuteStats/inverse_transform_sampling.html#discrete_distributions.
# Define PDFs (pdfs) and CDFs (cdfs) from weights (w_is_c).
w_is_c = w_is_c + 1e-5
pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True)
cdfs = torch.cumsum(pdfs, dim=-1)
cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1)
# Get uniform samples (us).
us = torch.rand(list(cdfs.shape[:-1]) + [self.N_f]).to(w_is_c)
# Use inverse transform sampling to sample the depths (t_is_f).
idxs = torch.searchsorted(cdfs, us, right=True)
t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1)
idxs_capped = idxs.clone()
max_ind = cdfs.shape[-1]
idxs_capped[idxs_capped == max_ind] = max_ind - 1
t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped)
t_i_f_top_edges[idxs == max_ind] = self.t_f
t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges
u_is_f = torch.rand_like(t_i_f_gaps).to(os)
t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps
# Combine the coarse (t_is_c) and fine (t_is_f) depths and sort them.
(t_is_f, _) = torch.sort(torch.cat([t_is_c, t_is_f.detach()], dim=-1), dim=-1)
# Calculate the points along the rays (r_ts_f) using the ray origins (os),
# depths (t_is_f), and ray directions (ds). See Section 4: r(t) = o + t * d.
r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :]
return (r_ts_f, t_is_f)
def render_radiance_volume(self, r_ts, ds, F, t_is):
# Use the network (F) to predict colors (c_is) and volume densities (sigma_is)
# for 3D points along rays (r_ts) given the viewing directions (ds) of the rays.
# See Section 3 and Figure 7 in the Supplementary Materials.
r_ts_flat = r_ts.reshape((-1, 3))
ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)
ds_flat = ds_rep.reshape((-1, 3))
c_is = []
sigma_is = []
# The network processes batches of inputs to avoid running out of memory.
for chunk_start in range(0, r_ts_flat.shape[0], self.chunk_size):
r_ts_batch = r_ts_flat[chunk_start : chunk_start + self.chunk_size]
ds_batch = ds_flat[chunk_start : chunk_start + self.chunk_size]
preds = F(r_ts_batch, ds_batch)
c_is.append(preds["c_is"])
sigma_is.append(preds["sigma_is"])
c_is = torch.cat(c_is).reshape(r_ts.shape)
sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])
# Calculate the distances (delta_is) between points along the rays. The
# differences in depths are scaled by the norms of the ray directions to get the
# final distances. See text following Equation (3) in Section 4.
delta_is = t_is[..., 1:] - t_is[..., :-1]
# "Infinity". Guarantees last alpha is always one.
one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)
delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)
delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)
# Calculate the alphas (alpha_is) of the 3D points using the volume densities
# (sigma_is) and distances between points (delta_is). See text following
# Equation (3) in Section 4 and https://en.wikipedia.org/wiki/Alpha_compositing.
alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)
# Calculate the accumulated transmittances (T_is) along the rays from the alphas
# (alpha_is). See Equation (3) in Section 4. T_i is "the probability that the
# ray travels from t_n to t_i without hitting any other particle".
T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)
# Guarantees the ray makes it at least to the first step. See:
# https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L142,
# which uses tf.math.cumprod(1.-alpha + 1e-10, axis=-1, exclusive=True).
T_is = torch.roll(T_is, 1, -1)
T_is[..., 0] = 1.0
# Calculate the weights (w_is) for the colors (c_is) along the rays using the
# transmittances (T_is) and alphas (alpha_is). See Equation (5) in Section 5.2:
# w_i = T_i * (1 - exp(-sigma_i * delta_i)).
w_is = T_is * alpha_is
# Calculate the pixel colors (C_rs) for the rays as weighted (w_is) sums of
# colors (c_is). See Equation (5) in Section 5.2: C_c_hat(r) = Σ w_i * c_i.
C_rs = (w_is[..., None] * c_is).sum(dim=-2)
return (C_rs, w_is)
def __call__(self, ds, os):
(r_ts_c, t_is_c) = self.get_coarse_query_points(ds, os)
(C_rs_c, w_is_c) = self.render_radiance_volume(r_ts_c, ds, self.F_c, t_is_c)
(r_ts_f, t_is_f) = self.get_fine_query_points(w_is_c, t_is_c, os, ds)
(C_rs_f, _) = self.render_radiance_volume(r_ts_f, ds, self.F_f, t_is_f)
return (C_rs_c, C_rs_f)
def load_data(device):
data_f = "66bdbc812bd0a196e194052f3f12cb2e.npz"
data = np.load(data_f)
# Set up initial ray origin (init_o) and ray directions (init_ds). These are the
# same across samples, we just rotate them based on the orientation of the camera.
# See Section 4.
images = data["images"] / 255
img_size = images.shape[1]
xs = torch.arange(img_size) - (img_size / 2 - 0.5)
ys = torch.arange(img_size) - (img_size / 2 - 0.5)
(xs, ys) = torch.meshgrid(xs, -ys, indexing="xy")
focal = float(data["focal"])
pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)
# We want the zs to be negative ones, so we divide everything by the focal length
# (which is in pixel units).
camera_coords = pixel_coords / focal
init_ds = camera_coords.to(device)
init_o = torch.Tensor(np.array([0, 0, float(data["camera_distance"])])).to(device)
return (images, data["poses"], init_ds, init_o, img_size)
def set_up_test_data(images, device, poses, init_ds, init_o):
# Set up test view.
test_idx = 150
plt.imshow(images[test_idx])
plt.show()
test_img = torch.Tensor(images[test_idx]).to(device)
test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device)
test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds)
test_os = (test_R @ init_o).expand(test_ds.shape)
train_idxs = np.arange(len(images)) != test_idx
return (test_ds, test_os, test_img, train_idxs)
def main():
# Set seed.
seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)
# Initialize NeRF.
device = "cuda:0"
nerf = NeRF(device)
# Number of training rays per iteration. See Section 5.3.
batch_img_size = 64
n_batch_pix = batch_img_size ** 2
# Initialize optimizer. See Section 5.3.
lr = 5e-4
train_params = list(nerf.F_c.parameters()) + list(nerf.F_f.parameters())
optimizer = optim.Adam(train_params, lr=lr)
criterion = nn.MSELoss()
# The learning rate decays exponentially. See Section 5.3
# See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L486.
lrate_decay = 250
decay_steps = lrate_decay * 1000
# See: https://github.com/bmild/nerf/blob/18b8aebda6700ed659cb27a0c348b737a5f6ab60/run_nerf.py#L707.
decay_rate = 0.1
# Load dataset.
(images, poses, init_ds, init_o, img_size) = load_data(device)
(test_ds, test_os, test_img, train_idxs) = set_up_test_data(
images, device, poses, init_ds, init_o
)
images = torch.Tensor(images[train_idxs])
poses = torch.Tensor(poses[train_idxs])
n_pix = img_size ** 2
pixel_ps = torch.full((n_pix,), 1 / n_pix).to(device)
# Start training model.
psnrs = []
iternums = []
# See Section 5.3.
num_iters = 300000
display_every = 100
nerf.F_c.train()
nerf.F_f.train()
for i in range(num_iters):
# Sample image and associated pose.
target_img_idx = np.random.randint(images.shape[0])
target_pose = poses[target_img_idx].to(device)
R = target_pose[:3, :3]
# Get rotated ray origins (os) and ray directions (ds). See Section 4.
ds = torch.einsum("ij,hwj->hwi", R, init_ds)
os = (R @ init_o).expand(ds.shape)
# Sample a batch of rays.
pix_idxs = pixel_ps.multinomial(n_batch_pix, False)
pix_idx_rows = pix_idxs // img_size
pix_idx_cols = pix_idxs % img_size
ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape(
batch_img_size, batch_img_size, -1
)
os_batch = os[pix_idx_rows, pix_idx_cols].reshape(
batch_img_size, batch_img_size, -1
)
# Run NeRF.
(C_rs_c, C_rs_f) = nerf(ds_batch, os_batch)
target_img = images[target_img_idx].to(device)
target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape(C_rs_f.shape)
# Calculate the mean squared error for both the coarse and fine MLP models and
# update the weights. See Equation (6) in Section 5.3.
loss = criterion(C_rs_c, target_img_batch) + criterion(C_rs_f, target_img_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Exponentially decay learning rate. See Section 5.3 and:
# https://keras.io/api/optimizers/learning_rate_schedules/exponential_decay/.
for g in optimizer.param_groups:
g["lr"] = lr * decay_rate ** (i / decay_steps)
if i % display_every == 0:
nerf.F_c.eval()
nerf.F_f.eval()
with torch.no_grad():
(_, C_rs_f) = nerf(test_ds, test_os)
loss = criterion(C_rs_f, test_img)
print(f"Loss: {loss.item()}")
psnr = -10.0 * torch.log10(loss)
psnrs.append(psnr.item())
iternums.append(i)
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.imshow(C_rs_f.detach().cpu().numpy())
plt.title(f"Iteration {i}")
plt.subplot(122)
plt.plot(iternums, psnrs)
plt.title("PSNR")
plt.show()
nerf.F_c.train()
nerf.F_f.train()
print("Done!")
if __name__ == "__main__":
main()
================================================
FILE: run_pixelnerf.py
================================================
import matplotlib.pyplot as plt
import numpy as np
import torch
from image_encoder import ImageEncoder
from pixelnerf_dataset import PixelNeRFDataset
from torch import nn, optim
def get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os):
u_is_c = torch.rand(*list(ds.shape[:2]) + [N_c]).to(ds)
t_is_c = t_i_c_bin_edges + u_is_c * t_i_c_gap
r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]
return (r_ts_c, t_is_c)
def get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds, r_ts_c, N_d, d_std, t_n):
w_is_c = w_is_c + 1e-5
pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True)
cdfs = torch.cumsum(pdfs, dim=-1)
cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1)
us = torch.rand(list(cdfs.shape[:-1]) + [N_f]).to(w_is_c)
idxs = torch.searchsorted(cdfs, us, right=True)
t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1)
idxs_capped = idxs.clone()
max_ind = cdfs.shape[-1]
idxs_capped[idxs_capped == max_ind] = max_ind - 1
t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped)
t_i_f_top_edges[idxs == max_ind] = t_f
t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges
u_is_f = torch.rand_like(t_i_f_gaps).to(os)
t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps
# See Section B.1 in the Supplementary Materials and:
# https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150.
t_is_d = (w_is_c * r_ts_c[..., 2]).sum(dim=-1)
t_is_d = t_is_d.unsqueeze(2).repeat((1, 1, N_d))
t_is_d = t_is_d + torch.normal(0, d_std, size=t_is_d.shape).to(t_is_d)
t_is_d = torch.clamp(t_is_d, t_n, t_f)
t_is_f = torch.cat([t_is_c, t_is_f.detach(), t_is_d], dim=-1)
(t_is_f, _) = torch.sort(t_is_f, dim=-1)
r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :]
return (r_ts_f, t_is_f)
def get_image_features_for_query_points(r_ts, camera_distance, scale, W_i):
# Get the projected image coordinates (pi_x_is) for each point along the rays
# (r_ts). This is just geometry. See: http://www.songho.ca/opengl/gl_projectionmatrix.html.
pi_x_is = r_ts[..., :2] / (camera_distance - r_ts[..., 2].unsqueeze(-1))
pi_x_is = pi_x_is / scale
# PyTorch's grid_sample function assumes (-1, -1) is the left-top pixel, but we want
# (-1, -1) to be the left-bottom pixel, so we negate the y-coordinates.
pi_x_is[..., 1] = -1 * pi_x_is[..., 1]
# PyTorch's grid_sample function expects the grid to have shape
# (N, H_out, W_out, 2).
pi_x_is = pi_x_is.permute(2, 0, 1, 3)
# PyTorch's grid_sample function expects the input to have shape (N, C, H_in, W_in).
W_i = W_i.repeat(pi_x_is.shape[0], 1, 1, 1)
# Get the image features (z_is) associated with the projected image coordinates
# (pi_x_is) from the encoded image features (W_i). See Section 4.2.
z_is = nn.functional.grid_sample(
W_i, pi_x_is, align_corners=True, padding_mode="border"
)
# Convert shape back to match rays.
z_is = z_is.permute(2, 3, 0, 1)
return z_is
def render_radiance_volume(r_ts, ds, z_is, chunk_size, F, t_is):
r_ts_flat = r_ts.reshape((-1, 3))
ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)
ds_flat = ds_rep.reshape((-1, 3))
z_is_flat = z_is.reshape((ds_flat.shape[0], -1))
c_is = []
sigma_is = []
for chunk_start in range(0, r_ts_flat.shape[0], chunk_size):
r_ts_batch = r_ts_flat[chunk_start : chunk_start + chunk_size]
ds_batch = ds_flat[chunk_start : chunk_start + chunk_size]
w_is_batch = z_is_flat[chunk_start : chunk_start + chunk_size]
preds = F(r_ts_batch, ds_batch, w_is_batch)
c_is.append(preds["c_is"])
sigma_is.append(preds["sigma_is"])
c_is = torch.cat(c_is).reshape(r_ts.shape)
sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])
delta_is = t_is[..., 1:] - t_is[..., :-1]
one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)
delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)
delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)
alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)
T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)
T_is = torch.roll(T_is, 1, -1)
T_is[..., 0] = 1.0
w_is = T_is * alpha_is
C_rs = (w_is[..., None] * c_is).sum(dim=-2)
return (C_rs, w_is)
def run_one_iter_of_pixelnerf(
ds,
N_c,
t_i_c_bin_edges,
t_i_c_gap,
os,
camera_distance,
scale,
W_i,
chunk_size,
F_c,
N_f,
t_f,
N_d,
d_std,
t_n,
F_f,
):
(r_ts_c, t_is_c) = get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os)
z_is_c = get_image_features_for_query_points(r_ts_c, camera_distance, scale, W_i)
(C_rs_c, w_is_c) = render_radiance_volume(
r_ts_c, ds, z_is_c, chunk_size, F_c, t_is_c
)
(r_ts_f, t_is_f) = get_fine_query_points(
w_is_c, N_f, t_is_c, t_f, os, ds, r_ts_c, N_d, d_std, t_n
)
z_is_f = get_image_features_for_query_points(r_ts_f, camera_distance, scale, W_i)
(C_rs_f, _) = render_radiance_volume(r_ts_f, ds, z_is_f, chunk_size, F_f, t_is_f)
return (C_rs_c, C_rs_f)
class PixelNeRFFCResNet(nn.Module):
def __init__(self):
super().__init__()
# Number of encoding functions for positions. See Section B.1 in the
# Supplementary Materials.
self.L_pos = 6
# Number of encoding functions for viewing directions.
self.L_dir = 0
pos_enc_feats = 3 + 3 * 2 * self.L_pos
dir_enc_feats = 3 + 3 * 2 * self.L_dir
# Set up ResNet MLP. See Section B.1 and Figure 18 in the Supplementary
# Materials.
net_width = 512
self.first_layer = nn.Sequential(
nn.Linear(pos_enc_feats + dir_enc_feats, net_width)
)
self.n_resnet_blocks = 5
z_linears = []
mlps = []
for resnet_block in range(self.n_resnet_blocks):
z_linears.append(nn.Linear(net_width, net_width))
mlps.append(
nn.Sequential(
nn.Linear(net_width, net_width),
nn.ReLU(),
nn.Linear(net_width, net_width),
nn.ReLU(),
)
)
self.z_linears = nn.ModuleList(z_linears)
self.mlps = nn.ModuleList(mlps)
self.final_layer = nn.Linear(net_width, 4)
def forward(self, xs, ds, zs):
xs_encoded = [xs]
for l_pos in range(self.L_pos):
xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs))
xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs))
xs_encoded = torch.cat(xs_encoded, dim=-1)
ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)
ds_encoded = [ds]
for l_dir in range(self.L_dir):
ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds))
ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds))
ds_encoded = torch.cat(ds_encoded, dim=-1)
# Use the network to predict colors (c_is) and volume densities (sigma_is) for
# 3D points (xs) along rays given the viewing directions (ds) of the rays
# and the associated input image features (zs). See Section B.1 and Figure 18 in
# the Supplementary Materials and:
# https://github.com/sxyu/pixel-nerf/blob/master/src/model/resnetfc.py.
outputs = self.first_layer(torch.cat([xs_encoded, ds_encoded], dim=-1))
for block_idx in range(self.n_resnet_blocks):
resnet_zs = self.z_linears[block_idx](zs)
outputs = outputs + resnet_zs
outputs = self.mlps[block_idx](outputs) + outputs
outputs = self.final_layer(outputs)
sigma_is = torch.relu(outputs[:, 0])
c_is = torch.sigmoid(outputs[:, 1:])
return {"c_is": c_is, "sigma_is": sigma_is}
def load_data():
# Initialize dataset and test object/poses.
data_dir = "data"
# See Section B.2.1 in the Supplementary Materials.
num_iters = 400000
test_obj_idx = 5
test_source_pose_idx = 11
test_target_pose_idx = 33
train_dataset = PixelNeRFDataset(
data_dir, num_iters, test_obj_idx, test_source_pose_idx, test_target_pose_idx
)
return train_dataset
def set_up_test_data(train_dataset, device):
obj_idx = train_dataset.test_obj_idx
obj = train_dataset.objs[obj_idx]
data_dir = train_dataset.data_dir
obj_dir = f"{data_dir}/{obj}"
z_len = train_dataset.z_len
source_pose_idx = train_dataset.test_source_pose_idx
source_img_f = f"{obj_dir}/{str(source_pose_idx).zfill(z_len)}.npy"
source_image = np.load(source_img_f) / 255
source_pose = train_dataset.poses[obj_idx, source_pose_idx]
source_R = source_pose[:3, :3]
target_pose_idx = train_dataset.test_target_pose_idx
target_img_f = f"{obj_dir}/{str(target_pose_idx).zfill(z_len)}.npy"
target_image = np.load(target_img_f) / 255
target_pose = train_dataset.poses[obj_idx, target_pose_idx]
target_R = target_pose[:3, :3]
R = torch.Tensor(source_R.T @ target_R).to(device)
plt.imshow(source_image)
plt.show()
source_image = torch.Tensor(source_image)
source_image = (
source_image - train_dataset.channel_means
) / train_dataset.channel_stds
source_image = source_image.to(device).unsqueeze(0).permute(0, 3, 1, 2)
plt.imshow(target_image)
plt.show()
target_image = torch.Tensor(target_image).to(device)
return (source_image, R, target_image)
def main():
seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)
device = "cuda:0"
F_c = PixelNeRFFCResNet().to(device)
F_f = PixelNeRFFCResNet().to(device)
E = ImageEncoder().to(device)
chunk_size = 1024 * 32
# See Section B.2 in the Supplementary Materials.
batch_img_size = 12
n_batch_pix = batch_img_size**2
n_objs = 4
# See Section B.2 in the Supplementary Materials.
lr = 1e-4
optimizer = optim.Adam(list(F_c.parameters()) + list(F_f.parameters()), lr=lr)
criterion = nn.MSELoss()
train_dataset = load_data()
camera_distance = train_dataset.camera_distance
scale = train_dataset.scale
t_n = 1.0
t_f = 4.0
img_size = train_dataset[0][2].shape[0]
# See Section B.1 in the Supplementary Materials,
# and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/conf/default.conf#L50,
# and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150.
N_c = 64
N_f = 16
N_d = 16
d_std = 0.01
t_i_c_gap = (t_f - t_n) / N_c
t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)
init_o = train_dataset.init_o.to(device)
init_ds = train_dataset.init_ds.to(device)
(test_source_image, test_R, test_target_image) = set_up_test_data(
train_dataset, device
)
test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds)
test_os = (test_R @ init_o).expand(test_ds.shape)
psnrs = []
iternums = []
num_iters = train_dataset.N
use_bbox = True
num_bbox_iters = 300000
display_every = 100
F_c.train()
F_f.train()
E.eval()
for i in range(num_iters):
if i == num_bbox_iters:
use_bbox = False
loss = 0
for obj in range(n_objs):
try:
(source_image, R, target_image, bbox) = train_dataset[0]
except ValueError:
continue
R = R.to(device)
ds = torch.einsum("ij,hwj->hwi", R, init_ds)
os = (R @ init_o).expand(ds.shape)
if use_bbox:
pix_rows = np.arange(bbox[0], bbox[2])
pix_cols = np.arange(bbox[1], bbox[3])
else:
pix_rows = np.arange(0, img_size)
pix_cols = np.arange(0, img_size)
pix_row_cols = np.meshgrid(pix_rows, pix_cols, indexing="ij")
pix_row_cols = np.stack(pix_row_cols).transpose(1, 2, 0).reshape(-1, 2)
choices = np.arange(len(pix_row_cols))
try:
selected_pix = np.random.choice(choices, n_batch_pix, False)
except ValueError:
continue
pix_idx_rows = pix_row_cols[selected_pix, 0]
pix_idx_cols = pix_row_cols[selected_pix, 1]
ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape(
batch_img_size, batch_img_size, -1
)
os_batch = os[pix_idx_rows, pix_idx_cols].reshape(
batch_img_size, batch_img_size, -1
)
# Extract feature pyramid from image. See Section 4.1, Section B.1 in the
# Supplementary Materials, and: https://github.com/sxyu/pixel-nerf/blob/master/src/model/encoder.py.
with torch.no_grad():
W_i = E(source_image.unsqueeze(0).permute(0, 3, 1, 2).to(device))
(C_rs_c, C_rs_f) = run_one_iter_of_pixelnerf(
ds_batch,
N_c,
t_i_c_bin_edges,
t_i_c_gap,
os_batch,
camera_distance,
scale,
W_i,
chunk_size,
F_c,
N_f,
t_f,
N_d,
d_std,
t_n,
F_f,
)
target_img = target_image.to(device)
target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape(
C_rs_c.shape
)
loss += criterion(C_rs_c, target_img_batch)
loss += criterion(C_rs_f, target_img_batch)
try:
optimizer.zero_grad()
loss.backward()
optimizer.step()
except AttributeError:
continue
if i % display_every == 0:
F_c.eval()
F_f.eval()
with torch.no_grad():
test_W_i = E(test_source_image)
(_, C_rs_f) = run_one_iter_of_pixelnerf(
test_ds,
N_c,
t_i_c_bin_edges,
t_i_c_gap,
test_os,
camera_distance,
scale,
test_W_i,
chunk_size,
F_c,
N_f,
t_f,
N_d,
d_std,
t_n,
F_f,
)
loss = criterion(C_rs_f, test_target_image)
print(f"Loss: {loss.item()}")
psnr = -10.0 * torch.log10(loss)
psnrs.append(psnr.item())
iternums.append(i)
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.imshow(C_rs_f.detach().cpu().numpy())
plt.title(f"Iteration {i}")
plt.subplot(122)
plt.plot(iternums, psnrs)
plt.title("PSNR")
plt.show()
F_c.train()
F_f.train()
print("Done!")
if __name__ == "__main__":
main()
================================================
FILE: run_pixelnerf_alt.py
================================================
import matplotlib.pyplot as plt
import numpy as np
import torch
from image_encoder import ImageEncoder
from pixelnerf_dataset import PixelNeRFDataset
from torch import nn, optim
class PixelNeRFFCResNet(nn.Module):
def __init__(self):
super().__init__()
# Number of encoding functions for positions. See Section B.1 in the
# Supplementary Materials.
self.L_pos = 6
# Number of encoding functions for viewing directions.
self.L_dir = 0
pos_enc_feats = 3 + 3 * 2 * self.L_pos
dir_enc_feats = 3 + 3 * 2 * self.L_dir
# Set up ResNet MLP. See Section B.1 and Figure 18 in the Supplementary
# Materials.
net_width = 512
self.first_layer = nn.Sequential(
nn.Linear(pos_enc_feats + dir_enc_feats, net_width)
)
self.n_resnet_blocks = 5
z_linears = []
mlps = []
for resnet_block in range(self.n_resnet_blocks):
z_linears.append(nn.Linear(net_width, net_width))
mlps.append(
nn.Sequential(
nn.Linear(net_width, net_width),
nn.ReLU(),
nn.Linear(net_width, net_width),
nn.ReLU(),
)
)
self.z_linears = nn.ModuleList(z_linears)
self.mlps = nn.ModuleList(mlps)
self.final_layer = nn.Linear(net_width, 4)
def forward(self, xs, ds, zs):
xs_encoded = [xs]
for l_pos in range(self.L_pos):
xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs))
xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs))
xs_encoded = torch.cat(xs_encoded, dim=-1)
ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)
ds_encoded = [ds]
for l_dir in range(self.L_dir):
ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds))
ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds))
ds_encoded = torch.cat(ds_encoded, dim=-1)
# Use the network to predict colors (c_is) and volume densities (sigma_is) for
# 3D points (xs) along rays given the viewing directions (ds) of the rays
# and the associated input image features (zs). See Section B.1 and Figure 18 in
# the Supplementary Materials and:
# https://github.com/sxyu/pixel-nerf/blob/master/src/model/resnetfc.py.
outputs = self.first_layer(torch.cat([xs_encoded, ds_encoded], dim=-1))
for block_idx in range(self.n_resnet_blocks):
resnet_zs = self.z_linears[block_idx](zs)
outputs = outputs + resnet_zs
outputs = self.mlps[block_idx](outputs) + outputs
outputs = self.final_layer(outputs)
sigma_is = torch.relu(outputs[:, 0])
c_is = torch.sigmoid(outputs[:, 1:])
return {"c_is": c_is, "sigma_is": sigma_is}
class PixelNeRF:
def __init__(self, device, camera_distance, scale):
self.device = device
# See Section B.1 in the Supplementary Materials,
# and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/conf/default.conf#L50,
# and: https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150.
self.N_c = N_c = 64
self.N_f = 16
self.N_d = 16
self.d_std = 0.01
self.t_n = t_n = 1.0
self.t_f = t_f = 4.0
self.t_i_c_gap = t_i_c_gap = (t_f - t_n) / N_c
self.t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)
self.F_c = PixelNeRFFCResNet().to(device)
self.F_f = PixelNeRFFCResNet().to(device)
self.E = ImageEncoder().to(device)
self.camera_distance = camera_distance
self.scale = scale
self.chunk_size = 1024 * 32
def get_coarse_query_points(self, ds, os):
u_is_c = torch.rand(*list(ds.shape[:2]) + [self.N_c]).to(ds)
t_is_c = self.t_i_c_bin_edges + u_is_c * self.t_i_c_gap
r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]
return (r_ts_c, t_is_c)
def get_fine_query_points(self, w_is_c, t_is_c, os, ds, r_ts_c):
w_is_c = w_is_c + 1e-5
pdfs = w_is_c / torch.sum(w_is_c, dim=-1, keepdim=True)
cdfs = torch.cumsum(pdfs, dim=-1)
cdfs = torch.cat([torch.zeros_like(cdfs[..., :1]), cdfs[..., :-1]], dim=-1)
us = torch.rand(list(cdfs.shape[:-1]) + [self.N_f]).to(w_is_c)
idxs = torch.searchsorted(cdfs, us, right=True)
t_i_f_bottom_edges = torch.gather(t_is_c, 2, idxs - 1)
idxs_capped = idxs.clone()
max_ind = cdfs.shape[-1]
idxs_capped[idxs_capped == max_ind] = max_ind - 1
t_i_f_top_edges = torch.gather(t_is_c, 2, idxs_capped)
t_i_f_top_edges[idxs == max_ind] = self.t_f
t_i_f_gaps = t_i_f_top_edges - t_i_f_bottom_edges
u_is_f = torch.rand_like(t_i_f_gaps).to(os)
t_is_f = t_i_f_bottom_edges + u_is_f * t_i_f_gaps
# See Section B.1 in the Supplementary Materials and:
# https://github.com/sxyu/pixel-nerf/blob/a5a514224272a91e3ec590f215567032e1f1c260/src/render/nerf.py#L150.
t_is_d = (w_is_c * r_ts_c[..., 2]).sum(dim=-1)
t_is_d = t_is_d.unsqueeze(2).repeat((1, 1, self.N_d))
t_is_d = t_is_d + torch.normal(0, self.d_std, size=t_is_d.shape).to(t_is_d)
t_is_d = torch.clamp(t_is_d, self.t_n, self.t_f)
t_is_f = torch.cat([t_is_c, t_is_f.detach(), t_is_d], dim=-1)
(t_is_f, _) = torch.sort(t_is_f, dim=-1)
r_ts_f = os[..., None, :] + t_is_f[..., :, None] * ds[..., None, :]
return (r_ts_f, t_is_f)
def get_image_features_for_query_points(self, r_ts, W_i):
# Get the projected image coordinates (pi_x_is) for each point along the rays
# (r_ts). This is just geometry. See: http://www.songho.ca/opengl/gl_projectionmatrix.html.
pi_x_is = r_ts[..., :2] / (self.camera_distance - r_ts[..., 2].unsqueeze(-1))
pi_x_is = pi_x_is / self.scale
# PyTorch's grid_sample function assumes (-1, -1) is the left-top pixel, but we want
# (-1, -1) to be the left-bottom pixel, so we negate the y-coordinates.
pi_x_is[..., 1] = -1 * pi_x_is[..., 1]
# PyTorch's grid_sample function expects the grid to have shape
# (N, H_out, W_out, 2).
pi_x_is = pi_x_is.permute(2, 0, 1, 3)
# PyTorch's grid_sample function expects the input to have shape (N, C, H_in, W_in).
W_i = W_i.repeat(pi_x_is.shape[0], 1, 1, 1)
# Get the image features (z_is) associated with the projected image coordinates
# (pi_x_is) from the encoded image features (W_i). See Section 4.2.
z_is = nn.functional.grid_sample(
W_i, pi_x_is, align_corners=True, padding_mode="border"
)
# Convert shape back to match rays.
z_is = z_is.permute(2, 3, 0, 1)
return z_is
def render_radiance_volume(self, r_ts, ds, z_is, F, t_is):
r_ts_flat = r_ts.reshape((-1, 3))
ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)
ds_flat = ds_rep.reshape((-1, 3))
z_is_flat = z_is.reshape((ds_flat.shape[0], -1))
c_is = []
sigma_is = []
for chunk_start in range(0, r_ts_flat.shape[0], self.chunk_size):
r_ts_batch = r_ts_flat[chunk_start : chunk_start + self.chunk_size]
ds_batch = ds_flat[chunk_start : chunk_start + self.chunk_size]
w_is_batch = z_is_flat[chunk_start : chunk_start + self.chunk_size]
preds = F(r_ts_batch, ds_batch, w_is_batch)
c_is.append(preds["c_is"])
sigma_is.append(preds["sigma_is"])
c_is = torch.cat(c_is).reshape(r_ts.shape)
sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])
delta_is = t_is[..., 1:] - t_is[..., :-1]
one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)
delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)
delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)
alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)
T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)
T_is = torch.roll(T_is, 1, -1)
T_is[..., 0] = 1.0
w_is = T_is * alpha_is
C_rs = (w_is[..., None] * c_is).sum(dim=-2)
return (C_rs, w_is)
def __call__(self, ds, os, source_image):
(r_ts_c, t_is_c) = self.get_coarse_query_points(ds, os)
# Extract feature pyramid from image. See Section 4.1, Section B.1 in the
# Supplementary Materials, and: https://github.com/sxyu/pixel-nerf/blob/master/src/model/encoder.py.
with torch.no_grad():
W_i = self.E(source_image.unsqueeze(0).permute(0, 3, 1, 2).to(self.device))
z_is_c = self.get_image_features_for_query_points(r_ts_c, W_i)
(C_rs_c, w_is_c) = self.render_radiance_volume(
r_ts_c, ds, z_is_c, self.F_c, t_is_c
)
(r_ts_f, t_is_f) = self.get_fine_query_points(w_is_c, t_is_c, os, ds, r_ts_c)
z_is_f = self.get_image_features_for_query_points(r_ts_f, W_i)
(C_rs_f, _) = self.render_radiance_volume(r_ts_f, ds, z_is_f, self.F_f, t_is_f)
return (C_rs_c, C_rs_f)
def load_data():
# Initialize dataset and test object/poses.
data_dir = "data"
# See Section B.2.1 in the Supplementary Materials.
num_iters = 400000
test_obj_idx = 5
test_source_pose_idx = 11
test_target_pose_idx = 33
train_dataset = PixelNeRFDataset(
data_dir, num_iters, test_obj_idx, test_source_pose_idx, test_target_pose_idx
)
return (num_iters, train_dataset)
def set_up_test_data(train_dataset, device):
obj_idx = train_dataset.test_obj_idx
obj = train_dataset.objs[obj_idx]
data_dir = train_dataset.data_dir
obj_dir = f"{data_dir}/{obj}"
z_len = train_dataset.z_len
source_pose_idx = train_dataset.test_source_pose_idx
source_img_f = f"{obj_dir}/{str(source_pose_idx).zfill(z_len)}.npy"
source_image = np.load(source_img_f) / 255
source_pose = train_dataset.poses[obj_idx, source_pose_idx]
source_R = source_pose[:3, :3]
target_pose_idx = train_dataset.test_target_pose_idx
target_img_f = f"{obj_dir}/{str(target_pose_idx).zfill(z_len)}.npy"
target_image = np.load(target_img_f) / 255
target_pose = train_dataset.poses[obj_idx, target_pose_idx]
target_R = target_pose[:3, :3]
R = torch.Tensor(source_R.T @ target_R).to(device)
plt.imshow(source_image)
plt.show()
source_image = torch.Tensor(source_image)
source_image = (
source_image - train_dataset.channel_means
) / train_dataset.channel_stds
plt.imshow(target_image)
plt.show()
target_image = torch.Tensor(target_image).to(device)
return (source_image, R, target_image)
def main():
seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)
device = "cuda:0"
(num_iters, train_dataset) = load_data()
img_size = train_dataset[0][2].shape[0]
pixelnerf = PixelNeRF(device, train_dataset.camera_distance, train_dataset.scale)
# See Section B.2 in the Supplementary Materials.
batch_img_size = 12
n_batch_pix = batch_img_size**2
n_objs = 4
# See Section B.2 in the Supplementary Materials.
lr = 1e-4
train_params = list(pixelnerf.F_c.parameters()) + list(pixelnerf.F_f.parameters())
optimizer = optim.Adam(train_params, lr=lr)
criterion = nn.MSELoss()
(test_source_image, test_R, test_target_image) = set_up_test_data(
train_dataset, device
)
init_o = train_dataset.init_o.to(device)
init_ds = train_dataset.init_ds.to(device)
test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds)
test_os = (test_R @ init_o).expand(test_ds.shape)
psnrs = []
iternums = []
use_bbox = True
num_bbox_iters = 300000
display_every = 100
pixelnerf.F_c.train()
pixelnerf.F_f.train()
pixelnerf.E.eval()
for i in range(num_iters):
if i == num_bbox_iters:
use_bbox = False
loss = 0
for obj in range(n_objs):
try:
(source_image, R, target_image, bbox) = train_dataset[0]
except ValueError:
continue
R = R.to(device)
ds = torch.einsum("ij,hwj->hwi", R, init_ds)
os = (R @ init_o).expand(ds.shape)
if use_bbox:
pix_rows = np.arange(bbox[0], bbox[2])
pix_cols = np.arange(bbox[1], bbox[3])
else:
pix_rows = np.arange(0, img_size)
pix_cols = np.arange(0, img_size)
pix_row_cols = np.meshgrid(pix_rows, pix_cols, indexing="ij")
pix_row_cols = np.stack(pix_row_cols).transpose(1, 2, 0).reshape(-1, 2)
choices = np.arange(len(pix_row_cols))
try:
selected_pix = np.random.choice(choices, n_batch_pix, False)
except ValueError:
continue
pix_idx_rows = pix_row_cols[selected_pix, 0]
pix_idx_cols = pix_row_cols[selected_pix, 1]
ds_batch = ds[pix_idx_rows, pix_idx_cols].reshape(
batch_img_size, batch_img_size, -1
)
os_batch = os[pix_idx_rows, pix_idx_cols].reshape(
batch_img_size, batch_img_size, -1
)
(C_rs_c, C_rs_f) = pixelnerf(ds_batch, os_batch, source_image)
target_img = target_image.to(device)
target_img_batch = target_img[pix_idx_rows, pix_idx_cols].reshape(
C_rs_c.shape
)
loss += criterion(C_rs_c, target_img_batch)
loss += criterion(C_rs_f, target_img_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % display_every == 0:
pixelnerf.F_c.eval()
pixelnerf.F_f.eval()
with torch.no_grad():
(_, C_rs_f) = pixelnerf(test_ds, test_os, test_source_image)
loss = criterion(C_rs_f, test_target_image)
print(f"Loss: {loss.item()}")
psnr = -10.0 * torch.log10(loss)
psnrs.append(psnr.item())
iternums.append(i)
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.imshow(C_rs_f.detach().cpu().numpy())
plt.title(f"Iteration {i}")
plt.subplot(122)
plt.plot(iternums, psnrs)
plt.title("PSNR")
plt.show()
pixelnerf.F_c.train()
pixelnerf.F_f.train()
print("Done!")
if __name__ == "__main__":
main()
================================================
FILE: run_tiny_nerf.py
================================================
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
def get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os):
u_is_c = torch.rand(*list(ds.shape[:2]) + [N_c]).to(ds)
t_is_c = t_i_c_bin_edges + u_is_c * t_i_c_gap
r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]
return (r_ts_c, t_is_c)
def render_radiance_volume(r_ts, ds, chunk_size, F, t_is):
r_ts_flat = r_ts.reshape((-1, 3))
ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)
ds_flat = ds_rep.reshape((-1, 3))
c_is = []
sigma_is = []
for chunk_start in range(0, r_ts_flat.shape[0], chunk_size):
r_ts_batch = r_ts_flat[chunk_start : chunk_start + chunk_size]
ds_batch = ds_flat[chunk_start : chunk_start + chunk_size]
preds = F(r_ts_batch, ds_batch)
c_is.append(preds["c_is"])
sigma_is.append(preds["sigma_is"])
c_is = torch.cat(c_is).reshape(r_ts.shape)
sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])
delta_is = t_is[..., 1:] - t_is[..., :-1]
one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)
delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)
delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)
alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)
T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)
T_is = torch.roll(T_is, 1, -1)
T_is[..., 0] = 1.0
w_is = T_is * alpha_is
C_rs = (w_is[..., None] * c_is).sum(dim=-2)
return C_rs
def run_one_iter_of_tiny_nerf(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_c):
(r_ts_c, t_is_c) = get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os)
C_rs_c = render_radiance_volume(r_ts_c, ds, chunk_size, F_c, t_is_c)
return C_rs_c
class VeryTinyNeRFMLP(nn.Module):
def __init__(self):
super().__init__()
self.L_pos = 6
self.L_dir = 4
pos_enc_feats = 3 + 3 * 2 * self.L_pos
dir_enc_feats = 3 + 3 * 2 * self.L_dir
net_width = 256
self.early_mlp = nn.Sequential(
nn.Linear(pos_enc_feats, net_width),
nn.ReLU(),
nn.Linear(net_width, net_width + 1),
nn.ReLU(),
)
self.late_mlp = nn.Sequential(
nn.Linear(net_width + dir_enc_feats, net_width),
nn.ReLU(),
nn.Linear(net_width, 3),
nn.Sigmoid(),
)
def forward(self, xs, ds):
xs_encoded = [xs]
for l_pos in range(self.L_pos):
xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs))
xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs))
xs_encoded = torch.cat(xs_encoded, dim=-1)
ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)
ds_encoded = [ds]
for l_dir in range(self.L_dir):
ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds))
ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds))
ds_encoded = torch.cat(ds_encoded, dim=-1)
outputs = self.early_mlp(xs_encoded)
sigma_is = outputs[:, 0]
c_is = self.late_mlp(torch.cat([outputs[:, 1:], ds_encoded], dim=-1))
return {"c_is": c_is, "sigma_is": sigma_is}
def main():
seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)
device = "cuda:0"
F_c = VeryTinyNeRFMLP().to(device)
chunk_size = 16384
lr = 5e-3
optimizer = optim.Adam(F_c.parameters(), lr=lr)
criterion = nn.MSELoss()
data_f = "66bdbc812bd0a196e194052f3f12cb2e.npz"
data = np.load(data_f)
images = data["images"] / 255
img_size = images.shape[1]
xs = torch.arange(img_size) - (img_size / 2 - 0.5)
ys = torch.arange(img_size) - (img_size / 2 - 0.5)
(xs, ys) = torch.meshgrid(xs, -ys, indexing="xy")
focal = float(data["focal"])
pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)
camera_coords = pixel_coords / focal
init_ds = camera_coords.to(device)
init_o = torch.Tensor(np.array([0, 0, float(data["camera_distance"])])).to(device)
test_idx = 150
plt.imshow(images[test_idx])
plt.show()
test_img = torch.Tensor(images[test_idx]).to(device)
poses = data["poses"]
test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device)
test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds)
test_os = (test_R @ init_o).expand(test_ds.shape)
t_n = 1.0
t_f = 4.0
N_c = 32
t_i_c_gap = (t_f - t_n) / N_c
t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)
train_idxs = np.arange(len(images)) != test_idx
images = torch.Tensor(images[train_idxs])
poses = torch.Tensor(poses[train_idxs])
psnrs = []
iternums = []
num_iters = 20000
display_every = 100
F_c.train()
for i in range(num_iters):
target_img_idx = np.random.randint(images.shape[0])
target_pose = poses[target_img_idx].to(device)
R = target_pose[:3, :3]
ds = torch.einsum("ij,hwj->hwi", R, init_ds)
os = (R @ init_o).expand(ds.shape)
C_rs_c = run_one_iter_of_tiny_nerf(
ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, chunk_size, F_c
)
loss = criterion(C_rs_c, images[target_img_idx].to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % display_every == 0:
F_c.eval()
with torch.no_grad():
C_rs_c = run_one_iter_of_tiny_nerf(
test_ds, N_c, t_i_c_bin_edges, t_i_c_gap, test_os, chunk_size, F_c
)
loss = criterion(C_rs_c, test_img)
print(f"Loss: {loss.item()}")
psnr = -10.0 * torch.log10(loss)
psnrs.append(psnr.item())
iternums.append(i)
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.imshow(C_rs_c.detach().cpu().numpy())
plt.title(f"Iteration {i}")
plt.subplot(122)
plt.plot(iternums, psnrs)
plt.title("PSNR")
plt.show()
F_c.train()
print("Done!")
if __name__ == "__main__":
main()
================================================
FILE: run_tiny_nerf_alt.py
================================================
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
class VeryTinyNeRFMLP(nn.Module):
def __init__(self):
super().__init__()
self.L_pos = 6
self.L_dir = 4
pos_enc_feats = 3 + 3 * 2 * self.L_pos
dir_enc_feats = 3 + 3 * 2 * self.L_dir
net_width = 256
self.early_mlp = nn.Sequential(
nn.Linear(pos_enc_feats, net_width),
nn.ReLU(),
nn.Linear(net_width, net_width + 1),
nn.ReLU(),
)
self.late_mlp = nn.Sequential(
nn.Linear(net_width + dir_enc_feats, net_width),
nn.ReLU(),
nn.Linear(net_width, 3),
nn.Sigmoid(),
)
def forward(self, xs, ds):
xs_encoded = [xs]
for l_pos in range(self.L_pos):
xs_encoded.append(torch.sin(2**l_pos * torch.pi * xs))
xs_encoded.append(torch.cos(2**l_pos * torch.pi * xs))
xs_encoded = torch.cat(xs_encoded, dim=-1)
ds = ds / ds.norm(p=2, dim=-1).unsqueeze(-1)
ds_encoded = [ds]
for l_dir in range(self.L_dir):
ds_encoded.append(torch.sin(2**l_dir * torch.pi * ds))
ds_encoded.append(torch.cos(2**l_dir * torch.pi * ds))
ds_encoded = torch.cat(ds_encoded, dim=-1)
outputs = self.early_mlp(xs_encoded)
sigma_is = outputs[:, 0]
c_is = self.late_mlp(torch.cat([outputs[:, 1:], ds_encoded], dim=-1))
return {"c_is": c_is, "sigma_is": sigma_is}
class VeryTinyNeRF:
def __init__(self, device):
self.F_c = VeryTinyNeRFMLP().to(device)
self.chunk_size = 16384
self.t_n = t_n = 1.0
self.t_f = t_f = 4.0
self.N_c = N_c = 32
self.t_i_c_gap = t_i_c_gap = (t_f - t_n) / N_c
self.t_i_c_bin_edges = (t_n + torch.arange(N_c) * t_i_c_gap).to(device)
def get_coarse_query_points(self, ds, os):
u_is_c = torch.rand(*list(ds.shape[:2]) + [self.N_c]).to(ds)
t_is_c = self.t_i_c_bin_edges + u_is_c * self.t_i_c_gap
r_ts_c = os[..., None, :] + t_is_c[..., :, None] * ds[..., None, :]
return (r_ts_c, t_is_c)
def render_radiance_volume(self, r_ts, ds, F, t_is):
r_ts_flat = r_ts.reshape((-1, 3))
ds_rep = ds.unsqueeze(2).repeat(1, 1, r_ts.shape[-2], 1)
ds_flat = ds_rep.reshape((-1, 3))
c_is = []
sigma_is = []
for chunk_start in range(0, r_ts_flat.shape[0], self.chunk_size):
r_ts_batch = r_ts_flat[chunk_start : chunk_start + self.chunk_size]
ds_batch = ds_flat[chunk_start : chunk_start + self.chunk_size]
preds = F(r_ts_batch, ds_batch)
c_is.append(preds["c_is"])
sigma_is.append(preds["sigma_is"])
c_is = torch.cat(c_is).reshape(r_ts.shape)
sigma_is = torch.cat(sigma_is).reshape(r_ts.shape[:-1])
delta_is = t_is[..., 1:] - t_is[..., :-1]
one_e_10 = torch.Tensor([1e10]).expand(delta_is[..., :1].shape)
delta_is = torch.cat([delta_is, one_e_10.to(delta_is)], dim=-1)
delta_is = delta_is * ds.norm(dim=-1).unsqueeze(-1)
alpha_is = 1.0 - torch.exp(-sigma_is * delta_is)
T_is = torch.cumprod(1.0 - alpha_is + 1e-10, -1)
T_is = torch.roll(T_is, 1, -1)
T_is[..., 0] = 1.0
w_is = T_is * alpha_is
C_rs = (w_is[..., None] * c_is).sum(dim=-2)
return C_rs
def __call__(self, ds, os):
(r_ts_c, t_is_c) = self.get_coarse_query_points(ds, os)
C_rs_c = self.render_radiance_volume(r_ts_c, ds, self.F_c, t_is_c)
return C_rs_c
def load_data(device):
data_f = "66bdbc812bd0a196e194052f3f12cb2e.npz"
data = np.load(data_f)
images = data["images"] / 255
img_size = images.shape[1]
xs = torch.arange(img_size) - (img_size / 2 - 0.5)
ys = torch.arange(img_size) - (img_size / 2 - 0.5)
(xs, ys) = torch.meshgrid(xs, -ys, indexing="xy")
focal = float(data["focal"])
pixel_coords = torch.stack([xs, ys, torch.full_like(xs, -focal)], dim=-1)
camera_coords = pixel_coords / focal
init_ds = camera_coords.to(device)
init_o = torch.Tensor(np.array([0, 0, float(data["camera_distance"])])).to(device)
return (images, data["poses"], init_ds, init_o, img_size)
def set_up_test_data(images, device, poses, init_ds, init_o):
test_idx = 150
plt.imshow(images[test_idx])
plt.show()
test_img = torch.Tensor(images[test_idx]).to(device)
test_R = torch.Tensor(poses[test_idx, :3, :3]).to(device)
test_ds = torch.einsum("ij,hwj->hwi", test_R, init_ds)
test_os = (test_R @ init_o).expand(test_ds.shape)
train_idxs = np.arange(len(images)) != test_idx
return (test_ds, test_os, test_img, train_idxs)
def main():
seed = 9458
torch.manual_seed(seed)
np.random.seed(seed)
device = "cuda:0"
nerf = VeryTinyNeRF(device)
lr = 5e-3
optimizer = optim.Adam(nerf.F_c.parameters(), lr=lr)
criterion = nn.MSELoss()
(images, poses, init_ds, init_o, test_img) = load_data(device)
(test_ds, test_os, test_img, train_idxs) = set_up_test_data(
images, device, poses, init_ds, init_o
)
images = torch.Tensor(images[train_idxs])
poses = torch.Tensor(poses[train_idxs])
psnrs = []
iternums = []
num_iters = 20000
display_every = 100
nerf.F_c.train()
for i in range(num_iters):
target_img_idx = np.random.randint(images.shape[0])
target_pose = poses[target_img_idx].to(device)
R = target_pose[:3, :3]
ds = torch.einsum("ij,hwj->hwi", R, init_ds)
os = (R @ init_o).expand(ds.shape)
C_rs_c = nerf(ds, os)
loss = criterion(C_rs_c, images[target_img_idx].to(device))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % display_every == 0:
nerf.F_c.eval()
with torch.no_grad():
C_rs_c = nerf(test_ds, test_os)
loss = criterion(C_rs_c, test_img)
print(f"Loss: {loss.item()}")
psnr = -10.0 * torch.log10(loss)
psnrs.append(psnr.item())
iternums.append(i)
plt.figure(figsize=(10, 4))
plt.subplot(121)
plt.imshow(C_rs_c.detach().cpu().numpy())
plt.title(f"Iteration {i}")
plt.subplot(122)
plt.plot(iternums, psnrs)
plt.title("PSNR")
plt.show()
nerf.F_c.train()
print("Done!")
if __name__ == "__main__":
main()
gitextract_esqdzl8y/ ├── .gitignore ├── 66bdbc812bd0a196e194052f3f12cb2e.npz ├── LICENSE ├── README.md ├── generate_nerf_dataset.py ├── generate_pixelnerf_dataset.py ├── image_encoder.py ├── pixelnerf_dataset.py ├── renderer.py ├── renderer_settings.py ├── run_nerf.py ├── run_nerf_alt.py ├── run_pixelnerf.py ├── run_pixelnerf_alt.py ├── run_tiny_nerf.py └── run_tiny_nerf_alt.py
SYMBOL INDEX (92 symbols across 11 files)
FILE: generate_nerf_dataset.py
function main (line 10) | def main():
FILE: generate_pixelnerf_dataset.py
function main (line 12) | def main():
FILE: image_encoder.py
class ImageEncoder (line 8) | class ImageEncoder(nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 13) | def forward(self, x):
FILE: pixelnerf_dataset.py
class PixelNeRFDataset (line 7) | class PixelNeRFDataset(Dataset):
method __init__ (line 8) | def __init__(
method __len__ (line 48) | def __len__(self):
method __getitem__ (line 51) | def __getitem__(self, idx):
FILE: renderer.py
function gen_rotation_matrix_from_cam_pos (line 14) | def gen_rotation_matrix_from_cam_pos(xyz, in_plane=0.0):
function gen_rotation_matrix_from_azim_elev_in_plane (line 33) | def gen_rotation_matrix_from_azim_elev_in_plane(
function parse_obj_file (line 57) | def parse_obj_file(input_obj):
function parse_mtl_file (line 144) | def parse_mtl_file(input_mtl):
function get_texture_data (line 186) | def get_texture_data(sub_objs, packed_arrays, mtl_infos, obj_f):
class Renderer (line 232) | class Renderer:
method __init__ (line 233) | def __init__(
method set_up_obj (line 420) | def set_up_obj(self, obj_f, mtl_f):
method load_obj (line 430) | def load_obj(self, packed_arrays, vertices, mtl_infos, sub_objs, textu...
method set_up_background (line 466) | def set_up_background(self, background_f=None):
method render (line 512) | def render(self, r=0.485, g=0.456, b=0.406, with_alpha=False):
method get_vertex_screen_coordinates (line 564) | def get_vertex_screen_coordinates(self):
method __del__ (line 599) | def __del__(self):
method release_obj (line 602) | def release_obj(self):
method release_background (line 613) | def release_background(self):
method release (line 620) | def release(self):
method adjust_angle_of_view (line 628) | def adjust_angle_of_view(self, angle_of_view):
method set_params (line 635) | def set_params(self, params):
method get_depth_arrays (line 658) | def get_depth_arrays(self):
method get_depth_map (line 668) | def get_depth_map(self):
method get_normal_map (line 673) | def get_normal_map(self):
FILE: run_nerf.py
function get_coarse_query_points (line 8) | def get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os):
function get_fine_query_points (line 18) | def get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds):
function render_radiance_volume (line 51) | def render_radiance_volume(r_ts, ds, chunk_size, F, t_is):
function run_one_iter_of_nerf (line 107) | def run_one_iter_of_nerf(
class NeRFMLP (line 119) | class NeRFMLP(nn.Module):
method __init__ (line 120) | def __init__(self):
method forward (line 155) | def forward(self, xs, ds):
function main (line 184) | def main():
FILE: run_nerf_alt.py
class NeRFMLP (line 8) | class NeRFMLP(nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 44) | def forward(self, xs, ds):
class NeRF (line 73) | class NeRF:
method __init__ (line 74) | def __init__(self, device):
method get_coarse_query_points (line 96) | def get_coarse_query_points(self, ds, os):
method get_fine_query_points (line 106) | def get_fine_query_points(self, w_is_c, t_is_c, os, ds):
method render_radiance_volume (line 138) | def render_radiance_volume(self, r_ts, ds, F, t_is):
method __call__ (line 193) | def __call__(self, ds, os):
function load_data (line 203) | def load_data(device):
function set_up_test_data (line 226) | def set_up_test_data(images, device, poses, init_ds, init_o):
function main (line 241) | def main():
FILE: run_pixelnerf.py
function get_coarse_query_points (line 10) | def get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os):
function get_fine_query_points (line 17) | def get_fine_query_points(w_is_c, N_f, t_is_c, t_f, os, ds, r_ts_c, N_d,...
function get_image_features_for_query_points (line 50) | def get_image_features_for_query_points(r_ts, camera_distance, scale, W_i):
function render_radiance_volume (line 73) | def render_radiance_volume(r_ts, ds, z_is, chunk_size, F, t_is):
function run_one_iter_of_pixelnerf (line 109) | def run_one_iter_of_pixelnerf(
class PixelNeRFFCResNet (line 141) | class PixelNeRFFCResNet(nn.Module):
method __init__ (line 142) | def __init__(self):
method forward (line 176) | def forward(self, xs, ds, zs):
function load_data (line 209) | def load_data():
function set_up_test_data (line 223) | def set_up_test_data(train_dataset, device):
function main (line 258) | def main():
FILE: run_pixelnerf_alt.py
class PixelNeRFFCResNet (line 10) | class PixelNeRFFCResNet(nn.Module):
method __init__ (line 11) | def __init__(self):
method forward (line 45) | def forward(self, xs, ds, zs):
class PixelNeRF (line 78) | class PixelNeRF:
method __init__ (line 79) | def __init__(self, device, camera_distance, scale):
method get_coarse_query_points (line 104) | def get_coarse_query_points(self, ds, os):
method get_fine_query_points (line 110) | def get_fine_query_points(self, w_is_c, t_is_c, os, ds, r_ts_c):
method get_image_features_for_query_points (line 142) | def get_image_features_for_query_points(self, r_ts, W_i):
method render_radiance_volume (line 164) | def render_radiance_volume(self, r_ts, ds, z_is, F, t_is):
method __call__ (line 199) | def __call__(self, ds, os, source_image):
function load_data (line 217) | def load_data():
function set_up_test_data (line 232) | def set_up_test_data(train_dataset, device):
function main (line 266) | def main():
FILE: run_tiny_nerf.py
function get_coarse_query_points (line 8) | def get_coarse_query_points(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os):
function render_radiance_volume (line 15) | def render_radiance_volume(r_ts, ds, chunk_size, F, t_is):
function run_one_iter_of_tiny_nerf (line 49) | def run_one_iter_of_tiny_nerf(ds, N_c, t_i_c_bin_edges, t_i_c_gap, os, c...
class VeryTinyNeRFMLP (line 55) | class VeryTinyNeRFMLP(nn.Module):
method __init__ (line 56) | def __init__(self):
method forward (line 77) | def forward(self, xs, ds):
function main (line 99) | def main():
FILE: run_tiny_nerf_alt.py
class VeryTinyNeRFMLP (line 8) | class VeryTinyNeRFMLP(nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 30) | def forward(self, xs, ds):
class VeryTinyNeRF (line 52) | class VeryTinyNeRF:
method __init__ (line 53) | def __init__(self, device):
method get_coarse_query_points (line 62) | def get_coarse_query_points(self, ds, os):
method render_radiance_volume (line 68) | def render_radiance_volume(self, r_ts, ds, F, t_is):
method __call__ (line 101) | def __call__(self, ds, os):
function load_data (line 107) | def load_data(device):
function set_up_test_data (line 125) | def set_up_test_data(images, device, poses, init_ds, init_o):
function main (line 139) | def main():
Condensed preview — 16 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (117K chars).
[
{
"path": ".gitignore",
"chars": 32,
"preview": ".idea\n__pycache__\ndata.zip\ndata\n"
},
{
"path": "LICENSE",
"chars": 1057,
"preview": "Copyright 2022 Michael A. Alcorn\n\nPermission is hereby granted, free of charge, to any person obtaining a copy of this s"
},
{
"path": "README.md",
"chars": 5296,
"preview": "# PyTorch NeRF and pixelNeRF\n\n**NeRF**: ["
},
{
"path": "generate_nerf_dataset.py",
"chars": 2237,
"preview": "import numpy as np\n\nfrom pyrr import Matrix44\nfrom renderer import gen_rotation_matrix_from_cam_pos, Renderer\nfrom rende"
},
{
"path": "generate_pixelnerf_dataset.py",
"chars": 3117,
"preview": "import numpy as np\nimport os\nimport sys\n\nfrom pyrr import Matrix44\nfrom renderer import gen_rotation_matrix_from_cam_pos"
},
{
"path": "image_encoder.py",
"chars": 1032,
"preview": "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nfrom torchvision.models import resnet34\n\n\nclass Imag"
},
{
"path": "pixelnerf_dataset.py",
"chars": 3523,
"preview": "import numpy as np\nimport torch\n\nfrom torch.utils.data import Dataset\n\n\nclass PixelNeRFDataset(Dataset):\n def __init_"
},
{
"path": "renderer.py",
"chars": 24975,
"preview": "import logging\nimport moderngl\nimport numpy as np\n\nfrom PIL import Image, ImageOps\nfrom pyrr import Matrix44\nfrom scipy."
},
{
"path": "renderer_settings.py",
"chars": 272,
"preview": "WINDOW_SIZE = 256\nIMG_SIZE = 128\nCULL_FACES = True\n\nCAMERA_DISTANCE = 2.25\n# See: https://en.wikipedia.org/wiki/Angle_of"
},
{
"path": "run_nerf.py",
"chars": 14171,
"preview": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom torch import nn, optim\n\n\ndef get_coarse_query_poin"
},
{
"path": "run_nerf_alt.py",
"chars": 14602,
"preview": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom torch import nn, optim\n\n\nclass NeRFMLP(nn.Module):"
},
{
"path": "run_pixelnerf.py",
"chars": 15209,
"preview": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom image_encoder import ImageEncoder\nfrom pixelnerf_d"
},
{
"path": "run_pixelnerf_alt.py",
"chars": 14686,
"preview": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom image_encoder import ImageEncoder\nfrom pixelnerf_d"
},
{
"path": "run_tiny_nerf.py",
"chars": 6177,
"preview": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom torch import nn, optim\n\n\ndef get_coarse_query_poin"
},
{
"path": "run_tiny_nerf_alt.py",
"chars": 6576,
"preview": "import matplotlib.pyplot as plt\nimport numpy as np\nimport torch\n\nfrom torch import nn, optim\n\n\nclass VeryTinyNeRFMLP(nn."
}
]
// ... and 1 more files (download for full content)
About this extraction
This page contains the full source code of the airalcorn2/pytorch-nerf GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 16 files (23.1 MB), approximately 31.5k tokens, and a symbol index with 92 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.