Repository: zeng-yifei/STAG4D Branch: main Commit: 9aa21c92f40b Files: 25 Total size: 257.0 KB Directory structure: gitextract_687zu6g9/ ├── .gitignore ├── README.md ├── __init__.py ├── cam_utils.py ├── configs/ │ └── stag4d.yaml ├── dataset_4d.py ├── deform.py ├── gs_renderer_4d.py ├── guidance/ │ ├── zero123_4d_utils.py │ └── zero123pp/ │ └── pipeline.py ├── main.py ├── mini_trainer.ipynb ├── requirements.txt ├── scripts/ │ ├── app.py │ └── gen_mv.py ├── sh_utils.py ├── simple-knn/ │ ├── ext.cpp │ ├── setup.py │ ├── simple_knn/ │ │ └── .gitkeep │ ├── simple_knn.cu │ ├── simple_knn.h │ ├── spatial.cu │ └── spatial.h ├── visualize.py └── zero123.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .gitignore ================================================ *.pyc ./valid/* ================================================ FILE: README.md ================================================

[ECCV 2024] STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians

Yifei Zeng1Yanqin Jiang*2Siyu Zhu3Yuanxun Lu1Youtian Lin1Hao Zhu1Weiming Hu2Xun Cao1Yao Yao1+
1Nanjing University 2CASIA 3Fudan University
*equal contribution +corresponding author

[Project Page]

# Update 7.4 Our paper has been accepted by ECCV 2024. Congrats! 6.20: IMPORTANT. Fix the bug caused by new version of diff_gauss. Newest version of diff_gauss use `color, depth, norm, alpha, radii, extra` as an output. However, previous version use `color, depth, alpha, radii` as an output. Using older version of this code will cause mismatch error and may misuse normal for the alpha loss, resulting in bad results. 5.26: Update Text/Image to 4D data below. 5.21: Fix RGB loss into the batch loop. Add visualize code. # ⚙️ Installation ```bash pip install -r requirements.txt git clone --recursive https://github.com/slothfulxtx/diff-gaussian-rasterization.git pip install ./diff-gaussian-rasterization pip install ./simple-knn ``` # Video-to-4D To generate the examples in the project page, you can download the dataset from [google drive](https://drive.google.com/file/d/1YDvhBv6z5SByF_WaTQVzzL9qz3TyEm6a/view?usp=sharing). Place them in the dataset folder, and run: ```bash python main.py --config configs/stag4d.yaml path=dataset/minions save_path=minions #use --gui=True to turn on the visualizer (recommend) python main.py --config configs/stag4d.yaml path=dataset/minions save_path=minions gui=True ``` To generate the spatial-temporal consistent data from stratch, your should place your rgba data in the form of ``` ├── dataset │ | your_data │ ├── 0_rgba.png │ ├── 1_rgba.png │ ├── 2_rgba.png │ ├── ... ``` and then run ```bash python scripts/gen_mv.py --path dataset/your_data --pipeline_path xxx/guidance/zero123pp python main.py --config configs/stag4d.yaml path=data_path save_path=saving_path gui=True ``` To visualize the result, use you can replace the main.py with visualize.py, and the result will be saved to the valid/xxx path, e.g.: ```bash python visualize.py --config configs/stag4d.yaml path=dataset/minions save_path=minions ``` # Text-to-4D For Text to 4D generation, we recommend using SDXL and SVD to generate a reasonable video. Then, after matting the video, use the command above to generate a good 4D result. (This pipeline contains many independent parts and is kind of complex, so we may upload the whole workflow after integration if possible.) If you want generate the examples in the paper, I also updated the corresponding data here in [google drive](https://drive.google.com/file/d/1EDNL7EBMR1vlfMOABdXjHzcKY7IXdcnj/view?usp=sharing). Remember to set size to 26 in config or use `size=26` in the command: ```bash python main.py --config configs/stag4d.yaml path=dataset/xxx save_path=xxx size=26 ``` # Tips for better quality If you want sacrifice time for better quality, here is some tips you can try to further improve the generated quality. 1, Use larger batch size. 2, Run for more steps. ## Citation If you find our work useful for your research, please consider citing our paper as well as Consistent4D: ``` @article{zeng2024stag4d, title={STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians}, author={Yifei Zeng and Yanqin Jiang and Siyu Zhu and Yuanxun Lu and Youtian Lin and Hao Zhu and Weiming Hu and Xun Cao and Yao Yao}, year={2024}, eprint={2403.14939}, archivePrefix={arXiv}, primaryClass={cs.CV} } @article{jiang2023consistent4d, title={Consistent4D: Consistent 360{\deg} Dynamic Object Generation from Monocular Video}, author={Yanqin Jiang and Li Zhang and Jin Gao and Weimin Hu and Yao Yao}, year={2023}, eprint={2311.02848}, archivePrefix={arXiv}, primaryClass={cs.CV} } ``` # Acknowledgment This repo is built on [DreamGaussian](https://github.com/dreamgaussian/dreamgaussian) and [Zero123plus](https://github.com/SUDO-AI-3D/zero123plus). Thank all the authors for their great work. ================================================ FILE: __init__.py ================================================ ================================================ FILE: cam_utils.py ================================================ import numpy as np from scipy.spatial.transform import Rotation as R import torch def dot(x, y): if isinstance(x, np.ndarray): return np.sum(x * y, -1, keepdims=True) else: return torch.sum(x * y, -1, keepdim=True) def length(x, eps=1e-20): if isinstance(x, np.ndarray): return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps)) else: return torch.sqrt(torch.clamp(dot(x, x), min=eps)) def safe_normalize(x, eps=1e-20): return x / length(x, eps) def look_at(campos, target, opengl=True): # campos: [N, 3], camera/eye position # target: [N, 3], object to look at # return: [N, 3, 3], rotation matrix if not opengl: # camera forward aligns with -z forward_vector = safe_normalize(target - campos) up_vector = np.array([0, 1, 0], dtype=np.float32) right_vector = safe_normalize(np.cross(forward_vector, up_vector)) up_vector = safe_normalize(np.cross(right_vector, forward_vector)) else: # camera forward aligns with +z forward_vector = safe_normalize(campos - target) up_vector = np.array([0, 1, 0], dtype=np.float32) right_vector = safe_normalize(np.cross(up_vector, forward_vector)) up_vector = safe_normalize(np.cross(forward_vector, right_vector)) R = np.stack([right_vector, up_vector, forward_vector], axis=1) return R # elevation & azimuth to pose (cam2world) matrix def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True): # radius: scalar # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90) # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90) # return: [4, 4], camera pose matrix if is_degree: elevation = np.deg2rad(np.array(elevation)) azimuth = np.deg2rad(np.array(azimuth)) x = radius * np.cos(elevation) * np.sin(azimuth) y = - radius * np.sin(elevation) z = radius * np.cos(elevation) * np.cos(azimuth) if target is None: target = np.zeros([3], dtype=np.float32) campos = np.array([x, y, z]) + target # [3] T = np.eye(4, dtype=np.float32) T[:3, :3] = look_at(campos, target, opengl) T[:3, 3] = campos return T class OrbitCamera: def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100): self.W = W self.H = H self.radius = r # camera distance from center self.fovy = np.deg2rad(fovy) # deg 2 rad self.near = near self.far = far self.center = np.array([0, 0, 0], dtype=np.float32) # look at this point self.rot = R.from_matrix(np.eye(3)) self.up = np.array([0, 1, 0], dtype=np.float32) # need to be normalized! @property def fovx(self): return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H) @property def campos(self): return self.pose[:3, 3] # pose (c2w) @property def pose(self): # first move camera to radius res = np.eye(4, dtype=np.float32) res[2, 3] = self.radius # opengl convention... # rotate rot = np.eye(4, dtype=np.float32) rot[:3, :3] = self.rot.as_matrix() res = rot @ res # translate res[:3, 3] -= self.center return res # view (w2c) @property def view(self): return np.linalg.inv(self.pose) # projection (perspective) @property def perspective(self): y = np.tan(self.fovy / 2) aspect = self.W / self.H return np.array( [ [1 / (y * aspect), 0, 0, 0], [0, -1 / y, 0, 0], [ 0, 0, -(self.far + self.near) / (self.far - self.near), -(2 * self.far * self.near) / (self.far - self.near), ], [0, 0, -1, 0], ], dtype=np.float32, ) # intrinsics @property def intrinsics(self): focal = self.H / (2 * np.tan(self.fovy / 2)) return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32) @property def mvp(self): return self.perspective @ np.linalg.inv(self.pose) # [4, 4] def orbit(self, dx, dy): # rotate along camera up/side axis! side = self.rot.as_matrix()[:3, 0] rotvec_x = self.up * np.radians(-0.05 * dx) rotvec_y = side * np.radians(-0.05 * dy) self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot def scale(self, delta): self.radius *= 1.1 ** (-delta) def pan(self, dx, dy, dz=0): # pan in camera coordinate system (careful on the sensitivity!) self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz]) ================================================ FILE: configs/stag4d.yaml ================================================ ### Input # input rgba image path (default to None, can be load in GUI too) input: # input text prompt (default to None, can be input in GUI too) prompt: a minion # input mesh for stage 2 (auto-search from stage 1 output path if None) mesh: # estimated elevation angle for input image elevation: 0 # reference image resolution ref_size: 512 # density thresh for mesh extraction density_thresh: 1 device: cuda #dynamic size: 30 path: dataset/minions # checkpoint to load for stage 1 (should be a ply file) load: ### Output outdir: logs mesh_format: obj save_path: ??? save_step: 8000 #checkpoint to load for stage fine (should be a path of ply with deform pth) load_path: load_step: valid_interval: 500 ### Training # guidance loss weights (0 to disable) lambda_sd: 0 mvdream: False lambda_zero123: 1 lambda_tv: 1 scale_loss_ratio: 7.5 imagedream: False # training batch size per iter batch_size: 4 # training iterations for stage 1 iters: 2000 # training iterations for stage 2 iters_refine: 50 # training camera radius radius: 2 # training camera fovy fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61 # whether allow geom training in stage 2 train_geo: False # prob to invert background color during training (0 = always black, 1 = always white) invert_bg_prob: 0.5 ### GUI gui: False force_cuda_rast: False # GUI resolution H: 800 W: 800 deformation_lr_init : 0.00016 deformation_lr_final : 0.000016 deformation_lr_delay_mult : 0.02 grid_lr_init : 0.0016 grid_lr_final : 0.00016 ### Gaussian splatting num_pts: 10000 sh_degree: 0 position_lr_init : 0.0002 position_lr_final : 0.000002 position_lr_delay_mult: 0.01 position_lr_max_steps: 2000 position_lr_max_steps2: 5000 feature_lr: 0.005 opacity_lr: 0.02 scaling_lr: 0.01 rotation_lr: 0.002 init_steps: 700 percent_dense: 0.1 density_start_iter: 1200 density_end_iter: 6000 densification_interval: 100 opacity_reset_interval: 700 densify_grad_threshold_percent: 0.025 time_smoothness_weight: 5 plane_tv_weight: 0.05 l1_time_planes: 0.05 ### Textured Mesh geom_lr: 0.0001 texture_lr: 0.2 ================================================ FILE: dataset_4d.py ================================================ import os import cv2 import glob import json import tqdm import random import numpy as np from scipy.spatial.transform import Slerp, Rotation import torch import torch.nn.functional as F from torch.utils.data import DataLoader import rembg import glob class SparseDataset: def __init__(self, opt, size,device='cuda', type='train', H=256, W=256): super().__init__() self.opt = opt self.device = device self.type = type # train, val, test self.size = size self.H = H self.W = W self.path = opt.path self.cx = self.H / 2 self.cy = self.W / 2 self.bg_remover=None def collate_ref(self,index): #print(index,str(index)) file = os.path.join(self.path,'ref/{}_rgba.png'.format(str(index))) #print(f'[INFO] load image from {file}...') img = cv2.imread(file, cv2.IMREAD_UNCHANGED) if img.shape[-1] == 3: if self.bg_remover is None: self.bg_remover = rembg.new_session() img = rembg.remove(img, session=self.bg_remover) img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA) img = img.astype(np.float32) / 255.0 self.input_mask = img[..., 3:] # white bg self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask) # bgr to rgb self.input_img = self.input_img[..., ::-1].copy() return self.input_img ,self.input_mask def collate_zero123(self,index): self.pattern=os.path.join(self.path,'zero123/{}_rgba/*.png'.format(str(index))) self.input_imgs=[] self.input_masks=[] file_list = glob.glob(self.pattern) #print(self.pattern,file_list) for files in sorted(file_list): #print(f'[INFO] load image from {files}...') img = cv2.imread(files, cv2.IMREAD_UNCHANGED) if img.shape[-1] == 3: if self.bg_remover is None: self.bg_remover = rembg.new_session() img = rembg.remove(img, session=self.bg_remover) img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA) img = img.astype(np.float32) / 255.0 self.input_mask = img[..., 3:] # white bg self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask) # bgr to rgb self.input_img = self.input_img[..., ::-1].copy() self.input_imgs.append(self.input_img) self.input_masks.append(self.input_mask) return self.input_imgs, self.input_masks def collate(self, index): ref_view_batch,input_mask_batch,zero123_view_batch,zero123_masks_batch = [],[],[],[] for index in np.arange(self.size): ref_view,input_mask = self.collate_ref(index) zero123_view,zero123_masks = self.collate_zero123(index) ref_view_batch.append(ref_view) input_mask_batch.append(input_mask) zero123_view_batch.append(zero123_view) zero123_masks_batch.append(zero123_masks) return ref_view_batch, input_mask_batch,zero123_view_batch,zero123_masks_batch def dataloader(self): loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate,shuffle=False, num_workers=0) return loader def dataloader_d(self): loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate_d,shuffle=False, num_workers=0) return loader ================================================ FILE: deform.py ================================================ import functools import math import os import time from tkinter import W import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.cpp_extension import load import torch.nn.init as init import abc import itertools import logging as log from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable import torch import torch.nn as nn import torch.nn.functional as F class Deformation(nn.Module): def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[], args=None): super(Deformation, self).__init__() self.D = D self.W = W self.input_ch = input_ch self.input_ch_time = input_ch_time self.skips = skips self.no_grid=False self.no_ds=False self.no_dr=False self.no_do=True self.bounds = 1.6 self.kplanes_config = { 'grid_dimensions': 2, 'input_coordinate_dim': 4, 'output_coordinate_dim': 32, 'resolution': [64, 64, 64, 25] } self.multires = [1, 2, 4, 8] self.no_grid = self.no_grid self.grid = HexPlaneField(self.bounds, self.kplanes_config, self.multires) self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_net() def create_net(self): mlp_out_dim = 0 if self.no_grid: self.feature_out = [nn.Linear(4,self.W)] else: self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)] for i in range(self.D-1): self.feature_out.append(nn.ReLU()) self.feature_out.append(nn.Linear(self.W,self.W)) self.feature_out = nn.Sequential(*self.feature_out) output_dim = self.W return \ nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\ nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\ nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4)), \ nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1)) def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb): if self.no_grid: h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1) else: grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1]) h = grid_feature h = self.feature_out(h) return h def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None, time_emb=None): if time_emb is None: return self.forward_static(rays_pts_emb[:,:3]) else: return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, time_emb) def forward_static(self, rays_pts_emb): grid_feature = self.grid(rays_pts_emb[:,:3]) dx = self.static_mlp(grid_feature) return rays_pts_emb[:, :3] + dx def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, time_emb): hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_emb).float() dx = self.pos_deform(hidden) pts = rays_pts_emb[:, :3] + dx if self.no_ds: scales = scales_emb[:,:3] else: ds = self.scales_deform(hidden) scales = scales_emb[:,:3] + ds if self.no_dr: rotations = rotations_emb[:,:4] else: dr = self.rotations_deform(hidden) rotations = rotations_emb[:,:4] + dr if self.no_do: opacity = opacity_emb[:,:1] else: do = self.opacity_deform(hidden) opacity = opacity_emb[:,:1] + do # + do # print("deformation value:","pts:",torch.abs(dx).mean(),"rotation:",torch.abs(dr).mean()) return pts, scales, rotations, opacity def get_mlp_parameters(self): parameter_list = [] for name, param in self.named_parameters(): if "grid" not in name: parameter_list.append(param) return parameter_list def get_grid_parameters(self): return list(self.grid.parameters() ) # + list(self.timegrid.parameters()) class deform_network(nn.Module): def __init__(self) : super(deform_network, self).__init__() net_width = 64 timebase_pe = 4 defor_depth= 1 posbase_pe= 10 scale_rotation_pe = 2 opacity_pe = 2 timenet_width = 64 timenet_output = 32 times_ch = 2*timebase_pe+1 self.timenet = nn.Sequential( nn.Linear(times_ch, timenet_width), nn.ReLU(), nn.Linear(timenet_width, timenet_output)) self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(4+3)+((4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output, args=None) self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)])) self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)])) self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)])) self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)])) self.apply(initialize_weights) # print(self) def forward(self, point, scales=None, rotations=None, opacity=None, times_sel=None): if times_sel is not None: return self.forward_dynamic(point, scales, rotations, opacity, times_sel) else: return self.forward_static(point) def forward_static(self, points): points = self.deformation_net(points) return points def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, times_sel=None): # times_emb = poc_fre(times_sel, self.time_poc) means3D, scales, rotations, opacity = self.deformation_net( point, scales, rotations, opacity, # times_feature, times_sel) return means3D, scales, rotations, opacity def get_mlp_parameters(self): return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters()) def get_grid_parameters(self): return self.deformation_net.get_grid_parameters() def initialize_weights(m): if isinstance(m, nn.Linear): # init.constant_(m.weight, 0) init.xavier_uniform_(m.weight,gain=1) if m.bias is not None: init.xavier_uniform_(m.weight,gain=1) # init.constant_(m.bias, 0) def get_normalized_directions(directions): """SH encoding must be in the range [0, 1] Args: directions: batch of directions """ return (directions + 1.0) / 2.0 def normalize_aabb(pts, aabb): return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0 def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor: grid_dim = coords.shape[-1] if grid.dim() == grid_dim + 1: # no batch dimension present, need to add it grid = grid.unsqueeze(0) if coords.dim() == 2: coords = coords.unsqueeze(0) if grid_dim == 2 or grid_dim == 3: grid_sampler = F.grid_sample else: raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only " f"implemented for 2 and 3D data.") coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:])) B, feature_dim = grid.shape[:2] n = coords.shape[-2] interp = grid_sampler( grid, # [B, feature_dim, reso, ...] coords, # [B, 1, ..., n, grid_dim] align_corners=align_corners, mode='bilinear', padding_mode='border') interp = interp.view(B, feature_dim, n).transpose(-1, -2) # [B, n, feature_dim] interp = interp.squeeze() # [B?, n, feature_dim?] return interp def init_grid_param( grid_nd: int, in_dim: int, out_dim: int, reso: Sequence[int], a: float = 0.1, b: float = 0.5): assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension" has_time_planes = in_dim == 4 assert grid_nd <= in_dim coo_combs = list(itertools.combinations(range(in_dim), grid_nd)) grid_coefs = nn.ParameterList() for ci, coo_comb in enumerate(coo_combs): new_grid_coef = nn.Parameter(torch.empty( [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]] )) if has_time_planes and 3 in coo_comb: # Initialize time planes to 1 nn.init.ones_(new_grid_coef) else: nn.init.uniform_(new_grid_coef, a=a, b=b) grid_coefs.append(new_grid_coef) return grid_coefs def interpolate_ms_features(pts: torch.Tensor, ms_grids: Collection[Iterable[nn.Module]], grid_dimensions: int, concat_features: bool, num_levels: Optional[int], ) -> torch.Tensor: coo_combs = list(itertools.combinations( range(pts.shape[-1]), grid_dimensions) ) if num_levels is None: num_levels = len(ms_grids) multi_scale_interp = [] if concat_features else 0. grid: nn.ParameterList for scale_id, grid in enumerate(ms_grids[:num_levels]): interp_space = 1. for ci, coo_comb in enumerate(coo_combs): # interpolate in plane feature_dim = grid[ci].shape[1] # shape of grid[ci]: 1, out_dim, *reso interp_out_plane = ( grid_sample_wrapper(grid[ci], pts[..., coo_comb]) .view(-1, feature_dim) ) # compute product over planes interp_space = interp_space * interp_out_plane # combine over scales if concat_features: multi_scale_interp.append(interp_space) else: multi_scale_interp = multi_scale_interp + interp_space if concat_features: multi_scale_interp = torch.cat(multi_scale_interp, dim=-1) return multi_scale_interp class HexPlaneField(nn.Module): def __init__( self, bounds, planeconfig, multires ) -> None: super().__init__() aabb = torch.tensor([[bounds,bounds,bounds], [-bounds,-bounds,-bounds]]) self.aabb = nn.Parameter(aabb, requires_grad=False) self.grid_config = [planeconfig] self.multiscale_res_multipliers = multires self.concat_features = True # 1. Init planes self.grids = nn.ModuleList() self.feat_dim = 0 for res in self.multiscale_res_multipliers: # initialize coordinate grid config = self.grid_config[0].copy() # Resolution fix: multi-res only on spatial planes config["resolution"] = [ r * res for r in config["resolution"][:3] ] + config["resolution"][3:] gp = init_grid_param( grid_nd=config["grid_dimensions"], in_dim=config["input_coordinate_dim"], out_dim=config["output_coordinate_dim"], reso=config["resolution"], ) # shape[1] is out-dim - Concatenate over feature len for each scale if self.concat_features: self.feat_dim += gp[-1].shape[1] else: self.feat_dim = gp[-1].shape[1] self.grids.append(gp) # print(f"Initialized model grids: {self.grids}") print("feature_dim:",self.feat_dim) def set_aabb(self,xyz_max, xyz_min): aabb = torch.tensor([ xyz_max, xyz_min ]) self.aabb = nn.Parameter(aabb,requires_grad=True) print("Voxel Plane: set aabb=",self.aabb) def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None): """Computes and returns the densities.""" pts = normalize_aabb(pts, self.aabb) pts = torch.cat((pts, timestamps), dim=-1) # [n_rays, n_samples, 4] pts = pts.reshape(-1, pts.shape[-1]) features = interpolate_ms_features( pts, ms_grids=self.grids, # noqa grid_dimensions=self.grid_config[0]["grid_dimensions"], concat_features=self.concat_features, num_levels=None) if len(features) < 1: features = torch.zeros((0, 1)).to(features.device) return features def forward(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None): features = self.get_density(pts, timestamps) return features def compute_plane_tv(t): batch_size, c, h, w = t.shape count_h = batch_size * c * (h - 1) * w count_w = batch_size * c * h * (w - 1) h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum() w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum() return 2 * (h_tv / count_h + w_tv / count_w) # This is summing over batch and c instead of avg def compute_plane_smoothness(t): batch_size, c, h, w = t.shape # Convolve with a second derivative filter, in the time dimension which is dimension 2 first_difference = t[..., 1:, :] - t[..., :h-1, :] # [batch, c, h-1, w] second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :] # [batch, c, h-2, w] # Take the L2 norm of the result return torch.square(torch.abs(second_difference)).mean() class Regularizer(): def __init__(self, reg_type, initialization): self.reg_type = reg_type self.initialization = initialization self.weight = float(self.initialization) self.last_reg = None def step(self, global_step): pass def report(self, d): if self.last_reg is not None: d[self.reg_type].update(self.last_reg.item()) def regularize(self, *args, **kwargs) -> torch.Tensor: out = self._regularize(*args, **kwargs) * self.weight self.last_reg = out.detach() return out @abc.abstractmethod def _regularize(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError() def __str__(self): return f"Regularizer({self.reg_type}, weight={self.weight})" class PlaneTV(Regularizer): def __init__(self, initial_value, what: str = 'field'): if what not in {'field', 'proposal_network'}: raise ValueError(f'what must be one of "field" or "proposal_network" ' f'but {what} was passed.') name = f'planeTV-{what[:2]}' super().__init__(name, initial_value) self.what = what def step(self, global_step): pass def _regularize(self, model, **kwargs): multi_res_grids: Sequence[nn.ParameterList] if self.what == 'field': multi_res_grids = model.field.grids elif self.what == 'proposal_network': multi_res_grids = [p.grids for p in model.proposal_networks] else: raise NotImplementedError(self.what) total = 0 # Note: input to compute_plane_tv should be of shape [batch_size, c, h, w] for grids in multi_res_grids: if len(grids) == 3: spatial_grids = [0, 1, 2] else: spatial_grids = [0, 1, 3] # These are the spatial grids; the others are spatiotemporal for grid_id in spatial_grids: total += compute_plane_tv(grids[grid_id]) for grid in grids: # grid: [1, c, h, w] total += compute_plane_tv(grid) return total class TimeSmoothness(Regularizer): def __init__(self, initial_value, what: str = 'field'): if what not in {'field', 'proposal_network'}: raise ValueError(f'what must be one of "field" or "proposal_network" ' f'but {what} was passed.') name = f'time-smooth-{what[:2]}' super().__init__(name, initial_value) self.what = what def _regularize(self, model, **kwargs) -> torch.Tensor: multi_res_grids: Sequence[nn.ParameterList] if self.what == 'field': multi_res_grids = model.field.grids elif self.what == 'proposal_network': multi_res_grids = [p.grids for p in model.proposal_networks] else: raise NotImplementedError(self.what) total = 0 # model.grids is 6 x [1, rank * F_dim, reso, reso] for grids in multi_res_grids: if len(grids) == 3: time_grids = [] else: time_grids = [2, 4, 5] for grid_id in time_grids: total += compute_plane_smoothness(grids[grid_id]) return torch.as_tensor(total) class L1ProposalNetwork(Regularizer): def __init__(self, initial_value): super().__init__('l1-proposal-network', initial_value) def _regularize(self, model, **kwargs) -> torch.Tensor: grids = [p.grids for p in model.proposal_networks] total = 0.0 for pn_grids in grids: for grid in pn_grids: total += torch.abs(grid).mean() return torch.as_tensor(total) class DepthTV(Regularizer): def __init__(self, initial_value): super().__init__('tv-depth', initial_value) def _regularize(self, model, model_out, **kwargs) -> torch.Tensor: depth = model_out['depth'] tv = compute_plane_tv( depth.reshape(64, 64)[None, None, :, :] ) return tv class L1TimePlanes(Regularizer): def __init__(self, initial_value, what='field'): if what not in {'field', 'proposal_network'}: raise ValueError(f'what must be one of "field" or "proposal_network" ' f'but {what} was passed.') super().__init__(f'l1-time-{what[:2]}', initial_value) self.what = what def _regularize(self, model, **kwargs) -> torch.Tensor: # model.grids is 6 x [1, rank * F_dim, reso, reso] multi_res_grids: Sequence[nn.ParameterList] if self.what == 'field': multi_res_grids = model.field.grids elif self.what == 'proposal_network': multi_res_grids = [p.grids for p in model.proposal_networks] else: raise NotImplementedError(self.what) total = 0.0 for grids in multi_res_grids: if len(grids) == 3: continue else: # These are the spatiotemporal grids spatiotemporal_grids = [2, 4, 5] for grid_id in spatiotemporal_grids: total += torch.abs(1 - grids[grid_id]).mean() return torch.as_tensor(total) ================================================ FILE: gs_renderer_4d.py ================================================ import os import math import numpy as np from typing import NamedTuple from plyfile import PlyData, PlyElement import torch from torch import nn from diff_gauss import ( GaussianRasterizationSettings, GaussianRasterizer, ) from simple_knn._C import distCUDA2 from sh_utils import eval_sh, SH2RGB, RGB2SH from deform import * def inverse_sigmoid(x): return torch.log(x/(1-x)) def get_expon_lr_func( lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000 ): def helper(step): if lr_init == lr_final: # constant lr, ignore other params return lr_init if step < 0 or (lr_init == 0.0 and lr_final == 0.0): # Disable this parameter return 0.0 if lr_delay_steps > 0: # A kind of reverse cosine decay. delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( 0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1) ) else: delay_rate = 1.0 t = np.clip(step / max_steps, 0, 1) log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t) return delay_rate * log_lerp return helper def strip_lowerdiag(L): uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda") uncertainty[:, 0] = L[:, 0, 0] uncertainty[:, 1] = L[:, 0, 1] uncertainty[:, 2] = L[:, 0, 2] uncertainty[:, 3] = L[:, 1, 1] uncertainty[:, 4] = L[:, 1, 2] uncertainty[:, 5] = L[:, 2, 2] return uncertainty def strip_symmetric(sym): return strip_lowerdiag(sym) def gaussian_3d_coeff(xyzs, covs): # xyzs: [N, 3] # covs: [N, 6] x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2] a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5] # eps must be small enough !!! inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24) inv_a = (d * f - e**2) * inv_det inv_b = (e * c - b * f) * inv_det inv_c = (e * b - c * d) * inv_det inv_d = (a * f - c**2) * inv_det inv_e = (b * c - e * a) * inv_det inv_f = (a * d - b**2) * inv_det power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e power[power > 0] = -1e10 # abnormal values... make weights 0 return torch.exp(power) def build_rotation(r): norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3]) q = r / norm[:, None] R = torch.zeros((q.size(0), 3, 3), device='cuda') r = q[:, 0] x = q[:, 1] y = q[:, 2] z = q[:, 3] R[:, 0, 0] = 1 - 2 * (y*y + z*z) R[:, 0, 1] = 2 * (x*y - r*z) R[:, 0, 2] = 2 * (x*z + r*y) R[:, 1, 0] = 2 * (x*y + r*z) R[:, 1, 1] = 1 - 2 * (x*x + z*z) R[:, 1, 2] = 2 * (y*z - r*x) R[:, 2, 0] = 2 * (x*z - r*y) R[:, 2, 1] = 2 * (y*z + r*x) R[:, 2, 2] = 1 - 2 * (x*x + y*y) return R def build_scaling_rotation(s, r): L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda") R = build_rotation(r) L[:,0,0] = s[:,0] L[:,1,1] = s[:,1] L[:,2,2] = s[:,2] L = R @ L return L class BasicPointCloud(NamedTuple): points: np.array colors: np.array normals: np.array class GaussianModel: def setup_functions(self): def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation): L = build_scaling_rotation(scaling_modifier * scaling, rotation) actual_covariance = L @ L.transpose(1, 2) symm = strip_symmetric(actual_covariance) return symm self.scaling_activation = torch.exp self.scaling_inverse_activation = torch.log self.covariance_activation = build_covariance_from_scaling_rotation self.opacity_activation = torch.sigmoid self.inverse_opacity_activation = inverse_sigmoid self.rotation_activation = torch.nn.functional.normalize def initialize(self, initial_values, raw=False): # NOTE: actual initialization is done in trainer # raw stands for raw values, i.e. not passed through activation self._xyz = nn.Parameter(initial_values["mean"].requires_grad_(True)).to('cuda') self._rotation = nn.Parameter(initial_values["qvec"].requires_grad_(True)).to('cuda') #self._scaling = nn.Parameter(initial_values["svec"].requires_grad_(True)).to('cuda') #self._features_dc = nn.Parameter(initial_values["color"].requires_grad_(True)).to('cuda') self._opacity = nn.Parameter(initial_values["alpha"].requires_grad_(True)).to('cuda') def __init__(self, sh_degree : int,args = None): self.active_sh_degree = 0 self.max_sh_degree = sh_degree self._xyz = torch.empty(0) self._features_dc = torch.empty(0) self._features_rest = torch.empty(0) self._scaling = torch.empty(0) self._rotation = torch.empty(0) self._opacity = torch.empty(0) self.max_radii2D = torch.empty(0) self.xyz_gradient_accum = torch.empty(0) self.denom = torch.empty(0) self.optimizer = None self.percent_dense = 0 self.spatial_lr_scale = 0 self._deformation_table = torch.empty(0) self._deformation = deform_network() self.setup_functions() def capture(self): return ( self.active_sh_degree, self._xyz, self._deformation.state_dict(), self._deformation_table, self._features_dc, self._features_rest, self._scaling, self._rotation, self._opacity, self.max_radii2D, self.xyz_gradient_accum, self.denom, self.optimizer.state_dict(), self.spatial_lr_scale, ) def restore(self, model_args, training_args): (self.active_sh_degree, self._xyz, self._deformation_table, self._deformation, self._features_dc, self._features_rest, self._scaling, self._rotation, self._opacity, self.max_radii2D, xyz_gradient_accum, denom, opt_dict, self.spatial_lr_scale) = model_args self.training_setup(training_args) self.xyz_gradient_accum = xyz_gradient_accum self.denom = denom self.optimizer.load_state_dict(opt_dict) @property def get_scaling(self): return self.scaling_activation(self._scaling) @property def get_rotation(self): return self.rotation_activation(self._rotation) @property def get_xyz(self): return self._xyz @property def get_features(self): features_dc = self._features_dc features_rest = self._features_rest return torch.cat((features_dc, features_rest), dim=1) @property def get_opacity(self): return self.opacity_activation(self._opacity) def get_covariance(self, scaling_modifier = 1): return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation) def oneupSHdegree(self): if self.active_sh_degree < self.max_sh_degree: self.active_sh_degree += 1 def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float = 1): self.spatial_lr_scale = spatial_lr_scale fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda() fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda()) features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda() features[:, :3, 0 ] = fused_color features[:, 3:, 1:] = 0.0 print("Number of points at initialisation : ", fused_point_cloud.shape[0]) dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001) scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3) rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda") rots[:, 0] = 1 opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda")) self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True)) self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True)) self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True)) self._scaling = nn.Parameter(scales.requires_grad_(True)) self._rotation = nn.Parameter(rots.requires_grad_(True)) self._opacity = nn.Parameter(opacities.requires_grad_(True)) self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0) #print(self._xyz.shape,self._rotation.shape) self._deformation = self._deformation.to("cuda") self.active_sh_degree = self.max_sh_degree self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda") self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") def training_setup(self, training_args): self.percent_dense = training_args.percent_dense self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda") l = [ {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"}, {'params': list(self._deformation.get_mlp_parameters()), 'lr': training_args.deformation_lr_init * self.spatial_lr_scale, "name": "deformation"}, {'params': list(self._deformation.get_grid_parameters()), 'lr': training_args.grid_lr_init * self.spatial_lr_scale, "name": "grid"}, {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"}, {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"}, {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"}, {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"}, {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"} ] self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15) self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale, lr_final=training_args.position_lr_final*self.spatial_lr_scale, lr_delay_mult=training_args.position_lr_delay_mult, max_steps=training_args.position_lr_max_steps) self.deformation_scheduler_args = get_expon_lr_func(lr_init=training_args.deformation_lr_init*self.spatial_lr_scale, lr_final=training_args.deformation_lr_final*self.spatial_lr_scale, lr_delay_mult=training_args.deformation_lr_delay_mult, max_steps=training_args.position_lr_max_steps) self.grid_scheduler_args = get_expon_lr_func(lr_init=training_args.grid_lr_init*self.spatial_lr_scale, lr_final=training_args.grid_lr_final*self.spatial_lr_scale, lr_delay_mult=training_args.deformation_lr_delay_mult, max_steps=training_args.position_lr_max_steps) def update_learning_rate(self, iteration): ''' Learning rate scheduling per step ''' for param_group in self.optimizer.param_groups: if param_group["name"] == "xyz": lr = self.xyz_scheduler_args(iteration) param_group['lr'] = lr # return lr if "grid" in param_group["name"]: lr = self.grid_scheduler_args(iteration) param_group['lr'] = lr # return lr elif param_group["name"] == "deformation": lr = self.deformation_scheduler_args(iteration) param_group['lr'] = lr # return lr def construct_list_of_attributes(self): l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] # All channels except the 3 DC for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]): l.append('f_dc_{}'.format(i)) for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]): l.append('f_rest_{}'.format(i)) l.append('opacity') for i in range(self._scaling.shape[1]): l.append('scale_{}'.format(i)) for i in range(self._rotation.shape[1]): l.append('rot_{}'.format(i)) return l def save_ply(self, path): os.makedirs(os.path.dirname(path), exist_ok=True) xyz = self._xyz.detach().cpu().numpy() normals = np.zeros_like(xyz) f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy() opacities = self._opacity.detach().cpu().numpy() scale = self._scaling.detach().cpu().numpy() rotation = self._rotation.detach().cpu().numpy() dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()] elements = np.empty(xyz.shape[0], dtype=dtype_full) attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) elements[:] = list(map(tuple, attributes)) el = PlyElement.describe(elements, 'vertex') PlyData([el]).write(path) def compute_deformation(self,time): deform = self._deformation[:,:,:time].sum(dim=-1) xyz = self._xyz + deform return xyz def load_model(self, path): print("loading model from exists{}".format(path)) weight_dict = torch.load(os.path.join(path,"deformation.pth"),map_location="cuda") self._deformation.load_state_dict(weight_dict) self._deformation = self._deformation.to("cuda") self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0) self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda") if os.path.exists(os.path.join(path, "deformation_table.pth")): self._deformation_table = torch.load(os.path.join(path, "deformation_table.pth"),map_location="cuda") if os.path.exists(os.path.join(path, "deformation_accum.pth")): self._deformation_accum = torch.load(os.path.join(path, "deformation_accum.pth"),map_location="cuda") self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0) self._deformation = self._deformation.to("cuda") def save_deformation(self, path): torch.save(self._deformation.state_dict(),os.path.join(path, "deformation.pth")) torch.save(self._deformation_table,os.path.join(path, "deformation_table.pth")) torch.save(self._deformation_accum,os.path.join(path, "deformation_accum.pth")) def load_ply(self, path): plydata = PlyData.read(path) xyz = np.stack((np.asarray(plydata.elements[0]["x"]), np.asarray(plydata.elements[0]["y"]), np.asarray(plydata.elements[0]["z"])), axis=1) opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis] print("Number of points at loading : ", xyz.shape[0]) features_dc = np.zeros((xyz.shape[0], 3, 1)) features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"]) features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"]) features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"]) extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")] assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3 features_extra = np.zeros((xyz.shape[0], len(extra_f_names))) for idx, attr_name in enumerate(extra_f_names): features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name]) # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC) features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1)) scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")] scales = np.zeros((xyz.shape[0], len(scale_names))) for idx, attr_name in enumerate(scale_names): scales[:, idx] = np.asarray(plydata.elements[0][attr_name]) rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")] rots = np.zeros((xyz.shape[0], len(rot_names))) for idx, attr_name in enumerate(rot_names): rots[:, idx] = np.asarray(plydata.elements[0][attr_name]) self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True)) self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True)) self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True)) self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True)) self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True)) self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0) #print(self._xyz.shape,self._rotation.shape) self._deformation = self._deformation.to("cuda") self.active_sh_degree = self.max_sh_degree self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda") self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") def replace_tensor_to_optimizer(self, tensor, name): optimizable_tensors = {} for group in self.optimizer.param_groups: if group["name"] == name: stored_state = self.optimizer.state.get(group['params'][0], None) stored_state["exp_avg"] = torch.zeros_like(tensor) stored_state["exp_avg_sq"] = torch.zeros_like(tensor) del self.optimizer.state[group['params'][0]] group["params"][0] = nn.Parameter(tensor.requires_grad_(True)) self.optimizer.state[group['params'][0]] = stored_state optimizable_tensors[group["name"]] = group["params"][0] return optimizable_tensors def _prune_optimizer(self, mask): optimizable_tensors = {} for group in self.optimizer.param_groups: if len(group["params"]) > 1: continue stored_state = self.optimizer.state.get(group['params'][0], None) if stored_state is not None: stored_state["exp_avg"] = stored_state["exp_avg"][mask] stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask] del self.optimizer.state[group['params'][0]] group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True))) self.optimizer.state[group['params'][0]] = stored_state optimizable_tensors[group["name"]] = group["params"][0] else: group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True)) optimizable_tensors[group["name"]] = group["params"][0] return optimizable_tensors def prune_points(self, mask): valid_points_mask = ~mask optimizable_tensors = self._prune_optimizer(valid_points_mask) self._xyz = optimizable_tensors["xyz"] self._features_dc = optimizable_tensors["f_dc"] self._features_rest = optimizable_tensors["f_rest"] self._opacity = optimizable_tensors["opacity"] self._scaling = optimizable_tensors["scaling"] self._rotation = optimizable_tensors["rotation"] self._deformation_accum = self._deformation_accum[valid_points_mask] self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask] self._deformation_table = self._deformation_table[valid_points_mask] self.denom = self.denom[valid_points_mask] self.max_radii2D = self.max_radii2D[valid_points_mask] def cat_tensors_to_optimizer(self, tensors_dict): optimizable_tensors = {} for group in self.optimizer.param_groups: if len(group["params"])>1:continue assert len(group["params"]) == 1 extension_tensor = tensors_dict[group["name"]] stored_state = self.optimizer.state.get(group['params'][0], None) if stored_state is not None: stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0) stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0) del self.optimizer.state[group['params'][0]] group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) self.optimizer.state[group['params'][0]] = stored_state optimizable_tensors[group["name"]] = group["params"][0] else: group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True)) optimizable_tensors[group["name"]] = group["params"][0] return optimizable_tensors def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table): d = {"xyz": new_xyz, "f_dc": new_features_dc, "f_rest": new_features_rest, "opacity": new_opacities, "scaling" : new_scaling, "rotation" : new_rotation, # "deformation": new_deformation } optimizable_tensors = self.cat_tensors_to_optimizer(d) self._xyz = optimizable_tensors["xyz"] self._features_dc = optimizable_tensors["f_dc"] self._features_rest = optimizable_tensors["f_rest"] self._opacity = optimizable_tensors["opacity"] self._scaling = optimizable_tensors["scaling"] self._rotation = optimizable_tensors["rotation"] # self._deformation = optimizable_tensors["deformation"] self._deformation_table = torch.cat([self._deformation_table,new_deformation_table],-1) self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") self._deformation_accum = torch.zeros((self.get_xyz.shape[0], 3), device="cuda") self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda") self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda") def densify_and_split(self, grads, grad_threshold, scene_extent, N=2): n_init_points = self.get_xyz.shape[0] # Extract points that satisfy the gradient condition padded_grad = torch.zeros((n_init_points), device="cuda") padded_grad[:grads.shape[0]] = grads.squeeze() selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False) selected_pts_mask = torch.logical_and(selected_pts_mask, torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent) if not selected_pts_mask.any(): return stds = self.get_scaling[selected_pts_mask].repeat(N,1) means =torch.zeros((stds.size(0), 3),device="cuda") samples = torch.normal(mean=means, std=stds) rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1) new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1) new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N)) new_rotation = self._rotation[selected_pts_mask].repeat(N,1) new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1) new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1) new_opacity = self._opacity[selected_pts_mask].repeat(N,1) new_deformation_table = self._deformation_table[selected_pts_mask].repeat(N) self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_deformation_table) prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool))) self.prune_points(prune_filter) def densify_and_clone(self, grads, grad_threshold, scene_extent): selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False) selected_pts_mask = torch.logical_and(selected_pts_mask, torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent) new_xyz = self._xyz[selected_pts_mask] # - 0.001 * self._xyz.grad[selected_pts_mask] new_features_dc = self._features_dc[selected_pts_mask] new_features_rest = self._features_rest[selected_pts_mask] new_opacities = self._opacity[selected_pts_mask] new_scaling = self._scaling[selected_pts_mask] new_rotation = self._rotation[selected_pts_mask] new_deformation_table = self._deformation_table[selected_pts_mask] self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table) def densify_and_prune(self, max_grad_percent, min_opacity, extent, max_screen_size): grads = self.xyz_gradient_accum / self.denom grads[grads.isnan()] = 0.0 grad_log = torch.log(grads) grad_log2=grad_log[~grad_log.isnan()] grad_log3=grad_log[~grad_log2.isinf()] max_grad_1 = torch.exp(grad_log3.mean()+grad_log3.var()) #adaptive densification with mean and var, unused max_grad_2 = torch.exp(grad_log3.squeeze(dim=1).sort(descending=True)[0][int(max_grad_percent*grad_log3.shape[0])]) #adaptive densification with relative grad max_grad = max_grad_2 #choose which to use #print('max_grad',max_grad_percent,max_grad_1,max_grad_2,grad_log3.mean(),grad_log3.var()) self.densify_and_clone(grads, max_grad, extent) self.densify_and_split(grads, max_grad, extent) prune_mask = (self.get_opacity < min_opacity).squeeze() if max_screen_size: big_points_vs = self.max_radii2D > max_screen_size big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent small_ws = self.get_scaling.max(dim=1).values<0.001 prune_mask = torch.logical_or(torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws),small_ws) self.prune_points(prune_mask) torch.cuda.empty_cache() def prune(self, min_opacity=0.01, extent=1, max_screen_size=1): prune_mask = (self.get_opacity < min_opacity).squeeze() # prune_mask_2 = torch.logical_and(self.get_opacity <= inverse_sigmoid(0.101 , dtype=torch.float, device="cuda"), self.get_opacity >= inverse_sigmoid(0.999 , dtype=torch.float, device="cuda")) # prune_mask = torch.logical_or(prune_mask, prune_mask_2) # deformation_sum = abs(self._deformation).sum(dim=-1).mean(dim=-1) # deformation_mask = (deformation_sum < torch.quantile(deformation_sum, torch.tensor([0.5]).to("cuda"))) # prune_mask = prune_mask & deformation_mask if max_screen_size: big_points_vs = self.max_radii2D > max_screen_size big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent #prune_mask = torch.logical_or(prune_mask, big_points_vs) small_ws = self.get_scaling.min(dim=1).values<0.001 prune_mask = torch.logical_or(torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws),small_ws) self.prune_points(prune_mask) def standard_constaint(self): means3D = self._xyz.detach() scales = self._scaling.detach() rotations = self._rotation.detach() opacity = self._opacity.detach() time = torch.tensor(0).to("cuda").repeat(means3D.shape[0],1) means3D_deform, scales_deform, rotations_deform, _ = self._deformation(means3D, scales, rotations, opacity, time) position_error = (means3D_deform - means3D)**2 rotation_error = (rotations_deform - rotations)**2 scaling_erorr = (scales_deform - scales)**2 return position_error.mean() + rotation_error.mean() + scaling_erorr.mean() def add_densification_stats(self, viewspace_point_tensor, update_filter): #print(viewspace_point_tensor,update_filter) self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor[update_filter,:2], dim=-1, keepdim=True) self.denom[update_filter] += 1 @torch.no_grad() def update_deformation_table(self,threshold): # print("origin deformation point nums:",self._deformation_table.sum()) self._deformation_table = torch.gt(self._deformation_accum.max(dim=-1).values/100,threshold) def print_deformation_weight_grad(self): for name, weight in self._deformation.named_parameters(): if weight.requires_grad: if weight.grad is None: print(name," :",weight.grad) else: if weight.grad.mean() != 0: print(name," :",weight.grad.mean(), weight.grad.min(), weight.grad.max()) print("-"*50) def _plane_regulation(self): multi_res_grids = self._deformation.deformation_net.grid.grids total = 0 # model.grids is 6 x [1, rank * F_dim, reso, reso] for grids in multi_res_grids: if len(grids) == 3: time_grids = [] else: time_grids = [0,1,3] for grid_id in time_grids: total += compute_plane_smoothness(grids[grid_id]) return total def _time_regulation(self): multi_res_grids = self._deformation.deformation_net.grid.grids total = 0 # model.grids is 6 x [1, rank * F_dim, reso, reso] for grids in multi_res_grids: if len(grids) == 3: time_grids = [] else: time_grids =[2, 4, 5] for grid_id in time_grids: total += compute_plane_smoothness(grids[grid_id]) return total def _l1_regulation(self): # model.grids is 6 x [1, rank * F_dim, reso, reso] multi_res_grids = self._deformation.deformation_net.grid.grids total = 0.0 for grids in multi_res_grids: if len(grids) == 3: continue else: # These are the spatiotemporal grids spatiotemporal_grids = [2, 4, 5] for grid_id in spatiotemporal_grids: total += torch.abs(1 - grids[grid_id]).mean() return total def compute_regulation(self, time_smoothness_weight, l1_time_planes_weight, plane_tv_weight): return plane_tv_weight * self._plane_regulation() + time_smoothness_weight * self._time_regulation() + l1_time_planes_weight * self._l1_regulation() def getProjectionMatrix(znear, zfar, fovX, fovY): tanHalfFovY = math.tan((fovY / 2)) tanHalfFovX = math.tan((fovX / 2)) P = torch.zeros(4, 4) z_sign = 1.0 P[0, 0] = 1 / tanHalfFovX P[1, 1] = 1 / tanHalfFovY P[3, 2] = z_sign P[2, 2] = z_sign * zfar / (zfar - znear) P[2, 3] = -(zfar * znear) / (zfar - znear) return P class MiniCam: def __init__(self, c2w, width, height, fovy, fovx, znear, zfar,time=0 ): # c2w (pose) should be in NeRF convention. self.image_width = width self.image_height = height self.FoVy = fovy self.FoVx = fovx self.znear = znear self.zfar = zfar self.time = time w2c = np.linalg.inv(c2w) # rectify... w2c[1:3, :3] *= -1 w2c[:3, 3] *= -1 self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda() self.projection_matrix = ( getProjectionMatrix( znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy ) .transpose(0, 1) .cuda() ) self.full_proj_transform = self.world_view_transform @ self.projection_matrix self.camera_center = -torch.tensor(c2w[:3, 3]).cuda() class Renderer: def __init__(self, sh_degree=3, white_background=True, radius=1): self.sh_degree = sh_degree self.white_background = white_background self.radius = radius self.gaussians = GaussianModel(sh_degree) self.bg_color = torch.tensor( [1, 1, 1] if white_background else [0, 0, 0], dtype=torch.float32, device="cuda", ) def initialize(self, input=None, num_pts=5000, radius=0.5,initial_values=None): # load checkpoint if input is None: # init from random point cloud phis = np.random.random((num_pts,)) * 2 * np.pi costheta = np.random.random((num_pts,)) * 2 - 1 thetas = np.arccos(costheta) mu = np.random.random((num_pts,)) radius = radius * np.cbrt(mu) x = radius * np.sin(thetas) * np.cos(phis) y = radius * np.sin(thetas) * np.sin(phis) z = radius * np.cos(thetas) xyz = np.stack((x, y, z), axis=1) if initial_values is not None: print(xyz.shape,initial_values["mean"].shape) R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) xyz = np.dot(initial_values["mean"].numpy(),R) # xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3 shs = np.random.random((num_pts, 3)) / 255.0 pcd = BasicPointCloud( points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3)) ) self.gaussians.create_from_pcd(pcd, 10) elif isinstance(input, BasicPointCloud): # load from a provided pcd self.gaussians.create_from_pcd(input, 1) else: # load from saved ply self.gaussians.load_ply(input) def render( self, viewpoint_camera, scaling_modifier=1.0, bg_color=None, override_color=None, compute_cov3D_python=False, convert_SHs_python=False, stage="fine", time_int = None, front_view=False, ): # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means screenspace_points = torch.zeros_like(self.gaussians.get_xyz, dtype=self.gaussians.get_xyz.dtype, requires_grad=True, device="cuda") + 0 try: screenspace_points.retain_grad() except: pass # Set up rasterization configuration tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) raster_settings = GaussianRasterizationSettings( image_height=int(viewpoint_camera.image_height), image_width=int(viewpoint_camera.image_width), tanfovx=tanfovx, tanfovy=tanfovy, bg=self.bg_color if bg_color is None else bg_color, scale_modifier=scaling_modifier, viewmatrix=viewpoint_camera.world_view_transform, projmatrix=viewpoint_camera.full_proj_transform, sh_degree=self.gaussians.active_sh_degree, campos=viewpoint_camera.camera_center, prefiltered=False, debug=False, ) if front_view==True: print(viewpoint_camera.world_view_transform,viewpoint_camera.full_proj_transform) rasterizer = GaussianRasterizer(raster_settings=raster_settings) means3D = self.gaussians.get_xyz time = torch.tensor(viewpoint_camera.time).to(means3D.device).repeat(means3D.shape[0],1) means2D = screenspace_points opacity = self.gaussians._opacity # If precomputed 3d covariance is provided, use it. If not, then it will be computed from # scaling / rotation by the rasterizer. scales = None rotations = None cov3D_precomp = None if compute_cov3D_python: cov3D_precomp = self.gaussians.get_covariance(scaling_modifier) else: scales = self.gaussians._scaling rotations = self.gaussians._rotation deformation_point = self.gaussians._deformation_table if stage == "coarse" : means3D_deform, scales_deform, rotations_deform, opacity_deform = means3D, scales, rotations, opacity else: means3D_deform, scales_deform, rotations_deform, opacity_deform = self.gaussians._deformation(means3D[deformation_point], scales[deformation_point], rotations[deformation_point], opacity[deformation_point], time[deformation_point]) # print(time.max()) with torch.no_grad(): self.gaussians._deformation_accum[deformation_point] += torch.abs(means3D_deform-means3D[deformation_point]) #print(torch.abs(means3D_deform-means3D[deformation_point]).mean()) means3D_final = torch.zeros_like(means3D) rotations_final = torch.zeros_like(rotations) scales_final = torch.zeros_like(scales) opacity_final = torch.zeros_like(opacity) means3D_final[deformation_point] = means3D_deform rotations_final[deformation_point] = rotations_deform scales_final[deformation_point] = scales_deform opacity_final[deformation_point] = opacity_deform means3D_final[~deformation_point] = means3D[~deformation_point] rotations_final[~deformation_point] = rotations[~deformation_point] scales_final[~deformation_point] = scales[~deformation_point] opacity_final[~deformation_point] = opacity[~deformation_point] scales_in=scales_final rotations_in=rotations_final opacity_in = opacity_final scales_final = self.gaussians.scaling_activation(scales_final) rotations_final = self.gaussians.rotation_activation(rotations_final) opacity = self.gaussians.opacity_activation(opacity) opacity_final = self.gaussians.opacity_activation(opacity_final) # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. shs = None colors_precomp = None shs = self.gaussians.get_features # Rasterize visible Gaussians to image, obtain their radii (on screen). rendered_image, rendered_depth, normal, rendered_alpha ,radii, _ = rasterizer( means3D = means3D_final, means2D = means2D, shs = shs, colors_precomp = colors_precomp, opacities = opacity_final, scales = scales_final, rotations = rotations_final, cov3Ds_precomp = cov3D_precomp) return { "image": rendered_image, "depth": rendered_depth, "alpha": rendered_alpha, "viewspace_points": screenspace_points, "visibility_filter": radii > 0, "radii": radii, 'xyz':means3D_final, 'rot':rotations_in, 'xy':means2D, 'color':shs, 'scales':scales_in, 'opacity':opacity_in, } ================================================ FILE: guidance/zero123_4d_utils.py ================================================ from transformers import CLIPTextModel, CLIPTokenizer, logging from diffusers import ( AutoencoderKL, UNet2DConditionModel, DDIMScheduler, StableDiffusionPipeline, ) import torchvision.transforms.functional as TF import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import sys sys.path.append('./') from zero123 import Zero123Pipeline class Zero123(nn.Module): def __init__(self, device, fp16=True, t_range=[0.02, 0.98]): super().__init__() self.device = device self.fp16 = fp16 self.dtype = torch.float16 if fp16 else torch.float32 self.pipe = Zero123Pipeline.from_pretrained( # "bennyguo/zero123-diffusers", "ashawkey/zero123-xl-diffusers", # './model_cache/zero123_xl', variant="fp16" if self.fp16 else None, torch_dtype=self.dtype, ).to(self.device) # for param in self.pipe.parameters(): # param.requires_grad = False self.pipe.image_encoder.eval() self.pipe.vae.eval() self.pipe.unet.eval() self.pipe.clip_camera_projection.eval() self.vae = self.pipe.vae self.unet = self.pipe.unet self.pipe.set_progress_bar_config(disable=True) self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.min_step = int(self.num_train_timesteps * t_range[0]) self.max_step = int(self.num_train_timesteps * t_range[1]) self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience self.min_step_percent = [0, 0.5, 0.02, 3000] self.max_step_percent= [0, 0.95, 0.5, 3000] self.embeddings = None self.embedding_list = [] @torch.no_grad() def get_img_embeds(self, x, input_imgs=None): # x: image tensor in [0, 1] x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) x_pil = [TF.to_pil_image(image) for image in x] x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) c = self.pipe.image_encoder(x_clip).image_embeds v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor self.embeddings = [c, v] self.additional_embeddings=[] if input_imgs!=None: for x in input_imgs: x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False) x_pil = [TF.to_pil_image(image) for image in x] x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype) c = self.pipe.image_encoder(x_clip).image_embeds v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor embeddings = [c, v] self.additional_embeddings.append(embeddings) self.embedding_list.append([self.embeddings,self.additional_embeddings]) @torch.no_grad() def refine(self, pred_rgb, polar, azimuth, radius, guidance_scale=5, steps=50, strength=0.8, ): batch_size = pred_rgb.shape[0] self.scheduler.set_timesteps(steps) if strength == 0: init_step = 0 latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype) else: init_step = int(steps * strength) pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step]) T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4] cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) cc_emb = self.pipe.clip_camera_projection(cc_emb) cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1) vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) for i, t in enumerate(self.scheduler.timesteps[init_step:]): x_in = torch.cat([latents] * 2) t_in = torch.cat([t.view(1)] * 2).to(self.device) noise_pred = self.unet( torch.cat([x_in, vae_emb], dim=1), t_in.to(self.unet.dtype), encoder_hidden_states=cc_emb, ).sample noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) latents = self.scheduler.step(noise_pred, t, latents).prev_sample imgs = self.decode_latents(latents) # [1, 3, 256, 256] return imgs def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False,idx=None,t=0): # pred_rgb: tensor [1, 3, H, W] in [0, 1] #print(polar) step_ratio = max(0.4,step_ratio) self.embeddings,self.additional_embeddings = self.embedding_list[t] batch_size = pred_rgb.shape[0] #print(self.embedding_list[1][0][0] -self.embedding_list[2][0][0]) #print(self.embedding_list[1][0][1] -self.embedding_list[2][0][1]) if idx is not None: embeddings = self.additional_embeddings[idx] else: embeddings = self.embeddings if as_latent: latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1 else: pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False) latents = self.encode_imgs(pred_rgb_256.to(self.dtype)) t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device) w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1) with torch.no_grad(): noise = torch.randn_like(latents) latents_noisy = self.scheduler.add_noise(latents, noise, t) x_in = torch.cat([latents_noisy] * 2) t_in = torch.cat([t] * 2) T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1) T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4] #print(self.embeddings[0].repeat(batch_size, 1, 1).shape,T.shape) cc_emb = torch.cat([embeddings[0].repeat(batch_size, 1, 1), T], dim=-1) cc_emb = self.pipe.clip_camera_projection(cc_emb) cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0) vae_emb = embeddings[1].repeat(batch_size, 1, 1, 1) vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0) noise_pred = self.unet( torch.cat([x_in, vae_emb], dim=1), t_in.to(self.unet.dtype), encoder_hidden_states=cc_emb, ).sample noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) target = (latents - grad).detach() loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum') return loss def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): min_step_percent = self.get_steps(self.min_step_percent, epoch, global_step) max_step_percent = self.get_steps(self.max_step_percent, epoch, global_step) self.min_step = int( self.num_train_timesteps * min_step_percent ) self.max_step = int( self.num_train_timesteps * max_step_percent ) def get_steps(self,percent,epoch, global_step): start_step, start_value, end_value, end_step = percent current_step = global_step value = start_value + (end_value - start_value) * max( min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 ) return value def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents imgs = self.vae.decode(latents).sample imgs = (imgs / 2 + 0.5).clamp(0, 1) return imgs def encode_imgs(self, imgs, mode=False): # imgs: [B, 3, H, W] imgs = 2 * imgs - 1 posterior = self.vae.encode(imgs).latent_dist if mode: latents = posterior.mode() else: latents = posterior.sample() latents = latents * self.vae.config.scaling_factor return latents if __name__ == '__main__': import cv2 import argparse import numpy as np import matplotlib.pyplot as plt parser = argparse.ArgumentParser() parser.add_argument('input', type=str) parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]') parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]') parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]') opt = parser.parse_args() device = torch.device('cuda') print(f'[INFO] loading image from {opt.input} ...') image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA) image = image.astype(np.float32) / 255.0 image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device) print(f'[INFO] loading model ...') zero123 = Zero123(device) print(f'[INFO] running model ...') zero123.get_img_embeds(image) while True: outputs = zero123.refine(image, polar=[opt.polar], azimuth=[opt.azimuth], radius=[opt.radius], strength=0) plt.imshow(outputs.float().cpu().numpy().transpose(0, 2, 3, 1)[0]) plt.show() ================================================ FILE: guidance/zero123pp/pipeline.py ================================================ from typing import Any, Dict, Optional from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.schedulers import KarrasDiffusionSchedulers import numpy import torch import torch.nn as nn import torch.utils.checkpoint import torch.distributed import transformers from collections import OrderedDict from PIL import Image from torchvision import transforms from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer import torch import torch.nn.functional as F from torch import nn import diffusers from diffusers import ( AutoencoderKL, DDPMScheduler, DiffusionPipeline, EulerAncestralDiscreteScheduler, UNet2DConditionModel, ImagePipelineOutput ) from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor from diffusers.utils.import_utils import is_xformers_available import os FIRST = True IDX = 0 PATH = '/home/vision/github/embeddings/' EMBED=[] def to_rgb_image(maybe_rgba: Image.Image): if maybe_rgba.mode == 'RGB': return maybe_rgba elif maybe_rgba.mode == 'RGBA': rgba = maybe_rgba img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) img = Image.fromarray(img, 'RGB') img.paste(rgba, mask=rgba.getchannel('A')) return img else: raise ValueError("Unsupported image type.", maybe_rgba.mode) class MyAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("MyAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, is_self_attention=False ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if is_self_attention: global FIRST global IDX global EMBED if FIRST == True: EMBED.append(encoder_hidden_states.to('cpu')) #print('saving to {})'.format(PATH,str(IDX)+'_hidden.pt')) #os.makedirs(PATH,exist_ok=True) #torch.save(encoder_hidden_states,os.path.join(PATH,str(IDX)+'_hidden.pt')) IDX=IDX+1 else: last_shape = encoder_hidden_states.shape[-1] replace_dim = int(9600/(last_shape//320)**2) encoder_hidden_states_load = EMBED[IDX].to('cuda') encoder_hidden_states[:,:replace_dim,:]=(encoder_hidden_states_load[:,:replace_dim,:]+encoder_hidden_states[:,:replace_dim,:])/2 IDX=IDX+1 key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) #print(key.shape) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states class ReferenceOnlyAttnProc(torch.nn.Module): def __init__( self, chained_proc, enabled=False, name=None ) -> None: super().__init__() self.enabled = enabled self.chained_proc = chained_proc self.name = name def __call__( self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict: dict = None, is_cfg_guidance = False ) -> Any: if encoder_hidden_states is None: encoder_hidden_states = hidden_states is_self_attention = False if self.enabled and is_cfg_guidance: res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask) hidden_states = hidden_states[1:] encoder_hidden_states = encoder_hidden_states[1:] if self.enabled: if mode == 'w': ref_dict[self.name] = encoder_hidden_states elif mode == 'r': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) is_self_attention = True elif mode == 'm': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) else: assert False, mode res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask,is_self_attention=is_self_attention) if self.enabled and is_cfg_guidance: res = torch.cat([res0, res]) return res class RefOnlyNoisedUNet(torch.nn.Module): def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None: super().__init__() self.unet = unet self.train_sched = train_sched self.val_sched = val_sched unet_lora_attn_procs = dict() for name, _ in unet.attn_processors.items(): if torch.__version__ >= '2.0': default_attn_proc = MyAttnProcessor2_0() print('using my attention') elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor() else: default_attn_proc = AttnProcessor() unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( default_attn_proc, enabled=name.endswith("attn1.processor"), name=name ) unet.set_attn_processor(unet_lora_attn_procs) def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.unet, name) def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): if is_cfg_guidance: encoder_hidden_states = encoder_hidden_states[1:] class_labels = class_labels[1:] self.unet( noisy_cond_lat, timestep, encoder_hidden_states=encoder_hidden_states, class_labels=class_labels, cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), **kwargs ) def forward( self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs, down_block_res_samples=None, mid_block_res_sample=None, **kwargs ): cond_lat = cross_attention_kwargs['cond_lat'] is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) noise = torch.randn_like(cond_lat) if self.training: noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) else: noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) #if cross_attention_kwargs['cond_lat_back'] is not None: # cond_lat_back = cross_attention_kwargs['cond_lat_back'] # noisy_cond_lat_back = self.val_sched.add_noise(cond_lat_back, noise, timestep.reshape(-1)) # noisy_cond_lat_back = self.val_sched.scale_model_input(noisy_cond_lat_back, timestep.reshape(-1)) ref_dict = {} self.forward_cond( noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs ) weight_dtype = self.unet.dtype return self.unet( sample, timestep, encoder_hidden_states, *args, class_labels=class_labels, cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), down_block_additional_residuals=[ sample.to(dtype=weight_dtype) for sample in down_block_res_samples ] if down_block_res_samples is not None else None, mid_block_additional_residual=( mid_block_res_sample.to(dtype=weight_dtype) if mid_block_res_sample is not None else None ), **kwargs ) def scale_latents(latents): latents = (latents - 0.22) * 0.75 return latents def unscale_latents(latents): latents = latents / 0.75 + 0.22 return latents def scale_image(image): image = image * 0.5 / 0.8 return image def unscale_image(image): image = image / 0.5 * 0.8 return image class DepthControlUNet(torch.nn.Module): def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None: super().__init__() self.unet = unet if controlnet is None: self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet) else: self.controlnet = controlnet DefaultAttnProc = MyAttnProcessor2_0 if is_xformers_available(): DefaultAttnProc = XFormersAttnProcessor self.controlnet.set_attn_processor(DefaultAttnProc()) self.conditioning_scale = conditioning_scale def __getattr__(self, name: str): try: return super().__getattr__(name) except AttributeError: return getattr(self.unet, name) def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs): cross_attention_kwargs = dict(cross_attention_kwargs) control_depth = cross_attention_kwargs.pop('control_depth') down_block_res_samples, mid_block_res_sample = self.controlnet( sample, timestep, encoder_hidden_states=encoder_hidden_states, controlnet_cond=control_depth, conditioning_scale=self.conditioning_scale, return_dict=False, ) return self.unet( sample, timestep, encoder_hidden_states=encoder_hidden_states, down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample, cross_attention_kwargs=cross_attention_kwargs ) class ModuleListDict(torch.nn.Module): def __init__(self, procs: dict) -> None: super().__init__() self.keys = sorted(procs.keys()) self.values = torch.nn.ModuleList(procs[k] for k in self.keys) def __getitem__(self, key): return self.values[self.keys.index(key)] class SuperNet(torch.nn.Module): def __init__(self, state_dict: Dict[str, torch.Tensor]): super().__init__() state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys())) self.layers = torch.nn.ModuleList(state_dict.values()) self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} # .processor for unet, .self_attn for text encoder self.split_keys = [".processor", ".self_attn"] # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` def map_to(module, state_dict, *args, **kwargs): new_state_dict = {} for key, value in state_dict.items(): num = int(key.split(".")[1]) # 0 is always "layers" new_key = key.replace(f"layers.{num}", module.mapping[num]) new_state_dict[new_key] = value return new_state_dict def remap_key(key, state_dict): for k in self.split_keys: if k in key: return key.split(k)[0] + k return key.split('.')[0] def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) for key in all_keys: replace_key = remap_key(key, state_dict) new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") state_dict[new_key] = state_dict[key] del state_dict[key] self._register_state_dict_hook(map_to) self._register_load_state_dict_pre_hook(map_from, with_module=True) class Zero123PlusPipeline(diffusers.StableDiffusionPipeline): tokenizer: transformers.CLIPTokenizer text_encoder: transformers.CLIPTextModel vision_encoder: transformers.CLIPVisionModelWithProjection feature_extractor_clip: transformers.CLIPImageProcessor unet: UNet2DConditionModel scheduler: diffusers.schedulers.KarrasDiffusionSchedulers vae: AutoencoderKL ramping: nn.Linear feature_extractor_vae: transformers.CLIPImageProcessor depth_transforms_multi = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, vision_encoder: transformers.CLIPVisionModelWithProjection, feature_extractor_clip: CLIPImageProcessor, feature_extractor_vae: CLIPImageProcessor, ramping_coefficients: Optional[list] = None, safety_checker=None, ): DiffusionPipeline.__init__(self) self.register_modules( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler, safety_checker=None, vision_encoder=vision_encoder, feature_extractor_clip=feature_extractor_clip, feature_extractor_vae=feature_extractor_vae ) self.register_to_config(ramping_coefficients=ramping_coefficients) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) def prepare(self): train_sched = DDPMScheduler.from_config(self.scheduler.config) if isinstance(self.unet, UNet2DConditionModel): self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval() def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0): self.prepare() self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale) return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)])) def encode_condition_image(self, image: torch.Tensor): image = self.vae.encode(image).latent_dist.sample() return image @torch.no_grad() def __call__( self, image: Image.Image = None, prompt = "", *args, num_images_per_prompt: Optional[int] = 1, guidance_scale=4.0, depth_image: Image.Image = None, output_type: Optional[str] = "pil", width=640, height=960, num_inference_steps=28, return_dict=True, is_first = False, **kwargs ): global FIRST FIRST = is_first global IDX IDX = 0 if is_first: global EMBED EMBED=[] # Create a generator with the specified seed generator = torch.Generator(device='cuda') generator.manual_seed(42) self.prepare() if image is None: raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.") assert not isinstance(image, torch.Tensor) image = to_rgb_image(image) image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values if depth_image is not None and hasattr(self.unet, "controlnet"): depth_image = to_rgb_image(depth_image) depth_image = self.depth_transforms_multi(depth_image).to( device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype ) image = image_1.to(device=self.vae.device, dtype=self.vae.dtype) image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype) cond_lat = self.encode_condition_image(image) if guidance_scale > 1: negative_lat = self.encode_condition_image(torch.zeros_like(image)) cond_lat = torch.cat([negative_lat, cond_lat]) encoded = self.vision_encoder(image_2, output_hidden_states=False) global_embeds = encoded.image_embeds global_embeds = global_embeds.unsqueeze(-2) encoder_hidden_states = self._encode_prompt( prompt, self.device, num_images_per_prompt, False ) ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) encoder_hidden_states = encoder_hidden_states + global_embeds * ramp cak = dict(cond_lat=cond_lat) if hasattr(self.unet, "controlnet"): cak['control_depth'] = depth_image cak['cond_lat_back'] = None latents: torch.Tensor = super().__call__( None, *args, cross_attention_kwargs=cak, guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, prompt_embeds=encoder_hidden_states, num_inference_steps=num_inference_steps, output_type='latent', width=width, height=height, generator=generator, **kwargs ).images latents = unscale_latents(latents) if not output_type == "latent": image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]) else: image = latents image = self.image_processor.postprocess(image, output_type=output_type) if not return_dict: return (image,) return ImagePipelineOutput(images=image) ================================================ FILE: main.py ================================================ import os import cv2 import time import tqdm import numpy as np import dearpygui.dearpygui as dpg import torch import torch.nn.functional as F import torchvision.utils as vutils from einops import rearrange, repeat import imageio import rembg from cam_utils import orbit_camera, OrbitCamera from gs_renderer_4d import Renderer, MiniCam from dataset_4d import SparseDataset def save_image_to_local(image_tensor, file_path): # Ensure the image tensor is in the range [0, 1] image_tensor = image_tensor.clamp(0, 1) # Save the image tensor to the specified file path vutils.save_image(image_tensor, file_path) class GUI: def __init__(self, opt): self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.gui = opt.gui # enable gui self.W = opt.W self.H = opt.H self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) self.mode = "image" self.seed = "random" self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) self.need_update = True # update buffer_image # models self.device = torch.device("cuda") self.bg_remover = None self.guidance_sd = None self.guidance_zero123 = None self.enable_sd = False self.enable_zero123 = False # renderer self.renderer = Renderer(sh_degree=self.opt.sh_degree) self.gaussain_scale_factor = 1 # input image self.input_img = None self.input_mask = None self.input_img_torch = None self.input_mask_torch = None self.overlay_input_img = False self.overlay_input_img_ratio = 0.5 #self.use_depth = opt.use_depth # input text self.prompt = "" self.negative_prompt = "" # training stuff self.training = False self.optimizer = None self.step = 0 self.t = 0 self.time = 0 self.train_steps = 1 # steps per rendering loop self.init = True self.stage = 'coarse' self.path = self.opt.path self.save_step = self.opt.save_step if self.opt.size is not None: self.size = self.opt.size else: self.size = len(os.listdir(os.path.join(self.path,'ref'))) self.frames=self.size self.dataset = SparseDataset(self.opt, self.size, H=self.H, W=self.W, device=self.device) self.dataloader =self.dataset.dataloader() self.iter = iter(self.dataloader) self.ref_view_batch, self.input_mask_batch,self.zero123_view_batch,self.zero123_masks_batch = next(self.iter) self.input_img_torch_batch,self.input_mask_torch_batch,self.zero123plus_imgs_torch_batch,self.zero123plus_masks_torch_batch=[],[],[],[] # load input data from cmdline if self.opt.input is not None: self.load_input(self.opt.input) # override prompt from cmdline if self.opt.prompt is not None: self.prompt = self.opt.prompt # override if provide a checkpoint self.renderer.initialize(num_pts=self.opt.num_pts) self.point_nums = [] if self.gui: dpg.create_context() self.register_dpg() self.test_step() def __del__(self): if self.gui: dpg.destroy_context() def seed_everything(self): try: seed = int(self.seed) except: seed = np.random.randint(0, 1000000) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True self.last_seed = seed def prepare_image(self,idx): # input image if self.input_img is not None: self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device) self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device) self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False) self.zero123plus_imgs_torch=[] self.zero123plus_masks_torch=[] # input image if self.input_imgs is not None: for i in np.arange(6): #print(idx,i) self.input_img2_torch=(torch.from_numpy(self.input_imgs[i]).permute(2, 0, 1).unsqueeze(0).to(self.device)) self.zero123plus_imgs_torch.append(F.interpolate(self.input_img2_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)) self.input_mask2_torch=torch.from_numpy(self.input_masks[i]).permute(2, 0, 1).unsqueeze(0).to(self.device) self.zero123plus_masks_torch.append(F.interpolate(self.input_mask2_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)) self.input_img_torch_batch.append(self.input_img_torch) self.input_mask_torch_batch.append(self.input_mask_torch) self.zero123plus_imgs_torch_batch.append(self.zero123plus_imgs_torch) self.zero123plus_masks_torch_batch.append(self.zero123plus_masks_torch) # prepare embeddings with torch.no_grad(): self.guidance_zero123.get_img_embeds(self.input_img_torch, self.zero123plus_imgs_torch) def prepare_train(self): self.step = 0 self.end_step = self.save_step+1 ## given a load_path, load corresponding model if self.opt.load_path is not None: if self.opt.load_step is not None: self.step = self.opt.load_step else: #default loading save_step ply self.step = self.save_step auto_path = os.path.join(self.opt.outdir,self.opt.load_path + str(self.step)) ply_path = os.path.join(auto_path,'model.ply') self.renderer.gaussians.load_model(auto_path) self.renderer.gaussians.load_ply(ply_path) self.end_step =self.step+self.end_step ## setup training self.renderer.gaussians.training_setup(self.opt) ## do not do progressive sh-level self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree self.optimizer = self.renderer.gaussians.optimizer # default camera pose = orbit_camera(self.opt.elevation, 0, self.opt.radius) self.fixed_cam = MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, ) self.set_fix_cam() self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != "" self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None print(f"[INFO] loading zero123...") from guidance.zero123_4d_utils import Zero123 self.guidance_zero123 = Zero123(self.device) print(f"[INFO] loaded zero123!") ## load multiview reference images for i in np.arange(len(self.ref_view_batch)): self.input_img = self.ref_view_batch[i] self.input_mask = self.input_mask_batch[i] self.input_imgs = self.zero123_view_batch[i] self.input_masks = self.zero123_masks_batch[i] self.prepare_image(i) def train_step(self): starter = torch.cuda.Event(enable_timing=True) ender = torch.cuda.Event(enable_timing=True) starter.record() torch.autograd.set_detect_anomaly(True) for _ in range(self.train_steps): if self.stepself.opt.position_lr_max_steps: self.opt.position_lr_max_steps = self.opt.position_lr_max_steps2 self.step += 1 step_ratio = min(1, self.step / self.opt.iters) viewspace_point_tensor_list = [] radii_list = [] visibility_filter_list = [] # update lr self.renderer.gaussians.update_learning_rate(self.step) self.guidance_zero123.update_step(0,self.step) loss = 0 if self.step%self.opt.valid_interval == 0: self.save_renderings( 0, 0, 2 ,'front') self.save_renderings( 180, 0, 2 ,'back') render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512) # avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30] min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation) max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation) for _ in np.arange(self.opt.batch_size): #sample time if self.init: self.t = self.frames//2 self.time = self.t/self.frames else: self.t = np.random.randint(0,self.frames) self.time = self.t/self.frames self.input_img_torch = self.input_img_torch_batch[self.t] self.input_mask_torch = self.input_mask_torch_batch[self.t] self.zero123plus_imgs_torch = self.zero123plus_imgs_torch_batch[self.t] self.zero123plus_masks_torch = self.zero123plus_masks_torch_batch[self.t] ## need to do rgb loss in the batch cur_cam = self.fixed_cam cur_cam.time=self.time out = self.renderer.render(cur_cam,stage=self.stage) viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"] radii_list.append(radii.unsqueeze(0)) visibility_filter_list.append(visibility_filter.unsqueeze(0)) viewspace_point_tensor_list.append(viewspace_point_tensor) # rgb loss image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1] image_loss =step_ratio* 20000* F.mse_loss(image, self.input_img_torch) loss = loss + image_loss alpha = out["alpha"].unsqueeze(0) alpha_loss = step_ratio* 5000* F.mse_loss(alpha, self.input_mask_torch) loss = loss + alpha_loss images = [] poses = [] vers_plus, hors_plus, radii_plus = [], [], [] self.guidance_zero123.update_step(1,self.step) # render random view ver = np.random.randint(min_ver, max_ver) hor = np.random.randint(-180, 180) radius = 0 pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius) poses.append(pose) cur_cam = MiniCam( pose, render_resolution, render_resolution, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, ) cur_cam.time=self.time if hor<30 and hor>-30 or np.random.rand()>0.4: idx=None vers_plus.append(torch.tensor(ver,device=self.device).unsqueeze(dim=0)) hors_plus.append(torch.tensor(hor,device=self.device).unsqueeze(dim=0)) radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0)) elif hor>0: idx=hor//60 vers_plus.append(torch.tensor(ver-self.fixed_elevation[idx],device=self.device).unsqueeze(dim=0)) hors_plus.append(torch.tensor(hor-self.fixed_azimuth[idx],device=self.device).unsqueeze(dim=0)) radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0)) elif hor<0: idx = (360+hor)//60 vers_plus.append(torch.tensor(ver-self.fixed_elevation[idx],device=self.device).unsqueeze(dim=0)) hors_plus.append(torch.tensor(hor-self.fixed_azimuth[idx],device=self.device).unsqueeze(dim=0)) radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0)) bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device="cuda") out = self.renderer.render(cur_cam, bg_color=bg_color,stage=self.stage) viewspace_point_tensor, visibility_filter, radii_rendering = out["viewspace_points"], out["visibility_filter"], out["radii"] radii_list.append(radii_rendering.unsqueeze(0)) visibility_filter_list.append(visibility_filter.unsqueeze(0)) viewspace_point_tensor_list.append(viewspace_point_tensor) image = out["image"].unsqueeze(0)# [1, 3, H, W] in [0, 1] images.append(image) images_render = torch.cat(images, dim=0) #poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device) vers_batch = torch.cat(vers_plus, dim=0).cpu().numpy() hors_batch = torch.cat(hors_plus, dim=0).cpu().numpy() radii_batch = torch.cat(radii_plus, dim=0).cpu().numpy() # guidance loss # as we have different reference views, so each time we only pass 1 image into zero123 for guidance zero123_loss = self.opt.lambda_zero123 * self.guidance_zero123.train_step(images_render, vers_batch, hors_batch, radii_batch, step_ratio,idx=idx,t = self.t) loss = loss + zero123_loss # tv loss scales = out['scales'] tv_loss = self.renderer.gaussians.compute_regulation(self.opt.time_smoothness_weight, self.opt.plane_tv_weight, self.opt.l1_time_planes) loss += self.opt.lambda_tv * tv_loss # scale loss from physgaussian r = self.opt.scale_loss_ratio scale_loss = (torch.mean(torch.maximum(torch.max(scales,dim=1).values/ \ (torch.min(scales,dim=1).values+1e-8),\ torch.ones_like(torch.max(scales,dim=1).values)*r))-r) * scales.shape[0] loss += scale_loss # optimize step loss.backward() self.optimizer.step() self.optimizer.zero_grad() viewspace_point_tensor_grad = torch.zeros_like(viewspace_point_tensor) for idx in range(0, len(viewspace_point_tensor_list)): viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad radii = torch.cat(radii_list,0).max(dim=0).values visibility_filter = torch.cat(visibility_filter_list).any(dim=0) if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter: self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter]) self.renderer.gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter) if self.step % self.opt.densification_interval == 1 : self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold_percent, min_opacity=0.01, extent=1, max_screen_size=2) ender.record() torch.cuda.synchronize() t = starter.elapsed_time(ender) self.need_update = True if self.gui: dpg.set_value("_log_train_time", f"{t:.4f}ms") dpg.set_value( "_log_train_log", f"step = {self.step: 5d} (+{self.train_steps: 2d})\n loss = {loss.item():.4f}\nzero123_loss = {zero123_loss.item():.4f}image_loss ={image_loss.item():.4f}\nloss_alpha = {alpha_loss.item():.4f} scale_loss:{scale_loss.item():.4f} ", ) def set_fix_cam(self): self.fixed_cam_plus=[] self.fixed_elevation = [] self.fixed_azimuth = [] pose = orbit_camera(self.opt.elevation-30,30 , self.opt.radius) self.fixed_elevation.append(-30) self.fixed_azimuth.append(30) self.fixed_cam_plus.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) pose = orbit_camera(self.opt.elevation+20, 90, self.opt.radius) self.fixed_elevation.append(20) self.fixed_azimuth.append(90) self.fixed_cam_plus.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) pose = orbit_camera(self.opt.elevation-30, 150, self.opt.radius) self.fixed_elevation.append(-30) self.fixed_azimuth.append(150) self.fixed_cam_plus.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) pose = orbit_camera(self.opt.elevation+20, 210, self.opt.radius) self.fixed_elevation.append(+20) self.fixed_azimuth.append(210) self.fixed_cam_plus.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) pose = orbit_camera(self.opt.elevation-30, 270, self.opt.radius) self.fixed_elevation.append(-30) self.fixed_azimuth.append(270) self.fixed_cam_plus.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) pose = orbit_camera(self.opt.elevation+20, 330, self.opt.radius) self.fixed_elevation.append(20) self.fixed_azimuth.append(330) self.fixed_cam_plus.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) @torch.no_grad() def test_step(self): # ignore if no need to update if not self.need_update: return starter = torch.cuda.Event(enable_timing=True) ender = torch.cuda.Event(enable_timing=True) starter.record() # should update image if self.need_update: # render image cur_cam = MiniCam( self.cam.pose, self.W, self.H, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, time=self.time ) #print(cur_cam.time) out = self.renderer.render(cur_cam, self.gaussain_scale_factor,stage=self.stage) buffer_image = out[self.mode] # [3, H, W] if self.mode in ['depth', 'alpha']: buffer_image = buffer_image.repeat(3, 1, 1) if self.mode == 'depth': buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20) buffer_image = F.interpolate( buffer_image.unsqueeze(0), size=(self.H, self.W), mode="bilinear", align_corners=False, ).squeeze(0) self.buffer_image = ( buffer_image.permute(1, 2, 0) .contiguous() .clamp(0, 1) .contiguous() .detach() .cpu() .numpy() ) # display input_image if self.overlay_input_img and self.input_img is not None: self.buffer_image = ( self.buffer_image * (1 - self.overlay_input_img_ratio) + self.input_img * self.overlay_input_img_ratio ) self.need_update = False ender.record() torch.cuda.synchronize() t = starter.elapsed_time(ender) if self.gui: dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)") dpg.set_value( "_texture", self.buffer_image ) # buffer must be contiguous, else seg fault! def load_input(self, file): # load image pass # load image @torch.no_grad() def save_renderings(self, elev=0, azim=0, radius=2, name='front'): images=[] for i in np.arange(self.frames): pose = orbit_camera(elev, azim, radius) cam = MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, ) cam.time=float(i/self.frames) out = self.renderer.render(cam,stage=self.stage) image = out["image"].unsqueeze(0) images.append(image) os.makedirs(f'./valid/{self.opt.save_path}/{self.step}_{name}',exist_ok=True) save_image_to_local(image[0].detach(),f'./valid/{self.opt.save_path}/{self.step}_{name}/{str(i).zfill(2)}.jpg') samples=torch.cat(images,dim=0) vid = ( (rearrange(samples, "t c h w -> t h w c") * 255).clamp(0,255).detach() .cpu() .numpy() .astype(np.uint8) ) video_path = f'./valid/{self.opt.save_path}/{self.step}_{name}/video.mp4' imageio.mimwrite(video_path, vid) @torch.no_grad() def save_model(self, mode='geo', texture_size=1024): os.makedirs(self.opt.outdir, exist_ok=True) if mode == 'geo': path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply') self.renderer.gaussians.save_ply(path) elif mode == 'geo+tex': path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply') self.renderer.gaussians.save_ply(path) else: path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply') self.renderer.gaussians.save_ply(path) print(f"[INFO] save model to {path}.") def register_dpg(self): ### register texture with dpg.texture_registry(show=False): dpg.add_raw_texture( self.W, self.H, self.buffer_image, format=dpg.mvFormat_Float_rgb, tag="_texture", ) ### register window # the rendered image, as the primary window with dpg.window( tag="_primary_window", width=self.W, height=self.H, pos=[0, 0], no_move=True, no_title_bar=True, no_scrollbar=True, ): # add the texture dpg.add_image("_texture") # dpg.set_primary_window("_primary_window", True) # control window with dpg.window( label="Control", tag="_control_window", width=600, height=self.H, pos=[self.W, 0], no_move=True, no_title_bar=True, ): # button theme with dpg.theme() as theme_button: with dpg.theme_component(dpg.mvButton): dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) # timer stuff with dpg.group(horizontal=True): dpg.add_text("Infer time: ") dpg.add_text("no data", tag="_log_infer_time") def callback_setattr(sender, app_data, user_data): setattr(self, user_data, app_data) # init stuff with dpg.collapsing_header(label="Initialize", default_open=True): # seed stuff def callback_set_seed(sender, app_data): self.seed = app_data self.seed_everything() dpg.add_input_text( label="seed", default_value=self.seed, on_enter=True, callback=callback_set_seed, ) # input stuff def callback_select_input(sender, app_data): # only one item for k, v in app_data["selections"].items(): dpg.set_value("_log_input", k) self.load_input(v) self.need_update = True with dpg.file_dialog( directory_selector=False, show=False, callback=callback_select_input, file_count=1, tag="file_dialog_tag", width=700, height=400, ): dpg.add_file_extension("Images{.jpg,.jpeg,.png}") with dpg.group(horizontal=True): dpg.add_button( label="input", callback=lambda: dpg.show_item("file_dialog_tag"), ) dpg.add_text("", tag="_log_input") # overlay stuff with dpg.group(horizontal=True): def callback_toggle_overlay_input_img(sender, app_data): self.overlay_input_img = not self.overlay_input_img self.need_update = True dpg.add_checkbox( label="overlay image", default_value=self.overlay_input_img, callback=callback_toggle_overlay_input_img, ) def callback_set_overlay_input_img_ratio(sender, app_data): self.overlay_input_img_ratio = app_data self.need_update = True dpg.add_slider_float( label="ratio", min_value=0, max_value=1, format="%.1f", default_value=self.overlay_input_img_ratio, callback=callback_set_overlay_input_img_ratio, ) # prompt stuff dpg.add_input_text( label="prompt", default_value=self.prompt, callback=callback_setattr, user_data="prompt", ) dpg.add_input_text( label="negative", default_value=self.negative_prompt, callback=callback_setattr, user_data="negative_prompt", ) # save current model with dpg.group(horizontal=True): dpg.add_text("Save: ") def callback_save(sender, app_data, user_data): self.save_model(mode=user_data) dpg.add_button( label="model", tag="_button_save_model", callback=callback_save, user_data='model', ) dpg.bind_item_theme("_button_save_model", theme_button) dpg.add_button( label="geo", tag="_button_save_mesh", callback=callback_save, user_data='geo', ) dpg.bind_item_theme("_button_save_mesh", theme_button) dpg.add_button( label="geo+tex", tag="_button_save_mesh_with_tex", callback=callback_save, user_data='geo+tex', ) dpg.bind_item_theme("_button_save_mesh_with_tex", theme_button) dpg.add_input_text( label="", default_value=self.opt.save_path, callback=callback_setattr, user_data="save_path", ) # training stuff with dpg.collapsing_header(label="Train", default_open=True): # lr and train button with dpg.group(horizontal=True): dpg.add_text("Train: ") def callback_train(sender, app_data): if self.training: self.training = False dpg.configure_item("_button_train", label="start") else: self.prepare_train() self.training = True dpg.configure_item("_button_train", label="stop") # dpg.add_button( # label="init", tag="_button_init", callback=self.prepare_train # ) # dpg.bind_item_theme("_button_init", theme_button) dpg.add_button( label="start", tag="_button_train", callback=callback_train ) dpg.bind_item_theme("_button_train", theme_button) with dpg.group(horizontal=True): dpg.add_text("", tag="_log_train_time") dpg.add_text("", tag="_log_train_log") # rendering options with dpg.collapsing_header(label="Rendering", default_open=True): # mode combo def callback_change_mode(sender, app_data): self.mode = app_data self.need_update = True dpg.add_combo( ("image", "depth", "alpha"), label="mode", default_value=self.mode, callback=callback_change_mode, ) # fov slider def callback_set_fovy(sender, app_data): self.cam.fovy = np.deg2rad(app_data) self.need_update = True dpg.add_slider_int( label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=np.rad2deg(self.cam.fovy), callback=callback_set_fovy, ) def callback_set_gaussain_scale(sender, app_data): self.gaussain_scale_factor = app_data self.need_update = True dpg.add_slider_float( label="gaussain scale", min_value=0, max_value=1, format="%.2f", default_value=self.gaussain_scale_factor, callback=callback_set_gaussain_scale, ) ### register camera handler def callback_camera_drag_rotate_or_draw_mask(sender, app_data): if not dpg.is_item_focused("_primary_window"): return dx = app_data[1] dy = app_data[2] self.cam.orbit(dx, dy) self.need_update = True def callback_camera_wheel_scale(sender, app_data): if not dpg.is_item_focused("_primary_window"): return delta = app_data self.cam.scale(delta) self.need_update = True def callback_camera_drag_pan(sender, app_data): if not dpg.is_item_focused("_primary_window"): return dx = app_data[1] dy = app_data[2] self.cam.pan(dx, dy) self.need_update = True def callback_set_mouse_loc(sender, app_data): if not dpg.is_item_focused("_primary_window"): return # just the pixel coordinate in image self.mouse_loc = np.array(app_data) with dpg.handler_registry(): # for camera moving dpg.add_mouse_drag_handler( button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate_or_draw_mask, ) dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) dpg.add_mouse_drag_handler( button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan ) dpg.create_viewport( title="Gaussian3D", width=self.W + 600, height=self.H + (45 if os.name == "nt" else 0), resizable=False, ) ### global theme with dpg.theme() as theme_no_padding: with dpg.theme_component(dpg.mvAll): # set all padding to 0 to avoid scroll bar dpg.add_theme_style( dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.add_theme_style( dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.add_theme_style( dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.bind_item_theme("_primary_window", theme_no_padding) dpg.setup_dearpygui() ### register a larger font # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf if os.path.exists("LXGWWenKai-Regular.ttf"): with dpg.font_registry(): with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font: dpg.bind_font(default_font) # dpg.show_metrics() dpg.show_viewport() def render(self): assert self.gui while dpg.is_dearpygui_running(): # update texture every frame if self.training: self.train_step() self.test_step() dpg.render_dearpygui_frame() # no gui mode def train(self, iters=500): if iters > 0: self.prepare_train() for i in tqdm.trange(iters): self.train_step() # do a last prune #self.renderer.gaussians.prune(min_opacity=0.01, extent=1, max_screen_size=1) # save self.save_model(mode='model') self.save_model(mode='geo+tex') if __name__ == "__main__": import argparse from omegaconf import OmegaConf parser = argparse.ArgumentParser() parser.add_argument("--config", required=True, help="path to the yaml config file") args, extras = parser.parse_known_args() # override default config from cli opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras)) gui = GUI(opt) if opt.gui: gui.render() else: gui.train(opt.save_step+1) ================================================ FILE: mini_trainer.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from gs_renderer_4d import Renderer, MiniCam\n", "from dataset_4d import SparseDataset\n", "import os\n", "import tqdm\n", "import numpy as np\n", "import torch\n", "\n", "from cam_utils import orbit_camera, OrbitCamera\n", "from guidance.sd_utils import StableDiffusion\n", "\n", "\n", "class trainer:\n", " def __init__(self,opt) -> None:\n", " \n", " #initialize options\n", " self.opt=opt\n", " self.device=self.opt.device\n", " \n", " #initialize renderer and gaussians\n", " self.renderer = Renderer(sh_degree=self.opt.sh_degree)\n", " self.renderer.initialize(num_pts=self.opt.num_pts) \n", " self.renderer.gaussians.training_setup(self.opt)\n", " \n", " self.optimizer = self.renderer.gaussians.optimizer\n", " \n", " self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)\n", " \n", " #initialize sd. replace with your own diffusion model if necessary.\n", " self.enable_sd = True\n", " self.guidance_sd = StableDiffusion(self.device)\n", " self.guidance_sd.get_text_embeds([self.opt.prompt],negative_prompts= [''])\n", " \n", " def save(self,save_path):\n", " #save \n", " auto_path = save_path\n", " os.makedirs(auto_path,exist_ok=True)\n", " ply_path = os.path.join(auto_path,'model.ply')\n", " self.renderer.gaussians.save_ply(ply_path)\n", " self.renderer.gaussians.save_deformation(auto_path)\n", " \n", " def load(self, load_path):\n", " #load\n", " auto_path = load_path\n", " ply_path = os.path.join(auto_path,'model.ply')\n", " self.renderer.gaussians.load_model(auto_path)\n", " self.renderer.gaussians.load_ply(ply_path)\n", " \n", " \n", " def render(self,frame_id, elevation, azimuth, radius):\n", " #render with parameters\n", " pose = orbit_camera(elevation,azimuth,radius)\n", " cam = MiniCam(\n", " pose,\n", " self.opt.ref_size,\n", " self.opt.ref_size,\n", " self.cam.fovy,\n", " self.cam.fovx,\n", " self.cam.near,\n", " self.cam.far,\n", " ) \n", " cam.time=float(frame_id/30) #30 is the total frame\n", " #use stage='coarse' for static rendering, use stage='fine' for dynamic rendering\n", " out = self.renderer.render(cam,stage='fine')\n", " image = out[\"image\"].unsqueeze(0)# [1, 3, H, W] in [0, 1]\n", " \n", " return image\n", " \n", " def train(self):\n", " self.step=0\n", " \n", " for i in tqdm.tqdm(range(10000)):\n", " self.step+=1\n", " self.renderer.gaussians.update_learning_rate(self.step)\n", " loss = 0\n", " \n", " min_ver = -30\n", " max_ver = 30\n", " vers, hors, radiis, poses = [], [], [], []\n", " images=[]\n", " viewspace_point_tensor_list, radii_list, visibility_filter_list = [], [], []\n", "\n", " render_resolution=512\n", " \n", " for _ in range(self.opt.batch_size):\n", " #sample time, vertical& horizontal angle\n", " ver = np.random.randint(min_ver, max_ver)\n", " hor = np.random.randint(-180, 180)\n", " radius=0\n", " self.t = np.random.randint(0,30)\n", " self.time = self.t/30\n", " \n", " vers.append(torch.tensor(self.opt.elevation + ver,device=self.device).unsqueeze(dim=0))\n", " hors.append(torch.tensor(hor,device=self.device).unsqueeze(dim=0))\n", " radiis.append(torch.tensor(self.opt.radius + radius,device=self.device).unsqueeze(dim=0))\n", " \n", " pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)\n", " \n", " poses.append(pose)\n", "\n", "\n", " cur_cam = MiniCam(\n", " pose,\n", " render_resolution,\n", " render_resolution,\n", " self.cam.fovy,\n", " self.cam.fovx,\n", " self.cam.near,\n", " self.cam.far,\n", " )\n", " cur_cam.time=self.time\n", " \n", " bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device=\"cuda\")\n", " #use stage='coarse' for static rendering, use stage='fine' for dynamic rendering\n", " out = self.renderer.render(cur_cam, bg_color=bg_color,stage='fine')\n", " \n", " #basic values for densification\n", " viewspace_point_tensor, visibility_filter, radii = out[\"viewspace_points\"], out[\"visibility_filter\"], out[\"radii\"] \n", " radii_list.append(radii.unsqueeze(0))\n", " visibility_filter_list.append(visibility_filter.unsqueeze(0))\n", " viewspace_point_tensor_list.append(viewspace_point_tensor)\n", " \n", " image = out[\"image\"].unsqueeze(0)# [1, 3, H, W] in [0, 1]\n", " images.append(image)\n", " \n", " images_batch = torch.cat(images, dim=0)\n", " poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)\n", " vers_batch = torch.cat(vers, dim=0).cpu().numpy()\n", " hors_batch = torch.cat(hors, dim=0).cpu().numpy()\n", " radii_batch = torch.cat(radiis, dim=0).cpu().numpy()\n", "\n", " if self.enable_sd:\n", " sd_loss = self.guidance_sd.train_step(images_batch,step_ratio=None,poses=poses)\n", " # guidance loss. replace with your own diffusion model if necessary.\n", " loss = loss + sd_loss\n", " else:\n", " zero123_loss = self.guidance_zero123.train_step(images_batch, vers_batch, hors_batch, radii_batch,step_ratio=None)\n", " # guidance loss.\n", " loss = loss + zero123_loss\n", " \n", " # optimize step\n", " loss.backward()\n", " self.optimizer.step()\n", " self.optimizer.zero_grad()\n", "\n", " #densifications. Adaptive densification is used here.\n", " viewspace_point_tensor_grad = torch.zeros_like(viewspace_point_tensor)\n", " for idx in range(0, len(viewspace_point_tensor_list)):\n", " viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad\n", "\n", " if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:\n", " self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])\n", " self.renderer.gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter)\n", " if self.step % self.opt.densification_interval == 0 :\n", "\n", " self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=1, max_screen_size=2)\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "feature_dim: 128\n", "Number of points at initialisation : 10000\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a4cb52b5dc0045ccba1d352fc337cbd0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading pipeline components...: 0%| | 0/6 [00:00╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮\n", " in <module>:6 \n", " \n", " 3 opt=OmegaConf.load('./configs/image_4d_m.yaml') \n", " 4 \n", " 5 train=trainer(opt) \n", " 6 train.train() \n", " 7 \n", " \n", " in train:134 \n", " \n", " 131 │ │ │ │ loss = loss + zero123_loss \n", " 132 │ │ │ \n", " 133 │ │ │ # optimize step \n", " 134 │ │ │ loss.backward() \n", " 135 │ │ │ self.optimizer.step() \n", " 136 │ │ │ self.optimizer.zero_grad() \n", " 137 \n", " \n", " /home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/_tensor.py:487 in backward \n", " \n", " 484 │ │ │ │ create_graph=create_graph, \n", " 485 │ │ │ │ inputs=inputs, \n", " 486 │ │ │ ) \n", " 487 │ │ torch.autograd.backward( \n", " 488 │ │ │ self, gradient, retain_graph, create_graph, inputs=inputs \n", " 489 │ │ ) \n", " 490 \n", " \n", " /home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/autograd/__init__.py:200 \n", " in backward \n", " \n", " 197 # The reason we repeat same the comment below is that \n", " 198 # some Python versions print out the first line of a multi-line function \n", " 199 # calls in the traceback and some print out the last line \n", " 200 Variable._execution_engine.run_backward( # Calls into the C++ engine to run the bac \n", " 201 │ │ tensors, grad_tensors_, retain_graph, create_graph, inputs, \n", " 202 │ │ allow_unreachable=True, accumulate_grad=True) # Calls into the C++ engine to ru \n", " 203 \n", "╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\n", "KeyboardInterrupt\n", "\n" ], "text/plain": [ "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92m\u001b[0m:\u001b[94m6\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m3 \u001b[0mopt=OmegaConf.load(\u001b[33m'\u001b[0m\u001b[33m./configs/image_4d_m.yaml\u001b[0m\u001b[33m'\u001b[0m) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m4 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m5 \u001b[0mtrain=trainer(opt) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m6 train.train() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m7 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92mtrain\u001b[0m:\u001b[94m134\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m131 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mloss = loss + zero123_loss \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m132 \u001b[0m\u001b[2m│ │ │ \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m133 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[2m# optimize step\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m134 \u001b[2m│ │ │ \u001b[0mloss.backward() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m135 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m.optimizer.step() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m136 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m.optimizer.zero_grad() \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m137 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/\u001b[0m\u001b[1;33m_tensor.py\u001b[0m:\u001b[94m487\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 484 \u001b[0m\u001b[2m│ │ │ │ \u001b[0mcreate_graph=create_graph, \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 485 \u001b[0m\u001b[2m│ │ │ │ \u001b[0minputs=inputs, \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 486 \u001b[0m\u001b[2m│ │ │ \u001b[0m) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 487 \u001b[2m│ │ \u001b[0mtorch.autograd.backward( \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 488 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, inputs=inputs \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 489 \u001b[0m\u001b[2m│ │ \u001b[0m) \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m 490 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2;33m/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/autograd/\u001b[0m\u001b[1;33m__init__.py\u001b[0m:\u001b[94m200\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m197 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# The reason we repeat same the comment below is that\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m198 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# some Python versions print out the first line of a multi-line function\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m199 \u001b[0m\u001b[2m│ \u001b[0m\u001b[2m# calls in the traceback and some print out the last line\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m200 \u001b[2m│ \u001b[0mVariable._execution_engine.run_backward( \u001b[2m# Calls into the C++ engine to run the bac\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m201 \u001b[0m\u001b[2m│ │ \u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs, \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m202 \u001b[0m\u001b[2m│ │ \u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, accumulate_grad=\u001b[94mTrue\u001b[0m) \u001b[2m# Calls into the C++ engine to ru\u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m│\u001b[0m \u001b[2m203 \u001b[0m \u001b[31m│\u001b[0m\n", "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n", "\u001b[1;91mKeyboardInterrupt\u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from omegaconf import OmegaConf\n", "\n", "opt=OmegaConf.load('./configs/image_4d_m.yaml')\n", "\n", "train=trainer(opt)\n", "train.train()" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.0" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import numpy as np\n", "np.deg2rad(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "torch0", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.17" } }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: requirements.txt ================================================ tqdm rich ninja numpy pandas scipy scikit-learn matplotlib opencv-python imageio imageio-ffmpeg omegaconf argparse torch einops plyfile pygltflib dearpygui accelerate rembg[gpu,cli] #zero123plus opencv-contrib-python diffusers==0.20.2 transformers==4.29.2 streamlit==1.22.0 altair<5 huggingface_hub git+https://github.com/facebookresearch/segment-anything.git gradio>=3.50 fire ================================================ FILE: scripts/app.py ================================================ import os import sys import numpy import torch import rembg import threading import urllib.request from PIL import Image import streamlit as st import huggingface_hub class SAMAPI: predictor = None @staticmethod @st.cache_resource def get_instance(sam_checkpoint=None): if SAMAPI.predictor is None: if sam_checkpoint is None: sam_checkpoint = "tmp/sam_vit_h_4b8939.pth" if not os.path.exists(sam_checkpoint): os.makedirs('tmp', exist_ok=True) urllib.request.urlretrieve( "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", sam_checkpoint ) device = "cuda:0" if torch.cuda.is_available() else "cpu" model_type = "default" from segment_anything import sam_model_registry, SamPredictor sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) SAMAPI.predictor = predictor return SAMAPI.predictor @staticmethod def segment_api(rgb, mask=None, bbox=None, sam_checkpoint=None): """ Parameters ---------- rgb : np.ndarray h,w,3 uint8 mask: np.ndarray h,w bool Returns ------- """ np = numpy predictor = SAMAPI.get_instance(sam_checkpoint) predictor.set_image(rgb) if mask is None and bbox is None: box_input = None else: # mask to bbox if bbox is None: y1, y2, x1, x2 = np.nonzero(mask)[0].min(), np.nonzero(mask)[0].max(), np.nonzero(mask)[1].min(), \ np.nonzero(mask)[1].max() else: x1, y1, x2, y2 = bbox box_input = np.array([[x1, y1, x2, y2]]) masks, scores, logits = predictor.predict( box=box_input, multimask_output=True, return_logits=False, ) mask = masks[-1] return mask def image_examples(samples, ncols, return_key=None, example_text="Examples"): global img_example_counter trigger = False with st.expander(example_text, True): for i in range(len(samples) // ncols): cols = st.columns(ncols) for j in range(ncols): idx = i * ncols + j if idx >= len(samples): continue entry = samples[idx] with cols[j]: st.image(entry['dispi']) img_example_counter += 1 with st.columns(5)[2]: this_trigger = st.button('\+', key='imgexuse%d' % img_example_counter) trigger = trigger or this_trigger if this_trigger: trigger = entry[return_key] return trigger def segment_img(img: Image): output = rembg.remove(img) mask = numpy.array(output)[:, :, 3] > 0 sam_mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask) segmented_img = Image.new("RGBA", img.size, (0, 0, 0, 0)) segmented_img.paste(img, mask=Image.fromarray(sam_mask)) return segmented_img def segment_6imgs(zero123pp_imgs): imgs = [zero123pp_imgs.crop([0, 0, 320, 320]), zero123pp_imgs.crop([320, 0, 640, 320]), zero123pp_imgs.crop([0, 320, 320, 640]), zero123pp_imgs.crop([320, 320, 640, 640]), zero123pp_imgs.crop([0, 640, 320, 960]), zero123pp_imgs.crop([320, 640, 640, 960])] segmented_imgs = [] for i, img in enumerate(imgs): output = rembg.remove(img) mask = numpy.array(output)[:, :, 3] mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask) data = numpy.array(img)[:,:,:3] data[mask == 0] = [255, 255, 255] segmented_imgs.append(data) result = numpy.concatenate([ numpy.concatenate([segmented_imgs[0], segmented_imgs[1]], axis=1), numpy.concatenate([segmented_imgs[2], segmented_imgs[3]], axis=1), numpy.concatenate([segmented_imgs[4], segmented_imgs[5]], axis=1) ]) return Image.fromarray(result) def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result @st.cache_data def check_dependencies(): reqs = [] try: import diffusers except ImportError: import traceback traceback.print_exc() print("Error: `diffusers` not found.", file=sys.stderr) reqs.append("diffusers==0.20.2") else: if not diffusers.__version__.startswith("0.20"): print( f"Warning: You are using an unsupported version of diffusers ({diffusers.__version__}), which may lead to performance issues.", file=sys.stderr ) print("Recommended version is `diffusers==0.20.2`.", file=sys.stderr) try: import transformers except ImportError: import traceback traceback.print_exc() print("Error: `transformers` not found.", file=sys.stderr) reqs.append("transformers==4.29.2") if torch.__version__ < '2.0': try: import xformers except ImportError: print("Warning: You are using PyTorch 1.x without a working `xformers` installation.", file=sys.stderr) print("You may see a significant memory overhead when running the model.", file=sys.stderr) if len(reqs): print(f"Info: Fix all dependency errors with `pip install {' '.join(reqs)}`.") @st.cache_resource def load_zero123plus_pipeline(): if 'HF_TOKEN' in os.environ: huggingface_hub.login(os.environ['HF_TOKEN']) from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler pipeline = DiffusionPipeline.from_pretrained( "sudo-ai/zero123plus-v1.1", custom_pipeline="sudo-ai/zero123plus-pipeline", torch_dtype=torch.float16 ) # Feel free to tune the scheduler pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( pipeline.scheduler.config, timestep_spacing='trailing' ) if torch.cuda.is_available(): pipeline.to('cuda:0') sys.main_lock = threading.Lock() return pipeline ================================================ FILE: scripts/gen_mv.py ================================================ import torch import requests from PIL import Image from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler import numpy import os import sys import numpy import torch import rembg import threading import urllib.request from PIL import Image import streamlit as st import huggingface_hub from app import SAMAPI def segment_img(img: Image): output = rembg.remove(img) mask = numpy.array(output)[:, :, 3] > 0 sam_mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask) segmented_img = Image.new("RGBA", img.size, (0, 0, 0, 0)) segmented_img.paste(img, mask=Image.fromarray(sam_mask)) return segmented_img def segment_6imgs(zero123pp_imgs): imgs = [zero123pp_imgs.crop([0, 0, 320, 320]), zero123pp_imgs.crop([320, 0, 640, 320]), zero123pp_imgs.crop([0, 320, 320, 640]), zero123pp_imgs.crop([320, 320, 640, 640]), zero123pp_imgs.crop([0, 640, 320, 960]), zero123pp_imgs.crop([320, 640, 640, 960])] segmented_imgs = [] import numpy as np for i, img in enumerate(imgs): output = rembg.remove(img) mask = numpy.array(output)[:, :, 3] mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask) data = numpy.array(img)[:,:,:3] data2 = numpy.ones([320,320,4]) data2[:,:,:3] = data for i in np.arange(data2.shape[0]): for j in np.arange(data2.shape[1]): if mask[i,j]==1: data2[i,j,3]=255 segmented_imgs.append(data2) #torch.manual_seed(42) return segmented_imgs def process_img(path,destination,pipeline, is_first): # Download an example image. print('processing:',path) #cond_whole = Image.open('output.png') cond = Image.open(path) # Run the pipeline! result = pipeline(cond, num_inference_steps=75,is_first = is_first).images[0] # for general real and synthetic images of general objects # usually it is enough to have around 28 inference steps # for images with delicate details like faces (real or anime) # you may need 75-100 steps for the details to construct #result.show() #result.save("./test_png/zero123pp/output.png") result=segment_6imgs(result) print('saving:',os.path.join(destination,'0~5.png'),'in',destination) for i in numpy.arange(6): Image.fromarray(numpy.uint8(result[i])).save(os.path.join(destination,'{}.png'.format(i))) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--path", required=True, help="path to process") # DiffusionPipeline.from_pretrained cannot received relative path for custom pipeline parser.add_argument("--pipeline_path", required=True, help="path of pipeline code, in ../guidance/zero123pp") args, extras = parser.parse_known_args() pipeline = DiffusionPipeline.from_pretrained( "sudo-ai/zero123plus-v1.1", custom_pipeline=args.pipeline_path, torch_dtype=torch.float16 ) pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config( pipeline.scheduler.config, timestep_spacing='trailing' ) pipeline.to('cuda:0') directory = args.path+'/' os.makedirs(directory+'ref', exist_ok=True) os.system(f"cp -r {directory+'*.png'} {directory+'ref/'}") is_first = True l=sorted(os.listdir(directory+'ref')) for file in sorted(os.listdir(directory+'ref')): if file[-4:-1]=='.pn': filename = os.path.splitext(os.path.basename(file))[0] destination = os.path.join(directory+'zero123',filename) os.makedirs(destination, exist_ok=True) img_path = os.path.join(directory+'ref',file) process_img(img_path,destination,pipeline, is_first) is_first = False ================================================ FILE: sh_utils.py ================================================ # Copyright 2021 The PlenOctree Authors. # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: # # 1. Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # # 2. Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. import torch C0 = 0.28209479177387814 C1 = 0.4886025119029199 C2 = [ 1.0925484305920792, -1.0925484305920792, 0.31539156525252005, -1.0925484305920792, 0.5462742152960396 ] C3 = [ -0.5900435899266435, 2.890611442640554, -0.4570457994644658, 0.3731763325901154, -0.4570457994644658, 1.445305721320277, -0.5900435899266435 ] C4 = [ 2.5033429417967046, -1.7701307697799304, 0.9461746957575601, -0.6690465435572892, 0.10578554691520431, -0.6690465435572892, 0.47308734787878004, -1.7701307697799304, 0.6258357354491761, ] def eval_sh(deg, sh, dirs): """ Evaluate spherical harmonics at unit directions using hardcoded SH polynomials. Works with torch/np/jnp. ... Can be 0 or more batch dimensions. Args: deg: int SH deg. Currently, 0-3 supported sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2] dirs: jnp.ndarray unit directions [..., 3] Returns: [..., C] """ assert deg <= 4 and deg >= 0 coeff = (deg + 1) ** 2 assert sh.shape[-1] >= coeff result = C0 * sh[..., 0] if deg > 0: x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] result = (result - C1 * y * sh[..., 1] + C1 * z * sh[..., 2] - C1 * x * sh[..., 3]) if deg > 1: xx, yy, zz = x * x, y * y, z * z xy, yz, xz = x * y, y * z, x * z result = (result + C2[0] * xy * sh[..., 4] + C2[1] * yz * sh[..., 5] + C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + C2[3] * xz * sh[..., 7] + C2[4] * (xx - yy) * sh[..., 8]) if deg > 2: result = (result + C3[0] * y * (3 * xx - yy) * sh[..., 9] + C3[1] * xy * z * sh[..., 10] + C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + C3[5] * z * (xx - yy) * sh[..., 14] + C3[6] * x * (xx - 3 * yy) * sh[..., 15]) if deg > 3: result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + C4[1] * yz * (3 * xx - yy) * sh[..., 17] + C4[2] * xy * (7 * zz - 1) * sh[..., 18] + C4[3] * yz * (7 * zz - 3) * sh[..., 19] + C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + C4[5] * xz * (7 * zz - 3) * sh[..., 21] + C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) return result def RGB2SH(rgb): return (rgb - 0.5) / C0 def SH2RGB(sh): return sh * C0 + 0.5 ================================================ FILE: simple-knn/ext.cpp ================================================ /* * Copyright (C) 2023, Inria * GRAPHDECO research group, https://team.inria.fr/graphdeco * All rights reserved. * * This software is free for non-commercial, research and evaluation use * under the terms of the LICENSE.md file. * * For inquiries contact george.drettakis@inria.fr */ #include #include "spatial.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("distCUDA2", &distCUDA2); } ================================================ FILE: simple-knn/setup.py ================================================ # # Copyright (C) 2023, Inria # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # # This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr # from setuptools import setup from torch.utils.cpp_extension import CUDAExtension, BuildExtension import os cxx_compiler_flags = [] if os.name == 'nt': cxx_compiler_flags.append("/wd4624") setup( name="simple_knn", ext_modules=[ CUDAExtension( name="simple_knn._C", sources=[ "spatial.cu", "simple_knn.cu", "ext.cpp"], extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags}) ], cmdclass={ 'build_ext': BuildExtension } ) ================================================ FILE: simple-knn/simple_knn/.gitkeep ================================================ ================================================ FILE: simple-knn/simple_knn.cu ================================================ /* * Copyright (C) 2023, Inria * GRAPHDECO research group, https://team.inria.fr/graphdeco * All rights reserved. * * This software is free for non-commercial, research and evaluation use * under the terms of the LICENSE.md file. * * For inquiries contact george.drettakis@inria.fr */ #define BOX_SIZE 1024 #include "cuda_runtime.h" #include "device_launch_parameters.h" #include "simple_knn.h" #include #include #include #include #include #include #define __CUDACC__ #include #include namespace cg = cooperative_groups; struct CustomMin { __device__ __forceinline__ float3 operator()(const float3& a, const float3& b) const { return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) }; } }; struct CustomMax { __device__ __forceinline__ float3 operator()(const float3& a, const float3& b) const { return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) }; } }; __host__ __device__ uint32_t prepMorton(uint32_t x) { x = (x | (x << 16)) & 0x030000FF; x = (x | (x << 8)) & 0x0300F00F; x = (x | (x << 4)) & 0x030C30C3; x = (x | (x << 2)) & 0x09249249; return x; } __host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx) { uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1)); uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1)); uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1)); return x | (y << 1) | (z << 2); } __global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes) { auto idx = cg::this_grid().thread_rank(); if (idx >= P) return; codes[idx] = coord2Morton(points[idx], minn, maxx); } struct MinMax { float3 minn; float3 maxx; }; __global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes) { auto idx = cg::this_grid().thread_rank(); MinMax me; if (idx < P) { me.minn = points[indices[idx]]; me.maxx = points[indices[idx]]; } else { me.minn = { FLT_MAX, FLT_MAX, FLT_MAX }; me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX }; } __shared__ MinMax redResult[BOX_SIZE]; for (int off = BOX_SIZE / 2; off >= 1; off /= 2) { if (threadIdx.x < 2 * off) redResult[threadIdx.x] = me; __syncthreads(); if (threadIdx.x < off) { MinMax other = redResult[threadIdx.x + off]; me.minn.x = min(me.minn.x, other.minn.x); me.minn.y = min(me.minn.y, other.minn.y); me.minn.z = min(me.minn.z, other.minn.z); me.maxx.x = max(me.maxx.x, other.maxx.x); me.maxx.y = max(me.maxx.y, other.maxx.y); me.maxx.z = max(me.maxx.z, other.maxx.z); } __syncthreads(); } if (threadIdx.x == 0) boxes[blockIdx.x] = me; } __device__ __host__ float distBoxPoint(const MinMax& box, const float3& p) { float3 diff = { 0, 0, 0 }; if (p.x < box.minn.x || p.x > box.maxx.x) diff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x)); if (p.y < box.minn.y || p.y > box.maxx.y) diff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y)); if (p.z < box.minn.z || p.z > box.maxx.z) diff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z)); return diff.x * diff.x + diff.y * diff.y + diff.z * diff.z; } template __device__ void updateKBest(const float3& ref, const float3& point, float* knn) { float3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z }; float dist = d.x * d.x + d.y * d.y + d.z * d.z; for (int j = 0; j < K; j++) { if (knn[j] > dist) { float t = knn[j]; knn[j] = dist; dist = t; } } } __global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists) { int idx = cg::this_grid().thread_rank(); if (idx >= P) return; float3 point = points[indices[idx]]; float best[3] = { FLT_MAX, FLT_MAX, FLT_MAX }; for (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++) { if (i == idx) continue; updateKBest<3>(point, points[indices[i]], best); } float reject = best[2]; best[0] = FLT_MAX; best[1] = FLT_MAX; best[2] = FLT_MAX; for (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++) { MinMax box = boxes[b]; float dist = distBoxPoint(box, point); if (dist > reject || dist > best[2]) continue; for (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++) { if (i == idx) continue; updateKBest<3>(point, points[indices[i]], best); } } dists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f; } void SimpleKNN::knn(int P, float3* points, float* meanDists) { float3* result; cudaMalloc(&result, sizeof(float3)); size_t temp_storage_bytes; float3 init = { 0, 0, 0 }, minn, maxx; cub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init); thrust::device_vector temp_storage(temp_storage_bytes); cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init); cudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost); cub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init); cudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost); thrust::device_vector morton(P); thrust::device_vector morton_sorted(P); coord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get()); thrust::device_vector indices(P); thrust::sequence(indices.begin(), indices.end()); thrust::device_vector indices_sorted(P); cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); temp_storage.resize(temp_storage_bytes); cub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P); uint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE; thrust::device_vector boxes(num_boxes); boxMinMax << > > (P, points, indices_sorted.data().get(), boxes.data().get()); boxMeanDist << > > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists); cudaFree(result); } ================================================ FILE: simple-knn/simple_knn.h ================================================ /* * Copyright (C) 2023, Inria * GRAPHDECO research group, https://team.inria.fr/graphdeco * All rights reserved. * * This software is free for non-commercial, research and evaluation use * under the terms of the LICENSE.md file. * * For inquiries contact george.drettakis@inria.fr */ #ifndef SIMPLEKNN_H_INCLUDED #define SIMPLEKNN_H_INCLUDED class SimpleKNN { public: static void knn(int P, float3* points, float* meanDists); }; #endif ================================================ FILE: simple-knn/spatial.cu ================================================ /* * Copyright (C) 2023, Inria * GRAPHDECO research group, https://team.inria.fr/graphdeco * All rights reserved. * * This software is free for non-commercial, research and evaluation use * under the terms of the LICENSE.md file. * * For inquiries contact george.drettakis@inria.fr */ #include "spatial.h" #include "simple_knn.h" torch::Tensor distCUDA2(const torch::Tensor& points) { const int P = points.size(0); auto float_opts = points.options().dtype(torch::kFloat32); torch::Tensor means = torch::full({P}, 0.0, float_opts); SimpleKNN::knn(P, (float3*)points.contiguous().data(), means.contiguous().data()); return means; } ================================================ FILE: simple-knn/spatial.h ================================================ /* * Copyright (C) 2023, Inria * GRAPHDECO research group, https://team.inria.fr/graphdeco * All rights reserved. * * This software is free for non-commercial, research and evaluation use * under the terms of the LICENSE.md file. * * For inquiries contact george.drettakis@inria.fr */ #include torch::Tensor distCUDA2(const torch::Tensor& points); ================================================ FILE: visualize.py ================================================ import os import cv2 import time import tqdm import numpy as np import dearpygui.dearpygui as dpg import torch import torch.nn.functional as F import rembg import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from cam_utils import orbit_camera, OrbitCamera from gs_renderer_4d import Renderer, MiniCam from dataset_4d import SparseDataset from einops import rearrange, repeat import torchvision.utils as vutils import imageio def save_image_to_local(image_tensor, file_path): # Ensure the image tensor is in the range [0, 1] image_tensor = image_tensor.clamp(0, 1) # Save the image tensor to the specified file path vutils.save_image(image_tensor, file_path) class GUI: def __init__(self, opt): self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.gui = opt.gui # enable gui self.W = opt.W self.H = opt.H self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy) self.mode = "image" self.seed = "random" self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32) self.need_update = True # update buffer_image # models self.device = torch.device("cuda") self.bg_remover = None self.guidance_sd = None self.guidance_zero123 = None self.enable_sd = False self.enable_zero123 = False # renderer self.renderer = Renderer(sh_degree=self.opt.sh_degree) self.gaussain_scale_factor = 1 # input image self.input_img = None self.input_mask = None self.input_img_torch = None self.input_mask_torch = None self.overlay_input_img = False self.overlay_input_img_ratio = 0.5 # input text self.prompt = "" self.negative_prompt = "" # training stuff self.training = False self.optimizer = None self.step = 0 self.t = 0 self.time =0 self.train_steps = 1 # steps per rendering loop self.path =self.opt.path if self.opt.size is not None: self.size = self.opt.size else: self.size = len(os.listdir(os.path.join(self.path,'ref'))) self.frames=self.size # override if provide a checkpoint self.renderer.initialize(num_pts=5000) if self.gui: dpg.create_context() self.register_dpg() self.test_step() def __del__(self): if self.gui: dpg.destroy_context() def seed_everything(self): try: seed = int(self.seed) except: seed = np.random.randint(0, 1000000) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True self.last_seed = seed # prepare embeddings #save_image_to_local(self.input_img_torch[0],'./valild2/ref_{}.jpg'.format(idx)) #save_image_to_local(self.input_img_torch_batch[idx][0],'./valild2/batch_{}.jpg'.format(idx)) #save_image_to_local(self.input_imgs_torch[0][0].detach(),'./valild2/ref0_{}.jpg'.format(idx)) def prepare_train(self): self.step = 0 # setup training self.renderer.gaussians.training_setup(self.opt) # do not do progressive sh-level self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree self.optimizer = self.renderer.gaussians.optimizer # default camera pose = orbit_camera(self.opt.elevation, 0, self.opt.radius) self.fixed_cam = MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, ) self.set_fix_cam2() self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != "" self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None # lazy load guidance model if self.guidance_sd is None and self.enable_sd: if self.opt.mvdream: print(f"[INFO] loading MVDream...") from guidance.mvdream_utils import MVDream self.guidance_sd = MVDream(self.device) print(f"[INFO] loaded MVDream!") else: print(f"[INFO] loading SD...") from guidance.sd_utils import StableDiffusion self.guidance_sd = StableDiffusion(self.device) print(f"[INFO] loaded SD!") #self.renderer.gaussians.initialize_post_first_timestep() def train_step(self): starter = torch.cuda.Event(enable_timing=True) ender = torch.cuda.Event(enable_timing=True) starter.record() self.stage='fine' if self.opt.load_step==None: self.step=8000 else: self.step = self.opt.load_step auto_path = os.path.join(self.opt.outdir,self.opt.save_path + str(self.step)) #os.makedirs(auto_path,exist_ok=True) ply_path = os.path.join(auto_path,'model.ply') self.renderer.gaussians.load_model(auto_path) self.renderer.gaussians.load_ply(ply_path) self.renderer.gaussians.update_learning_rate(self.step) self.save_renderings(name='front') self.save_renderings(azim=180,name='back') self.save_renderings(azim=-30,name='front_moving',interval=2) self.save_renderings(azim=150,name='back_moving',interval=2) self.save_renderings(azim=0,name='round',interval=360//self.size) ender.record() torch.cuda.synchronize() t = starter.elapsed_time(ender) self.need_update = True if self.gui: dpg.set_value("_log_train_time", f"{t:.4f}ms") dpg.set_value( "_log_train_log", f"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {tv_loss.item():.4f}tv_loss = {loss.item():.4f}\nzero123_loss = {zero123_loss.item():.4f}image_loss ={image_loss.item():.4f} ", ) @torch.no_grad() def save_renderings(self, elev=0, azim=0, radius=2, name='front', interval=0): if interval==0: images=[] for i in np.arange(self.frames): pose = orbit_camera(elev, azim, radius) cam = MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, ) cam.time=float(i/self.frames) out = self.renderer.render(cam,stage=self.stage) image = out["image"].unsqueeze(0) images.append(image) #os.makedirs(f'./valid/{self.opt.save_path}/final_{name}',exist_ok=True) #save_image_to_local(image[0].detach(),f'./valid/{self.opt.save_path}/final_{name}/{str(i).zfill(2)}.jpg') samples=torch.cat(images,dim=0) vid = ( (rearrange(samples, "t c h w -> t h w c") * 255).clamp(0,255).detach() .cpu() .numpy() .astype(np.uint8) ) video_path = f'./valid/{self.opt.save_path}/video_{name}.mp4' imageio.mimwrite(video_path, vid) else: images=[] for i in np.arange(self.frames): pose = orbit_camera(elev, (azim+interval*i)%360, radius) cam = MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, ) cam.time=float(i/self.frames) out = self.renderer.render(cam,stage=self.stage) image = out["image"].unsqueeze(0) images.append(image) #os.makedirs(f'./valid/{self.opt.save_path}/final_{name}',exist_ok=True) #save_image_to_local(image[0].detach(),f'./valid/{self.opt.save_path}/final_{name}/{str(i).zfill(2)}.jpg') samples=torch.cat(images,dim=0) vid = ( (rearrange(samples, "t c h w -> t h w c") * 255).clamp(0,255).detach() .cpu() .numpy() .astype(np.uint8) ) video_path = f'./valid/{self.opt.save_path}/video_{name}.mp4' imageio.mimwrite(video_path, vid) def set_fix_cam2(self): self.fixed_cam2=[] self.fixed_elevation = [] self.fixed_azimuth = [] pose = orbit_camera(self.opt.elevation-30,30 , self.opt.radius) self.fixed_elevation.append(-30) self.fixed_azimuth.append(30) self.fixed_cam2.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) pose = orbit_camera(self.opt.elevation+20, 90, self.opt.radius) self.fixed_elevation.append(20) self.fixed_azimuth.append(90) self.fixed_cam2.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) pose = orbit_camera(self.opt.elevation-30, 150, self.opt.radius) self.fixed_elevation.append(-30) self.fixed_azimuth.append(150) self.fixed_cam2.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) pose = orbit_camera(self.opt.elevation+20, 210, self.opt.radius) self.fixed_elevation.append(+20) self.fixed_azimuth.append(210) self.fixed_cam2.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) pose = orbit_camera(self.opt.elevation-30, 270, self.opt.radius) self.fixed_elevation.append(-30) self.fixed_azimuth.append(270) self.fixed_cam2.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) pose = orbit_camera(self.opt.elevation+20, 330, self.opt.radius) self.fixed_elevation.append(20) self.fixed_azimuth.append(330) self.fixed_cam2.append(MiniCam( pose, self.opt.ref_size, self.opt.ref_size, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, )) @torch.no_grad() def test_step(self): # ignore if no need to update if not self.need_update: return starter = torch.cuda.Event(enable_timing=True) ender = torch.cuda.Event(enable_timing=True) starter.record() # should update image if self.need_update: # render image cur_cam = MiniCam( self.cam.pose, self.W, self.H, self.cam.fovy, self.cam.fovx, self.cam.near, self.cam.far, time=self.time ) #print(cur_cam.time) out = self.renderer.render(cur_cam, self.gaussain_scale_factor) buffer_image = out[self.mode] # [3, H, W] if self.mode in ['depth', 'alpha']: buffer_image = buffer_image.repeat(3, 1, 1) if self.mode == 'depth': buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20) buffer_image = F.interpolate( buffer_image.unsqueeze(0), size=(self.H, self.W), mode="bilinear", align_corners=False, ).squeeze(0) self.buffer_image = ( buffer_image.permute(1, 2, 0) .contiguous() .clamp(0, 1) .contiguous() .detach() .cpu() .numpy() ) # display input_image if self.overlay_input_img and self.input_img is not None: self.buffer_image = ( self.buffer_image * (1 - self.overlay_input_img_ratio) + self.input_img * self.overlay_input_img_ratio ) self.need_update = False ender.record() torch.cuda.synchronize() t = starter.elapsed_time(ender) if self.gui: dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)") dpg.set_value( "_texture", self.buffer_image ) # buffer must be contiguous, else seg fault! def load_input(self, file): # load image # load image import glob self.input_imgs=[] self.input_masks=[] file_list = glob.glob(self.pattern) print(self.pattern,file_list) for files in sorted(file_list): print(f"Reading file: {self.pattern}") print(f'[INFO] load image from {files}...') img = cv2.imread(files, cv2.IMREAD_UNCHANGED) if img.shape[-1] == 3: if self.bg_remover is None: self.bg_remover = rembg.new_session() img = rembg.remove(img, session=self.bg_remover) img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA) img = img.astype(np.float32) / 255.0 self.input_mask = img[..., 3:] # white bg self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask) # bgr to rgb self.input_img = self.input_img[..., ::-1].copy() self.input_imgs.append(self.input_img) self.input_masks.append(self.input_mask) print(f'[INFO] load image from {file}...') img = cv2.imread(file, cv2.IMREAD_UNCHANGED) if img.shape[-1] == 3: if self.bg_remover is None: self.bg_remover = rembg.new_session() img = rembg.remove(img, session=self.bg_remover) img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA) img = img.astype(np.float32) / 255.0 self.input_mask = img[..., 3:] # white bg self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask) # bgr to rgb self.input_img = self.input_img[..., ::-1].copy() # load prompt file_prompt = file.replace("_rgba.png", "_caption.txt") if os.path.exists(file_prompt): print(f'[INFO] load prompt from {file_prompt}...') with open(file_prompt, "r") as f: self.prompt = f.read().strip() @torch.no_grad() def save_model(self, mode='geo', texture_size=1024): os.makedirs(self.opt.outdir, exist_ok=True) if mode == 'geo': path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply') self.renderer.gaussians.save_ply(path) elif mode == 'geo+tex': path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply') self.renderer.gaussians.save_ply(path) else: path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply') self.renderer.gaussians.save_ply(path) print(f"[INFO] save model to {path}.") def register_dpg(self): ### register texture with dpg.texture_registry(show=False): dpg.add_raw_texture( self.W, self.H, self.buffer_image, format=dpg.mvFormat_Float_rgb, tag="_texture", ) ### register window # the rendered image, as the primary window with dpg.window( tag="_primary_window", width=self.W, height=self.H, pos=[0, 0], no_move=True, no_title_bar=True, no_scrollbar=True, ): # add the texture dpg.add_image("_texture") # dpg.set_primary_window("_primary_window", True) # control window with dpg.window( label="Control", tag="_control_window", width=600, height=self.H, pos=[self.W, 0], no_move=True, no_title_bar=True, ): # button theme with dpg.theme() as theme_button: with dpg.theme_component(dpg.mvButton): dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18)) dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47)) dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83)) dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5) dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3) # timer stuff with dpg.group(horizontal=True): dpg.add_text("Infer time: ") dpg.add_text("no data", tag="_log_infer_time") def callback_setattr(sender, app_data, user_data): setattr(self, user_data, app_data) # init stuff with dpg.collapsing_header(label="Initialize", default_open=True): # seed stuff def callback_set_seed(sender, app_data): self.seed = app_data self.seed_everything() dpg.add_input_text( label="seed", default_value=self.seed, on_enter=True, callback=callback_set_seed, ) # input stuff def callback_select_input(sender, app_data): # only one item for k, v in app_data["selections"].items(): dpg.set_value("_log_input", k) self.load_input(v) self.need_update = True with dpg.file_dialog( directory_selector=False, show=False, callback=callback_select_input, file_count=1, tag="file_dialog_tag", width=700, height=400, ): dpg.add_file_extension("Images{.jpg,.jpeg,.png}") with dpg.group(horizontal=True): dpg.add_button( label="input", callback=lambda: dpg.show_item("file_dialog_tag"), ) dpg.add_text("", tag="_log_input") # overlay stuff with dpg.group(horizontal=True): def callback_toggle_overlay_input_img(sender, app_data): self.overlay_input_img = not self.overlay_input_img self.need_update = True dpg.add_checkbox( label="overlay image", default_value=self.overlay_input_img, callback=callback_toggle_overlay_input_img, ) def callback_set_overlay_input_img_ratio(sender, app_data): self.overlay_input_img_ratio = app_data self.need_update = True dpg.add_slider_float( label="ratio", min_value=0, max_value=1, format="%.1f", default_value=self.overlay_input_img_ratio, callback=callback_set_overlay_input_img_ratio, ) # prompt stuff dpg.add_input_text( label="prompt", default_value=self.prompt, callback=callback_setattr, user_data="prompt", ) dpg.add_input_text( label="negative", default_value=self.negative_prompt, callback=callback_setattr, user_data="negative_prompt", ) # save current model with dpg.group(horizontal=True): dpg.add_text("Save: ") def callback_save(sender, app_data, user_data): self.save_model(mode=user_data) dpg.add_button( label="model", tag="_button_save_model", callback=callback_save, user_data='model', ) dpg.bind_item_theme("_button_save_model", theme_button) dpg.add_button( label="geo", tag="_button_save_mesh", callback=callback_save, user_data='geo', ) dpg.bind_item_theme("_button_save_mesh", theme_button) dpg.add_button( label="geo+tex", tag="_button_save_mesh_with_tex", callback=callback_save, user_data='geo+tex', ) dpg.bind_item_theme("_button_save_mesh_with_tex", theme_button) dpg.add_input_text( label="", default_value=self.opt.save_path, callback=callback_setattr, user_data="save_path", ) # training stuff with dpg.collapsing_header(label="Train", default_open=True): # lr and train button with dpg.group(horizontal=True): dpg.add_text("Train: ") def callback_train(sender, app_data): if self.training: self.training = False dpg.configure_item("_button_train", label="start") else: self.prepare_train() self.training = True dpg.configure_item("_button_train", label="stop") # dpg.add_button( # label="init", tag="_button_init", callback=self.prepare_train # ) # dpg.bind_item_theme("_button_init", theme_button) dpg.add_button( label="start", tag="_button_train", callback=callback_train ) dpg.bind_item_theme("_button_train", theme_button) with dpg.group(horizontal=True): dpg.add_text("", tag="_log_train_time") dpg.add_text("", tag="_log_train_log") # rendering options with dpg.collapsing_header(label="Rendering", default_open=True): # mode combo def callback_change_mode(sender, app_data): self.mode = app_data self.need_update = True dpg.add_combo( ("image", "depth", "alpha"), label="mode", default_value=self.mode, callback=callback_change_mode, ) # fov slider def callback_set_fovy(sender, app_data): self.cam.fovy = np.deg2rad(app_data) self.need_update = True dpg.add_slider_int( label="FoV (vertical)", min_value=1, max_value=120, format="%d deg", default_value=np.rad2deg(self.cam.fovy), callback=callback_set_fovy, ) def callback_set_gaussain_scale(sender, app_data): self.gaussain_scale_factor = app_data self.need_update = True dpg.add_slider_float( label="gaussain scale", min_value=0, max_value=1, format="%.2f", default_value=self.gaussain_scale_factor, callback=callback_set_gaussain_scale, ) ### register camera handler def callback_camera_drag_rotate_or_draw_mask(sender, app_data): if not dpg.is_item_focused("_primary_window"): return dx = app_data[1] dy = app_data[2] self.cam.orbit(dx, dy) self.need_update = True def callback_camera_wheel_scale(sender, app_data): if not dpg.is_item_focused("_primary_window"): return delta = app_data self.cam.scale(delta) self.need_update = True def callback_camera_drag_pan(sender, app_data): if not dpg.is_item_focused("_primary_window"): return dx = app_data[1] dy = app_data[2] self.cam.pan(dx, dy) self.need_update = True def callback_set_mouse_loc(sender, app_data): if not dpg.is_item_focused("_primary_window"): return # just the pixel coordinate in image self.mouse_loc = np.array(app_data) with dpg.handler_registry(): # for camera moving dpg.add_mouse_drag_handler( button=dpg.mvMouseButton_Left, callback=callback_camera_drag_rotate_or_draw_mask, ) dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale) dpg.add_mouse_drag_handler( button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan ) dpg.create_viewport( title="Gaussian3D", width=self.W + 600, height=self.H + (45 if os.name == "nt" else 0), resizable=False, ) ### global theme with dpg.theme() as theme_no_padding: with dpg.theme_component(dpg.mvAll): # set all padding to 0 to avoid scroll bar dpg.add_theme_style( dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.add_theme_style( dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.add_theme_style( dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core ) dpg.bind_item_theme("_primary_window", theme_no_padding) dpg.setup_dearpygui() ### register a larger font # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf if os.path.exists("LXGWWenKai-Regular.ttf"): with dpg.font_registry(): with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font: dpg.bind_font(default_font) # dpg.show_metrics() dpg.show_viewport() def render(self): assert self.gui while dpg.is_dearpygui_running(): # update texture every frame if self.training: self.train_step() self.test_step() dpg.render_dearpygui_frame() # no gui mode def train(self, iters=500): self.prepare_train() self.train_step() # do a last prune #self.renderer.gaussians.prune(min_opacity=0.01, extent=1, max_screen_size=1) # save self.save_model(mode='model') self.save_model(mode='geo+tex') if __name__ == "__main__": import argparse from omegaconf import OmegaConf parser = argparse.ArgumentParser() parser.add_argument("--config", required=True, help="path to the yaml config file") args, extras = parser.parse_known_args() # override default config from cli opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras)) gui = GUI(opt) if opt.gui: gui.render() else: gui.train(opt.iters) ================================================ FILE: zero123.py ================================================ # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import math import warnings from typing import Any, Callable, Dict, List, Optional, Union import PIL import torch import torchvision.transforms.functional as TF from diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models.modeling_utils import ModelMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import ( StableDiffusionSafetyChecker, ) from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import deprecate, is_accelerate_available, logging from diffusers.utils.torch_utils import randn_tensor from packaging import version from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection logger = logging.get_logger(__name__) # pylint: disable=invalid-name class CLIPCameraProjection(ModelMixin, ConfigMixin): """ A Projection layer for CLIP embedding and camera embedding. Parameters: embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed` additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings + additional_embeddings`. """ @register_to_config def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4): super().__init__() self.embedding_dim = embedding_dim self.additional_embeddings = additional_embeddings self.input_dim = self.embedding_dim + self.additional_embeddings self.output_dim = self.embedding_dim self.proj = torch.nn.Linear(self.input_dim, self.output_dim) def forward( self, embedding: torch.FloatTensor, ): """ The [`PriorTransformer`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`): The currently input embeddings. Returns: The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`). """ proj_embedding = self.proj(embedding) return proj_embedding class Zero123Pipeline(DiffusionPipeline): r""" Pipeline to generate variations from an input image using Stable Diffusion. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) Args: vae ([`AutoencoderKL`]): Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. image_encoder ([`CLIPVisionModelWithProjection`]): Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. scheduler ([`SchedulerMixin`]): A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. feature_extractor ([`CLIPImageProcessor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ # TODO: feature_extractor is required to encode images (if they are in PIL format), # we should give a descriptive message if the pipeline doesn't have one. _optional_components = ["safety_checker"] def __init__( self, vae: AutoencoderKL, image_encoder: CLIPVisionModelWithProjection, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, clip_camera_projection: CLIPCameraProjection, requires_safety_checker: bool = True, ): super().__init__() if safety_checker is None and requires_safety_checker: logger.warn( f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" " it only for use-cases that involve analyzing network behavior or auditing its results. For more" " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." ) if safety_checker is not None and feature_extractor is None: raise ValueError( "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) is_unet_version_less_0_9_0 = hasattr( unet.config, "_diffusers_version" ) and version.parse( version.parse(unet.config._diffusers_version).base_version ) < version.parse( "0.9.0.dev0" ) is_unet_sample_size_less_64 = ( hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 ) if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: deprecation_message = ( "The configuration file of the unet has set the default `sample_size` to smaller than" " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the" " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" " in the config might lead to incorrect results in future versions. If you have downloaded this" " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" " the `unet/config.json` file" ) deprecate( "sample_size<64", "1.0.0", deprecation_message, standard_warn=False ) new_config = dict(unet.config) new_config["sample_size"] = 64 unet._internal_dict = FrozenDict(new_config) self.register_modules( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, clip_camera_projection=clip_camera_projection, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) def enable_sequential_cpu_offload(self, gpu_id=0): r""" Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. """ if is_accelerate_available(): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [ self.unet, self.image_encoder, self.vae, self.safety_checker, ]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device def _execution_device(self): r""" Returns the device on which the pipeline's models will be executed. After calling `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module hooks. """ if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _encode_image( self, image, elevation, azimuth, distance, device, num_images_per_prompt, do_classifier_free_guidance, clip_image_embeddings=None, image_camera_embeddings=None, ): dtype = next(self.image_encoder.parameters()).dtype if image_camera_embeddings is None: if image is None: assert clip_image_embeddings is not None image_embeddings = clip_image_embeddings.to(device=device, dtype=dtype) else: if not isinstance(image, torch.Tensor): image = self.feature_extractor( images=image, return_tensors="pt" ).pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.unsqueeze(1) bs_embed, seq_len, _ = image_embeddings.shape if isinstance(elevation, float): elevation = torch.as_tensor( [elevation] * bs_embed, dtype=dtype, device=device ) if isinstance(azimuth, float): azimuth = torch.as_tensor( [azimuth] * bs_embed, dtype=dtype, device=device ) if isinstance(distance, float): distance = torch.as_tensor( [distance] * bs_embed, dtype=dtype, device=device ) camera_embeddings = torch.stack( [ torch.deg2rad(elevation), torch.sin(torch.deg2rad(azimuth)), torch.cos(torch.deg2rad(azimuth)), distance, ], dim=-1, )[:, None, :] image_embeddings = torch.cat([image_embeddings, camera_embeddings], dim=-1) # project (image, camera) embeddings to the same dimension as clip embeddings image_embeddings = self.clip_camera_projection(image_embeddings) else: image_embeddings = image_camera_embeddings.to(device=device, dtype=dtype) bs_embed, seq_len, _ = image_embeddings.shape # duplicate image embeddings for each generation per prompt, using mps friendly method image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view( bs_embed * num_images_per_prompt, seq_len, -1 ) if do_classifier_free_guidance: negative_prompt_embeds = torch.zeros_like(image_embeddings) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: has_nsfw_concept = None else: if torch.is_tensor(image): feature_extractor_input = self.image_processor.postprocess( image, output_type="pil" ) else: feature_extractor_input = self.image_processor.numpy_to_pil(image) safety_checker_input = self.feature_extractor( feature_extractor_input, return_tensors="pt" ).to(device) image, has_nsfw_concept = self.safety_checker( images=image, clip_input=safety_checker_input.pixel_values.to(dtype) ) return image, has_nsfw_concept # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents def decode_latents(self, latents): warnings.warn( "The decode_latents method is deprecated and will be removed in a future version. Please" " use VaeImageProcessor instead", FutureWarning, ) latents = 1 / self.vae.config.scaling_factor * latents image = self.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set( inspect.signature(self.scheduler.step).parameters.keys() ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator accepts_generator = "generator" in set( inspect.signature(self.scheduler.step).parameters.keys() ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs def check_inputs(self, image, height, width, callback_steps): # TODO: check image size or adjust image size to (height, width) if height % 8 != 0 or width % 8 != 0: raise ValueError( f"`height` and `width` have to be divisible by 8 but are {height} and {width}." ) if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f" {type(callback_steps)}." ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents( self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None, ): shape = ( batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) if latents is None: latents = randn_tensor( shape, generator=generator, device=device, dtype=dtype ) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents def _get_latent_model_input( self, latents: torch.FloatTensor, image: Optional[ Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor] ], num_images_per_prompt: int, do_classifier_free_guidance: bool, image_latents: Optional[torch.FloatTensor] = None, ): if isinstance(image, PIL.Image.Image): image_pt = TF.to_tensor(image).unsqueeze(0).to(latents) elif isinstance(image, list): image_pt = torch.stack([TF.to_tensor(img) for img in image], dim=0).to( latents ) elif isinstance(image, torch.Tensor): image_pt = image else: image_pt = None if image_pt is None: assert image_latents is not None image_pt = image_latents.repeat_interleave(num_images_per_prompt, dim=0) else: image_pt = image_pt * 2.0 - 1.0 # scale to [-1, 1] # FIXME: encoded latents should be multiplied with self.vae.config.scaling_factor # but zero123 was not trained this way image_pt = self.vae.encode(image_pt).latent_dist.mode() image_pt = image_pt.repeat_interleave(num_images_per_prompt, dim=0) if do_classifier_free_guidance: latent_model_input = torch.cat( [ torch.cat([latents, latents], dim=0), torch.cat([torch.zeros_like(image_pt), image_pt], dim=0), ], dim=1, ) else: latent_model_input = torch.cat([latents, image_pt], dim=1) return latent_model_input @torch.no_grad() def __call__( self, image: Optional[ Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor] ] = None, elevation: Optional[Union[float, torch.FloatTensor]] = None, azimuth: Optional[Union[float, torch.FloatTensor]] = None, distance: Optional[Union[float, torch.FloatTensor]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 3.0, num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, clip_image_embeddings: Optional[torch.FloatTensor] = None, image_camera_embeddings: Optional[torch.FloatTensor] = None, image_latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, ): r""" Function invoked when calling the pipeline for generation. Args: image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`): The image or images to guide the image generation. If you provide a tensor, it needs to comply with the configuration of [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json) `CLIPImageProcessor` height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. guidance_scale (`float`, *optional*, defaults to 7.5): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that will be called every `callback_steps` steps during inference. The function will be called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images, and the second element is a list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" (nsfw) content, according to the `safety_checker`. """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct # TODO: check input elevation, azimuth, and distance # TODO: check image, clip_image_embeddings, image_latents self.check_inputs(image, height, width, callback_steps) # 2. Define call parameters if isinstance(image, PIL.Image.Image): batch_size = 1 elif isinstance(image, list): batch_size = len(image) elif isinstance(image, torch.Tensor): batch_size = image.shape[0] else: assert image_latents is not None assert ( clip_image_embeddings is not None or image_camera_embeddings is not None ) batch_size = image_latents.shape[0] device = self._execution_device # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input image if isinstance(image, PIL.Image.Image) or isinstance(image, list): pil_image = image elif isinstance(image, torch.Tensor): pil_image = [TF.to_pil_image(image[i]) for i in range(image.shape[0])] else: pil_image = None image_embeddings = self._encode_image( pil_image, elevation, azimuth, distance, device, num_images_per_prompt, do_classifier_free_guidance, clip_image_embeddings, image_camera_embeddings, ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables # num_channels_latents = self.unet.config.in_channels num_channels_latents = 4 # FIXME: hard-coded latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = self._get_latent_model_input( latents, image, num_images_per_prompt, do_classifier_free_guidance, image_latents, ) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t ) # predict the noise residual noise_pred = self.unet( latent_model_input, t, encoder_hidden_states=image_embeddings, cross_attention_kwargs=cross_attention_kwargs, ).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * ( noise_pred_text - noise_pred_uncond ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs ).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if not output_type == "latent": image = self.vae.decode( latents / self.vae.config.scaling_factor, return_dict=False )[0] image, has_nsfw_concept = self.run_safety_checker( image, device, image_embeddings.dtype ) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess( image, output_type=output_type, do_denormalize=do_denormalize ) if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput( images=image, nsfw_content_detected=has_nsfw_concept )