Repository: zeng-yifei/STAG4D Branch: main Commit: 9aa21c92f40b Files: 25 Total size: 257.0 KB Directory structure: gitextract_687zu6g9/ ├── .gitignore ├── README.md ├── __init__.py ├── cam_utils.py ├── configs/ │ └── stag4d.yaml ├── dataset_4d.py ├── deform.py ├── gs_renderer_4d.py ├── guidance/ │ ├── zero123_4d_utils.py │ └── zero123pp/ │ └── pipeline.py ├── main.py ├── mini_trainer.ipynb ├── requirements.txt ├── scripts/ │ ├── app.py │ └── gen_mv.py ├── sh_utils.py ├── simple-knn/ │ ├── ext.cpp │ ├── setup.py │ ├── simple_knn/ │ │ └── .gitkeep │ ├── simple_knn.cu │ ├── simple_knn.h │ ├── spatial.cu │ └── spatial.h ├── visualize.py └── zero123.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pyc ./valid/* ================================================ FILE: README.md ================================================
# Text-to-4D
For Text to 4D generation, we recommend using SDXL and SVD to generate a reasonable video. Then, after matting the video, use
the command above to generate a good 4D result. (This pipeline contains many independent parts and is kind of complex, so we may upload the whole workflow after integration if possible.)
If you want generate the examples in the paper, I also updated the corresponding data here in [google drive](https://drive.google.com/file/d/1EDNL7EBMR1vlfMOABdXjHzcKY7IXdcnj/view?usp=sharing). Remember to set size to 26 in config or use `size=26` in the command:
```bash
python main.py --config configs/stag4d.yaml path=dataset/xxx save_path=xxx size=26
```
# Tips for better quality
If you want sacrifice time for better quality, here is some tips you can try to further improve the generated quality.
1, Use larger batch size.
2, Run for more steps.
## Citation
If you find our work useful for your research, please consider citing our paper as well as Consistent4D:
```
@article{zeng2024stag4d,
title={STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians},
author={Yifei Zeng and Yanqin Jiang and Siyu Zhu and Yuanxun Lu and Youtian Lin and Hao Zhu and Weiming Hu and Xun Cao and Yao Yao},
year={2024},
eprint={2403.14939},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@article{jiang2023consistent4d,
title={Consistent4D: Consistent 360{\deg} Dynamic Object Generation from Monocular Video},
author={Yanqin Jiang and Li Zhang and Jin Gao and Weimin Hu and Yao Yao},
year={2023},
eprint={2311.02848},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
# Acknowledgment
This repo is built on [DreamGaussian](https://github.com/dreamgaussian/dreamgaussian) and [Zero123plus](https://github.com/SUDO-AI-3D/zero123plus). Thank all the authors for their great work.
================================================
FILE: __init__.py
================================================
================================================
FILE: cam_utils.py
================================================
import numpy as np
from scipy.spatial.transform import Rotation as R
import torch
def dot(x, y):
if isinstance(x, np.ndarray):
return np.sum(x * y, -1, keepdims=True)
else:
return torch.sum(x * y, -1, keepdim=True)
def length(x, eps=1e-20):
if isinstance(x, np.ndarray):
return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
else:
return torch.sqrt(torch.clamp(dot(x, x), min=eps))
def safe_normalize(x, eps=1e-20):
return x / length(x, eps)
def look_at(campos, target, opengl=True):
# campos: [N, 3], camera/eye position
# target: [N, 3], object to look at
# return: [N, 3, 3], rotation matrix
if not opengl:
# camera forward aligns with -z
forward_vector = safe_normalize(target - campos)
up_vector = np.array([0, 1, 0], dtype=np.float32)
right_vector = safe_normalize(np.cross(forward_vector, up_vector))
up_vector = safe_normalize(np.cross(right_vector, forward_vector))
else:
# camera forward aligns with +z
forward_vector = safe_normalize(campos - target)
up_vector = np.array([0, 1, 0], dtype=np.float32)
right_vector = safe_normalize(np.cross(up_vector, forward_vector))
up_vector = safe_normalize(np.cross(forward_vector, right_vector))
R = np.stack([right_vector, up_vector, forward_vector], axis=1)
return R
# elevation & azimuth to pose (cam2world) matrix
def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):
# radius: scalar
# elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)
# azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)
# return: [4, 4], camera pose matrix
if is_degree:
elevation = np.deg2rad(np.array(elevation))
azimuth = np.deg2rad(np.array(azimuth))
x = radius * np.cos(elevation) * np.sin(azimuth)
y = - radius * np.sin(elevation)
z = radius * np.cos(elevation) * np.cos(azimuth)
if target is None:
target = np.zeros([3], dtype=np.float32)
campos = np.array([x, y, z]) + target # [3]
T = np.eye(4, dtype=np.float32)
T[:3, :3] = look_at(campos, target, opengl)
T[:3, 3] = campos
return T
class OrbitCamera:
def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
self.W = W
self.H = H
self.radius = r # camera distance from center
self.fovy = np.deg2rad(fovy) # deg 2 rad
self.near = near
self.far = far
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
self.rot = R.from_matrix(np.eye(3))
self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
@property
def fovx(self):
return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)
@property
def campos(self):
return self.pose[:3, 3]
# pose (c2w)
@property
def pose(self):
# first move camera to radius
res = np.eye(4, dtype=np.float32)
res[2, 3] = self.radius # opengl convention...
# rotate
rot = np.eye(4, dtype=np.float32)
rot[:3, :3] = self.rot.as_matrix()
res = rot @ res
# translate
res[:3, 3] -= self.center
return res
# view (w2c)
@property
def view(self):
return np.linalg.inv(self.pose)
# projection (perspective)
@property
def perspective(self):
y = np.tan(self.fovy / 2)
aspect = self.W / self.H
return np.array(
[
[1 / (y * aspect), 0, 0, 0],
[0, -1 / y, 0, 0],
[
0,
0,
-(self.far + self.near) / (self.far - self.near),
-(2 * self.far * self.near) / (self.far - self.near),
],
[0, 0, -1, 0],
],
dtype=np.float32,
)
# intrinsics
@property
def intrinsics(self):
focal = self.H / (2 * np.tan(self.fovy / 2))
return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)
@property
def mvp(self):
return self.perspective @ np.linalg.inv(self.pose) # [4, 4]
def orbit(self, dx, dy):
# rotate along camera up/side axis!
side = self.rot.as_matrix()[:3, 0]
rotvec_x = self.up * np.radians(-0.05 * dx)
rotvec_y = side * np.radians(-0.05 * dy)
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
def scale(self, delta):
self.radius *= 1.1 ** (-delta)
def pan(self, dx, dy, dz=0):
# pan in camera coordinate system (careful on the sensitivity!)
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])
================================================
FILE: configs/stag4d.yaml
================================================
### Input
# input rgba image path (default to None, can be load in GUI too)
input:
# input text prompt (default to None, can be input in GUI too)
prompt: a minion
# input mesh for stage 2 (auto-search from stage 1 output path if None)
mesh:
# estimated elevation angle for input image
elevation: 0
# reference image resolution
ref_size: 512
# density thresh for mesh extraction
density_thresh: 1
device: cuda
#dynamic
size: 30
path: dataset/minions
# checkpoint to load for stage 1 (should be a ply file)
load:
### Output
outdir: logs
mesh_format: obj
save_path: ???
save_step: 8000
#checkpoint to load for stage fine (should be a path of ply with deform pth)
load_path:
load_step:
valid_interval: 500
### Training
# guidance loss weights (0 to disable)
lambda_sd: 0
mvdream: False
lambda_zero123: 1
lambda_tv: 1
scale_loss_ratio: 7.5
imagedream: False
# training batch size per iter
batch_size: 4
# training iterations for stage 1
iters: 2000
# training iterations for stage 2
iters_refine: 50
# training camera radius
radius: 2
# training camera fovy
fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61
# whether allow geom training in stage 2
train_geo: False
# prob to invert background color during training (0 = always black, 1 = always white)
invert_bg_prob: 0.5
### GUI
gui: False
force_cuda_rast: False
# GUI resolution
H: 800
W: 800
deformation_lr_init : 0.00016
deformation_lr_final : 0.000016
deformation_lr_delay_mult : 0.02
grid_lr_init : 0.0016
grid_lr_final : 0.00016
### Gaussian splatting
num_pts: 10000
sh_degree: 0
position_lr_init : 0.0002
position_lr_final : 0.000002
position_lr_delay_mult: 0.01
position_lr_max_steps: 2000
position_lr_max_steps2: 5000
feature_lr: 0.005
opacity_lr: 0.02
scaling_lr: 0.01
rotation_lr: 0.002
init_steps: 700
percent_dense: 0.1
density_start_iter: 1200
density_end_iter: 6000
densification_interval: 100
opacity_reset_interval: 700
densify_grad_threshold_percent: 0.025
time_smoothness_weight: 5
plane_tv_weight: 0.05
l1_time_planes: 0.05
### Textured Mesh
geom_lr: 0.0001
texture_lr: 0.2
================================================
FILE: dataset_4d.py
================================================
import os
import cv2
import glob
import json
import tqdm
import random
import numpy as np
from scipy.spatial.transform import Slerp, Rotation
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import rembg
import glob
class SparseDataset:
def __init__(self, opt, size,device='cuda', type='train', H=256, W=256):
super().__init__()
self.opt = opt
self.device = device
self.type = type # train, val, test
self.size = size
self.H = H
self.W = W
self.path = opt.path
self.cx = self.H / 2
self.cy = self.W / 2
self.bg_remover=None
def collate_ref(self,index):
#print(index,str(index))
file = os.path.join(self.path,'ref/{}_rgba.png'.format(str(index)))
#print(f'[INFO] load image from {file}...')
img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
if img.shape[-1] == 3:
if self.bg_remover is None:
self.bg_remover = rembg.new_session()
img = rembg.remove(img, session=self.bg_remover)
img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32) / 255.0
self.input_mask = img[..., 3:]
# white bg
self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)
# bgr to rgb
self.input_img = self.input_img[..., ::-1].copy()
return self.input_img ,self.input_mask
def collate_zero123(self,index):
self.pattern=os.path.join(self.path,'zero123/{}_rgba/*.png'.format(str(index)))
self.input_imgs=[]
self.input_masks=[]
file_list = glob.glob(self.pattern)
#print(self.pattern,file_list)
for files in sorted(file_list):
#print(f'[INFO] load image from {files}...')
img = cv2.imread(files, cv2.IMREAD_UNCHANGED)
if img.shape[-1] == 3:
if self.bg_remover is None:
self.bg_remover = rembg.new_session()
img = rembg.remove(img, session=self.bg_remover)
img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32) / 255.0
self.input_mask = img[..., 3:]
# white bg
self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)
# bgr to rgb
self.input_img = self.input_img[..., ::-1].copy()
self.input_imgs.append(self.input_img)
self.input_masks.append(self.input_mask)
return self.input_imgs, self.input_masks
def collate(self, index):
ref_view_batch,input_mask_batch,zero123_view_batch,zero123_masks_batch = [],[],[],[]
for index in np.arange(self.size):
ref_view,input_mask = self.collate_ref(index)
zero123_view,zero123_masks = self.collate_zero123(index)
ref_view_batch.append(ref_view)
input_mask_batch.append(input_mask)
zero123_view_batch.append(zero123_view)
zero123_masks_batch.append(zero123_masks)
return ref_view_batch, input_mask_batch,zero123_view_batch,zero123_masks_batch
def dataloader(self):
loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate,shuffle=False, num_workers=0)
return loader
def dataloader_d(self):
loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate_d,shuffle=False, num_workers=0)
return loader
================================================
FILE: deform.py
================================================
import functools
import math
import os
import time
from tkinter import W
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load
import torch.nn.init as init
import abc
import itertools
import logging as log
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
class Deformation(nn.Module):
def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[], args=None):
super(Deformation, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_time = input_ch_time
self.skips = skips
self.no_grid=False
self.no_ds=False
self.no_dr=False
self.no_do=True
self.bounds = 1.6
self.kplanes_config = {
'grid_dimensions': 2,
'input_coordinate_dim': 4,
'output_coordinate_dim': 32,
'resolution': [64, 64, 64, 25]
}
self.multires = [1, 2, 4, 8]
self.no_grid = self.no_grid
self.grid = HexPlaneField(self.bounds, self.kplanes_config, self.multires)
self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_net()
def create_net(self):
mlp_out_dim = 0
if self.no_grid:
self.feature_out = [nn.Linear(4,self.W)]
else:
self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)]
for i in range(self.D-1):
self.feature_out.append(nn.ReLU())
self.feature_out.append(nn.Linear(self.W,self.W))
self.feature_out = nn.Sequential(*self.feature_out)
output_dim = self.W
return \
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4)), \
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1))
def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb):
if self.no_grid:
h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1)
else:
grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1])
h = grid_feature
h = self.feature_out(h)
return h
def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None, time_emb=None):
if time_emb is None:
return self.forward_static(rays_pts_emb[:,:3])
else:
return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, time_emb)
def forward_static(self, rays_pts_emb):
grid_feature = self.grid(rays_pts_emb[:,:3])
dx = self.static_mlp(grid_feature)
return rays_pts_emb[:, :3] + dx
def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, time_emb):
hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_emb).float()
dx = self.pos_deform(hidden)
pts = rays_pts_emb[:, :3] + dx
if self.no_ds:
scales = scales_emb[:,:3]
else:
ds = self.scales_deform(hidden)
scales = scales_emb[:,:3] + ds
if self.no_dr:
rotations = rotations_emb[:,:4]
else:
dr = self.rotations_deform(hidden)
rotations = rotations_emb[:,:4] + dr
if self.no_do:
opacity = opacity_emb[:,:1]
else:
do = self.opacity_deform(hidden)
opacity = opacity_emb[:,:1] + do
# + do
# print("deformation value:","pts:",torch.abs(dx).mean(),"rotation:",torch.abs(dr).mean())
return pts, scales, rotations, opacity
def get_mlp_parameters(self):
parameter_list = []
for name, param in self.named_parameters():
if "grid" not in name:
parameter_list.append(param)
return parameter_list
def get_grid_parameters(self):
return list(self.grid.parameters() )
# + list(self.timegrid.parameters())
class deform_network(nn.Module):
def __init__(self) :
super(deform_network, self).__init__()
net_width = 64
timebase_pe = 4
defor_depth= 1
posbase_pe= 10
scale_rotation_pe = 2
opacity_pe = 2
timenet_width = 64
timenet_output = 32
times_ch = 2*timebase_pe+1
self.timenet = nn.Sequential(
nn.Linear(times_ch, timenet_width), nn.ReLU(),
nn.Linear(timenet_width, timenet_output))
self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(4+3)+((4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output, args=None)
self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)]))
self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)]))
self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)]))
self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)]))
self.apply(initialize_weights)
# print(self)
def forward(self, point, scales=None, rotations=None, opacity=None, times_sel=None):
if times_sel is not None:
return self.forward_dynamic(point, scales, rotations, opacity, times_sel)
else:
return self.forward_static(point)
def forward_static(self, points):
points = self.deformation_net(points)
return points
def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, times_sel=None):
# times_emb = poc_fre(times_sel, self.time_poc)
means3D, scales, rotations, opacity = self.deformation_net( point,
scales,
rotations,
opacity,
# times_feature,
times_sel)
return means3D, scales, rotations, opacity
def get_mlp_parameters(self):
return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters())
def get_grid_parameters(self):
return self.deformation_net.get_grid_parameters()
def initialize_weights(m):
if isinstance(m, nn.Linear):
# init.constant_(m.weight, 0)
init.xavier_uniform_(m.weight,gain=1)
if m.bias is not None:
init.xavier_uniform_(m.weight,gain=1)
# init.constant_(m.bias, 0)
def get_normalized_directions(directions):
"""SH encoding must be in the range [0, 1]
Args:
directions: batch of directions
"""
return (directions + 1.0) / 2.0
def normalize_aabb(pts, aabb):
return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0
def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor:
grid_dim = coords.shape[-1]
if grid.dim() == grid_dim + 1:
# no batch dimension present, need to add it
grid = grid.unsqueeze(0)
if coords.dim() == 2:
coords = coords.unsqueeze(0)
if grid_dim == 2 or grid_dim == 3:
grid_sampler = F.grid_sample
else:
raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only "
f"implemented for 2 and 3D data.")
coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:]))
B, feature_dim = grid.shape[:2]
n = coords.shape[-2]
interp = grid_sampler(
grid, # [B, feature_dim, reso, ...]
coords, # [B, 1, ..., n, grid_dim]
align_corners=align_corners,
mode='bilinear', padding_mode='border')
interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim]
interp = interp.squeeze() # [B?, n, feature_dim?]
return interp
def init_grid_param(
grid_nd: int,
in_dim: int,
out_dim: int,
reso: Sequence[int],
a: float = 0.1,
b: float = 0.5):
assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"
has_time_planes = in_dim == 4
assert grid_nd <= in_dim
coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
grid_coefs = nn.ParameterList()
for ci, coo_comb in enumerate(coo_combs):
new_grid_coef = nn.Parameter(torch.empty(
[1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]
))
if has_time_planes and 3 in coo_comb: # Initialize time planes to 1
nn.init.ones_(new_grid_coef)
else:
nn.init.uniform_(new_grid_coef, a=a, b=b)
grid_coefs.append(new_grid_coef)
return grid_coefs
def interpolate_ms_features(pts: torch.Tensor,
ms_grids: Collection[Iterable[nn.Module]],
grid_dimensions: int,
concat_features: bool,
num_levels: Optional[int],
) -> torch.Tensor:
coo_combs = list(itertools.combinations(
range(pts.shape[-1]), grid_dimensions)
)
if num_levels is None:
num_levels = len(ms_grids)
multi_scale_interp = [] if concat_features else 0.
grid: nn.ParameterList
for scale_id, grid in enumerate(ms_grids[:num_levels]):
interp_space = 1.
for ci, coo_comb in enumerate(coo_combs):
# interpolate in plane
feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso
interp_out_plane = (
grid_sample_wrapper(grid[ci], pts[..., coo_comb])
.view(-1, feature_dim)
)
# compute product over planes
interp_space = interp_space * interp_out_plane
# combine over scales
if concat_features:
multi_scale_interp.append(interp_space)
else:
multi_scale_interp = multi_scale_interp + interp_space
if concat_features:
multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
return multi_scale_interp
class HexPlaneField(nn.Module):
def __init__(
self,
bounds,
planeconfig,
multires
) -> None:
super().__init__()
aabb = torch.tensor([[bounds,bounds,bounds],
[-bounds,-bounds,-bounds]])
self.aabb = nn.Parameter(aabb, requires_grad=False)
self.grid_config = [planeconfig]
self.multiscale_res_multipliers = multires
self.concat_features = True
# 1. Init planes
self.grids = nn.ModuleList()
self.feat_dim = 0
for res in self.multiscale_res_multipliers:
# initialize coordinate grid
config = self.grid_config[0].copy()
# Resolution fix: multi-res only on spatial planes
config["resolution"] = [
r * res for r in config["resolution"][:3]
] + config["resolution"][3:]
gp = init_grid_param(
grid_nd=config["grid_dimensions"],
in_dim=config["input_coordinate_dim"],
out_dim=config["output_coordinate_dim"],
reso=config["resolution"],
)
# shape[1] is out-dim - Concatenate over feature len for each scale
if self.concat_features:
self.feat_dim += gp[-1].shape[1]
else:
self.feat_dim = gp[-1].shape[1]
self.grids.append(gp)
# print(f"Initialized model grids: {self.grids}")
print("feature_dim:",self.feat_dim)
def set_aabb(self,xyz_max, xyz_min):
aabb = torch.tensor([
xyz_max,
xyz_min
])
self.aabb = nn.Parameter(aabb,requires_grad=True)
print("Voxel Plane: set aabb=",self.aabb)
def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
"""Computes and returns the densities."""
pts = normalize_aabb(pts, self.aabb)
pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4]
pts = pts.reshape(-1, pts.shape[-1])
features = interpolate_ms_features(
pts, ms_grids=self.grids, # noqa
grid_dimensions=self.grid_config[0]["grid_dimensions"],
concat_features=self.concat_features, num_levels=None)
if len(features) < 1:
features = torch.zeros((0, 1)).to(features.device)
return features
def forward(self,
pts: torch.Tensor,
timestamps: Optional[torch.Tensor] = None):
features = self.get_density(pts, timestamps)
return features
def compute_plane_tv(t):
batch_size, c, h, w = t.shape
count_h = batch_size * c * (h - 1) * w
count_w = batch_size * c * h * (w - 1)
h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum()
w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum()
return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg
def compute_plane_smoothness(t):
batch_size, c, h, w = t.shape
# Convolve with a second derivative filter, in the time dimension which is dimension 2
first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w]
second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w]
# Take the L2 norm of the result
return torch.square(torch.abs(second_difference)).mean()
class Regularizer():
def __init__(self, reg_type, initialization):
self.reg_type = reg_type
self.initialization = initialization
self.weight = float(self.initialization)
self.last_reg = None
def step(self, global_step):
pass
def report(self, d):
if self.last_reg is not None:
d[self.reg_type].update(self.last_reg.item())
def regularize(self, *args, **kwargs) -> torch.Tensor:
out = self._regularize(*args, **kwargs) * self.weight
self.last_reg = out.detach()
return out
@abc.abstractmethod
def _regularize(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError()
def __str__(self):
return f"Regularizer({self.reg_type}, weight={self.weight})"
class PlaneTV(Regularizer):
def __init__(self, initial_value, what: str = 'field'):
if what not in {'field', 'proposal_network'}:
raise ValueError(f'what must be one of "field" or "proposal_network" '
f'but {what} was passed.')
name = f'planeTV-{what[:2]}'
super().__init__(name, initial_value)
self.what = what
def step(self, global_step):
pass
def _regularize(self, model, **kwargs):
multi_res_grids: Sequence[nn.ParameterList]
if self.what == 'field':
multi_res_grids = model.field.grids
elif self.what == 'proposal_network':
multi_res_grids = [p.grids for p in model.proposal_networks]
else:
raise NotImplementedError(self.what)
total = 0
# Note: input to compute_plane_tv should be of shape [batch_size, c, h, w]
for grids in multi_res_grids:
if len(grids) == 3:
spatial_grids = [0, 1, 2]
else:
spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal
for grid_id in spatial_grids:
total += compute_plane_tv(grids[grid_id])
for grid in grids:
# grid: [1, c, h, w]
total += compute_plane_tv(grid)
return total
class TimeSmoothness(Regularizer):
def __init__(self, initial_value, what: str = 'field'):
if what not in {'field', 'proposal_network'}:
raise ValueError(f'what must be one of "field" or "proposal_network" '
f'but {what} was passed.')
name = f'time-smooth-{what[:2]}'
super().__init__(name, initial_value)
self.what = what
def _regularize(self, model, **kwargs) -> torch.Tensor:
multi_res_grids: Sequence[nn.ParameterList]
if self.what == 'field':
multi_res_grids = model.field.grids
elif self.what == 'proposal_network':
multi_res_grids = [p.grids for p in model.proposal_networks]
else:
raise NotImplementedError(self.what)
total = 0
# model.grids is 6 x [1, rank * F_dim, reso, reso]
for grids in multi_res_grids:
if len(grids) == 3:
time_grids = []
else:
time_grids = [2, 4, 5]
for grid_id in time_grids:
total += compute_plane_smoothness(grids[grid_id])
return torch.as_tensor(total)
class L1ProposalNetwork(Regularizer):
def __init__(self, initial_value):
super().__init__('l1-proposal-network', initial_value)
def _regularize(self, model, **kwargs) -> torch.Tensor:
grids = [p.grids for p in model.proposal_networks]
total = 0.0
for pn_grids in grids:
for grid in pn_grids:
total += torch.abs(grid).mean()
return torch.as_tensor(total)
class DepthTV(Regularizer):
def __init__(self, initial_value):
super().__init__('tv-depth', initial_value)
def _regularize(self, model, model_out, **kwargs) -> torch.Tensor:
depth = model_out['depth']
tv = compute_plane_tv(
depth.reshape(64, 64)[None, None, :, :]
)
return tv
class L1TimePlanes(Regularizer):
def __init__(self, initial_value, what='field'):
if what not in {'field', 'proposal_network'}:
raise ValueError(f'what must be one of "field" or "proposal_network" '
f'but {what} was passed.')
super().__init__(f'l1-time-{what[:2]}', initial_value)
self.what = what
def _regularize(self, model, **kwargs) -> torch.Tensor:
# model.grids is 6 x [1, rank * F_dim, reso, reso]
multi_res_grids: Sequence[nn.ParameterList]
if self.what == 'field':
multi_res_grids = model.field.grids
elif self.what == 'proposal_network':
multi_res_grids = [p.grids for p in model.proposal_networks]
else:
raise NotImplementedError(self.what)
total = 0.0
for grids in multi_res_grids:
if len(grids) == 3:
continue
else:
# These are the spatiotemporal grids
spatiotemporal_grids = [2, 4, 5]
for grid_id in spatiotemporal_grids:
total += torch.abs(1 - grids[grid_id]).mean()
return torch.as_tensor(total)
================================================
FILE: gs_renderer_4d.py
================================================
import os
import math
import numpy as np
from typing import NamedTuple
from plyfile import PlyData, PlyElement
import torch
from torch import nn
from diff_gauss import (
GaussianRasterizationSettings,
GaussianRasterizer,
)
from simple_knn._C import distCUDA2
from sh_utils import eval_sh, SH2RGB, RGB2SH
from deform import *
def inverse_sigmoid(x):
return torch.log(x/(1-x))
def get_expon_lr_func(
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
def helper(step):
if lr_init == lr_final:
# constant lr, ignore other params
return lr_init
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
# Disable this parameter
return 0.0
if lr_delay_steps > 0:
# A kind of reverse cosine decay.
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
)
else:
delay_rate = 1.0
t = np.clip(step / max_steps, 0, 1)
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
return delay_rate * log_lerp
return helper
def strip_lowerdiag(L):
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
uncertainty[:, 0] = L[:, 0, 0]
uncertainty[:, 1] = L[:, 0, 1]
uncertainty[:, 2] = L[:, 0, 2]
uncertainty[:, 3] = L[:, 1, 1]
uncertainty[:, 4] = L[:, 1, 2]
uncertainty[:, 5] = L[:, 2, 2]
return uncertainty
def strip_symmetric(sym):
return strip_lowerdiag(sym)
def gaussian_3d_coeff(xyzs, covs):
# xyzs: [N, 3]
# covs: [N, 6]
x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5]
# eps must be small enough !!!
inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)
inv_a = (d * f - e**2) * inv_det
inv_b = (e * c - b * f) * inv_det
inv_c = (e * b - c * d) * inv_det
inv_d = (a * f - c**2) * inv_det
inv_e = (b * c - e * a) * inv_det
inv_f = (a * d - b**2) * inv_det
power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e
power[power > 0] = -1e10 # abnormal values... make weights 0
return torch.exp(power)
def build_rotation(r):
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
q = r / norm[:, None]
R = torch.zeros((q.size(0), 3, 3), device='cuda')
r = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
R[:, 0, 1] = 2 * (x*y - r*z)
R[:, 0, 2] = 2 * (x*z + r*y)
R[:, 1, 0] = 2 * (x*y + r*z)
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
R[:, 1, 2] = 2 * (y*z - r*x)
R[:, 2, 0] = 2 * (x*z - r*y)
R[:, 2, 1] = 2 * (y*z + r*x)
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
return R
def build_scaling_rotation(s, r):
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
R = build_rotation(r)
L[:,0,0] = s[:,0]
L[:,1,1] = s[:,1]
L[:,2,2] = s[:,2]
L = R @ L
return L
class BasicPointCloud(NamedTuple):
points: np.array
colors: np.array
normals: np.array
class GaussianModel:
def setup_functions(self):
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
actual_covariance = L @ L.transpose(1, 2)
symm = strip_symmetric(actual_covariance)
return symm
self.scaling_activation = torch.exp
self.scaling_inverse_activation = torch.log
self.covariance_activation = build_covariance_from_scaling_rotation
self.opacity_activation = torch.sigmoid
self.inverse_opacity_activation = inverse_sigmoid
self.rotation_activation = torch.nn.functional.normalize
def initialize(self, initial_values, raw=False):
# NOTE: actual initialization is done in trainer
# raw stands for raw values, i.e. not passed through activation
self._xyz = nn.Parameter(initial_values["mean"].requires_grad_(True)).to('cuda')
self._rotation = nn.Parameter(initial_values["qvec"].requires_grad_(True)).to('cuda')
#self._scaling = nn.Parameter(initial_values["svec"].requires_grad_(True)).to('cuda')
#self._features_dc = nn.Parameter(initial_values["color"].requires_grad_(True)).to('cuda')
self._opacity = nn.Parameter(initial_values["alpha"].requires_grad_(True)).to('cuda')
def __init__(self, sh_degree : int,args = None):
self.active_sh_degree = 0
self.max_sh_degree = sh_degree
self._xyz = torch.empty(0)
self._features_dc = torch.empty(0)
self._features_rest = torch.empty(0)
self._scaling = torch.empty(0)
self._rotation = torch.empty(0)
self._opacity = torch.empty(0)
self.max_radii2D = torch.empty(0)
self.xyz_gradient_accum = torch.empty(0)
self.denom = torch.empty(0)
self.optimizer = None
self.percent_dense = 0
self.spatial_lr_scale = 0
self._deformation_table = torch.empty(0)
self._deformation = deform_network()
self.setup_functions()
def capture(self):
return (
self.active_sh_degree,
self._xyz,
self._deformation.state_dict(),
self._deformation_table,
self._features_dc,
self._features_rest,
self._scaling,
self._rotation,
self._opacity,
self.max_radii2D,
self.xyz_gradient_accum,
self.denom,
self.optimizer.state_dict(),
self.spatial_lr_scale,
)
def restore(self, model_args, training_args):
(self.active_sh_degree,
self._xyz,
self._deformation_table,
self._deformation,
self._features_dc,
self._features_rest,
self._scaling,
self._rotation,
self._opacity,
self.max_radii2D,
xyz_gradient_accum,
denom,
opt_dict,
self.spatial_lr_scale) = model_args
self.training_setup(training_args)
self.xyz_gradient_accum = xyz_gradient_accum
self.denom = denom
self.optimizer.load_state_dict(opt_dict)
@property
def get_scaling(self):
return self.scaling_activation(self._scaling)
@property
def get_rotation(self):
return self.rotation_activation(self._rotation)
@property
def get_xyz(self):
return self._xyz
@property
def get_features(self):
features_dc = self._features_dc
features_rest = self._features_rest
return torch.cat((features_dc, features_rest), dim=1)
@property
def get_opacity(self):
return self.opacity_activation(self._opacity)
def get_covariance(self, scaling_modifier = 1):
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
def oneupSHdegree(self):
if self.active_sh_degree < self.max_sh_degree:
self.active_sh_degree += 1
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float = 1):
self.spatial_lr_scale = spatial_lr_scale
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
features[:, :3, 0 ] = fused_color
features[:, 3:, 1:] = 0.0
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
rots[:, 0] = 1
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
self._scaling = nn.Parameter(scales.requires_grad_(True))
self._rotation = nn.Parameter(rots.requires_grad_(True))
self._opacity = nn.Parameter(opacities.requires_grad_(True))
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
#print(self._xyz.shape,self._rotation.shape)
self._deformation = self._deformation.to("cuda")
self.active_sh_degree = self.max_sh_degree
self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
def training_setup(self, training_args):
self.percent_dense = training_args.percent_dense
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
l = [
{'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
{'params': list(self._deformation.get_mlp_parameters()), 'lr': training_args.deformation_lr_init * self.spatial_lr_scale, "name": "deformation"},
{'params': list(self._deformation.get_grid_parameters()), 'lr': training_args.grid_lr_init * self.spatial_lr_scale, "name": "grid"},
{'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
{'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
{'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
{'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
{'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
]
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
lr_delay_mult=training_args.position_lr_delay_mult,
max_steps=training_args.position_lr_max_steps)
self.deformation_scheduler_args = get_expon_lr_func(lr_init=training_args.deformation_lr_init*self.spatial_lr_scale,
lr_final=training_args.deformation_lr_final*self.spatial_lr_scale,
lr_delay_mult=training_args.deformation_lr_delay_mult,
max_steps=training_args.position_lr_max_steps)
self.grid_scheduler_args = get_expon_lr_func(lr_init=training_args.grid_lr_init*self.spatial_lr_scale,
lr_final=training_args.grid_lr_final*self.spatial_lr_scale,
lr_delay_mult=training_args.deformation_lr_delay_mult,
max_steps=training_args.position_lr_max_steps)
def update_learning_rate(self, iteration):
''' Learning rate scheduling per step '''
for param_group in self.optimizer.param_groups:
if param_group["name"] == "xyz":
lr = self.xyz_scheduler_args(iteration)
param_group['lr'] = lr
# return lr
if "grid" in param_group["name"]:
lr = self.grid_scheduler_args(iteration)
param_group['lr'] = lr
# return lr
elif param_group["name"] == "deformation":
lr = self.deformation_scheduler_args(iteration)
param_group['lr'] = lr
# return lr
def construct_list_of_attributes(self):
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
# All channels except the 3 DC
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
l.append('f_dc_{}'.format(i))
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
l.append('f_rest_{}'.format(i))
l.append('opacity')
for i in range(self._scaling.shape[1]):
l.append('scale_{}'.format(i))
for i in range(self._rotation.shape[1]):
l.append('rot_{}'.format(i))
return l
def save_ply(self, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
xyz = self._xyz.detach().cpu().numpy()
normals = np.zeros_like(xyz)
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
opacities = self._opacity.detach().cpu().numpy()
scale = self._scaling.detach().cpu().numpy()
rotation = self._rotation.detach().cpu().numpy()
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
elements = np.empty(xyz.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path)
def compute_deformation(self,time):
deform = self._deformation[:,:,:time].sum(dim=-1)
xyz = self._xyz + deform
return xyz
def load_model(self, path):
print("loading model from exists{}".format(path))
weight_dict = torch.load(os.path.join(path,"deformation.pth"),map_location="cuda")
self._deformation.load_state_dict(weight_dict)
self._deformation = self._deformation.to("cuda")
self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
if os.path.exists(os.path.join(path, "deformation_table.pth")):
self._deformation_table = torch.load(os.path.join(path, "deformation_table.pth"),map_location="cuda")
if os.path.exists(os.path.join(path, "deformation_accum.pth")):
self._deformation_accum = torch.load(os.path.join(path, "deformation_accum.pth"),map_location="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
self._deformation = self._deformation.to("cuda")
def save_deformation(self, path):
torch.save(self._deformation.state_dict(),os.path.join(path, "deformation.pth"))
torch.save(self._deformation_table,os.path.join(path, "deformation_table.pth"))
torch.save(self._deformation_accum,os.path.join(path, "deformation_accum.pth"))
def load_ply(self, path):
plydata = PlyData.read(path)
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"])), axis=1)
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
print("Number of points at loading : ", xyz.shape[0])
features_dc = np.zeros((xyz.shape[0], 3, 1))
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
scales = np.zeros((xyz.shape[0], len(scale_names)))
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
rots = np.zeros((xyz.shape[0], len(rot_names)))
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
#print(self._xyz.shape,self._rotation.shape)
self._deformation = self._deformation.to("cuda")
self.active_sh_degree = self.max_sh_degree
self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
def replace_tensor_to_optimizer(self, tensor, name):
optimizable_tensors = {}
for group in self.optimizer.param_groups:
if group["name"] == name:
stored_state = self.optimizer.state.get(group['params'][0], None)
stored_state["exp_avg"] = torch.zeros_like(tensor)
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def _prune_optimizer(self, mask):
optimizable_tensors = {}
for group in self.optimizer.param_groups:
if len(group["params"]) > 1:
continue
stored_state = self.optimizer.state.get(group['params'][0], None)
if stored_state is not None:
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
else:
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def prune_points(self, mask):
valid_points_mask = ~mask
optimizable_tensors = self._prune_optimizer(valid_points_mask)
self._xyz = optimizable_tensors["xyz"]
self._features_dc = optimizable_tensors["f_dc"]
self._features_rest = optimizable_tensors["f_rest"]
self._opacity = optimizable_tensors["opacity"]
self._scaling = optimizable_tensors["scaling"]
self._rotation = optimizable_tensors["rotation"]
self._deformation_accum = self._deformation_accum[valid_points_mask]
self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
self._deformation_table = self._deformation_table[valid_points_mask]
self.denom = self.denom[valid_points_mask]
self.max_radii2D = self.max_radii2D[valid_points_mask]
def cat_tensors_to_optimizer(self, tensors_dict):
optimizable_tensors = {}
for group in self.optimizer.param_groups:
if len(group["params"])>1:continue
assert len(group["params"]) == 1
extension_tensor = tensors_dict[group["name"]]
stored_state = self.optimizer.state.get(group['params'][0], None)
if stored_state is not None:
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
else:
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table):
d = {"xyz": new_xyz,
"f_dc": new_features_dc,
"f_rest": new_features_rest,
"opacity": new_opacities,
"scaling" : new_scaling,
"rotation" : new_rotation,
# "deformation": new_deformation
}
optimizable_tensors = self.cat_tensors_to_optimizer(d)
self._xyz = optimizable_tensors["xyz"]
self._features_dc = optimizable_tensors["f_dc"]
self._features_rest = optimizable_tensors["f_rest"]
self._opacity = optimizable_tensors["opacity"]
self._scaling = optimizable_tensors["scaling"]
self._rotation = optimizable_tensors["rotation"]
# self._deformation = optimizable_tensors["deformation"]
self._deformation_table = torch.cat([self._deformation_table,new_deformation_table],-1)
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self._deformation_accum = torch.zeros((self.get_xyz.shape[0], 3), device="cuda")
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
n_init_points = self.get_xyz.shape[0]
# Extract points that satisfy the gradient condition
padded_grad = torch.zeros((n_init_points), device="cuda")
padded_grad[:grads.shape[0]] = grads.squeeze()
selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
selected_pts_mask = torch.logical_and(selected_pts_mask,
torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
if not selected_pts_mask.any():
return
stds = self.get_scaling[selected_pts_mask].repeat(N,1)
means =torch.zeros((stds.size(0), 3),device="cuda")
samples = torch.normal(mean=means, std=stds)
rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
new_deformation_table = self._deformation_table[selected_pts_mask].repeat(N)
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_deformation_table)
prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
self.prune_points(prune_filter)
def densify_and_clone(self, grads, grad_threshold, scene_extent):
selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
selected_pts_mask = torch.logical_and(selected_pts_mask,
torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
new_xyz = self._xyz[selected_pts_mask]
# - 0.001 * self._xyz.grad[selected_pts_mask]
new_features_dc = self._features_dc[selected_pts_mask]
new_features_rest = self._features_rest[selected_pts_mask]
new_opacities = self._opacity[selected_pts_mask]
new_scaling = self._scaling[selected_pts_mask]
new_rotation = self._rotation[selected_pts_mask]
new_deformation_table = self._deformation_table[selected_pts_mask]
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table)
def densify_and_prune(self, max_grad_percent, min_opacity, extent, max_screen_size):
grads = self.xyz_gradient_accum / self.denom
grads[grads.isnan()] = 0.0
grad_log = torch.log(grads)
grad_log2=grad_log[~grad_log.isnan()]
grad_log3=grad_log[~grad_log2.isinf()]
max_grad_1 = torch.exp(grad_log3.mean()+grad_log3.var()) #adaptive densification with mean and var, unused
max_grad_2 = torch.exp(grad_log3.squeeze(dim=1).sort(descending=True)[0][int(max_grad_percent*grad_log3.shape[0])]) #adaptive densification with relative grad
max_grad = max_grad_2 #choose which to use
#print('max_grad',max_grad_percent,max_grad_1,max_grad_2,grad_log3.mean(),grad_log3.var())
self.densify_and_clone(grads, max_grad, extent)
self.densify_and_split(grads, max_grad, extent)
prune_mask = (self.get_opacity < min_opacity).squeeze()
if max_screen_size:
big_points_vs = self.max_radii2D > max_screen_size
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
small_ws = self.get_scaling.max(dim=1).values<0.001
prune_mask = torch.logical_or(torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws),small_ws)
self.prune_points(prune_mask)
torch.cuda.empty_cache()
def prune(self, min_opacity=0.01, extent=1, max_screen_size=1):
prune_mask = (self.get_opacity < min_opacity).squeeze()
# prune_mask_2 = torch.logical_and(self.get_opacity <= inverse_sigmoid(0.101 , dtype=torch.float, device="cuda"), self.get_opacity >= inverse_sigmoid(0.999 , dtype=torch.float, device="cuda"))
# prune_mask = torch.logical_or(prune_mask, prune_mask_2)
# deformation_sum = abs(self._deformation).sum(dim=-1).mean(dim=-1)
# deformation_mask = (deformation_sum < torch.quantile(deformation_sum, torch.tensor([0.5]).to("cuda")))
# prune_mask = prune_mask & deformation_mask
if max_screen_size:
big_points_vs = self.max_radii2D > max_screen_size
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
#prune_mask = torch.logical_or(prune_mask, big_points_vs)
small_ws = self.get_scaling.min(dim=1).values<0.001
prune_mask = torch.logical_or(torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws),small_ws)
self.prune_points(prune_mask)
def standard_constaint(self):
means3D = self._xyz.detach()
scales = self._scaling.detach()
rotations = self._rotation.detach()
opacity = self._opacity.detach()
time = torch.tensor(0).to("cuda").repeat(means3D.shape[0],1)
means3D_deform, scales_deform, rotations_deform, _ = self._deformation(means3D, scales, rotations, opacity, time)
position_error = (means3D_deform - means3D)**2
rotation_error = (rotations_deform - rotations)**2
scaling_erorr = (scales_deform - scales)**2
return position_error.mean() + rotation_error.mean() + scaling_erorr.mean()
def add_densification_stats(self, viewspace_point_tensor, update_filter):
#print(viewspace_point_tensor,update_filter)
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor[update_filter,:2], dim=-1, keepdim=True)
self.denom[update_filter] += 1
@torch.no_grad()
def update_deformation_table(self,threshold):
# print("origin deformation point nums:",self._deformation_table.sum())
self._deformation_table = torch.gt(self._deformation_accum.max(dim=-1).values/100,threshold)
def print_deformation_weight_grad(self):
for name, weight in self._deformation.named_parameters():
if weight.requires_grad:
if weight.grad is None:
print(name," :",weight.grad)
else:
if weight.grad.mean() != 0:
print(name," :",weight.grad.mean(), weight.grad.min(), weight.grad.max())
print("-"*50)
def _plane_regulation(self):
multi_res_grids = self._deformation.deformation_net.grid.grids
total = 0
# model.grids is 6 x [1, rank * F_dim, reso, reso]
for grids in multi_res_grids:
if len(grids) == 3:
time_grids = []
else:
time_grids = [0,1,3]
for grid_id in time_grids:
total += compute_plane_smoothness(grids[grid_id])
return total
def _time_regulation(self):
multi_res_grids = self._deformation.deformation_net.grid.grids
total = 0
# model.grids is 6 x [1, rank * F_dim, reso, reso]
for grids in multi_res_grids:
if len(grids) == 3:
time_grids = []
else:
time_grids =[2, 4, 5]
for grid_id in time_grids:
total += compute_plane_smoothness(grids[grid_id])
return total
def _l1_regulation(self):
# model.grids is 6 x [1, rank * F_dim, reso, reso]
multi_res_grids = self._deformation.deformation_net.grid.grids
total = 0.0
for grids in multi_res_grids:
if len(grids) == 3:
continue
else:
# These are the spatiotemporal grids
spatiotemporal_grids = [2, 4, 5]
for grid_id in spatiotemporal_grids:
total += torch.abs(1 - grids[grid_id]).mean()
return total
def compute_regulation(self, time_smoothness_weight, l1_time_planes_weight, plane_tv_weight):
return plane_tv_weight * self._plane_regulation() + time_smoothness_weight * self._time_regulation() + l1_time_planes_weight * self._l1_regulation()
def getProjectionMatrix(znear, zfar, fovX, fovY):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 1 / tanHalfFovX
P[1, 1] = 1 / tanHalfFovY
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
class MiniCam:
def __init__(self, c2w, width, height, fovy, fovx, znear, zfar,time=0 ):
# c2w (pose) should be in NeRF convention.
self.image_width = width
self.image_height = height
self.FoVy = fovy
self.FoVx = fovx
self.znear = znear
self.zfar = zfar
self.time = time
w2c = np.linalg.inv(c2w)
# rectify...
w2c[1:3, :3] *= -1
w2c[:3, 3] *= -1
self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()
self.projection_matrix = (
getProjectionMatrix(
znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
)
.transpose(0, 1)
.cuda()
)
self.full_proj_transform = self.world_view_transform @ self.projection_matrix
self.camera_center = -torch.tensor(c2w[:3, 3]).cuda()
class Renderer:
def __init__(self, sh_degree=3, white_background=True, radius=1):
self.sh_degree = sh_degree
self.white_background = white_background
self.radius = radius
self.gaussians = GaussianModel(sh_degree)
self.bg_color = torch.tensor(
[1, 1, 1] if white_background else [0, 0, 0],
dtype=torch.float32,
device="cuda",
)
def initialize(self, input=None, num_pts=5000, radius=0.5,initial_values=None):
# load checkpoint
if input is None:
# init from random point cloud
phis = np.random.random((num_pts,)) * 2 * np.pi
costheta = np.random.random((num_pts,)) * 2 - 1
thetas = np.arccos(costheta)
mu = np.random.random((num_pts,))
radius = radius * np.cbrt(mu)
x = radius * np.sin(thetas) * np.cos(phis)
y = radius * np.sin(thetas) * np.sin(phis)
z = radius * np.cos(thetas)
xyz = np.stack((x, y, z), axis=1)
if initial_values is not None:
print(xyz.shape,initial_values["mean"].shape)
R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
xyz = np.dot(initial_values["mean"].numpy(),R)
# xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
shs = np.random.random((num_pts, 3)) / 255.0
pcd = BasicPointCloud(
points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
)
self.gaussians.create_from_pcd(pcd, 10)
elif isinstance(input, BasicPointCloud):
# load from a provided pcd
self.gaussians.create_from_pcd(input, 1)
else:
# load from saved ply
self.gaussians.load_ply(input)
def render(
self,
viewpoint_camera,
scaling_modifier=1.0,
bg_color=None,
override_color=None,
compute_cov3D_python=False,
convert_SHs_python=False,
stage="fine",
time_int = None,
front_view=False,
):
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros_like(self.gaussians.get_xyz, dtype=self.gaussians.get_xyz.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
pass
# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=self.bg_color if bg_color is None else bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=self.gaussians.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=False,
)
if front_view==True:
print(viewpoint_camera.world_view_transform,viewpoint_camera.full_proj_transform)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = self.gaussians.get_xyz
time = torch.tensor(viewpoint_camera.time).to(means3D.device).repeat(means3D.shape[0],1)
means2D = screenspace_points
opacity = self.gaussians._opacity
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
if compute_cov3D_python:
cov3D_precomp = self.gaussians.get_covariance(scaling_modifier)
else:
scales = self.gaussians._scaling
rotations = self.gaussians._rotation
deformation_point = self.gaussians._deformation_table
if stage == "coarse" :
means3D_deform, scales_deform, rotations_deform, opacity_deform = means3D, scales, rotations, opacity
else:
means3D_deform, scales_deform, rotations_deform, opacity_deform = self.gaussians._deformation(means3D[deformation_point], scales[deformation_point],
rotations[deformation_point], opacity[deformation_point],
time[deformation_point])
# print(time.max())
with torch.no_grad():
self.gaussians._deformation_accum[deformation_point] += torch.abs(means3D_deform-means3D[deformation_point])
#print(torch.abs(means3D_deform-means3D[deformation_point]).mean())
means3D_final = torch.zeros_like(means3D)
rotations_final = torch.zeros_like(rotations)
scales_final = torch.zeros_like(scales)
opacity_final = torch.zeros_like(opacity)
means3D_final[deformation_point] = means3D_deform
rotations_final[deformation_point] = rotations_deform
scales_final[deformation_point] = scales_deform
opacity_final[deformation_point] = opacity_deform
means3D_final[~deformation_point] = means3D[~deformation_point]
rotations_final[~deformation_point] = rotations[~deformation_point]
scales_final[~deformation_point] = scales[~deformation_point]
opacity_final[~deformation_point] = opacity[~deformation_point]
scales_in=scales_final
rotations_in=rotations_final
opacity_in = opacity_final
scales_final = self.gaussians.scaling_activation(scales_final)
rotations_final = self.gaussians.rotation_activation(rotations_final)
opacity = self.gaussians.opacity_activation(opacity)
opacity_final = self.gaussians.opacity_activation(opacity_final)
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
shs = None
colors_precomp = None
shs = self.gaussians.get_features
# Rasterize visible Gaussians to image, obtain their radii (on screen).
rendered_image, rendered_depth, normal, rendered_alpha ,radii, _ = rasterizer(
means3D = means3D_final,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
opacities = opacity_final,
scales = scales_final,
rotations = rotations_final,
cov3Ds_precomp = cov3D_precomp)
return {
"image": rendered_image,
"depth": rendered_depth,
"alpha": rendered_alpha,
"viewspace_points": screenspace_points,
"visibility_filter": radii > 0,
"radii": radii,
'xyz':means3D_final,
'rot':rotations_in,
'xy':means2D,
'color':shs,
'scales':scales_in,
'opacity':opacity_in,
}
================================================
FILE: guidance/zero123_4d_utils.py
================================================
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
DDIMScheduler,
StableDiffusionPipeline,
)
import torchvision.transforms.functional as TF
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('./')
from zero123 import Zero123Pipeline
class Zero123(nn.Module):
def __init__(self, device, fp16=True, t_range=[0.02, 0.98]):
super().__init__()
self.device = device
self.fp16 = fp16
self.dtype = torch.float16 if fp16 else torch.float32
self.pipe = Zero123Pipeline.from_pretrained(
# "bennyguo/zero123-diffusers",
"ashawkey/zero123-xl-diffusers",
# './model_cache/zero123_xl',
variant="fp16" if self.fp16 else None,
torch_dtype=self.dtype,
).to(self.device)
# for param in self.pipe.parameters():
# param.requires_grad = False
self.pipe.image_encoder.eval()
self.pipe.vae.eval()
self.pipe.unet.eval()
self.pipe.clip_camera_projection.eval()
self.vae = self.pipe.vae
self.unet = self.pipe.unet
self.pipe.set_progress_bar_config(disable=True)
self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.min_step_percent = [0, 0.5, 0.02, 3000]
self.max_step_percent= [0, 0.95, 0.5, 3000]
self.embeddings = None
self.embedding_list = []
@torch.no_grad()
def get_img_embeds(self, x, input_imgs=None):
# x: image tensor in [0, 1]
x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False)
x_pil = [TF.to_pil_image(image) for image in x]
x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype)
c = self.pipe.image_encoder(x_clip).image_embeds
v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor
self.embeddings = [c, v]
self.additional_embeddings=[]
if input_imgs!=None:
for x in input_imgs:
x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False)
x_pil = [TF.to_pil_image(image) for image in x]
x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype)
c = self.pipe.image_encoder(x_clip).image_embeds
v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor
embeddings = [c, v]
self.additional_embeddings.append(embeddings)
self.embedding_list.append([self.embeddings,self.additional_embeddings])
@torch.no_grad()
def refine(self, pred_rgb, polar, azimuth, radius,
guidance_scale=5, steps=50, strength=0.8,
):
batch_size = pred_rgb.shape[0]
self.scheduler.set_timesteps(steps)
if strength == 0:
init_step = 0
latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype)
else:
init_step = int(steps * strength)
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1)
T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4]
cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)
cc_emb = self.pipe.clip_camera_projection(cc_emb)
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)
vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1)
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)
for i, t in enumerate(self.scheduler.timesteps[init_step:]):
x_in = torch.cat([latents] * 2)
t_in = torch.cat([t.view(1)] * 2).to(self.device)
noise_pred = self.unet(
torch.cat([x_in, vae_emb], dim=1),
t_in.to(self.unet.dtype),
encoder_hidden_states=cc_emb,
).sample
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
imgs = self.decode_latents(latents) # [1, 3, 256, 256]
return imgs
def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False,idx=None,t=0):
# pred_rgb: tensor [1, 3, H, W] in [0, 1]
#print(polar)
step_ratio = max(0.4,step_ratio)
self.embeddings,self.additional_embeddings = self.embedding_list[t]
batch_size = pred_rgb.shape[0]
#print(self.embedding_list[1][0][0] -self.embedding_list[2][0][0])
#print(self.embedding_list[1][0][1] -self.embedding_list[2][0][1])
if idx is not None:
embeddings = self.additional_embeddings[idx]
else:
embeddings = self.embeddings
if as_latent:
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
else:
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
with torch.no_grad():
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
x_in = torch.cat([latents_noisy] * 2)
t_in = torch.cat([t] * 2)
T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1)
T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4]
#print(self.embeddings[0].repeat(batch_size, 1, 1).shape,T.shape)
cc_emb = torch.cat([embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)
cc_emb = self.pipe.clip_camera_projection(cc_emb)
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)
vae_emb = embeddings[1].repeat(batch_size, 1, 1, 1)
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)
noise_pred = self.unet(
torch.cat([x_in, vae_emb], dim=1),
t_in.to(self.unet.dtype),
encoder_hidden_states=cc_emb,
).sample
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
target = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum')
return loss
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
min_step_percent = self.get_steps(self.min_step_percent, epoch, global_step)
max_step_percent = self.get_steps(self.max_step_percent, epoch, global_step)
self.min_step = int( self.num_train_timesteps * min_step_percent )
self.max_step = int( self.num_train_timesteps * max_step_percent )
def get_steps(self,percent,epoch, global_step):
start_step, start_value, end_value, end_step = percent
current_step = global_step
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
return value
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs, mode=False):
# imgs: [B, 3, H, W]
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
if mode:
latents = posterior.mode()
else:
latents = posterior.sample()
latents = latents * self.vae.config.scaling_factor
return latents
if __name__ == '__main__':
import cv2
import argparse
import numpy as np
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str)
parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')
parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')
parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')
opt = parser.parse_args()
device = torch.device('cuda')
print(f'[INFO] loading image from {opt.input} ...')
image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
image = image.astype(np.float32) / 255.0
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
print(f'[INFO] loading model ...')
zero123 = Zero123(device)
print(f'[INFO] running model ...')
zero123.get_img_embeds(image)
while True:
outputs = zero123.refine(image, polar=[opt.polar], azimuth=[opt.azimuth], radius=[opt.radius], strength=0)
plt.imshow(outputs.float().cpu().numpy().transpose(0, 2, 3, 1)[0])
plt.show()
================================================
FILE: guidance/zero123pp/pipeline.py
================================================
from typing import Any, Dict, Optional
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
import numpy
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torch.distributed
import transformers
from collections import OrderedDict
from PIL import Image
from torchvision import transforms
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
import torch
import torch.nn.functional as F
from torch import nn
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
UNet2DConditionModel,
ImagePipelineOutput
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor
from diffusers.utils.import_utils import is_xformers_available
import os
FIRST = True
IDX = 0
PATH = '/home/vision/github/embeddings/'
EMBED=[]
def to_rgb_image(maybe_rgba: Image.Image):
if maybe_rgba.mode == 'RGB':
return maybe_rgba
elif maybe_rgba.mode == 'RGBA':
rgba = maybe_rgba
img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
img = Image.fromarray(img, 'RGB')
img.paste(rgba, mask=rgba.getchannel('A'))
return img
else:
raise ValueError("Unsupported image type.", maybe_rgba.mode)
class MyAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MyAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
is_self_attention=False
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if is_self_attention:
global FIRST
global IDX
global EMBED
if FIRST == True:
EMBED.append(encoder_hidden_states.to('cpu'))
#print('saving to {})'.format(PATH,str(IDX)+'_hidden.pt'))
#os.makedirs(PATH,exist_ok=True)
#torch.save(encoder_hidden_states,os.path.join(PATH,str(IDX)+'_hidden.pt'))
IDX=IDX+1
else:
last_shape = encoder_hidden_states.shape[-1]
replace_dim = int(9600/(last_shape//320)**2)
encoder_hidden_states_load = EMBED[IDX].to('cuda')
encoder_hidden_states[:,:replace_dim,:]=(encoder_hidden_states_load[:,:replace_dim,:]+encoder_hidden_states[:,:replace_dim,:])/2
IDX=IDX+1
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
#print(key.shape)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class ReferenceOnlyAttnProc(torch.nn.Module):
def __init__(
self,
chained_proc,
enabled=False,
name=None
) -> None:
super().__init__()
self.enabled = enabled
self.chained_proc = chained_proc
self.name = name
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
mode="w", ref_dict: dict = None, is_cfg_guidance = False
) -> Any:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
is_self_attention = False
if self.enabled and is_cfg_guidance:
res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask)
hidden_states = hidden_states[1:]
encoder_hidden_states = encoder_hidden_states[1:]
if self.enabled:
if mode == 'w':
ref_dict[self.name] = encoder_hidden_states
elif mode == 'r':
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
is_self_attention = True
elif mode == 'm':
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
else:
assert False, mode
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask,is_self_attention=is_self_attention)
if self.enabled and is_cfg_guidance:
res = torch.cat([res0, res])
return res
class RefOnlyNoisedUNet(torch.nn.Module):
def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
super().__init__()
self.unet = unet
self.train_sched = train_sched
self.val_sched = val_sched
unet_lora_attn_procs = dict()
for name, _ in unet.attn_processors.items():
if torch.__version__ >= '2.0':
default_attn_proc = MyAttnProcessor2_0()
print('using my attention')
elif is_xformers_available():
default_attn_proc = XFormersAttnProcessor()
else:
default_attn_proc = AttnProcessor()
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
)
unet.set_attn_processor(unet_lora_attn_procs)
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.unet, name)
def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
if is_cfg_guidance:
encoder_hidden_states = encoder_hidden_states[1:]
class_labels = class_labels[1:]
self.unet(
noisy_cond_lat, timestep,
encoder_hidden_states=encoder_hidden_states,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
**kwargs
)
def forward(
self, sample, timestep, encoder_hidden_states, class_labels=None,
*args, cross_attention_kwargs,
down_block_res_samples=None, mid_block_res_sample=None,
**kwargs
):
cond_lat = cross_attention_kwargs['cond_lat']
is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
noise = torch.randn_like(cond_lat)
if self.training:
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
else:
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
#if cross_attention_kwargs['cond_lat_back'] is not None:
# cond_lat_back = cross_attention_kwargs['cond_lat_back']
# noisy_cond_lat_back = self.val_sched.add_noise(cond_lat_back, noise, timestep.reshape(-1))
# noisy_cond_lat_back = self.val_sched.scale_model_input(noisy_cond_lat_back, timestep.reshape(-1))
ref_dict = {}
self.forward_cond(
noisy_cond_lat, timestep,
encoder_hidden_states, class_labels,
ref_dict, is_cfg_guidance, **kwargs
)
weight_dtype = self.unet.dtype
return self.unet(
sample, timestep,
encoder_hidden_states, *args,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
down_block_additional_residuals=[
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
] if down_block_res_samples is not None else None,
mid_block_additional_residual=(
mid_block_res_sample.to(dtype=weight_dtype)
if mid_block_res_sample is not None else None
),
**kwargs
)
def scale_latents(latents):
latents = (latents - 0.22) * 0.75
return latents
def unscale_latents(latents):
latents = latents / 0.75 + 0.22
return latents
def scale_image(image):
image = image * 0.5 / 0.8
return image
def unscale_image(image):
image = image / 0.5 * 0.8
return image
class DepthControlUNet(torch.nn.Module):
def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None:
super().__init__()
self.unet = unet
if controlnet is None:
self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet)
else:
self.controlnet = controlnet
DefaultAttnProc = MyAttnProcessor2_0
if is_xformers_available():
DefaultAttnProc = XFormersAttnProcessor
self.controlnet.set_attn_processor(DefaultAttnProc())
self.conditioning_scale = conditioning_scale
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.unet, name)
def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs):
cross_attention_kwargs = dict(cross_attention_kwargs)
control_depth = cross_attention_kwargs.pop('control_depth')
down_block_res_samples, mid_block_res_sample = self.controlnet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_depth,
conditioning_scale=self.conditioning_scale,
return_dict=False,
)
return self.unet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
down_block_res_samples=down_block_res_samples,
mid_block_res_sample=mid_block_res_sample,
cross_attention_kwargs=cross_attention_kwargs
)
class ModuleListDict(torch.nn.Module):
def __init__(self, procs: dict) -> None:
super().__init__()
self.keys = sorted(procs.keys())
self.values = torch.nn.ModuleList(procs[k] for k in self.keys)
def __getitem__(self, key):
return self.values[self.keys.index(key)]
class SuperNet(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys()))
self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# .processor for unet, .self_attn for text encoder
self.split_keys = [".processor", ".self_attn"]
# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
def map_to(module, state_dict, *args, **kwargs):
new_state_dict = {}
for key, value in state_dict.items():
num = int(key.split(".")[1]) # 0 is always "layers"
new_key = key.replace(f"layers.{num}", module.mapping[num])
new_state_dict[new_key] = value
return new_state_dict
def remap_key(key, state_dict):
for k in self.split_keys:
if k in key:
return key.split(k)[0] + k
return key.split('.')[0]
def map_from(module, state_dict, *args, **kwargs):
all_keys = list(state_dict.keys())
for key in all_keys:
replace_key = remap_key(key, state_dict)
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
state_dict[new_key] = state_dict[key]
del state_dict[key]
self._register_state_dict_hook(map_to)
self._register_load_state_dict_pre_hook(map_from, with_module=True)
class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
tokenizer: transformers.CLIPTokenizer
text_encoder: transformers.CLIPTextModel
vision_encoder: transformers.CLIPVisionModelWithProjection
feature_extractor_clip: transformers.CLIPImageProcessor
unet: UNet2DConditionModel
scheduler: diffusers.schedulers.KarrasDiffusionSchedulers
vae: AutoencoderKL
ramping: nn.Linear
feature_extractor_vae: transformers.CLIPImageProcessor
depth_transforms_multi = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
vision_encoder: transformers.CLIPVisionModelWithProjection,
feature_extractor_clip: CLIPImageProcessor,
feature_extractor_vae: CLIPImageProcessor,
ramping_coefficients: Optional[list] = None,
safety_checker=None,
):
DiffusionPipeline.__init__(self)
self.register_modules(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
unet=unet, scheduler=scheduler, safety_checker=None,
vision_encoder=vision_encoder,
feature_extractor_clip=feature_extractor_clip,
feature_extractor_vae=feature_extractor_vae
)
self.register_to_config(ramping_coefficients=ramping_coefficients)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def prepare(self):
train_sched = DDPMScheduler.from_config(self.scheduler.config)
if isinstance(self.unet, UNet2DConditionModel):
self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()
def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0):
self.prepare()
self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)
return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)]))
def encode_condition_image(self, image: torch.Tensor):
image = self.vae.encode(image).latent_dist.sample()
return image
@torch.no_grad()
def __call__(
self,
image: Image.Image = None,
prompt = "",
*args,
num_images_per_prompt: Optional[int] = 1,
guidance_scale=4.0,
depth_image: Image.Image = None,
output_type: Optional[str] = "pil",
width=640,
height=960,
num_inference_steps=28,
return_dict=True,
is_first = False,
**kwargs
):
global FIRST
FIRST = is_first
global IDX
IDX = 0
if is_first:
global EMBED
EMBED=[]
# Create a generator with the specified seed
generator = torch.Generator(device='cuda')
generator.manual_seed(42)
self.prepare()
if image is None:
raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
assert not isinstance(image, torch.Tensor)
image = to_rgb_image(image)
image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
if depth_image is not None and hasattr(self.unet, "controlnet"):
depth_image = to_rgb_image(depth_image)
depth_image = self.depth_transforms_multi(depth_image).to(
device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
)
image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
cond_lat = self.encode_condition_image(image)
if guidance_scale > 1:
negative_lat = self.encode_condition_image(torch.zeros_like(image))
cond_lat = torch.cat([negative_lat, cond_lat])
encoded = self.vision_encoder(image_2, output_hidden_states=False)
global_embeds = encoded.image_embeds
global_embeds = global_embeds.unsqueeze(-2)
encoder_hidden_states = self._encode_prompt(
prompt,
self.device,
num_images_per_prompt,
False
)
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
cak = dict(cond_lat=cond_lat)
if hasattr(self.unet, "controlnet"):
cak['control_depth'] = depth_image
cak['cond_lat_back'] = None
latents: torch.Tensor = super().__call__(
None,
*args,
cross_attention_kwargs=cak,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=encoder_hidden_states,
num_inference_steps=num_inference_steps,
output_type='latent',
width=width,
height=height,
generator=generator,
**kwargs
).images
latents = unscale_latents(latents)
if not output_type == "latent":
image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
================================================
FILE: main.py
================================================
import os
import cv2
import time
import tqdm
import numpy as np
import dearpygui.dearpygui as dpg
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
from einops import rearrange, repeat
import imageio
import rembg
from cam_utils import orbit_camera, OrbitCamera
from gs_renderer_4d import Renderer, MiniCam
from dataset_4d import SparseDataset
def save_image_to_local(image_tensor, file_path):
# Ensure the image tensor is in the range [0, 1]
image_tensor = image_tensor.clamp(0, 1)
# Save the image tensor to the specified file path
vutils.save_image(image_tensor, file_path)
class GUI:
def __init__(self, opt):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.gui = opt.gui # enable gui
self.W = opt.W
self.H = opt.H
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
self.mode = "image"
self.seed = "random"
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # update buffer_image
# models
self.device = torch.device("cuda")
self.bg_remover = None
self.guidance_sd = None
self.guidance_zero123 = None
self.enable_sd = False
self.enable_zero123 = False
# renderer
self.renderer = Renderer(sh_degree=self.opt.sh_degree)
self.gaussain_scale_factor = 1
# input image
self.input_img = None
self.input_mask = None
self.input_img_torch = None
self.input_mask_torch = None
self.overlay_input_img = False
self.overlay_input_img_ratio = 0.5
#self.use_depth = opt.use_depth
# input text
self.prompt = ""
self.negative_prompt = ""
# training stuff
self.training = False
self.optimizer = None
self.step = 0
self.t = 0
self.time = 0
self.train_steps = 1 # steps per rendering loop
self.init = True
self.stage = 'coarse'
self.path = self.opt.path
self.save_step = self.opt.save_step
if self.opt.size is not None:
self.size = self.opt.size
else:
self.size = len(os.listdir(os.path.join(self.path,'ref')))
self.frames=self.size
self.dataset = SparseDataset(self.opt, self.size, H=self.H, W=self.W, device=self.device)
self.dataloader =self.dataset.dataloader()
self.iter = iter(self.dataloader)
self.ref_view_batch, self.input_mask_batch,self.zero123_view_batch,self.zero123_masks_batch = next(self.iter)
self.input_img_torch_batch,self.input_mask_torch_batch,self.zero123plus_imgs_torch_batch,self.zero123plus_masks_torch_batch=[],[],[],[]
# load input data from cmdline
if self.opt.input is not None:
self.load_input(self.opt.input)
# override prompt from cmdline
if self.opt.prompt is not None:
self.prompt = self.opt.prompt
# override if provide a checkpoint
self.renderer.initialize(num_pts=self.opt.num_pts)
self.point_nums = []
if self.gui:
dpg.create_context()
self.register_dpg()
self.test_step()
def __del__(self):
if self.gui:
dpg.destroy_context()
def seed_everything(self):
try:
seed = int(self.seed)
except:
seed = np.random.randint(0, 1000000)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
self.last_seed = seed
def prepare_image(self,idx):
# input image
if self.input_img is not None:
self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device)
self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)
self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
self.zero123plus_imgs_torch=[]
self.zero123plus_masks_torch=[]
# input image
if self.input_imgs is not None:
for i in np.arange(6):
#print(idx,i)
self.input_img2_torch=(torch.from_numpy(self.input_imgs[i]).permute(2, 0, 1).unsqueeze(0).to(self.device))
self.zero123plus_imgs_torch.append(F.interpolate(self.input_img2_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False))
self.input_mask2_torch=torch.from_numpy(self.input_masks[i]).permute(2, 0, 1).unsqueeze(0).to(self.device)
self.zero123plus_masks_torch.append(F.interpolate(self.input_mask2_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False))
self.input_img_torch_batch.append(self.input_img_torch)
self.input_mask_torch_batch.append(self.input_mask_torch)
self.zero123plus_imgs_torch_batch.append(self.zero123plus_imgs_torch)
self.zero123plus_masks_torch_batch.append(self.zero123plus_masks_torch)
# prepare embeddings
with torch.no_grad():
self.guidance_zero123.get_img_embeds(self.input_img_torch, self.zero123plus_imgs_torch)
def prepare_train(self):
self.step = 0
self.end_step = self.save_step+1
## given a load_path, load corresponding model
if self.opt.load_path is not None:
if self.opt.load_step is not None:
self.step = self.opt.load_step
else:
#default loading save_step ply
self.step = self.save_step
auto_path = os.path.join(self.opt.outdir,self.opt.load_path + str(self.step))
ply_path = os.path.join(auto_path,'model.ply')
self.renderer.gaussians.load_model(auto_path)
self.renderer.gaussians.load_ply(ply_path)
self.end_step =self.step+self.end_step
## setup training
self.renderer.gaussians.training_setup(self.opt)
## do not do progressive sh-level
self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree
self.optimizer = self.renderer.gaussians.optimizer
# default camera
pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)
self.fixed_cam = MiniCam(
pose,
self.opt.ref_size,
self.opt.ref_size,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
)
self.set_fix_cam()
self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != ""
self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None
print(f"[INFO] loading zero123...")
from guidance.zero123_4d_utils import Zero123
self.guidance_zero123 = Zero123(self.device)
print(f"[INFO] loaded zero123!")
## load multiview reference images
for i in np.arange(len(self.ref_view_batch)):
self.input_img = self.ref_view_batch[i]
self.input_mask = self.input_mask_batch[i]
self.input_imgs = self.zero123_view_batch[i]
self.input_masks = self.zero123_masks_batch[i]
self.prepare_image(i)
def train_step(self):
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()
torch.autograd.set_detect_anomaly(True)
for _ in range(self.train_steps):
if self.step╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", "│ in <module>:6 │\n", "│ │\n", "│ 3 opt=OmegaConf.load('./configs/image_4d_m.yaml') │\n", "│ 4 │\n", "│ 5 train=trainer(opt) │\n", "│ ❱ 6 train.train() │\n", "│ 7 │\n", "│ │\n", "│ in train:134 │\n", "│ │\n", "│ 131 │ │ │ │ loss = loss + zero123_loss │\n", "│ 132 │ │ │ │\n", "│ 133 │ │ │ # optimize step │\n", "│ ❱ 134 │ │ │ loss.backward() │\n", "│ 135 │ │ │ self.optimizer.step() │\n", "│ 136 │ │ │ self.optimizer.zero_grad() │\n", "│ 137 │\n", "│ │\n", "│ /home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/_tensor.py:487 in backward │\n", "│ │\n", "│ 484 │ │ │ │ create_graph=create_graph, │\n", "│ 485 │ │ │ │ inputs=inputs, │\n", "│ 486 │ │ │ ) │\n", "│ ❱ 487 │ │ torch.autograd.backward( │\n", "│ 488 │ │ │ self, gradient, retain_graph, create_graph, inputs=inputs │\n", "│ 489 │ │ ) │\n", "│ 490 │\n", "│ │\n", "│ /home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/autograd/__init__.py:200 │\n", "│ in backward │\n", "│ │\n", "│ 197 │ # The reason we repeat same the comment below is that │\n", "│ 198 │ # some Python versions print out the first line of a multi-line function │\n", "│ 199 │ # calls in the traceback and some print out the last line │\n", "│ ❱ 200 │ Variable._execution_engine.run_backward( # Calls into the C++ engine to run the bac │\n", "│ 201 │ │ tensors, grad_tensors_, retain_graph, create_graph, inputs, │\n", "│ 202 │ │ allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to ru │\n", "│ 203 │\n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "KeyboardInterrupt\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m