Full Code of zeng-yifei/STAG4D for AI

main 9aa21c92f40b cached
25 files
257.0 KB
64.7k tokens
233 symbols
1 requests
Download .txt
Showing preview only (268K chars total). Download the full file or copy to clipboard to get everything.
Repository: zeng-yifei/STAG4D
Branch: main
Commit: 9aa21c92f40b
Files: 25
Total size: 257.0 KB

Directory structure:
gitextract_687zu6g9/

├── .gitignore
├── README.md
├── __init__.py
├── cam_utils.py
├── configs/
│   └── stag4d.yaml
├── dataset_4d.py
├── deform.py
├── gs_renderer_4d.py
├── guidance/
│   ├── zero123_4d_utils.py
│   └── zero123pp/
│       └── pipeline.py
├── main.py
├── mini_trainer.ipynb
├── requirements.txt
├── scripts/
│   ├── app.py
│   └── gen_mv.py
├── sh_utils.py
├── simple-knn/
│   ├── ext.cpp
│   ├── setup.py
│   ├── simple_knn/
│   │   └── .gitkeep
│   ├── simple_knn.cu
│   ├── simple_knn.h
│   ├── spatial.cu
│   └── spatial.h
├── visualize.py
└── zero123.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
*.pyc
./valid/*

================================================
FILE: README.md
================================================
<h1>[ECCV 2024] STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians</h1>

<div>
    <a href='https://github.com/zeng-yifei?tab=repositories/' target='_blank'>Yifei Zeng</a><sup>1</sup>&emsp;
    <a href="https://github.com/yanqinJiang" target='_blank'>Yanqin Jiang*</a><sup>2</sup>&emsp;
    <a href="https://sites.google.com/site/zhusiyucs/home/" target='_blank'>Siyu Zhu</a><sup>3</sup>&emsp;
    <a href='https://github.com/YuanxunLu' target='_blank'>Yuanxun Lu</a><sup>1</sup>&emsp;
    <a href="https://linyou.github.io/">Youtian Lin</a><sup>1</sup>&emsp;
    <a href='https://zhuhao-nju.github.io/home/' target='_blank'>Hao Zhu</a><sup>1</sup>&emsp;
    <a href="https://people.ucas.ac.cn/~huweiming">Weiming Hu</a><sup>2</sup>&emsp;
    <a href='https://cite.nju.edu.cn/People/Faculty/20190621/i5054.html' target='_blank'>Xun Cao</a><sup>1</sup>&emsp;
    <a href='https://yoyo000.github.io/' target='_blank'>Yao Yao</a><sup>1+</sup>&emsp;
</div>
<div>
    <sup>1</sup>Nanjing University
    <sup>2</sup>CASIA
    <sup>3</sup>Fudan University
</div>
<div>
    <sup>*</sup>equal contribution
    <sup>+</sup>corresponding author
</div>

<h4 align="center">
  <a href="https://nju-3dv.github.io/projects/STAG4D/" target='_blank'>[Project Page]</a> •
</h4>

# Update
7.4 Our paper has been accepted by ECCV 2024. Congrats!

6.20: IMPORTANT. Fix the bug caused by new version of diff_gauss. Newest version of diff_gauss use `color, depth, norm, alpha, radii, extra` as an output. However, previous version use `color, depth, alpha, radii` as an output. Using older version of this code will cause mismatch error and may misuse normal for the alpha loss, resulting in bad results.

5.26: Update Text/Image to 4D data below.

5.21: Fix RGB loss into the batch loop. Add visualize code.


# ⚙️ Installation

```bash
pip install -r requirements.txt

git clone --recursive https://github.com/slothfulxtx/diff-gaussian-rasterization.git
pip install ./diff-gaussian-rasterization

pip install ./simple-knn
```

# Video-to-4D
To generate the examples in the project page, you can download the dataset from [google drive](https://drive.google.com/file/d/1YDvhBv6z5SByF_WaTQVzzL9qz3TyEm6a/view?usp=sharing). Place them in the dataset folder, and run:
```bash
python main.py --config configs/stag4d.yaml path=dataset/minions save_path=minions

#use --gui=True to turn on the visualizer (recommend)
python main.py --config configs/stag4d.yaml path=dataset/minions save_path=minions gui=True

```

To generate the spatial-temporal consistent data from stratch, your should place your rgba data in the form of 

```
├── dataset
│   | your_data 
│     ├── 0_rgba.png
│     ├── 1_rgba.png
│     ├── 2_rgba.png
│     ├── ...

```

and then run 
```bash
python scripts/gen_mv.py --path dataset/your_data --pipeline_path xxx/guidance/zero123pp

python main.py --config configs/stag4d.yaml path=data_path save_path=saving_path gui=True
```

To visualize the result, use you can replace the main.py with visualize.py, and the result will be saved to the valid/xxx path, e.g.:
```bash
python visualize.py --config configs/stag4d.yaml path=dataset/minions save_path=minions
```

<img src='assets/videoto4d.gif' height='60%'>

# Text-to-4D
For Text to 4D generation, we recommend using SDXL and SVD to generate a reasonable video. Then, after matting the video, use
the command above to generate a good 4D result. (This pipeline contains many independent parts and is kind of complex, so we may upload the whole workflow after integration if possible.)

If you want generate the examples in the paper, I also updated the corresponding data here in [google drive](https://drive.google.com/file/d/1EDNL7EBMR1vlfMOABdXjHzcKY7IXdcnj/view?usp=sharing). Remember to set size to 26 in config or use `size=26` in the command:

```bash
python main.py --config configs/stag4d.yaml path=dataset/xxx save_path=xxx size=26
```

<img src='assets/textto4d3.gif' height='60%'>

# Tips for better quality
If you want sacrifice time for better quality, here is some tips you can try to further improve the generated quality.

1, Use larger batch size.

2, Run for more steps.

## Citation
If you find our work useful for your research, please consider citing our paper as well as Consistent4D:
```
@article{zeng2024stag4d,
      title={STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians}, 
      author={Yifei Zeng and Yanqin Jiang and Siyu Zhu and Yuanxun Lu and Youtian Lin and Hao Zhu and Weiming Hu and Xun Cao and Yao Yao},
      year={2024},
      eprint={2403.14939},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

@article{jiang2023consistent4d,
      title={Consistent4D: Consistent 360{\deg} Dynamic Object Generation from Monocular Video}, 
      author={Yanqin Jiang and Li Zhang and Jin Gao and Weimin Hu and Yao Yao},
      year={2023},
      eprint={2311.02848},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}
```

# Acknowledgment
This repo is built on [DreamGaussian](https://github.com/dreamgaussian/dreamgaussian) and [Zero123plus](https://github.com/SUDO-AI-3D/zero123plus). Thank all the authors for their great work.

================================================
FILE: __init__.py
================================================


================================================
FILE: cam_utils.py
================================================
import numpy as np
from scipy.spatial.transform import Rotation as R

import torch

def dot(x, y):
    if isinstance(x, np.ndarray):
        return np.sum(x * y, -1, keepdims=True)
    else:
        return torch.sum(x * y, -1, keepdim=True)


def length(x, eps=1e-20):
    if isinstance(x, np.ndarray):
        return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))
    else:
        return torch.sqrt(torch.clamp(dot(x, x), min=eps))


def safe_normalize(x, eps=1e-20):
    return x / length(x, eps)


def look_at(campos, target, opengl=True):
    # campos: [N, 3], camera/eye position
    # target: [N, 3], object to look at
    # return: [N, 3, 3], rotation matrix
    if not opengl:
        # camera forward aligns with -z
        forward_vector = safe_normalize(target - campos)
        up_vector = np.array([0, 1, 0], dtype=np.float32)
        right_vector = safe_normalize(np.cross(forward_vector, up_vector))
        up_vector = safe_normalize(np.cross(right_vector, forward_vector))
    else:
        # camera forward aligns with +z
        forward_vector = safe_normalize(campos - target)
        up_vector = np.array([0, 1, 0], dtype=np.float32)
        right_vector = safe_normalize(np.cross(up_vector, forward_vector))
        up_vector = safe_normalize(np.cross(forward_vector, right_vector))
    R = np.stack([right_vector, up_vector, forward_vector], axis=1)
    return R


# elevation & azimuth to pose (cam2world) matrix
def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):
    # radius: scalar
    # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)
    # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)
    # return: [4, 4], camera pose matrix
    if is_degree:
        elevation = np.deg2rad(np.array(elevation))
        azimuth = np.deg2rad(np.array(azimuth))
    x = radius * np.cos(elevation) * np.sin(azimuth)
    y = - radius * np.sin(elevation)
    z = radius * np.cos(elevation) * np.cos(azimuth)
    if target is None:
        target = np.zeros([3], dtype=np.float32)
    campos = np.array([x, y, z]) + target  # [3]
    T = np.eye(4, dtype=np.float32)
    T[:3, :3] = look_at(campos, target, opengl)
    T[:3, 3] = campos
    return T


class OrbitCamera:
    def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
        self.W = W
        self.H = H
        self.radius = r  # camera distance from center
        self.fovy = np.deg2rad(fovy)  # deg 2 rad
        self.near = near
        self.far = far
        self.center = np.array([0, 0, 0], dtype=np.float32)  # look at this point
        self.rot = R.from_matrix(np.eye(3))
        self.up = np.array([0, 1, 0], dtype=np.float32)  # need to be normalized!

    @property
    def fovx(self):
        return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)

    @property
    def campos(self):
        return self.pose[:3, 3]

    # pose (c2w)
    @property
    def pose(self):
        # first move camera to radius
        res = np.eye(4, dtype=np.float32)
        res[2, 3] = self.radius  # opengl convention...
        # rotate
        rot = np.eye(4, dtype=np.float32)
        rot[:3, :3] = self.rot.as_matrix()
        res = rot @ res
        # translate
        res[:3, 3] -= self.center
        return res

    # view (w2c)
    @property
    def view(self):
        return np.linalg.inv(self.pose)

    # projection (perspective)
    @property
    def perspective(self):
        y = np.tan(self.fovy / 2)
        aspect = self.W / self.H
        return np.array(
            [
                [1 / (y * aspect), 0, 0, 0],
                [0, -1 / y, 0, 0],
                [
                    0,
                    0,
                    -(self.far + self.near) / (self.far - self.near),
                    -(2 * self.far * self.near) / (self.far - self.near),
                ],
                [0, 0, -1, 0],
            ],
            dtype=np.float32,
        )

    # intrinsics
    @property
    def intrinsics(self):
        focal = self.H / (2 * np.tan(self.fovy / 2))
        return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)

    @property
    def mvp(self):
        return self.perspective @ np.linalg.inv(self.pose)  # [4, 4]

    def orbit(self, dx, dy):
        # rotate along camera up/side axis!
        side = self.rot.as_matrix()[:3, 0]
        rotvec_x = self.up * np.radians(-0.05 * dx)
        rotvec_y = side * np.radians(-0.05 * dy)
        self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot

    def scale(self, delta):
        self.radius *= 1.1 ** (-delta)

    def pan(self, dx, dy, dz=0):
        # pan in camera coordinate system (careful on the sensitivity!)
        self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])

================================================
FILE: configs/stag4d.yaml
================================================
### Input
# input rgba image path (default to None, can be load in GUI too)
input: 
# input text prompt (default to None, can be input in GUI too)
prompt: a minion
# input mesh for stage 2 (auto-search from stage 1 output path if None)
mesh:
# estimated elevation angle for input image 
elevation: 0
# reference image resolution
ref_size: 512
# density thresh for mesh extraction
density_thresh: 1
device: cuda

#dynamic
size: 30
path: dataset/minions

# checkpoint to load for stage 1 (should be a ply file)
load: 

### Output
outdir: logs
mesh_format: obj
save_path: ???
save_step: 8000
#checkpoint to load for stage fine (should be a path of ply with deform pth)
load_path: 
load_step: 
valid_interval: 500

### Training
# guidance loss weights (0 to disable)
lambda_sd: 0
mvdream: False
lambda_zero123: 1
lambda_tv: 1
scale_loss_ratio: 7.5
imagedream: False

# training batch size per iter
batch_size: 4
# training iterations for stage 1
iters: 2000
# training iterations for stage 2
iters_refine: 50
# training camera radius
radius: 2
# training camera fovy
fovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61

# whether allow geom training in stage 2
train_geo: False
# prob to invert background color during training (0 = always black, 1 = always white)
invert_bg_prob: 0.5


### GUI
gui: False
force_cuda_rast: False
# GUI resolution
H: 800
W: 800
deformation_lr_init : 0.00016
deformation_lr_final : 0.000016
deformation_lr_delay_mult : 0.02
grid_lr_init : 0.0016
grid_lr_final : 0.00016
### Gaussian splatting
num_pts: 10000
sh_degree: 0
position_lr_init : 0.0002
position_lr_final : 0.000002
position_lr_delay_mult: 0.01
position_lr_max_steps: 2000
position_lr_max_steps2: 5000

feature_lr: 0.005
opacity_lr: 0.02
scaling_lr: 0.01
rotation_lr: 0.002
init_steps: 700

percent_dense: 0.1
density_start_iter: 1200
density_end_iter: 6000
densification_interval: 100
opacity_reset_interval: 700
densify_grad_threshold_percent: 0.025

time_smoothness_weight: 5
plane_tv_weight: 0.05
l1_time_planes: 0.05


### Textured Mesh
geom_lr: 0.0001
texture_lr: 0.2

================================================
FILE: dataset_4d.py
================================================
import os
import cv2
import glob
import json
import tqdm
import random
import numpy as np
from scipy.spatial.transform import Slerp, Rotation


import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import rembg
import glob
class SparseDataset:
    def __init__(self, opt, size,device='cuda', type='train', H=256, W=256):
        super().__init__()
        
        self.opt = opt
        self.device = device
        self.type = type # train, val, test
        self.size = size
        self.H = H
        self.W = W
        self.path = opt.path
        
        self.cx = self.H / 2
        self.cy = self.W / 2
        self.bg_remover=None

    def collate_ref(self,index):
        #print(index,str(index))
        file = os.path.join(self.path,'ref/{}_rgba.png'.format(str(index)))
        #print(f'[INFO] load image from {file}...')

        img = cv2.imread(file, cv2.IMREAD_UNCHANGED)
        if img.shape[-1] == 3:
            if self.bg_remover is None:
                self.bg_remover = rembg.new_session()
            img = rembg.remove(img, session=self.bg_remover)

        img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
        img = img.astype(np.float32) / 255.0

        self.input_mask = img[..., 3:]
        # white bg
        self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)
        # bgr to rgb
        self.input_img = self.input_img[..., ::-1].copy()

        return self.input_img ,self.input_mask

                
    def collate_zero123(self,index):

        self.pattern=os.path.join(self.path,'zero123/{}_rgba/*.png'.format(str(index)))
        self.input_imgs=[]
        self.input_masks=[]
        file_list = glob.glob(self.pattern)
        #print(self.pattern,file_list)
        for files in sorted(file_list):
                    
                   
                    #print(f'[INFO] load image from {files}...')
                    img = cv2.imread(files, cv2.IMREAD_UNCHANGED)
                    if img.shape[-1] == 3:
                        if self.bg_remover is None:
                            self.bg_remover = rembg.new_session()
                        img = rembg.remove(img, session=self.bg_remover)

                    img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)
                    img = img.astype(np.float32) / 255.0

                    self.input_mask = img[..., 3:]
                    # white bg
                    self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)
                    # bgr to rgb
                    self.input_img = self.input_img[..., ::-1].copy()
                    
                    self.input_imgs.append(self.input_img)
                    self.input_masks.append(self.input_mask)
        return self.input_imgs, self.input_masks
    
    def collate(self, index):
        ref_view_batch,input_mask_batch,zero123_view_batch,zero123_masks_batch = [],[],[],[]
        for index in np.arange(self.size):
            ref_view,input_mask = self.collate_ref(index)
            zero123_view,zero123_masks = self.collate_zero123(index)
            ref_view_batch.append(ref_view)
            input_mask_batch.append(input_mask)
            zero123_view_batch.append(zero123_view)
            zero123_masks_batch.append(zero123_masks)
        return ref_view_batch, input_mask_batch,zero123_view_batch,zero123_masks_batch
    

    def dataloader(self):
        loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate,shuffle=False, num_workers=0)
        return loader

    def dataloader_d(self):
        loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate_d,shuffle=False, num_workers=0)
        return loader

================================================
FILE: deform.py
================================================

import functools
import math
import os
import time
from tkinter import W

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load
import torch.nn.init as init
import abc
            
import itertools
import logging as log
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
class Deformation(nn.Module):
    def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[], args=None):
        super(Deformation, self).__init__()
        self.D = D
        self.W = W
        self.input_ch = input_ch
        self.input_ch_time = input_ch_time
        self.skips = skips
        self.no_grid=False
        self.no_ds=False
        self.no_dr=False
        self.no_do=True
        self.bounds = 1.6
        self.kplanes_config = {
                             'grid_dimensions': 2,
                             'input_coordinate_dim': 4,
                             'output_coordinate_dim': 32,
                             'resolution': [64, 64, 64, 25]
                            }
        self.multires = [1, 2, 4, 8]
        self.no_grid = self.no_grid
        self.grid = HexPlaneField(self.bounds, self.kplanes_config, self.multires)
        self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_net()

    def create_net(self):
        
        mlp_out_dim = 0
        if self.no_grid:
            self.feature_out = [nn.Linear(4,self.W)]
        else:
            self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)]
        
        for i in range(self.D-1):
            self.feature_out.append(nn.ReLU())
            self.feature_out.append(nn.Linear(self.W,self.W))
        self.feature_out = nn.Sequential(*self.feature_out)
        output_dim = self.W
        return  \
            nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
            nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\
            nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4)), \
            nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1))
    
    def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb):

        if self.no_grid:
            h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1)
        else:
            grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1])

            h = grid_feature
        
        h = self.feature_out(h)
  
        return h

    def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None, time_emb=None):
        if time_emb is None:
            return self.forward_static(rays_pts_emb[:,:3])
        else:
            return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, time_emb)

    def forward_static(self, rays_pts_emb):
        grid_feature = self.grid(rays_pts_emb[:,:3])
        dx = self.static_mlp(grid_feature)
        return rays_pts_emb[:, :3] + dx
    def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, time_emb):
        hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_emb).float()
        dx = self.pos_deform(hidden)
        pts = rays_pts_emb[:, :3] + dx
        if self.no_ds:
            scales = scales_emb[:,:3]
        else:
            ds = self.scales_deform(hidden)
            scales = scales_emb[:,:3] + ds
        if self.no_dr:
            rotations = rotations_emb[:,:4]
        else:
            dr = self.rotations_deform(hidden)
            rotations = rotations_emb[:,:4] + dr
        if self.no_do:
            opacity = opacity_emb[:,:1] 
        else:
            do = self.opacity_deform(hidden) 
            opacity = opacity_emb[:,:1] + do
        # + do
        # print("deformation value:","pts:",torch.abs(dx).mean(),"rotation:",torch.abs(dr).mean())

        return pts, scales, rotations, opacity
    def get_mlp_parameters(self):
        parameter_list = []
        for name, param in self.named_parameters():
            if  "grid" not in name:
                parameter_list.append(param)
        return parameter_list
    def get_grid_parameters(self):
        return list(self.grid.parameters() ) 
    # + list(self.timegrid.parameters())
class deform_network(nn.Module):
    def __init__(self) :
        super(deform_network, self).__init__()
        net_width = 64
        timebase_pe = 4
        defor_depth= 1
        posbase_pe= 10
        scale_rotation_pe = 2
        opacity_pe = 2
        timenet_width = 64
        timenet_output = 32
        times_ch = 2*timebase_pe+1
        self.timenet = nn.Sequential(
        nn.Linear(times_ch, timenet_width), nn.ReLU(),
        nn.Linear(timenet_width, timenet_output))
        self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(4+3)+((4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output, args=None)
        self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)]))
        self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)]))
        self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)]))
        self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)]))
        self.apply(initialize_weights)
        # print(self)

    def forward(self, point, scales=None, rotations=None, opacity=None, times_sel=None):
        if times_sel is not None:
            return self.forward_dynamic(point, scales, rotations, opacity, times_sel)
        else:
            return self.forward_static(point)

        
    def forward_static(self, points):
        points = self.deformation_net(points)
        return points
    def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, times_sel=None):
        # times_emb = poc_fre(times_sel, self.time_poc)

        means3D, scales, rotations, opacity = self.deformation_net( point,
                                                  scales,
                                                rotations,
                                                opacity,
                                                # times_feature,
                                                times_sel)
        return means3D, scales, rotations, opacity
    def get_mlp_parameters(self):
        return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters())
    def get_grid_parameters(self):
        return self.deformation_net.get_grid_parameters()

def initialize_weights(m):
    if isinstance(m, nn.Linear):
        # init.constant_(m.weight, 0)
        init.xavier_uniform_(m.weight,gain=1)
        if m.bias is not None:
            init.xavier_uniform_(m.weight,gain=1)
            # init.constant_(m.bias, 0)



def get_normalized_directions(directions):
    """SH encoding must be in the range [0, 1]

    Args:
        directions: batch of directions
    """
    return (directions + 1.0) / 2.0


def normalize_aabb(pts, aabb):
    return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0
def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor:
    grid_dim = coords.shape[-1]

    if grid.dim() == grid_dim + 1:
        # no batch dimension present, need to add it
        grid = grid.unsqueeze(0)
    if coords.dim() == 2:
        coords = coords.unsqueeze(0)

    if grid_dim == 2 or grid_dim == 3:
        grid_sampler = F.grid_sample
    else:
        raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only "
                                  f"implemented for 2 and 3D data.")

    coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:]))
    B, feature_dim = grid.shape[:2]
    n = coords.shape[-2]
    interp = grid_sampler(
        grid,  # [B, feature_dim, reso, ...]
        coords,  # [B, 1, ..., n, grid_dim]
        align_corners=align_corners,
        mode='bilinear', padding_mode='border')
    interp = interp.view(B, feature_dim, n).transpose(-1, -2)  # [B, n, feature_dim]
    interp = interp.squeeze()  # [B?, n, feature_dim?]
    return interp

def init_grid_param(
        grid_nd: int,
        in_dim: int,
        out_dim: int,
        reso: Sequence[int],
        a: float = 0.1,
        b: float = 0.5):
    assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"
    has_time_planes = in_dim == 4
    assert grid_nd <= in_dim
    coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
    grid_coefs = nn.ParameterList()
    for ci, coo_comb in enumerate(coo_combs):
        new_grid_coef = nn.Parameter(torch.empty(
            [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]
        ))
        if has_time_planes and 3 in coo_comb:  # Initialize time planes to 1
            nn.init.ones_(new_grid_coef)
        else:
            nn.init.uniform_(new_grid_coef, a=a, b=b)
        grid_coefs.append(new_grid_coef)

    return grid_coefs


def interpolate_ms_features(pts: torch.Tensor,
                            ms_grids: Collection[Iterable[nn.Module]],
                            grid_dimensions: int,
                            concat_features: bool,
                            num_levels: Optional[int],
                            ) -> torch.Tensor:
    coo_combs = list(itertools.combinations(
        range(pts.shape[-1]), grid_dimensions)
    )
    if num_levels is None:
        num_levels = len(ms_grids)
    multi_scale_interp = [] if concat_features else 0.
    grid: nn.ParameterList
    for scale_id,  grid in enumerate(ms_grids[:num_levels]):
        interp_space = 1.
        for ci, coo_comb in enumerate(coo_combs):
            # interpolate in plane
            feature_dim = grid[ci].shape[1]  # shape of grid[ci]: 1, out_dim, *reso
            interp_out_plane = (
                grid_sample_wrapper(grid[ci], pts[..., coo_comb])
                .view(-1, feature_dim)
            )
            # compute product over planes
            interp_space = interp_space * interp_out_plane

        # combine over scales
        if concat_features:
            multi_scale_interp.append(interp_space)
        else:
            multi_scale_interp = multi_scale_interp + interp_space

    if concat_features:
        multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
    return multi_scale_interp


class HexPlaneField(nn.Module):
    def __init__(
        self,
        
        bounds,
        planeconfig,
        multires
    ) -> None:
        super().__init__()
        aabb = torch.tensor([[bounds,bounds,bounds],
                             [-bounds,-bounds,-bounds]])
        self.aabb = nn.Parameter(aabb, requires_grad=False)
        self.grid_config =  [planeconfig]
        self.multiscale_res_multipliers = multires
        self.concat_features = True

        # 1. Init planes
        self.grids = nn.ModuleList()
        self.feat_dim = 0
        for res in self.multiscale_res_multipliers:
            # initialize coordinate grid
            config = self.grid_config[0].copy()
            # Resolution fix: multi-res only on spatial planes
            config["resolution"] = [
                r * res for r in config["resolution"][:3]
            ] + config["resolution"][3:]
            gp = init_grid_param(
                grid_nd=config["grid_dimensions"],
                in_dim=config["input_coordinate_dim"],
                out_dim=config["output_coordinate_dim"],
                reso=config["resolution"],
            )
            # shape[1] is out-dim - Concatenate over feature len for each scale
            if self.concat_features:
                self.feat_dim += gp[-1].shape[1]
            else:
                self.feat_dim = gp[-1].shape[1]
            self.grids.append(gp)
        # print(f"Initialized model grids: {self.grids}")
        print("feature_dim:",self.feat_dim)


    def set_aabb(self,xyz_max, xyz_min):
        aabb = torch.tensor([
            xyz_max,
            xyz_min
        ])
        self.aabb = nn.Parameter(aabb,requires_grad=True)
        print("Voxel Plane: set aabb=",self.aabb)

    def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
        """Computes and returns the densities."""

        pts = normalize_aabb(pts, self.aabb)
        pts = torch.cat((pts, timestamps), dim=-1)  # [n_rays, n_samples, 4]

        pts = pts.reshape(-1, pts.shape[-1])
        features = interpolate_ms_features(
            pts, ms_grids=self.grids,  # noqa
            grid_dimensions=self.grid_config[0]["grid_dimensions"],
            concat_features=self.concat_features, num_levels=None)
        if len(features) < 1:
            features = torch.zeros((0, 1)).to(features.device)


        return features

    def forward(self,
                pts: torch.Tensor,
                timestamps: Optional[torch.Tensor] = None):

        features = self.get_density(pts, timestamps)

        return features
    
def compute_plane_tv(t):
    batch_size, c, h, w = t.shape
    count_h = batch_size * c * (h - 1) * w
    count_w = batch_size * c * h * (w - 1)
    h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum()
    w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum()
    return 2 * (h_tv / count_h + w_tv / count_w)  # This is summing over batch and c instead of avg


def compute_plane_smoothness(t):
    batch_size, c, h, w = t.shape
    # Convolve with a second derivative filter, in the time dimension which is dimension 2
    first_difference = t[..., 1:, :] - t[..., :h-1, :]  # [batch, c, h-1, w]
    second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :]  # [batch, c, h-2, w]
    # Take the L2 norm of the result
    return torch.square(torch.abs(second_difference)).mean()


class Regularizer():
    def __init__(self, reg_type, initialization):
        self.reg_type = reg_type
        self.initialization = initialization
        self.weight = float(self.initialization)
        self.last_reg = None

    def step(self, global_step):
        pass

    def report(self, d):
        if self.last_reg is not None:
            d[self.reg_type].update(self.last_reg.item())

    def regularize(self, *args, **kwargs) -> torch.Tensor:
        out = self._regularize(*args, **kwargs) * self.weight
        self.last_reg = out.detach()
        return out

    @abc.abstractmethod
    def _regularize(self, *args, **kwargs) -> torch.Tensor:
        raise NotImplementedError()

    def __str__(self):
        return f"Regularizer({self.reg_type}, weight={self.weight})"


class PlaneTV(Regularizer):
    def __init__(self, initial_value, what: str = 'field'):
        if what not in {'field', 'proposal_network'}:
            raise ValueError(f'what must be one of "field" or "proposal_network" '
                             f'but {what} was passed.')
        name = f'planeTV-{what[:2]}'
        super().__init__(name, initial_value)
        self.what = what

    def step(self, global_step):
        pass

    def _regularize(self, model, **kwargs):
        multi_res_grids: Sequence[nn.ParameterList]
        if self.what == 'field':
            multi_res_grids = model.field.grids
        elif self.what == 'proposal_network':
            multi_res_grids = [p.grids for p in model.proposal_networks]
        else:
            raise NotImplementedError(self.what)
        total = 0
        # Note: input to compute_plane_tv should be of shape [batch_size, c, h, w]
        for grids in multi_res_grids:
            if len(grids) == 3:
                spatial_grids = [0, 1, 2]
            else:
                spatial_grids = [0, 1, 3]  # These are the spatial grids; the others are spatiotemporal
            for grid_id in spatial_grids:
                total += compute_plane_tv(grids[grid_id])
            for grid in grids:
                # grid: [1, c, h, w]
                total += compute_plane_tv(grid)
        return total


class TimeSmoothness(Regularizer):
    def __init__(self, initial_value, what: str = 'field'):
        if what not in {'field', 'proposal_network'}:
            raise ValueError(f'what must be one of "field" or "proposal_network" '
                             f'but {what} was passed.')
        name = f'time-smooth-{what[:2]}'
        super().__init__(name, initial_value)
        self.what = what

    def _regularize(self, model, **kwargs) -> torch.Tensor:
        multi_res_grids: Sequence[nn.ParameterList]
        if self.what == 'field':
            multi_res_grids = model.field.grids
        elif self.what == 'proposal_network':
            multi_res_grids = [p.grids for p in model.proposal_networks]
        else:
            raise NotImplementedError(self.what)
        total = 0
        # model.grids is 6 x [1, rank * F_dim, reso, reso]
        for grids in multi_res_grids:
            if len(grids) == 3:
                time_grids = []
            else:
                time_grids = [2, 4, 5]
            for grid_id in time_grids:
                total += compute_plane_smoothness(grids[grid_id])
        return torch.as_tensor(total)



class L1ProposalNetwork(Regularizer):
    def __init__(self, initial_value):
        super().__init__('l1-proposal-network', initial_value)

    def _regularize(self, model, **kwargs) -> torch.Tensor:
        grids = [p.grids for p in model.proposal_networks]
        total = 0.0
        for pn_grids in grids:
            for grid in pn_grids:
                total += torch.abs(grid).mean()
        return torch.as_tensor(total)


class DepthTV(Regularizer):
    def __init__(self, initial_value):
        super().__init__('tv-depth', initial_value)

    def _regularize(self, model, model_out, **kwargs) -> torch.Tensor:
        depth = model_out['depth']
        tv = compute_plane_tv(
            depth.reshape(64, 64)[None, None, :, :]
        )
        return tv


class L1TimePlanes(Regularizer):
    def __init__(self, initial_value, what='field'):
        if what not in {'field', 'proposal_network'}:
            raise ValueError(f'what must be one of "field" or "proposal_network" '
                             f'but {what} was passed.')
        super().__init__(f'l1-time-{what[:2]}', initial_value)
        self.what = what

    def _regularize(self, model, **kwargs) -> torch.Tensor:
        # model.grids is 6 x [1, rank * F_dim, reso, reso]
        multi_res_grids: Sequence[nn.ParameterList]
        if self.what == 'field':
            multi_res_grids = model.field.grids
        elif self.what == 'proposal_network':
            multi_res_grids = [p.grids for p in model.proposal_networks]
        else:
            raise NotImplementedError(self.what)

        total = 0.0
        for grids in multi_res_grids:
            if len(grids) == 3:
                continue
            else:
                # These are the spatiotemporal grids
                spatiotemporal_grids = [2, 4, 5]
            for grid_id in spatiotemporal_grids:
                total += torch.abs(1 - grids[grid_id]).mean()
        return torch.as_tensor(total)



================================================
FILE: gs_renderer_4d.py
================================================
import os
import math
import numpy as np
from typing import NamedTuple
from plyfile import PlyData, PlyElement

import torch
from torch import nn

from diff_gauss import (
    GaussianRasterizationSettings,
    GaussianRasterizer,
)
from simple_knn._C import distCUDA2

from sh_utils import eval_sh, SH2RGB, RGB2SH

from deform import *
def inverse_sigmoid(x):
    return torch.log(x/(1-x))




def get_expon_lr_func(
    lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000
):
    
    def helper(step):
        if lr_init == lr_final:
            # constant lr, ignore other params
            return lr_init
        if step < 0 or (lr_init == 0.0 and lr_final == 0.0):
            # Disable this parameter
            return 0.0
        if lr_delay_steps > 0:
            # A kind of reverse cosine decay.
            delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
                0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)
            )
        else:
            delay_rate = 1.0
        t = np.clip(step / max_steps, 0, 1)
        log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
        return delay_rate * log_lerp

    return helper


def strip_lowerdiag(L):
    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device="cuda")

    uncertainty[:, 0] = L[:, 0, 0]
    uncertainty[:, 1] = L[:, 0, 1]
    uncertainty[:, 2] = L[:, 0, 2]
    uncertainty[:, 3] = L[:, 1, 1]
    uncertainty[:, 4] = L[:, 1, 2]
    uncertainty[:, 5] = L[:, 2, 2]
    return uncertainty

def strip_symmetric(sym):
    return strip_lowerdiag(sym)

def gaussian_3d_coeff(xyzs, covs):
    # xyzs: [N, 3]
    # covs: [N, 6]
    x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]
    a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5]

    # eps must be small enough !!!
    inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)
    inv_a = (d * f - e**2) * inv_det
    inv_b = (e * c - b * f) * inv_det
    inv_c = (e * b - c * d) * inv_det
    inv_d = (a * f - c**2) * inv_det
    inv_e = (b * c - e * a) * inv_det
    inv_f = (a * d - b**2) * inv_det

    power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e

    power[power > 0] = -1e10 # abnormal values... make weights 0
        
    return torch.exp(power)

def build_rotation(r):
    norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])

    q = r / norm[:, None]

    R = torch.zeros((q.size(0), 3, 3), device='cuda')

    r = q[:, 0]
    x = q[:, 1]
    y = q[:, 2]
    z = q[:, 3]

    R[:, 0, 0] = 1 - 2 * (y*y + z*z)
    R[:, 0, 1] = 2 * (x*y - r*z)
    R[:, 0, 2] = 2 * (x*z + r*y)
    R[:, 1, 0] = 2 * (x*y + r*z)
    R[:, 1, 1] = 1 - 2 * (x*x + z*z)
    R[:, 1, 2] = 2 * (y*z - r*x)
    R[:, 2, 0] = 2 * (x*z - r*y)
    R[:, 2, 1] = 2 * (y*z + r*x)
    R[:, 2, 2] = 1 - 2 * (x*x + y*y)
    return R

def build_scaling_rotation(s, r):
    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device="cuda")
    R = build_rotation(r)

    L[:,0,0] = s[:,0]
    L[:,1,1] = s[:,1]
    L[:,2,2] = s[:,2]

    L = R @ L
    return L

class BasicPointCloud(NamedTuple):
    points: np.array
    colors: np.array
    normals: np.array


class GaussianModel:

    def setup_functions(self):
        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
            L = build_scaling_rotation(scaling_modifier * scaling, rotation)
            actual_covariance = L @ L.transpose(1, 2)
            symm = strip_symmetric(actual_covariance)
            return symm
        
        self.scaling_activation = torch.exp
        self.scaling_inverse_activation = torch.log

        self.covariance_activation = build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize
        
    def initialize(self, initial_values, raw=False):
        # NOTE: actual initialization is done in trainer
        # raw stands for raw values, i.e. not passed through activation
        self._xyz = nn.Parameter(initial_values["mean"].requires_grad_(True)).to('cuda')
        self._rotation = nn.Parameter(initial_values["qvec"].requires_grad_(True)).to('cuda')

        #self._scaling = nn.Parameter(initial_values["svec"].requires_grad_(True)).to('cuda')
        #self._features_dc = nn.Parameter(initial_values["color"].requires_grad_(True)).to('cuda')
        self._opacity = nn.Parameter(initial_values["alpha"].requires_grad_(True)).to('cuda')

    
    def __init__(self, sh_degree : int,args = None):
        self.active_sh_degree = 0
        self.max_sh_degree = sh_degree  
        self._xyz = torch.empty(0)
        self._features_dc = torch.empty(0)
        self._features_rest = torch.empty(0)
        self._scaling = torch.empty(0)
        self._rotation = torch.empty(0)
        self._opacity = torch.empty(0)
        self.max_radii2D = torch.empty(0)
        self.xyz_gradient_accum = torch.empty(0)
        self.denom = torch.empty(0)
        self.optimizer = None
        self.percent_dense = 0
        self.spatial_lr_scale = 0
        self._deformation_table = torch.empty(0)
        self._deformation = deform_network()
        self.setup_functions()

    def capture(self):
        return (
            self.active_sh_degree,
            self._xyz,
            self._deformation.state_dict(),
            self._deformation_table,
            self._features_dc,
            self._features_rest,
            self._scaling,
            self._rotation,
            self._opacity,
            self.max_radii2D,
            self.xyz_gradient_accum,
            self.denom,
            self.optimizer.state_dict(),
            self.spatial_lr_scale,
        )
    
    def restore(self, model_args, training_args):
        (self.active_sh_degree, 
        self._xyz, 
        self._deformation_table,
        self._deformation,
        self._features_dc, 
        self._features_rest,
        self._scaling, 
        self._rotation, 
        self._opacity,
        self.max_radii2D, 
        xyz_gradient_accum, 
        denom,
        opt_dict, 
        self.spatial_lr_scale) = model_args
        self.training_setup(training_args)
        self.xyz_gradient_accum = xyz_gradient_accum
        self.denom = denom
        self.optimizer.load_state_dict(opt_dict)

    @property
    def get_scaling(self):
        return self.scaling_activation(self._scaling)
    
    @property
    def get_rotation(self):
        return self.rotation_activation(self._rotation)
    
    @property
    def get_xyz(self):
        return self._xyz
    
    @property
    def get_features(self):
        features_dc = self._features_dc
        features_rest = self._features_rest
        return torch.cat((features_dc, features_rest), dim=1)
    
    @property
    def get_opacity(self):
        return self.opacity_activation(self._opacity)


    

    
    def get_covariance(self, scaling_modifier = 1):
        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)

    def oneupSHdegree(self):
        if self.active_sh_degree < self.max_sh_degree:
            self.active_sh_degree += 1

    def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float = 1):
        self.spatial_lr_scale = spatial_lr_scale
        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()
        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())
        features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()
        features[:, :3, 0 ] = fused_color
        features[:, 3:, 1:] = 0.0

        print("Number of points at initialisation : ", fused_point_cloud.shape[0])

        dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)
        scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)
        rots = torch.zeros((fused_point_cloud.shape[0], 4), device="cuda")
        rots[:, 0] = 1

        opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device="cuda"))

        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))
        self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))
        self._scaling = nn.Parameter(scales.requires_grad_(True))
        self._rotation = nn.Parameter(rots.requires_grad_(True))
        self._opacity = nn.Parameter(opacities.requires_grad_(True))
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
        #print(self._xyz.shape,self._rotation.shape)
        self._deformation = self._deformation.to("cuda") 
        self.active_sh_degree = self.max_sh_degree
        self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")

    def training_setup(self, training_args):
        self.percent_dense = training_args.percent_dense
        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
        

        l = [
            {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, "name": "xyz"},
            {'params': list(self._deformation.get_mlp_parameters()), 'lr': training_args.deformation_lr_init * self.spatial_lr_scale, "name": "deformation"},
            {'params': list(self._deformation.get_grid_parameters()), 'lr': training_args.grid_lr_init * self.spatial_lr_scale, "name": "grid"},
            {'params': [self._features_dc], 'lr': training_args.feature_lr, "name": "f_dc"},
            {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, "name": "f_rest"},
            {'params': [self._opacity], 'lr': training_args.opacity_lr, "name": "opacity"},
            {'params': [self._scaling], 'lr': training_args.scaling_lr, "name": "scaling"},
            {'params': [self._rotation], 'lr': training_args.rotation_lr, "name": "rotation"}
            
        ]

        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)
        self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,
                                                    lr_final=training_args.position_lr_final*self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.position_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)
        self.deformation_scheduler_args = get_expon_lr_func(lr_init=training_args.deformation_lr_init*self.spatial_lr_scale,
                                                    lr_final=training_args.deformation_lr_final*self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.deformation_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)    
        self.grid_scheduler_args = get_expon_lr_func(lr_init=training_args.grid_lr_init*self.spatial_lr_scale,
                                                    lr_final=training_args.grid_lr_final*self.spatial_lr_scale,
                                                    lr_delay_mult=training_args.deformation_lr_delay_mult,
                                                    max_steps=training_args.position_lr_max_steps)    
    def update_learning_rate(self, iteration):
        ''' Learning rate scheduling per step '''
        for param_group in self.optimizer.param_groups:
            if param_group["name"] == "xyz":
                lr = self.xyz_scheduler_args(iteration)
                param_group['lr'] = lr
                # return lr
            if  "grid" in param_group["name"]:
                lr = self.grid_scheduler_args(iteration)
                param_group['lr'] = lr
                # return lr
            elif param_group["name"] == "deformation":
                lr = self.deformation_scheduler_args(iteration)
                param_group['lr'] = lr
                # return lr

    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        # All channels except the 3 DC
        for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):
            l.append('f_dc_{}'.format(i))
        for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):
            l.append('f_rest_{}'.format(i))
        l.append('opacity')
        for i in range(self._scaling.shape[1]):
            l.append('scale_{}'.format(i))
        for i in range(self._rotation.shape[1]):
            l.append('rot_{}'.format(i))
        return l

    def save_ply(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)

        xyz = self._xyz.detach().cpu().numpy()
        normals = np.zeros_like(xyz)
        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        opacities = self._opacity.detach().cpu().numpy()
        scale = self._scaling.detach().cpu().numpy()
        rotation = self._rotation.detach().cpu().numpy()
        
        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]

        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)

    def compute_deformation(self,time):
        
        deform = self._deformation[:,:,:time].sum(dim=-1)
        xyz = self._xyz + deform
        return xyz
    def load_model(self, path):
        print("loading model from exists{}".format(path))
        weight_dict = torch.load(os.path.join(path,"deformation.pth"),map_location="cuda")
        self._deformation.load_state_dict(weight_dict)
        self._deformation = self._deformation.to("cuda")
        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
        self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
        if os.path.exists(os.path.join(path, "deformation_table.pth")):
            self._deformation_table = torch.load(os.path.join(path, "deformation_table.pth"),map_location="cuda")
        if os.path.exists(os.path.join(path, "deformation_accum.pth")):
            self._deformation_accum = torch.load(os.path.join(path, "deformation_accum.pth"),map_location="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")
        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
        self._deformation = self._deformation.to("cuda") 
        

    def save_deformation(self, path):
        torch.save(self._deformation.state_dict(),os.path.join(path, "deformation.pth"))
        torch.save(self._deformation_table,os.path.join(path, "deformation_table.pth"))
        torch.save(self._deformation_accum,os.path.join(path, "deformation_accum.pth"))
        
    def load_ply(self, path):
        plydata = PlyData.read(path)

        xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
                        np.asarray(plydata.elements[0]["y"]),
                        np.asarray(plydata.elements[0]["z"])),  axis=1)
        opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]

        print("Number of points at loading : ", xyz.shape[0])

        features_dc = np.zeros((xyz.shape[0], 3, 1))
        features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
        features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
        features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])

        extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
        assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3
        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
        for idx, attr_name in enumerate(extra_f_names):
            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
        features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))

        scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
        scales = np.zeros((xyz.shape[0], len(scale_names)))
        for idx, attr_name in enumerate(scale_names):
            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])

        rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
        rots = np.zeros((xyz.shape[0], len(rot_names)))
        for idx, attr_name in enumerate(rot_names):
            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])

        self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device="cuda").requires_grad_(True))
        self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
        self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device="cuda").transpose(1, 2).contiguous().requires_grad_(True))
        self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device="cuda").requires_grad_(True))
        self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device="cuda").requires_grad_(True))
        self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device="cuda").requires_grad_(True))
        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device="cuda"),0)
        #print(self._xyz.shape,self._rotation.shape)
        self._deformation = self._deformation.to("cuda") 
        self.active_sh_degree = self.max_sh_degree
        self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")

        
    def replace_tensor_to_optimizer(self, tensor, name):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if group["name"] == name:
                stored_state = self.optimizer.state.get(group['params'][0], None)
                stored_state["exp_avg"] = torch.zeros_like(tensor)
                stored_state["exp_avg_sq"] = torch.zeros_like(tensor)

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter(tensor.requires_grad_(True))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def _prune_optimizer(self, mask):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if len(group["params"]) > 1:
                continue
            stored_state = self.optimizer.state.get(group['params'][0], None)
            if stored_state is not None:
                stored_state["exp_avg"] = stored_state["exp_avg"][mask]
                stored_state["exp_avg_sq"] = stored_state["exp_avg_sq"][mask]

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter((group["params"][0][mask].requires_grad_(True)))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(group["params"][0][mask].requires_grad_(True))
                optimizable_tensors[group["name"]] = group["params"][0]
        return optimizable_tensors

    def prune_points(self, mask):
        valid_points_mask = ~mask
        optimizable_tensors = self._prune_optimizer(valid_points_mask)

        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]
        self._deformation_accum = self._deformation_accum[valid_points_mask]
        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]
        self._deformation_table = self._deformation_table[valid_points_mask]
        self.denom = self.denom[valid_points_mask]
        self.max_radii2D = self.max_radii2D[valid_points_mask]

    def cat_tensors_to_optimizer(self, tensors_dict):
        optimizable_tensors = {}
        for group in self.optimizer.param_groups:
            if len(group["params"])>1:continue
            assert len(group["params"]) == 1
            extension_tensor = tensors_dict[group["name"]]
            stored_state = self.optimizer.state.get(group['params'][0], None)
            if stored_state is not None:

                stored_state["exp_avg"] = torch.cat((stored_state["exp_avg"], torch.zeros_like(extension_tensor)), dim=0)
                stored_state["exp_avg_sq"] = torch.cat((stored_state["exp_avg_sq"], torch.zeros_like(extension_tensor)), dim=0)

                del self.optimizer.state[group['params'][0]]
                group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
                self.optimizer.state[group['params'][0]] = stored_state

                optimizable_tensors[group["name"]] = group["params"][0]
            else:
                group["params"][0] = nn.Parameter(torch.cat((group["params"][0], extension_tensor), dim=0).requires_grad_(True))
                optimizable_tensors[group["name"]] = group["params"][0]

        return optimizable_tensors

    def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table):
        d = {"xyz": new_xyz,
        "f_dc": new_features_dc,
        "f_rest": new_features_rest,
        "opacity": new_opacities,
        "scaling" : new_scaling,
        "rotation" : new_rotation,
        # "deformation": new_deformation
       }

        optimizable_tensors = self.cat_tensors_to_optimizer(d)
        self._xyz = optimizable_tensors["xyz"]
        self._features_dc = optimizable_tensors["f_dc"]
        self._features_rest = optimizable_tensors["f_rest"]
        self._opacity = optimizable_tensors["opacity"]
        self._scaling = optimizable_tensors["scaling"]
        self._rotation = optimizable_tensors["rotation"]
        # self._deformation = optimizable_tensors["deformation"]
        
        self._deformation_table = torch.cat([self._deformation_table,new_deformation_table],-1)
        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self._deformation_accum = torch.zeros((self.get_xyz.shape[0], 3), device="cuda")
        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device="cuda")
        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device="cuda")

    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
        n_init_points = self.get_xyz.shape[0]
        # Extract points that satisfy the gradient condition
        padded_grad = torch.zeros((n_init_points), device="cuda")
        padded_grad[:grads.shape[0]] = grads.squeeze()
        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)
        selected_pts_mask = torch.logical_and(selected_pts_mask,
                                              torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)
        if not selected_pts_mask.any():
            return
        stds = self.get_scaling[selected_pts_mask].repeat(N,1)
        means =torch.zeros((stds.size(0), 3),device="cuda")
        samples = torch.normal(mean=means, std=stds)
        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)
        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)
        new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))
        new_rotation = self._rotation[selected_pts_mask].repeat(N,1)
        new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)
        new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)
        new_opacity = self._opacity[selected_pts_mask].repeat(N,1)
        new_deformation_table = self._deformation_table[selected_pts_mask].repeat(N)
        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_deformation_table)

        prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device="cuda", dtype=bool)))
        self.prune_points(prune_filter)
        
    def densify_and_clone(self, grads, grad_threshold, scene_extent):
        selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)
        selected_pts_mask = torch.logical_and(selected_pts_mask,
                                              torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)
        
        new_xyz = self._xyz[selected_pts_mask] 
        # - 0.001 * self._xyz.grad[selected_pts_mask]
        new_features_dc = self._features_dc[selected_pts_mask]
        new_features_rest = self._features_rest[selected_pts_mask]
        new_opacities = self._opacity[selected_pts_mask]
        new_scaling = self._scaling[selected_pts_mask]
        new_rotation = self._rotation[selected_pts_mask]
        new_deformation_table = self._deformation_table[selected_pts_mask]

        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table)

    def densify_and_prune(self, max_grad_percent, min_opacity, extent, max_screen_size):
        grads = self.xyz_gradient_accum / self.denom
        grads[grads.isnan()] = 0.0
        grad_log = torch.log(grads)
        grad_log2=grad_log[~grad_log.isnan()]
        grad_log3=grad_log[~grad_log2.isinf()]

        max_grad_1 = torch.exp(grad_log3.mean()+grad_log3.var()) #adaptive densification with mean and var, unused
        max_grad_2 = torch.exp(grad_log3.squeeze(dim=1).sort(descending=True)[0][int(max_grad_percent*grad_log3.shape[0])]) #adaptive densification with relative grad
        max_grad = max_grad_2 #choose which to use

        #print('max_grad',max_grad_percent,max_grad_1,max_grad_2,grad_log3.mean(),grad_log3.var())
        self.densify_and_clone(grads, max_grad, extent)
        self.densify_and_split(grads, max_grad, extent)

        prune_mask = (self.get_opacity < min_opacity).squeeze()

        if max_screen_size:
            big_points_vs = self.max_radii2D > max_screen_size
            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
            small_ws = self.get_scaling.max(dim=1).values<0.001
            prune_mask = torch.logical_or(torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws),small_ws)
        self.prune_points(prune_mask)


        torch.cuda.empty_cache()


    
    def prune(self, min_opacity=0.01, extent=1, max_screen_size=1):

        prune_mask = (self.get_opacity < min_opacity).squeeze()
        # prune_mask_2 = torch.logical_and(self.get_opacity <= inverse_sigmoid(0.101 , dtype=torch.float, device="cuda"), self.get_opacity >= inverse_sigmoid(0.999 , dtype=torch.float, device="cuda"))
        # prune_mask = torch.logical_or(prune_mask, prune_mask_2)
        # deformation_sum = abs(self._deformation).sum(dim=-1).mean(dim=-1) 
        # deformation_mask = (deformation_sum < torch.quantile(deformation_sum, torch.tensor([0.5]).to("cuda")))
        # prune_mask = prune_mask & deformation_mask
        if max_screen_size:
            big_points_vs = self.max_radii2D > max_screen_size
            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent
            #prune_mask = torch.logical_or(prune_mask, big_points_vs)
            small_ws = self.get_scaling.min(dim=1).values<0.001
            prune_mask = torch.logical_or(torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws),small_ws)
        self.prune_points(prune_mask)
        



    def standard_constaint(self):
        
        means3D = self._xyz.detach()
        scales = self._scaling.detach()
        rotations = self._rotation.detach()
        opacity = self._opacity.detach()
        time =  torch.tensor(0).to("cuda").repeat(means3D.shape[0],1)
        means3D_deform, scales_deform, rotations_deform, _ = self._deformation(means3D, scales, rotations, opacity, time)
        position_error = (means3D_deform - means3D)**2
        rotation_error = (rotations_deform - rotations)**2 
        scaling_erorr = (scales_deform - scales)**2
        return position_error.mean() + rotation_error.mean() + scaling_erorr.mean()


    def add_densification_stats(self, viewspace_point_tensor, update_filter):
        #print(viewspace_point_tensor,update_filter)
        self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor[update_filter,:2], dim=-1, keepdim=True)
        self.denom[update_filter] += 1
        
    @torch.no_grad()
    def update_deformation_table(self,threshold):
        # print("origin deformation point nums:",self._deformation_table.sum())
        self._deformation_table = torch.gt(self._deformation_accum.max(dim=-1).values/100,threshold)
    def print_deformation_weight_grad(self):
        for name, weight in self._deformation.named_parameters():
            if weight.requires_grad:
                if weight.grad is None:
                    
                    print(name," :",weight.grad)
                else:
                    if weight.grad.mean() != 0:
                        print(name," :",weight.grad.mean(), weight.grad.min(), weight.grad.max())
        print("-"*50)
    def _plane_regulation(self):
        multi_res_grids = self._deformation.deformation_net.grid.grids
        total = 0
        # model.grids is 6 x [1, rank * F_dim, reso, reso]
        for grids in multi_res_grids:
            if len(grids) == 3:
                time_grids = []
            else:
                time_grids =  [0,1,3]
            for grid_id in time_grids:
                total += compute_plane_smoothness(grids[grid_id])
        return total
    def _time_regulation(self):
        multi_res_grids = self._deformation.deformation_net.grid.grids
        total = 0
        # model.grids is 6 x [1, rank * F_dim, reso, reso]
        for grids in multi_res_grids:
            if len(grids) == 3:
                time_grids = []
            else:
                time_grids =[2, 4, 5]
            for grid_id in time_grids:
                total += compute_plane_smoothness(grids[grid_id])
        return total
    def _l1_regulation(self):
                # model.grids is 6 x [1, rank * F_dim, reso, reso]
        multi_res_grids = self._deformation.deformation_net.grid.grids

        total = 0.0
        for grids in multi_res_grids:
            if len(grids) == 3:
                continue
            else:
                # These are the spatiotemporal grids
                spatiotemporal_grids = [2, 4, 5]
            for grid_id in spatiotemporal_grids:
                total += torch.abs(1 - grids[grid_id]).mean()
        return total
    def compute_regulation(self, time_smoothness_weight, l1_time_planes_weight, plane_tv_weight):
        return plane_tv_weight * self._plane_regulation() + time_smoothness_weight * self._time_regulation() + l1_time_planes_weight * self._l1_regulation()

def getProjectionMatrix(znear, zfar, fovX, fovY):
    tanHalfFovY = math.tan((fovY / 2))
    tanHalfFovX = math.tan((fovX / 2))

    P = torch.zeros(4, 4)

    z_sign = 1.0

    P[0, 0] = 1 / tanHalfFovX
    P[1, 1] = 1 / tanHalfFovY
    P[3, 2] = z_sign
    P[2, 2] = z_sign * zfar / (zfar - znear)
    P[2, 3] = -(zfar * znear) / (zfar - znear)
    return P


class MiniCam:
    def __init__(self, c2w, width, height, fovy, fovx, znear, zfar,time=0 ):
        # c2w (pose) should be in NeRF convention.

        self.image_width = width
        self.image_height = height
        self.FoVy = fovy
        self.FoVx = fovx
        self.znear = znear
        self.zfar = zfar
        self.time = time
        w2c = np.linalg.inv(c2w)

        # rectify...
        w2c[1:3, :3] *= -1
        w2c[:3, 3] *= -1

        self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()
        self.projection_matrix = (
            getProjectionMatrix(
                znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy
            )
            .transpose(0, 1)
            .cuda()
        )
        self.full_proj_transform = self.world_view_transform @ self.projection_matrix
        self.camera_center = -torch.tensor(c2w[:3, 3]).cuda()


class Renderer:
    def __init__(self, sh_degree=3, white_background=True, radius=1):
        
        self.sh_degree = sh_degree
        self.white_background = white_background
        self.radius = radius

        self.gaussians = GaussianModel(sh_degree)

        self.bg_color = torch.tensor(
            [1, 1, 1] if white_background else [0, 0, 0],
            dtype=torch.float32,
            device="cuda",
        )
    
    def initialize(self, input=None, num_pts=5000, radius=0.5,initial_values=None):
        # load checkpoint
        if input is None:
            # init from random point cloud
            
            phis = np.random.random((num_pts,)) * 2 * np.pi
            costheta = np.random.random((num_pts,)) * 2 - 1
            thetas = np.arccos(costheta)
            mu = np.random.random((num_pts,))
            radius = radius * np.cbrt(mu)
            x = radius * np.sin(thetas) * np.cos(phis)
            y = radius * np.sin(thetas) * np.sin(phis)
            z = radius * np.cos(thetas)
            xyz = np.stack((x, y, z), axis=1)
            if initial_values is not None:
                print(xyz.shape,initial_values["mean"].shape)
                R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])
                xyz = np.dot(initial_values["mean"].numpy(),R)
            # xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3

            shs = np.random.random((num_pts, 3)) / 255.0
            pcd = BasicPointCloud(
                points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))
            )
            self.gaussians.create_from_pcd(pcd, 10)
        elif isinstance(input, BasicPointCloud):
            # load from a provided pcd
            self.gaussians.create_from_pcd(input, 1)
        else:
            # load from saved ply
            self.gaussians.load_ply(input)

    def render(
        self,
        viewpoint_camera,
        scaling_modifier=1.0,
        bg_color=None,
        override_color=None,
        compute_cov3D_python=False,
        convert_SHs_python=False,
        stage="fine",
        time_int = None,
        front_view=False,
    ):

 
        # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means
        screenspace_points = torch.zeros_like(self.gaussians.get_xyz, dtype=self.gaussians.get_xyz.dtype, requires_grad=True, device="cuda") + 0
        try:
            screenspace_points.retain_grad()
        except:
            pass

        # Set up rasterization configuration

        tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)
        tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)

        raster_settings = GaussianRasterizationSettings(
            image_height=int(viewpoint_camera.image_height),
            image_width=int(viewpoint_camera.image_width),
            tanfovx=tanfovx,
            tanfovy=tanfovy,
            bg=self.bg_color if bg_color is None else bg_color,
            scale_modifier=scaling_modifier,
            viewmatrix=viewpoint_camera.world_view_transform,
            projmatrix=viewpoint_camera.full_proj_transform,
            sh_degree=self.gaussians.active_sh_degree,
            campos=viewpoint_camera.camera_center,
            prefiltered=False,
            debug=False,
        )
        if front_view==True:
            print(viewpoint_camera.world_view_transform,viewpoint_camera.full_proj_transform)
        rasterizer = GaussianRasterizer(raster_settings=raster_settings)


        means3D = self.gaussians.get_xyz
        time = torch.tensor(viewpoint_camera.time).to(means3D.device).repeat(means3D.shape[0],1)
        means2D = screenspace_points
        opacity = self.gaussians._opacity

        # If precomputed 3d covariance is provided, use it. If not, then it will be computed from
        # scaling / rotation by the rasterizer.
        scales = None
        rotations = None
        cov3D_precomp = None
        if compute_cov3D_python:
            cov3D_precomp = self.gaussians.get_covariance(scaling_modifier)
        else:
            scales = self.gaussians._scaling
            rotations = self.gaussians._rotation
        deformation_point = self.gaussians._deformation_table

        if stage == "coarse" :
            means3D_deform, scales_deform, rotations_deform, opacity_deform = means3D, scales, rotations, opacity
        else:
            means3D_deform, scales_deform, rotations_deform, opacity_deform = self.gaussians._deformation(means3D[deformation_point], scales[deformation_point], 
                                                                         rotations[deformation_point], opacity[deformation_point],
                                                                         time[deformation_point])
        # print(time.max())
        with torch.no_grad():
            self.gaussians._deformation_accum[deformation_point] += torch.abs(means3D_deform-means3D[deformation_point])
        #print(torch.abs(means3D_deform-means3D[deformation_point]).mean())
        means3D_final = torch.zeros_like(means3D)
        rotations_final = torch.zeros_like(rotations)
        scales_final = torch.zeros_like(scales)
        opacity_final = torch.zeros_like(opacity)
        means3D_final[deformation_point] =  means3D_deform
        rotations_final[deformation_point] =  rotations_deform
        scales_final[deformation_point] =  scales_deform
        opacity_final[deformation_point] = opacity_deform
        means3D_final[~deformation_point] = means3D[~deformation_point]
        rotations_final[~deformation_point] = rotations[~deformation_point]
        scales_final[~deformation_point] = scales[~deformation_point]
        opacity_final[~deformation_point] = opacity[~deformation_point]

        
        scales_in=scales_final
        rotations_in=rotations_final
        opacity_in = opacity_final
        
        scales_final = self.gaussians.scaling_activation(scales_final)
        rotations_final = self.gaussians.rotation_activation(rotations_final)
        opacity = self.gaussians.opacity_activation(opacity)
        opacity_final = self.gaussians.opacity_activation(opacity_final)

        # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors
        # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.
        shs = None
        colors_precomp = None

        shs = self.gaussians.get_features

        # Rasterize visible Gaussians to image, obtain their radii (on screen).
        rendered_image, rendered_depth, normal, rendered_alpha ,radii, _ = rasterizer(
        means3D = means3D_final,
        means2D = means2D,
        shs = shs,
        colors_precomp = colors_precomp,
        opacities = opacity_final,
        scales = scales_final,
        rotations = rotations_final,
        cov3Ds_precomp = cov3D_precomp)


        return {
            "image": rendered_image,
            "depth": rendered_depth,
            "alpha": rendered_alpha,
            "viewspace_points": screenspace_points,
            "visibility_filter": radii > 0,
            "radii": radii,
            'xyz':means3D_final,
            'rot':rotations_in,
            'xy':means2D,
            'color':shs,
            'scales':scales_in,
            'opacity':opacity_in,
        }



================================================
FILE: guidance/zero123_4d_utils.py
================================================
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import (
    AutoencoderKL,
    UNet2DConditionModel,
    DDIMScheduler,
    StableDiffusionPipeline,
)
import torchvision.transforms.functional as TF

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.append('./')

from zero123 import Zero123Pipeline


class Zero123(nn.Module):
    def __init__(self, device, fp16=True, t_range=[0.02, 0.98]):
        super().__init__()

        self.device = device
        self.fp16 = fp16
        self.dtype = torch.float16 if fp16 else torch.float32

        self.pipe = Zero123Pipeline.from_pretrained(            
            # "bennyguo/zero123-diffusers",
            "ashawkey/zero123-xl-diffusers",
            # './model_cache/zero123_xl',
            variant="fp16" if self.fp16 else None,
            torch_dtype=self.dtype,
        ).to(self.device)

        # for param in self.pipe.parameters():
        #     param.requires_grad = False

        self.pipe.image_encoder.eval()
        self.pipe.vae.eval()
        self.pipe.unet.eval()
        self.pipe.clip_camera_projection.eval()

        self.vae = self.pipe.vae
        self.unet = self.pipe.unet

        self.pipe.set_progress_bar_config(disable=True)

        self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
        self.num_train_timesteps = self.scheduler.config.num_train_timesteps

        self.min_step = int(self.num_train_timesteps * t_range[0])
        self.max_step = int(self.num_train_timesteps * t_range[1])
        self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience
        self.min_step_percent = [0, 0.5, 0.02, 3000] 
        self.max_step_percent= [0, 0.95, 0.5, 3000]
        self.embeddings = None
        self.embedding_list = []

    @torch.no_grad()
    def get_img_embeds(self, x, input_imgs=None):
        # x: image tensor in [0, 1]
        x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False)
        x_pil = [TF.to_pil_image(image) for image in x]
        x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype)
        c = self.pipe.image_encoder(x_clip).image_embeds
        v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor
        self.embeddings = [c, v]
        self.additional_embeddings=[]
        if input_imgs!=None:
            
            for x in input_imgs:
                    x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False)
                    x_pil = [TF.to_pil_image(image) for image in x]
                    x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors="pt").pixel_values.to(device=self.device, dtype=self.dtype)
                    c = self.pipe.image_encoder(x_clip).image_embeds
                    v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor
                    embeddings = [c, v]
                    self.additional_embeddings.append(embeddings)
        self.embedding_list.append([self.embeddings,self.additional_embeddings])

    @torch.no_grad()
    def refine(self, pred_rgb, polar, azimuth, radius, 
               guidance_scale=5, steps=50, strength=0.8,
        ):

        batch_size = pred_rgb.shape[0]

        self.scheduler.set_timesteps(steps)

        if strength == 0:
            init_step = 0
            latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype)
        else:
            init_step = int(steps * strength)
            pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
            latents = self.encode_imgs(pred_rgb_256.to(self.dtype))
            latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])

        T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1)
        T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4]
        cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)
        cc_emb = self.pipe.clip_camera_projection(cc_emb)
        cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)

        vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1)
        vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)

        for i, t in enumerate(self.scheduler.timesteps[init_step:]):
            
            x_in = torch.cat([latents] * 2)
            t_in = torch.cat([t.view(1)] * 2).to(self.device)

            noise_pred = self.unet(
                torch.cat([x_in, vae_emb], dim=1),
                t_in.to(self.unet.dtype),
                encoder_hidden_states=cc_emb,
            ).sample

            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        imgs = self.decode_latents(latents) # [1, 3, 256, 256]
        return imgs
    
    def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False,idx=None,t=0):
        # pred_rgb: tensor [1, 3, H, W] in [0, 1]
        #print(polar)
        step_ratio = max(0.4,step_ratio)
        self.embeddings,self.additional_embeddings = self.embedding_list[t]
        batch_size = pred_rgb.shape[0]
        #print(self.embedding_list[1][0][0] -self.embedding_list[2][0][0])
        #print(self.embedding_list[1][0][1] -self.embedding_list[2][0][1])
        if idx is not None:
            embeddings = self.additional_embeddings[idx]
        else:
            embeddings = self.embeddings

        if as_latent:
            latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1
        else:
            pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)
            latents = self.encode_imgs(pred_rgb_256.to(self.dtype))


        t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)

        w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)

        with torch.no_grad():
            noise = torch.randn_like(latents)
            latents_noisy = self.scheduler.add_noise(latents, noise, t)

            x_in = torch.cat([latents_noisy] * 2)
            t_in = torch.cat([t] * 2)

            T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1)
            T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4]
            #print(self.embeddings[0].repeat(batch_size, 1, 1).shape,T.shape)
            cc_emb = torch.cat([embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)
            cc_emb = self.pipe.clip_camera_projection(cc_emb)
            cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)

            vae_emb = embeddings[1].repeat(batch_size, 1, 1, 1)
            vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)

            noise_pred = self.unet(
                torch.cat([x_in, vae_emb], dim=1),
                t_in.to(self.unet.dtype),
                encoder_hidden_states=cc_emb,
            ).sample

        noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

        grad = w * (noise_pred - noise)
        grad = torch.nan_to_num(grad)

        target = (latents - grad).detach()
        loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum')

        return loss
    def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
        min_step_percent = self.get_steps(self.min_step_percent, epoch, global_step)
        max_step_percent = self.get_steps(self.max_step_percent, epoch, global_step)
        self.min_step = int( self.num_train_timesteps * min_step_percent )
        self.max_step = int( self.num_train_timesteps * max_step_percent )
        
        
    def get_steps(self,percent,epoch, global_step):
        start_step, start_value, end_value, end_step = percent
        
        current_step = global_step
        value = start_value + (end_value - start_value) * max(
                min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0
            )
        
        return value

    def decode_latents(self, latents):
        latents = 1 / self.vae.config.scaling_factor * latents

        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)

        return imgs

    def encode_imgs(self, imgs, mode=False):
        # imgs: [B, 3, H, W]

        imgs = 2 * imgs - 1

        posterior = self.vae.encode(imgs).latent_dist
        if mode:
            latents = posterior.mode()
        else:
            latents = posterior.sample() 
        latents = latents * self.vae.config.scaling_factor

        return latents
    
    
if __name__ == '__main__':
    import cv2
    import argparse
    import numpy as np
    import matplotlib.pyplot as plt
    
    parser = argparse.ArgumentParser()

    parser.add_argument('input', type=str)
    parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')
    parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')
    parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')

    opt = parser.parse_args()

    device = torch.device('cuda')

    print(f'[INFO] loading image from {opt.input} ...')
    image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)
    image = image.astype(np.float32) / 255.0
    image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)

    print(f'[INFO] loading model ...')
    zero123 = Zero123(device)

    print(f'[INFO] running model ...')
    zero123.get_img_embeds(image)

    while True:
        outputs = zero123.refine(image, polar=[opt.polar], azimuth=[opt.azimuth], radius=[opt.radius], strength=0)
        plt.imshow(outputs.float().cpu().numpy().transpose(0, 2, 3, 1)[0])
        plt.show()

================================================
FILE: guidance/zero123pp/pipeline.py
================================================
from typing import Any, Dict, Optional
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.schedulers import KarrasDiffusionSchedulers

import numpy
import torch
import torch.nn as nn
import torch.utils.checkpoint
import torch.distributed
import transformers
from collections import OrderedDict
from PIL import Image
from torchvision import transforms
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
import torch
import torch.nn.functional as F
from torch import nn
import diffusers
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    DiffusionPipeline,
    EulerAncestralDiscreteScheduler,
    UNet2DConditionModel,
    ImagePipelineOutput
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor
from diffusers.utils.import_utils import is_xformers_available

import os
FIRST = True
IDX = 0
PATH = '/home/vision/github/embeddings/'
EMBED=[]

def to_rgb_image(maybe_rgba: Image.Image):
    if maybe_rgba.mode == 'RGB':
        return maybe_rgba
    elif maybe_rgba.mode == 'RGBA':
        rgba = maybe_rgba
        img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)
        img = Image.fromarray(img, 'RGB')
        img.paste(rgba, mask=rgba.getchannel('A'))
        return img
    else:
        raise ValueError("Unsupported image type.", maybe_rgba.mode)
    
class MyAttnProcessor2_0:
    r"""
    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
    """

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("MyAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: Attention,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        temb=None,
        is_self_attention=False
    ):
        residual = hidden_states

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )

        if attention_mask is not None:
            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
            # scaled_dot_product_attention expects attention_mask shape to be
            # (batch, heads, source_length, target_length)
            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
        
        if is_self_attention:
            global FIRST
            global IDX
            global EMBED
            if FIRST == True:
                EMBED.append(encoder_hidden_states.to('cpu'))
                #print('saving to {})'.format(PATH,str(IDX)+'_hidden.pt'))
                #os.makedirs(PATH,exist_ok=True)
                #torch.save(encoder_hidden_states,os.path.join(PATH,str(IDX)+'_hidden.pt'))
                IDX=IDX+1
            else:
                last_shape = encoder_hidden_states.shape[-1]
                replace_dim = int(9600/(last_shape//320)**2)
                encoder_hidden_states_load = EMBED[IDX].to('cuda')
                encoder_hidden_states[:,:replace_dim,:]=(encoder_hidden_states_load[:,:replace_dim,:]+encoder_hidden_states[:,:replace_dim,:])/2
                IDX=IDX+1
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        #print(key.shape)


        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        # TODO: add support for attn.scale when we move to Torch 2.1
        hidden_states = F.scaled_dot_product_attention(
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
        )

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states

class ReferenceOnlyAttnProc(torch.nn.Module):
    def __init__(
        self,
        chained_proc,
        enabled=False,
        name=None
    ) -> None:
        super().__init__()
        self.enabled = enabled
        self.chained_proc = chained_proc
        self.name = name

    def __call__(
        self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,
        mode="w", ref_dict: dict = None, is_cfg_guidance = False
    ) -> Any:
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        is_self_attention = False
        if self.enabled and is_cfg_guidance:
            res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask)
            hidden_states = hidden_states[1:]
            encoder_hidden_states = encoder_hidden_states[1:]
        if self.enabled:
            if mode == 'w':
                ref_dict[self.name] = encoder_hidden_states
            elif mode == 'r':
                encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
                is_self_attention = True
            elif mode == 'm':
                encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)
            else:
                assert False, mode
        res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask,is_self_attention=is_self_attention)
        if self.enabled and is_cfg_guidance:
            res = torch.cat([res0, res])
        return res


class RefOnlyNoisedUNet(torch.nn.Module):
    def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:
        super().__init__()
        self.unet = unet
        self.train_sched = train_sched
        self.val_sched = val_sched

        unet_lora_attn_procs = dict()
        for name, _ in unet.attn_processors.items():
            if torch.__version__ >= '2.0':
                default_attn_proc = MyAttnProcessor2_0()
                print('using my attention')
            elif is_xformers_available():
                default_attn_proc = XFormersAttnProcessor()
            else:
                default_attn_proc = AttnProcessor()
            unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(
                default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
            )
        unet.set_attn_processor(unet_lora_attn_procs)

    def __getattr__(self, name: str):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.unet, name)

    def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):
        if is_cfg_guidance:
            encoder_hidden_states = encoder_hidden_states[1:]
            class_labels = class_labels[1:]
        self.unet(
            noisy_cond_lat, timestep,
            encoder_hidden_states=encoder_hidden_states,
            class_labels=class_labels,
            cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
            **kwargs
        )

    def forward(
        self, sample, timestep, encoder_hidden_states, class_labels=None,
        *args, cross_attention_kwargs,
        down_block_res_samples=None, mid_block_res_sample=None,
        **kwargs
    ):
        

        cond_lat = cross_attention_kwargs['cond_lat']
        is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)
        noise = torch.randn_like(cond_lat)
        if self.training:
            noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
            noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
        else:
            noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
            noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
            #if cross_attention_kwargs['cond_lat_back'] is not None:
            #    cond_lat_back = cross_attention_kwargs['cond_lat_back']
            #    noisy_cond_lat_back = self.val_sched.add_noise(cond_lat_back, noise, timestep.reshape(-1))
            #    noisy_cond_lat_back = self.val_sched.scale_model_input(noisy_cond_lat_back, timestep.reshape(-1))

        ref_dict = {}
        self.forward_cond(
            noisy_cond_lat, timestep,
            encoder_hidden_states, class_labels,
            ref_dict, is_cfg_guidance, **kwargs
        )
        weight_dtype = self.unet.dtype
        return self.unet(
            sample, timestep,
            encoder_hidden_states, *args,
            class_labels=class_labels,
            cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),
            down_block_additional_residuals=[
                sample.to(dtype=weight_dtype) for sample in down_block_res_samples
            ] if down_block_res_samples is not None else None,
            mid_block_additional_residual=(
                mid_block_res_sample.to(dtype=weight_dtype)
                if mid_block_res_sample is not None else None
            ),
            **kwargs
        )


def scale_latents(latents):
    latents = (latents - 0.22) * 0.75
    return latents


def unscale_latents(latents):
    latents = latents / 0.75 + 0.22
    return latents


def scale_image(image):
    image = image * 0.5 / 0.8
    return image


def unscale_image(image):
    image = image / 0.5 * 0.8
    return image


class DepthControlUNet(torch.nn.Module):
    def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None:
        super().__init__()
        self.unet = unet
        if controlnet is None:
            self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet)
        else:
            self.controlnet = controlnet
        DefaultAttnProc = MyAttnProcessor2_0
        if is_xformers_available():
            DefaultAttnProc = XFormersAttnProcessor
        self.controlnet.set_attn_processor(DefaultAttnProc())
        self.conditioning_scale = conditioning_scale

    def __getattr__(self, name: str):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.unet, name)

    def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs):
        cross_attention_kwargs = dict(cross_attention_kwargs)
        control_depth = cross_attention_kwargs.pop('control_depth')
        down_block_res_samples, mid_block_res_sample = self.controlnet(
            sample,
            timestep,
            encoder_hidden_states=encoder_hidden_states,
            controlnet_cond=control_depth,
            conditioning_scale=self.conditioning_scale,
            return_dict=False,
        )
        return self.unet(
            sample,
            timestep,
            encoder_hidden_states=encoder_hidden_states,
            down_block_res_samples=down_block_res_samples,
            mid_block_res_sample=mid_block_res_sample,
            cross_attention_kwargs=cross_attention_kwargs
        )


class ModuleListDict(torch.nn.Module):
    def __init__(self, procs: dict) -> None:
        super().__init__()
        self.keys = sorted(procs.keys())
        self.values = torch.nn.ModuleList(procs[k] for k in self.keys)

    def __getitem__(self, key):
        return self.values[self.keys.index(key)]


class SuperNet(torch.nn.Module):
    def __init__(self, state_dict: Dict[str, torch.Tensor]):
        super().__init__()
        state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys()))
        self.layers = torch.nn.ModuleList(state_dict.values())
        self.mapping = dict(enumerate(state_dict.keys()))
        self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}

        # .processor for unet, .self_attn for text encoder
        self.split_keys = [".processor", ".self_attn"]

        # we add a hook to state_dict() and load_state_dict() so that the
        # naming fits with `unet.attn_processors`
        def map_to(module, state_dict, *args, **kwargs):
            new_state_dict = {}
            for key, value in state_dict.items():
                num = int(key.split(".")[1])  # 0 is always "layers"
                new_key = key.replace(f"layers.{num}", module.mapping[num])
                new_state_dict[new_key] = value

            return new_state_dict

        def remap_key(key, state_dict):
            for k in self.split_keys:
                if k in key:
                    return key.split(k)[0] + k
            return key.split('.')[0]

        def map_from(module, state_dict, *args, **kwargs):
            all_keys = list(state_dict.keys())
            for key in all_keys:
                replace_key = remap_key(key, state_dict)
                new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
                state_dict[new_key] = state_dict[key]
                del state_dict[key]

        self._register_state_dict_hook(map_to)
        self._register_load_state_dict_pre_hook(map_from, with_module=True)


class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
    tokenizer: transformers.CLIPTokenizer
    text_encoder: transformers.CLIPTextModel
    vision_encoder: transformers.CLIPVisionModelWithProjection

    feature_extractor_clip: transformers.CLIPImageProcessor
    unet: UNet2DConditionModel
    scheduler: diffusers.schedulers.KarrasDiffusionSchedulers

    vae: AutoencoderKL
    ramping: nn.Linear

    feature_extractor_vae: transformers.CLIPImageProcessor

    depth_transforms_multi = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: KarrasDiffusionSchedulers,
        vision_encoder: transformers.CLIPVisionModelWithProjection,
        feature_extractor_clip: CLIPImageProcessor, 
        feature_extractor_vae: CLIPImageProcessor,
        ramping_coefficients: Optional[list] = None,
        safety_checker=None,
    ):
        DiffusionPipeline.__init__(self)

        self.register_modules(
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
            unet=unet, scheduler=scheduler, safety_checker=None,
            vision_encoder=vision_encoder,
            feature_extractor_clip=feature_extractor_clip,
            feature_extractor_vae=feature_extractor_vae
        )
        self.register_to_config(ramping_coefficients=ramping_coefficients)
        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

    def prepare(self):
        train_sched = DDPMScheduler.from_config(self.scheduler.config)
        if isinstance(self.unet, UNet2DConditionModel):
            self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()

    def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0):
        self.prepare()
        self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)
        return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)]))

    def encode_condition_image(self, image: torch.Tensor):
        image = self.vae.encode(image).latent_dist.sample()
        return image

    @torch.no_grad()
    def __call__(
        self,
        image: Image.Image = None,
        prompt = "",
        *args,
        num_images_per_prompt: Optional[int] = 1,
        guidance_scale=4.0,
        depth_image: Image.Image = None,
        output_type: Optional[str] = "pil",
        width=640,
        height=960,
        num_inference_steps=28,
        return_dict=True,
        is_first = False,
        **kwargs
    ):
        global FIRST
        FIRST = is_first
        global IDX
        IDX = 0
        if is_first:
            global EMBED
            EMBED=[]
        
        # Create a generator with the specified seed
        generator = torch.Generator(device='cuda')
        generator.manual_seed(42)
        self.prepare()
        if image is None:
            raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.")
        assert not isinstance(image, torch.Tensor)
        image = to_rgb_image(image)
        image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
        image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
        if depth_image is not None and hasattr(self.unet, "controlnet"):
            depth_image = to_rgb_image(depth_image)
            depth_image = self.depth_transforms_multi(depth_image).to(
                device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype
            )
        image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
        image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
        cond_lat = self.encode_condition_image(image)
        if guidance_scale > 1:
            negative_lat = self.encode_condition_image(torch.zeros_like(image))
            cond_lat = torch.cat([negative_lat, cond_lat])
        encoded = self.vision_encoder(image_2, output_hidden_states=False)
        global_embeds = encoded.image_embeds
        global_embeds = global_embeds.unsqueeze(-2)
        
        encoder_hidden_states = self._encode_prompt(
            prompt,
            self.device,
            num_images_per_prompt,
            False
        )
        ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
        encoder_hidden_states = encoder_hidden_states + global_embeds * ramp
        cak = dict(cond_lat=cond_lat)
        if hasattr(self.unet, "controlnet"):
            cak['control_depth'] = depth_image
        
        cak['cond_lat_back'] = None

        latents: torch.Tensor = super().__call__(
            None,
            *args,
            cross_attention_kwargs=cak,
            guidance_scale=guidance_scale,
            num_images_per_prompt=num_images_per_prompt,
            prompt_embeds=encoder_hidden_states,
            num_inference_steps=num_inference_steps,
            output_type='latent',
            width=width,
            height=height,
            generator=generator,
            **kwargs
        ).images
        latents = unscale_latents(latents)
        if not output_type == "latent":
            image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
        else:
            image = latents
        
        image = self.image_processor.postprocess(image, output_type=output_type)

        if not return_dict:
            return (image,)
        

        return ImagePipelineOutput(images=image)

================================================
FILE: main.py
================================================
import os
import cv2
import time
import tqdm
import numpy as np
import dearpygui.dearpygui as dpg

import torch
import torch.nn.functional as F
import torchvision.utils as vutils
from einops import rearrange, repeat
import imageio
import rembg

from cam_utils import orbit_camera, OrbitCamera
from gs_renderer_4d import Renderer, MiniCam
from dataset_4d import SparseDataset

def save_image_to_local(image_tensor, file_path):
    # Ensure the image tensor is in the range [0, 1]
    image_tensor = image_tensor.clamp(0, 1) 

    # Save the image tensor to the specified file path
    vutils.save_image(image_tensor, file_path)

class GUI:
    def __init__(self, opt):
        self.opt = opt  # shared with the trainer's opt to support in-place modification of rendering parameters.
        self.gui = opt.gui # enable gui
        self.W = opt.W
        self.H = opt.H
        self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)

        self.mode = "image"
        self.seed = "random"

        self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)
        self.need_update = True  # update buffer_image

        # models
        self.device = torch.device("cuda")
        self.bg_remover = None

        self.guidance_sd = None
        self.guidance_zero123 = None

        self.enable_sd = False
        self.enable_zero123 = False

        # renderer
        self.renderer = Renderer(sh_degree=self.opt.sh_degree)
        self.gaussain_scale_factor = 1

        # input image
        self.input_img = None
        self.input_mask = None
        self.input_img_torch = None
        self.input_mask_torch = None
        self.overlay_input_img = False
        self.overlay_input_img_ratio = 0.5
        
        #self.use_depth = opt.use_depth

        # input text
        self.prompt = ""
        self.negative_prompt = ""

        # training stuff
        self.training = False
        self.optimizer = None
        self.step = 0
        self.t = 0
        self.time = 0
        self.train_steps = 1  # steps per rendering loop
        self.init = True
        self.stage = 'coarse'
        self.path = self.opt.path
        self.save_step = self.opt.save_step
        
        if self.opt.size is not None:
            self.size = self.opt.size
        else:
            self.size = len(os.listdir(os.path.join(self.path,'ref')))
        self.frames=self.size
        self.dataset = SparseDataset(self.opt, self.size, H=self.H, W=self.W, device=self.device)
        self.dataloader =self.dataset.dataloader()
        self.iter = iter(self.dataloader)
        self.ref_view_batch, self.input_mask_batch,self.zero123_view_batch,self.zero123_masks_batch = next(self.iter)
        self.input_img_torch_batch,self.input_mask_torch_batch,self.zero123plus_imgs_torch_batch,self.zero123plus_masks_torch_batch=[],[],[],[]
        

        # load input data from cmdline
        if self.opt.input is not None:
            self.load_input(self.opt.input)
        
        # override prompt from cmdline
        if self.opt.prompt is not None:
            self.prompt = self.opt.prompt

        # override if provide a checkpoint
        
        self.renderer.initialize(num_pts=self.opt.num_pts)            

        self.point_nums = []
        if self.gui:
            dpg.create_context()
            self.register_dpg()
            self.test_step()

    def __del__(self):
        if self.gui:
            dpg.destroy_context()

    def seed_everything(self):
        try:
            seed = int(self.seed)
        except:
            seed = np.random.randint(0, 1000000)

        os.environ["PYTHONHASHSEED"] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = True

        self.last_seed = seed
        
    def prepare_image(self,idx):
        # input image
        if self.input_img is not None:
            self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device)
            self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)

            self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)
            self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False)
        
        self.zero123plus_imgs_torch=[]
        self.zero123plus_masks_torch=[]
        # input image
        if self.input_imgs is not None:
            for i in np.arange(6):
                #print(idx,i)
                self.input_img2_torch=(torch.from_numpy(self.input_imgs[i]).permute(2, 0, 1).unsqueeze(0).to(self.device))
                self.zero123plus_imgs_torch.append(F.interpolate(self.input_img2_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False))

                self.input_mask2_torch=torch.from_numpy(self.input_masks[i]).permute(2, 0, 1).unsqueeze(0).to(self.device)
                self.zero123plus_masks_torch.append(F.interpolate(self.input_mask2_torch, (self.opt.ref_size, self.opt.ref_size), mode="bilinear", align_corners=False))
                
        self.input_img_torch_batch.append(self.input_img_torch)
        self.input_mask_torch_batch.append(self.input_mask_torch)
        self.zero123plus_imgs_torch_batch.append(self.zero123plus_imgs_torch)
        self.zero123plus_masks_torch_batch.append(self.zero123plus_masks_torch)
        
        # prepare embeddings
        with torch.no_grad():
            self.guidance_zero123.get_img_embeds(self.input_img_torch, self.zero123plus_imgs_torch)


    def prepare_train(self):

        self.step = 0
        self.end_step = self.save_step+1
        
        ## given a load_path, load corresponding model
        if self.opt.load_path is not None:
           if self.opt.load_step is not None:
               self.step = self.opt.load_step
           else:
               #default loading save_step ply
               self.step = self.save_step 
           auto_path = os.path.join(self.opt.outdir,self.opt.load_path + str(self.step))

           ply_path = os.path.join(auto_path,'model.ply')
           self.renderer.gaussians.load_model(auto_path)
           self.renderer.gaussians.load_ply(ply_path)
           self.end_step =self.step+self.end_step
        
        ## setup training
        self.renderer.gaussians.training_setup(self.opt)
        
        ## do not do progressive sh-level
        self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree
        self.optimizer = self.renderer.gaussians.optimizer
        
        # default camera
        pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)
        self.fixed_cam = MiniCam(
            pose,
            self.opt.ref_size,
            self.opt.ref_size,
            self.cam.fovy,
            self.cam.fovx,
            self.cam.near,
            self.cam.far,
        )
        self.set_fix_cam()
        self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != ""
        self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None


        print(f"[INFO] loading zero123...")
        from guidance.zero123_4d_utils import Zero123
        self.guidance_zero123 = Zero123(self.device)
        print(f"[INFO] loaded zero123!")

        ## load multiview reference images
        for i in np.arange(len(self.ref_view_batch)):
                self.input_img =   self.ref_view_batch[i]
                self.input_mask =  self.input_mask_batch[i]
                self.input_imgs =  self.zero123_view_batch[i]
                self.input_masks = self.zero123_masks_batch[i]
                self.prepare_image(i)


    def train_step(self):
        starter = torch.cuda.Event(enable_timing=True)
        ender = torch.cuda.Event(enable_timing=True)
        starter.record()


        
        torch.autograd.set_detect_anomaly(True)
        for _ in range(self.train_steps):

            if self.step<self.opt.init_steps:
                self.init = True
                self.stage = 'coarse'
            else:
                self.init = False
                self.stage = 'fine'
            
            if self.step == self.end_step:
                exit()
                
            ## save model
            if self.step == self.save_step:
                auto_path = os.path.join(self.opt.outdir,self.opt.save_path + str(self.step))
                os.makedirs(auto_path,exist_ok=True)
                ply_path = os.path.join(auto_path,'model.ply')
                self.renderer.gaussians.save_ply(ply_path)
                self.renderer.gaussians.save_deformation(auto_path)
                
            
            if self.step>self.opt.position_lr_max_steps:
                self.opt.position_lr_max_steps = self.opt.position_lr_max_steps2

            self.step += 1
            step_ratio = min(1, self.step / self.opt.iters)
            viewspace_point_tensor_list = []
            radii_list = []
            visibility_filter_list = []
            # update lr
            self.renderer.gaussians.update_learning_rate(self.step)
            self.guidance_zero123.update_step(0,self.step)

            loss = 0
            
            if self.step%self.opt.valid_interval == 0:
                self.save_renderings( 0, 0, 2 ,'front')
                self.save_renderings( 180, 0, 2 ,'back')
                

            render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512)

            # avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]
            min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation)
            max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation)


            for _ in np.arange(self.opt.batch_size):
                
                #sample time
                if self.init:
                        self.t = self.frames//2
                        self.time = self.t/self.frames
                else:   
                        self.t = np.random.randint(0,self.frames)
                        self.time = self.t/self.frames

                self.input_img_torch =   self.input_img_torch_batch[self.t]
                self.input_mask_torch =  self.input_mask_torch_batch[self.t]
                self.zero123plus_imgs_torch =  self.zero123plus_imgs_torch_batch[self.t]
                self.zero123plus_masks_torch = self.zero123plus_masks_torch_batch[self.t]
                
                ## need to do rgb loss in the batch
                cur_cam = self.fixed_cam
                cur_cam.time=self.time
                
                out = self.renderer.render(cur_cam,stage=self.stage)
                viewspace_point_tensor, visibility_filter, radii = out["viewspace_points"], out["visibility_filter"], out["radii"]  
                radii_list.append(radii.unsqueeze(0))
                visibility_filter_list.append(visibility_filter.unsqueeze(0))
                viewspace_point_tensor_list.append(viewspace_point_tensor)
                # rgb loss
                image = out["image"].unsqueeze(0) # [1, 3, H, W] in [0, 1]
                image_loss =step_ratio* 20000*  F.mse_loss(image, self.input_img_torch)
                loss = loss + image_loss
                
                alpha = out["alpha"].unsqueeze(0)
                alpha_loss = step_ratio* 5000*  F.mse_loss(alpha, self.input_mask_torch)
                loss = loss + alpha_loss
                
                images = []
                poses = []

                vers_plus, hors_plus, radii_plus = [], [], []
                self.guidance_zero123.update_step(1,self.step)
                # render random view
                ver = np.random.randint(min_ver, max_ver)
                hor = np.random.randint(-180, 180)
                radius = 0
                

                


                pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)
                
                poses.append(pose)

                cur_cam = MiniCam(
                    pose,
                    render_resolution,
                    render_resolution,
                    self.cam.fovy,
                    self.cam.fovx,
                    self.cam.near,
                    self.cam.far,
                )
                cur_cam.time=self.time
                if hor<30 and hor>-30 or np.random.rand()>0.4:
                    idx=None
                    vers_plus.append(torch.tensor(ver,device=self.device).unsqueeze(dim=0))
                    hors_plus.append(torch.tensor(hor,device=self.device).unsqueeze(dim=0))
                    radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0))
                elif hor>0:
                    idx=hor//60
                    vers_plus.append(torch.tensor(ver-self.fixed_elevation[idx],device=self.device).unsqueeze(dim=0))
                    hors_plus.append(torch.tensor(hor-self.fixed_azimuth[idx],device=self.device).unsqueeze(dim=0))
                    radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0))
                elif hor<0:
                    idx = (360+hor)//60
                    vers_plus.append(torch.tensor(ver-self.fixed_elevation[idx],device=self.device).unsqueeze(dim=0))
                    hors_plus.append(torch.tensor(hor-self.fixed_azimuth[idx],device=self.device).unsqueeze(dim=0))
                    radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0))

                bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device="cuda")
                out = self.renderer.render(cur_cam, bg_color=bg_color,stage=self.stage)
                viewspace_point_tensor, visibility_filter, radii_rendering = out["viewspace_points"], out["visibility_filter"], out["radii"]  
                radii_list.append(radii_rendering.unsqueeze(0))
                visibility_filter_list.append(visibility_filter.unsqueeze(0))
                viewspace_point_tensor_list.append(viewspace_point_tensor)
                image = out["image"].unsqueeze(0)# [1, 3, H, W] in [0, 1]
                images.append(image)
                
                
            
                images_render = torch.cat(images, dim=0)
                #poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)
                vers_batch = torch.cat(vers_plus, dim=0).cpu().numpy()
                hors_batch = torch.cat(hors_plus, dim=0).cpu().numpy()
                radii_batch = torch.cat(radii_plus, dim=0).cpu().numpy()

                # guidance loss
                # as we have different reference views, so each time we only pass 1 image into zero123 for guidance
                zero123_loss = self.opt.lambda_zero123 * self.guidance_zero123.train_step(images_render, vers_batch, hors_batch, radii_batch, step_ratio,idx=idx,t = self.t)
                loss = loss + zero123_loss
            
            # tv loss
            scales = out['scales']
            tv_loss = self.renderer.gaussians.compute_regulation(self.opt.time_smoothness_weight, self.opt.plane_tv_weight, self.opt.l1_time_planes)
            loss += self.opt.lambda_tv * tv_loss
            
            # scale loss from physgaussian
            r = self.opt.scale_loss_ratio
            scale_loss = (torch.mean(torch.maximum(torch.max(scales,dim=1).values/ \
                                                  (torch.min(scales,dim=1).values+1e-8),\
                                                    torch.ones_like(torch.max(scales,dim=1).values)*r))-r) * scales.shape[0]
            loss += scale_loss

            # optimize step
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            viewspace_point_tensor_grad = torch.zeros_like(viewspace_point_tensor)
            for idx in range(0, len(viewspace_point_tensor_list)):
                    viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad
            radii = torch.cat(radii_list,0).max(dim=0).values
            visibility_filter = torch.cat(visibility_filter_list).any(dim=0)
            if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:
                self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
                self.renderer.gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter)
                if self.step % self.opt.densification_interval == 1 :

                    self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold_percent, min_opacity=0.01, extent=1, max_screen_size=2)



        ender.record()
        torch.cuda.synchronize()
        t = starter.elapsed_time(ender)

        self.need_update = True

        if self.gui:
            dpg.set_value("_log_train_time", f"{t:.4f}ms")
            dpg.set_value(
                "_log_train_log",
                 f"step = {self.step: 5d} (+{self.train_steps: 2d})\n loss = {loss.item():.4f}\nzero123_loss = {zero123_loss.item():.4f}image_loss ={image_loss.item():.4f}\nloss_alpha = {alpha_loss.item():.4f} scale_loss:{scale_loss.item():.4f} ",
            )

    def set_fix_cam(self):
        self.fixed_cam_plus=[]
        self.fixed_elevation = []
        self.fixed_azimuth = []
        
        pose = orbit_camera(self.opt.elevation-30,30 , self.opt.radius)
        self.fixed_elevation.append(-30)
        self.fixed_azimuth.append(30)
        self.fixed_cam_plus.append(MiniCam(
                pose,
                self.opt.ref_size,
                self.opt.ref_size,
                self.cam.fovy,
                self.cam.fovx,
                self.cam.near,
                self.cam.far,
            ))
        
        pose = orbit_camera(self.opt.elevation+20, 90, self.opt.radius)
        self.fixed_elevation.append(20)
        self.fixed_azimuth.append(90)
        self.fixed_cam_plus.append(MiniCam(
                pose,
                self.opt.ref_size,
                self.opt.ref_size,
                self.cam.fovy,
                self.cam.fovx,
                self.cam.near,
                self.cam.far,
            ))
        pose = orbit_camera(self.opt.elevation-30, 150, self.opt.radius)
        self.fixed_elevation.append(-30)
        self.fixed_azimuth.append(150)
        self.fixed_cam_plus.append(MiniCam(
                pose,
                self.opt.ref_size,
                self.opt.ref_size,
                self.cam.fovy,
                self.cam.fovx,
                self.cam.near,
                self.cam.far,
            ))
        
        pose = orbit_camera(self.opt.elevation+20, 210, self.opt.radius)
        self.fixed_elevation.append(+20)
        self.fixed_azimuth.append(210)
        self.fixed_cam_plus.append(MiniCam(
                pose,
                self.opt.ref_size,
                self.opt.ref_size,
                self.cam.fovy,
                self.cam.fovx,
                self.cam.near,
                self.cam.far,
            ))
        
        pose = orbit_camera(self.opt.elevation-30, 270, self.opt.radius)
        self.fixed_elevation.append(-30)
        self.fixed_azimuth.append(270)
        self.fixed_cam_plus.append(MiniCam(
                pose,
                self.opt.ref_size,
                self.opt.ref_size,
                self.cam.fovy,
                self.cam.fovx,
                self.cam.near,
                self.cam.far,
            ))
        
        pose = orbit_camera(self.opt.elevation+20, 330, self.opt.radius)
        self.fixed_elevation.append(20)
        self.fixed_azimuth.append(330)
        self.fixed_cam_plus.append(MiniCam(
                pose,
                self.opt.ref_size,
                self.opt.ref_size,
                self.cam.fovy,
                self.cam.fovx,
                self.cam.near,
                self.cam.far,
            ))
        
    @torch.no_grad()
    def test_step(self):
        # ignore if no need to update
        if not self.need_update:
            return

        starter = torch.cuda.Event(enable_timing=True)
        ender = torch.cuda.Event(enable_timing=True)
        starter.record()

        # should update image
        if self.need_update:
            # render image

            cur_cam = MiniCam(
                self.cam.pose,
                self.W,
                self.H,
                self.cam.fovy,
                self.cam.fovx,
                self.cam.near,
                self.cam.far,
                time=self.time
            )
            #print(cur_cam.time)
            out = self.renderer.render(cur_cam, self.gaussain_scale_factor,stage=self.stage)

            buffer_image = out[self.mode]  # [3, H, W]

            if self.mode in ['depth', 'alpha']:
                buffer_image = buffer_image.repeat(3, 1, 1)
                if self.mode == 'depth':
                    buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)

            buffer_image = F.interpolate(
                buffer_image.unsqueeze(0),
                size=(self.H, self.W),
                mode="bilinear",
                align_corners=False,
            ).squeeze(0)

            self.buffer_image = (
                buffer_image.permute(1, 2, 0)
                .contiguous()
                .clamp(0, 1)
                .contiguous()
                .detach()
                .cpu()
                .numpy()
            )

            # display input_image
            if self.overlay_input_img and self.input_img is not None:
                self.buffer_image = (
                    self.buffer_image * (1 - self.overlay_input_img_ratio)
                    + self.input_img * self.overlay_input_img_ratio
                )

            self.need_update = False

        ender.record()
        torch.cuda.synchronize()
        t = starter.elapsed_time(ender)

        if self.gui:
            dpg.set_value("_log_infer_time", f"{t:.4f}ms ({int(1000/t)} FPS)")
            dpg.set_value(
                "_texture", self.buffer_image
            )  # buffer must be contiguous, else seg fault!

    
    def load_input(self, file):
        # load image
        pass
        # load image

    @torch.no_grad()
    def save_renderings(self, elev=0, azim=0, radius=2, name='front'):
        images=[]
        for i in np.arange(self.frames):
            
            pose = orbit_camera(elev, azim, radius)
            cam = MiniCam(
            pose,
            self.opt.ref_size,
            self.opt.ref_size,
            self.cam.fovy,
            self.cam.fovx,
            self.cam.near,
            self.cam.far,
            )   
            cam.time=float(i/self.frames)
            out = self.renderer.render(cam,stage=self.stage)
            image = out["image"].unsqueeze(0)
            images.append(image)
            os.makedirs(f'./valid/{self.opt.save_path}/{self.step}_{name}',exist_ok=True)
            save_image_to_local(image[0].detach(),f'./valid/{self.opt.save_path}/{self.step}_{name}/{str(i).zfill(2)}.jpg')
        samples=torch.cat(images,dim=0)
        
        vid = (
            (rearrange(samples, "t c h w -> t h w c") * 255).clamp(0,255).detach()
            .cpu()
            .numpy()
            .astype(np.uint8)
        )
        video_path = f'./valid/{self.opt.save_path}/{self.step}_{name}/video.mp4'
        imageio.mimwrite(video_path, vid)


    @torch.no_grad()
    def save_model(self, mode='geo', texture_size=1024):
        os.makedirs(self.opt.outdir, exist_ok=True)
        if mode == 'geo':
            path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')
            self.renderer.gaussians.save_ply(path)

        elif mode == 'geo+tex':
            path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')
            self.renderer.gaussians.save_ply(path)

        else:
            path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')
            self.renderer.gaussians.save_ply(path)

        print(f"[INFO] save model to {path}.")

    def register_dpg(self):
        ### register texture

        with dpg.texture_registry(show=False):
            dpg.add_raw_texture(
                self.W,
                self.H,
                self.buffer_image,
                format=dpg.mvFormat_Float_rgb,
                tag="_texture",
            )

        ### register window

        # the rendered image, as the primary window
        with dpg.window(
            tag="_primary_window",
            width=self.W,
            height=self.H,
            pos=[0, 0],
            no_move=True,
            no_title_bar=True,
            no_scrollbar=True,
        ):
            # add the texture
            dpg.add_image("_texture")

        # dpg.set_primary_window("_primary_window", True)

        # control window
        with dpg.window(
            label="Control",
            tag="_control_window",
            width=600,
            height=self.H,
            pos=[self.W, 0],
            no_move=True,
            no_title_bar=True,
        ):
            # button theme
            with dpg.theme() as theme_button:
                with dpg.theme_component(dpg.mvButton):
                    dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))
                    dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))
                    dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))
                    dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)
                    dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)

            # timer stuff
            with dpg.group(horizontal=True):
                dpg.add_text("Infer time: ")
                dpg.add_text("no data", tag="_log_infer_time")

            def callback_setattr(sender, app_data, user_data):
                setattr(self, user_data, app_data)

            # init stuff
            with dpg.collapsing_header(label="Initialize", default_open=True):

                # seed stuff
                def callback_set_seed(sender, app_data):
                    self.seed = app_data
                    self.seed_everything()

                dpg.add_input_text(
                    label="seed",
                    default_value=self.seed,
                    on_enter=True,
                    callback=callback_set_seed,
                )

                # input stuff
                def callback_select_input(sender, app_data):
                    # only one item
                    for k, v in app_data["selections"].items():
                        dpg.set_value("_log_input", k)
                        self.load_input(v)

                    self.need_update = True

                with dpg.file_dialog(
                    directory_selector=False,
                    show=False,
                    callback=callback_select_input,
                    file_count=1,
                    tag="file_dialog_tag",
                    width=700,
                    height=400,
                ):
                    dpg.add_file_extension("Images{.jpg,.jpeg,.png}")

                with dpg.group(horizontal=True):
                    dpg.add_button(
                        label="input",
                        callback=lambda: dpg.show_item("file_dialog_tag"),
                    )
                    dpg.add_text("", tag="_log_input")
                
                # overlay stuff
                with dpg.group(horizontal=True):

                    def callback_toggle_overlay_input_img(sender, app_data):
                        self.overlay_input_img = not self.overlay_input_img
                        self.need_update = True

                    dpg.add_checkbox(
                        label="overlay image",
                        default_value=self.overlay_input_img,
                        callback=callback_toggle_overlay_input_img,
                    )

                    def callback_set_overlay_input_img_ratio(sender, app_data):
                        self.overlay_input_img_ratio = app_data
                        self.need_update = True

                    dpg.add_slider_float(
                        label="ratio",
                        min_value=0,
                        max_value=1,
                        format="%.1f",
                        default_value=self.overlay_input_img_ratio,
                        callback=callback_set_overlay_input_img_ratio,
                    )

                # prompt stuff
            
                dpg.add_input_text(
                    label="prompt",
                    default_value=self.prompt,
                    callback=callback_setattr,
                    user_data="prompt",
                )

                dpg.add_input_text(
                    label="negative",
                    default_value=self.negative_prompt,
                    callback=callback_setattr,
                    user_data="negative_prompt",
                )

                # save current model
                with dpg.group(horizontal=True):
                    dpg.add_text("Save: ")

                    def callback_save(sender, app_data, user_data):
                        self.save_model(mode=user_data)

                    dpg.add_button(
                        label="model",
                        tag="_button_save_model",
                        callback=callback_save,
                        user_data='model',
                    )
                    dpg.bind_item_theme("_button_save_model", theme_button)

                    dpg.add_button(
                        label="geo",
                        tag="_button_save_mesh",
                        callback=callback_save,
                        user_data='geo',
                    )
                    dpg.bind_item_theme("_button_save_mesh", theme_button)

                    dpg.add_button(
                        label="geo+tex",
                        tag="_button_save_mesh_with_tex",
                        callback=callback_save,
                        user_data='geo+tex',
                    )
                    dpg.bind_item_theme("_button_save_mesh_with_tex", theme_button)

                    dpg.add_input_text(
                        label="",
                        default_value=self.opt.save_path,
                        callback=callback_setattr,
                        user_data="save_path",
                    )

            # training stuff
            with dpg.collapsing_header(label="Train", default_open=True):
                # lr and train button
                with dpg.group(horizontal=True):
                    dpg.add_text("Train: ")

                    def callback_train(sender, app_data):
                        if self.training:
                            self.training = False
                            dpg.configure_item("_button_train", label="start")
                        else:
                            self.prepare_train()
                            self.training = True
                            dpg.configure_item("_button_train", label="stop")

                    # dpg.add_button(
                    #     label="init", tag="_button_init", callback=self.prepare_train
                    # )
                    # dpg.bind_item_theme("_button_init", theme_button)

                    dpg.add_button(
                        label="start", tag="_button_train", callback=callback_train
                    )
                    dpg.bind_item_theme("_button_train", theme_button)

                with dpg.group(horizontal=True):
                    dpg.add_text("", tag="_log_train_time")
                    dpg.add_text("", tag="_log_train_log")

            # rendering options
            with dpg.collapsing_header(label="Rendering", default_open=True):
                # mode combo
                def callback_change_mode(sender, app_data):
                    self.mode = app_data
                    self.need_update = True

                dpg.add_combo(
                    ("image", "depth", "alpha"),
                    label="mode",
                    default_value=self.mode,
                    callback=callback_change_mode,
                )

                # fov slider
                def callback_set_fovy(sender, app_data):
                    self.cam.fovy = np.deg2rad(app_data)
                    self.need_update = True

                dpg.add_slider_int(
                    label="FoV (vertical)",
                    min_value=1,
                    max_value=120,
                    format="%d deg",
                    default_value=np.rad2deg(self.cam.fovy),
                    callback=callback_set_fovy,
                )

                def callback_set_gaussain_scale(sender, app_data):
                    self.gaussain_scale_factor = app_data
                    self.need_update = True

                dpg.add_slider_float(
                    label="gaussain scale",
                    min_value=0,
                    max_value=1,
                    format="%.2f",
                    default_value=self.gaussain_scale_factor,
                    callback=callback_set_gaussain_scale,
                )

        ### register camera handler

        def callback_camera_drag_rotate_or_draw_mask(sender, app_data):
            if not dpg.is_item_focused("_primary_window"):
                return

            dx = app_data[1]
            dy = app_data[2]

            self.cam.orbit(dx, dy)
            self.need_update = True

        def callback_camera_wheel_scale(sender, app_data):
            if not dpg.is_item_focused("_primary_window"):
                return

            delta = app_data

            self.cam.scale(delta)
            self.need_update = True

        def callback_camera_drag_pan(sender, app_data):
            if not dpg.is_item_focused("_primary_window"):
                return

            dx = app_data[1]
            dy = app_data[2]

            self.cam.pan(dx, dy)
            self.need_update = True

        def callback_set_mouse_loc(sender, app_data):
            if not dpg.is_item_focused("_primary_window"):
                return

            # just the pixel coordinate in image
            self.mouse_loc = np.array(app_data)

        with dpg.handler_registry():
            # for camera moving
            dpg.add_mouse_drag_handler(
                button=dpg.mvMouseButton_Left,
                callback=callback_camera_drag_rotate_or_draw_mask,
            )
            dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)
            dpg.add_mouse_drag_handler(
                button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan
            )

        dpg.create_viewport(
            title="Gaussian3D",
            width=self.W + 600,
            height=self.H + (45 if os.name == "nt" else 0),
            resizable=False,
        )

        ### global theme
        with dpg.theme() as theme_no_padding:
            with dpg.theme_component(dpg.mvAll):
                # set all padding to 0 to avoid scroll bar
                dpg.add_theme_style(
                    dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core
                )
                dpg.add_theme_style(
                    dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core
                )
                dpg.add_theme_style(
                    dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core
                )

        dpg.bind_item_theme("_primary_window", theme_no_padding)

        dpg.setup_dearpygui()

        ### register a larger font
        # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf
        if os.path.exists("LXGWWenKai-Regular.ttf"):
            with dpg.font_registry():
                with dpg.font("LXGWWenKai-Regular.ttf", 18) as default_font:
                    dpg.bind_font(default_font)

        # dpg.show_metrics()

        dpg.show_viewport()

    def render(self):
        assert self.gui
        while dpg.is_dearpygui_running():
            # update texture every frame
            if self.training:
                self.train_step()
            self.test_step()
            dpg.render_dearpygui_frame()
    
    # no gui mode
    def train(self, iters=500):
        
        if iters > 0:
            self.prepare_train()
            for i in tqdm.trange(iters):
                self.train_step()
            # do a last prune
            #self.renderer.gaussians.prune(min_opacity=0.01, extent=1, max_screen_size=1)
            

            
            
        # save
        self.save_model(mode='model')
        self.save_model(mode='geo+tex')
        

if __name__ == "__main__":
    import argparse
    from omegaconf import OmegaConf

    parser = argparse.ArgumentParser()
    parser.add_argument("--config", required=True, help="path to the yaml config file")
    args, extras = parser.parse_known_args()

    # override default config from cli
    opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))

    gui = GUI(opt)

    if opt.gui:
        gui.render()
    else:
        gui.train(opt.save_step+1)

================================================
FILE: mini_trainer.ipynb
================================================
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gs_renderer_4d import Renderer, MiniCam\n",
    "from dataset_4d import SparseDataset\n",
    "import os\n",
    "import tqdm\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from cam_utils import orbit_camera, OrbitCamera\n",
    "from guidance.sd_utils import StableDiffusion\n",
    "\n",
    "\n",
    "class trainer:\n",
    "    def __init__(self,opt) -> None:\n",
    "        \n",
    "        #initialize options\n",
    "        self.opt=opt\n",
    "        self.device=self.opt.device\n",
    "        \n",
    "        #initialize renderer and gaussians\n",
    "        self.renderer = Renderer(sh_degree=self.opt.sh_degree)\n",
    "        self.renderer.initialize(num_pts=self.opt.num_pts)   \n",
    "        self.renderer.gaussians.training_setup(self.opt)\n",
    "        \n",
    "        self.optimizer = self.renderer.gaussians.optimizer\n",
    "        \n",
    "        self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)\n",
    "        \n",
    "        #initialize sd. replace with your own diffusion model if necessary.\n",
    "        self.enable_sd = True\n",
    "        self.guidance_sd = StableDiffusion(self.device)\n",
    "        self.guidance_sd.get_text_embeds([self.opt.prompt],negative_prompts= [''])\n",
    "        \n",
    "    def save(self,save_path):\n",
    "        #save \n",
    "        auto_path = save_path\n",
    "        os.makedirs(auto_path,exist_ok=True)\n",
    "        ply_path = os.path.join(auto_path,'model.ply')\n",
    "        self.renderer.gaussians.save_ply(ply_path)\n",
    "        self.renderer.gaussians.save_deformation(auto_path)\n",
    "        \n",
    "    def load(self, load_path):\n",
    "        #load\n",
    "        auto_path = load_path\n",
    "        ply_path = os.path.join(auto_path,'model.ply')\n",
    "        self.renderer.gaussians.load_model(auto_path)\n",
    "        self.renderer.gaussians.load_ply(ply_path)\n",
    "           \n",
    "           \n",
    "    def render(self,frame_id, elevation, azimuth, radius):\n",
    "        #render with parameters\n",
    "        pose = orbit_camera(elevation,azimuth,radius)\n",
    "        cam = MiniCam(\n",
    "                        pose,\n",
    "                        self.opt.ref_size,\n",
    "                        self.opt.ref_size,\n",
    "                        self.cam.fovy,\n",
    "                        self.cam.fovx,\n",
    "                        self.cam.near,\n",
    "                        self.cam.far,\n",
    "                        )   \n",
    "        cam.time=float(frame_id/30) #30 is the total frame\n",
    "        #use stage='coarse' for static rendering, use stage='fine' for dynamic rendering\n",
    "        out = self.renderer.render(cam,stage='fine')\n",
    "        image = out[\"image\"].unsqueeze(0)# [1, 3, H, W] in [0, 1]\n",
    "        \n",
    "        return image\n",
    "    \n",
    "    def train(self):\n",
    "        self.step=0\n",
    "        \n",
    "        for i in tqdm.tqdm(range(10000)):\n",
    "            self.step+=1\n",
    "            self.renderer.gaussians.update_learning_rate(self.step)\n",
    "            loss = 0\n",
    "            \n",
    "            min_ver = -30\n",
    "            max_ver = 30\n",
    "            vers, hors, radiis, poses = [], [], [], []\n",
    "            images=[]\n",
    "            viewspace_point_tensor_list, radii_list, visibility_filter_list = [], [], []\n",
    "\n",
    "            render_resolution=512\n",
    "            \n",
    "            for _ in range(self.opt.batch_size):\n",
    "                #sample time, vertical& horizontal  angle\n",
    "                ver = np.random.randint(min_ver, max_ver)\n",
    "                hor = np.random.randint(-180, 180)\n",
    "                radius=0\n",
    "                self.t = np.random.randint(0,30)\n",
    "                self.time = self.t/30\n",
    "                \n",
    "                vers.append(torch.tensor(self.opt.elevation + ver,device=self.device).unsqueeze(dim=0))\n",
    "                hors.append(torch.tensor(hor,device=self.device).unsqueeze(dim=0))\n",
    "                radiis.append(torch.tensor(self.opt.radius + radius,device=self.device).unsqueeze(dim=0))\n",
    "                \n",
    "                pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)\n",
    "                \n",
    "                poses.append(pose)\n",
    "\n",
    "\n",
    "                cur_cam = MiniCam(\n",
    "                        pose,\n",
    "                        render_resolution,\n",
    "                        render_resolution,\n",
    "                        self.cam.fovy,\n",
    "                        self.cam.fovx,\n",
    "                        self.cam.near,\n",
    "                        self.cam.far,\n",
    "                    )\n",
    "                cur_cam.time=self.time\n",
    "                \n",
    "                bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device=\"cuda\")\n",
    "                #use stage='coarse' for static rendering, use stage='fine' for dynamic rendering\n",
    "                out = self.renderer.render(cur_cam, bg_color=bg_color,stage='fine')\n",
    "                \n",
    "                #basic values for densification\n",
    "                viewspace_point_tensor, visibility_filter, radii = out[\"viewspace_points\"], out[\"visibility_filter\"], out[\"radii\"]  \n",
    "                radii_list.append(radii.unsqueeze(0))\n",
    "                visibility_filter_list.append(visibility_filter.unsqueeze(0))\n",
    "                viewspace_point_tensor_list.append(viewspace_point_tensor)\n",
    "                \n",
    "                image = out[\"image\"].unsqueeze(0)# [1, 3, H, W] in [0, 1]\n",
    "                images.append(image)\n",
    "                \n",
    "            images_batch = torch.cat(images, dim=0)\n",
    "            poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)\n",
    "            vers_batch = torch.cat(vers, dim=0).cpu().numpy()\n",
    "            hors_batch = torch.cat(hors, dim=0).cpu().numpy()\n",
    "            radii_batch = torch.cat(radiis, dim=0).cpu().numpy()\n",
    "\n",
    "            if self.enable_sd:\n",
    "                sd_loss = self.guidance_sd.train_step(images_batch,step_ratio=None,poses=poses)\n",
    "                # guidance loss. replace with your own diffusion model if necessary.\n",
    "                loss = loss + sd_loss\n",
    "            else:\n",
    "                zero123_loss = self.guidance_zero123.train_step(images_batch, vers_batch, hors_batch, radii_batch,step_ratio=None)\n",
    "                # guidance loss.\n",
    "                loss = loss + zero123_loss\n",
    "                \n",
    "            # optimize step\n",
    "            loss.backward()\n",
    "            self.optimizer.step()\n",
    "            self.optimizer.zero_grad()\n",
    "\n",
    "            #densifications. Adaptive densification is used here.\n",
    "            viewspace_point_tensor_grad = torch.zeros_like(viewspace_point_tensor)\n",
    "            for idx in range(0, len(viewspace_point_tensor_list)):\n",
    "                    viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad\n",
    "\n",
    "            if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:\n",
    "                self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])\n",
    "                self.renderer.gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter)\n",
    "                if self.step % self.opt.densification_interval == 0 :\n",
    "\n",
    "                    self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=1, max_screen_size=2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "feature_dim: 128\n",
      "Number of points at initialisation :  10000\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a4cb52b5dc0045ccba1d352fc337cbd0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 72/10000 [00:19<44:48,  3.69it/s]  \n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000\">╭─────────────────────────────── </span><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Traceback </span><span style=\"color: #bf7f7f; text-decoration-color: #bf7f7f; font-weight: bold\">(most recent call last)</span><span style=\"color: #800000; text-decoration-color: #800000\"> ────────────────────────────────╮</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">&lt;module&gt;</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">6</span>                                                                                    <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">3 </span>opt=OmegaConf.load(<span style=\"color: #808000; text-decoration-color: #808000\">'./configs/image_4d_m.yaml'</span>)                                              <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">4 </span>                                                                                             <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">5 </span>train=trainer(opt)                                                                           <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>6 train.train()                                                                                <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">7 </span>                                                                                             <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">train</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">134</span>                                                                                     <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">131 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   │   </span>loss = loss + zero123_loss                                                 <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">132 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   </span>                                                                               <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">133 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># optimize step</span>                                                                <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>134 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   </span>loss.backward()                                                                <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">135 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>.optimizer.step()                                                          <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">136 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>.optimizer.zero_grad()                                                     <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">137 </span>                                                                                           <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">_tensor.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">487</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">backward</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 484 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   │   </span>create_graph=create_graph,                                                <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 485 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   │   </span>inputs=inputs,                                                            <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 486 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   </span>)                                                                             <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span> 487 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>torch.autograd.backward(                                                          <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 488 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   │   </span><span style=\"color: #00ffff; text-decoration-color: #00ffff\">self</span>, gradient, retain_graph, create_graph, inputs=inputs                     <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 489 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>)                                                                                 <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 490 </span>                                                                                          <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/autograd/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">__init__.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">200</span>   <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">backward</span>                                                                                      <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>                                                                                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">197 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># The reason we repeat same the comment below is that</span>                                  <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">198 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># some Python versions print out the first line of a multi-line function</span>               <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">199 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># calls in the traceback and some print out the last line</span>                              <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>200 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   </span>Variable._execution_engine.run_backward(  <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># Calls into the C++ engine to run the bac</span>   <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">201 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>tensors, grad_tensors_, retain_graph, create_graph, inputs,                        <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">202 </span><span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│   │   </span>allow_unreachable=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">True</span>, accumulate_grad=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">True</span>)  <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"># Calls into the C++ engine to ru</span>   <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">│</span>   <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">203 </span>                                                                                           <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
       "<span style=\"color: #800000; text-decoration-color: #800000\">╰──────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
       "<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">KeyboardInterrupt</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
       "\u001b[31m│\u001b[0m in \u001b[92m<module>\u001b[0m:\u001b[94m6\u001b[0m                                                                                    \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m3 \u001b[0mopt=OmegaConf.load(\u001b[33m'\u001b[0m\u001b[33m./configs/image_4d_m.yaml\u001b[0m\u001b[33m'\u001b[0m)                                              \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m4 \u001b[0m                                                                                             \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m5 \u001b[0mtrain=trainer(opt)                                                                           \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m6 train.train()                                                                                \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m7 \u001b[0m                                                                                             \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m in \u001b[92mtrain\u001b[0m:\u001b[94m134\u001b[0m                                                                                     \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m131 \u001b[0m\u001b[2m│   │   │   │   \u001b[0mloss = loss + zero123_loss                                                 \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m132 \u001b[0m\u001b[2m│   │   │   \u001b[0m                                                                               \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m133 \u001b[0m\u001b[2m│   │   │   \u001b[0m\u001b[2m# optimize step\u001b[0m                                                                \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m134 \u001b[2m│   │   │   \u001b[0mloss.backward()                                                                \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m135 \u001b[0m\u001b[2m│   │   │   \u001b[0m\u001b[96mself\u001b[0m.optimizer.step()                                                          \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m136 \u001b[0m\u001b[2m│   │   │   \u001b[0m\u001b[96mself\u001b[0m.optimizer.zero_grad()                                                     \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m137 \u001b[0m                                                                                           \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[2;33m/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/\u001b[0m\u001b[1;33m_tensor.py\u001b[0m:\u001b[94m487\u001b[0m in \u001b[92mbackward\u001b[0m \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 484 \u001b[0m\u001b[2m│   │   │   │   \u001b[0mcreate_graph=create_graph,                                                \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 485 \u001b[0m\u001b[2m│   │   │   │   \u001b[0minputs=inputs,                                                            \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 486 \u001b[0m\u001b[2m│   │   │   \u001b[0m)                                                                             \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m 487 \u001b[2m│   │   \u001b[0mtorch.autograd.backward(                                                          \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 488 \u001b[0m\u001b[2m│   │   │   \u001b[0m\u001b[96mself\u001b[0m, gradient, retain_graph, create_graph, inputs=inputs                     \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 489 \u001b[0m\u001b[2m│   │   \u001b[0m)                                                                                 \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m 490 \u001b[0m                                                                                          \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[2;33m/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/autograd/\u001b[0m\u001b[1;33m__init__.py\u001b[0m:\u001b[94m200\u001b[0m   \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m in \u001b[92mbackward\u001b[0m                                                                                      \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m                                                                                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m197 \u001b[0m\u001b[2m│   \u001b[0m\u001b[2m# The reason we repeat same the comment below is that\u001b[0m                                  \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m198 \u001b[0m\u001b[2m│   \u001b[0m\u001b[2m# some Python versions print out the first line of a multi-line function\u001b[0m               \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m199 \u001b[0m\u001b[2m│   \u001b[0m\u001b[2m# calls in the traceback and some print out the last line\u001b[0m                              \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m200 \u001b[2m│   \u001b[0mVariable._execution_engine.run_backward(  \u001b[2m# Calls into the C++ engine to run the bac\u001b[0m   \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m201 \u001b[0m\u001b[2m│   │   \u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs,                        \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m202 \u001b[0m\u001b[2m│   │   \u001b[0mallow_unreachable=\u001b[94mTrue\u001b[0m, accumulate_grad=\u001b[94mTrue\u001b[0m)  \u001b[2m# Calls into the C++ engine to ru\u001b[0m   \u001b[31m│\u001b[0m\n",
       "\u001b[31m│\u001b[0m   \u001b[2m203 \u001b[0m                                                                                           \u001b[31m│\u001b[0m\n",
       "\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
       "\u001b[1;91mKeyboardInterrupt\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from omegaconf import OmegaConf\n",
    "\n",
    "opt=OmegaConf.load('./configs/image_4d_m.yaml')\n",
    "\n",
    "train=trainer(opt)\n",
    "train.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.0"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "np.deg2rad(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch0",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}


================================================
FILE: requirements.txt
================================================
tqdm
rich
ninja
numpy
pandas
scipy
scikit-learn
matplotlib
opencv-python
imageio
imageio-ffmpeg
omegaconf
argparse
torch
einops
plyfile
pygltflib
dearpygui
accelerate
rembg[gpu,cli]

#zero123plus
opencv-contrib-python
diffusers==0.20.2
transformers==4.29.2
streamlit==1.22.0
altair<5
huggingface_hub
git+https://github.com/facebookresearch/segment-anything.git
gradio>=3.50
fire

================================================
FILE: scripts/app.py
================================================
import os
import sys
import numpy
import torch
import rembg
import threading
import urllib.request
from PIL import Image
import streamlit as st
import huggingface_hub

class SAMAPI:
    predictor = None

    @staticmethod
    @st.cache_resource
    def get_instance(sam_checkpoint=None):
        if SAMAPI.predictor is None:
            if sam_checkpoint is None:
                sam_checkpoint = "tmp/sam_vit_h_4b8939.pth"
            if not os.path.exists(sam_checkpoint):
                os.makedirs('tmp', exist_ok=True)
                urllib.request.urlretrieve(
                    "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
                    sam_checkpoint
                )
            device = "cuda:0" if torch.cuda.is_available() else "cpu"
            model_type = "default"

            from segment_anything import sam_model_registry, SamPredictor

            sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
            sam.to(device=device)

            predictor = SamPredictor(sam)
            SAMAPI.predictor = predictor
        return SAMAPI.predictor

    @staticmethod
    def segment_api(rgb, mask=None, bbox=None, sam_checkpoint=None):
        """

        Parameters
        ----------
        rgb : np.ndarray h,w,3 uint8
        mask: np.ndarray h,w bool

        Returns
        -------

        """
        np = numpy
        predictor = SAMAPI.get_instance(sam_checkpoint)
        predictor.set_image(rgb)
        if mask is None and bbox is None:
            box_input = None
        else:
            # mask to bbox
            if bbox is None:
                y1, y2, x1, x2 = np.nonzero(mask)[0].min(), np.nonzero(mask)[0].max(), np.nonzero(mask)[1].min(), \
                                 np.nonzero(mask)[1].max()
            else:
                x1, y1, x2, y2 = bbox
            box_input = np.array([[x1, y1, x2, y2]])
        masks, scores, logits = predictor.predict(
            box=box_input,
            multimask_output=True,
            return_logits=False,
        )
        mask = masks[-1]
        return mask


def image_examples(samples, ncols, return_key=None, example_text="Examples"):
    global img_example_counter
    trigger = False
    with st.expander(example_text, True):
        for i in range(len(samples) // ncols):
            cols = st.columns(ncols)
            for j in range(ncols):
                idx = i * ncols + j
                if idx >= len(samples):
                    continue
                entry = samples[idx]
                with cols[j]:
                    st.image(entry['dispi'])
                    img_example_counter += 1
                    with st.columns(5)[2]:
                        this_trigger = st.button('\+', key='imgexuse%d' % img_example_counter)
                    trigger = trigger or this_trigger
                    if this_trigger:
                        trigger = entry[return_key]
    return trigger


def segment_img(img: Image):
    output = rembg.remove(img)
    mask = numpy.array(output)[:, :, 3] > 0
    sam_mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
    segmented_img = Image.new("RGBA", img.size, (0, 0, 0, 0))
    segmented_img.paste(img, mask=Image.fromarray(sam_mask))
    return segmented_img


def segment_6imgs(zero123pp_imgs):
    imgs = [zero123pp_imgs.crop([0, 0, 320, 320]),
            zero123pp_imgs.crop([320, 0, 640, 320]),
            zero123pp_imgs.crop([0, 320, 320, 640]),
            zero123pp_imgs.crop([320, 320, 640, 640]),
            zero123pp_imgs.crop([0, 640, 320, 960]),
            zero123pp_imgs.crop([320, 640, 640, 960])]
    segmented_imgs = []
    for i, img in enumerate(imgs):
        output = rembg.remove(img)
        mask = numpy.array(output)[:, :, 3]
        mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
        data = numpy.array(img)[:,:,:3]
        data[mask == 0] = [255, 255, 255]
        segmented_imgs.append(data)
    result = numpy.concatenate([
        numpy.concatenate([segmented_imgs[0], segmented_imgs[1]], axis=1),
        numpy.concatenate([segmented_imgs[2], segmented_imgs[3]], axis=1),
        numpy.concatenate([segmented_imgs[4], segmented_imgs[5]], axis=1)
    ])
    return Image.fromarray(result)


def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result


@st.cache_data
def check_dependencies():
    reqs = []
    try:
        import diffusers
    except ImportError:
        import traceback
        traceback.print_exc()
        print("Error: `diffusers` not found.", file=sys.stderr)
        reqs.append("diffusers==0.20.2")
    else:
        if not diffusers.__version__.startswith("0.20"):
            print(
                f"Warning: You are using an unsupported version of diffusers ({diffusers.__version__}), which may lead to performance issues.",
                file=sys.stderr
            )
            print("Recommended version is `diffusers==0.20.2`.", file=sys.stderr)
    try:
        import transformers
    except ImportError:
        import traceback
        traceback.print_exc()
        print("Error: `transformers` not found.", file=sys.stderr)
        reqs.append("transformers==4.29.2")
    if torch.__version__ < '2.0':
        try:
            import xformers
        except ImportError:
            print("Warning: You are using PyTorch 1.x without a working `xformers` installation.", file=sys.stderr)
            print("You may see a significant memory overhead when running the model.", file=sys.stderr)
    if len(reqs):
        print(f"Info: Fix all dependency errors with `pip install {' '.join(reqs)}`.")


@st.cache_resource
def load_zero123plus_pipeline():
    if 'HF_TOKEN' in os.environ:
        huggingface_hub.login(os.environ['HF_TOKEN'])
    from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
    pipeline = DiffusionPipeline.from_pretrained(
        "sudo-ai/zero123plus-v1.1", custom_pipeline="sudo-ai/zero123plus-pipeline",
        torch_dtype=torch.float16
    )
    # Feel free to tune the scheduler
    pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
        pipeline.scheduler.config, timestep_spacing='trailing'
    )
    if torch.cuda.is_available():
        pipeline.to('cuda:0')
    sys.main_lock = threading.Lock()
    return pipeline


================================================
FILE: scripts/gen_mv.py
================================================
import torch
import requests
from PIL import Image
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
import numpy
import os
import sys
import numpy
import torch
import rembg
import threading
import urllib.request
from PIL import Image
import streamlit as st
import huggingface_hub
from app import SAMAPI
def segment_img(img: Image):
    output = rembg.remove(img)
    mask = numpy.array(output)[:, :, 3] > 0
    sam_mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
    segmented_img = Image.new("RGBA", img.size, (0, 0, 0, 0))
    segmented_img.paste(img, mask=Image.fromarray(sam_mask))
    return segmented_img


def segment_6imgs(zero123pp_imgs):
    imgs = [zero123pp_imgs.crop([0, 0, 320, 320]),
            zero123pp_imgs.crop([320, 0, 640, 320]),
            zero123pp_imgs.crop([0, 320, 320, 640]),
            zero123pp_imgs.crop([320, 320, 640, 640]),
            zero123pp_imgs.crop([0, 640, 320, 960]),
            zero123pp_imgs.crop([320, 640, 640, 960])]
    segmented_imgs = []
    import numpy as np
    for i, img in enumerate(imgs):
        output = rembg.remove(img)
        mask = numpy.array(output)[:, :, 3]
        mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)
        data = numpy.array(img)[:,:,:3]
        data2 = numpy.ones([320,320,4])
        data2[:,:,:3] = data
        for i in np.arange(data2.shape[0]):
                for j in np.arange(data2.shape[1]):
                        if mask[i,j]==1:
                                data2[i,j,3]=255
        segmented_imgs.append(data2)

        #torch.manual_seed(42)
    return segmented_imgs

def process_img(path,destination,pipeline, is_first):
    # Download an example image.
        print('processing:',path)
        #cond_whole = Image.open('output.png')
        cond = Image.open(path)
        # Run the pipeline!
        result = pipeline(cond, num_inference_steps=75,is_first = is_first).images[0]
        # for general real and synthetic images of general objects
        # usually it is enough to have around 28 inference steps
        # for images with delicate details like faces (real or anime)
        # you may need 75-100 steps for the details to construct

        #result.show()
        #result.save("./test_png/zero123pp/output.png")
        result=segment_6imgs(result)
        print('saving:',os.path.join(destination,'0~5.png'),'in',destination)
        for i in numpy.arange(6):
            Image.fromarray(numpy.uint8(result[i])).save(os.path.join(destination,'{}.png'.format(i)))



if __name__ == "__main__":
    import argparse


    parser = argparse.ArgumentParser()
    parser.add_argument("--path", required=True, help="path to process")
    # DiffusionPipeline.from_pretrained cannot received relative path for custom pipeline
    parser.add_argument("--pipeline_path", required=True, help="path of pipeline code, in ../guidance/zero123pp")
    args, extras = parser.parse_known_args()


    pipeline = DiffusionPipeline.from_pretrained(
        "sudo-ai/zero123plus-v1.1", custom_pipeline=args.pipeline_path,
        torch_dtype=torch.float16
    )

    pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
        pipeline.scheduler.config, timestep_spacing='trailing'
    )
    pipeline.to('cuda:0')
    
    
    directory = args.path+'/'
    os.makedirs(directory+'ref', exist_ok=True)
    os.system(f"cp -r {directory+'*.png'} {directory+'ref/'}")
    is_first = True
    l=sorted(os.listdir(directory+'ref'))
        

    for file in  sorted(os.listdir(directory+'ref')):
        if  file[-4:-1]=='.pn':
                
            filename =  os.path.splitext(os.path.basename(file))[0]
            destination = os.path.join(directory+'zero123',filename)
            
            os.makedirs(destination, exist_ok=True)
            img_path = os.path.join(directory+'ref',file)
            process_img(img_path,destination,pipeline, is_first)
            is_first = False

    


================================================
FILE: sh_utils.py
================================================
#  Copyright 2021 The PlenOctree Authors.
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#  this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#  this list of conditions and the following disclaimer in the documentation
#  and/or other materials provided with the distribution.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.

import torch

C0 = 0.28209479177387814
C1 = 0.4886025119029199
C2 = [
    1.0925484305920792,
    -1.0925484305920792,
    0.31539156525252005,
    -1.0925484305920792,
    0.5462742152960396
]
C3 = [
    -0.5900435899266435,
    2.890611442640554,
    -0.4570457994644658,
    0.3731763325901154,
    -0.4570457994644658,
    1.445305721320277,
    -0.5900435899266435
]
C4 = [
    2.5033429417967046,
    -1.7701307697799304,
    0.9461746957575601,
    -0.6690465435572892,
    0.10578554691520431,
    -0.6690465435572892,
    0.47308734787878004,
    -1.7701307697799304,
    0.6258357354491761,
]   


def eval_sh(deg, sh, dirs):
    """
    Evaluate spherical harmonics at unit directions
    using hardcoded SH polynomials.
    Works with torch/np/jnp.
    ... Can be 0 or more batch dimensions.
    Args:
        deg: int SH deg. Currently, 0-3 supported
        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]
        dirs: jnp.ndarray unit directions [..., 3]
    Returns:
        [..., C]
    """
    assert deg <= 4 and deg >= 0
    coeff = (deg + 1) ** 2
    assert sh.shape[-1] >= coeff

    result = C0 * sh[..., 0]
    if deg > 0:
        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
        result = (result -
                C1 * y * sh[..., 1] +
                C1 * z * sh[..., 2] -
                C1 * x * sh[..., 3])

        if deg > 1:
            xx, yy, zz = x * x, y * y, z * z
            xy, yz, xz = x * y, y * z, x * z
            result = (result +
                    C2[0] * xy * sh[..., 4] +
                    C2[1] * yz * sh[..., 5] +
                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
                    C2[3] * xz * sh[..., 7] +
                    C2[4] * (xx - yy) * sh[..., 8])

            if deg > 2:
                result = (result +
                C3[0] * y * (3 * xx - yy) * sh[..., 9] +
                C3[1] * xy * z * sh[..., 10] +
                C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
                C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
                C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
                C3[5] * z * (xx - yy) * sh[..., 14] +
                C3[6] * x * (xx - 3 * yy) * sh[..., 15])

                if deg > 3:
                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
    return result

def RGB2SH(rgb):
    return (rgb - 0.5) / C0

def SH2RGB(sh):
    return sh * C0 + 0.5

================================================
FILE: simple-knn/ext.cpp
================================================
/*
 * Copyright (C) 2023, Inria
 * GRAPHDECO research group, https://team.inria.fr/graphdeco
 * All rights reserved.
 *
 * This software is free for non-commercial, research and evaluation use 
 * under the terms of the LICENSE.md file.
 *
 * For inquiries contact  george.drettakis@inria.fr
 */

#include <torch/extension.h>
#include "spatial.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("distCUDA2", &distCUDA2);
}


================================================
FILE: simple-knn/setup.py
================================================
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use 
# under the terms of the LICENSE.md file.
#
# For inquiries contact  george.drettakis@inria.fr
#

from setuptools import setup
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
import os

cxx_compiler_flags = []

if os.name == 'nt':
    cxx_compiler_flags.append("/wd4624")

setup(
    name="simple_knn",
    ext_modules=[
        CUDAExtension(
            name="simple_knn._C",
            sources=[
            "spatial.cu", 
            "simple_knn.cu",
            "ext.cpp"],
            extra_compile_args={"nvcc": [], "cxx": cxx_compiler_flags})
        ],
    cmdclass={
        'build_ext': BuildExtension
    }
)


================================================
FILE: simple-knn/simple_knn/.gitkeep
================================================


================================================
FILE: simple-knn/simple_knn.cu
================================================
/*
 * Copyright (C) 2023, Inria
 * GRAPHDECO research group, https://team.inria.fr/graphdeco
 * All rights reserved.
 *
 * This software is free for non-commercial, research and evaluation use 
 * under the terms of the LICENSE.md file.
 *
 * For inquiries contact  george.drettakis@inria.fr
 */

#define BOX_SIZE 1024

#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include "simple_knn.h"
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <vector>
#include <cuda_runtime_api.h>
#include <thrust/device_vector.h>
#include <thrust/sequence.h>
#define __CUDACC__
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>

namespace cg = cooperative_groups;

struct CustomMin
{
	__device__ __forceinline__
		float3 operator()(const float3& a, const float3& b) const {
		return { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) };
	}
};

struct CustomMax
{
	__device__ __forceinline__
		float3 operator()(const float3& a, const float3& b) const {
		return { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) };
	}
};

__host__ __device__ uint32_t prepMorton(uint32_t x)
{
	x = (x | (x << 16)) & 0x030000FF;
	x = (x | (x << 8)) & 0x0300F00F;
	x = (x | (x << 4)) & 0x030C30C3;
	x = (x | (x << 2)) & 0x09249249;
	return x;
}

__host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx)
{
	uint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1));
	uint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1));
	uint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1));

	return x | (y << 1) | (z << 2);
}

__global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes)
{
	auto idx = cg::this_grid().thread_rank();
	if (idx >= P)
		return;

	codes[idx] = coord2Morton(points[idx], minn, maxx);
}

struct MinMax
{
	float3 minn;
	float3 maxx;
};

__global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes)
{
	auto idx = cg::this_grid().thread_rank();

	MinMax me;
	if (idx < P)
	{
		me.minn = points[indices[idx]];
		me.maxx = points[indices[idx]];
	}
	else
	{
		me.minn = { FLT_MAX, FLT_MAX, FLT_MAX };
		me.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX };
	}

	__shared__ MinMax redResult[BOX_SIZE];

	for (int off = BOX_SIZE / 2; off >= 1; off /= 2)
	{
		if (threadIdx.x < 2 * off)
			redResult[threadIdx.x] = me;
		__syncthreads();

		if (threadIdx.x < off)
		{
			MinMax other = redResult[threadIdx.x + off];
			me.minn.x = min(me.minn.x, other.minn.x);
			me.minn.y = min(me.minn.y, other.minn.y);
			me.minn.z = min(me.minn.z, other.minn.z);
			me.maxx.x = max(me.maxx.x, other.maxx.x);
			me.maxx.y = max(me.maxx.y, o
Download .txt
gitextract_687zu6g9/

├── .gitignore
├── README.md
├── __init__.py
├── cam_utils.py
├── configs/
│   └── stag4d.yaml
├── dataset_4d.py
├── deform.py
├── gs_renderer_4d.py
├── guidance/
│   ├── zero123_4d_utils.py
│   └── zero123pp/
│       └── pipeline.py
├── main.py
├── mini_trainer.ipynb
├── requirements.txt
├── scripts/
│   ├── app.py
│   └── gen_mv.py
├── sh_utils.py
├── simple-knn/
│   ├── ext.cpp
│   ├── setup.py
│   ├── simple_knn/
│   │   └── .gitkeep
│   ├── simple_knn.cu
│   ├── simple_knn.h
│   ├── spatial.cu
│   └── spatial.h
├── visualize.py
└── zero123.py
Download .txt
SYMBOL INDEX (233 symbols across 14 files)

FILE: cam_utils.py
  function dot (line 6) | def dot(x, y):
  function length (line 13) | def length(x, eps=1e-20):
  function safe_normalize (line 20) | def safe_normalize(x, eps=1e-20):
  function look_at (line 24) | def look_at(campos, target, opengl=True):
  function orbit_camera (line 45) | def orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=No...
  class OrbitCamera (line 65) | class OrbitCamera:
    method __init__ (line 66) | def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):
    method fovx (line 78) | def fovx(self):
    method campos (line 82) | def campos(self):
    method pose (line 87) | def pose(self):
    method view (line 101) | def view(self):
    method perspective (line 106) | def perspective(self):
    method intrinsics (line 126) | def intrinsics(self):
    method mvp (line 131) | def mvp(self):
    method orbit (line 134) | def orbit(self, dx, dy):
    method scale (line 141) | def scale(self, delta):
    method pan (line 144) | def pan(self, dx, dy, dz=0):

FILE: dataset_4d.py
  class SparseDataset (line 16) | class SparseDataset:
    method __init__ (line 17) | def __init__(self, opt, size,device='cuda', type='train', H=256, W=256):
    method collate_ref (line 32) | def collate_ref(self,index):
    method collate_zero123 (line 55) | def collate_zero123(self,index):
    method collate (line 85) | def collate(self, index):
    method dataloader (line 97) | def dataloader(self):
    method dataloader_d (line 101) | def dataloader_d(self):

FILE: deform.py
  class Deformation (line 23) | class Deformation(nn.Module):
    method __init__ (line 24) | def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[],...
    method create_net (line 47) | def create_net(self):
    method query_time (line 66) | def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb):
    method forward (line 79) | def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, o...
    method forward_static (line 85) | def forward_static(self, rays_pts_emb):
    method forward_dynamic (line 89) | def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opac...
    method get_mlp_parameters (line 112) | def get_mlp_parameters(self):
    method get_grid_parameters (line 118) | def get_grid_parameters(self):
  class deform_network (line 121) | class deform_network(nn.Module):
    method __init__ (line 122) | def __init__(self) :
    method forward (line 144) | def forward(self, point, scales=None, rotations=None, opacity=None, ti...
    method forward_static (line 151) | def forward_static(self, points):
    method forward_dynamic (line 154) | def forward_dynamic(self, point, scales=None, rotations=None, opacity=...
    method get_mlp_parameters (line 164) | def get_mlp_parameters(self):
    method get_grid_parameters (line 166) | def get_grid_parameters(self):
  function initialize_weights (line 169) | def initialize_weights(m):
  function get_normalized_directions (line 179) | def get_normalized_directions(directions):
  function normalize_aabb (line 188) | def normalize_aabb(pts, aabb):
  function grid_sample_wrapper (line 190) | def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_...
  function init_grid_param (line 217) | def init_grid_param(
  function interpolate_ms_features (line 242) | def interpolate_ms_features(pts: torch.Tensor,
  class HexPlaneField (line 278) | class HexPlaneField(nn.Module):
    method __init__ (line 279) | def __init__(
    method set_aabb (line 320) | def set_aabb(self,xyz_max, xyz_min):
    method get_density (line 328) | def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Te...
    method forward (line 345) | def forward(self,
  function compute_plane_tv (line 353) | def compute_plane_tv(t):
  function compute_plane_smoothness (line 362) | def compute_plane_smoothness(t):
  class Regularizer (line 371) | class Regularizer():
    method __init__ (line 372) | def __init__(self, reg_type, initialization):
    method step (line 378) | def step(self, global_step):
    method report (line 381) | def report(self, d):
    method regularize (line 385) | def regularize(self, *args, **kwargs) -> torch.Tensor:
    method _regularize (line 391) | def _regularize(self, *args, **kwargs) -> torch.Tensor:
    method __str__ (line 394) | def __str__(self):
  class PlaneTV (line 398) | class PlaneTV(Regularizer):
    method __init__ (line 399) | def __init__(self, initial_value, what: str = 'field'):
    method step (line 407) | def step(self, global_step):
    method _regularize (line 410) | def _regularize(self, model, **kwargs):
  class TimeSmoothness (line 433) | class TimeSmoothness(Regularizer):
    method __init__ (line 434) | def __init__(self, initial_value, what: str = 'field'):
    method _regularize (line 442) | def _regularize(self, model, **kwargs) -> torch.Tensor:
  class L1ProposalNetwork (line 463) | class L1ProposalNetwork(Regularizer):
    method __init__ (line 464) | def __init__(self, initial_value):
    method _regularize (line 467) | def _regularize(self, model, **kwargs) -> torch.Tensor:
  class DepthTV (line 476) | class DepthTV(Regularizer):
    method __init__ (line 477) | def __init__(self, initial_value):
    method _regularize (line 480) | def _regularize(self, model, model_out, **kwargs) -> torch.Tensor:
  class L1TimePlanes (line 488) | class L1TimePlanes(Regularizer):
    method __init__ (line 489) | def __init__(self, initial_value, what='field'):
    method _regularize (line 496) | def _regularize(self, model, **kwargs) -> torch.Tensor:

FILE: gs_renderer_4d.py
  function inverse_sigmoid (line 19) | def inverse_sigmoid(x):
  function get_expon_lr_func (line 25) | def get_expon_lr_func(
  function strip_lowerdiag (line 50) | def strip_lowerdiag(L):
  function strip_symmetric (line 61) | def strip_symmetric(sym):
  function gaussian_3d_coeff (line 64) | def gaussian_3d_coeff(xyzs, covs):
  function build_rotation (line 85) | def build_rotation(r):
  function build_scaling_rotation (line 108) | def build_scaling_rotation(s, r):
  class BasicPointCloud (line 119) | class BasicPointCloud(NamedTuple):
  class GaussianModel (line 125) | class GaussianModel:
    method setup_functions (line 127) | def setup_functions(self):
    method initialize (line 144) | def initialize(self, initial_values, raw=False):
    method __init__ (line 155) | def __init__(self, sh_degree : int,args = None):
    method capture (line 174) | def capture(self):
    method restore (line 192) | def restore(self, model_args, training_args):
    method get_scaling (line 213) | def get_scaling(self):
    method get_rotation (line 217) | def get_rotation(self):
    method get_xyz (line 221) | def get_xyz(self):
    method get_features (line 225) | def get_features(self):
    method get_opacity (line 231) | def get_opacity(self):
    method get_covariance (line 238) | def get_covariance(self, scaling_modifier = 1):
    method oneupSHdegree (line 241) | def oneupSHdegree(self):
    method create_from_pcd (line 245) | def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : fl...
    method training_setup (line 276) | def training_setup(self, training_args):
    method update_learning_rate (line 308) | def update_learning_rate(self, iteration):
    method construct_list_of_attributes (line 324) | def construct_list_of_attributes(self):
    method save_ply (line 338) | def save_ply(self, path):
    method compute_deformation (line 357) | def compute_deformation(self,time):
    method load_model (line 362) | def load_model(self, path):
    method save_deformation (line 378) | def save_deformation(self, path):
    method load_ply (line 383) | def load_ply(self, path):
    method replace_tensor_to_optimizer (line 430) | def replace_tensor_to_optimizer(self, tensor, name):
    method _prune_optimizer (line 445) | def _prune_optimizer(self, mask):
    method prune_points (line 465) | def prune_points(self, mask):
    method cat_tensors_to_optimizer (line 481) | def cat_tensors_to_optimizer(self, tensors_dict):
    method densification_postfix (line 504) | def densification_postfix(self, new_xyz, new_features_dc, new_features...
    method densify_and_split (line 529) | def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):
    method densify_and_clone (line 555) | def densify_and_clone(self, grads, grad_threshold, scene_extent):
    method densify_and_prune (line 571) | def densify_and_prune(self, max_grad_percent, min_opacity, extent, max...
    method prune (line 600) | def prune(self, min_opacity=0.01, extent=1, max_screen_size=1):
    method standard_constaint (line 619) | def standard_constaint(self):
    method add_densification_stats (line 633) | def add_densification_stats(self, viewspace_point_tensor, update_filter):
    method update_deformation_table (line 639) | def update_deformation_table(self,threshold):
    method print_deformation_weight_grad (line 642) | def print_deformation_weight_grad(self):
    method _plane_regulation (line 652) | def _plane_regulation(self):
    method _time_regulation (line 664) | def _time_regulation(self):
    method _l1_regulation (line 676) | def _l1_regulation(self):
    method compute_regulation (line 690) | def compute_regulation(self, time_smoothness_weight, l1_time_planes_we...
  function getProjectionMatrix (line 693) | def getProjectionMatrix(znear, zfar, fovX, fovY):
  class MiniCam (line 709) | class MiniCam:
    method __init__ (line 710) | def __init__(self, c2w, width, height, fovy, fovx, znear, zfar,time=0 ):
  class Renderer (line 738) | class Renderer:
    method __init__ (line 739) | def __init__(self, sh_degree=3, white_background=True, radius=1):
    method initialize (line 753) | def initialize(self, input=None, num_pts=5000, radius=0.5,initial_valu...
    method render (line 785) | def render(

FILE: guidance/zero123_4d_utils.py
  class Zero123 (line 21) | class Zero123(nn.Module):
    method __init__ (line 22) | def __init__(self, device, fp16=True, t_range=[0.02, 0.98]):
    method get_img_embeds (line 62) | def get_img_embeds(self, x, input_imgs=None):
    method refine (line 84) | def refine(self, pred_rgb, polar, azimuth, radius,
    method train_step (line 129) | def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None...
    method update_step (line 186) | def update_step(self, epoch: int, global_step: int, on_load_weights: b...
    method get_steps (line 193) | def get_steps(self,percent,epoch, global_step):
    method decode_latents (line 203) | def decode_latents(self, latents):
    method encode_imgs (line 211) | def encode_imgs(self, imgs, mode=False):

FILE: guidance/zero123pp/pipeline.py
  function to_rgb_image (line 37) | def to_rgb_image(maybe_rgba: Image.Image):
  class MyAttnProcessor2_0 (line 49) | class MyAttnProcessor2_0:
    method __init__ (line 54) | def __init__(self):
    method __call__ (line 58) | def __call__(
  class ReferenceOnlyAttnProc (line 152) | class ReferenceOnlyAttnProc(torch.nn.Module):
    method __init__ (line 153) | def __init__(
    method __call__ (line 164) | def __call__(
  class RefOnlyNoisedUNet (line 191) | class RefOnlyNoisedUNet(torch.nn.Module):
    method __init__ (line 192) | def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMSchedu...
    method __getattr__ (line 212) | def __getattr__(self, name: str):
    method forward_cond (line 218) | def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states...
    method forward (line 230) | def forward(
  function scale_latents (line 275) | def scale_latents(latents):
  function unscale_latents (line 280) | def unscale_latents(latents):
  function scale_image (line 285) | def scale_image(image):
  function unscale_image (line 290) | def unscale_image(image):
  class DepthControlUNet (line 295) | class DepthControlUNet(torch.nn.Module):
    method __init__ (line 296) | def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffu...
    method __getattr__ (line 309) | def __getattr__(self, name: str):
    method forward (line 315) | def forward(self, sample, timestep, encoder_hidden_states, class_label...
  class ModuleListDict (line 336) | class ModuleListDict(torch.nn.Module):
    method __init__ (line 337) | def __init__(self, procs: dict) -> None:
    method __getitem__ (line 342) | def __getitem__(self, key):
  class SuperNet (line 346) | class SuperNet(torch.nn.Module):
    method __init__ (line 347) | def __init__(self, state_dict: Dict[str, torch.Tensor]):
  class Zero123PlusPipeline (line 386) | class Zero123PlusPipeline(diffusers.StableDiffusionPipeline):
    method __init__ (line 405) | def __init__(
    method prepare (line 431) | def prepare(self):
    method add_controlnet (line 436) | def add_controlnet(self, controlnet: Optional[diffusers.ControlNetMode...
    method encode_condition_image (line 441) | def encode_condition_image(self, image: torch.Tensor):
    method __call__ (line 446) | def __call__(

FILE: main.py
  function save_image_to_local (line 19) | def save_image_to_local(image_tensor, file_path):
  class GUI (line 26) | class GUI:
    method __init__ (line 27) | def __init__(self, opt):
    method __del__ (line 110) | def __del__(self):
    method seed_everything (line 114) | def seed_everything(self):
    method prepare_image (line 129) | def prepare_image(self,idx):
    method prepare_train (line 160) | def prepare_train(self):
    method train_step (line 216) | def train_step(self):
    method set_fix_cam (line 414) | def set_fix_cam(self):
    method test_step (line 497) | def test_step(self):
    method load_input (line 567) | def load_input(self, file):
    method save_renderings (line 573) | def save_renderings(self, elev=0, azim=0, radius=2, name='front'):
    method save_model (line 606) | def save_model(self, mode='geo', texture_size=1024):
    method register_dpg (line 622) | def register_dpg(self):
    method render (line 955) | def render(self):
    method train (line 965) | def train(self, iters=500):

FILE: scripts/app.py
  class SAMAPI (line 12) | class SAMAPI:
    method get_instance (line 17) | def get_instance(sam_checkpoint=None):
    method segment_api (line 40) | def segment_api(rgb, mask=None, bbox=None, sam_checkpoint=None):
  function image_examples (line 74) | def image_examples(samples, ncols, return_key=None, example_text="Exampl...
  function segment_img (line 96) | def segment_img(img: Image):
  function segment_6imgs (line 105) | def segment_6imgs(zero123pp_imgs):
  function expand2square (line 128) | def expand2square(pil_img, background_color):
  function check_dependencies (line 143) | def check_dependencies():
  function load_zero123plus_pipeline (line 177) | def load_zero123plus_pipeline():

FILE: scripts/gen_mv.py
  function segment_img (line 17) | def segment_img(img: Image):
  function segment_6imgs (line 26) | def segment_6imgs(zero123pp_imgs):
  function process_img (line 51) | def process_img(path,destination,pipeline, is_first):

FILE: sh_utils.py
  function eval_sh (line 57) | def eval_sh(deg, sh, dirs):
  function RGB2SH (line 114) | def RGB2SH(rgb):
  function SH2RGB (line 117) | def SH2RGB(sh):

FILE: simple-knn/ext.cpp
  function PYBIND11_MODULE (line 15) | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {

FILE: simple-knn/simple_knn.h
  function class (line 15) | class SimpleKNN

FILE: visualize.py
  function save_image_to_local (line 22) | def save_image_to_local(image_tensor, file_path):
  class GUI (line 28) | class GUI:
    method __init__ (line 29) | def __init__(self, opt):
    method __del__ (line 95) | def __del__(self):
    method seed_everything (line 99) | def seed_everything(self):
    method prepare_train (line 120) | def prepare_train(self):
    method train_step (line 164) | def train_step(self):
    method save_renderings (line 206) | def save_renderings(self, elev=0, azim=0, radius=2, name='front', inte...
    method set_fix_cam2 (line 270) | def set_fix_cam2(self):
    method test_step (line 353) | def test_step(self):
    method load_input (line 423) | def load_input(self, file):
    method save_model (line 479) | def save_model(self, mode='geo', texture_size=1024):
    method register_dpg (line 495) | def register_dpg(self):
    method render (line 828) | def render(self):
    method train (line 838) | def train(self, iters=500):

FILE: zero123.py
  class CLIPCameraProjection (line 41) | class CLIPCameraProjection(ModelMixin, ConfigMixin):
    method __init__ (line 53) | def __init__(self, embedding_dim: int = 768, additional_embeddings: in...
    method forward (line 63) | def forward(
  class Zero123Pipeline (line 81) | class Zero123Pipeline(DiffusionPipeline):
    method __init__ (line 109) | def __init__(
    method enable_sequential_cpu_offload (line 180) | def enable_sequential_cpu_offload(self, gpu_id=0):
    method _execution_device (line 204) | def _execution_device(self):
    method _encode_image (line 221) | def _encode_image(
    method run_safety_checker (line 299) | def run_safety_checker(self, image, device, dtype):
    method decode_latents (line 318) | def decode_latents(self, latents):
    method prepare_extra_step_kwargs (line 332) | def prepare_extra_step_kwargs(self, generator, eta):
    method check_inputs (line 353) | def check_inputs(self, image, height, width, callback_steps):
    method prepare_latents (line 371) | def prepare_latents(
    method _get_latent_model_input (line 405) | def _get_latent_model_input(
    method __call__ (line 449) | def __call__(
Condensed preview — 25 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (275K chars).
[
  {
    "path": ".gitignore",
    "chars": 15,
    "preview": "*.pyc\n./valid/*"
  },
  {
    "path": "README.md",
    "chars": 5135,
    "preview": "<h1>[ECCV 2024] STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians</h1>\n\n<div>\n    <a href='https://github.com/ze"
  },
  {
    "path": "__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "cam_utils.py",
    "chars": 4804,
    "preview": "import numpy as np\nfrom scipy.spatial.transform import Rotation as R\n\nimport torch\n\ndef dot(x, y):\n    if isinstance(x, "
  },
  {
    "path": "configs/stag4d.yaml",
    "chars": 2166,
    "preview": "### Input\n# input rgba image path (default to None, can be load in GUI too)\ninput: \n# input text prompt (default to None"
  },
  {
    "path": "dataset_4d.py",
    "chars": 3770,
    "preview": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport random\nimport numpy as np\nfrom scipy.spatial.transform i"
  },
  {
    "path": "deform.py",
    "chars": 19304,
    "preview": "\nimport functools\nimport math\nimport os\nimport time\nfrom tkinter import W\n\nimport numpy as np\nimport torch\nimport torch."
  },
  {
    "path": "gs_renderer_4d.py",
    "chars": 41064,
    "preview": "import os\nimport math\nimport numpy as np\nfrom typing import NamedTuple\nfrom plyfile import PlyData, PlyElement\n\nimport t"
  },
  {
    "path": "guidance/zero123_4d_utils.py",
    "chars": 10453,
    "preview": "from transformers import CLIPTextModel, CLIPTokenizer, logging\nfrom diffusers import (\n    AutoencoderKL,\n    UNet2DCond"
  },
  {
    "path": "guidance/zero123pp/pipeline.py",
    "chars": 20621,
    "preview": "from typing import Any, Dict, Optional\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.s"
  },
  {
    "path": "main.py",
    "chars": 37702,
    "preview": "import os\nimport cv2\nimport time\nimport tqdm\nimport numpy as np\nimport dearpygui.dearpygui as dpg\n\nimport torch\nimport t"
  },
  {
    "path": "mini_trainer.ipynb",
    "chars": 33086,
    "preview": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n "
  },
  {
    "path": "requirements.txt",
    "chars": 378,
    "preview": "tqdm\nrich\nninja\nnumpy\npandas\nscipy\nscikit-learn\nmatplotlib\nopencv-python\nimageio\nimageio-ffmpeg\nomegaconf\nargparse\ntorch"
  },
  {
    "path": "scripts/app.py",
    "chars": 6720,
    "preview": "import os\nimport sys\nimport numpy\nimport torch\nimport rembg\nimport threading\nimport urllib.request\nfrom PIL import Image"
  },
  {
    "path": "scripts/gen_mv.py",
    "chars": 3966,
    "preview": "import torch\nimport requests\nfrom PIL import Image\nfrom diffusers import DiffusionPipeline, EulerAncestralDiscreteSchedu"
  },
  {
    "path": "sh_utils.py",
    "chars": 4371,
    "preview": "#  Copyright 2021 The PlenOctree Authors.\n#  Redistribution and use in source and binary forms, with or without\n#  modif"
  },
  {
    "path": "simple-knn/ext.cpp",
    "chars": 427,
    "preview": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n"
  },
  {
    "path": "simple-knn/setup.py",
    "chars": 830,
    "preview": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# Thi"
  },
  {
    "path": "simple-knn/simple_knn/.gitkeep",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "simple-knn/simple_knn.cu",
    "chars": 6352,
    "preview": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n"
  },
  {
    "path": "simple-knn/simple_knn.h",
    "chars": 451,
    "preview": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n"
  },
  {
    "path": "simple-knn/spatial.cu",
    "chars": 671,
    "preview": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n"
  },
  {
    "path": "simple-knn/spatial.h",
    "chars": 380,
    "preview": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n"
  },
  {
    "path": "visualize.py",
    "chars": 29874,
    "preview": "import os\nimport cv2\nimport time\nimport tqdm\nimport numpy as np\nimport dearpygui.dearpygui as dpg\n\nimport torch\nimport t"
  },
  {
    "path": "zero123.py",
    "chars": 30625,
    "preview": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"Lic"
  }
]

About this extraction

This page contains the full source code of the zeng-yifei/STAG4D GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 25 files (257.0 KB), approximately 64.7k tokens, and a symbol index with 233 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!