Showing preview only (268K chars total). Download the full file or copy to clipboard to get everything.
Repository: zeng-yifei/STAG4D
Branch: main
Commit: 9aa21c92f40b
Files: 25
Total size: 257.0 KB
Directory structure:
gitextract_687zu6g9/
├── .gitignore
├── README.md
├── __init__.py
├── cam_utils.py
├── configs/
│ └── stag4d.yaml
├── dataset_4d.py
├── deform.py
├── gs_renderer_4d.py
├── guidance/
│ ├── zero123_4d_utils.py
│ └── zero123pp/
│ └── pipeline.py
├── main.py
├── mini_trainer.ipynb
├── requirements.txt
├── scripts/
│ ├── app.py
│ └── gen_mv.py
├── sh_utils.py
├── simple-knn/
│ ├── ext.cpp
│ ├── setup.py
│ ├── simple_knn/
│ │ └── .gitkeep
│ ├── simple_knn.cu
│ ├── simple_knn.h
│ ├── spatial.cu
│ └── spatial.h
├── visualize.py
└── zero123.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
*.pyc
./valid/*
================================================
FILE: README.md
================================================
<h1>[ECCV 2024] STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians</h1>
<div>
<a href='https://github.com/zeng-yifei?tab=repositories/' target='_blank'>Yifei Zeng</a><sup>1</sup> 
<a href="https://github.com/yanqinJiang" target='_blank'>Yanqin Jiang*</a><sup>2</sup> 
<a href="https://sites.google.com/site/zhusiyucs/home/" target='_blank'>Siyu Zhu</a><sup>3</sup> 
<a href='https://github.com/YuanxunLu' target='_blank'>Yuanxun Lu</a><sup>1</sup> 
<a href="https://linyou.github.io/">Youtian Lin</a><sup>1</sup> 
<a href='https://zhuhao-nju.github.io/home/' target='_blank'>Hao Zhu</a><sup>1</sup> 
<a href="https://people.ucas.ac.cn/~huweiming">Weiming Hu</a><sup>2</sup> 
<a href='https://cite.nju.edu.cn/People/Faculty/20190621/i5054.html' target='_blank'>Xun Cao</a><sup>1</sup> 
<a href='https://yoyo000.github.io/' target='_blank'>Yao Yao</a><sup>1+</sup> 
</div>
<div>
<sup>1</sup>Nanjing University
<sup>2</sup>CASIA
<sup>3</sup>Fudan University
</div>
<div>
<sup>*</sup>equal contribution
<sup>+</sup>corresponding author
</div>
<h4 align="center">
<a href="https://nju-3dv.github.io/projects/STAG4D/" target='_blank'>[Project Page]</a> •
</h4>
# Update
7.4 Our paper has been accepted by ECCV 2024. Congrats!
6.20: IMPORTANT. Fix the bug caused by new version of diff_gauss. Newest version of diff_gauss use `color, depth, norm, alpha, radii, extra` as an output. However, previous version use `color, depth, alpha, radii` as an output. Using older version of this code will cause mismatch error and may misuse normal for the alpha loss, resulting in bad results.
5.26: Update Text/Image to 4D data below.
5.21: Fix RGB loss into the batch loop. Add visualize code.
# ⚙️ Installation
```bash
pip install -r requirements.txt
git clone --recursive https://github.com/slothfulxtx/diff-gaussian-rasterization.git
pip install ./diff-gaussian-rasterization
pip install ./simple-knn
```
# Video-to-4D
To generate the examples in the project page, you can download the dataset from [google drive](https://drive.google.com/file/d/1YDvhBv6z5SByF_WaTQVzzL9qz3TyEm6a/view?usp=sharing). Place them in the dataset folder, and run:
```bash
python main.py --config configs/stag4d.yaml path=dataset/minions save_path=minions
#use --gui=True to turn on the visualizer (recommend)
python main.py --config configs/stag4d.yaml path=dataset/minions save_path=minions gui=True
```
To generate the spatial-temporal consistent data from stratch, your should place your rgba data in the form of
```
├── dataset
│ | your_data
│ ├── 0_rgba.png
│ ├── 1_rgba.png
│ ├── 2_rgba.png
│ ├── ...
```
and then run
```bash
python scripts/gen_mv.py --path dataset/your_data --pipeline_path xxx/guidance/zero123pp
python main.py --config configs/stag4d.yaml path=data_path save_path=saving_path gui=True
```
To visualize the result, use you can replace the main.py with visualize.py, and the result will be saved to the valid/xxx path, e.g.:
```bash
python visualize.py --config configs/stag4d.yaml path=dataset/minions save_path=minions
```
<img src='assets/videoto4d.gif' height='60%'>
# Text-to-4D
For Text to 4D generation, we recommend using SDXL and SVD to generate a reasonable video. Then, after matting the video, use
the command above to generate a good 4D result. (This pipeline contains many independent parts and is kind of complex, so we may upload the whole workflow after integration if possible.)
If you want generate the examples in the paper, I also updated the corresponding data here in [google drive](https://drive.google.com/file/d/1EDNL7EBMR1vlfMOABdXjHzcKY7IXdcnj/view?usp=sharing). Remember to set size to 26 in config or use `size=26` in the command:
```bash
python main.py --config configs/stag4d.yaml path=dataset/xxx save_path=xxx size=26
```
<img src='assets/textto4d3.gif' height='60%'>
# Tips for better quality
If you want sacrifice time for better quality, here is some tips you can try to further improve the generated quality.
1, Use larger batch size.
2, Run for more steps.
## Citation
If you find our work useful for your research, please consider citing our paper as well as Consistent4D:
```
@article{zeng2024stag4d,
title={STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians},
author={Yifei Zeng and Yanqin Jiang and Siyu Zhu and Yuanxun Lu and Youtian Lin and Hao Zhu and Weiming Hu and Xun Cao and Yao Yao},
year={2024},
eprint={2403.14939},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
@article{jiang2023consistent4d,
title={Consistent4D: Consistent 360{\deg} Dynamic Object Generation from Monocular Video},
author={Yanqin Jiang and Li Zhang and Jin Gao and Weimin Hu and Yao Yao},
year={2023},
eprint={2311.02848},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
# Acknowledgment
This repo is built on [DreamGaussian](https://github.com/dreamgaussian/dreamgaussian) and [Zero123plus](https://github.com/SUDO-AI-3D/zero123plus). Thank all the authors for their great work.
================================================
FILE: __init__.py
================================================
================================================
FILE: cam_utils.py
================================================
import numpy as np
from scipy.spatial.transform import Rotation as R
import torch
def dot(x, y):
if isinstance(x, np.ndarray):
return np.sum(x * y, -1, keepdims=True)
else:
return torch.sum(x * y, -1, keepdim=True)
def length(x, eps=1e-20):
if isinstance(x, np.ndarray):
return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
else:
return torch.sqrt(torch.clamp(dot(x, x), min=eps))
def safe_normalize(x, eps=1e-20):
return x / length(x, eps)
def look_at(campos, target, opengl=True):
# campos: [N, 3], camera/eye position
# target: [N, 3], object to look at
# return: [N, 3, 3], rotation matrix
if not opengl:
# camera forward aligns with -z
forward_vector = safe_normalize(target - campos)
up_vector = np.array([0, 1, 0], dtype=np.float32)
right_vector = safe_normalize(np.cross(forward_vector, up_vector))
up_vector = safe_normalize(np.cross(right_vector, forward_vector))
else:
# camera forward aligns with +z
forward_vector = safe_normalize(campos - target)
up_vector = np.array([0, 1, 0], dtype=np.float32)
right_vector = safe_normalize(np.cross(up_vector, forward_vector))
up_vector = safe_normalize(np.cross(forward_vector, right_vector))
R = np.stack([right_vector, up_vector, forward_vector], axis=1)
return R
# elevation & azimuth to pose (cam2world) matrix
def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):
# radius: scalar
# elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)
# azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)
# return: [4, 4], camera pose matrix
if is_degree:
elevation = np.deg2rad(np.array(elevation))
azimuth = np.deg2rad(np.array(azimuth))
x = radius * np.cos(elevation) * np.sin(azimuth)
y = - radius * np.sin(elevation)
z = radius * np.cos(elevation) * np.cos(azimuth)
if target is None:
target = np.zeros([3], dtype=np.float32)
campos = np.array([x, y, z]) + target # [3]
T = np.eye(4, dtype=np.float32)
T[:3, :3] = look_at(campos, target, opengl)
T[:3, 3] = campos
return T
class OrbitCamera:
def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
self.W = W
self.H = H
self.radius = r # camera distance from center
self.fovy = np.deg2rad(fovy) # deg 2 rad
self.near = near
self.far = far
self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point
self.rot = R.from_matrix(np.eye(3))
self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized!
@property
def fovx(self):
return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)
@property
def campos(self):
return self.pose[:3, 3]
# pose (c2w)
@property
def pose(self):
# first move camera to radius
res = np.eye(4, dtype=np.float32)
res[2, 3] = self.radius # opengl convention...
# rotate
rot = np.eye(4, dtype=np.float32)
rot[:3, :3] = self.rot.as_matrix()
res = rot @ res
# translate
res[:3, 3] -= self.center
return res
# view (w2c)
@property
def view(self):
return np.linalg.inv(self.pose)
# projection (perspective)
@property
def perspective(self):
y = np.tan(self.fovy / 2)
aspect = self.W / self.H
return np.array(
[
[1 / (y * aspect), 0, 0, 0],
[0, -1 / y, 0, 0],
[
0,
0,
-(self.far + self.near) / (self.far - self.near),
-(2 * self.far * self.near) / (self.far - self.near),
],
[0, 0, -1, 0],
],
dtype=np.float32,
)
# intrinsics
@property
def intrinsics(self):
focal = self.H / (2 * np.tan(self.fovy / 2))
return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)
@property
def mvp(self):
return self.perspective @ np.linalg.inv(self.pose) # [4, 4]
def orbit(self, dx, dy):
# rotate along camera up/side axis!
side = self.rot.as_matrix()[:3, 0]
rotvec_x = self.up * np.radians(-0.05 * dx)
rotvec_y = side * np.radians(-0.05 * dy)
self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot
def scale(self, delta):
self.radius *= 1.1 ** (-delta)
def pan(self, dx, dy, dz=0):
# pan in camera coordinate system (careful on the sensitivity!)
self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])
================================================
FILE: configs/stag4d.yaml
================================================
### Input
# input rgba image path (default to None, can be load in GUI too)
input:
# input text prompt (default to None, can be input in GUI too)
prompt: a minion
# input mesh for stage 2 (auto-search from stage 1 output path if None)
mesh:
# estimated elevation angle for input image
elevation: 0
# reference image resolution
ref_size: 512
# density thresh for mesh extraction
density_thresh: 1
device: cuda
#dynamic
size: 30
path: dataset/minions
# checkpoint to load for stage 1 (should be a ply file)
load:
### Output
outdir: logs
mesh_format: obj
save_path: ???
save_step: 8000
#checkpoint to load for stage fine (should be a path of ply with deform pth)
load_path:
load_step:
valid_interval: 500
### Training
# guidance loss weights (0 to disable)
lambda_sd: 0
mvdream: False
lambda_zero123: 1
lambda_tv: 1
scale_loss_ratio: 7.5
imagedream: False
# training batch size per iter
batch_size: 4
# training iterations for stage 1
iters: 2000
# training iterations for stage 2
iters_refine: 50
# training camera radius
radius: 2
# training camera fovy
fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61
# whether allow geom training in stage 2
train_geo: False
# prob to invert background color during training (0 = always black, 1 = always white)
invert_bg_prob: 0.5
### GUI
gui: False
force_cuda_rast: False
# GUI resolution
H: 800
W: 800
deformation_lr_init : 0.00016
deformation_lr_final : 0.000016
deformation_lr_delay_mult : 0.02
grid_lr_init : 0.0016
grid_lr_final : 0.00016
### Gaussian splatting
num_pts: 10000
sh_degree: 0
position_lr_init : 0.0002
position_lr_final : 0.000002
position_lr_delay_mult: 0.01
position_lr_max_steps: 2000
position_lr_max_steps2: 5000
feature_lr: 0.005
opacity_lr: 0.02
scaling_lr: 0.01
rotation_lr: 0.002
init_steps: 700
percent_dense: 0.1
density_start_iter: 1200
density_end_iter: 6000
densification_interval: 100
opacity_reset_interval: 700
densify_grad_threshold_percent: 0.025
time_smoothness_weight: 5
plane_tv_weight: 0.05
l1_time_planes: 0.05
### Textured Mesh
geom_lr: 0.0001
texture_lr: 0.2
================================================
FILE: dataset_4d.py
================================================
import os
import cv2
import glob
import json
import tqdm
import random
import numpy as np
from scipy.spatial.transform import Slerp, Rotation
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import rembg
import glob
class SparseDataset:
def __init__(self, opt, size,device='cuda', type='train', H=256, W=256):
super().__init__()
self.opt = opt
self.device = device
self.type = type # train, val, test
self.size = size
self.H = H
self.W = W
self.path = opt.path
self.cx = self.H / 2
self.cy = self.W / 2
self.bg_remover=None
def collate_ref(self,index):
#print(index,str(index))
file = os.path.join(self.path,'ref/{}_rgba.png'.format(str(index)))
#print(f'[INFO] load image from {file}...')
img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
if img.shape[-1] == 3:
if self.bg_remover is None:
self.bg_remover = rembg.new_session()
img = rembg.remove(img, session=self.bg_remover)
img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32) / 255.0
self.input_mask = img[..., 3:]
# white bg
self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)
# bgr to rgb
self.input_img = self.input_img[..., ::-1].copy()
return self.input_img ,self.input_mask
def collate_zero123(self,index):
self.pattern=os.path.join(self.path,'zero123/{}_rgba/*.png'.format(str(index)))
self.input_imgs=[]
self.input_masks=[]
file_list = glob.glob(self.pattern)
#print(self.pattern,file_list)
for files in sorted(file_list):
#print(f'[INFO] load image from {files}...')
img = cv2.imread(files, cv2.IMREAD_UNCHANGED)
if img.shape[-1] == 3:
if self.bg_remover is None:
self.bg_remover = rembg.new_session()
img = rembg.remove(img, session=self.bg_remover)
img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
img = img.astype(np.float32) / 255.0
self.input_mask = img[..., 3:]
# white bg
self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)
# bgr to rgb
self.input_img = self.input_img[..., ::-1].copy()
self.input_imgs.append(self.input_img)
self.input_masks.append(self.input_mask)
return self.input_imgs, self.input_masks
def collate(self, index):
ref_view_batch,input_mask_batch,zero123_view_batch,zero123_masks_batch = [],[],[],[]
for index in np.arange(self.size):
ref_view,input_mask = self.collate_ref(index)
zero123_view,zero123_masks = self.collate_zero123(index)
ref_view_batch.append(ref_view)
input_mask_batch.append(input_mask)
zero123_view_batch.append(zero123_view)
zero123_masks_batch.append(zero123_masks)
return ref_view_batch, input_mask_batch,zero123_view_batch,zero123_masks_batch
def dataloader(self):
loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate,shuffle=False, num_workers=0)
return loader
def dataloader_d(self):
loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate_d,shuffle=False, num_workers=0)
return loader
================================================
FILE: deform.py
================================================
import functools
import math
import os
import time
from tkinter import W
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load
import torch.nn.init as init
import abc
import itertools
import logging as log
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
class Deformation(nn.Module):
def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[], args=None):
super(Deformation, self).__init__()
self.D = D
self.W = W
self.input_ch = input_ch
self.input_ch_time = input_ch_time
self.skips = skips
self.no_grid=False
self.no_ds=False
self.no_dr=False
self.no_do=True
self.bounds = 1.6
self.kplanes_config = {
'grid_dimensions': 2,
'input_coordinate_dim': 4,
'output_coordinate_dim': 32,
'resolution': [64, 64, 64, 25]
}
self.multires = [1, 2, 4, 8]
self.no_grid = self.no_grid
self.grid = HexPlaneField(self.bounds, self.kplanes_config, self.multires)
self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_net()
def create_net(self):
mlp_out_dim = 0
if self.no_grid:
self.feature_out = [nn.Linear(4,self.W)]
else:
self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)]
for i in range(self.D-1):
self.feature_out.append(nn.ReLU())
self.feature_out.append(nn.Linear(self.W,self.W))
self.feature_out = nn.Sequential(*self.feature_out)
output_dim = self.W
return \
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4)), \
nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1))
def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb):
if self.no_grid:
h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1)
else:
grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1])
h = grid_feature
h = self.feature_out(h)
return h
def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None, time_emb=None):
if time_emb is None:
return self.forward_static(rays_pts_emb[:,:3])
else:
return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, time_emb)
def forward_static(self, rays_pts_emb):
grid_feature = self.grid(rays_pts_emb[:,:3])
dx = self.static_mlp(grid_feature)
return rays_pts_emb[:, :3] + dx
def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, time_emb):
hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_emb).float()
dx = self.pos_deform(hidden)
pts = rays_pts_emb[:, :3] + dx
if self.no_ds:
scales = scales_emb[:,:3]
else:
ds = self.scales_deform(hidden)
scales = scales_emb[:,:3] + ds
if self.no_dr:
rotations = rotations_emb[:,:4]
else:
dr = self.rotations_deform(hidden)
rotations = rotations_emb[:,:4] + dr
if self.no_do:
opacity = opacity_emb[:,:1]
else:
do = self.opacity_deform(hidden)
opacity = opacity_emb[:,:1] + do
# + do
# print("deformation value:","pts:",torch.abs(dx).mean(),"rotation:",torch.abs(dr).mean())
return pts, scales, rotations, opacity
def get_mlp_parameters(self):
parameter_list = []
for name, param in self.named_parameters():
if "grid" not in name:
parameter_list.append(param)
return parameter_list
def get_grid_parameters(self):
return list(self.grid.parameters() )
# + list(self.timegrid.parameters())
class deform_network(nn.Module):
def __init__(self) :
super(deform_network, self).__init__()
net_width = 64
timebase_pe = 4
defor_depth= 1
posbase_pe= 10
scale_rotation_pe = 2
opacity_pe = 2
timenet_width = 64
timenet_output = 32
times_ch = 2*timebase_pe+1
self.timenet = nn.Sequential(
nn.Linear(times_ch, timenet_width), nn.ReLU(),
nn.Linear(timenet_width, timenet_output))
self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(4+3)+((4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output, args=None)
self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)]))
self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)]))
self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)]))
self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)]))
self.apply(initialize_weights)
# print(self)
def forward(self, point, scales=None, rotations=None, opacity=None, times_sel=None):
if times_sel is not None:
return self.forward_dynamic(point, scales, rotations, opacity, times_sel)
else:
return self.forward_static(point)
def forward_static(self, points):
points = self.deformation_net(points)
return points
def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, times_sel=None):
# times_emb = poc_fre(times_sel, self.time_poc)
means3D, scales, rotations, opacity = self.deformation_net( point,
scales,
rotations,
opacity,
# times_feature,
times_sel)
return means3D, scales, rotations, opacity
def get_mlp_parameters(self):
return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters())
def get_grid_parameters(self):
return self.deformation_net.get_grid_parameters()
def initialize_weights(m):
if isinstance(m, nn.Linear):
# init.constant_(m.weight, 0)
init.xavier_uniform_(m.weight,gain=1)
if m.bias is not None:
init.xavier_uniform_(m.weight,gain=1)
# init.constant_(m.bias, 0)
def get_normalized_directions(directions):
"""SH encoding must be in the range [0, 1]
Args:
directions: batch of directions
"""
return (directions + 1.0) / 2.0
def normalize_aabb(pts, aabb):
return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0
def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor:
grid_dim = coords.shape[-1]
if grid.dim() == grid_dim + 1:
# no batch dimension present, need to add it
grid = grid.unsqueeze(0)
if coords.dim() == 2:
coords = coords.unsqueeze(0)
if grid_dim == 2 or grid_dim == 3:
grid_sampler = F.grid_sample
else:
raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only "
f"implemented for 2 and 3D data.")
coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:]))
B, feature_dim = grid.shape[:2]
n = coords.shape[-2]
interp = grid_sampler(
grid, # [B, feature_dim, reso, ...]
coords, # [B, 1, ..., n, grid_dim]
align_corners=align_corners,
mode='bilinear', padding_mode='border')
interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim]
interp = interp.squeeze() # [B?, n, feature_dim?]
return interp
def init_grid_param(
grid_nd: int,
in_dim: int,
out_dim: int,
reso: Sequence[int],
a: float = 0.1,
b: float = 0.5):
assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"
has_time_planes = in_dim == 4
assert grid_nd <= in_dim
coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
grid_coefs = nn.ParameterList()
for ci, coo_comb in enumerate(coo_combs):
new_grid_coef = nn.Parameter(torch.empty(
[1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]
))
if has_time_planes and 3 in coo_comb: # Initialize time planes to 1
nn.init.ones_(new_grid_coef)
else:
nn.init.uniform_(new_grid_coef, a=a, b=b)
grid_coefs.append(new_grid_coef)
return grid_coefs
def interpolate_ms_features(pts: torch.Tensor,
ms_grids: Collection[Iterable[nn.Module]],
grid_dimensions: int,
concat_features: bool,
num_levels: Optional[int],
) -> torch.Tensor:
coo_combs = list(itertools.combinations(
range(pts.shape[-1]), grid_dimensions)
)
if num_levels is None:
num_levels = len(ms_grids)
multi_scale_interp = [] if concat_features else 0.
grid: nn.ParameterList
for scale_id, grid in enumerate(ms_grids[:num_levels]):
interp_space = 1.
for ci, coo_comb in enumerate(coo_combs):
# interpolate in plane
feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso
interp_out_plane = (
grid_sample_wrapper(grid[ci], pts[..., coo_comb])
.view(-1, feature_dim)
)
# compute product over planes
interp_space = interp_space * interp_out_plane
# combine over scales
if concat_features:
multi_scale_interp.append(interp_space)
else:
multi_scale_interp = multi_scale_interp + interp_space
if concat_features:
multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
return multi_scale_interp
class HexPlaneField(nn.Module):
def __init__(
self,
bounds,
planeconfig,
multires
) -> None:
super().__init__()
aabb = torch.tensor([[bounds,bounds,bounds],
[-bounds,-bounds,-bounds]])
self.aabb = nn.Parameter(aabb, requires_grad=False)
self.grid_config = [planeconfig]
self.multiscale_res_multipliers = multires
self.concat_features = True
# 1. Init planes
self.grids = nn.ModuleList()
self.feat_dim = 0
for res in self.multiscale_res_multipliers:
# initialize coordinate grid
config = self.grid_config[0].copy()
# Resolution fix: multi-res only on spatial planes
config["resolution"] = [
r * res for r in config["resolution"][:3]
] + config["resolution"][3:]
gp = init_grid_param(
grid_nd=config["grid_dimensions"],
in_dim=config["input_coordinate_dim"],
out_dim=config["output_coordinate_dim"],
reso=config["resolution"],
)
# shape[1] is out-dim - Concatenate over feature len for each scale
if self.concat_features:
self.feat_dim += gp[-1].shape[1]
else:
self.feat_dim = gp[-1].shape[1]
self.grids.append(gp)
# print(f"Initialized model grids: {self.grids}")
print("feature_dim:",self.feat_dim)
def set_aabb(self,xyz_max, xyz_min):
aabb = torch.tensor([
xyz_max,
xyz_min
])
self.aabb = nn.Parameter(aabb,requires_grad=True)
print("Voxel Plane: set aabb=",self.aabb)
def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
"""Computes and returns the densities."""
pts = normalize_aabb(pts, self.aabb)
pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4]
pts = pts.reshape(-1, pts.shape[-1])
features = interpolate_ms_features(
pts, ms_grids=self.grids, # noqa
grid_dimensions=self.grid_config[0]["grid_dimensions"],
concat_features=self.concat_features, num_levels=None)
if len(features) < 1:
features = torch.zeros((0, 1)).to(features.device)
return features
def forward(self,
pts: torch.Tensor,
timestamps: Optional[torch.Tensor] = None):
features = self.get_density(pts, timestamps)
return features
def compute_plane_tv(t):
batch_size, c, h, w = t.shape
count_h = batch_size * c * (h - 1) * w
count_w = batch_size * c * h * (w - 1)
h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum()
w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum()
return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg
def compute_plane_smoothness(t):
batch_size, c, h, w = t.shape
# Convolve with a second derivative filter, in the time dimension which is dimension 2
first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w]
second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w]
# Take the L2 norm of the result
return torch.square(torch.abs(second_difference)).mean()
class Regularizer():
def __init__(self, reg_type, initialization):
self.reg_type = reg_type
self.initialization = initialization
self.weight = float(self.initialization)
self.last_reg = None
def step(self, global_step):
pass
def report(self, d):
if self.last_reg is not None:
d[self.reg_type].update(self.last_reg.item())
def regularize(self, *args, **kwargs) -> torch.Tensor:
out = self._regularize(*args, **kwargs) * self.weight
self.last_reg = out.detach()
return out
@abc.abstractmethod
def _regularize(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError()
def __str__(self):
return f"Regularizer({self.reg_type}, weight={self.weight})"
class PlaneTV(Regularizer):
def __init__(self, initial_value, what: str = 'field'):
if what not in {'field', 'proposal_network'}:
raise ValueError(f'what must be one of "field" or "proposal_network" '
f'but {what} was passed.')
name = f'planeTV-{what[:2]}'
super().__init__(name, initial_value)
self.what = what
def step(self, global_step):
pass
def _regularize(self, model, **kwargs):
multi_res_grids: Sequence[nn.ParameterList]
if self.what == 'field':
multi_res_grids = model.field.grids
elif self.what == 'proposal_network':
multi_res_grids = [p.grids for p in model.proposal_networks]
else:
raise NotImplementedError(self.what)
total = 0
# Note: input to compute_plane_tv should be of shape [batch_size, c, h, w]
for grids in multi_res_grids:
if len(grids) == 3:
spatial_grids = [0, 1, 2]
else:
spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal
for grid_id in spatial_grids:
total += compute_plane_tv(grids[grid_id])
for grid in grids:
# grid: [1, c, h, w]
total += compute_plane_tv(grid)
return total
class TimeSmoothness(Regularizer):
def __init__(self, initial_value, what: str = 'field'):
if what not in {'field', 'proposal_network'}:
raise ValueError(f'what must be one of "field" or "proposal_network" '
f'but {what} was passed.')
name = f'time-smooth-{what[:2]}'
super().__init__(name, initial_value)
self.what = what
def _regularize(self, model, **kwargs) -> torch.Tensor:
multi_res_grids: Sequence[nn.ParameterList]
if self.what == 'field':
multi_res_grids = model.field.grids
elif self.what == 'proposal_network':
multi_res_grids = [p.grids for p in model.proposal_networks]
else:
raise NotImplementedError(self.what)
total = 0
# model.grids is 6 x [1, rank * F_dim, reso, reso]
for grids in multi_res_grids:
if len(grids) == 3:
time_grids = []
else:
time_grids = [2, 4, 5]
for grid_id in time_grids:
total += compute_plane_smoothness(grids[grid_id])
return torch.as_tensor(total)
class L1ProposalNetwork(Regularizer):
def __init__(self, initial_value):
super().__init__('l1-proposal-network', initial_value)
def _regularize(self, model, **kwargs) -> torch.Tensor:
grids = [p.grids for p in model.proposal_networks]
total = 0.0
for pn_grids in grids:
for grid in pn_grids:
total += torch.abs(grid).mean()
return torch.as_tensor(total)
class DepthTV(Regularizer):
def __init__(self, initial_value):
super().__init__('tv-depth', initial_value)
def _regularize(self, model, model_out, **kwargs) -> torch.Tensor:
depth = model_out['depth']
tv = compute_plane_tv(
depth.reshape(64, 64)[None, None, :, :]
)
return tv
class L1TimePlanes(Regularizer):
def __init__(self, initial_value, what='field'):
if what not in {'field', 'proposal_network'}:
raise ValueError(f'what must be one of "field" or "proposal_network" '
f'but {what} was passed.')
super().__init__(f'l1-time-{what[:2]}', initial_value)
self.what = what
def _regularize(self, model, **kwargs) -> torch.Tensor:
# model.grids is 6 x [1, rank * F_dim, reso, reso]
multi_res_grids: Sequence[nn.ParameterList]
if self.what == 'field':
multi_res_grids = model.field.grids
elif self.what == 'proposal_network':
multi_res_grids = [p.grids for p in model.proposal_networks]
else:
raise NotImplementedError(self.what)
total = 0.0
for grids in multi_res_grids:
if len(grids) == 3:
continue
else:
# These are the spatiotemporal grids
spatiotemporal_grids = [2, 4, 5]
for grid_id in spatiotemporal_grids:
total += torch.abs(1 - grids[grid_id]).mean()
return torch.as_tensor(total)
================================================
FILE: gs_renderer_4d.py
================================================
import os
import math
import numpy as np
from typing import NamedTuple
from plyfile import PlyData, PlyElement
import torch
from torch import nn
from diff_gauss import (
GaussianRasterizationSettings,
GaussianRasterizer,
)
from simple_knn._C import distCUDA2
from sh_utils import eval_sh, SH2RGB, RGB2SH
from deform import *
def inverse_sigmoid(x):
return torch.log(x/(1-x))
def get_expon_lr_func(
lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
def helper(step):
if lr_init == lr_final:
# constant lr, ignore other params
return lr_init
if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
# Disable this parameter
return 0.0
if lr_delay_steps > 0:
# A kind of reverse cosine decay.
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
)
else:
delay_rate = 1.0
t = np.clip(step / max_steps, 0, 1)
log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
return delay_rate * log_lerp
return helper
def strip_lowerdiag(L):
uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")
uncertainty[:, 0] = L[:, 0, 0]
uncertainty[:, 1] = L[:, 0, 1]
uncertainty[:, 2] = L[:, 0, 2]
uncertainty[:, 3] = L[:, 1, 1]
uncertainty[:, 4] = L[:, 1, 2]
uncertainty[:, 5] = L[:, 2, 2]
return uncertainty
def strip_symmetric(sym):
return strip_lowerdiag(sym)
def gaussian_3d_coeff(xyzs, covs):
# xyzs: [N, 3]
# covs: [N, 6]
x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5]
# eps must be small enough !!!
inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)
inv_a = (d * f - e**2) * inv_det
inv_b = (e * c - b * f) * inv_det
inv_c = (e * b - c * d) * inv_det
inv_d = (a * f - c**2) * inv_det
inv_e = (b * c - e * a) * inv_det
inv_f = (a * d - b**2) * inv_det
power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e
power[power > 0] = -1e10 # abnormal values... make weights 0
return torch.exp(power)
def build_rotation(r):
norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])
q = r / norm[:, None]
R = torch.zeros((q.size(0), 3, 3), device='cuda')
r = q[:, 0]
x = q[:, 1]
y = q[:, 2]
z = q[:, 3]
R[:, 0, 0] = 1 - 2 * (y*y + z*z)
R[:, 0, 1] = 2 * (x*y - r*z)
R[:, 0, 2] = 2 * (x*z + r*y)
R[:, 1, 0] = 2 * (x*y + r*z)
R[:, 1, 1] = 1 - 2 * (x*x + z*z)
R[:, 1, 2] = 2 * (y*z - r*x)
R[:, 2, 0] = 2 * (x*z - r*y)
R[:, 2, 1] = 2 * (y*z + r*x)
R[:, 2, 2] = 1 - 2 * (x*x + y*y)
return R
def build_scaling_rotation(s, r):
L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
R = build_rotation(r)
L[:,0,0] = s[:,0]
L[:,1,1] = s[:,1]
L[:,2,2] = s[:,2]
L = R @ L
return L
class BasicPointCloud(NamedTuple):
points: np.array
colors: np.array
normals: np.array
class GaussianModel:
def setup_functions(self):
def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
L = build_scaling_rotation(scaling_modifier * scaling, rotation)
actual_covariance = L @ L.transpose(1, 2)
symm = strip_symmetric(actual_covariance)
return symm
self.scaling_activation = torch.exp
self.scaling_inverse_activation = torch.log
self.covariance_activation = build_covariance_from_scaling_rotation
self.opacity_activation = torch.sigmoid
self.inverse_opacity_activation = inverse_sigmoid
self.rotation_activation = torch.nn.functional.normalize
def initialize(self, initial_values, raw=False):
# NOTE: actual initialization is done in trainer
# raw stands for raw values, i.e. not passed through activation
self._xyz = nn.Parameter(initial_values["mean"].requires_grad_(True)).to('cuda')
self._rotation = nn.Parameter(initial_values["qvec"].requires_grad_(True)).to('cuda')
#self._scaling = nn.Parameter(initial_values["svec"].requires_grad_(True)).to('cuda')
#self._features_dc = nn.Parameter(initial_values["color"].requires_grad_(True)).to('cuda')
self._opacity = nn.Parameter(initial_values["alpha"].requires_grad_(True)).to('cuda')
def __init__(self, sh_degree : int,args = None):
self.active_sh_degree = 0
self.max_sh_degree = sh_degree
self._xyz = torch.empty(0)
self._features_dc = torch.empty(0)
self._features_rest = torch.empty(0)
self._scaling = torch.empty(0)
self._rotation = torch.empty(0)
self._opacity = torch.empty(0)
self.max_radii2D = torch.empty(0)
self.xyz_gradient_accum = torch.empty(0)
self.denom = torch.empty(0)
self.optimizer = None
self.percent_dense = 0
self.spatial_lr_scale = 0
self._deformation_table = torch.empty(0)
self._deformation = deform_network()
self.setup_functions()
def capture(self):
return (
self.active_sh_degree,
self._xyz,
self._deformation.state_dict(),
self._deformation_table,
self._features_dc,
self._features_rest,
self._scaling,
self._rotation,
self._opacity,
self.max_radii2D,
self.xyz_gradient_accum,
self.denom,
self.optimizer.state_dict(),
self.spatial_lr_scale,
)
def restore(self, model_args, training_args):
(self.active_sh_degree,
self._xyz,
self._deformation_table,
self._deformation,
self._features_dc,
self._features_rest,
self._scaling,
self._rotation,
self._opacity,
self.max_radii2D,
xyz_gradient_accum,
denom,
opt_dict,
self.spatial_lr_scale) = model_args
self.training_setup(training_args)
self.xyz_gradient_accum = xyz_gradient_accum
self.denom = denom
self.optimizer.load_state_dict(opt_dict)
@property
def get_scaling(self):
return self.scaling_activation(self._scaling)
@property
def get_rotation(self):
return self.rotation_activation(self._rotation)
@property
def get_xyz(self):
return self._xyz
@property
def get_features(self):
features_dc = self._features_dc
features_rest = self._features_rest
return torch.cat((features_dc, features_rest), dim=1)
@property
def get_opacity(self):
return self.opacity_activation(self._opacity)
def get_covariance(self, scaling_modifier = 1):
return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)
def oneupSHdegree(self):
if self.active_sh_degree < self.max_sh_degree:
self.active_sh_degree += 1
def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float = 1):
self.spatial_lr_scale = spatial_lr_scale
fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
features[:, :3, 0 ] = fused_color
features[:, 3:, 1:] = 0.0
print("Number of points at initialisation : ", fused_point_cloud.shape[0])
dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
rots[:, 0] = 1
opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))
self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
self._scaling = nn.Parameter(scales.requires_grad_(True))
self._rotation = nn.Parameter(rots.requires_grad_(True))
self._opacity = nn.Parameter(opacities.requires_grad_(True))
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
#print(self._xyz.shape,self._rotation.shape)
self._deformation = self._deformation.to("cuda")
self.active_sh_degree = self.max_sh_degree
self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
def training_setup(self, training_args):
self.percent_dense = training_args.percent_dense
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
l = [
{'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
{'params': list(self._deformation.get_mlp_parameters()), 'lr': training_args.deformation_lr_init * self.spatial_lr_scale, "name": "deformation"},
{'params': list(self._deformation.get_grid_parameters()), 'lr': training_args.grid_lr_init * self.spatial_lr_scale, "name": "grid"},
{'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
{'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
{'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
{'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
{'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
]
self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
lr_final=training_args.position_lr_final*self.spatial_lr_scale,
lr_delay_mult=training_args.position_lr_delay_mult,
max_steps=training_args.position_lr_max_steps)
self.deformation_scheduler_args = get_expon_lr_func(lr_init=training_args.deformation_lr_init*self.spatial_lr_scale,
lr_final=training_args.deformation_lr_final*self.spatial_lr_scale,
lr_delay_mult=training_args.deformation_lr_delay_mult,
max_steps=training_args.position_lr_max_steps)
self.grid_scheduler_args = get_expon_lr_func(lr_init=training_args.grid_lr_init*self.spatial_lr_scale,
lr_final=training_args.grid_lr_final*self.spatial_lr_scale,
lr_delay_mult=training_args.deformation_lr_delay_mult,
max_steps=training_args.position_lr_max_steps)
def update_learning_rate(self, iteration):
''' Learning rate scheduling per step '''
for param_group in self.optimizer.param_groups:
if param_group["name"] == "xyz":
lr = self.xyz_scheduler_args(iteration)
param_group['lr'] = lr
# return lr
if "grid" in param_group["name"]:
lr = self.grid_scheduler_args(iteration)
param_group['lr'] = lr
# return lr
elif param_group["name"] == "deformation":
lr = self.deformation_scheduler_args(iteration)
param_group['lr'] = lr
# return lr
def construct_list_of_attributes(self):
l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
# All channels except the 3 DC
for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
l.append('f_dc_{}'.format(i))
for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
l.append('f_rest_{}'.format(i))
l.append('opacity')
for i in range(self._scaling.shape[1]):
l.append('scale_{}'.format(i))
for i in range(self._rotation.shape[1]):
l.append('rot_{}'.format(i))
return l
def save_ply(self, path):
os.makedirs(os.path.dirname(path), exist_ok=True)
xyz = self._xyz.detach().cpu().numpy()
normals = np.zeros_like(xyz)
f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
opacities = self._opacity.detach().cpu().numpy()
scale = self._scaling.detach().cpu().numpy()
rotation = self._rotation.detach().cpu().numpy()
dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]
elements = np.empty(xyz.shape[0], dtype=dtype_full)
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
elements[:] = list(map(tuple, attributes))
el = PlyElement.describe(elements, 'vertex')
PlyData([el]).write(path)
def compute_deformation(self,time):
deform = self._deformation[:,:,:time].sum(dim=-1)
xyz = self._xyz + deform
return xyz
def load_model(self, path):
print("loading model from exists{}".format(path))
weight_dict = torch.load(os.path.join(path,"deformation.pth"),map_location="cuda")
self._deformation.load_state_dict(weight_dict)
self._deformation = self._deformation.to("cuda")
self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
if os.path.exists(os.path.join(path, "deformation_table.pth")):
self._deformation_table = torch.load(os.path.join(path, "deformation_table.pth"),map_location="cuda")
if os.path.exists(os.path.join(path, "deformation_accum.pth")):
self._deformation_accum = torch.load(os.path.join(path, "deformation_accum.pth"),map_location="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
self._deformation = self._deformation.to("cuda")
def save_deformation(self, path):
torch.save(self._deformation.state_dict(),os.path.join(path, "deformation.pth"))
torch.save(self._deformation_table,os.path.join(path, "deformation_table.pth"))
torch.save(self._deformation_accum,os.path.join(path, "deformation_accum.pth"))
def load_ply(self, path):
plydata = PlyData.read(path)
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"])), axis=1)
opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]
print("Number of points at loading : ", xyz.shape[0])
features_dc = np.zeros((xyz.shape[0], 3, 1))
features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])
extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
# Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))
scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
scales = np.zeros((xyz.shape[0], len(scale_names)))
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = np.asarray(plydata.elements[0][attr_name])
rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
rots = np.zeros((xyz.shape[0], len(rot_names)))
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = np.asarray(plydata.elements[0][attr_name])
self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
#print(self._xyz.shape,self._rotation.shape)
self._deformation = self._deformation.to("cuda")
self.active_sh_degree = self.max_sh_degree
self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
def replace_tensor_to_optimizer(self, tensor, name):
optimizable_tensors = {}
for group in self.optimizer.param_groups:
if group["name"] == name:
stored_state = self.optimizer.state.get(group['params'][0], None)
stored_state["exp_avg"] = torch.zeros_like(tensor)
stored_state["exp_avg_sq"] = torch.zeros_like(tensor)
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def _prune_optimizer(self, mask):
optimizable_tensors = {}
for group in self.optimizer.param_groups:
if len(group["params"]) > 1:
continue
stored_state = self.optimizer.state.get(group['params'][0], None)
if stored_state is not None:
stored_state["exp_avg"] = stored_state["exp_avg"][mask]
stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
else:
group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def prune_points(self, mask):
valid_points_mask = ~mask
optimizable_tensors = self._prune_optimizer(valid_points_mask)
self._xyz = optimizable_tensors["xyz"]
self._features_dc = optimizable_tensors["f_dc"]
self._features_rest = optimizable_tensors["f_rest"]
self._opacity = optimizable_tensors["opacity"]
self._scaling = optimizable_tensors["scaling"]
self._rotation = optimizable_tensors["rotation"]
self._deformation_accum = self._deformation_accum[valid_points_mask]
self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
self._deformation_table = self._deformation_table[valid_points_mask]
self.denom = self.denom[valid_points_mask]
self.max_radii2D = self.max_radii2D[valid_points_mask]
def cat_tensors_to_optimizer(self, tensors_dict):
optimizable_tensors = {}
for group in self.optimizer.param_groups:
if len(group["params"])>1:continue
assert len(group["params"]) == 1
extension_tensor = tensors_dict[group["name"]]
stored_state = self.optimizer.state.get(group['params'][0], None)
if stored_state is not None:
stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)
del self.optimizer.state[group['params'][0]]
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
self.optimizer.state[group['params'][0]] = stored_state
optimizable_tensors[group["name"]] = group["params"][0]
else:
group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
optimizable_tensors[group["name"]] = group["params"][0]
return optimizable_tensors
def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table):
d = {"xyz": new_xyz,
"f_dc": new_features_dc,
"f_rest": new_features_rest,
"opacity": new_opacities,
"scaling" : new_scaling,
"rotation" : new_rotation,
# "deformation": new_deformation
}
optimizable_tensors = self.cat_tensors_to_optimizer(d)
self._xyz = optimizable_tensors["xyz"]
self._features_dc = optimizable_tensors["f_dc"]
self._features_rest = optimizable_tensors["f_rest"]
self._opacity = optimizable_tensors["opacity"]
self._scaling = optimizable_tensors["scaling"]
self._rotation = optimizable_tensors["rotation"]
# self._deformation = optimizable_tensors["deformation"]
self._deformation_table = torch.cat([self._deformation_table,new_deformation_table],-1)
self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self._deformation_accum = torch.zeros((self.get_xyz.shape[0], 3), device="cuda")
self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
n_init_points = self.get_xyz.shape[0]
# Extract points that satisfy the gradient condition
padded_grad = torch.zeros((n_init_points), device="cuda")
padded_grad[:grads.shape[0]] = grads.squeeze()
selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
selected_pts_mask = torch.logical_and(selected_pts_mask,
torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
if not selected_pts_mask.any():
return
stds = self.get_scaling[selected_pts_mask].repeat(N,1)
means =torch.zeros((stds.size(0), 3),device="cuda")
samples = torch.normal(mean=means, std=stds)
rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
new_deformation_table = self._deformation_table[selected_pts_mask].repeat(N)
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_deformation_table)
prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
self.prune_points(prune_filter)
def densify_and_clone(self, grads, grad_threshold, scene_extent):
selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
selected_pts_mask = torch.logical_and(selected_pts_mask,
torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
new_xyz = self._xyz[selected_pts_mask]
# - 0.001 * self._xyz.grad[selected_pts_mask]
new_features_dc = self._features_dc[selected_pts_mask]
new_features_rest = self._features_rest[selected_pts_mask]
new_opacities = self._opacity[selected_pts_mask]
new_scaling = self._scaling[selected_pts_mask]
new_rotation = self._rotation[selected_pts_mask]
new_deformation_table = self._deformation_table[selected_pts_mask]
self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table)
def densify_and_prune(self, max_grad_percent, min_opacity, extent, max_screen_size):
grads = self.xyz_gradient_accum / self.denom
grads[grads.isnan()] = 0.0
grad_log = torch.log(grads)
grad_log2=grad_log[~grad_log.isnan()]
grad_log3=grad_log[~grad_log2.isinf()]
max_grad_1 = torch.exp(grad_log3.mean()+grad_log3.var()) #adaptive densification with mean and var, unused
max_grad_2 = torch.exp(grad_log3.squeeze(dim=1).sort(descending=True)[0][int(max_grad_percent*grad_log3.shape[0])]) #adaptive densification with relative grad
max_grad = max_grad_2 #choose which to use
#print('max_grad',max_grad_percent,max_grad_1,max_grad_2,grad_log3.mean(),grad_log3.var())
self.densify_and_clone(grads, max_grad, extent)
self.densify_and_split(grads, max_grad, extent)
prune_mask = (self.get_opacity < min_opacity).squeeze()
if max_screen_size:
big_points_vs = self.max_radii2D > max_screen_size
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
small_ws = self.get_scaling.max(dim=1).values<0.001
prune_mask = torch.logical_or(torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws),small_ws)
self.prune_points(prune_mask)
torch.cuda.empty_cache()
def prune(self, min_opacity=0.01, extent=1, max_screen_size=1):
prune_mask = (self.get_opacity < min_opacity).squeeze()
# prune_mask_2 = torch.logical_and(self.get_opacity <= inverse_sigmoid(0.101 , dtype=torch.float, device="cuda"), self.get_opacity >= inverse_sigmoid(0.999 , dtype=torch.float, device="cuda"))
# prune_mask = torch.logical_or(prune_mask, prune_mask_2)
# deformation_sum = abs(self._deformation).sum(dim=-1).mean(dim=-1)
# deformation_mask = (deformation_sum < torch.quantile(deformation_sum, torch.tensor([0.5]).to("cuda")))
# prune_mask = prune_mask & deformation_mask
if max_screen_size:
big_points_vs = self.max_radii2D > max_screen_size
big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
#prune_mask = torch.logical_or(prune_mask, big_points_vs)
small_ws = self.get_scaling.min(dim=1).values<0.001
prune_mask = torch.logical_or(torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws),small_ws)
self.prune_points(prune_mask)
def standard_constaint(self):
means3D = self._xyz.detach()
scales = self._scaling.detach()
rotations = self._rotation.detach()
opacity = self._opacity.detach()
time = torch.tensor(0).to("cuda").repeat(means3D.shape[0],1)
means3D_deform, scales_deform, rotations_deform, _ = self._deformation(means3D, scales, rotations, opacity, time)
position_error = (means3D_deform - means3D)**2
rotation_error = (rotations_deform - rotations)**2
scaling_erorr = (scales_deform - scales)**2
return position_error.mean() + rotation_error.mean() + scaling_erorr.mean()
def add_densification_stats(self, viewspace_point_tensor, update_filter):
#print(viewspace_point_tensor,update_filter)
self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor[update_filter,:2], dim=-1, keepdim=True)
self.denom[update_filter] += 1
@torch.no_grad()
def update_deformation_table(self,threshold):
# print("origin deformation point nums:",self._deformation_table.sum())
self._deformation_table = torch.gt(self._deformation_accum.max(dim=-1).values/100,threshold)
def print_deformation_weight_grad(self):
for name, weight in self._deformation.named_parameters():
if weight.requires_grad:
if weight.grad is None:
print(name," :",weight.grad)
else:
if weight.grad.mean() != 0:
print(name," :",weight.grad.mean(), weight.grad.min(), weight.grad.max())
print("-"*50)
def _plane_regulation(self):
multi_res_grids = self._deformation.deformation_net.grid.grids
total = 0
# model.grids is 6 x [1, rank * F_dim, reso, reso]
for grids in multi_res_grids:
if len(grids) == 3:
time_grids = []
else:
time_grids = [0,1,3]
for grid_id in time_grids:
total += compute_plane_smoothness(grids[grid_id])
return total
def _time_regulation(self):
multi_res_grids = self._deformation.deformation_net.grid.grids
total = 0
# model.grids is 6 x [1, rank * F_dim, reso, reso]
for grids in multi_res_grids:
if len(grids) == 3:
time_grids = []
else:
time_grids =[2, 4, 5]
for grid_id in time_grids:
total += compute_plane_smoothness(grids[grid_id])
return total
def _l1_regulation(self):
# model.grids is 6 x [1, rank * F_dim, reso, reso]
multi_res_grids = self._deformation.deformation_net.grid.grids
total = 0.0
for grids in multi_res_grids:
if len(grids) == 3:
continue
else:
# These are the spatiotemporal grids
spatiotemporal_grids = [2, 4, 5]
for grid_id in spatiotemporal_grids:
total += torch.abs(1 - grids[grid_id]).mean()
return total
def compute_regulation(self, time_smoothness_weight, l1_time_planes_weight, plane_tv_weight):
return plane_tv_weight * self._plane_regulation() + time_smoothness_weight * self._time_regulation() + l1_time_planes_weight * self._l1_regulation()
def getProjectionMatrix(znear, zfar, fovX, fovY):
tanHalfFovY = math.tan((fovY / 2))
tanHalfFovX = math.tan((fovX / 2))
P = torch.zeros(4, 4)
z_sign = 1.0
P[0, 0] = 1 / tanHalfFovX
P[1, 1] = 1 / tanHalfFovY
P[3, 2] = z_sign
P[2, 2] = z_sign * zfar / (zfar - znear)
P[2, 3] = -(zfar * znear) / (zfar - znear)
return P
class MiniCam:
def __init__(self, c2w, width, height, fovy, fovx, znear, zfar,time=0 ):
# c2w (pose) should be in NeRF convention.
self.image_width = width
self.image_height = height
self.FoVy = fovy
self.FoVx = fovx
self.znear = znear
self.zfar = zfar
self.time = time
w2c = np.linalg.inv(c2w)
# rectify...
w2c[1:3, :3] *= -1
w2c[:3, 3] *= -1
self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()
self.projection_matrix = (
getProjectionMatrix(
znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
)
.transpose(0, 1)
.cuda()
)
self.full_proj_transform = self.world_view_transform @ self.projection_matrix
self.camera_center = -torch.tensor(c2w[:3, 3]).cuda()
class Renderer:
def __init__(self, sh_degree=3, white_background=True, radius=1):
self.sh_degree = sh_degree
self.white_background = white_background
self.radius = radius
self.gaussians = GaussianModel(sh_degree)
self.bg_color = torch.tensor(
[1, 1, 1] if white_background else [0, 0, 0],
dtype=torch.float32,
device="cuda",
)
def initialize(self, input=None, num_pts=5000, radius=0.5,initial_values=None):
# load checkpoint
if input is None:
# init from random point cloud
phis = np.random.random((num_pts,)) * 2 * np.pi
costheta = np.random.random((num_pts,)) * 2 - 1
thetas = np.arccos(costheta)
mu = np.random.random((num_pts,))
radius = radius * np.cbrt(mu)
x = radius * np.sin(thetas) * np.cos(phis)
y = radius * np.sin(thetas) * np.sin(phis)
z = radius * np.cos(thetas)
xyz = np.stack((x, y, z), axis=1)
if initial_values is not None:
print(xyz.shape,initial_values["mean"].shape)
R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
xyz = np.dot(initial_values["mean"].numpy(),R)
# xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3
shs = np.random.random((num_pts, 3)) / 255.0
pcd = BasicPointCloud(
points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
)
self.gaussians.create_from_pcd(pcd, 10)
elif isinstance(input, BasicPointCloud):
# load from a provided pcd
self.gaussians.create_from_pcd(input, 1)
else:
# load from saved ply
self.gaussians.load_ply(input)
def render(
self,
viewpoint_camera,
scaling_modifier=1.0,
bg_color=None,
override_color=None,
compute_cov3D_python=False,
convert_SHs_python=False,
stage="fine",
time_int = None,
front_view=False,
):
# Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
screenspace_points = torch.zeros_like(self.gaussians.get_xyz, dtype=self.gaussians.get_xyz.dtype, requires_grad=True, device="cuda") + 0
try:
screenspace_points.retain_grad()
except:
pass
# Set up rasterization configuration
tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)
raster_settings = GaussianRasterizationSettings(
image_height=int(viewpoint_camera.image_height),
image_width=int(viewpoint_camera.image_width),
tanfovx=tanfovx,
tanfovy=tanfovy,
bg=self.bg_color if bg_color is None else bg_color,
scale_modifier=scaling_modifier,
viewmatrix=viewpoint_camera.world_view_transform,
projmatrix=viewpoint_camera.full_proj_transform,
sh_degree=self.gaussians.active_sh_degree,
campos=viewpoint_camera.camera_center,
prefiltered=False,
debug=False,
)
if front_view==True:
print(viewpoint_camera.world_view_transform,viewpoint_camera.full_proj_transform)
rasterizer = GaussianRasterizer(raster_settings=raster_settings)
means3D = self.gaussians.get_xyz
time = torch.tensor(viewpoint_camera.time).to(means3D.device).repeat(means3D.shape[0],1)
means2D = screenspace_points
opacity = self.gaussians._opacity
# If precomputed 3d covariance is provided, use it. If not, then it will be computed from
# scaling / rotation by the rasterizer.
scales = None
rotations = None
cov3D_precomp = None
if compute_cov3D_python:
cov3D_precomp = self.gaussians.get_covariance(scaling_modifier)
else:
scales = self.gaussians._scaling
rotations = self.gaussians._rotation
deformation_point = self.gaussians._deformation_table
if stage == "coarse" :
means3D_deform, scales_deform, rotations_deform, opacity_deform = means3D, scales, rotations, opacity
else:
means3D_deform, scales_deform, rotations_deform, opacity_deform = self.gaussians._deformation(means3D[deformation_point], scales[deformation_point],
rotations[deformation_point], opacity[deformation_point],
time[deformation_point])
# print(time.max())
with torch.no_grad():
self.gaussians._deformation_accum[deformation_point] += torch.abs(means3D_deform-means3D[deformation_point])
#print(torch.abs(means3D_deform-means3D[deformation_point]).mean())
means3D_final = torch.zeros_like(means3D)
rotations_final = torch.zeros_like(rotations)
scales_final = torch.zeros_like(scales)
opacity_final = torch.zeros_like(opacity)
means3D_final[deformation_point] = means3D_deform
rotations_final[deformation_point] = rotations_deform
scales_final[deformation_point] = scales_deform
opacity_final[deformation_point] = opacity_deform
means3D_final[~deformation_point] = means3D[~deformation_point]
rotations_final[~deformation_point] = rotations[~deformation_point]
scales_final[~deformation_point] = scales[~deformation_point]
opacity_final[~deformation_point] = opacity[~deformation_point]
scales_in=scales_final
rotations_in=rotations_final
opacity_in = opacity_final
scales_final = self.gaussians.scaling_activation(scales_final)
rotations_final = self.gaussians.rotation_activation(rotations_final)
opacity = self.gaussians.opacity_activation(opacity)
opacity_final = self.gaussians.opacity_activation(opacity_final)
# If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
# from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
shs = None
colors_precomp = None
shs = self.gaussians.get_features
# Rasterize visible Gaussians to image, obtain their radii (on screen).
rendered_image, rendered_depth, normal, rendered_alpha ,radii, _ = rasterizer(
means3D = means3D_final,
means2D = means2D,
shs = shs,
colors_precomp = colors_precomp,
opacities = opacity_final,
scales = scales_final,
rotations = rotations_final,
cov3Ds_precomp = cov3D_precomp)
return {
"image": rendered_image,
"depth": rendered_depth,
"alpha": rendered_alpha,
"viewspace_points": screenspace_points,
"visibility_filter": radii > 0,
"radii": radii,
'xyz':means3D_final,
'rot':rotations_in,
'xy':means2D,
'color':shs,
'scales':scales_in,
'opacity':opacity_in,
}
================================================
FILE: guidance/zero123_4d_utils.py
================================================
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import (
AutoencoderKL,
UNet2DConditionModel,
DDIMScheduler,
StableDiffusionPipeline,
)
import torchvision.transforms.functional as TF
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
sys.path.append('./')
from zero123 import Zero123Pipeline
class Zero123(nn.Module):
def __init__(self, device, fp16=True, t_range=[0.02, 0.98]):
super().__init__()
self.device = device
self.fp16 = fp16
self.dtype = torch.float16 if fp16 else torch.float32
self.pipe = Zero123Pipeline.from_pretrained(
# "bennyguo/zero123-diffusers",
"ashawkey/zero123-xl-diffusers",
# './model_cache/zero123_xl',
variant="fp16" if self.fp16 else None,
torch_dtype=self.dtype,
).to(self.device)
# for param in self.pipe.parameters():
# param.requires_grad = False
self.pipe.image_encoder.eval()
self.pipe.vae.eval()
self.pipe.unet.eval()
self.pipe.clip_camera_projection.eval()
self.vae = self.pipe.vae
self.unet = self.pipe.unet
self.pipe.set_progress_bar_config(disable=True)
self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.min_step = int(self.num_train_timesteps * t_range[0])
self.max_step = int(self.num_train_timesteps * t_range[1])
self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
self.min_step_percent = [0, 0.5, 0.02, 3000]
self.max_step_percent= [0, 0.95, 0.5, 3000]
self.embeddings = None
self.embedding_list = []
@torch.no_grad()
def get_img_embeds(self, x, input_imgs=None):
# x: image tensor in [0, 1]
x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False)
x_pil = [TF.to_pil_image(image) for image in x]
x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype)
c = self.pipe.image_encoder(x_clip).image_embeds
v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor
self.embeddings = [c, v]
self.additional_embeddings=[]
if input_imgs!=None:
for x in input_imgs:
x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False)
x_pil = [TF.to_pil_image(image) for image in x]
x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype)
c = self.pipe.image_encoder(x_clip).image_embeds
v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor
embeddings = [c, v]
self.additional_embeddings.append(embeddings)
self.embedding_list.append([self.embeddings,self.additional_embeddings])
@torch.no_grad()
def refine(self, pred_rgb, polar, azimuth, radius,
guidance_scale=5, steps=50, strength=0.8,
):
batch_size = pred_rgb.shape[0]
self.scheduler.set_timesteps(steps)
if strength == 0:
init_step = 0
latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype)
else:
init_step = int(steps * strength)
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])
T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1)
T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4]
cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)
cc_emb = self.pipe.clip_camera_projection(cc_emb)
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)
vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1)
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)
for i, t in enumerate(self.scheduler.timesteps[init_step:]):
x_in = torch.cat([latents] * 2)
t_in = torch.cat([t.view(1)] * 2).to(self.device)
noise_pred = self.unet(
torch.cat([x_in, vae_emb], dim=1),
t_in.to(self.unet.dtype),
encoder_hidden_states=cc_emb,
).sample
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
imgs = self.decode_latents(latents) # [1, 3, 256, 256]
return imgs
def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False,idx=None,t=0):
# pred_rgb: tensor [1, 3, H, W] in [0, 1]
#print(polar)
step_ratio = max(0.4,step_ratio)
self.embeddings,self.additional_embeddings = self.embedding_list[t]
batch_size = pred_rgb.shape[0]
#print(self.embedding_list[1][0][0] -self.embedding_list[2][0][0])
#print(self.embedding_list[1][0][1] -self.embedding_list[2][0][1])
if idx is not None:
embeddings = self.additional_embeddings[idx]
else:
embeddings = self.embeddings
if as_latent:
latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
else:
pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)
w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)
with torch.no_grad():
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
x_in = torch.cat([latents_noisy] * 2)
t_in = torch.cat([t] * 2)
T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1)
T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4]
#print(self.embeddings[0].repeat(batch_size, 1, 1).shape,T.shape)
cc_emb = torch.cat([embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)
cc_emb = self.pipe.clip_camera_projection(cc_emb)
cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)
vae_emb = embeddings[1].repeat(batch_size, 1, 1, 1)
vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)
noise_pred = self.unet(
torch.cat([x_in, vae_emb], dim=1),
t_in.to(self.unet.dtype),
encoder_hidden_states=cc_emb,
).sample
noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
target = (latents - grad).detach()
loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum')
return loss
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
min_step_percent = self.get_steps(self.min_step_percent, epoch, global_step)
max_step_percent = self.get_steps(self.max_step_percent, epoch, global_step)
self.min_step = int( self.num_train_timesteps * min_step_percent )
self.max_step = int( self.num_train_timesteps * max_step_percent )
def get_steps(self,percent,epoch, global_step):
start_step, start_value, end_value, end_step = percent
current_step = global_step
value = start_value + (end_value - start_value) * max(
min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
)
return value
def decode_latents(self, latents):
latents = 1 / self.vae.config.scaling_factor * latents
imgs = self.vae.decode(latents).sample
imgs = (imgs / 2 + 0.5).clamp(0, 1)
return imgs
def encode_imgs(self, imgs, mode=False):
# imgs: [B, 3, H, W]
imgs = 2 * imgs - 1
posterior = self.vae.encode(imgs).latent_dist
if mode:
latents = posterior.mode()
else:
latents = posterior.sample()
latents = latents * self.vae.config.scaling_factor
return latents
if __name__ == '__main__':
import cv2
import argparse
import numpy as np
import matplotlib.pyplot as plt
parser = argparse.ArgumentParser()
parser.add_argument('input', type=str)
parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')
parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')
parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')
opt = parser.parse_args()
device = torch.device('cuda')
print(f'[INFO] loading image from {opt.input} ...')
image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
image = image.astype(np.float32) / 255.0
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)
print(f'[INFO] loading model ...')
zero123 = Zero123(device)
print(f'[INFO] running model ...')
zero123.get_img_embeds(image)
while True:
outputs = zero123.refine(image, polar=[opt.polar], azimuth=[opt.azimuth], radius=[opt.radius], strength=0)
plt.imshow(outputs.float().cpu().numpy().transpose(0, 2, 3, 1)[0])
plt.show()
================================================
FILE: guidance/zero123pp/pipeline.py
================================================
from typing import Any, Dict, Optional
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers
import numpy
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torch.distributed
import transformers
from collections import OrderedDict
from PIL import Image
from torchvision import transforms
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
import torch
import torch.nn.functional as F
from torch import nn
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
UNet2DConditionModel,
ImagePipelineOutput
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor
from diffusers.utils.import_utils import is_xformers_available
import os
FIRST = True
IDX = 0
PATH = '/home/vision/github/embeddings/'
EMBED=[]
def to_rgb_image(maybe_rgba: Image.Image):
if maybe_rgba.mode == 'RGB':
return maybe_rgba
elif maybe_rgba.mode == 'RGBA':
rgba = maybe_rgba
img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
img = Image.fromarray(img, 'RGB')
img.paste(rgba, mask=rgba.getchannel('A'))
return img
else:
raise ValueError("Unsupported image type.", maybe_rgba.mode)
class MyAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MyAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
is_self_attention=False
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
if is_self_attention:
global FIRST
global IDX
global EMBED
if FIRST == True:
EMBED.append(encoder_hidden_states.to('cpu'))
#print('saving to {})'.format(PATH,str(IDX)+'_hidden.pt'))
#os.makedirs(PATH,exist_ok=True)
#torch.save(encoder_hidden_states,os.path.join(PATH,str(IDX)+'_hidden.pt'))
IDX=IDX+1
else:
last_shape = encoder_hidden_states.shape[-1]
replace_dim = int(9600/(last_shape//320)**2)
encoder_hidden_states_load = EMBED[IDX].to('cuda')
encoder_hidden_states[:,:replace_dim,:]=(encoder_hidden_states_load[:,:replace_dim,:]+encoder_hidden_states[:,:replace_dim,:])/2
IDX=IDX+1
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
#print(key.shape)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class ReferenceOnlyAttnProc(torch.nn.Module):
def __init__(
self,
chained_proc,
enabled=False,
name=None
) -> None:
super().__init__()
self.enabled = enabled
self.chained_proc = chained_proc
self.name = name
def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
mode="w", ref_dict: dict = None, is_cfg_guidance = False
) -> Any:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
is_self_attention = False
if self.enabled and is_cfg_guidance:
res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask)
hidden_states = hidden_states[1:]
encoder_hidden_states = encoder_hidden_states[1:]
if self.enabled:
if mode == 'w':
ref_dict[self.name] = encoder_hidden_states
elif mode == 'r':
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
is_self_attention = True
elif mode == 'm':
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
else:
assert False, mode
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask,is_self_attention=is_self_attention)
if self.enabled and is_cfg_guidance:
res = torch.cat([res0, res])
return res
class RefOnlyNoisedUNet(torch.nn.Module):
def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
super().__init__()
self.unet = unet
self.train_sched = train_sched
self.val_sched = val_sched
unet_lora_attn_procs = dict()
for name, _ in unet.attn_processors.items():
if torch.__version__ >= '2.0':
default_attn_proc = MyAttnProcessor2_0()
print('using my attention')
elif is_xformers_available():
default_attn_proc = XFormersAttnProcessor()
else:
default_attn_proc = AttnProcessor()
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
)
unet.set_attn_processor(unet_lora_attn_procs)
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.unet, name)
def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
if is_cfg_guidance:
encoder_hidden_states = encoder_hidden_states[1:]
class_labels = class_labels[1:]
self.unet(
noisy_cond_lat, timestep,
encoder_hidden_states=encoder_hidden_states,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
**kwargs
)
def forward(
self, sample, timestep, encoder_hidden_states, class_labels=None,
*args, cross_attention_kwargs,
down_block_res_samples=None, mid_block_res_sample=None,
**kwargs
):
cond_lat = cross_attention_kwargs['cond_lat']
is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
noise = torch.randn_like(cond_lat)
if self.training:
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
else:
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
#if cross_attention_kwargs['cond_lat_back'] is not None:
# cond_lat_back = cross_attention_kwargs['cond_lat_back']
# noisy_cond_lat_back = self.val_sched.add_noise(cond_lat_back, noise, timestep.reshape(-1))
# noisy_cond_lat_back = self.val_sched.scale_model_input(noisy_cond_lat_back, timestep.reshape(-1))
ref_dict = {}
self.forward_cond(
noisy_cond_lat, timestep,
encoder_hidden_states, class_labels,
ref_dict, is_cfg_guidance, **kwargs
)
weight_dtype = self.unet.dtype
return self.unet(
sample, timestep,
encoder_hidden_states, *args,
class_labels=class_labels,
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
down_block_additional_residuals=[
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
] if down_block_res_samples is not None else None,
mid_block_additional_residual=(
mid_block_res_sample.to(dtype=weight_dtype)
if mid_block_res_sample is not None else None
),
**kwargs
)
def scale_latents(latents):
latents = (latents - 0.22) * 0.75
return latents
def unscale_latents(latents):
latents = latents / 0.75 + 0.22
return latents
def scale_image(image):
image = image * 0.5 / 0.8
return image
def unscale_image(image):
image = image / 0.5 * 0.8
return image
class DepthControlUNet(torch.nn.Module):
def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None:
super().__init__()
self.unet = unet
if controlnet is None:
self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet)
else:
self.controlnet = controlnet
DefaultAttnProc = MyAttnProcessor2_0
if is_xformers_available():
DefaultAttnProc = XFormersAttnProcessor
self.controlnet.set_attn_processor(DefaultAttnProc())
self.conditioning_scale = conditioning_scale
def __getattr__(self, name: str):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.unet, name)
def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs):
cross_attention_kwargs = dict(cross_attention_kwargs)
control_depth = cross_attention_kwargs.pop('control_depth')
down_block_res_samples, mid_block_res_sample = self.controlnet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control_depth,
conditioning_scale=self.conditioning_scale,
return_dict=False,
)
return self.unet(
sample,
timestep,
encoder_hidden_states=encoder_hidden_states,
down_block_res_samples=down_block_res_samples,
mid_block_res_sample=mid_block_res_sample,
cross_attention_kwargs=cross_attention_kwargs
)
class ModuleListDict(torch.nn.Module):
def __init__(self, procs: dict) -> None:
super().__init__()
self.keys = sorted(procs.keys())
self.values = torch.nn.ModuleList(procs[k] for k in self.keys)
def __getitem__(self, key):
return self.values[self.keys.index(key)]
class SuperNet(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys()))
self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# .processor for unet, .self_attn for text encoder
self.split_keys = [".processor", ".self_attn"]
# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
def map_to(module, state_dict, *args, **kwargs):
new_state_dict = {}
for key, value in state_dict.items():
num = int(key.split(".")[1]) # 0 is always "layers"
new_key = key.replace(f"layers.{num}", module.mapping[num])
new_state_dict[new_key] = value
return new_state_dict
def remap_key(key, state_dict):
for k in self.split_keys:
if k in key:
return key.split(k)[0] + k
return key.split('.')[0]
def map_from(module, state_dict, *args, **kwargs):
all_keys = list(state_dict.keys())
for key in all_keys:
replace_key = remap_key(key, state_dict)
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
state_dict[new_key] = state_dict[key]
del state_dict[key]
self._register_state_dict_hook(map_to)
self._register_load_state_dict_pre_hook(map_from, with_module=True)
class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
tokenizer: transformers.CLIPTokenizer
text_encoder: transformers.CLIPTextModel
vision_encoder: transformers.CLIPVisionModelWithProjection
feature_extractor_clip: transformers.CLIPImageProcessor
unet: UNet2DConditionModel
scheduler: diffusers.schedulers.KarrasDiffusionSchedulers
vae: AutoencoderKL
ramping: nn.Linear
feature_extractor_vae: transformers.CLIPImageProcessor
depth_transforms_multi = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
def __init__(
self,
vae: AutoencoderKL,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
vision_encoder: transformers.CLIPVisionModelWithProjection,
feature_extractor_clip: CLIPImageProcessor,
feature_extractor_vae: CLIPImageProcessor,
ramping_coefficients: Optional[list] = None,
safety_checker=None,
):
DiffusionPipeline.__init__(self)
self.register_modules(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
unet=unet, scheduler=scheduler, safety_checker=None,
vision_encoder=vision_encoder,
feature_extractor_clip=feature_extractor_clip,
feature_extractor_vae=feature_extractor_vae
)
self.register_to_config(ramping_coefficients=ramping_coefficients)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def prepare(self):
train_sched = DDPMScheduler.from_config(self.scheduler.config)
if isinstance(self.unet, UNet2DConditionModel):
self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()
def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0):
self.prepare()
self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)
return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)]))
def encode_condition_image(self, image: torch.Tensor):
image = self.vae.encode(image).latent_dist.sample()
return image
@torch.no_grad()
def __call__(
self,
image: Image.Image = None,
prompt = "",
*args,
num_images_per_prompt: Optional[int] = 1,
guidance_scale=4.0,
depth_image: Image.Image = None,
output_type: Optional[str] = "pil",
width=640,
height=960,
num_inference_steps=28,
return_dict=True,
is_first = False,
**kwargs
):
global FIRST
FIRST = is_first
global IDX
IDX = 0
if is_first:
global EMBED
EMBED=[]
# Create a generator with the specified seed
generator = torch.Generator(device='cuda')
generator.manual_seed(42)
self.prepare()
if image is None:
raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
assert not isinstance(image, torch.Tensor)
image = to_rgb_image(image)
image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
if depth_image is not None and hasattr(self.unet, "controlnet"):
depth_image = to_rgb_image(depth_image)
depth_image = self.depth_transforms_multi(depth_image).to(
device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
)
image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
cond_lat = self.encode_condition_image(image)
if guidance_scale > 1:
negative_lat = self.encode_condition_image(torch.zeros_like(image))
cond_lat = torch.cat([negative_lat, cond_lat])
encoded = self.vision_encoder(image_2, output_hidden_states=False)
global_embeds = encoded.image_embeds
global_embeds = global_embeds.unsqueeze(-2)
encoder_hidden_states = self._encode_prompt(
prompt,
self.device,
num_images_per_prompt,
False
)
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
cak = dict(cond_lat=cond_lat)
if hasattr(self.unet, "controlnet"):
cak['control_depth'] = depth_image
cak['cond_lat_back'] = None
latents: torch.Tensor = super().__call__(
None,
*args,
cross_attention_kwargs=cak,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=encoder_hidden_states,
num_inference_steps=num_inference_steps,
output_type='latent',
width=width,
height=height,
generator=generator,
**kwargs
).images
latents = unscale_latents(latents)
if not output_type == "latent":
image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
================================================
FILE: main.py
================================================
import os
import cv2
import time
import tqdm
import numpy as np
import dearpygui.dearpygui as dpg
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
from einops import rearrange, repeat
import imageio
import rembg
from cam_utils import orbit_camera, OrbitCamera
from gs_renderer_4d import Renderer, MiniCam
from dataset_4d import SparseDataset
def save_image_to_local(image_tensor, file_path):
# Ensure the image tensor is in the range [0, 1]
image_tensor = image_tensor.clamp(0, 1)
# Save the image tensor to the specified file path
vutils.save_image(image_tensor, file_path)
class GUI:
def __init__(self, opt):
self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.gui = opt.gui # enable gui
self.W = opt.W
self.H = opt.H
self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)
self.mode = "image"
self.seed = "random"
self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
self.need_update = True # update buffer_image
# models
self.device = torch.device("cuda")
self.bg_remover = None
self.guidance_sd = None
self.guidance_zero123 = None
self.enable_sd = False
self.enable_zero123 = False
# renderer
self.renderer = Renderer(sh_degree=self.opt.sh_degree)
self.gaussain_scale_factor = 1
# input image
self.input_img = None
self.input_mask = None
self.input_img_torch = None
self.input_mask_torch = None
self.overlay_input_img = False
self.overlay_input_img_ratio = 0.5
#self.use_depth = opt.use_depth
# input text
self.prompt = ""
self.negative_prompt = ""
# training stuff
self.training = False
self.optimizer = None
self.step = 0
self.t = 0
self.time = 0
self.train_steps = 1 # steps per rendering loop
self.init = True
self.stage = 'coarse'
self.path = self.opt.path
self.save_step = self.opt.save_step
if self.opt.size is not None:
self.size = self.opt.size
else:
self.size = len(os.listdir(os.path.join(self.path,'ref')))
self.frames=self.size
self.dataset = SparseDataset(self.opt, self.size, H=self.H, W=self.W, device=self.device)
self.dataloader =self.dataset.dataloader()
self.iter = iter(self.dataloader)
self.ref_view_batch, self.input_mask_batch,self.zero123_view_batch,self.zero123_masks_batch = next(self.iter)
self.input_img_torch_batch,self.input_mask_torch_batch,self.zero123plus_imgs_torch_batch,self.zero123plus_masks_torch_batch=[],[],[],[]
# load input data from cmdline
if self.opt.input is not None:
self.load_input(self.opt.input)
# override prompt from cmdline
if self.opt.prompt is not None:
self.prompt = self.opt.prompt
# override if provide a checkpoint
self.renderer.initialize(num_pts=self.opt.num_pts)
self.point_nums = []
if self.gui:
dpg.create_context()
self.register_dpg()
self.test_step()
def __del__(self):
if self.gui:
dpg.destroy_context()
def seed_everything(self):
try:
seed = int(self.seed)
except:
seed = np.random.randint(0, 1000000)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
self.last_seed = seed
def prepare_image(self,idx):
# input image
if self.input_img is not None:
self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device)
self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)
self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
self.zero123plus_imgs_torch=[]
self.zero123plus_masks_torch=[]
# input image
if self.input_imgs is not None:
for i in np.arange(6):
#print(idx,i)
self.input_img2_torch=(torch.from_numpy(self.input_imgs[i]).permute(2, 0, 1).unsqueeze(0).to(self.device))
self.zero123plus_imgs_torch.append(F.interpolate(self.input_img2_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False))
self.input_mask2_torch=torch.from_numpy(self.input_masks[i]).permute(2, 0, 1).unsqueeze(0).to(self.device)
self.zero123plus_masks_torch.append(F.interpolate(self.input_mask2_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False))
self.input_img_torch_batch.append(self.input_img_torch)
self.input_mask_torch_batch.append(self.input_mask_torch)
self.zero123plus_imgs_torch_batch.append(self.zero123plus_imgs_torch)
self.zero123plus_masks_torch_batch.append(self.zero123plus_masks_torch)
# prepare embeddings
with torch.no_grad():
self.guidance_zero123.get_img_embeds(self.input_img_torch, self.zero123plus_imgs_torch)
def prepare_train(self):
self.step = 0
self.end_step = self.save_step+1
## given a load_path, load corresponding model
if self.opt.load_path is not None:
if self.opt.load_step is not None:
self.step = self.opt.load_step
else:
#default loading save_step ply
self.step = self.save_step
auto_path = os.path.join(self.opt.outdir,self.opt.load_path + str(self.step))
ply_path = os.path.join(auto_path,'model.ply')
self.renderer.gaussians.load_model(auto_path)
self.renderer.gaussians.load_ply(ply_path)
self.end_step =self.step+self.end_step
## setup training
self.renderer.gaussians.training_setup(self.opt)
## do not do progressive sh-level
self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree
self.optimizer = self.renderer.gaussians.optimizer
# default camera
pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)
self.fixed_cam = MiniCam(
pose,
self.opt.ref_size,
self.opt.ref_size,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
)
self.set_fix_cam()
self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != ""
self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None
print(f"[INFO] loading zero123...")
from guidance.zero123_4d_utils import Zero123
self.guidance_zero123 = Zero123(self.device)
print(f"[INFO] loaded zero123!")
## load multiview reference images
for i in np.arange(len(self.ref_view_batch)):
self.input_img = self.ref_view_batch[i]
self.input_mask = self.input_mask_batch[i]
self.input_imgs = self.zero123_view_batch[i]
self.input_masks = self.zero123_masks_batch[i]
self.prepare_image(i)
def train_step(self):
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()
torch.autograd.set_detect_anomaly(True)
for _ in range(self.train_steps):
if self.step<self.opt.init_steps:
self.init = True
self.stage = 'coarse'
else:
self.init = False
self.stage = 'fine'
if self.step == self.end_step:
exit()
## save model
if self.step == self.save_step:
auto_path = os.path.join(self.opt.outdir,self.opt.save_path + str(self.step))
os.makedirs(auto_path,exist_ok=True)
ply_path = os.path.join(auto_path,'model.ply')
self.renderer.gaussians.save_ply(ply_path)
self.renderer.gaussians.save_deformation(auto_path)
if self.step>self.opt.position_lr_max_steps:
self.opt.position_lr_max_steps = self.opt.position_lr_max_steps2
self.step += 1
step_ratio = min(1, self.step / self.opt.iters)
viewspace_point_tensor_list = []
radii_list = []
visibility_filter_list = []
# update lr
self.renderer.gaussians.update_learning_rate(self.step)
self.guidance_zero123.update_step(0,self.step)
loss = 0
if self.step%self.opt.valid_interval == 0:
self.save_renderings( 0, 0, 2 ,'front')
self.save_renderings( 180, 0, 2 ,'back')
render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512)
# avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]
min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation)
max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation)
for _ in np.arange(self.opt.batch_size):
#sample time
if self.init:
self.t = self.frames//2
self.time = self.t/self.frames
else:
self.t = np.random.randint(0,self.frames)
self.time = self.t/self.frames
self.input_img_torch = self.input_img_torch_batch[self.t]
self.input_mask_torch = self.input_mask_torch_batch[self.t]
self.zero123plus_imgs_torch = self.zero123plus_imgs_torch_batch[self.t]
self.zero123plus_masks_torch = self.zero123plus_masks_torch_batch[self.t]
## need to do rgb loss in the batch
cur_cam = self.fixed_cam
cur_cam.time=self.time
out = self.renderer.render(cur_cam,stage=self.stage)
viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"]
radii_list.append(radii.unsqueeze(0))
visibility_filter_list.append(visibility_filter.unsqueeze(0))
viewspace_point_tensor_list.append(viewspace_point_tensor)
# rgb loss
image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
image_loss =step_ratio* 20000* F.mse_loss(image, self.input_img_torch)
loss = loss + image_loss
alpha = out["alpha"].unsqueeze(0)
alpha_loss = step_ratio* 5000* F.mse_loss(alpha, self.input_mask_torch)
loss = loss + alpha_loss
images = []
poses = []
vers_plus, hors_plus, radii_plus = [], [], []
self.guidance_zero123.update_step(1,self.step)
# render random view
ver = np.random.randint(min_ver, max_ver)
hor = np.random.randint(-180, 180)
radius = 0
pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)
poses.append(pose)
cur_cam = MiniCam(
pose,
render_resolution,
render_resolution,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
)
cur_cam.time=self.time
if hor<30 and hor>-30 or np.random.rand()>0.4:
idx=None
vers_plus.append(torch.tensor(ver,device=self.device).unsqueeze(dim=0))
hors_plus.append(torch.tensor(hor,device=self.device).unsqueeze(dim=0))
radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0))
elif hor>0:
idx=hor//60
vers_plus.append(torch.tensor(ver-self.fixed_elevation[idx],device=self.device).unsqueeze(dim=0))
hors_plus.append(torch.tensor(hor-self.fixed_azimuth[idx],device=self.device).unsqueeze(dim=0))
radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0))
elif hor<0:
idx = (360+hor)//60
vers_plus.append(torch.tensor(ver-self.fixed_elevation[idx],device=self.device).unsqueeze(dim=0))
hors_plus.append(torch.tensor(hor-self.fixed_azimuth[idx],device=self.device).unsqueeze(dim=0))
radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0))
bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device="cuda")
out = self.renderer.render(cur_cam, bg_color=bg_color,stage=self.stage)
viewspace_point_tensor, visibility_filter, radii_rendering = out["viewspace_points"], out["visibility_filter"], out["radii"]
radii_list.append(radii_rendering.unsqueeze(0))
visibility_filter_list.append(visibility_filter.unsqueeze(0))
viewspace_point_tensor_list.append(viewspace_point_tensor)
image = out["image"].unsqueeze(0)# [1, 3, H, W] in [0, 1]
images.append(image)
images_render = torch.cat(images, dim=0)
#poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)
vers_batch = torch.cat(vers_plus, dim=0).cpu().numpy()
hors_batch = torch.cat(hors_plus, dim=0).cpu().numpy()
radii_batch = torch.cat(radii_plus, dim=0).cpu().numpy()
# guidance loss
# as we have different reference views, so each time we only pass 1 image into zero123 for guidance
zero123_loss = self.opt.lambda_zero123 * self.guidance_zero123.train_step(images_render, vers_batch, hors_batch, radii_batch, step_ratio,idx=idx,t = self.t)
loss = loss + zero123_loss
# tv loss
scales = out['scales']
tv_loss = self.renderer.gaussians.compute_regulation(self.opt.time_smoothness_weight, self.opt.plane_tv_weight, self.opt.l1_time_planes)
loss += self.opt.lambda_tv * tv_loss
# scale loss from physgaussian
r = self.opt.scale_loss_ratio
scale_loss = (torch.mean(torch.maximum(torch.max(scales,dim=1).values/ \
(torch.min(scales,dim=1).values+1e-8),\
torch.ones_like(torch.max(scales,dim=1).values)*r))-r) * scales.shape[0]
loss += scale_loss
# optimize step
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
viewspace_point_tensor_grad = torch.zeros_like(viewspace_point_tensor)
for idx in range(0, len(viewspace_point_tensor_list)):
viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad
radii = torch.cat(radii_list,0).max(dim=0).values
visibility_filter = torch.cat(visibility_filter_list).any(dim=0)
if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
self.renderer.gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter)
if self.step % self.opt.densification_interval == 1 :
self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold_percent, min_opacity=0.01, extent=1, max_screen_size=2)
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
self.need_update = True
if self.gui:
dpg.set_value("_log_train_time", f"{t:.4f}ms")
dpg.set_value(
"_log_train_log",
f"step = {self.step: 5d} (+{self.train_steps: 2d})\n loss = {loss.item():.4f}\nzero123_loss = {zero123_loss.item():.4f}image_loss ={image_loss.item():.4f}\nloss_alpha = {alpha_loss.item():.4f} scale_loss:{scale_loss.item():.4f} ",
)
def set_fix_cam(self):
self.fixed_cam_plus=[]
self.fixed_elevation = []
self.fixed_azimuth = []
pose = orbit_camera(self.opt.elevation-30,30 , self.opt.radius)
self.fixed_elevation.append(-30)
self.fixed_azimuth.append(30)
self.fixed_cam_plus.append(MiniCam(
pose,
self.opt.ref_size,
self.opt.ref_size,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
))
pose = orbit_camera(self.opt.elevation+20, 90, self.opt.radius)
self.fixed_elevation.append(20)
self.fixed_azimuth.append(90)
self.fixed_cam_plus.append(MiniCam(
pose,
self.opt.ref_size,
self.opt.ref_size,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
))
pose = orbit_camera(self.opt.elevation-30, 150, self.opt.radius)
self.fixed_elevation.append(-30)
self.fixed_azimuth.append(150)
self.fixed_cam_plus.append(MiniCam(
pose,
self.opt.ref_size,
self.opt.ref_size,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
))
pose = orbit_camera(self.opt.elevation+20, 210, self.opt.radius)
self.fixed_elevation.append(+20)
self.fixed_azimuth.append(210)
self.fixed_cam_plus.append(MiniCam(
pose,
self.opt.ref_size,
self.opt.ref_size,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
))
pose = orbit_camera(self.opt.elevation-30, 270, self.opt.radius)
self.fixed_elevation.append(-30)
self.fixed_azimuth.append(270)
self.fixed_cam_plus.append(MiniCam(
pose,
self.opt.ref_size,
self.opt.ref_size,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
))
pose = orbit_camera(self.opt.elevation+20, 330, self.opt.radius)
self.fixed_elevation.append(20)
self.fixed_azimuth.append(330)
self.fixed_cam_plus.append(MiniCam(
pose,
self.opt.ref_size,
self.opt.ref_size,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
))
@torch.no_grad()
def test_step(self):
# ignore if no need to update
if not self.need_update:
return
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()
# should update image
if self.need_update:
# render image
cur_cam = MiniCam(
self.cam.pose,
self.W,
self.H,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
time=self.time
)
#print(cur_cam.time)
out = self.renderer.render(cur_cam, self.gaussain_scale_factor,stage=self.stage)
buffer_image = out[self.mode] # [3, H, W]
if self.mode in ['depth', 'alpha']:
buffer_image = buffer_image.repeat(3, 1, 1)
if self.mode == 'depth':
buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)
buffer_image = F.interpolate(
buffer_image.unsqueeze(0),
size=(self.H, self.W),
mode="bilinear",
align_corners=False,
).squeeze(0)
self.buffer_image = (
buffer_image.permute(1, 2, 0)
.contiguous()
.clamp(0, 1)
.contiguous()
.detach()
.cpu()
.numpy()
)
# display input_image
if self.overlay_input_img and self.input_img is not None:
self.buffer_image = (
self.buffer_image * (1 - self.overlay_input_img_ratio)
+ self.input_img * self.overlay_input_img_ratio
)
self.need_update = False
ender.record()
torch.cuda.synchronize()
t = starter.elapsed_time(ender)
if self.gui:
dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)")
dpg.set_value(
"_texture", self.buffer_image
) # buffer must be contiguous, else seg fault!
def load_input(self, file):
# load image
pass
# load image
@torch.no_grad()
def save_renderings(self, elev=0, azim=0, radius=2, name='front'):
images=[]
for i in np.arange(self.frames):
pose = orbit_camera(elev, azim, radius)
cam = MiniCam(
pose,
self.opt.ref_size,
self.opt.ref_size,
self.cam.fovy,
self.cam.fovx,
self.cam.near,
self.cam.far,
)
cam.time=float(i/self.frames)
out = self.renderer.render(cam,stage=self.stage)
image = out["image"].unsqueeze(0)
images.append(image)
os.makedirs(f'./valid/{self.opt.save_path}/{self.step}_{name}',exist_ok=True)
save_image_to_local(image[0].detach(),f'./valid/{self.opt.save_path}/{self.step}_{name}/{str(i).zfill(2)}.jpg')
samples=torch.cat(images,dim=0)
vid = (
(rearrange(samples, "t c h w -> t h w c") * 255).clamp(0,255).detach()
.cpu()
.numpy()
.astype(np.uint8)
)
video_path = f'./valid/{self.opt.save_path}/{self.step}_{name}/video.mp4'
imageio.mimwrite(video_path, vid)
@torch.no_grad()
def save_model(self, mode='geo', texture_size=1024):
os.makedirs(self.opt.outdir, exist_ok=True)
if mode == 'geo':
path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')
self.renderer.gaussians.save_ply(path)
elif mode == 'geo+tex':
path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')
self.renderer.gaussians.save_ply(path)
else:
path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')
self.renderer.gaussians.save_ply(path)
print(f"[INFO] save model to {path}.")
def register_dpg(self):
### register texture
with dpg.texture_registry(show=False):
dpg.add_raw_texture(
self.W,
self.H,
self.buffer_image,
format=dpg.mvFormat_Float_rgb,
tag="_texture",
)
### register window
# the rendered image, as the primary window
with dpg.window(
tag="_primary_window",
width=self.W,
height=self.H,
pos=[0, 0],
no_move=True,
no_title_bar=True,
no_scrollbar=True,
):
# add the texture
dpg.add_image("_texture")
# dpg.set_primary_window("_primary_window", True)
# control window
with dpg.window(
label="Control",
tag="_control_window",
width=600,
height=self.H,
pos=[self.W, 0],
no_move=True,
no_title_bar=True,
):
# button theme
with dpg.theme() as theme_button:
with dpg.theme_component(dpg.mvButton):
dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)
# timer stuff
with dpg.group(horizontal=True):
dpg.add_text("Infer time: ")
dpg.add_text("no data", tag="_log_infer_time")
def callback_setattr(sender, app_data, user_data):
setattr(self, user_data, app_data)
# init stuff
with dpg.collapsing_header(label="Initialize", default_open=True):
# seed stuff
def callback_set_seed(sender, app_data):
self.seed = app_data
self.seed_everything()
dpg.add_input_text(
label="seed",
default_value=self.seed,
on_enter=True,
callback=callback_set_seed,
)
# input stuff
def callback_select_input(sender, app_data):
# only one item
for k, v in app_data["selections"].items():
dpg.set_value("_log_input", k)
self.load_input(v)
self.need_update = True
with dpg.file_dialog(
directory_selector=False,
show=False,
callback=callback_select_input,
file_count=1,
tag="file_dialog_tag",
width=700,
height=400,
):
dpg.add_file_extension("Images{.jpg,.jpeg,.png}")
with dpg.group(horizontal=True):
dpg.add_button(
label="input",
callback=lambda: dpg.show_item("file_dialog_tag"),
)
dpg.add_text("", tag="_log_input")
# overlay stuff
with dpg.group(horizontal=True):
def callback_toggle_overlay_input_img(sender, app_data):
self.overlay_input_img = not self.overlay_input_img
self.need_update = True
dpg.add_checkbox(
label="overlay image",
default_value=self.overlay_input_img,
callback=callback_toggle_overlay_input_img,
)
def callback_set_overlay_input_img_ratio(sender, app_data):
self.overlay_input_img_ratio = app_data
self.need_update = True
dpg.add_slider_float(
label="ratio",
min_value=0,
max_value=1,
format="%.1f",
default_value=self.overlay_input_img_ratio,
callback=callback_set_overlay_input_img_ratio,
)
# prompt stuff
dpg.add_input_text(
label="prompt",
default_value=self.prompt,
callback=callback_setattr,
user_data="prompt",
)
dpg.add_input_text(
label="negative",
default_value=self.negative_prompt,
callback=callback_setattr,
user_data="negative_prompt",
)
# save current model
with dpg.group(horizontal=True):
dpg.add_text("Save: ")
def callback_save(sender, app_data, user_data):
self.save_model(mode=user_data)
dpg.add_button(
label="model",
tag="_button_save_model",
callback=callback_save,
user_data='model',
)
dpg.bind_item_theme("_button_save_model", theme_button)
dpg.add_button(
label="geo",
tag="_button_save_mesh",
callback=callback_save,
user_data='geo',
)
dpg.bind_item_theme("_button_save_mesh", theme_button)
dpg.add_button(
label="geo+tex",
tag="_button_save_mesh_with_tex",
callback=callback_save,
user_data='geo+tex',
)
dpg.bind_item_theme("_button_save_mesh_with_tex", theme_button)
dpg.add_input_text(
label="",
default_value=self.opt.save_path,
callback=callback_setattr,
user_data="save_path",
)
# training stuff
with dpg.collapsing_header(label="Train", default_open=True):
# lr and train button
with dpg.group(horizontal=True):
dpg.add_text("Train: ")
def callback_train(sender, app_data):
if self.training:
self.training = False
dpg.configure_item("_button_train", label="start")
else:
self.prepare_train()
self.training = True
dpg.configure_item("_button_train", label="stop")
# dpg.add_button(
# label="init", tag="_button_init", callback=self.prepare_train
# )
# dpg.bind_item_theme("_button_init", theme_button)
dpg.add_button(
label="start", tag="_button_train", callback=callback_train
)
dpg.bind_item_theme("_button_train", theme_button)
with dpg.group(horizontal=True):
dpg.add_text("", tag="_log_train_time")
dpg.add_text("", tag="_log_train_log")
# rendering options
with dpg.collapsing_header(label="Rendering", default_open=True):
# mode combo
def callback_change_mode(sender, app_data):
self.mode = app_data
self.need_update = True
dpg.add_combo(
("image", "depth", "alpha"),
label="mode",
default_value=self.mode,
callback=callback_change_mode,
)
# fov slider
def callback_set_fovy(sender, app_data):
self.cam.fovy = np.deg2rad(app_data)
self.need_update = True
dpg.add_slider_int(
label="FoV (vertical)",
min_value=1,
max_value=120,
format="%d deg",
default_value=np.rad2deg(self.cam.fovy),
callback=callback_set_fovy,
)
def callback_set_gaussain_scale(sender, app_data):
self.gaussain_scale_factor = app_data
self.need_update = True
dpg.add_slider_float(
label="gaussain scale",
min_value=0,
max_value=1,
format="%.2f",
default_value=self.gaussain_scale_factor,
callback=callback_set_gaussain_scale,
)
### register camera handler
def callback_camera_drag_rotate_or_draw_mask(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.orbit(dx, dy)
self.need_update = True
def callback_camera_wheel_scale(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
delta = app_data
self.cam.scale(delta)
self.need_update = True
def callback_camera_drag_pan(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
dx = app_data[1]
dy = app_data[2]
self.cam.pan(dx, dy)
self.need_update = True
def callback_set_mouse_loc(sender, app_data):
if not dpg.is_item_focused("_primary_window"):
return
# just the pixel coordinate in image
self.mouse_loc = np.array(app_data)
with dpg.handler_registry():
# for camera moving
dpg.add_mouse_drag_handler(
button=dpg.mvMouseButton_Left,
callback=callback_camera_drag_rotate_or_draw_mask,
)
dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
dpg.add_mouse_drag_handler(
button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
)
dpg.create_viewport(
title="Gaussian3D",
width=self.W + 600,
height=self.H + (45 if os.name == "nt" else 0),
resizable=False,
)
### global theme
with dpg.theme() as theme_no_padding:
with dpg.theme_component(dpg.mvAll):
# set all padding to 0 to avoid scroll bar
dpg.add_theme_style(
dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.add_theme_style(
dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.add_theme_style(
dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
)
dpg.bind_item_theme("_primary_window", theme_no_padding)
dpg.setup_dearpygui()
### register a larger font
# get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf
if os.path.exists("LXGWWenKai-Regular.ttf"):
with dpg.font_registry():
with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font:
dpg.bind_font(default_font)
# dpg.show_metrics()
dpg.show_viewport()
def render(self):
assert self.gui
while dpg.is_dearpygui_running():
# update texture every frame
if self.training:
self.train_step()
self.test_step()
dpg.render_dearpygui_frame()
# no gui mode
def train(self, iters=500):
if iters > 0:
self.prepare_train()
for i in tqdm.trange(iters):
self.train_step()
# do a last prune
#self.renderer.gaussians.prune(min_opacity=0.01, extent=1, max_screen_size=1)
# save
self.save_model(mode='model')
self.save_model(mode='geo+tex')
if __name__ == "__main__":
import argparse
from omegaconf import OmegaConf
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to the yaml config file")
args, extras = parser.parse_known_args()
# override default config from cli
opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))
gui = GUI(opt)
if opt.gui:
gui.render()
else:
gui.train(opt.save_step+1)
================================================
FILE: mini_trainer.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from gs_renderer_4d import Renderer, MiniCam\n",
"from dataset_4d import SparseDataset\n",
"import os\n",
"import tqdm\n",
"import numpy as np\n",
"import torch\n",
"\n",
"from cam_utils import orbit_camera, OrbitCamera\n",
"from guidance.sd_utils import StableDiffusion\n",
"\n",
"\n",
"class trainer:\n",
" def __init__(self,opt) -> None:\n",
" \n",
" #initialize options\n",
" self.opt=opt\n",
" self.device=self.opt.device\n",
" \n",
" #initialize renderer and gaussians\n",
" self.renderer = Renderer(sh_degree=self.opt.sh_degree)\n",
" self.renderer.initialize(num_pts=self.opt.num_pts) \n",
" self.renderer.gaussians.training_setup(self.opt)\n",
" \n",
" self.optimizer = self.renderer.gaussians.optimizer\n",
" \n",
" self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)\n",
" \n",
" #initialize sd. replace with your own diffusion model if necessary.\n",
" self.enable_sd = True\n",
" self.guidance_sd = StableDiffusion(self.device)\n",
" self.guidance_sd.get_text_embeds([self.opt.prompt],negative_prompts= [''])\n",
" \n",
" def save(self,save_path):\n",
" #save \n",
" auto_path = save_path\n",
" os.makedirs(auto_path,exist_ok=True)\n",
" ply_path = os.path.join(auto_path,'model.ply')\n",
" self.renderer.gaussians.save_ply(ply_path)\n",
" self.renderer.gaussians.save_deformation(auto_path)\n",
" \n",
" def load(self, load_path):\n",
" #load\n",
" auto_path = load_path\n",
" ply_path = os.path.join(auto_path,'model.ply')\n",
" self.renderer.gaussians.load_model(auto_path)\n",
" self.renderer.gaussians.load_ply(ply_path)\n",
" \n",
" \n",
" def render(self,frame_id, elevation, azimuth, radius):\n",
" #render with parameters\n",
" pose = orbit_camera(elevation,azimuth,radius)\n",
" cam = MiniCam(\n",
" pose,\n",
" self.opt.ref_size,\n",
" self.opt.ref_size,\n",
" self.cam.fovy,\n",
" self.cam.fovx,\n",
" self.cam.near,\n",
" self.cam.far,\n",
" ) \n",
" cam.time=float(frame_id/30) #30 is the total frame\n",
" #use stage='coarse' for static rendering, use stage='fine' for dynamic rendering\n",
" out = self.renderer.render(cam,stage='fine')\n",
" image = out[\"image\"].unsqueeze(0)# [1, 3, H, W] in [0, 1]\n",
" \n",
" return image\n",
" \n",
" def train(self):\n",
" self.step=0\n",
" \n",
" for i in tqdm.tqdm(range(10000)):\n",
" self.step+=1\n",
" self.renderer.gaussians.update_learning_rate(self.step)\n",
" loss = 0\n",
" \n",
" min_ver = -30\n",
" max_ver = 30\n",
" vers, hors, radiis, poses = [], [], [], []\n",
" images=[]\n",
" viewspace_point_tensor_list, radii_list, visibility_filter_list = [], [], []\n",
"\n",
" render_resolution=512\n",
" \n",
" for _ in range(self.opt.batch_size):\n",
" #sample time, vertical& horizontal angle\n",
" ver = np.random.randint(min_ver, max_ver)\n",
" hor = np.random.randint(-180, 180)\n",
" radius=0\n",
" self.t = np.random.randint(0,30)\n",
" self.time = self.t/30\n",
" \n",
" vers.append(torch.tensor(self.opt.elevation + ver,device=self.device).unsqueeze(dim=0))\n",
" hors.append(torch.tensor(hor,device=self.device).unsqueeze(dim=0))\n",
" radiis.append(torch.tensor(self.opt.radius + radius,device=self.device).unsqueeze(dim=0))\n",
" \n",
" pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)\n",
" \n",
" poses.append(pose)\n",
"\n",
"\n",
" cur_cam = MiniCam(\n",
" pose,\n",
" render_resolution,\n",
" render_resolution,\n",
" self.cam.fovy,\n",
" self.cam.fovx,\n",
" self.cam.near,\n",
" self.cam.far,\n",
" )\n",
" cur_cam.time=self.time\n",
" \n",
" bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device=\"cuda\")\n",
" #use stage='coarse' for static rendering, use stage='fine' for dynamic rendering\n",
" out = self.renderer.render(cur_cam, bg_color=bg_color,stage='fine')\n",
" \n",
" #basic values for densification\n",
" viewspace_point_tensor, visibility_filter, radii = out[\"viewspace_points\"], out[\"visibility_filter\"], out[\"radii\"] \n",
" radii_list.append(radii.unsqueeze(0))\n",
" visibility_filter_list.append(visibility_filter.unsqueeze(0))\n",
" viewspace_point_tensor_list.append(viewspace_point_tensor)\n",
" \n",
" image = out[\"image\"].unsqueeze(0)# [1, 3, H, W] in [0, 1]\n",
" images.append(image)\n",
" \n",
" images_batch = torch.cat(images, dim=0)\n",
" poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)\n",
" vers_batch = torch.cat(vers, dim=0).cpu().numpy()\n",
" hors_batch = torch.cat(hors, dim=0).cpu().numpy()\n",
" radii_batch = torch.cat(radiis, dim=0).cpu().numpy()\n",
"\n",
" if self.enable_sd:\n",
" sd_loss = self.guidance_sd.train_step(images_batch,step_ratio=None,poses=poses)\n",
" # guidance loss. replace with your own diffusion model if necessary.\n",
" loss = loss + sd_loss\n",
" else:\n",
" zero123_loss = self.guidance_zero123.train_step(images_batch, vers_batch, hors_batch, radii_batch,step_ratio=None)\n",
" # guidance loss.\n",
" loss = loss + zero123_loss\n",
" \n",
" # optimize step\n",
" loss.backward()\n",
" self.optimizer.step()\n",
" self.optimizer.zero_grad()\n",
"\n",
" #densifications. Adaptive densification is used here.\n",
" viewspace_point_tensor_grad = torch.zeros_like(viewspace_point_tensor)\n",
" for idx in range(0, len(viewspace_point_tensor_list)):\n",
" viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad\n",
"\n",
" if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:\n",
" self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])\n",
" self.renderer.gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter)\n",
" if self.step % self.opt.densification_interval == 0 :\n",
"\n",
" self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=1, max_screen_size=2)\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"feature_dim: 128\n",
"Number of points at initialisation : 10000\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a4cb52b5dc0045ccba1d352fc337cbd0",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading pipeline components...: 0%| | 0/6 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 1%| | 72/10000 [00:19<44:48, 3.69it/s] \n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000\">╭─────────────────────────────── </span><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Traceback </span><span style=\"color: #bf7f7f; text-decoration-color: #bf7f7f; font-weight: bold\">(most recent call last)</span><span style=\"color: #800000; text-decoration-color: #800000\"> ────────────────────────────────╮</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\"><module></span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">6</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">3 </span>opt=OmegaConf.load(<span style=\"color: #808000; text-decoration-color: #808000\">'./configs/image_4d_m.yaml'</span>) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">4 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">5 </span>train=trainer(opt) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>6 train.train() <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">7 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">train</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">134</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">131 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ │ </span>loss = loss + zero123_loss <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">132 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">133 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># optimize step</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>134 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ </span>loss.backward() <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">135 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>.optimizer.step() <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">136 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>.optimizer.zero_grad() <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">137 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">_tensor.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">487</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">backward</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 484 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ │ </span>create_graph=create_graph, <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 485 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ │ </span>inputs=inputs, <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 486 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ </span>) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span> 487 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span>torch.autograd.backward( <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 488 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ │ </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>, gradient, retain_graph, create_graph, inputs=inputs <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 489 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span>) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 490 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/autograd/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">__init__.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">200</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">backward</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">197 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># The reason we repeat same the comment below is that</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">198 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># some Python versions print out the first line of a multi-line function</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">199 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># calls in the traceback and some print out the last line</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>200 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ </span>Variable._execution_engine.run_backward( <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># Calls into the C++ engine to run the bac</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">201 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span>tensors, grad_tensors_, retain_graph, create_graph, inputs, <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">202 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span>allow_unreachable=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">True</span>, accumulate_grad=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">True</span>) <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># Calls into the C++ engine to ru</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">203 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">╰──────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
"<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">KeyboardInterrupt</span>\n",
"</pre>\n"
],
"text/plain": [
"\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
"\u001b[31m│\u001b[0m in \u001b[92m<module>\u001b[0m:\u001b[94m6\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m3 \u001b[0mopt=OmegaConf.load(\u001b[33m'\u001b[0m\u001b[33m./configs/image_4d_m.yaml\u001b[0m\u001b[33m'\u001b[0m) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m4 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m5 \u001b[0mtrain=trainer(opt) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m6 train.train() \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m7 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m in \u001b[92mtrain\u001b[0m:\u001b[94m134\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m131 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mloss = loss + zero123_loss \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m132 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m133 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# optimize step\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m134 \u001b[2m│ │ │ \u001b[0mloss.backward() \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m135 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m.optimizer.step() \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m136 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m.optimizer.zero_grad() \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m137 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2;33m/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/\u001b[0m\u001b[1;33m_tensor.py\u001b[0m:\u001b[94m487\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 484 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mcreate_graph=create_graph, \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 485 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minputs=inputs, \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 486 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 487 \u001b[2m│ │ \u001b[0mtorch.autograd.backward( \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 488 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, inputs=inputs \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 489 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m 490 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2;33m/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/autograd/\u001b[0m\u001b[1;33m__init__.py\u001b[0m:\u001b[94m200\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m197 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# The reason we repeat same the comment below is that\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m198 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# some Python versions print out the first line of a multi-line function\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m199 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# calls in the traceback and some print out the last line\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m200 \u001b[2m│ \u001b[0mVariable._execution_engine.run_backward( \u001b[2m# Calls into the C++ engine to run the bac\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m201 \u001b[0m\u001b[2m│ │ \u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs, \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m202 \u001b[0m\u001b[2m│ │ \u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, accumulate_grad=\u001b[94mTrue\u001b[0m) \u001b[2m# Calls into the C++ engine to ru\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m203 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
"\u001b[1;91mKeyboardInterrupt\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from omegaconf import OmegaConf\n",
"\n",
"opt=OmegaConf.load('./configs/image_4d_m.yaml')\n",
"\n",
"train=trainer(opt)\n",
"train.train()"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.0"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np\n",
"np.deg2rad(0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "torch0",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: requirements.txt
================================================
tqdm
rich
ninja
numpy
pandas
scipy
scikit-learn
matplotlib
opencv-python
imageio
imageio-ffmpeg
omegaconf
argparse
torch
einops
plyfile
pygltflib
dearpygui
accelerate
rembg[gpu,cli]
#zero123plus
opencv-contrib-python
diffusers==0.20.2
transformers==4.29.2
streamlit==1.22.0
altair<5
huggingface_hub
git+https://github.com/facebookresearch/segment-anything.git
gradio>=3.50
fire
================================================
FILE: scripts/app.py
================================================
import os
import sys
import numpy
import torch
import rembg
import threading
import urllib.request
from PIL import Image
import streamlit as st
import huggingface_hub
class SAMAPI:
predictor = None
@staticmethod
@st.cache_resource
def get_instance(sam_checkpoint=None):
if SAMAPI.predictor is None:
if sam_checkpoint is None:
sam_checkpoint = "tmp/sam_vit_h_4b8939.pth"
if not os.path.exists(sam_checkpoint):
os.makedirs('tmp', exist_ok=True)
urllib.request.urlretrieve(
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
sam_checkpoint
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_type = "default"
from segment_anything import sam_model_registry, SamPredictor
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
SAMAPI.predictor = predictor
return SAMAPI.predictor
@staticmethod
def segment_api(rgb, mask=None, bbox=None, sam_checkpoint=None):
"""
Parameters
----------
rgb : np.ndarray h,w,3 uint8
mask: np.ndarray h,w bool
Returns
-------
"""
np = numpy
predictor = SAMAPI.get_instance(sam_checkpoint)
predictor.set_image(rgb)
if mask is None and bbox is None:
box_input = None
else:
# mask to bbox
if bbox is None:
y1, y2, x1, x2 = np.nonzero(mask)[0].min(), np.nonzero(mask)[0].max(), np.nonzero(mask)[1].min(), \
np.nonzero(mask)[1].max()
else:
x1, y1, x2, y2 = bbox
box_input = np.array([[x1, y1, x2, y2]])
masks, scores, logits = predictor.predict(
box=box_input,
multimask_output=True,
return_logits=False,
)
mask = masks[-1]
return mask
def image_examples(samples, ncols, return_key=None, example_text="Examples"):
global img_example_counter
trigger = False
with st.expander(example_text, True):
for i in range(len(samples) // ncols):
cols = st.columns(ncols)
for j in range(ncols):
idx = i * ncols + j
if idx >= len(samples):
continue
entry = samples[idx]
with cols[j]:
st.image(entry['dispi'])
img_example_counter += 1
with st.columns(5)[2]:
this_trigger = st.button('\+', key='imgexuse%d' % img_example_counter)
trigger = trigger or this_trigger
if this_trigger:
trigger = entry[return_key]
return trigger
def segment_img(img: Image):
output = rembg.remove(img)
mask = numpy.array(output)[:, :, 3] > 0
sam_mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
segmented_img = Image.new("RGBA", img.size, (0, 0, 0, 0))
segmented_img.paste(img, mask=Image.fromarray(sam_mask))
return segmented_img
def segment_6imgs(zero123pp_imgs):
imgs = [zero123pp_imgs.crop([0, 0, 320, 320]),
zero123pp_imgs.crop([320, 0, 640, 320]),
zero123pp_imgs.crop([0, 320, 320, 640]),
zero123pp_imgs.crop([320, 320, 640, 640]),
zero123pp_imgs.crop([0, 640, 320, 960]),
zero123pp_imgs.crop([320, 640, 640, 960])]
segmented_imgs = []
for i, img in enumerate(imgs):
output = rembg.remove(img)
mask = numpy.array(output)[:, :, 3]
mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
data = numpy.array(img)[:,:,:3]
data[mask == 0] = [255, 255, 255]
segmented_imgs.append(data)
result = numpy.concatenate([
numpy.concatenate([segmented_imgs[0], segmented_imgs[1]], axis=1),
numpy.concatenate([segmented_imgs[2], segmented_imgs[3]], axis=1),
numpy.concatenate([segmented_imgs[4], segmented_imgs[5]], axis=1)
])
return Image.fromarray(result)
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
@st.cache_data
def check_dependencies():
reqs = []
try:
import diffusers
except ImportError:
import traceback
traceback.print_exc()
print("Error: `diffusers` not found.", file=sys.stderr)
reqs.append("diffusers==0.20.2")
else:
if not diffusers.__version__.startswith("0.20"):
print(
f"Warning: You are using an unsupported version of diffusers ({diffusers.__version__}), which may lead to performance issues.",
file=sys.stderr
)
print("Recommended version is `diffusers==0.20.2`.", file=sys.stderr)
try:
import transformers
except ImportError:
import traceback
traceback.print_exc()
print("Error: `transformers` not found.", file=sys.stderr)
reqs.append("transformers==4.29.2")
if torch.__version__ < '2.0':
try:
import xformers
except ImportError:
print("Warning: You are using PyTorch 1.x without a working `xformers` installation.", file=sys.stderr)
print("You may see a significant memory overhead when running the model.", file=sys.stderr)
if len(reqs):
print(f"Info: Fix all dependency errors with `pip install {' '.join(reqs)}`.")
@st.cache_resource
def load_zero123plus_pipeline():
if 'HF_TOKEN' in os.environ:
huggingface_hub.login(os.environ['HF_TOKEN'])
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.1", custom_pipeline="sudo-ai/zero123plus-pipeline",
torch_dtype=torch.float16
)
# Feel free to tune the scheduler
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
if torch.cuda.is_available():
pipeline.to('cuda:0')
sys.main_lock = threading.Lock()
return pipeline
================================================
FILE: scripts/gen_mv.py
================================================
import torch
import requests
from PIL import Image
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
import numpy
import os
import sys
import numpy
import torch
import rembg
import threading
import urllib.request
from PIL import Image
import streamlit as st
import huggingface_hub
from app import SAMAPI
def segment_img(img: Image):
output = rembg.remove(img)
mask = numpy.array(output)[:, :, 3] > 0
sam_mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
segmented_img = Image.new("RGBA", img.size, (0, 0, 0, 0))
segmented_img.paste(img, mask=Image.fromarray(sam_mask))
return segmented_img
def segment_6imgs(zero123pp_imgs):
imgs = [zero123pp_imgs.crop([0, 0, 320, 320]),
zero123pp_imgs.crop([320, 0, 640, 320]),
zero123pp_imgs.crop([0, 320, 320, 640]),
zero123pp_imgs.crop([320, 320, 640, 640]),
zero123pp_imgs.crop([0, 640, 320, 960]),
zero123pp_imgs.crop([320, 640, 640, 960])]
segmented_imgs = []
import numpy as np
for i, img in enumerate(imgs):
output = rembg.remove(img)
mask = numpy.array(output)[:, :, 3]
mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
data = numpy.array(img)[:,:,:3]
data2 = numpy.ones([320,320,4])
data2[:,:,:3] = data
for i in np.arange(data2.shape[0]):
for j in np.arange(data2.shape[1]):
if mask[i,j]==1:
data2[i,j,3]=255
segmented_imgs.append(data2)
#torch.manual_seed(42)
return segmented_imgs
def process_img(path,destination,pipeline, is_first):
# Download an example image.
print('processing:',path)
#cond_whole = Image.open('output.png')
cond = Image.open(path)
# Run the pipeline!
result = pipeline(cond, num_inference_steps=75,is_first = is_first).images[0]
# for general real and synthetic images of general objects
# usually it is enough to have around 28 inference steps
# for images with delicate details like faces (real or anime)
# you may need 75-100 steps for the details to construct
#result.show()
#result.save("./test_png/zero123pp/output.png")
result=segment_6imgs(result)
print('saving:',os.path.join(destination,'0~5.png'),'in',destination)
for i in numpy.arange(6):
Image.fromarray(numpy.uint8(result[i])).save(os.path.join(destination,'{}.png'.format(i)))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--path", required=True, help="path to process")
# DiffusionPipeline.from_pretrained cannot received relative path for custom pipeline
parser.add_argument("--pipeline_path", required=True, help="path of pipeline code, in ../guidance/zero123pp")
args, extras = parser.parse_known_args()
pipeline = DiffusionPipeline.from_pretrained(
"sudo-ai/zero123plus-v1.1", custom_pipeline=args.pipeline_path,
torch_dtype=torch.float16
)
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipeline.scheduler.config, timestep_spacing='trailing'
)
pipeline.to('cuda:0')
directory = args.path+'/'
os.makedirs(directory+'ref', exist_ok=True)
os.system(f"cp -r {directory+'*.png'} {directory+'ref/'}")
is_first = True
l=sorted(os.listdir(directory+'ref'))
for file in sorted(os.listdir(directory+'ref')):
if file[-4:-1]=='.pn':
filename = os.path.splitext(os.path.basename(file))[0]
destination = os.path.join(directory+'zero123',filename)
os.makedirs(destination, exist_ok=True)
img_path = os.path.join(directory+'ref',file)
process_img(img_path,destination,pipeline, is_first)
is_first = False
================================================
FILE: sh_utils.py
================================================
# Copyright 2021 The PlenOctree Authors.
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
import torch
C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [
1.0925484305920792,
-1.0925484305920792,
0.31539156525252005,
-1.0925484305920792,
0.5462742152960396
]
C3 = [
-0.5900435899266435,
2.890611442640554,
-0.4570457994644658,
0.3731763325901154,
-0.4570457994644658,
1.445305721320277,
-0.5900435899266435
]
C4 = [
2.5033429417967046,
-1.7701307697799304,
0.9461746957575601,
-0.6690465435572892,
0.10578554691520431,
-0.6690465435572892,
0.47308734787878004,
-1.7701307697799304,
0.6258357354491761,
]
def eval_sh(deg, sh, dirs):
"""
Evaluate spherical harmonics at unit directions
using hardcoded SH polynomials.
Works with torch/np/jnp.
... Can be 0 or more batch dimensions.
Args:
deg: int SH deg. Currently, 0-3 supported
sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
dirs: jnp.ndarray unit directions [..., 3]
Returns:
[..., C]
"""
assert deg <= 4 and deg >= 0
coeff = (deg + 1) ** 2
assert sh.shape[-1] >= coeff
result = C0 * sh[..., 0]
if deg > 0:
x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
result = (result -
C1 * y * sh[..., 1] +
C1 * z * sh[..., 2] -
C1 * x * sh[..., 3])
if deg > 1:
xx, yy, zz = x * x, y * y, z * z
xy, yz, xz = x * y, y * z, x * z
result = (result +
C2[0] * xy * sh[..., 4] +
C2[1] * yz * sh[..., 5] +
C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
C2[3] * xz * sh[..., 7] +
C2[4] * (xx - yy) * sh[..., 8])
if deg > 2:
result = (result +
C3[0] * y * (3 * xx - yy) * sh[..., 9] +
C3[1] * xy * z * sh[..., 10] +
C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
C3[5] * z * (xx - yy) * sh[..., 14] +
C3[6] * x * (xx - 3 * yy) * sh[..., 15])
if deg > 3:
result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
return result
def RGB2SH(rgb):
return (rgb - 0.5) / C0
def SH2RGB(sh):
return sh * C0 + 0.5
================================================
FILE: simple-knn/ext.cpp
================================================
/*
* Copyright (C) 2023, Inria
* GRAPHDECO research group, https://team.inria.fr/graphdeco
* All rights reserved.
*
* This software is free for non-commercial, research and evaluation use
* under the terms of the LICENSE.md file.
*
* For inquiries contact george.drettakis@inria.fr
*/
#include <torch/extension.h>
#include "spatial.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("distCUDA2", &distCUDA2);
}
================================================
FILE: simple-knn/setup.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact george.drettakis@inria.fr
#
from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
import os
cxx_compiler_flags = []
if os.name == 'nt':
cxx_compiler_flags.append("/wd4624")
setup(
name="simple_knn",
ext_modules=[
CUDAExtension(
name="simple_knn._C",
sources=[
"spatial.cu",
"simple_knn.cu",
"ext.cpp"],
extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags})
],
cmdclass={
'build_ext': BuildExtension
}
)
================================================
FILE: simple-knn/simple_knn/.gitkeep
================================================
================================================
FILE: simple-knn/simple_knn.cu
================================================
/*
* Copyright (C) 2023, Inria
* GRAPHDECO research group, https://team.inria.fr/graphdeco
* All rights reserved.
*
* This software is free for non-commercial, research and evaluation use
* under the terms of the LICENSE.md file.
*
* For inquiries contact george.drettakis@inria.fr
*/
#define BOX_SIZE 1024
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include "simple_knn.h"
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <vector>
#include <cuda_runtime_api.h>
#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#define __CUDACC__
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
namespace cg = cooperative_groups;
struct CustomMin
{
__device__ __forceinline__
float3 operator()(const float3& a, const float3& b) const {
return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) };
}
};
struct CustomMax
{
__device__ __forceinline__
float3 operator()(const float3& a, const float3& b) const {
return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) };
}
};
__host__ __device__ uint32_t prepMorton(uint32_t x)
{
x = (x | (x << 16)) & 0x030000FF;
x = (x | (x << 8)) & 0x0300F00F;
x = (x | (x << 4)) & 0x030C30C3;
x = (x | (x << 2)) & 0x09249249;
return x;
}
__host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx)
{
uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1));
uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1));
uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1));
return x | (y << 1) | (z << 2);
}
__global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes)
{
auto idx = cg::this_grid().thread_rank();
if (idx >= P)
return;
codes[idx] = coord2Morton(points[idx], minn, maxx);
}
struct MinMax
{
float3 minn;
float3 maxx;
};
__global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes)
{
auto idx = cg::this_grid().thread_rank();
MinMax me;
if (idx < P)
{
me.minn = points[indices[idx]];
me.maxx = points[indices[idx]];
}
else
{
me.minn = { FLT_MAX, FLT_MAX, FLT_MAX };
me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX };
}
__shared__ MinMax redResult[BOX_SIZE];
for (int off = BOX_SIZE / 2; off >= 1; off /= 2)
{
if (threadIdx.x < 2 * off)
redResult[threadIdx.x] = me;
__syncthreads();
if (threadIdx.x < off)
{
MinMax other = redResult[threadIdx.x + off];
me.minn.x = min(me.minn.x, other.minn.x);
me.minn.y = min(me.minn.y, other.minn.y);
me.minn.z = min(me.minn.z, other.minn.z);
me.maxx.x = max(me.maxx.x, other.maxx.x);
me.maxx.y = max(me.maxx.y, o
gitextract_687zu6g9/ ├── .gitignore ├── README.md ├── __init__.py ├── cam_utils.py ├── configs/ │ └── stag4d.yaml ├── dataset_4d.py ├── deform.py ├── gs_renderer_4d.py ├── guidance/ │ ├── zero123_4d_utils.py │ └── zero123pp/ │ └── pipeline.py ├── main.py ├── mini_trainer.ipynb ├── requirements.txt ├── scripts/ │ ├── app.py │ └── gen_mv.py ├── sh_utils.py ├── simple-knn/ │ ├── ext.cpp │ ├── setup.py │ ├── simple_knn/ │ │ └── .gitkeep │ ├── simple_knn.cu │ ├── simple_knn.h │ ├── spatial.cu │ └── spatial.h ├── visualize.py └── zero123.py
SYMBOL INDEX (233 symbols across 14 files)
FILE: cam_utils.py
function dot (line 6) | def dot(x, y):
function length (line 13) | def length(x, eps=1e-20):
function safe_normalize (line 20) | def safe_normalize(x, eps=1e-20):
function look_at (line 24) | def look_at(campos, target, opengl=True):
function orbit_camera (line 45) | def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=No...
class OrbitCamera (line 65) | class OrbitCamera:
method __init__ (line 66) | def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
method fovx (line 78) | def fovx(self):
method campos (line 82) | def campos(self):
method pose (line 87) | def pose(self):
method view (line 101) | def view(self):
method perspective (line 106) | def perspective(self):
method intrinsics (line 126) | def intrinsics(self):
method mvp (line 131) | def mvp(self):
method orbit (line 134) | def orbit(self, dx, dy):
method scale (line 141) | def scale(self, delta):
method pan (line 144) | def pan(self, dx, dy, dz=0):
FILE: dataset_4d.py
class SparseDataset (line 16) | class SparseDataset:
method __init__ (line 17) | def __init__(self, opt, size,device='cuda', type='train', H=256, W=256):
method collate_ref (line 32) | def collate_ref(self,index):
method collate_zero123 (line 55) | def collate_zero123(self,index):
method collate (line 85) | def collate(self, index):
method dataloader (line 97) | def dataloader(self):
method dataloader_d (line 101) | def dataloader_d(self):
FILE: deform.py
class Deformation (line 23) | class Deformation(nn.Module):
method __init__ (line 24) | def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[],...
method create_net (line 47) | def create_net(self):
method query_time (line 66) | def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb):
method forward (line 79) | def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, o...
method forward_static (line 85) | def forward_static(self, rays_pts_emb):
method forward_dynamic (line 89) | def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opac...
method get_mlp_parameters (line 112) | def get_mlp_parameters(self):
method get_grid_parameters (line 118) | def get_grid_parameters(self):
class deform_network (line 121) | class deform_network(nn.Module):
method __init__ (line 122) | def __init__(self) :
method forward (line 144) | def forward(self, point, scales=None, rotations=None, opacity=None, ti...
method forward_static (line 151) | def forward_static(self, points):
method forward_dynamic (line 154) | def forward_dynamic(self, point, scales=None, rotations=None, opacity=...
method get_mlp_parameters (line 164) | def get_mlp_parameters(self):
method get_grid_parameters (line 166) | def get_grid_parameters(self):
function initialize_weights (line 169) | def initialize_weights(m):
function get_normalized_directions (line 179) | def get_normalized_directions(directions):
function normalize_aabb (line 188) | def normalize_aabb(pts, aabb):
function grid_sample_wrapper (line 190) | def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_...
function init_grid_param (line 217) | def init_grid_param(
function interpolate_ms_features (line 242) | def interpolate_ms_features(pts: torch.Tensor,
class HexPlaneField (line 278) | class HexPlaneField(nn.Module):
method __init__ (line 279) | def __init__(
method set_aabb (line 320) | def set_aabb(self,xyz_max, xyz_min):
method get_density (line 328) | def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Te...
method forward (line 345) | def forward(self,
function compute_plane_tv (line 353) | def compute_plane_tv(t):
function compute_plane_smoothness (line 362) | def compute_plane_smoothness(t):
class Regularizer (line 371) | class Regularizer():
method __init__ (line 372) | def __init__(self, reg_type, initialization):
method step (line 378) | def step(self, global_step):
method report (line 381) | def report(self, d):
method regularize (line 385) | def regularize(self, *args, **kwargs) -> torch.Tensor:
method _regularize (line 391) | def _regularize(self, *args, **kwargs) -> torch.Tensor:
method __str__ (line 394) | def __str__(self):
class PlaneTV (line 398) | class PlaneTV(Regularizer):
method __init__ (line 399) | def __init__(self, initial_value, what: str = 'field'):
method step (line 407) | def step(self, global_step):
method _regularize (line 410) | def _regularize(self, model, **kwargs):
class TimeSmoothness (line 433) | class TimeSmoothness(Regularizer):
method __init__ (line 434) | def __init__(self, initial_value, what: str = 'field'):
method _regularize (line 442) | def _regularize(self, model, **kwargs) -> torch.Tensor:
class L1ProposalNetwork (line 463) | class L1ProposalNetwork(Regularizer):
method __init__ (line 464) | def __init__(self, initial_value):
method _regularize (line 467) | def _regularize(self, model, **kwargs) -> torch.Tensor:
class DepthTV (line 476) | class DepthTV(Regularizer):
method __init__ (line 477) | def __init__(self, initial_value):
method _regularize (line 480) | def _regularize(self, model, model_out, **kwargs) -> torch.Tensor:
class L1TimePlanes (line 488) | class L1TimePlanes(Regularizer):
method __init__ (line 489) | def __init__(self, initial_value, what='field'):
method _regularize (line 496) | def _regularize(self, model, **kwargs) -> torch.Tensor:
FILE: gs_renderer_4d.py
function inverse_sigmoid (line 19) | def inverse_sigmoid(x):
function get_expon_lr_func (line 25) | def get_expon_lr_func(
function strip_lowerdiag (line 50) | def strip_lowerdiag(L):
function strip_symmetric (line 61) | def strip_symmetric(sym):
function gaussian_3d_coeff (line 64) | def gaussian_3d_coeff(xyzs, covs):
function build_rotation (line 85) | def build_rotation(r):
function build_scaling_rotation (line 108) | def build_scaling_rotation(s, r):
class BasicPointCloud (line 119) | class BasicPointCloud(NamedTuple):
class GaussianModel (line 125) | class GaussianModel:
method setup_functions (line 127) | def setup_functions(self):
method initialize (line 144) | def initialize(self, initial_values, raw=False):
method __init__ (line 155) | def __init__(self, sh_degree : int,args = None):
method capture (line 174) | def capture(self):
method restore (line 192) | def restore(self, model_args, training_args):
method get_scaling (line 213) | def get_scaling(self):
method get_rotation (line 217) | def get_rotation(self):
method get_xyz (line 221) | def get_xyz(self):
method get_features (line 225) | def get_features(self):
method get_opacity (line 231) | def get_opacity(self):
method get_covariance (line 238) | def get_covariance(self, scaling_modifier = 1):
method oneupSHdegree (line 241) | def oneupSHdegree(self):
method create_from_pcd (line 245) | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : fl...
method training_setup (line 276) | def training_setup(self, training_args):
method update_learning_rate (line 308) | def update_learning_rate(self, iteration):
method construct_list_of_attributes (line 324) | def construct_list_of_attributes(self):
method save_ply (line 338) | def save_ply(self, path):
method compute_deformation (line 357) | def compute_deformation(self,time):
method load_model (line 362) | def load_model(self, path):
method save_deformation (line 378) | def save_deformation(self, path):
method load_ply (line 383) | def load_ply(self, path):
method replace_tensor_to_optimizer (line 430) | def replace_tensor_to_optimizer(self, tensor, name):
method _prune_optimizer (line 445) | def _prune_optimizer(self, mask):
method prune_points (line 465) | def prune_points(self, mask):
method cat_tensors_to_optimizer (line 481) | def cat_tensors_to_optimizer(self, tensors_dict):
method densification_postfix (line 504) | def densification_postfix(self, new_xyz, new_features_dc, new_features...
method densify_and_split (line 529) | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
method densify_and_clone (line 555) | def densify_and_clone(self, grads, grad_threshold, scene_extent):
method densify_and_prune (line 571) | def densify_and_prune(self, max_grad_percent, min_opacity, extent, max...
method prune (line 600) | def prune(self, min_opacity=0.01, extent=1, max_screen_size=1):
method standard_constaint (line 619) | def standard_constaint(self):
method add_densification_stats (line 633) | def add_densification_stats(self, viewspace_point_tensor, update_filter):
method update_deformation_table (line 639) | def update_deformation_table(self,threshold):
method print_deformation_weight_grad (line 642) | def print_deformation_weight_grad(self):
method _plane_regulation (line 652) | def _plane_regulation(self):
method _time_regulation (line 664) | def _time_regulation(self):
method _l1_regulation (line 676) | def _l1_regulation(self):
method compute_regulation (line 690) | def compute_regulation(self, time_smoothness_weight, l1_time_planes_we...
function getProjectionMatrix (line 693) | def getProjectionMatrix(znear, zfar, fovX, fovY):
class MiniCam (line 709) | class MiniCam:
method __init__ (line 710) | def __init__(self, c2w, width, height, fovy, fovx, znear, zfar,time=0 ):
class Renderer (line 738) | class Renderer:
method __init__ (line 739) | def __init__(self, sh_degree=3, white_background=True, radius=1):
method initialize (line 753) | def initialize(self, input=None, num_pts=5000, radius=0.5,initial_valu...
method render (line 785) | def render(
FILE: guidance/zero123_4d_utils.py
class Zero123 (line 21) | class Zero123(nn.Module):
method __init__ (line 22) | def __init__(self, device, fp16=True, t_range=[0.02, 0.98]):
method get_img_embeds (line 62) | def get_img_embeds(self, x, input_imgs=None):
method refine (line 84) | def refine(self, pred_rgb, polar, azimuth, radius,
method train_step (line 129) | def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None...
method update_step (line 186) | def update_step(self, epoch: int, global_step: int, on_load_weights: b...
method get_steps (line 193) | def get_steps(self,percent,epoch, global_step):
method decode_latents (line 203) | def decode_latents(self, latents):
method encode_imgs (line 211) | def encode_imgs(self, imgs, mode=False):
FILE: guidance/zero123pp/pipeline.py
function to_rgb_image (line 37) | def to_rgb_image(maybe_rgba: Image.Image):
class MyAttnProcessor2_0 (line 49) | class MyAttnProcessor2_0:
method __init__ (line 54) | def __init__(self):
method __call__ (line 58) | def __call__(
class ReferenceOnlyAttnProc (line 152) | class ReferenceOnlyAttnProc(torch.nn.Module):
method __init__ (line 153) | def __init__(
method __call__ (line 164) | def __call__(
class RefOnlyNoisedUNet (line 191) | class RefOnlyNoisedUNet(torch.nn.Module):
method __init__ (line 192) | def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMSchedu...
method __getattr__ (line 212) | def __getattr__(self, name: str):
method forward_cond (line 218) | def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states...
method forward (line 230) | def forward(
function scale_latents (line 275) | def scale_latents(latents):
function unscale_latents (line 280) | def unscale_latents(latents):
function scale_image (line 285) | def scale_image(image):
function unscale_image (line 290) | def unscale_image(image):
class DepthControlUNet (line 295) | class DepthControlUNet(torch.nn.Module):
method __init__ (line 296) | def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffu...
method __getattr__ (line 309) | def __getattr__(self, name: str):
method forward (line 315) | def forward(self, sample, timestep, encoder_hidden_states, class_label...
class ModuleListDict (line 336) | class ModuleListDict(torch.nn.Module):
method __init__ (line 337) | def __init__(self, procs: dict) -> None:
method __getitem__ (line 342) | def __getitem__(self, key):
class SuperNet (line 346) | class SuperNet(torch.nn.Module):
method __init__ (line 347) | def __init__(self, state_dict: Dict[str, torch.Tensor]):
class Zero123PlusPipeline (line 386) | class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
method __init__ (line 405) | def __init__(
method prepare (line 431) | def prepare(self):
method add_controlnet (line 436) | def add_controlnet(self, controlnet: Optional[diffusers.ControlNetMode...
method encode_condition_image (line 441) | def encode_condition_image(self, image: torch.Tensor):
method __call__ (line 446) | def __call__(
FILE: main.py
function save_image_to_local (line 19) | def save_image_to_local(image_tensor, file_path):
class GUI (line 26) | class GUI:
method __init__ (line 27) | def __init__(self, opt):
method __del__ (line 110) | def __del__(self):
method seed_everything (line 114) | def seed_everything(self):
method prepare_image (line 129) | def prepare_image(self,idx):
method prepare_train (line 160) | def prepare_train(self):
method train_step (line 216) | def train_step(self):
method set_fix_cam (line 414) | def set_fix_cam(self):
method test_step (line 497) | def test_step(self):
method load_input (line 567) | def load_input(self, file):
method save_renderings (line 573) | def save_renderings(self, elev=0, azim=0, radius=2, name='front'):
method save_model (line 606) | def save_model(self, mode='geo', texture_size=1024):
method register_dpg (line 622) | def register_dpg(self):
method render (line 955) | def render(self):
method train (line 965) | def train(self, iters=500):
FILE: scripts/app.py
class SAMAPI (line 12) | class SAMAPI:
method get_instance (line 17) | def get_instance(sam_checkpoint=None):
method segment_api (line 40) | def segment_api(rgb, mask=None, bbox=None, sam_checkpoint=None):
function image_examples (line 74) | def image_examples(samples, ncols, return_key=None, example_text="Exampl...
function segment_img (line 96) | def segment_img(img: Image):
function segment_6imgs (line 105) | def segment_6imgs(zero123pp_imgs):
function expand2square (line 128) | def expand2square(pil_img, background_color):
function check_dependencies (line 143) | def check_dependencies():
function load_zero123plus_pipeline (line 177) | def load_zero123plus_pipeline():
FILE: scripts/gen_mv.py
function segment_img (line 17) | def segment_img(img: Image):
function segment_6imgs (line 26) | def segment_6imgs(zero123pp_imgs):
function process_img (line 51) | def process_img(path,destination,pipeline, is_first):
FILE: sh_utils.py
function eval_sh (line 57) | def eval_sh(deg, sh, dirs):
function RGB2SH (line 114) | def RGB2SH(rgb):
function SH2RGB (line 117) | def SH2RGB(sh):
FILE: simple-knn/ext.cpp
function PYBIND11_MODULE (line 15) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
FILE: simple-knn/simple_knn.h
function class (line 15) | class SimpleKNN
FILE: visualize.py
function save_image_to_local (line 22) | def save_image_to_local(image_tensor, file_path):
class GUI (line 28) | class GUI:
method __init__ (line 29) | def __init__(self, opt):
method __del__ (line 95) | def __del__(self):
method seed_everything (line 99) | def seed_everything(self):
method prepare_train (line 120) | def prepare_train(self):
method train_step (line 164) | def train_step(self):
method save_renderings (line 206) | def save_renderings(self, elev=0, azim=0, radius=2, name='front', inte...
method set_fix_cam2 (line 270) | def set_fix_cam2(self):
method test_step (line 353) | def test_step(self):
method load_input (line 423) | def load_input(self, file):
method save_model (line 479) | def save_model(self, mode='geo', texture_size=1024):
method register_dpg (line 495) | def register_dpg(self):
method render (line 828) | def render(self):
method train (line 838) | def train(self, iters=500):
FILE: zero123.py
class CLIPCameraProjection (line 41) | class CLIPCameraProjection(ModelMixin, ConfigMixin):
method __init__ (line 53) | def __init__(self, embedding_dim: int = 768, additional_embeddings: in...
method forward (line 63) | def forward(
class Zero123Pipeline (line 81) | class Zero123Pipeline(DiffusionPipeline):
method __init__ (line 109) | def __init__(
method enable_sequential_cpu_offload (line 180) | def enable_sequential_cpu_offload(self, gpu_id=0):
method _execution_device (line 204) | def _execution_device(self):
method _encode_image (line 221) | def _encode_image(
method run_safety_checker (line 299) | def run_safety_checker(self, image, device, dtype):
method decode_latents (line 318) | def decode_latents(self, latents):
method prepare_extra_step_kwargs (line 332) | def prepare_extra_step_kwargs(self, generator, eta):
method check_inputs (line 353) | def check_inputs(self, image, height, width, callback_steps):
method prepare_latents (line 371) | def prepare_latents(
method _get_latent_model_input (line 405) | def _get_latent_model_input(
method __call__ (line 449) | def __call__(
Condensed preview — 25 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (275K chars).
[
{
"path": ".gitignore",
"chars": 15,
"preview": "*.pyc\n./valid/*"
},
{
"path": "README.md",
"chars": 5135,
"preview": "<h1>[ECCV 2024] STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians</h1>\n\n<div>\n <a href='https://github.com/ze"
},
{
"path": "__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "cam_utils.py",
"chars": 4804,
"preview": "import numpy as np\nfrom scipy.spatial.transform import Rotation as R\n\nimport torch\n\ndef dot(x, y):\n if isinstance(x, "
},
{
"path": "configs/stag4d.yaml",
"chars": 2166,
"preview": "### Input\n# input rgba image path (default to None, can be load in GUI too)\ninput: \n# input text prompt (default to None"
},
{
"path": "dataset_4d.py",
"chars": 3770,
"preview": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport random\nimport numpy as np\nfrom scipy.spatial.transform i"
},
{
"path": "deform.py",
"chars": 19304,
"preview": "\nimport functools\nimport math\nimport os\nimport time\nfrom tkinter import W\n\nimport numpy as np\nimport torch\nimport torch."
},
{
"path": "gs_renderer_4d.py",
"chars": 41064,
"preview": "import os\nimport math\nimport numpy as np\nfrom typing import NamedTuple\nfrom plyfile import PlyData, PlyElement\n\nimport t"
},
{
"path": "guidance/zero123_4d_utils.py",
"chars": 10453,
"preview": "from transformers import CLIPTextModel, CLIPTokenizer, logging\nfrom diffusers import (\n AutoencoderKL,\n UNet2DCond"
},
{
"path": "guidance/zero123pp/pipeline.py",
"chars": 20621,
"preview": "from typing import Any, Dict, Optional\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.s"
},
{
"path": "main.py",
"chars": 37702,
"preview": "import os\nimport cv2\nimport time\nimport tqdm\nimport numpy as np\nimport dearpygui.dearpygui as dpg\n\nimport torch\nimport t"
},
{
"path": "mini_trainer.ipynb",
"chars": 33086,
"preview": "{\n \"cells\": [\n {\n \"cell_type\": \"code\",\n \"execution_count\": 1,\n \"metadata\": {},\n \"outputs\": [],\n \"source\": [\n "
},
{
"path": "requirements.txt",
"chars": 378,
"preview": "tqdm\nrich\nninja\nnumpy\npandas\nscipy\nscikit-learn\nmatplotlib\nopencv-python\nimageio\nimageio-ffmpeg\nomegaconf\nargparse\ntorch"
},
{
"path": "scripts/app.py",
"chars": 6720,
"preview": "import os\nimport sys\nimport numpy\nimport torch\nimport rembg\nimport threading\nimport urllib.request\nfrom PIL import Image"
},
{
"path": "scripts/gen_mv.py",
"chars": 3966,
"preview": "import torch\nimport requests\nfrom PIL import Image\nfrom diffusers import DiffusionPipeline, EulerAncestralDiscreteSchedu"
},
{
"path": "sh_utils.py",
"chars": 4371,
"preview": "# Copyright 2021 The PlenOctree Authors.\n# Redistribution and use in source and binary forms, with or without\n# modif"
},
{
"path": "simple-knn/ext.cpp",
"chars": 427,
"preview": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n"
},
{
"path": "simple-knn/setup.py",
"chars": 830,
"preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
},
{
"path": "simple-knn/simple_knn/.gitkeep",
"chars": 0,
"preview": ""
},
{
"path": "simple-knn/simple_knn.cu",
"chars": 6352,
"preview": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n"
},
{
"path": "simple-knn/simple_knn.h",
"chars": 451,
"preview": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n"
},
{
"path": "simple-knn/spatial.cu",
"chars": 671,
"preview": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n"
},
{
"path": "simple-knn/spatial.h",
"chars": 380,
"preview": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n"
},
{
"path": "visualize.py",
"chars": 29874,
"preview": "import os\nimport cv2\nimport time\nimport tqdm\nimport numpy as np\nimport dearpygui.dearpygui as dpg\n\nimport torch\nimport t"
},
{
"path": "zero123.py",
"chars": 30625,
"preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
}
]
About this extraction
This page contains the full source code of the zeng-yifei/STAG4D GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 25 files (257.0 KB), approximately 64.7k tokens, and a symbol index with 233 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.