[
  {
    "path": ".gitignore",
    "content": "*.pyc\n./valid/*"
  },
  {
    "path": "README.md",
    "content": "<h1>[ECCV 2024] STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians</h1>\n\n<div>\n    <a href='https://github.com/zeng-yifei?tab=repositories/' target='_blank'>Yifei Zeng</a><sup>1</sup>&emsp;\n    <a href=\"https://github.com/yanqinJiang\" target='_blank'>Yanqin Jiang*</a><sup>2</sup>&emsp;\n    <a href=\"https://sites.google.com/site/zhusiyucs/home/\" target='_blank'>Siyu Zhu</a><sup>3</sup>&emsp;\n    <a href='https://github.com/YuanxunLu' target='_blank'>Yuanxun Lu</a><sup>1</sup>&emsp;\n    <a href=\"https://linyou.github.io/\">Youtian Lin</a><sup>1</sup>&emsp;\n    <a href='https://zhuhao-nju.github.io/home/' target='_blank'>Hao Zhu</a><sup>1</sup>&emsp;\n    <a href=\"https://people.ucas.ac.cn/~huweiming\">Weiming Hu</a><sup>2</sup>&emsp;\n    <a href='https://cite.nju.edu.cn/People/Faculty/20190621/i5054.html' target='_blank'>Xun Cao</a><sup>1</sup>&emsp;\n    <a href='https://yoyo000.github.io/' target='_blank'>Yao Yao</a><sup>1+</sup>&emsp;\n</div>\n<div>\n    <sup>1</sup>Nanjing University\n    <sup>2</sup>CASIA\n    <sup>3</sup>Fudan University\n</div>\n<div>\n    <sup>*</sup>equal contribution\n    <sup>+</sup>corresponding author\n</div>\n\n<h4 align=\"center\">\n  <a href=\"https://nju-3dv.github.io/projects/STAG4D/\" target='_blank'>[Project Page]</a> •\n</h4>\n\n# Update\n7.4 Our paper has been accepted by ECCV 2024. Congrats!\n\n6.20: IMPORTANT. Fix the bug caused by new version of diff_gauss. Newest version of diff_gauss use `color, depth, norm, alpha, radii, extra` as an output. However, previous version use `color, depth, alpha, radii` as an output. Using older version of this code will cause mismatch error and may misuse normal for the alpha loss, resulting in bad results.\n\n5.26: Update Text/Image to 4D data below.\n\n5.21: Fix RGB loss into the batch loop. Add visualize code.\n\n\n# ⚙️ Installation\n\n```bash\npip install -r requirements.txt\n\ngit clone --recursive https://github.com/slothfulxtx/diff-gaussian-rasterization.git\npip install ./diff-gaussian-rasterization\n\npip install ./simple-knn\n```\n\n# Video-to-4D\nTo generate the examples in the project page, you can download the dataset from [google drive](https://drive.google.com/file/d/1YDvhBv6z5SByF_WaTQVzzL9qz3TyEm6a/view?usp=sharing). Place them in the dataset folder, and run:\n```bash\npython main.py --config configs/stag4d.yaml path=dataset/minions save_path=minions\n\n#use --gui=True to turn on the visualizer (recommend)\npython main.py --config configs/stag4d.yaml path=dataset/minions save_path=minions gui=True\n\n```\n\nTo generate the spatial-temporal consistent data from stratch, your should place your rgba data in the form of \n\n```\n├── dataset\n│   | your_data \n│     ├── 0_rgba.png\n│     ├── 1_rgba.png\n│     ├── 2_rgba.png\n│     ├── ...\n\n```\n\nand then run \n```bash\npython scripts/gen_mv.py --path dataset/your_data --pipeline_path xxx/guidance/zero123pp\n\npython main.py --config configs/stag4d.yaml path=data_path save_path=saving_path gui=True\n```\n\nTo visualize the result, use you can replace the main.py with visualize.py, and the result will be saved to the valid/xxx path, e.g.:\n```bash\npython visualize.py --config configs/stag4d.yaml path=dataset/minions save_path=minions\n```\n\n<img src='assets/videoto4d.gif' height='60%'>\n\n# Text-to-4D\nFor Text to 4D generation, we recommend using SDXL and SVD to generate a reasonable video. Then, after matting the video, use\nthe command above to generate a good 4D result. (This pipeline contains many independent parts and is kind of complex, so we may upload the whole workflow after integration if possible.)\n\nIf you want generate the examples in the paper, I also updated the corresponding data here in [google drive](https://drive.google.com/file/d/1EDNL7EBMR1vlfMOABdXjHzcKY7IXdcnj/view?usp=sharing). Remember to set size to 26 in config or use `size=26` in the command:\n\n```bash\npython main.py --config configs/stag4d.yaml path=dataset/xxx save_path=xxx size=26\n```\n\n<img src='assets/textto4d3.gif' height='60%'>\n\n# Tips for better quality\nIf you want sacrifice time for better quality, here is some tips you can try to further improve the generated quality.\n\n1, Use larger batch size.\n\n2, Run for more steps.\n\n## Citation\nIf you find our work useful for your research, please consider citing our paper as well as Consistent4D:\n```\n@article{zeng2024stag4d,\n      title={STAG4D: Spatial-Temporal Anchored Generative 4D Gaussians}, \n      author={Yifei Zeng and Yanqin Jiang and Siyu Zhu and Yuanxun Lu and Youtian Lin and Hao Zhu and Weiming Hu and Xun Cao and Yao Yao},\n      year={2024},\n      eprint={2403.14939},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV}\n}\n\n@article{jiang2023consistent4d,\n      title={Consistent4D: Consistent 360{\\deg} Dynamic Object Generation from Monocular Video}, \n      author={Yanqin Jiang and Li Zhang and Jin Gao and Weimin Hu and Yao Yao},\n      year={2023},\n      eprint={2311.02848},\n      archivePrefix={arXiv},\n      primaryClass={cs.CV}\n}\n```\n\n# Acknowledgment\nThis repo is built on [DreamGaussian](https://github.com/dreamgaussian/dreamgaussian) and [Zero123plus](https://github.com/SUDO-AI-3D/zero123plus). Thank all the authors for their great work."
  },
  {
    "path": "__init__.py",
    "content": ""
  },
  {
    "path": "cam_utils.py",
    "content": "import numpy as np\nfrom scipy.spatial.transform import Rotation as R\n\nimport torch\n\ndef dot(x, y):\n    if isinstance(x, np.ndarray):\n        return np.sum(x * y, -1, keepdims=True)\n    else:\n        return torch.sum(x * y, -1, keepdim=True)\n\n\ndef length(x, eps=1e-20):\n    if isinstance(x, np.ndarray):\n        return np.sqrt(np.maximum(np.sum(x * x, axis=-1, keepdims=True), eps))\n    else:\n        return torch.sqrt(torch.clamp(dot(x, x), min=eps))\n\n\ndef safe_normalize(x, eps=1e-20):\n    return x / length(x, eps)\n\n\ndef look_at(campos, target, opengl=True):\n    # campos: [N, 3], camera/eye position\n    # target: [N, 3], object to look at\n    # return: [N, 3, 3], rotation matrix\n    if not opengl:\n        # camera forward aligns with -z\n        forward_vector = safe_normalize(target - campos)\n        up_vector = np.array([0, 1, 0], dtype=np.float32)\n        right_vector = safe_normalize(np.cross(forward_vector, up_vector))\n        up_vector = safe_normalize(np.cross(right_vector, forward_vector))\n    else:\n        # camera forward aligns with +z\n        forward_vector = safe_normalize(campos - target)\n        up_vector = np.array([0, 1, 0], dtype=np.float32)\n        right_vector = safe_normalize(np.cross(up_vector, forward_vector))\n        up_vector = safe_normalize(np.cross(forward_vector, right_vector))\n    R = np.stack([right_vector, up_vector, forward_vector], axis=1)\n    return R\n\n\n# elevation & azimuth to pose (cam2world) matrix\ndef orbit_camera(elevation, azimuth, radius=1, is_degree=True, target=None, opengl=True):\n    # radius: scalar\n    # elevation: scalar, in (-90, 90), from +y to -y is (-90, 90)\n    # azimuth: scalar, in (-180, 180), from +z to +x is (0, 90)\n    # return: [4, 4], camera pose matrix\n    if is_degree:\n        elevation = np.deg2rad(np.array(elevation))\n        azimuth = np.deg2rad(np.array(azimuth))\n    x = radius * np.cos(elevation) * np.sin(azimuth)\n    y = - radius * np.sin(elevation)\n    z = radius * np.cos(elevation) * np.cos(azimuth)\n    if target is None:\n        target = np.zeros([3], dtype=np.float32)\n    campos = np.array([x, y, z]) + target  # [3]\n    T = np.eye(4, dtype=np.float32)\n    T[:3, :3] = look_at(campos, target, opengl)\n    T[:3, 3] = campos\n    return T\n\n\nclass OrbitCamera:\n    def __init__(self, W, H, r=2, fovy=60, near=0.01, far=100):\n        self.W = W\n        self.H = H\n        self.radius = r  # camera distance from center\n        self.fovy = np.deg2rad(fovy)  # deg 2 rad\n        self.near = near\n        self.far = far\n        self.center = np.array([0, 0, 0], dtype=np.float32)  # look at this point\n        self.rot = R.from_matrix(np.eye(3))\n        self.up = np.array([0, 1, 0], dtype=np.float32)  # need to be normalized!\n\n    @property\n    def fovx(self):\n        return 2 * np.arctan(np.tan(self.fovy / 2) * self.W / self.H)\n\n    @property\n    def campos(self):\n        return self.pose[:3, 3]\n\n    # pose (c2w)\n    @property\n    def pose(self):\n        # first move camera to radius\n        res = np.eye(4, dtype=np.float32)\n        res[2, 3] = self.radius  # opengl convention...\n        # rotate\n        rot = np.eye(4, dtype=np.float32)\n        rot[:3, :3] = self.rot.as_matrix()\n        res = rot @ res\n        # translate\n        res[:3, 3] -= self.center\n        return res\n\n    # view (w2c)\n    @property\n    def view(self):\n        return np.linalg.inv(self.pose)\n\n    # projection (perspective)\n    @property\n    def perspective(self):\n        y = np.tan(self.fovy / 2)\n        aspect = self.W / self.H\n        return np.array(\n            [\n                [1 / (y * aspect), 0, 0, 0],\n                [0, -1 / y, 0, 0],\n                [\n                    0,\n                    0,\n                    -(self.far + self.near) / (self.far - self.near),\n                    -(2 * self.far * self.near) / (self.far - self.near),\n                ],\n                [0, 0, -1, 0],\n            ],\n            dtype=np.float32,\n        )\n\n    # intrinsics\n    @property\n    def intrinsics(self):\n        focal = self.H / (2 * np.tan(self.fovy / 2))\n        return np.array([focal, focal, self.W // 2, self.H // 2], dtype=np.float32)\n\n    @property\n    def mvp(self):\n        return self.perspective @ np.linalg.inv(self.pose)  # [4, 4]\n\n    def orbit(self, dx, dy):\n        # rotate along camera up/side axis!\n        side = self.rot.as_matrix()[:3, 0]\n        rotvec_x = self.up * np.radians(-0.05 * dx)\n        rotvec_y = side * np.radians(-0.05 * dy)\n        self.rot = R.from_rotvec(rotvec_x) * R.from_rotvec(rotvec_y) * self.rot\n\n    def scale(self, delta):\n        self.radius *= 1.1 ** (-delta)\n\n    def pan(self, dx, dy, dz=0):\n        # pan in camera coordinate system (careful on the sensitivity!)\n        self.center += 0.0005 * self.rot.as_matrix()[:3, :3] @ np.array([-dx, -dy, dz])"
  },
  {
    "path": "configs/stag4d.yaml",
    "content": "### Input\n# input rgba image path (default to None, can be load in GUI too)\ninput: \n# input text prompt (default to None, can be input in GUI too)\nprompt: a minion\n# input mesh for stage 2 (auto-search from stage 1 output path if None)\nmesh:\n# estimated elevation angle for input image \nelevation: 0\n# reference image resolution\nref_size: 512\n# density thresh for mesh extraction\ndensity_thresh: 1\ndevice: cuda\n\n#dynamic\nsize: 30\npath: dataset/minions\n\n# checkpoint to load for stage 1 (should be a ply file)\nload: \n\n### Output\noutdir: logs\nmesh_format: obj\nsave_path: ???\nsave_step: 8000\n#checkpoint to load for stage fine (should be a path of ply with deform pth)\nload_path: \nload_step: \nvalid_interval: 500\n\n### Training\n# guidance loss weights (0 to disable)\nlambda_sd: 0\nmvdream: False\nlambda_zero123: 1\nlambda_tv: 1\nscale_loss_ratio: 7.5\nimagedream: False\n\n# training batch size per iter\nbatch_size: 4\n# training iterations for stage 1\niters: 2000\n# training iterations for stage 2\niters_refine: 50\n# training camera radius\nradius: 2\n# training camera fovy\nfovy: 49.1 # align with zero123 rendering setting (ref: https://github.com/cvlab-columbia/zero123/blob/main/objaverse-rendering/scripts/blender_script.py#L61\n\n# whether allow geom training in stage 2\ntrain_geo: False\n# prob to invert background color during training (0 = always black, 1 = always white)\ninvert_bg_prob: 0.5\n\n\n### GUI\ngui: False\nforce_cuda_rast: False\n# GUI resolution\nH: 800\nW: 800\ndeformation_lr_init : 0.00016\ndeformation_lr_final : 0.000016\ndeformation_lr_delay_mult : 0.02\ngrid_lr_init : 0.0016\ngrid_lr_final : 0.00016\n### Gaussian splatting\nnum_pts: 10000\nsh_degree: 0\nposition_lr_init : 0.0002\nposition_lr_final : 0.000002\nposition_lr_delay_mult: 0.01\nposition_lr_max_steps: 2000\nposition_lr_max_steps2: 5000\n\nfeature_lr: 0.005\nopacity_lr: 0.02\nscaling_lr: 0.01\nrotation_lr: 0.002\ninit_steps: 700\n\npercent_dense: 0.1\ndensity_start_iter: 1200\ndensity_end_iter: 6000\ndensification_interval: 100\nopacity_reset_interval: 700\ndensify_grad_threshold_percent: 0.025\n\ntime_smoothness_weight: 5\nplane_tv_weight: 0.05\nl1_time_planes: 0.05\n\n\n### Textured Mesh\ngeom_lr: 0.0001\ntexture_lr: 0.2"
  },
  {
    "path": "dataset_4d.py",
    "content": "import os\nimport cv2\nimport glob\nimport json\nimport tqdm\nimport random\nimport numpy as np\nfrom scipy.spatial.transform import Slerp, Rotation\n\n\nimport torch\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader\nimport rembg\nimport glob\nclass SparseDataset:\n    def __init__(self, opt, size,device='cuda', type='train', H=256, W=256):\n        super().__init__()\n        \n        self.opt = opt\n        self.device = device\n        self.type = type # train, val, test\n        self.size = size\n        self.H = H\n        self.W = W\n        self.path = opt.path\n        \n        self.cx = self.H / 2\n        self.cy = self.W / 2\n        self.bg_remover=None\n\n    def collate_ref(self,index):\n        #print(index,str(index))\n        file = os.path.join(self.path,'ref/{}_rgba.png'.format(str(index)))\n        #print(f'[INFO] load image from {file}...')\n\n        img = cv2.imread(file, cv2.IMREAD_UNCHANGED)\n        if img.shape[-1] == 3:\n            if self.bg_remover is None:\n                self.bg_remover = rembg.new_session()\n            img = rembg.remove(img, session=self.bg_remover)\n\n        img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)\n        img = img.astype(np.float32) / 255.0\n\n        self.input_mask = img[..., 3:]\n        # white bg\n        self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)\n        # bgr to rgb\n        self.input_img = self.input_img[..., ::-1].copy()\n\n        return self.input_img ,self.input_mask\n\n                \n    def collate_zero123(self,index):\n\n        self.pattern=os.path.join(self.path,'zero123/{}_rgba/*.png'.format(str(index)))\n        self.input_imgs=[]\n        self.input_masks=[]\n        file_list = glob.glob(self.pattern)\n        #print(self.pattern,file_list)\n        for files in sorted(file_list):\n                    \n                   \n                    #print(f'[INFO] load image from {files}...')\n                    img = cv2.imread(files, cv2.IMREAD_UNCHANGED)\n                    if img.shape[-1] == 3:\n                        if self.bg_remover is None:\n                            self.bg_remover = rembg.new_session()\n                        img = rembg.remove(img, session=self.bg_remover)\n\n                    img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)\n                    img = img.astype(np.float32) / 255.0\n\n                    self.input_mask = img[..., 3:]\n                    # white bg\n                    self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)\n                    # bgr to rgb\n                    self.input_img = self.input_img[..., ::-1].copy()\n                    \n                    self.input_imgs.append(self.input_img)\n                    self.input_masks.append(self.input_mask)\n        return self.input_imgs, self.input_masks\n    \n    def collate(self, index):\n        ref_view_batch,input_mask_batch,zero123_view_batch,zero123_masks_batch = [],[],[],[]\n        for index in np.arange(self.size):\n            ref_view,input_mask = self.collate_ref(index)\n            zero123_view,zero123_masks = self.collate_zero123(index)\n            ref_view_batch.append(ref_view)\n            input_mask_batch.append(input_mask)\n            zero123_view_batch.append(zero123_view)\n            zero123_masks_batch.append(zero123_masks)\n        return ref_view_batch, input_mask_batch,zero123_view_batch,zero123_masks_batch\n    \n\n    def dataloader(self):\n        loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate,shuffle=False, num_workers=0)\n        return loader\n\n    def dataloader_d(self):\n        loader = DataLoader(list(range(self.size)), batch_size=1, collate_fn=self.collate_d,shuffle=False, num_workers=0)\n        return loader"
  },
  {
    "path": "deform.py",
    "content": "\nimport functools\nimport math\nimport os\nimport time\nfrom tkinter import W\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.cpp_extension import load\nimport torch.nn.init as init\nimport abc\n            \nimport itertools\nimport logging as log\nfrom typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nclass Deformation(nn.Module):\n    def __init__(self, D=8, W=256, input_ch=27, input_ch_time=9, skips=[], args=None):\n        super(Deformation, self).__init__()\n        self.D = D\n        self.W = W\n        self.input_ch = input_ch\n        self.input_ch_time = input_ch_time\n        self.skips = skips\n        self.no_grid=False\n        self.no_ds=False\n        self.no_dr=False\n        self.no_do=True\n        self.bounds = 1.6\n        self.kplanes_config = {\n                             'grid_dimensions': 2,\n                             'input_coordinate_dim': 4,\n                             'output_coordinate_dim': 32,\n                             'resolution': [64, 64, 64, 25]\n                            }\n        self.multires = [1, 2, 4, 8]\n        self.no_grid = self.no_grid\n        self.grid = HexPlaneField(self.bounds, self.kplanes_config, self.multires)\n        self.pos_deform, self.scales_deform, self.rotations_deform, self.opacity_deform = self.create_net()\n\n    def create_net(self):\n        \n        mlp_out_dim = 0\n        if self.no_grid:\n            self.feature_out = [nn.Linear(4,self.W)]\n        else:\n            self.feature_out = [nn.Linear(mlp_out_dim + self.grid.feat_dim ,self.W)]\n        \n        for i in range(self.D-1):\n            self.feature_out.append(nn.ReLU())\n            self.feature_out.append(nn.Linear(self.W,self.W))\n        self.feature_out = nn.Sequential(*self.feature_out)\n        output_dim = self.W\n        return  \\\n            nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\\\n            nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 3)),\\\n            nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 4)), \\\n            nn.Sequential(nn.ReLU(),nn.Linear(self.W,self.W),nn.ReLU(),nn.Linear(self.W, 1))\n    \n    def query_time(self, rays_pts_emb, scales_emb, rotations_emb, time_emb):\n\n        if self.no_grid:\n            h = torch.cat([rays_pts_emb[:,:3],time_emb[:,:1]],-1)\n        else:\n            grid_feature = self.grid(rays_pts_emb[:,:3], time_emb[:,:1])\n\n            h = grid_feature\n        \n        h = self.feature_out(h)\n  \n        return h\n\n    def forward(self, rays_pts_emb, scales_emb=None, rotations_emb=None, opacity = None, time_emb=None):\n        if time_emb is None:\n            return self.forward_static(rays_pts_emb[:,:3])\n        else:\n            return self.forward_dynamic(rays_pts_emb, scales_emb, rotations_emb, opacity, time_emb)\n\n    def forward_static(self, rays_pts_emb):\n        grid_feature = self.grid(rays_pts_emb[:,:3])\n        dx = self.static_mlp(grid_feature)\n        return rays_pts_emb[:, :3] + dx\n    def forward_dynamic(self,rays_pts_emb, scales_emb, rotations_emb, opacity_emb, time_emb):\n        hidden = self.query_time(rays_pts_emb, scales_emb, rotations_emb, time_emb).float()\n        dx = self.pos_deform(hidden)\n        pts = rays_pts_emb[:, :3] + dx\n        if self.no_ds:\n            scales = scales_emb[:,:3]\n        else:\n            ds = self.scales_deform(hidden)\n            scales = scales_emb[:,:3] + ds\n        if self.no_dr:\n            rotations = rotations_emb[:,:4]\n        else:\n            dr = self.rotations_deform(hidden)\n            rotations = rotations_emb[:,:4] + dr\n        if self.no_do:\n            opacity = opacity_emb[:,:1] \n        else:\n            do = self.opacity_deform(hidden) \n            opacity = opacity_emb[:,:1] + do\n        # + do\n        # print(\"deformation value:\",\"pts:\",torch.abs(dx).mean(),\"rotation:\",torch.abs(dr).mean())\n\n        return pts, scales, rotations, opacity\n    def get_mlp_parameters(self):\n        parameter_list = []\n        for name, param in self.named_parameters():\n            if  \"grid\" not in name:\n                parameter_list.append(param)\n        return parameter_list\n    def get_grid_parameters(self):\n        return list(self.grid.parameters() ) \n    # + list(self.timegrid.parameters())\nclass deform_network(nn.Module):\n    def __init__(self) :\n        super(deform_network, self).__init__()\n        net_width = 64\n        timebase_pe = 4\n        defor_depth= 1\n        posbase_pe= 10\n        scale_rotation_pe = 2\n        opacity_pe = 2\n        timenet_width = 64\n        timenet_output = 32\n        times_ch = 2*timebase_pe+1\n        self.timenet = nn.Sequential(\n        nn.Linear(times_ch, timenet_width), nn.ReLU(),\n        nn.Linear(timenet_width, timenet_output))\n        self.deformation_net = Deformation(W=net_width, D=defor_depth, input_ch=(4+3)+((4+3)*scale_rotation_pe)*2, input_ch_time=timenet_output, args=None)\n        self.register_buffer('time_poc', torch.FloatTensor([(2**i) for i in range(timebase_pe)]))\n        self.register_buffer('pos_poc', torch.FloatTensor([(2**i) for i in range(posbase_pe)]))\n        self.register_buffer('rotation_scaling_poc', torch.FloatTensor([(2**i) for i in range(scale_rotation_pe)]))\n        self.register_buffer('opacity_poc', torch.FloatTensor([(2**i) for i in range(opacity_pe)]))\n        self.apply(initialize_weights)\n        # print(self)\n\n    def forward(self, point, scales=None, rotations=None, opacity=None, times_sel=None):\n        if times_sel is not None:\n            return self.forward_dynamic(point, scales, rotations, opacity, times_sel)\n        else:\n            return self.forward_static(point)\n\n        \n    def forward_static(self, points):\n        points = self.deformation_net(points)\n        return points\n    def forward_dynamic(self, point, scales=None, rotations=None, opacity=None, times_sel=None):\n        # times_emb = poc_fre(times_sel, self.time_poc)\n\n        means3D, scales, rotations, opacity = self.deformation_net( point,\n                                                  scales,\n                                                rotations,\n                                                opacity,\n                                                # times_feature,\n                                                times_sel)\n        return means3D, scales, rotations, opacity\n    def get_mlp_parameters(self):\n        return self.deformation_net.get_mlp_parameters() + list(self.timenet.parameters())\n    def get_grid_parameters(self):\n        return self.deformation_net.get_grid_parameters()\n\ndef initialize_weights(m):\n    if isinstance(m, nn.Linear):\n        # init.constant_(m.weight, 0)\n        init.xavier_uniform_(m.weight,gain=1)\n        if m.bias is not None:\n            init.xavier_uniform_(m.weight,gain=1)\n            # init.constant_(m.bias, 0)\n\n\n\ndef get_normalized_directions(directions):\n    \"\"\"SH encoding must be in the range [0, 1]\n\n    Args:\n        directions: batch of directions\n    \"\"\"\n    return (directions + 1.0) / 2.0\n\n\ndef normalize_aabb(pts, aabb):\n    return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0\ndef grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor:\n    grid_dim = coords.shape[-1]\n\n    if grid.dim() == grid_dim + 1:\n        # no batch dimension present, need to add it\n        grid = grid.unsqueeze(0)\n    if coords.dim() == 2:\n        coords = coords.unsqueeze(0)\n\n    if grid_dim == 2 or grid_dim == 3:\n        grid_sampler = F.grid_sample\n    else:\n        raise NotImplementedError(f\"Grid-sample was called with {grid_dim}D data but is only \"\n                                  f\"implemented for 2 and 3D data.\")\n\n    coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:]))\n    B, feature_dim = grid.shape[:2]\n    n = coords.shape[-2]\n    interp = grid_sampler(\n        grid,  # [B, feature_dim, reso, ...]\n        coords,  # [B, 1, ..., n, grid_dim]\n        align_corners=align_corners,\n        mode='bilinear', padding_mode='border')\n    interp = interp.view(B, feature_dim, n).transpose(-1, -2)  # [B, n, feature_dim]\n    interp = interp.squeeze()  # [B?, n, feature_dim?]\n    return interp\n\ndef init_grid_param(\n        grid_nd: int,\n        in_dim: int,\n        out_dim: int,\n        reso: Sequence[int],\n        a: float = 0.1,\n        b: float = 0.5):\n    assert in_dim == len(reso), \"Resolution must have same number of elements as input-dimension\"\n    has_time_planes = in_dim == 4\n    assert grid_nd <= in_dim\n    coo_combs = list(itertools.combinations(range(in_dim), grid_nd))\n    grid_coefs = nn.ParameterList()\n    for ci, coo_comb in enumerate(coo_combs):\n        new_grid_coef = nn.Parameter(torch.empty(\n            [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]\n        ))\n        if has_time_planes and 3 in coo_comb:  # Initialize time planes to 1\n            nn.init.ones_(new_grid_coef)\n        else:\n            nn.init.uniform_(new_grid_coef, a=a, b=b)\n        grid_coefs.append(new_grid_coef)\n\n    return grid_coefs\n\n\ndef interpolate_ms_features(pts: torch.Tensor,\n                            ms_grids: Collection[Iterable[nn.Module]],\n                            grid_dimensions: int,\n                            concat_features: bool,\n                            num_levels: Optional[int],\n                            ) -> torch.Tensor:\n    coo_combs = list(itertools.combinations(\n        range(pts.shape[-1]), grid_dimensions)\n    )\n    if num_levels is None:\n        num_levels = len(ms_grids)\n    multi_scale_interp = [] if concat_features else 0.\n    grid: nn.ParameterList\n    for scale_id,  grid in enumerate(ms_grids[:num_levels]):\n        interp_space = 1.\n        for ci, coo_comb in enumerate(coo_combs):\n            # interpolate in plane\n            feature_dim = grid[ci].shape[1]  # shape of grid[ci]: 1, out_dim, *reso\n            interp_out_plane = (\n                grid_sample_wrapper(grid[ci], pts[..., coo_comb])\n                .view(-1, feature_dim)\n            )\n            # compute product over planes\n            interp_space = interp_space * interp_out_plane\n\n        # combine over scales\n        if concat_features:\n            multi_scale_interp.append(interp_space)\n        else:\n            multi_scale_interp = multi_scale_interp + interp_space\n\n    if concat_features:\n        multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)\n    return multi_scale_interp\n\n\nclass HexPlaneField(nn.Module):\n    def __init__(\n        self,\n        \n        bounds,\n        planeconfig,\n        multires\n    ) -> None:\n        super().__init__()\n        aabb = torch.tensor([[bounds,bounds,bounds],\n                             [-bounds,-bounds,-bounds]])\n        self.aabb = nn.Parameter(aabb, requires_grad=False)\n        self.grid_config =  [planeconfig]\n        self.multiscale_res_multipliers = multires\n        self.concat_features = True\n\n        # 1. Init planes\n        self.grids = nn.ModuleList()\n        self.feat_dim = 0\n        for res in self.multiscale_res_multipliers:\n            # initialize coordinate grid\n            config = self.grid_config[0].copy()\n            # Resolution fix: multi-res only on spatial planes\n            config[\"resolution\"] = [\n                r * res for r in config[\"resolution\"][:3]\n            ] + config[\"resolution\"][3:]\n            gp = init_grid_param(\n                grid_nd=config[\"grid_dimensions\"],\n                in_dim=config[\"input_coordinate_dim\"],\n                out_dim=config[\"output_coordinate_dim\"],\n                reso=config[\"resolution\"],\n            )\n            # shape[1] is out-dim - Concatenate over feature len for each scale\n            if self.concat_features:\n                self.feat_dim += gp[-1].shape[1]\n            else:\n                self.feat_dim = gp[-1].shape[1]\n            self.grids.append(gp)\n        # print(f\"Initialized model grids: {self.grids}\")\n        print(\"feature_dim:\",self.feat_dim)\n\n\n    def set_aabb(self,xyz_max, xyz_min):\n        aabb = torch.tensor([\n            xyz_max,\n            xyz_min\n        ])\n        self.aabb = nn.Parameter(aabb,requires_grad=True)\n        print(\"Voxel Plane: set aabb=\",self.aabb)\n\n    def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):\n        \"\"\"Computes and returns the densities.\"\"\"\n\n        pts = normalize_aabb(pts, self.aabb)\n        pts = torch.cat((pts, timestamps), dim=-1)  # [n_rays, n_samples, 4]\n\n        pts = pts.reshape(-1, pts.shape[-1])\n        features = interpolate_ms_features(\n            pts, ms_grids=self.grids,  # noqa\n            grid_dimensions=self.grid_config[0][\"grid_dimensions\"],\n            concat_features=self.concat_features, num_levels=None)\n        if len(features) < 1:\n            features = torch.zeros((0, 1)).to(features.device)\n\n\n        return features\n\n    def forward(self,\n                pts: torch.Tensor,\n                timestamps: Optional[torch.Tensor] = None):\n\n        features = self.get_density(pts, timestamps)\n\n        return features\n    \ndef compute_plane_tv(t):\n    batch_size, c, h, w = t.shape\n    count_h = batch_size * c * (h - 1) * w\n    count_w = batch_size * c * h * (w - 1)\n    h_tv = torch.square(t[..., 1:, :] - t[..., :h-1, :]).sum()\n    w_tv = torch.square(t[..., :, 1:] - t[..., :, :w-1]).sum()\n    return 2 * (h_tv / count_h + w_tv / count_w)  # This is summing over batch and c instead of avg\n\n\ndef compute_plane_smoothness(t):\n    batch_size, c, h, w = t.shape\n    # Convolve with a second derivative filter, in the time dimension which is dimension 2\n    first_difference = t[..., 1:, :] - t[..., :h-1, :]  # [batch, c, h-1, w]\n    second_difference = first_difference[..., 1:, :] - first_difference[..., :h-2, :]  # [batch, c, h-2, w]\n    # Take the L2 norm of the result\n    return torch.square(torch.abs(second_difference)).mean()\n\n\nclass Regularizer():\n    def __init__(self, reg_type, initialization):\n        self.reg_type = reg_type\n        self.initialization = initialization\n        self.weight = float(self.initialization)\n        self.last_reg = None\n\n    def step(self, global_step):\n        pass\n\n    def report(self, d):\n        if self.last_reg is not None:\n            d[self.reg_type].update(self.last_reg.item())\n\n    def regularize(self, *args, **kwargs) -> torch.Tensor:\n        out = self._regularize(*args, **kwargs) * self.weight\n        self.last_reg = out.detach()\n        return out\n\n    @abc.abstractmethod\n    def _regularize(self, *args, **kwargs) -> torch.Tensor:\n        raise NotImplementedError()\n\n    def __str__(self):\n        return f\"Regularizer({self.reg_type}, weight={self.weight})\"\n\n\nclass PlaneTV(Regularizer):\n    def __init__(self, initial_value, what: str = 'field'):\n        if what not in {'field', 'proposal_network'}:\n            raise ValueError(f'what must be one of \"field\" or \"proposal_network\" '\n                             f'but {what} was passed.')\n        name = f'planeTV-{what[:2]}'\n        super().__init__(name, initial_value)\n        self.what = what\n\n    def step(self, global_step):\n        pass\n\n    def _regularize(self, model, **kwargs):\n        multi_res_grids: Sequence[nn.ParameterList]\n        if self.what == 'field':\n            multi_res_grids = model.field.grids\n        elif self.what == 'proposal_network':\n            multi_res_grids = [p.grids for p in model.proposal_networks]\n        else:\n            raise NotImplementedError(self.what)\n        total = 0\n        # Note: input to compute_plane_tv should be of shape [batch_size, c, h, w]\n        for grids in multi_res_grids:\n            if len(grids) == 3:\n                spatial_grids = [0, 1, 2]\n            else:\n                spatial_grids = [0, 1, 3]  # These are the spatial grids; the others are spatiotemporal\n            for grid_id in spatial_grids:\n                total += compute_plane_tv(grids[grid_id])\n            for grid in grids:\n                # grid: [1, c, h, w]\n                total += compute_plane_tv(grid)\n        return total\n\n\nclass TimeSmoothness(Regularizer):\n    def __init__(self, initial_value, what: str = 'field'):\n        if what not in {'field', 'proposal_network'}:\n            raise ValueError(f'what must be one of \"field\" or \"proposal_network\" '\n                             f'but {what} was passed.')\n        name = f'time-smooth-{what[:2]}'\n        super().__init__(name, initial_value)\n        self.what = what\n\n    def _regularize(self, model, **kwargs) -> torch.Tensor:\n        multi_res_grids: Sequence[nn.ParameterList]\n        if self.what == 'field':\n            multi_res_grids = model.field.grids\n        elif self.what == 'proposal_network':\n            multi_res_grids = [p.grids for p in model.proposal_networks]\n        else:\n            raise NotImplementedError(self.what)\n        total = 0\n        # model.grids is 6 x [1, rank * F_dim, reso, reso]\n        for grids in multi_res_grids:\n            if len(grids) == 3:\n                time_grids = []\n            else:\n                time_grids = [2, 4, 5]\n            for grid_id in time_grids:\n                total += compute_plane_smoothness(grids[grid_id])\n        return torch.as_tensor(total)\n\n\n\nclass L1ProposalNetwork(Regularizer):\n    def __init__(self, initial_value):\n        super().__init__('l1-proposal-network', initial_value)\n\n    def _regularize(self, model, **kwargs) -> torch.Tensor:\n        grids = [p.grids for p in model.proposal_networks]\n        total = 0.0\n        for pn_grids in grids:\n            for grid in pn_grids:\n                total += torch.abs(grid).mean()\n        return torch.as_tensor(total)\n\n\nclass DepthTV(Regularizer):\n    def __init__(self, initial_value):\n        super().__init__('tv-depth', initial_value)\n\n    def _regularize(self, model, model_out, **kwargs) -> torch.Tensor:\n        depth = model_out['depth']\n        tv = compute_plane_tv(\n            depth.reshape(64, 64)[None, None, :, :]\n        )\n        return tv\n\n\nclass L1TimePlanes(Regularizer):\n    def __init__(self, initial_value, what='field'):\n        if what not in {'field', 'proposal_network'}:\n            raise ValueError(f'what must be one of \"field\" or \"proposal_network\" '\n                             f'but {what} was passed.')\n        super().__init__(f'l1-time-{what[:2]}', initial_value)\n        self.what = what\n\n    def _regularize(self, model, **kwargs) -> torch.Tensor:\n        # model.grids is 6 x [1, rank * F_dim, reso, reso]\n        multi_res_grids: Sequence[nn.ParameterList]\n        if self.what == 'field':\n            multi_res_grids = model.field.grids\n        elif self.what == 'proposal_network':\n            multi_res_grids = [p.grids for p in model.proposal_networks]\n        else:\n            raise NotImplementedError(self.what)\n\n        total = 0.0\n        for grids in multi_res_grids:\n            if len(grids) == 3:\n                continue\n            else:\n                # These are the spatiotemporal grids\n                spatiotemporal_grids = [2, 4, 5]\n            for grid_id in spatiotemporal_grids:\n                total += torch.abs(1 - grids[grid_id]).mean()\n        return torch.as_tensor(total)\n\n"
  },
  {
    "path": "gs_renderer_4d.py",
    "content": "import os\nimport math\nimport numpy as np\nfrom typing import NamedTuple\nfrom plyfile import PlyData, PlyElement\n\nimport torch\nfrom torch import nn\n\nfrom diff_gauss import (\n    GaussianRasterizationSettings,\n    GaussianRasterizer,\n)\nfrom simple_knn._C import distCUDA2\n\nfrom sh_utils import eval_sh, SH2RGB, RGB2SH\n\nfrom deform import *\ndef inverse_sigmoid(x):\n    return torch.log(x/(1-x))\n\n\n\n\ndef get_expon_lr_func(\n    lr_init, lr_final, lr_delay_steps=0, lr_delay_mult=1.0, max_steps=1000000\n):\n    \n    def helper(step):\n        if lr_init == lr_final:\n            # constant lr, ignore other params\n            return lr_init\n        if step < 0 or (lr_init == 0.0 and lr_final == 0.0):\n            # Disable this parameter\n            return 0.0\n        if lr_delay_steps > 0:\n            # A kind of reverse cosine decay.\n            delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(\n                0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)\n            )\n        else:\n            delay_rate = 1.0\n        t = np.clip(step / max_steps, 0, 1)\n        log_lerp = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)\n        return delay_rate * log_lerp\n\n    return helper\n\n\ndef strip_lowerdiag(L):\n    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=\"cuda\")\n\n    uncertainty[:, 0] = L[:, 0, 0]\n    uncertainty[:, 1] = L[:, 0, 1]\n    uncertainty[:, 2] = L[:, 0, 2]\n    uncertainty[:, 3] = L[:, 1, 1]\n    uncertainty[:, 4] = L[:, 1, 2]\n    uncertainty[:, 5] = L[:, 2, 2]\n    return uncertainty\n\ndef strip_symmetric(sym):\n    return strip_lowerdiag(sym)\n\ndef gaussian_3d_coeff(xyzs, covs):\n    # xyzs: [N, 3]\n    # covs: [N, 6]\n    x, y, z = xyzs[:, 0], xyzs[:, 1], xyzs[:, 2]\n    a, b, c, d, e, f = covs[:, 0], covs[:, 1], covs[:, 2], covs[:, 3], covs[:, 4], covs[:, 5]\n\n    # eps must be small enough !!!\n    inv_det = 1 / (a * d * f + 2 * e * c * b - e**2 * a - c**2 * d - b**2 * f + 1e-24)\n    inv_a = (d * f - e**2) * inv_det\n    inv_b = (e * c - b * f) * inv_det\n    inv_c = (e * b - c * d) * inv_det\n    inv_d = (a * f - c**2) * inv_det\n    inv_e = (b * c - e * a) * inv_det\n    inv_f = (a * d - b**2) * inv_det\n\n    power = -0.5 * (x**2 * inv_a + y**2 * inv_d + z**2 * inv_f) - x * y * inv_b - x * z * inv_c - y * z * inv_e\n\n    power[power > 0] = -1e10 # abnormal values... make weights 0\n        \n    return torch.exp(power)\n\ndef build_rotation(r):\n    norm = torch.sqrt(r[:,0]*r[:,0] + r[:,1]*r[:,1] + r[:,2]*r[:,2] + r[:,3]*r[:,3])\n\n    q = r / norm[:, None]\n\n    R = torch.zeros((q.size(0), 3, 3), device='cuda')\n\n    r = q[:, 0]\n    x = q[:, 1]\n    y = q[:, 2]\n    z = q[:, 3]\n\n    R[:, 0, 0] = 1 - 2 * (y*y + z*z)\n    R[:, 0, 1] = 2 * (x*y - r*z)\n    R[:, 0, 2] = 2 * (x*z + r*y)\n    R[:, 1, 0] = 2 * (x*y + r*z)\n    R[:, 1, 1] = 1 - 2 * (x*x + z*z)\n    R[:, 1, 2] = 2 * (y*z - r*x)\n    R[:, 2, 0] = 2 * (x*z - r*y)\n    R[:, 2, 1] = 2 * (y*z + r*x)\n    R[:, 2, 2] = 1 - 2 * (x*x + y*y)\n    return R\n\ndef build_scaling_rotation(s, r):\n    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=\"cuda\")\n    R = build_rotation(r)\n\n    L[:,0,0] = s[:,0]\n    L[:,1,1] = s[:,1]\n    L[:,2,2] = s[:,2]\n\n    L = R @ L\n    return L\n\nclass BasicPointCloud(NamedTuple):\n    points: np.array\n    colors: np.array\n    normals: np.array\n\n\nclass GaussianModel:\n\n    def setup_functions(self):\n        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):\n            L = build_scaling_rotation(scaling_modifier * scaling, rotation)\n            actual_covariance = L @ L.transpose(1, 2)\n            symm = strip_symmetric(actual_covariance)\n            return symm\n        \n        self.scaling_activation = torch.exp\n        self.scaling_inverse_activation = torch.log\n\n        self.covariance_activation = build_covariance_from_scaling_rotation\n\n        self.opacity_activation = torch.sigmoid\n        self.inverse_opacity_activation = inverse_sigmoid\n\n        self.rotation_activation = torch.nn.functional.normalize\n        \n    def initialize(self, initial_values, raw=False):\n        # NOTE: actual initialization is done in trainer\n        # raw stands for raw values, i.e. not passed through activation\n        self._xyz = nn.Parameter(initial_values[\"mean\"].requires_grad_(True)).to('cuda')\n        self._rotation = nn.Parameter(initial_values[\"qvec\"].requires_grad_(True)).to('cuda')\n\n        #self._scaling = nn.Parameter(initial_values[\"svec\"].requires_grad_(True)).to('cuda')\n        #self._features_dc = nn.Parameter(initial_values[\"color\"].requires_grad_(True)).to('cuda')\n        self._opacity = nn.Parameter(initial_values[\"alpha\"].requires_grad_(True)).to('cuda')\n\n    \n    def __init__(self, sh_degree : int,args = None):\n        self.active_sh_degree = 0\n        self.max_sh_degree = sh_degree  \n        self._xyz = torch.empty(0)\n        self._features_dc = torch.empty(0)\n        self._features_rest = torch.empty(0)\n        self._scaling = torch.empty(0)\n        self._rotation = torch.empty(0)\n        self._opacity = torch.empty(0)\n        self.max_radii2D = torch.empty(0)\n        self.xyz_gradient_accum = torch.empty(0)\n        self.denom = torch.empty(0)\n        self.optimizer = None\n        self.percent_dense = 0\n        self.spatial_lr_scale = 0\n        self._deformation_table = torch.empty(0)\n        self._deformation = deform_network()\n        self.setup_functions()\n\n    def capture(self):\n        return (\n            self.active_sh_degree,\n            self._xyz,\n            self._deformation.state_dict(),\n            self._deformation_table,\n            self._features_dc,\n            self._features_rest,\n            self._scaling,\n            self._rotation,\n            self._opacity,\n            self.max_radii2D,\n            self.xyz_gradient_accum,\n            self.denom,\n            self.optimizer.state_dict(),\n            self.spatial_lr_scale,\n        )\n    \n    def restore(self, model_args, training_args):\n        (self.active_sh_degree, \n        self._xyz, \n        self._deformation_table,\n        self._deformation,\n        self._features_dc, \n        self._features_rest,\n        self._scaling, \n        self._rotation, \n        self._opacity,\n        self.max_radii2D, \n        xyz_gradient_accum, \n        denom,\n        opt_dict, \n        self.spatial_lr_scale) = model_args\n        self.training_setup(training_args)\n        self.xyz_gradient_accum = xyz_gradient_accum\n        self.denom = denom\n        self.optimizer.load_state_dict(opt_dict)\n\n    @property\n    def get_scaling(self):\n        return self.scaling_activation(self._scaling)\n    \n    @property\n    def get_rotation(self):\n        return self.rotation_activation(self._rotation)\n    \n    @property\n    def get_xyz(self):\n        return self._xyz\n    \n    @property\n    def get_features(self):\n        features_dc = self._features_dc\n        features_rest = self._features_rest\n        return torch.cat((features_dc, features_rest), dim=1)\n    \n    @property\n    def get_opacity(self):\n        return self.opacity_activation(self._opacity)\n\n\n    \n\n    \n    def get_covariance(self, scaling_modifier = 1):\n        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation)\n\n    def oneupSHdegree(self):\n        if self.active_sh_degree < self.max_sh_degree:\n            self.active_sh_degree += 1\n\n    def create_from_pcd(self, pcd : BasicPointCloud, spatial_lr_scale : float = 1):\n        self.spatial_lr_scale = spatial_lr_scale\n        fused_point_cloud = torch.tensor(np.asarray(pcd.points)).float().cuda()\n        fused_color = RGB2SH(torch.tensor(np.asarray(pcd.colors)).float().cuda())\n        features = torch.zeros((fused_color.shape[0], 3, (self.max_sh_degree + 1) ** 2)).float().cuda()\n        features[:, :3, 0 ] = fused_color\n        features[:, 3:, 1:] = 0.0\n\n        print(\"Number of points at initialisation : \", fused_point_cloud.shape[0])\n\n        dist2 = torch.clamp_min(distCUDA2(torch.from_numpy(np.asarray(pcd.points)).float().cuda()), 0.0000001)\n        scales = torch.log(torch.sqrt(dist2))[...,None].repeat(1, 3)\n        rots = torch.zeros((fused_point_cloud.shape[0], 4), device=\"cuda\")\n        rots[:, 0] = 1\n\n        opacities = inverse_sigmoid(0.1 * torch.ones((fused_point_cloud.shape[0], 1), dtype=torch.float, device=\"cuda\"))\n\n        self._xyz = nn.Parameter(fused_point_cloud.requires_grad_(True))\n        self._features_dc = nn.Parameter(features[:,:,0:1].transpose(1, 2).contiguous().requires_grad_(True))\n        self._features_rest = nn.Parameter(features[:,:,1:].transpose(1, 2).contiguous().requires_grad_(True))\n        self._scaling = nn.Parameter(scales.requires_grad_(True))\n        self._rotation = nn.Parameter(rots.requires_grad_(True))\n        self._opacity = nn.Parameter(opacities.requires_grad_(True))\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device=\"cuda\"),0)\n        #print(self._xyz.shape,self._rotation.shape)\n        self._deformation = self._deformation.to(\"cuda\") \n        self.active_sh_degree = self.max_sh_degree\n        self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device=\"cuda\")\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n    def training_setup(self, training_args):\n        self.percent_dense = training_args.percent_dense\n        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device=\"cuda\")\n        \n\n        l = [\n            {'params': [self._xyz], 'lr': training_args.position_lr_init * self.spatial_lr_scale, \"name\": \"xyz\"},\n            {'params': list(self._deformation.get_mlp_parameters()), 'lr': training_args.deformation_lr_init * self.spatial_lr_scale, \"name\": \"deformation\"},\n            {'params': list(self._deformation.get_grid_parameters()), 'lr': training_args.grid_lr_init * self.spatial_lr_scale, \"name\": \"grid\"},\n            {'params': [self._features_dc], 'lr': training_args.feature_lr, \"name\": \"f_dc\"},\n            {'params': [self._features_rest], 'lr': training_args.feature_lr / 20.0, \"name\": \"f_rest\"},\n            {'params': [self._opacity], 'lr': training_args.opacity_lr, \"name\": \"opacity\"},\n            {'params': [self._scaling], 'lr': training_args.scaling_lr, \"name\": \"scaling\"},\n            {'params': [self._rotation], 'lr': training_args.rotation_lr, \"name\": \"rotation\"}\n            \n        ]\n\n        self.optimizer = torch.optim.Adam(l, lr=0.0, eps=1e-15)\n        self.xyz_scheduler_args = get_expon_lr_func(lr_init=training_args.position_lr_init*self.spatial_lr_scale,\n                                                    lr_final=training_args.position_lr_final*self.spatial_lr_scale,\n                                                    lr_delay_mult=training_args.position_lr_delay_mult,\n                                                    max_steps=training_args.position_lr_max_steps)\n        self.deformation_scheduler_args = get_expon_lr_func(lr_init=training_args.deformation_lr_init*self.spatial_lr_scale,\n                                                    lr_final=training_args.deformation_lr_final*self.spatial_lr_scale,\n                                                    lr_delay_mult=training_args.deformation_lr_delay_mult,\n                                                    max_steps=training_args.position_lr_max_steps)    \n        self.grid_scheduler_args = get_expon_lr_func(lr_init=training_args.grid_lr_init*self.spatial_lr_scale,\n                                                    lr_final=training_args.grid_lr_final*self.spatial_lr_scale,\n                                                    lr_delay_mult=training_args.deformation_lr_delay_mult,\n                                                    max_steps=training_args.position_lr_max_steps)    \n    def update_learning_rate(self, iteration):\n        ''' Learning rate scheduling per step '''\n        for param_group in self.optimizer.param_groups:\n            if param_group[\"name\"] == \"xyz\":\n                lr = self.xyz_scheduler_args(iteration)\n                param_group['lr'] = lr\n                # return lr\n            if  \"grid\" in param_group[\"name\"]:\n                lr = self.grid_scheduler_args(iteration)\n                param_group['lr'] = lr\n                # return lr\n            elif param_group[\"name\"] == \"deformation\":\n                lr = self.deformation_scheduler_args(iteration)\n                param_group['lr'] = lr\n                # return lr\n\n    def construct_list_of_attributes(self):\n        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']\n        # All channels except the 3 DC\n        for i in range(self._features_dc.shape[1]*self._features_dc.shape[2]):\n            l.append('f_dc_{}'.format(i))\n        for i in range(self._features_rest.shape[1]*self._features_rest.shape[2]):\n            l.append('f_rest_{}'.format(i))\n        l.append('opacity')\n        for i in range(self._scaling.shape[1]):\n            l.append('scale_{}'.format(i))\n        for i in range(self._rotation.shape[1]):\n            l.append('rot_{}'.format(i))\n        return l\n\n    def save_ply(self, path):\n        os.makedirs(os.path.dirname(path), exist_ok=True)\n\n        xyz = self._xyz.detach().cpu().numpy()\n        normals = np.zeros_like(xyz)\n        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n        f_rest = self._features_rest.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()\n        opacities = self._opacity.detach().cpu().numpy()\n        scale = self._scaling.detach().cpu().numpy()\n        rotation = self._rotation.detach().cpu().numpy()\n        \n        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]\n\n        elements = np.empty(xyz.shape[0], dtype=dtype_full)\n        attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1)\n        elements[:] = list(map(tuple, attributes))\n        el = PlyElement.describe(elements, 'vertex')\n        PlyData([el]).write(path)\n\n    def compute_deformation(self,time):\n        \n        deform = self._deformation[:,:,:time].sum(dim=-1)\n        xyz = self._xyz + deform\n        return xyz\n    def load_model(self, path):\n        print(\"loading model from exists{}\".format(path))\n        weight_dict = torch.load(os.path.join(path,\"deformation.pth\"),map_location=\"cuda\")\n        self._deformation.load_state_dict(weight_dict)\n        self._deformation = self._deformation.to(\"cuda\")\n        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device=\"cuda\"),0)\n        self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device=\"cuda\")\n        if os.path.exists(os.path.join(path, \"deformation_table.pth\")):\n            self._deformation_table = torch.load(os.path.join(path, \"deformation_table.pth\"),map_location=\"cuda\")\n        if os.path.exists(os.path.join(path, \"deformation_accum.pth\")):\n            self._deformation_accum = torch.load(os.path.join(path, \"deformation_accum.pth\"),map_location=\"cuda\")\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device=\"cuda\"),0)\n        self._deformation = self._deformation.to(\"cuda\") \n        \n\n    def save_deformation(self, path):\n        torch.save(self._deformation.state_dict(),os.path.join(path, \"deformation.pth\"))\n        torch.save(self._deformation_table,os.path.join(path, \"deformation_table.pth\"))\n        torch.save(self._deformation_accum,os.path.join(path, \"deformation_accum.pth\"))\n        \n    def load_ply(self, path):\n        plydata = PlyData.read(path)\n\n        xyz = np.stack((np.asarray(plydata.elements[0][\"x\"]),\n                        np.asarray(plydata.elements[0][\"y\"]),\n                        np.asarray(plydata.elements[0][\"z\"])),  axis=1)\n        opacities = np.asarray(plydata.elements[0][\"opacity\"])[..., np.newaxis]\n\n        print(\"Number of points at loading : \", xyz.shape[0])\n\n        features_dc = np.zeros((xyz.shape[0], 3, 1))\n        features_dc[:, 0, 0] = np.asarray(plydata.elements[0][\"f_dc_0\"])\n        features_dc[:, 1, 0] = np.asarray(plydata.elements[0][\"f_dc_1\"])\n        features_dc[:, 2, 0] = np.asarray(plydata.elements[0][\"f_dc_2\"])\n\n        extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"f_rest_\")]\n        assert len(extra_f_names)==3*(self.max_sh_degree + 1) ** 2 - 3\n        features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))\n        for idx, attr_name in enumerate(extra_f_names):\n            features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])\n        # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)\n        features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))\n\n        scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"scale_\")]\n        scales = np.zeros((xyz.shape[0], len(scale_names)))\n        for idx, attr_name in enumerate(scale_names):\n            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])\n\n        rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith(\"rot\")]\n        rots = np.zeros((xyz.shape[0], len(rot_names)))\n        for idx, attr_name in enumerate(rot_names):\n            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])\n\n        self._xyz = nn.Parameter(torch.tensor(xyz, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._features_dc = nn.Parameter(torch.tensor(features_dc, dtype=torch.float, device=\"cuda\").transpose(1, 2).contiguous().requires_grad_(True))\n        self._features_rest = nn.Parameter(torch.tensor(features_extra, dtype=torch.float, device=\"cuda\").transpose(1, 2).contiguous().requires_grad_(True))\n        self._opacity = nn.Parameter(torch.tensor(opacities, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._scaling = nn.Parameter(torch.tensor(scales, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._rotation = nn.Parameter(torch.tensor(rots, dtype=torch.float, device=\"cuda\").requires_grad_(True))\n        self._deformation_table = torch.gt(torch.ones((self.get_xyz.shape[0]),device=\"cuda\"),0)\n        #print(self._xyz.shape,self._rotation.shape)\n        self._deformation = self._deformation.to(\"cuda\") \n        self.active_sh_degree = self.max_sh_degree\n        self._deformation_accum = torch.zeros((self.get_xyz.shape[0],3),device=\"cuda\")\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n        \n    def replace_tensor_to_optimizer(self, tensor, name):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            if group[\"name\"] == name:\n                stored_state = self.optimizer.state.get(group['params'][0], None)\n                stored_state[\"exp_avg\"] = torch.zeros_like(tensor)\n                stored_state[\"exp_avg_sq\"] = torch.zeros_like(tensor)\n\n                del self.optimizer.state[group['params'][0]]\n                group[\"params\"][0] = nn.Parameter(tensor.requires_grad_(True))\n                self.optimizer.state[group['params'][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n        return optimizable_tensors\n\n    def _prune_optimizer(self, mask):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            if len(group[\"params\"]) > 1:\n                continue\n            stored_state = self.optimizer.state.get(group['params'][0], None)\n            if stored_state is not None:\n                stored_state[\"exp_avg\"] = stored_state[\"exp_avg\"][mask]\n                stored_state[\"exp_avg_sq\"] = stored_state[\"exp_avg_sq\"][mask]\n\n                del self.optimizer.state[group['params'][0]]\n                group[\"params\"][0] = nn.Parameter((group[\"params\"][0][mask].requires_grad_(True)))\n                self.optimizer.state[group['params'][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n            else:\n                group[\"params\"][0] = nn.Parameter(group[\"params\"][0][mask].requires_grad_(True))\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n        return optimizable_tensors\n\n    def prune_points(self, mask):\n        valid_points_mask = ~mask\n        optimizable_tensors = self._prune_optimizer(valid_points_mask)\n\n        self._xyz = optimizable_tensors[\"xyz\"]\n        self._features_dc = optimizable_tensors[\"f_dc\"]\n        self._features_rest = optimizable_tensors[\"f_rest\"]\n        self._opacity = optimizable_tensors[\"opacity\"]\n        self._scaling = optimizable_tensors[\"scaling\"]\n        self._rotation = optimizable_tensors[\"rotation\"]\n        self._deformation_accum = self._deformation_accum[valid_points_mask]\n        self.xyz_gradient_accum = self.xyz_gradient_accum[valid_points_mask]\n        self._deformation_table = self._deformation_table[valid_points_mask]\n        self.denom = self.denom[valid_points_mask]\n        self.max_radii2D = self.max_radii2D[valid_points_mask]\n\n    def cat_tensors_to_optimizer(self, tensors_dict):\n        optimizable_tensors = {}\n        for group in self.optimizer.param_groups:\n            if len(group[\"params\"])>1:continue\n            assert len(group[\"params\"]) == 1\n            extension_tensor = tensors_dict[group[\"name\"]]\n            stored_state = self.optimizer.state.get(group['params'][0], None)\n            if stored_state is not None:\n\n                stored_state[\"exp_avg\"] = torch.cat((stored_state[\"exp_avg\"], torch.zeros_like(extension_tensor)), dim=0)\n                stored_state[\"exp_avg_sq\"] = torch.cat((stored_state[\"exp_avg_sq\"], torch.zeros_like(extension_tensor)), dim=0)\n\n                del self.optimizer.state[group['params'][0]]\n                group[\"params\"][0] = nn.Parameter(torch.cat((group[\"params\"][0], extension_tensor), dim=0).requires_grad_(True))\n                self.optimizer.state[group['params'][0]] = stored_state\n\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n            else:\n                group[\"params\"][0] = nn.Parameter(torch.cat((group[\"params\"][0], extension_tensor), dim=0).requires_grad_(True))\n                optimizable_tensors[group[\"name\"]] = group[\"params\"][0]\n\n        return optimizable_tensors\n\n    def densification_postfix(self, new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table):\n        d = {\"xyz\": new_xyz,\n        \"f_dc\": new_features_dc,\n        \"f_rest\": new_features_rest,\n        \"opacity\": new_opacities,\n        \"scaling\" : new_scaling,\n        \"rotation\" : new_rotation,\n        # \"deformation\": new_deformation\n       }\n\n        optimizable_tensors = self.cat_tensors_to_optimizer(d)\n        self._xyz = optimizable_tensors[\"xyz\"]\n        self._features_dc = optimizable_tensors[\"f_dc\"]\n        self._features_rest = optimizable_tensors[\"f_rest\"]\n        self._opacity = optimizable_tensors[\"opacity\"]\n        self._scaling = optimizable_tensors[\"scaling\"]\n        self._rotation = optimizable_tensors[\"rotation\"]\n        # self._deformation = optimizable_tensors[\"deformation\"]\n        \n        self._deformation_table = torch.cat([self._deformation_table,new_deformation_table],-1)\n        self.xyz_gradient_accum = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self._deformation_accum = torch.zeros((self.get_xyz.shape[0], 3), device=\"cuda\")\n        self.denom = torch.zeros((self.get_xyz.shape[0], 1), device=\"cuda\")\n        self.max_radii2D = torch.zeros((self.get_xyz.shape[0]), device=\"cuda\")\n\n    def densify_and_split(self, grads, grad_threshold, scene_extent, N=2):\n        n_init_points = self.get_xyz.shape[0]\n        # Extract points that satisfy the gradient condition\n        padded_grad = torch.zeros((n_init_points), device=\"cuda\")\n        padded_grad[:grads.shape[0]] = grads.squeeze()\n        selected_pts_mask = torch.where(padded_grad >= grad_threshold, True, False)\n        selected_pts_mask = torch.logical_and(selected_pts_mask,\n                                              torch.max(self.get_scaling, dim=1).values > self.percent_dense*scene_extent)\n        if not selected_pts_mask.any():\n            return\n        stds = self.get_scaling[selected_pts_mask].repeat(N,1)\n        means =torch.zeros((stds.size(0), 3),device=\"cuda\")\n        samples = torch.normal(mean=means, std=stds)\n        rots = build_rotation(self._rotation[selected_pts_mask]).repeat(N,1,1)\n        new_xyz = torch.bmm(rots, samples.unsqueeze(-1)).squeeze(-1) + self.get_xyz[selected_pts_mask].repeat(N, 1)\n        new_scaling = self.scaling_inverse_activation(self.get_scaling[selected_pts_mask].repeat(N,1) / (0.8*N))\n        new_rotation = self._rotation[selected_pts_mask].repeat(N,1)\n        new_features_dc = self._features_dc[selected_pts_mask].repeat(N,1,1)\n        new_features_rest = self._features_rest[selected_pts_mask].repeat(N,1,1)\n        new_opacity = self._opacity[selected_pts_mask].repeat(N,1)\n        new_deformation_table = self._deformation_table[selected_pts_mask].repeat(N)\n        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacity, new_scaling, new_rotation, new_deformation_table)\n\n        prune_filter = torch.cat((selected_pts_mask, torch.zeros(N * selected_pts_mask.sum(), device=\"cuda\", dtype=bool)))\n        self.prune_points(prune_filter)\n        \n    def densify_and_clone(self, grads, grad_threshold, scene_extent):\n        selected_pts_mask = torch.where(torch.norm(grads, dim=-1) >= grad_threshold, True, False)\n        selected_pts_mask = torch.logical_and(selected_pts_mask,\n                                              torch.max(self.get_scaling, dim=1).values <= self.percent_dense*scene_extent)\n        \n        new_xyz = self._xyz[selected_pts_mask] \n        # - 0.001 * self._xyz.grad[selected_pts_mask]\n        new_features_dc = self._features_dc[selected_pts_mask]\n        new_features_rest = self._features_rest[selected_pts_mask]\n        new_opacities = self._opacity[selected_pts_mask]\n        new_scaling = self._scaling[selected_pts_mask]\n        new_rotation = self._rotation[selected_pts_mask]\n        new_deformation_table = self._deformation_table[selected_pts_mask]\n\n        self.densification_postfix(new_xyz, new_features_dc, new_features_rest, new_opacities, new_scaling, new_rotation, new_deformation_table)\n\n    def densify_and_prune(self, max_grad_percent, min_opacity, extent, max_screen_size):\n        grads = self.xyz_gradient_accum / self.denom\n        grads[grads.isnan()] = 0.0\n        grad_log = torch.log(grads)\n        grad_log2=grad_log[~grad_log.isnan()]\n        grad_log3=grad_log[~grad_log2.isinf()]\n\n        max_grad_1 = torch.exp(grad_log3.mean()+grad_log3.var()) #adaptive densification with mean and var, unused\n        max_grad_2 = torch.exp(grad_log3.squeeze(dim=1).sort(descending=True)[0][int(max_grad_percent*grad_log3.shape[0])]) #adaptive densification with relative grad\n        max_grad = max_grad_2 #choose which to use\n\n        #print('max_grad',max_grad_percent,max_grad_1,max_grad_2,grad_log3.mean(),grad_log3.var())\n        self.densify_and_clone(grads, max_grad, extent)\n        self.densify_and_split(grads, max_grad, extent)\n\n        prune_mask = (self.get_opacity < min_opacity).squeeze()\n\n        if max_screen_size:\n            big_points_vs = self.max_radii2D > max_screen_size\n            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent\n            small_ws = self.get_scaling.max(dim=1).values<0.001\n            prune_mask = torch.logical_or(torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws),small_ws)\n        self.prune_points(prune_mask)\n\n\n        torch.cuda.empty_cache()\n\n\n    \n    def prune(self, min_opacity=0.01, extent=1, max_screen_size=1):\n\n        prune_mask = (self.get_opacity < min_opacity).squeeze()\n        # prune_mask_2 = torch.logical_and(self.get_opacity <= inverse_sigmoid(0.101 , dtype=torch.float, device=\"cuda\"), self.get_opacity >= inverse_sigmoid(0.999 , dtype=torch.float, device=\"cuda\"))\n        # prune_mask = torch.logical_or(prune_mask, prune_mask_2)\n        # deformation_sum = abs(self._deformation).sum(dim=-1).mean(dim=-1) \n        # deformation_mask = (deformation_sum < torch.quantile(deformation_sum, torch.tensor([0.5]).to(\"cuda\")))\n        # prune_mask = prune_mask & deformation_mask\n        if max_screen_size:\n            big_points_vs = self.max_radii2D > max_screen_size\n            big_points_ws = self.get_scaling.max(dim=1).values > 0.1 * extent\n            #prune_mask = torch.logical_or(prune_mask, big_points_vs)\n            small_ws = self.get_scaling.min(dim=1).values<0.001\n            prune_mask = torch.logical_or(torch.logical_or(torch.logical_or(prune_mask, big_points_vs), big_points_ws),small_ws)\n        self.prune_points(prune_mask)\n        \n\n\n\n    def standard_constaint(self):\n        \n        means3D = self._xyz.detach()\n        scales = self._scaling.detach()\n        rotations = self._rotation.detach()\n        opacity = self._opacity.detach()\n        time =  torch.tensor(0).to(\"cuda\").repeat(means3D.shape[0],1)\n        means3D_deform, scales_deform, rotations_deform, _ = self._deformation(means3D, scales, rotations, opacity, time)\n        position_error = (means3D_deform - means3D)**2\n        rotation_error = (rotations_deform - rotations)**2 \n        scaling_erorr = (scales_deform - scales)**2\n        return position_error.mean() + rotation_error.mean() + scaling_erorr.mean()\n\n\n    def add_densification_stats(self, viewspace_point_tensor, update_filter):\n        #print(viewspace_point_tensor,update_filter)\n        self.xyz_gradient_accum[update_filter] += torch.norm(viewspace_point_tensor[update_filter,:2], dim=-1, keepdim=True)\n        self.denom[update_filter] += 1\n        \n    @torch.no_grad()\n    def update_deformation_table(self,threshold):\n        # print(\"origin deformation point nums:\",self._deformation_table.sum())\n        self._deformation_table = torch.gt(self._deformation_accum.max(dim=-1).values/100,threshold)\n    def print_deformation_weight_grad(self):\n        for name, weight in self._deformation.named_parameters():\n            if weight.requires_grad:\n                if weight.grad is None:\n                    \n                    print(name,\" :\",weight.grad)\n                else:\n                    if weight.grad.mean() != 0:\n                        print(name,\" :\",weight.grad.mean(), weight.grad.min(), weight.grad.max())\n        print(\"-\"*50)\n    def _plane_regulation(self):\n        multi_res_grids = self._deformation.deformation_net.grid.grids\n        total = 0\n        # model.grids is 6 x [1, rank * F_dim, reso, reso]\n        for grids in multi_res_grids:\n            if len(grids) == 3:\n                time_grids = []\n            else:\n                time_grids =  [0,1,3]\n            for grid_id in time_grids:\n                total += compute_plane_smoothness(grids[grid_id])\n        return total\n    def _time_regulation(self):\n        multi_res_grids = self._deformation.deformation_net.grid.grids\n        total = 0\n        # model.grids is 6 x [1, rank * F_dim, reso, reso]\n        for grids in multi_res_grids:\n            if len(grids) == 3:\n                time_grids = []\n            else:\n                time_grids =[2, 4, 5]\n            for grid_id in time_grids:\n                total += compute_plane_smoothness(grids[grid_id])\n        return total\n    def _l1_regulation(self):\n                # model.grids is 6 x [1, rank * F_dim, reso, reso]\n        multi_res_grids = self._deformation.deformation_net.grid.grids\n\n        total = 0.0\n        for grids in multi_res_grids:\n            if len(grids) == 3:\n                continue\n            else:\n                # These are the spatiotemporal grids\n                spatiotemporal_grids = [2, 4, 5]\n            for grid_id in spatiotemporal_grids:\n                total += torch.abs(1 - grids[grid_id]).mean()\n        return total\n    def compute_regulation(self, time_smoothness_weight, l1_time_planes_weight, plane_tv_weight):\n        return plane_tv_weight * self._plane_regulation() + time_smoothness_weight * self._time_regulation() + l1_time_planes_weight * self._l1_regulation()\n\ndef getProjectionMatrix(znear, zfar, fovX, fovY):\n    tanHalfFovY = math.tan((fovY / 2))\n    tanHalfFovX = math.tan((fovX / 2))\n\n    P = torch.zeros(4, 4)\n\n    z_sign = 1.0\n\n    P[0, 0] = 1 / tanHalfFovX\n    P[1, 1] = 1 / tanHalfFovY\n    P[3, 2] = z_sign\n    P[2, 2] = z_sign * zfar / (zfar - znear)\n    P[2, 3] = -(zfar * znear) / (zfar - znear)\n    return P\n\n\nclass MiniCam:\n    def __init__(self, c2w, width, height, fovy, fovx, znear, zfar,time=0 ):\n        # c2w (pose) should be in NeRF convention.\n\n        self.image_width = width\n        self.image_height = height\n        self.FoVy = fovy\n        self.FoVx = fovx\n        self.znear = znear\n        self.zfar = zfar\n        self.time = time\n        w2c = np.linalg.inv(c2w)\n\n        # rectify...\n        w2c[1:3, :3] *= -1\n        w2c[:3, 3] *= -1\n\n        self.world_view_transform = torch.tensor(w2c).transpose(0, 1).cuda()\n        self.projection_matrix = (\n            getProjectionMatrix(\n                znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy\n            )\n            .transpose(0, 1)\n            .cuda()\n        )\n        self.full_proj_transform = self.world_view_transform @ self.projection_matrix\n        self.camera_center = -torch.tensor(c2w[:3, 3]).cuda()\n\n\nclass Renderer:\n    def __init__(self, sh_degree=3, white_background=True, radius=1):\n        \n        self.sh_degree = sh_degree\n        self.white_background = white_background\n        self.radius = radius\n\n        self.gaussians = GaussianModel(sh_degree)\n\n        self.bg_color = torch.tensor(\n            [1, 1, 1] if white_background else [0, 0, 0],\n            dtype=torch.float32,\n            device=\"cuda\",\n        )\n    \n    def initialize(self, input=None, num_pts=5000, radius=0.5,initial_values=None):\n        # load checkpoint\n        if input is None:\n            # init from random point cloud\n            \n            phis = np.random.random((num_pts,)) * 2 * np.pi\n            costheta = np.random.random((num_pts,)) * 2 - 1\n            thetas = np.arccos(costheta)\n            mu = np.random.random((num_pts,))\n            radius = radius * np.cbrt(mu)\n            x = radius * np.sin(thetas) * np.cos(phis)\n            y = radius * np.sin(thetas) * np.sin(phis)\n            z = radius * np.cos(thetas)\n            xyz = np.stack((x, y, z), axis=1)\n            if initial_values is not None:\n                print(xyz.shape,initial_values[\"mean\"].shape)\n                R = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]])\n                xyz = np.dot(initial_values[\"mean\"].numpy(),R)\n            # xyz = np.random.random((num_pts, 3)) * 2.6 - 1.3\n\n            shs = np.random.random((num_pts, 3)) / 255.0\n            pcd = BasicPointCloud(\n                points=xyz, colors=SH2RGB(shs), normals=np.zeros((num_pts, 3))\n            )\n            self.gaussians.create_from_pcd(pcd, 10)\n        elif isinstance(input, BasicPointCloud):\n            # load from a provided pcd\n            self.gaussians.create_from_pcd(input, 1)\n        else:\n            # load from saved ply\n            self.gaussians.load_ply(input)\n\n    def render(\n        self,\n        viewpoint_camera,\n        scaling_modifier=1.0,\n        bg_color=None,\n        override_color=None,\n        compute_cov3D_python=False,\n        convert_SHs_python=False,\n        stage=\"fine\",\n        time_int = None,\n        front_view=False,\n    ):\n\n \n        # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means\n        screenspace_points = torch.zeros_like(self.gaussians.get_xyz, dtype=self.gaussians.get_xyz.dtype, requires_grad=True, device=\"cuda\") + 0\n        try:\n            screenspace_points.retain_grad()\n        except:\n            pass\n\n        # Set up rasterization configuration\n\n        tanfovx = math.tan(viewpoint_camera.FoVx * 0.5)\n        tanfovy = math.tan(viewpoint_camera.FoVy * 0.5)\n\n        raster_settings = GaussianRasterizationSettings(\n            image_height=int(viewpoint_camera.image_height),\n            image_width=int(viewpoint_camera.image_width),\n            tanfovx=tanfovx,\n            tanfovy=tanfovy,\n            bg=self.bg_color if bg_color is None else bg_color,\n            scale_modifier=scaling_modifier,\n            viewmatrix=viewpoint_camera.world_view_transform,\n            projmatrix=viewpoint_camera.full_proj_transform,\n            sh_degree=self.gaussians.active_sh_degree,\n            campos=viewpoint_camera.camera_center,\n            prefiltered=False,\n            debug=False,\n        )\n        if front_view==True:\n            print(viewpoint_camera.world_view_transform,viewpoint_camera.full_proj_transform)\n        rasterizer = GaussianRasterizer(raster_settings=raster_settings)\n\n\n        means3D = self.gaussians.get_xyz\n        time = torch.tensor(viewpoint_camera.time).to(means3D.device).repeat(means3D.shape[0],1)\n        means2D = screenspace_points\n        opacity = self.gaussians._opacity\n\n        # If precomputed 3d covariance is provided, use it. If not, then it will be computed from\n        # scaling / rotation by the rasterizer.\n        scales = None\n        rotations = None\n        cov3D_precomp = None\n        if compute_cov3D_python:\n            cov3D_precomp = self.gaussians.get_covariance(scaling_modifier)\n        else:\n            scales = self.gaussians._scaling\n            rotations = self.gaussians._rotation\n        deformation_point = self.gaussians._deformation_table\n\n        if stage == \"coarse\" :\n            means3D_deform, scales_deform, rotations_deform, opacity_deform = means3D, scales, rotations, opacity\n        else:\n            means3D_deform, scales_deform, rotations_deform, opacity_deform = self.gaussians._deformation(means3D[deformation_point], scales[deformation_point], \n                                                                         rotations[deformation_point], opacity[deformation_point],\n                                                                         time[deformation_point])\n        # print(time.max())\n        with torch.no_grad():\n            self.gaussians._deformation_accum[deformation_point] += torch.abs(means3D_deform-means3D[deformation_point])\n        #print(torch.abs(means3D_deform-means3D[deformation_point]).mean())\n        means3D_final = torch.zeros_like(means3D)\n        rotations_final = torch.zeros_like(rotations)\n        scales_final = torch.zeros_like(scales)\n        opacity_final = torch.zeros_like(opacity)\n        means3D_final[deformation_point] =  means3D_deform\n        rotations_final[deformation_point] =  rotations_deform\n        scales_final[deformation_point] =  scales_deform\n        opacity_final[deformation_point] = opacity_deform\n        means3D_final[~deformation_point] = means3D[~deformation_point]\n        rotations_final[~deformation_point] = rotations[~deformation_point]\n        scales_final[~deformation_point] = scales[~deformation_point]\n        opacity_final[~deformation_point] = opacity[~deformation_point]\n\n        \n        scales_in=scales_final\n        rotations_in=rotations_final\n        opacity_in = opacity_final\n        \n        scales_final = self.gaussians.scaling_activation(scales_final)\n        rotations_final = self.gaussians.rotation_activation(rotations_final)\n        opacity = self.gaussians.opacity_activation(opacity)\n        opacity_final = self.gaussians.opacity_activation(opacity_final)\n\n        # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors\n        # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer.\n        shs = None\n        colors_precomp = None\n\n        shs = self.gaussians.get_features\n\n        # Rasterize visible Gaussians to image, obtain their radii (on screen).\n        rendered_image, rendered_depth, normal, rendered_alpha ,radii, _ = rasterizer(\n        means3D = means3D_final,\n        means2D = means2D,\n        shs = shs,\n        colors_precomp = colors_precomp,\n        opacities = opacity_final,\n        scales = scales_final,\n        rotations = rotations_final,\n        cov3Ds_precomp = cov3D_precomp)\n\n\n        return {\n            \"image\": rendered_image,\n            \"depth\": rendered_depth,\n            \"alpha\": rendered_alpha,\n            \"viewspace_points\": screenspace_points,\n            \"visibility_filter\": radii > 0,\n            \"radii\": radii,\n            'xyz':means3D_final,\n            'rot':rotations_in,\n            'xy':means2D,\n            'color':shs,\n            'scales':scales_in,\n            'opacity':opacity_in,\n        }\n\n"
  },
  {
    "path": "guidance/zero123_4d_utils.py",
    "content": "from transformers import CLIPTextModel, CLIPTokenizer, logging\nfrom diffusers import (\n    AutoencoderKL,\n    UNet2DConditionModel,\n    DDIMScheduler,\n    StableDiffusionPipeline,\n)\nimport torchvision.transforms.functional as TF\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\nimport sys\nsys.path.append('./')\n\nfrom zero123 import Zero123Pipeline\n\n\nclass Zero123(nn.Module):\n    def __init__(self, device, fp16=True, t_range=[0.02, 0.98]):\n        super().__init__()\n\n        self.device = device\n        self.fp16 = fp16\n        self.dtype = torch.float16 if fp16 else torch.float32\n\n        self.pipe = Zero123Pipeline.from_pretrained(            \n            # \"bennyguo/zero123-diffusers\",\n            \"ashawkey/zero123-xl-diffusers\",\n            # './model_cache/zero123_xl',\n            variant=\"fp16\" if self.fp16 else None,\n            torch_dtype=self.dtype,\n        ).to(self.device)\n\n        # for param in self.pipe.parameters():\n        #     param.requires_grad = False\n\n        self.pipe.image_encoder.eval()\n        self.pipe.vae.eval()\n        self.pipe.unet.eval()\n        self.pipe.clip_camera_projection.eval()\n\n        self.vae = self.pipe.vae\n        self.unet = self.pipe.unet\n\n        self.pipe.set_progress_bar_config(disable=True)\n\n        self.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)\n        self.num_train_timesteps = self.scheduler.config.num_train_timesteps\n\n        self.min_step = int(self.num_train_timesteps * t_range[0])\n        self.max_step = int(self.num_train_timesteps * t_range[1])\n        self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience\n        self.min_step_percent = [0, 0.5, 0.02, 3000] \n        self.max_step_percent= [0, 0.95, 0.5, 3000]\n        self.embeddings = None\n        self.embedding_list = []\n\n    @torch.no_grad()\n    def get_img_embeds(self, x, input_imgs=None):\n        # x: image tensor in [0, 1]\n        x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False)\n        x_pil = [TF.to_pil_image(image) for image in x]\n        x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors=\"pt\").pixel_values.to(device=self.device, dtype=self.dtype)\n        c = self.pipe.image_encoder(x_clip).image_embeds\n        v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor\n        self.embeddings = [c, v]\n        self.additional_embeddings=[]\n        if input_imgs!=None:\n            \n            for x in input_imgs:\n                    x = F.interpolate(x, (256, 256), mode='bilinear', align_corners=False)\n                    x_pil = [TF.to_pil_image(image) for image in x]\n                    x_clip = self.pipe.feature_extractor(images=x_pil, return_tensors=\"pt\").pixel_values.to(device=self.device, dtype=self.dtype)\n                    c = self.pipe.image_encoder(x_clip).image_embeds\n                    v = self.encode_imgs(x.to(self.dtype)) / self.vae.config.scaling_factor\n                    embeddings = [c, v]\n                    self.additional_embeddings.append(embeddings)\n        self.embedding_list.append([self.embeddings,self.additional_embeddings])\n\n    @torch.no_grad()\n    def refine(self, pred_rgb, polar, azimuth, radius, \n               guidance_scale=5, steps=50, strength=0.8,\n        ):\n\n        batch_size = pred_rgb.shape[0]\n\n        self.scheduler.set_timesteps(steps)\n\n        if strength == 0:\n            init_step = 0\n            latents = torch.randn((1, 4, 32, 32), device=self.device, dtype=self.dtype)\n        else:\n            init_step = int(steps * strength)\n            pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)\n            latents = self.encode_imgs(pred_rgb_256.to(self.dtype))\n            latents = self.scheduler.add_noise(latents, torch.randn_like(latents), self.scheduler.timesteps[init_step])\n\n        T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1)\n        T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4]\n        cc_emb = torch.cat([self.embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)\n        cc_emb = self.pipe.clip_camera_projection(cc_emb)\n        cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)\n\n        vae_emb = self.embeddings[1].repeat(batch_size, 1, 1, 1)\n        vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)\n\n        for i, t in enumerate(self.scheduler.timesteps[init_step:]):\n            \n            x_in = torch.cat([latents] * 2)\n            t_in = torch.cat([t.view(1)] * 2).to(self.device)\n\n            noise_pred = self.unet(\n                torch.cat([x_in, vae_emb], dim=1),\n                t_in.to(self.unet.dtype),\n                encoder_hidden_states=cc_emb,\n            ).sample\n\n            noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n            \n            latents = self.scheduler.step(noise_pred, t, latents).prev_sample\n\n        imgs = self.decode_latents(latents) # [1, 3, 256, 256]\n        return imgs\n    \n    def train_step(self, pred_rgb, polar, azimuth, radius, step_ratio=None, guidance_scale=5, as_latent=False,idx=None,t=0):\n        # pred_rgb: tensor [1, 3, H, W] in [0, 1]\n        #print(polar)\n        step_ratio = max(0.4,step_ratio)\n        self.embeddings,self.additional_embeddings = self.embedding_list[t]\n        batch_size = pred_rgb.shape[0]\n        #print(self.embedding_list[1][0][0] -self.embedding_list[2][0][0])\n        #print(self.embedding_list[1][0][1] -self.embedding_list[2][0][1])\n        if idx is not None:\n            embeddings = self.additional_embeddings[idx]\n        else:\n            embeddings = self.embeddings\n\n        if as_latent:\n            latents = F.interpolate(pred_rgb, (32, 32), mode='bilinear', align_corners=False) * 2 - 1\n        else:\n            pred_rgb_256 = F.interpolate(pred_rgb, (256, 256), mode='bilinear', align_corners=False)\n            latents = self.encode_imgs(pred_rgb_256.to(self.dtype))\n\n\n        t = torch.randint(self.min_step, self.max_step + 1, (batch_size,), dtype=torch.long, device=self.device)\n\n        w = (1 - self.alphas[t]).view(batch_size, 1, 1, 1)\n\n        with torch.no_grad():\n            noise = torch.randn_like(latents)\n            latents_noisy = self.scheduler.add_noise(latents, noise, t)\n\n            x_in = torch.cat([latents_noisy] * 2)\n            t_in = torch.cat([t] * 2)\n\n            T = np.stack([np.deg2rad(polar), np.sin(np.deg2rad(azimuth)), np.cos(np.deg2rad(azimuth)), radius], axis=-1)\n            T = torch.from_numpy(T).unsqueeze(1).to(self.dtype).to(self.device) # [8, 1, 4]\n            #print(self.embeddings[0].repeat(batch_size, 1, 1).shape,T.shape)\n            cc_emb = torch.cat([embeddings[0].repeat(batch_size, 1, 1), T], dim=-1)\n            cc_emb = self.pipe.clip_camera_projection(cc_emb)\n            cc_emb = torch.cat([cc_emb, torch.zeros_like(cc_emb)], dim=0)\n\n            vae_emb = embeddings[1].repeat(batch_size, 1, 1, 1)\n            vae_emb = torch.cat([vae_emb, torch.zeros_like(vae_emb)], dim=0)\n\n            noise_pred = self.unet(\n                torch.cat([x_in, vae_emb], dim=1),\n                t_in.to(self.unet.dtype),\n                encoder_hidden_states=cc_emb,\n            ).sample\n\n        noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2)\n        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)\n\n        grad = w * (noise_pred - noise)\n        grad = torch.nan_to_num(grad)\n\n        target = (latents - grad).detach()\n        loss = 0.5 * F.mse_loss(latents.float(), target, reduction='sum')\n\n        return loss\n    def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):\n        min_step_percent = self.get_steps(self.min_step_percent, epoch, global_step)\n        max_step_percent = self.get_steps(self.max_step_percent, epoch, global_step)\n        self.min_step = int( self.num_train_timesteps * min_step_percent )\n        self.max_step = int( self.num_train_timesteps * max_step_percent )\n        \n        \n    def get_steps(self,percent,epoch, global_step):\n        start_step, start_value, end_value, end_step = percent\n        \n        current_step = global_step\n        value = start_value + (end_value - start_value) * max(\n                min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0\n            )\n        \n        return value\n\n    def decode_latents(self, latents):\n        latents = 1 / self.vae.config.scaling_factor * latents\n\n        imgs = self.vae.decode(latents).sample\n        imgs = (imgs / 2 + 0.5).clamp(0, 1)\n\n        return imgs\n\n    def encode_imgs(self, imgs, mode=False):\n        # imgs: [B, 3, H, W]\n\n        imgs = 2 * imgs - 1\n\n        posterior = self.vae.encode(imgs).latent_dist\n        if mode:\n            latents = posterior.mode()\n        else:\n            latents = posterior.sample() \n        latents = latents * self.vae.config.scaling_factor\n\n        return latents\n    \n    \nif __name__ == '__main__':\n    import cv2\n    import argparse\n    import numpy as np\n    import matplotlib.pyplot as plt\n    \n    parser = argparse.ArgumentParser()\n\n    parser.add_argument('input', type=str)\n    parser.add_argument('--polar', type=float, default=0, help='delta polar angle in [-90, 90]')\n    parser.add_argument('--azimuth', type=float, default=0, help='delta azimuth angle in [-180, 180]')\n    parser.add_argument('--radius', type=float, default=0, help='delta camera radius multiplier in [-0.5, 0.5]')\n\n    opt = parser.parse_args()\n\n    device = torch.device('cuda')\n\n    print(f'[INFO] loading image from {opt.input} ...')\n    image = cv2.imread(opt.input, cv2.IMREAD_UNCHANGED)\n    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n    image = cv2.resize(image, (256, 256), interpolation=cv2.INTER_AREA)\n    image = image.astype(np.float32) / 255.0\n    image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).contiguous().to(device)\n\n    print(f'[INFO] loading model ...')\n    zero123 = Zero123(device)\n\n    print(f'[INFO] running model ...')\n    zero123.get_img_embeds(image)\n\n    while True:\n        outputs = zero123.refine(image, polar=[opt.polar], azimuth=[opt.azimuth], radius=[opt.radius], strength=0)\n        plt.imshow(outputs.float().cpu().numpy().transpose(0, 2, 3, 1)[0])\n        plt.show()"
  },
  {
    "path": "guidance/zero123pp/pipeline.py",
    "content": "from typing import Any, Dict, Optional\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\n\nimport numpy\nimport torch\nimport torch.nn as nn\nimport torch.utils.checkpoint\nimport torch.distributed\nimport transformers\nfrom collections import OrderedDict\nfrom PIL import Image\nfrom torchvision import transforms\nfrom transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer\nimport torch\nimport torch.nn.functional as F\nfrom torch import nn\nimport diffusers\nfrom diffusers import (\n    AutoencoderKL,\n    DDPMScheduler,\n    DiffusionPipeline,\n    EulerAncestralDiscreteScheduler,\n    UNet2DConditionModel,\n    ImagePipelineOutput\n)\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor\nfrom diffusers.utils.import_utils import is_xformers_available\n\nimport os\nFIRST = True\nIDX = 0\nPATH = '/home/vision/github/embeddings/'\nEMBED=[]\n\ndef to_rgb_image(maybe_rgba: Image.Image):\n    if maybe_rgba.mode == 'RGB':\n        return maybe_rgba\n    elif maybe_rgba.mode == 'RGBA':\n        rgba = maybe_rgba\n        img = numpy.random.randint(127, 128, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8)\n        img = Image.fromarray(img, 'RGB')\n        img.paste(rgba, mask=rgba.getchannel('A'))\n        return img\n    else:\n        raise ValueError(\"Unsupported image type.\", maybe_rgba.mode)\n    \nclass MyAttnProcessor2_0:\n    r\"\"\"\n    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).\n    \"\"\"\n\n    def __init__(self):\n        if not hasattr(F, \"scaled_dot_product_attention\"):\n            raise ImportError(\"MyAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.\")\n\n    def __call__(\n        self,\n        attn: Attention,\n        hidden_states,\n        encoder_hidden_states=None,\n        attention_mask=None,\n        temb=None,\n        is_self_attention=False\n    ):\n        residual = hidden_states\n\n        if attn.spatial_norm is not None:\n            hidden_states = attn.spatial_norm(hidden_states, temb)\n\n        input_ndim = hidden_states.ndim\n\n        if input_ndim == 4:\n            batch_size, channel, height, width = hidden_states.shape\n            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)\n\n        batch_size, sequence_length, _ = (\n            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape\n        )\n\n        if attention_mask is not None:\n            attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)\n            # scaled_dot_product_attention expects attention_mask shape to be\n            # (batch, heads, source_length, target_length)\n            attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])\n\n        if attn.group_norm is not None:\n            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)\n\n        query = attn.to_q(hidden_states)\n\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        elif attn.norm_cross:\n            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)\n        \n        if is_self_attention:\n            global FIRST\n            global IDX\n            global EMBED\n            if FIRST == True:\n                EMBED.append(encoder_hidden_states.to('cpu'))\n                #print('saving to {})'.format(PATH,str(IDX)+'_hidden.pt'))\n                #os.makedirs(PATH,exist_ok=True)\n                #torch.save(encoder_hidden_states,os.path.join(PATH,str(IDX)+'_hidden.pt'))\n                IDX=IDX+1\n            else:\n                last_shape = encoder_hidden_states.shape[-1]\n                replace_dim = int(9600/(last_shape//320)**2)\n                encoder_hidden_states_load = EMBED[IDX].to('cuda')\n                encoder_hidden_states[:,:replace_dim,:]=(encoder_hidden_states_load[:,:replace_dim,:]+encoder_hidden_states[:,:replace_dim,:])/2\n                IDX=IDX+1\n        key = attn.to_k(encoder_hidden_states)\n        value = attn.to_v(encoder_hidden_states)\n        \n\n        inner_dim = key.shape[-1]\n        head_dim = inner_dim // attn.heads\n\n        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n\n        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)\n        #print(key.shape)\n\n\n        # the output of sdp = (batch, num_heads, seq_len, head_dim)\n        # TODO: add support for attn.scale when we move to Torch 2.1\n        hidden_states = F.scaled_dot_product_attention(\n            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False\n        )\n\n        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)\n        hidden_states = hidden_states.to(query.dtype)\n\n        # linear proj\n        hidden_states = attn.to_out[0](hidden_states)\n        # dropout\n        hidden_states = attn.to_out[1](hidden_states)\n\n        if input_ndim == 4:\n            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)\n\n        if attn.residual_connection:\n            hidden_states = hidden_states + residual\n\n        hidden_states = hidden_states / attn.rescale_output_factor\n\n        return hidden_states\n\nclass ReferenceOnlyAttnProc(torch.nn.Module):\n    def __init__(\n        self,\n        chained_proc,\n        enabled=False,\n        name=None\n    ) -> None:\n        super().__init__()\n        self.enabled = enabled\n        self.chained_proc = chained_proc\n        self.name = name\n\n    def __call__(\n        self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None,\n        mode=\"w\", ref_dict: dict = None, is_cfg_guidance = False\n    ) -> Any:\n        if encoder_hidden_states is None:\n            encoder_hidden_states = hidden_states\n        is_self_attention = False\n        if self.enabled and is_cfg_guidance:\n            res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask)\n            hidden_states = hidden_states[1:]\n            encoder_hidden_states = encoder_hidden_states[1:]\n        if self.enabled:\n            if mode == 'w':\n                ref_dict[self.name] = encoder_hidden_states\n            elif mode == 'r':\n                encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)\n                is_self_attention = True\n            elif mode == 'm':\n                encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1)\n            else:\n                assert False, mode\n        res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask,is_self_attention=is_self_attention)\n        if self.enabled and is_cfg_guidance:\n            res = torch.cat([res0, res])\n        return res\n\n\nclass RefOnlyNoisedUNet(torch.nn.Module):\n    def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None:\n        super().__init__()\n        self.unet = unet\n        self.train_sched = train_sched\n        self.val_sched = val_sched\n\n        unet_lora_attn_procs = dict()\n        for name, _ in unet.attn_processors.items():\n            if torch.__version__ >= '2.0':\n                default_attn_proc = MyAttnProcessor2_0()\n                print('using my attention')\n            elif is_xformers_available():\n                default_attn_proc = XFormersAttnProcessor()\n            else:\n                default_attn_proc = AttnProcessor()\n            unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(\n                default_attn_proc, enabled=name.endswith(\"attn1.processor\"), name=name\n            )\n        unet.set_attn_processor(unet_lora_attn_procs)\n\n    def __getattr__(self, name: str):\n        try:\n            return super().__getattr__(name)\n        except AttributeError:\n            return getattr(self.unet, name)\n\n    def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs):\n        if is_cfg_guidance:\n            encoder_hidden_states = encoder_hidden_states[1:]\n            class_labels = class_labels[1:]\n        self.unet(\n            noisy_cond_lat, timestep,\n            encoder_hidden_states=encoder_hidden_states,\n            class_labels=class_labels,\n            cross_attention_kwargs=dict(mode=\"w\", ref_dict=ref_dict),\n            **kwargs\n        )\n\n    def forward(\n        self, sample, timestep, encoder_hidden_states, class_labels=None,\n        *args, cross_attention_kwargs,\n        down_block_res_samples=None, mid_block_res_sample=None,\n        **kwargs\n    ):\n        \n\n        cond_lat = cross_attention_kwargs['cond_lat']\n        is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False)\n        noise = torch.randn_like(cond_lat)\n        if self.training:\n            noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)\n            noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)\n        else:\n            noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))\n            noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))\n            #if cross_attention_kwargs['cond_lat_back'] is not None:\n            #    cond_lat_back = cross_attention_kwargs['cond_lat_back']\n            #    noisy_cond_lat_back = self.val_sched.add_noise(cond_lat_back, noise, timestep.reshape(-1))\n            #    noisy_cond_lat_back = self.val_sched.scale_model_input(noisy_cond_lat_back, timestep.reshape(-1))\n\n        ref_dict = {}\n        self.forward_cond(\n            noisy_cond_lat, timestep,\n            encoder_hidden_states, class_labels,\n            ref_dict, is_cfg_guidance, **kwargs\n        )\n        weight_dtype = self.unet.dtype\n        return self.unet(\n            sample, timestep,\n            encoder_hidden_states, *args,\n            class_labels=class_labels,\n            cross_attention_kwargs=dict(mode=\"r\", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance),\n            down_block_additional_residuals=[\n                sample.to(dtype=weight_dtype) for sample in down_block_res_samples\n            ] if down_block_res_samples is not None else None,\n            mid_block_additional_residual=(\n                mid_block_res_sample.to(dtype=weight_dtype)\n                if mid_block_res_sample is not None else None\n            ),\n            **kwargs\n        )\n\n\ndef scale_latents(latents):\n    latents = (latents - 0.22) * 0.75\n    return latents\n\n\ndef unscale_latents(latents):\n    latents = latents / 0.75 + 0.22\n    return latents\n\n\ndef scale_image(image):\n    image = image * 0.5 / 0.8\n    return image\n\n\ndef unscale_image(image):\n    image = image / 0.5 * 0.8\n    return image\n\n\nclass DepthControlUNet(torch.nn.Module):\n    def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None:\n        super().__init__()\n        self.unet = unet\n        if controlnet is None:\n            self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet)\n        else:\n            self.controlnet = controlnet\n        DefaultAttnProc = MyAttnProcessor2_0\n        if is_xformers_available():\n            DefaultAttnProc = XFormersAttnProcessor\n        self.controlnet.set_attn_processor(DefaultAttnProc())\n        self.conditioning_scale = conditioning_scale\n\n    def __getattr__(self, name: str):\n        try:\n            return super().__getattr__(name)\n        except AttributeError:\n            return getattr(self.unet, name)\n\n    def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs):\n        cross_attention_kwargs = dict(cross_attention_kwargs)\n        control_depth = cross_attention_kwargs.pop('control_depth')\n        down_block_res_samples, mid_block_res_sample = self.controlnet(\n            sample,\n            timestep,\n            encoder_hidden_states=encoder_hidden_states,\n            controlnet_cond=control_depth,\n            conditioning_scale=self.conditioning_scale,\n            return_dict=False,\n        )\n        return self.unet(\n            sample,\n            timestep,\n            encoder_hidden_states=encoder_hidden_states,\n            down_block_res_samples=down_block_res_samples,\n            mid_block_res_sample=mid_block_res_sample,\n            cross_attention_kwargs=cross_attention_kwargs\n        )\n\n\nclass ModuleListDict(torch.nn.Module):\n    def __init__(self, procs: dict) -> None:\n        super().__init__()\n        self.keys = sorted(procs.keys())\n        self.values = torch.nn.ModuleList(procs[k] for k in self.keys)\n\n    def __getitem__(self, key):\n        return self.values[self.keys.index(key)]\n\n\nclass SuperNet(torch.nn.Module):\n    def __init__(self, state_dict: Dict[str, torch.Tensor]):\n        super().__init__()\n        state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys()))\n        self.layers = torch.nn.ModuleList(state_dict.values())\n        self.mapping = dict(enumerate(state_dict.keys()))\n        self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}\n\n        # .processor for unet, .self_attn for text encoder\n        self.split_keys = [\".processor\", \".self_attn\"]\n\n        # we add a hook to state_dict() and load_state_dict() so that the\n        # naming fits with `unet.attn_processors`\n        def map_to(module, state_dict, *args, **kwargs):\n            new_state_dict = {}\n            for key, value in state_dict.items():\n                num = int(key.split(\".\")[1])  # 0 is always \"layers\"\n                new_key = key.replace(f\"layers.{num}\", module.mapping[num])\n                new_state_dict[new_key] = value\n\n            return new_state_dict\n\n        def remap_key(key, state_dict):\n            for k in self.split_keys:\n                if k in key:\n                    return key.split(k)[0] + k\n            return key.split('.')[0]\n\n        def map_from(module, state_dict, *args, **kwargs):\n            all_keys = list(state_dict.keys())\n            for key in all_keys:\n                replace_key = remap_key(key, state_dict)\n                new_key = key.replace(replace_key, f\"layers.{module.rev_mapping[replace_key]}\")\n                state_dict[new_key] = state_dict[key]\n                del state_dict[key]\n\n        self._register_state_dict_hook(map_to)\n        self._register_load_state_dict_pre_hook(map_from, with_module=True)\n\n\nclass Zero123PlusPipeline(diffusers.StableDiffusionPipeline):\n    tokenizer: transformers.CLIPTokenizer\n    text_encoder: transformers.CLIPTextModel\n    vision_encoder: transformers.CLIPVisionModelWithProjection\n\n    feature_extractor_clip: transformers.CLIPImageProcessor\n    unet: UNet2DConditionModel\n    scheduler: diffusers.schedulers.KarrasDiffusionSchedulers\n\n    vae: AutoencoderKL\n    ramping: nn.Linear\n\n    feature_extractor_vae: transformers.CLIPImageProcessor\n\n    depth_transforms_multi = transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize([0.5], [0.5])\n    ])\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        text_encoder: CLIPTextModel,\n        tokenizer: CLIPTokenizer,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        vision_encoder: transformers.CLIPVisionModelWithProjection,\n        feature_extractor_clip: CLIPImageProcessor, \n        feature_extractor_vae: CLIPImageProcessor,\n        ramping_coefficients: Optional[list] = None,\n        safety_checker=None,\n    ):\n        DiffusionPipeline.__init__(self)\n\n        self.register_modules(\n            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,\n            unet=unet, scheduler=scheduler, safety_checker=None,\n            vision_encoder=vision_encoder,\n            feature_extractor_clip=feature_extractor_clip,\n            feature_extractor_vae=feature_extractor_vae\n        )\n        self.register_to_config(ramping_coefficients=ramping_coefficients)\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n\n    def prepare(self):\n        train_sched = DDPMScheduler.from_config(self.scheduler.config)\n        if isinstance(self.unet, UNet2DConditionModel):\n            self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval()\n\n    def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0):\n        self.prepare()\n        self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale)\n        return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)]))\n\n    def encode_condition_image(self, image: torch.Tensor):\n        image = self.vae.encode(image).latent_dist.sample()\n        return image\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        image: Image.Image = None,\n        prompt = \"\",\n        *args,\n        num_images_per_prompt: Optional[int] = 1,\n        guidance_scale=4.0,\n        depth_image: Image.Image = None,\n        output_type: Optional[str] = \"pil\",\n        width=640,\n        height=960,\n        num_inference_steps=28,\n        return_dict=True,\n        is_first = False,\n        **kwargs\n    ):\n        global FIRST\n        FIRST = is_first\n        global IDX\n        IDX = 0\n        if is_first:\n            global EMBED\n            EMBED=[]\n        \n        # Create a generator with the specified seed\n        generator = torch.Generator(device='cuda')\n        generator.manual_seed(42)\n        self.prepare()\n        if image is None:\n            raise ValueError(\"Inputting embeddings not supported for this pipeline. Please pass an image.\")\n        assert not isinstance(image, torch.Tensor)\n        image = to_rgb_image(image)\n        image_1 = self.feature_extractor_vae(images=image, return_tensors=\"pt\").pixel_values\n        image_2 = self.feature_extractor_clip(images=image, return_tensors=\"pt\").pixel_values\n        if depth_image is not None and hasattr(self.unet, \"controlnet\"):\n            depth_image = to_rgb_image(depth_image)\n            depth_image = self.depth_transforms_multi(depth_image).to(\n                device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype\n            )\n        image = image_1.to(device=self.vae.device, dtype=self.vae.dtype)\n        image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)\n        cond_lat = self.encode_condition_image(image)\n        if guidance_scale > 1:\n            negative_lat = self.encode_condition_image(torch.zeros_like(image))\n            cond_lat = torch.cat([negative_lat, cond_lat])\n        encoded = self.vision_encoder(image_2, output_hidden_states=False)\n        global_embeds = encoded.image_embeds\n        global_embeds = global_embeds.unsqueeze(-2)\n        \n        encoder_hidden_states = self._encode_prompt(\n            prompt,\n            self.device,\n            num_images_per_prompt,\n            False\n        )\n        ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)\n        encoder_hidden_states = encoder_hidden_states + global_embeds * ramp\n        cak = dict(cond_lat=cond_lat)\n        if hasattr(self.unet, \"controlnet\"):\n            cak['control_depth'] = depth_image\n        \n        cak['cond_lat_back'] = None\n\n        latents: torch.Tensor = super().__call__(\n            None,\n            *args,\n            cross_attention_kwargs=cak,\n            guidance_scale=guidance_scale,\n            num_images_per_prompt=num_images_per_prompt,\n            prompt_embeds=encoder_hidden_states,\n            num_inference_steps=num_inference_steps,\n            output_type='latent',\n            width=width,\n            height=height,\n            generator=generator,\n            **kwargs\n        ).images\n        latents = unscale_latents(latents)\n        if not output_type == \"latent\":\n            image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])\n        else:\n            image = latents\n        \n        image = self.image_processor.postprocess(image, output_type=output_type)\n\n        if not return_dict:\n            return (image,)\n        \n\n        return ImagePipelineOutput(images=image)"
  },
  {
    "path": "main.py",
    "content": "import os\nimport cv2\nimport time\nimport tqdm\nimport numpy as np\nimport dearpygui.dearpygui as dpg\n\nimport torch\nimport torch.nn.functional as F\nimport torchvision.utils as vutils\nfrom einops import rearrange, repeat\nimport imageio\nimport rembg\n\nfrom cam_utils import orbit_camera, OrbitCamera\nfrom gs_renderer_4d import Renderer, MiniCam\nfrom dataset_4d import SparseDataset\n\ndef save_image_to_local(image_tensor, file_path):\n    # Ensure the image tensor is in the range [0, 1]\n    image_tensor = image_tensor.clamp(0, 1) \n\n    # Save the image tensor to the specified file path\n    vutils.save_image(image_tensor, file_path)\n\nclass GUI:\n    def __init__(self, opt):\n        self.opt = opt  # shared with the trainer's opt to support in-place modification of rendering parameters.\n        self.gui = opt.gui # enable gui\n        self.W = opt.W\n        self.H = opt.H\n        self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)\n\n        self.mode = \"image\"\n        self.seed = \"random\"\n\n        self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)\n        self.need_update = True  # update buffer_image\n\n        # models\n        self.device = torch.device(\"cuda\")\n        self.bg_remover = None\n\n        self.guidance_sd = None\n        self.guidance_zero123 = None\n\n        self.enable_sd = False\n        self.enable_zero123 = False\n\n        # renderer\n        self.renderer = Renderer(sh_degree=self.opt.sh_degree)\n        self.gaussain_scale_factor = 1\n\n        # input image\n        self.input_img = None\n        self.input_mask = None\n        self.input_img_torch = None\n        self.input_mask_torch = None\n        self.overlay_input_img = False\n        self.overlay_input_img_ratio = 0.5\n        \n        #self.use_depth = opt.use_depth\n\n        # input text\n        self.prompt = \"\"\n        self.negative_prompt = \"\"\n\n        # training stuff\n        self.training = False\n        self.optimizer = None\n        self.step = 0\n        self.t = 0\n        self.time = 0\n        self.train_steps = 1  # steps per rendering loop\n        self.init = True\n        self.stage = 'coarse'\n        self.path = self.opt.path\n        self.save_step = self.opt.save_step\n        \n        if self.opt.size is not None:\n            self.size = self.opt.size\n        else:\n            self.size = len(os.listdir(os.path.join(self.path,'ref')))\n        self.frames=self.size\n        self.dataset = SparseDataset(self.opt, self.size, H=self.H, W=self.W, device=self.device)\n        self.dataloader =self.dataset.dataloader()\n        self.iter = iter(self.dataloader)\n        self.ref_view_batch, self.input_mask_batch,self.zero123_view_batch,self.zero123_masks_batch = next(self.iter)\n        self.input_img_torch_batch,self.input_mask_torch_batch,self.zero123plus_imgs_torch_batch,self.zero123plus_masks_torch_batch=[],[],[],[]\n        \n\n        # load input data from cmdline\n        if self.opt.input is not None:\n            self.load_input(self.opt.input)\n        \n        # override prompt from cmdline\n        if self.opt.prompt is not None:\n            self.prompt = self.opt.prompt\n\n        # override if provide a checkpoint\n        \n        self.renderer.initialize(num_pts=self.opt.num_pts)            \n\n        self.point_nums = []\n        if self.gui:\n            dpg.create_context()\n            self.register_dpg()\n            self.test_step()\n\n    def __del__(self):\n        if self.gui:\n            dpg.destroy_context()\n\n    def seed_everything(self):\n        try:\n            seed = int(self.seed)\n        except:\n            seed = np.random.randint(0, 1000000)\n\n        os.environ[\"PYTHONHASHSEED\"] = str(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed(seed)\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = True\n\n        self.last_seed = seed\n        \n    def prepare_image(self,idx):\n        # input image\n        if self.input_img is not None:\n            self.input_img_torch = torch.from_numpy(self.input_img).permute(2, 0, 1).unsqueeze(0).to(self.device)\n            self.input_img_torch = F.interpolate(self.input_img_torch, (self.opt.ref_size, self.opt.ref_size), mode=\"bilinear\", align_corners=False)\n\n            self.input_mask_torch = torch.from_numpy(self.input_mask).permute(2, 0, 1).unsqueeze(0).to(self.device)\n            self.input_mask_torch = F.interpolate(self.input_mask_torch, (self.opt.ref_size, self.opt.ref_size), mode=\"bilinear\", align_corners=False)\n        \n        self.zero123plus_imgs_torch=[]\n        self.zero123plus_masks_torch=[]\n        # input image\n        if self.input_imgs is not None:\n            for i in np.arange(6):\n                #print(idx,i)\n                self.input_img2_torch=(torch.from_numpy(self.input_imgs[i]).permute(2, 0, 1).unsqueeze(0).to(self.device))\n                self.zero123plus_imgs_torch.append(F.interpolate(self.input_img2_torch, (self.opt.ref_size, self.opt.ref_size), mode=\"bilinear\", align_corners=False))\n\n                self.input_mask2_torch=torch.from_numpy(self.input_masks[i]).permute(2, 0, 1).unsqueeze(0).to(self.device)\n                self.zero123plus_masks_torch.append(F.interpolate(self.input_mask2_torch, (self.opt.ref_size, self.opt.ref_size), mode=\"bilinear\", align_corners=False))\n                \n        self.input_img_torch_batch.append(self.input_img_torch)\n        self.input_mask_torch_batch.append(self.input_mask_torch)\n        self.zero123plus_imgs_torch_batch.append(self.zero123plus_imgs_torch)\n        self.zero123plus_masks_torch_batch.append(self.zero123plus_masks_torch)\n        \n        # prepare embeddings\n        with torch.no_grad():\n            self.guidance_zero123.get_img_embeds(self.input_img_torch, self.zero123plus_imgs_torch)\n\n\n    def prepare_train(self):\n\n        self.step = 0\n        self.end_step = self.save_step+1\n        \n        ## given a load_path, load corresponding model\n        if self.opt.load_path is not None:\n           if self.opt.load_step is not None:\n               self.step = self.opt.load_step\n           else:\n               #default loading save_step ply\n               self.step = self.save_step \n           auto_path = os.path.join(self.opt.outdir,self.opt.load_path + str(self.step))\n\n           ply_path = os.path.join(auto_path,'model.ply')\n           self.renderer.gaussians.load_model(auto_path)\n           self.renderer.gaussians.load_ply(ply_path)\n           self.end_step =self.step+self.end_step\n        \n        ## setup training\n        self.renderer.gaussians.training_setup(self.opt)\n        \n        ## do not do progressive sh-level\n        self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree\n        self.optimizer = self.renderer.gaussians.optimizer\n        \n        # default camera\n        pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)\n        self.fixed_cam = MiniCam(\n            pose,\n            self.opt.ref_size,\n            self.opt.ref_size,\n            self.cam.fovy,\n            self.cam.fovx,\n            self.cam.near,\n            self.cam.far,\n        )\n        self.set_fix_cam()\n        self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != \"\"\n        self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None\n\n\n        print(f\"[INFO] loading zero123...\")\n        from guidance.zero123_4d_utils import Zero123\n        self.guidance_zero123 = Zero123(self.device)\n        print(f\"[INFO] loaded zero123!\")\n\n        ## load multiview reference images\n        for i in np.arange(len(self.ref_view_batch)):\n                self.input_img =   self.ref_view_batch[i]\n                self.input_mask =  self.input_mask_batch[i]\n                self.input_imgs =  self.zero123_view_batch[i]\n                self.input_masks = self.zero123_masks_batch[i]\n                self.prepare_image(i)\n\n\n    def train_step(self):\n        starter = torch.cuda.Event(enable_timing=True)\n        ender = torch.cuda.Event(enable_timing=True)\n        starter.record()\n\n\n        \n        torch.autograd.set_detect_anomaly(True)\n        for _ in range(self.train_steps):\n\n            if self.step<self.opt.init_steps:\n                self.init = True\n                self.stage = 'coarse'\n            else:\n                self.init = False\n                self.stage = 'fine'\n            \n            if self.step == self.end_step:\n                exit()\n                \n            ## save model\n            if self.step == self.save_step:\n                auto_path = os.path.join(self.opt.outdir,self.opt.save_path + str(self.step))\n                os.makedirs(auto_path,exist_ok=True)\n                ply_path = os.path.join(auto_path,'model.ply')\n                self.renderer.gaussians.save_ply(ply_path)\n                self.renderer.gaussians.save_deformation(auto_path)\n                \n            \n            if self.step>self.opt.position_lr_max_steps:\n                self.opt.position_lr_max_steps = self.opt.position_lr_max_steps2\n\n            self.step += 1\n            step_ratio = min(1, self.step / self.opt.iters)\n            viewspace_point_tensor_list = []\n            radii_list = []\n            visibility_filter_list = []\n            # update lr\n            self.renderer.gaussians.update_learning_rate(self.step)\n            self.guidance_zero123.update_step(0,self.step)\n\n            loss = 0\n            \n            if self.step%self.opt.valid_interval == 0:\n                self.save_renderings( 0, 0, 2 ,'front')\n                self.save_renderings( 180, 0, 2 ,'back')\n                \n\n            render_resolution = 128 if step_ratio < 0.3 else (256 if step_ratio < 0.6 else 512)\n\n            # avoid too large elevation (> 80 or < -80), and make sure it always cover [-30, 30]\n            min_ver = max(min(-30, -30 - self.opt.elevation), -80 - self.opt.elevation)\n            max_ver = min(max(30, 30 - self.opt.elevation), 80 - self.opt.elevation)\n\n\n            for _ in np.arange(self.opt.batch_size):\n                \n                #sample time\n                if self.init:\n                        self.t = self.frames//2\n                        self.time = self.t/self.frames\n                else:   \n                        self.t = np.random.randint(0,self.frames)\n                        self.time = self.t/self.frames\n\n                self.input_img_torch =   self.input_img_torch_batch[self.t]\n                self.input_mask_torch =  self.input_mask_torch_batch[self.t]\n                self.zero123plus_imgs_torch =  self.zero123plus_imgs_torch_batch[self.t]\n                self.zero123plus_masks_torch = self.zero123plus_masks_torch_batch[self.t]\n                \n                ## need to do rgb loss in the batch\n                cur_cam = self.fixed_cam\n                cur_cam.time=self.time\n                \n                out = self.renderer.render(cur_cam,stage=self.stage)\n                viewspace_point_tensor, visibility_filter, radii = out[\"viewspace_points\"], out[\"visibility_filter\"], out[\"radii\"]  \n                radii_list.append(radii.unsqueeze(0))\n                visibility_filter_list.append(visibility_filter.unsqueeze(0))\n                viewspace_point_tensor_list.append(viewspace_point_tensor)\n                # rgb loss\n                image = out[\"image\"].unsqueeze(0) # [1, 3, H, W] in [0, 1]\n                image_loss =step_ratio* 20000*  F.mse_loss(image, self.input_img_torch)\n                loss = loss + image_loss\n                \n                alpha = out[\"alpha\"].unsqueeze(0)\n                alpha_loss = step_ratio* 5000*  F.mse_loss(alpha, self.input_mask_torch)\n                loss = loss + alpha_loss\n                \n                images = []\n                poses = []\n\n                vers_plus, hors_plus, radii_plus = [], [], []\n                self.guidance_zero123.update_step(1,self.step)\n                # render random view\n                ver = np.random.randint(min_ver, max_ver)\n                hor = np.random.randint(-180, 180)\n                radius = 0\n                \n\n                \n\n\n                pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)\n                \n                poses.append(pose)\n\n                cur_cam = MiniCam(\n                    pose,\n                    render_resolution,\n                    render_resolution,\n                    self.cam.fovy,\n                    self.cam.fovx,\n                    self.cam.near,\n                    self.cam.far,\n                )\n                cur_cam.time=self.time\n                if hor<30 and hor>-30 or np.random.rand()>0.4:\n                    idx=None\n                    vers_plus.append(torch.tensor(ver,device=self.device).unsqueeze(dim=0))\n                    hors_plus.append(torch.tensor(hor,device=self.device).unsqueeze(dim=0))\n                    radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0))\n                elif hor>0:\n                    idx=hor//60\n                    vers_plus.append(torch.tensor(ver-self.fixed_elevation[idx],device=self.device).unsqueeze(dim=0))\n                    hors_plus.append(torch.tensor(hor-self.fixed_azimuth[idx],device=self.device).unsqueeze(dim=0))\n                    radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0))\n                elif hor<0:\n                    idx = (360+hor)//60\n                    vers_plus.append(torch.tensor(ver-self.fixed_elevation[idx],device=self.device).unsqueeze(dim=0))\n                    hors_plus.append(torch.tensor(hor-self.fixed_azimuth[idx],device=self.device).unsqueeze(dim=0))\n                    radii_plus.append(torch.tensor(radius,device=self.device).unsqueeze(dim=0))\n\n                bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device=\"cuda\")\n                out = self.renderer.render(cur_cam, bg_color=bg_color,stage=self.stage)\n                viewspace_point_tensor, visibility_filter, radii_rendering = out[\"viewspace_points\"], out[\"visibility_filter\"], out[\"radii\"]  \n                radii_list.append(radii_rendering.unsqueeze(0))\n                visibility_filter_list.append(visibility_filter.unsqueeze(0))\n                viewspace_point_tensor_list.append(viewspace_point_tensor)\n                image = out[\"image\"].unsqueeze(0)# [1, 3, H, W] in [0, 1]\n                images.append(image)\n                \n                \n            \n                images_render = torch.cat(images, dim=0)\n                #poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)\n                vers_batch = torch.cat(vers_plus, dim=0).cpu().numpy()\n                hors_batch = torch.cat(hors_plus, dim=0).cpu().numpy()\n                radii_batch = torch.cat(radii_plus, dim=0).cpu().numpy()\n\n                # guidance loss\n                # as we have different reference views, so each time we only pass 1 image into zero123 for guidance\n                zero123_loss = self.opt.lambda_zero123 * self.guidance_zero123.train_step(images_render, vers_batch, hors_batch, radii_batch, step_ratio,idx=idx,t = self.t)\n                loss = loss + zero123_loss\n            \n            # tv loss\n            scales = out['scales']\n            tv_loss = self.renderer.gaussians.compute_regulation(self.opt.time_smoothness_weight, self.opt.plane_tv_weight, self.opt.l1_time_planes)\n            loss += self.opt.lambda_tv * tv_loss\n            \n            # scale loss from physgaussian\n            r = self.opt.scale_loss_ratio\n            scale_loss = (torch.mean(torch.maximum(torch.max(scales,dim=1).values/ \\\n                                                  (torch.min(scales,dim=1).values+1e-8),\\\n                                                    torch.ones_like(torch.max(scales,dim=1).values)*r))-r) * scales.shape[0]\n            loss += scale_loss\n\n            # optimize step\n            loss.backward()\n            self.optimizer.step()\n            self.optimizer.zero_grad()\n\n            viewspace_point_tensor_grad = torch.zeros_like(viewspace_point_tensor)\n            for idx in range(0, len(viewspace_point_tensor_list)):\n                    viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad\n            radii = torch.cat(radii_list,0).max(dim=0).values\n            visibility_filter = torch.cat(visibility_filter_list).any(dim=0)\n            if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:\n                self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])\n                self.renderer.gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter)\n                if self.step % self.opt.densification_interval == 1 :\n\n                    self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold_percent, min_opacity=0.01, extent=1, max_screen_size=2)\n\n\n\n        ender.record()\n        torch.cuda.synchronize()\n        t = starter.elapsed_time(ender)\n\n        self.need_update = True\n\n        if self.gui:\n            dpg.set_value(\"_log_train_time\", f\"{t:.4f}ms\")\n            dpg.set_value(\n                \"_log_train_log\",\n                 f\"step = {self.step: 5d} (+{self.train_steps: 2d})\\n loss = {loss.item():.4f}\\nzero123_loss = {zero123_loss.item():.4f}image_loss ={image_loss.item():.4f}\\nloss_alpha = {alpha_loss.item():.4f} scale_loss:{scale_loss.item():.4f} \",\n            )\n\n    def set_fix_cam(self):\n        self.fixed_cam_plus=[]\n        self.fixed_elevation = []\n        self.fixed_azimuth = []\n        \n        pose = orbit_camera(self.opt.elevation-30,30 , self.opt.radius)\n        self.fixed_elevation.append(-30)\n        self.fixed_azimuth.append(30)\n        self.fixed_cam_plus.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        \n        pose = orbit_camera(self.opt.elevation+20, 90, self.opt.radius)\n        self.fixed_elevation.append(20)\n        self.fixed_azimuth.append(90)\n        self.fixed_cam_plus.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        pose = orbit_camera(self.opt.elevation-30, 150, self.opt.radius)\n        self.fixed_elevation.append(-30)\n        self.fixed_azimuth.append(150)\n        self.fixed_cam_plus.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        \n        pose = orbit_camera(self.opt.elevation+20, 210, self.opt.radius)\n        self.fixed_elevation.append(+20)\n        self.fixed_azimuth.append(210)\n        self.fixed_cam_plus.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        \n        pose = orbit_camera(self.opt.elevation-30, 270, self.opt.radius)\n        self.fixed_elevation.append(-30)\n        self.fixed_azimuth.append(270)\n        self.fixed_cam_plus.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        \n        pose = orbit_camera(self.opt.elevation+20, 330, self.opt.radius)\n        self.fixed_elevation.append(20)\n        self.fixed_azimuth.append(330)\n        self.fixed_cam_plus.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        \n    @torch.no_grad()\n    def test_step(self):\n        # ignore if no need to update\n        if not self.need_update:\n            return\n\n        starter = torch.cuda.Event(enable_timing=True)\n        ender = torch.cuda.Event(enable_timing=True)\n        starter.record()\n\n        # should update image\n        if self.need_update:\n            # render image\n\n            cur_cam = MiniCam(\n                self.cam.pose,\n                self.W,\n                self.H,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n                time=self.time\n            )\n            #print(cur_cam.time)\n            out = self.renderer.render(cur_cam, self.gaussain_scale_factor,stage=self.stage)\n\n            buffer_image = out[self.mode]  # [3, H, W]\n\n            if self.mode in ['depth', 'alpha']:\n                buffer_image = buffer_image.repeat(3, 1, 1)\n                if self.mode == 'depth':\n                    buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)\n\n            buffer_image = F.interpolate(\n                buffer_image.unsqueeze(0),\n                size=(self.H, self.W),\n                mode=\"bilinear\",\n                align_corners=False,\n            ).squeeze(0)\n\n            self.buffer_image = (\n                buffer_image.permute(1, 2, 0)\n                .contiguous()\n                .clamp(0, 1)\n                .contiguous()\n                .detach()\n                .cpu()\n                .numpy()\n            )\n\n            # display input_image\n            if self.overlay_input_img and self.input_img is not None:\n                self.buffer_image = (\n                    self.buffer_image * (1 - self.overlay_input_img_ratio)\n                    + self.input_img * self.overlay_input_img_ratio\n                )\n\n            self.need_update = False\n\n        ender.record()\n        torch.cuda.synchronize()\n        t = starter.elapsed_time(ender)\n\n        if self.gui:\n            dpg.set_value(\"_log_infer_time\", f\"{t:.4f}ms ({int(1000/t)} FPS)\")\n            dpg.set_value(\n                \"_texture\", self.buffer_image\n            )  # buffer must be contiguous, else seg fault!\n\n    \n    def load_input(self, file):\n        # load image\n        pass\n        # load image\n\n    @torch.no_grad()\n    def save_renderings(self, elev=0, azim=0, radius=2, name='front'):\n        images=[]\n        for i in np.arange(self.frames):\n            \n            pose = orbit_camera(elev, azim, radius)\n            cam = MiniCam(\n            pose,\n            self.opt.ref_size,\n            self.opt.ref_size,\n            self.cam.fovy,\n            self.cam.fovx,\n            self.cam.near,\n            self.cam.far,\n            )   \n            cam.time=float(i/self.frames)\n            out = self.renderer.render(cam,stage=self.stage)\n            image = out[\"image\"].unsqueeze(0)\n            images.append(image)\n            os.makedirs(f'./valid/{self.opt.save_path}/{self.step}_{name}',exist_ok=True)\n            save_image_to_local(image[0].detach(),f'./valid/{self.opt.save_path}/{self.step}_{name}/{str(i).zfill(2)}.jpg')\n        samples=torch.cat(images,dim=0)\n        \n        vid = (\n            (rearrange(samples, \"t c h w -> t h w c\") * 255).clamp(0,255).detach()\n            .cpu()\n            .numpy()\n            .astype(np.uint8)\n        )\n        video_path = f'./valid/{self.opt.save_path}/{self.step}_{name}/video.mp4'\n        imageio.mimwrite(video_path, vid)\n\n\n    @torch.no_grad()\n    def save_model(self, mode='geo', texture_size=1024):\n        os.makedirs(self.opt.outdir, exist_ok=True)\n        if mode == 'geo':\n            path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')\n            self.renderer.gaussians.save_ply(path)\n\n        elif mode == 'geo+tex':\n            path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')\n            self.renderer.gaussians.save_ply(path)\n\n        else:\n            path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')\n            self.renderer.gaussians.save_ply(path)\n\n        print(f\"[INFO] save model to {path}.\")\n\n    def register_dpg(self):\n        ### register texture\n\n        with dpg.texture_registry(show=False):\n            dpg.add_raw_texture(\n                self.W,\n                self.H,\n                self.buffer_image,\n                format=dpg.mvFormat_Float_rgb,\n                tag=\"_texture\",\n            )\n\n        ### register window\n\n        # the rendered image, as the primary window\n        with dpg.window(\n            tag=\"_primary_window\",\n            width=self.W,\n            height=self.H,\n            pos=[0, 0],\n            no_move=True,\n            no_title_bar=True,\n            no_scrollbar=True,\n        ):\n            # add the texture\n            dpg.add_image(\"_texture\")\n\n        # dpg.set_primary_window(\"_primary_window\", True)\n\n        # control window\n        with dpg.window(\n            label=\"Control\",\n            tag=\"_control_window\",\n            width=600,\n            height=self.H,\n            pos=[self.W, 0],\n            no_move=True,\n            no_title_bar=True,\n        ):\n            # button theme\n            with dpg.theme() as theme_button:\n                with dpg.theme_component(dpg.mvButton):\n                    dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))\n                    dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))\n                    dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))\n                    dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)\n                    dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)\n\n            # timer stuff\n            with dpg.group(horizontal=True):\n                dpg.add_text(\"Infer time: \")\n                dpg.add_text(\"no data\", tag=\"_log_infer_time\")\n\n            def callback_setattr(sender, app_data, user_data):\n                setattr(self, user_data, app_data)\n\n            # init stuff\n            with dpg.collapsing_header(label=\"Initialize\", default_open=True):\n\n                # seed stuff\n                def callback_set_seed(sender, app_data):\n                    self.seed = app_data\n                    self.seed_everything()\n\n                dpg.add_input_text(\n                    label=\"seed\",\n                    default_value=self.seed,\n                    on_enter=True,\n                    callback=callback_set_seed,\n                )\n\n                # input stuff\n                def callback_select_input(sender, app_data):\n                    # only one item\n                    for k, v in app_data[\"selections\"].items():\n                        dpg.set_value(\"_log_input\", k)\n                        self.load_input(v)\n\n                    self.need_update = True\n\n                with dpg.file_dialog(\n                    directory_selector=False,\n                    show=False,\n                    callback=callback_select_input,\n                    file_count=1,\n                    tag=\"file_dialog_tag\",\n                    width=700,\n                    height=400,\n                ):\n                    dpg.add_file_extension(\"Images{.jpg,.jpeg,.png}\")\n\n                with dpg.group(horizontal=True):\n                    dpg.add_button(\n                        label=\"input\",\n                        callback=lambda: dpg.show_item(\"file_dialog_tag\"),\n                    )\n                    dpg.add_text(\"\", tag=\"_log_input\")\n                \n                # overlay stuff\n                with dpg.group(horizontal=True):\n\n                    def callback_toggle_overlay_input_img(sender, app_data):\n                        self.overlay_input_img = not self.overlay_input_img\n                        self.need_update = True\n\n                    dpg.add_checkbox(\n                        label=\"overlay image\",\n                        default_value=self.overlay_input_img,\n                        callback=callback_toggle_overlay_input_img,\n                    )\n\n                    def callback_set_overlay_input_img_ratio(sender, app_data):\n                        self.overlay_input_img_ratio = app_data\n                        self.need_update = True\n\n                    dpg.add_slider_float(\n                        label=\"ratio\",\n                        min_value=0,\n                        max_value=1,\n                        format=\"%.1f\",\n                        default_value=self.overlay_input_img_ratio,\n                        callback=callback_set_overlay_input_img_ratio,\n                    )\n\n                # prompt stuff\n            \n                dpg.add_input_text(\n                    label=\"prompt\",\n                    default_value=self.prompt,\n                    callback=callback_setattr,\n                    user_data=\"prompt\",\n                )\n\n                dpg.add_input_text(\n                    label=\"negative\",\n                    default_value=self.negative_prompt,\n                    callback=callback_setattr,\n                    user_data=\"negative_prompt\",\n                )\n\n                # save current model\n                with dpg.group(horizontal=True):\n                    dpg.add_text(\"Save: \")\n\n                    def callback_save(sender, app_data, user_data):\n                        self.save_model(mode=user_data)\n\n                    dpg.add_button(\n                        label=\"model\",\n                        tag=\"_button_save_model\",\n                        callback=callback_save,\n                        user_data='model',\n                    )\n                    dpg.bind_item_theme(\"_button_save_model\", theme_button)\n\n                    dpg.add_button(\n                        label=\"geo\",\n                        tag=\"_button_save_mesh\",\n                        callback=callback_save,\n                        user_data='geo',\n                    )\n                    dpg.bind_item_theme(\"_button_save_mesh\", theme_button)\n\n                    dpg.add_button(\n                        label=\"geo+tex\",\n                        tag=\"_button_save_mesh_with_tex\",\n                        callback=callback_save,\n                        user_data='geo+tex',\n                    )\n                    dpg.bind_item_theme(\"_button_save_mesh_with_tex\", theme_button)\n\n                    dpg.add_input_text(\n                        label=\"\",\n                        default_value=self.opt.save_path,\n                        callback=callback_setattr,\n                        user_data=\"save_path\",\n                    )\n\n            # training stuff\n            with dpg.collapsing_header(label=\"Train\", default_open=True):\n                # lr and train button\n                with dpg.group(horizontal=True):\n                    dpg.add_text(\"Train: \")\n\n                    def callback_train(sender, app_data):\n                        if self.training:\n                            self.training = False\n                            dpg.configure_item(\"_button_train\", label=\"start\")\n                        else:\n                            self.prepare_train()\n                            self.training = True\n                            dpg.configure_item(\"_button_train\", label=\"stop\")\n\n                    # dpg.add_button(\n                    #     label=\"init\", tag=\"_button_init\", callback=self.prepare_train\n                    # )\n                    # dpg.bind_item_theme(\"_button_init\", theme_button)\n\n                    dpg.add_button(\n                        label=\"start\", tag=\"_button_train\", callback=callback_train\n                    )\n                    dpg.bind_item_theme(\"_button_train\", theme_button)\n\n                with dpg.group(horizontal=True):\n                    dpg.add_text(\"\", tag=\"_log_train_time\")\n                    dpg.add_text(\"\", tag=\"_log_train_log\")\n\n            # rendering options\n            with dpg.collapsing_header(label=\"Rendering\", default_open=True):\n                # mode combo\n                def callback_change_mode(sender, app_data):\n                    self.mode = app_data\n                    self.need_update = True\n\n                dpg.add_combo(\n                    (\"image\", \"depth\", \"alpha\"),\n                    label=\"mode\",\n                    default_value=self.mode,\n                    callback=callback_change_mode,\n                )\n\n                # fov slider\n                def callback_set_fovy(sender, app_data):\n                    self.cam.fovy = np.deg2rad(app_data)\n                    self.need_update = True\n\n                dpg.add_slider_int(\n                    label=\"FoV (vertical)\",\n                    min_value=1,\n                    max_value=120,\n                    format=\"%d deg\",\n                    default_value=np.rad2deg(self.cam.fovy),\n                    callback=callback_set_fovy,\n                )\n\n                def callback_set_gaussain_scale(sender, app_data):\n                    self.gaussain_scale_factor = app_data\n                    self.need_update = True\n\n                dpg.add_slider_float(\n                    label=\"gaussain scale\",\n                    min_value=0,\n                    max_value=1,\n                    format=\"%.2f\",\n                    default_value=self.gaussain_scale_factor,\n                    callback=callback_set_gaussain_scale,\n                )\n\n        ### register camera handler\n\n        def callback_camera_drag_rotate_or_draw_mask(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            dx = app_data[1]\n            dy = app_data[2]\n\n            self.cam.orbit(dx, dy)\n            self.need_update = True\n\n        def callback_camera_wheel_scale(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            delta = app_data\n\n            self.cam.scale(delta)\n            self.need_update = True\n\n        def callback_camera_drag_pan(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            dx = app_data[1]\n            dy = app_data[2]\n\n            self.cam.pan(dx, dy)\n            self.need_update = True\n\n        def callback_set_mouse_loc(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            # just the pixel coordinate in image\n            self.mouse_loc = np.array(app_data)\n\n        with dpg.handler_registry():\n            # for camera moving\n            dpg.add_mouse_drag_handler(\n                button=dpg.mvMouseButton_Left,\n                callback=callback_camera_drag_rotate_or_draw_mask,\n            )\n            dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)\n            dpg.add_mouse_drag_handler(\n                button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan\n            )\n\n        dpg.create_viewport(\n            title=\"Gaussian3D\",\n            width=self.W + 600,\n            height=self.H + (45 if os.name == \"nt\" else 0),\n            resizable=False,\n        )\n\n        ### global theme\n        with dpg.theme() as theme_no_padding:\n            with dpg.theme_component(dpg.mvAll):\n                # set all padding to 0 to avoid scroll bar\n                dpg.add_theme_style(\n                    dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core\n                )\n                dpg.add_theme_style(\n                    dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core\n                )\n                dpg.add_theme_style(\n                    dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core\n                )\n\n        dpg.bind_item_theme(\"_primary_window\", theme_no_padding)\n\n        dpg.setup_dearpygui()\n\n        ### register a larger font\n        # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf\n        if os.path.exists(\"LXGWWenKai-Regular.ttf\"):\n            with dpg.font_registry():\n                with dpg.font(\"LXGWWenKai-Regular.ttf\", 18) as default_font:\n                    dpg.bind_font(default_font)\n\n        # dpg.show_metrics()\n\n        dpg.show_viewport()\n\n    def render(self):\n        assert self.gui\n        while dpg.is_dearpygui_running():\n            # update texture every frame\n            if self.training:\n                self.train_step()\n            self.test_step()\n            dpg.render_dearpygui_frame()\n    \n    # no gui mode\n    def train(self, iters=500):\n        \n        if iters > 0:\n            self.prepare_train()\n            for i in tqdm.trange(iters):\n                self.train_step()\n            # do a last prune\n            #self.renderer.gaussians.prune(min_opacity=0.01, extent=1, max_screen_size=1)\n            \n\n            \n            \n        # save\n        self.save_model(mode='model')\n        self.save_model(mode='geo+tex')\n        \n\nif __name__ == \"__main__\":\n    import argparse\n    from omegaconf import OmegaConf\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config\", required=True, help=\"path to the yaml config file\")\n    args, extras = parser.parse_known_args()\n\n    # override default config from cli\n    opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))\n\n    gui = GUI(opt)\n\n    if opt.gui:\n        gui.render()\n    else:\n        gui.train(opt.save_step+1)"
  },
  {
    "path": "mini_trainer.ipynb",
    "content": "{\n \"cells\": [\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": [\n    \"from gs_renderer_4d import Renderer, MiniCam\\n\",\n    \"from dataset_4d import SparseDataset\\n\",\n    \"import os\\n\",\n    \"import tqdm\\n\",\n    \"import numpy as np\\n\",\n    \"import torch\\n\",\n    \"\\n\",\n    \"from cam_utils import orbit_camera, OrbitCamera\\n\",\n    \"from guidance.sd_utils import StableDiffusion\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"class trainer:\\n\",\n    \"    def __init__(self,opt) -> None:\\n\",\n    \"        \\n\",\n    \"        #initialize options\\n\",\n    \"        self.opt=opt\\n\",\n    \"        self.device=self.opt.device\\n\",\n    \"        \\n\",\n    \"        #initialize renderer and gaussians\\n\",\n    \"        self.renderer = Renderer(sh_degree=self.opt.sh_degree)\\n\",\n    \"        self.renderer.initialize(num_pts=self.opt.num_pts)   \\n\",\n    \"        self.renderer.gaussians.training_setup(self.opt)\\n\",\n    \"        \\n\",\n    \"        self.optimizer = self.renderer.gaussians.optimizer\\n\",\n    \"        \\n\",\n    \"        self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)\\n\",\n    \"        \\n\",\n    \"        #initialize sd. replace with your own diffusion model if necessary.\\n\",\n    \"        self.enable_sd = True\\n\",\n    \"        self.guidance_sd = StableDiffusion(self.device)\\n\",\n    \"        self.guidance_sd.get_text_embeds([self.opt.prompt],negative_prompts= [''])\\n\",\n    \"        \\n\",\n    \"    def save(self,save_path):\\n\",\n    \"        #save \\n\",\n    \"        auto_path = save_path\\n\",\n    \"        os.makedirs(auto_path,exist_ok=True)\\n\",\n    \"        ply_path = os.path.join(auto_path,'model.ply')\\n\",\n    \"        self.renderer.gaussians.save_ply(ply_path)\\n\",\n    \"        self.renderer.gaussians.save_deformation(auto_path)\\n\",\n    \"        \\n\",\n    \"    def load(self, load_path):\\n\",\n    \"        #load\\n\",\n    \"        auto_path = load_path\\n\",\n    \"        ply_path = os.path.join(auto_path,'model.ply')\\n\",\n    \"        self.renderer.gaussians.load_model(auto_path)\\n\",\n    \"        self.renderer.gaussians.load_ply(ply_path)\\n\",\n    \"           \\n\",\n    \"           \\n\",\n    \"    def render(self,frame_id, elevation, azimuth, radius):\\n\",\n    \"        #render with parameters\\n\",\n    \"        pose = orbit_camera(elevation,azimuth,radius)\\n\",\n    \"        cam = MiniCam(\\n\",\n    \"                        pose,\\n\",\n    \"                        self.opt.ref_size,\\n\",\n    \"                        self.opt.ref_size,\\n\",\n    \"                        self.cam.fovy,\\n\",\n    \"                        self.cam.fovx,\\n\",\n    \"                        self.cam.near,\\n\",\n    \"                        self.cam.far,\\n\",\n    \"                        )   \\n\",\n    \"        cam.time=float(frame_id/30) #30 is the total frame\\n\",\n    \"        #use stage='coarse' for static rendering, use stage='fine' for dynamic rendering\\n\",\n    \"        out = self.renderer.render(cam,stage='fine')\\n\",\n    \"        image = out[\\\"image\\\"].unsqueeze(0)# [1, 3, H, W] in [0, 1]\\n\",\n    \"        \\n\",\n    \"        return image\\n\",\n    \"    \\n\",\n    \"    def train(self):\\n\",\n    \"        self.step=0\\n\",\n    \"        \\n\",\n    \"        for i in tqdm.tqdm(range(10000)):\\n\",\n    \"            self.step+=1\\n\",\n    \"            self.renderer.gaussians.update_learning_rate(self.step)\\n\",\n    \"            loss = 0\\n\",\n    \"            \\n\",\n    \"            min_ver = -30\\n\",\n    \"            max_ver = 30\\n\",\n    \"            vers, hors, radiis, poses = [], [], [], []\\n\",\n    \"            images=[]\\n\",\n    \"            viewspace_point_tensor_list, radii_list, visibility_filter_list = [], [], []\\n\",\n    \"\\n\",\n    \"            render_resolution=512\\n\",\n    \"            \\n\",\n    \"            for _ in range(self.opt.batch_size):\\n\",\n    \"                #sample time, vertical& horizontal  angle\\n\",\n    \"                ver = np.random.randint(min_ver, max_ver)\\n\",\n    \"                hor = np.random.randint(-180, 180)\\n\",\n    \"                radius=0\\n\",\n    \"                self.t = np.random.randint(0,30)\\n\",\n    \"                self.time = self.t/30\\n\",\n    \"                \\n\",\n    \"                vers.append(torch.tensor(self.opt.elevation + ver,device=self.device).unsqueeze(dim=0))\\n\",\n    \"                hors.append(torch.tensor(hor,device=self.device).unsqueeze(dim=0))\\n\",\n    \"                radiis.append(torch.tensor(self.opt.radius + radius,device=self.device).unsqueeze(dim=0))\\n\",\n    \"                \\n\",\n    \"                pose = orbit_camera(self.opt.elevation + ver, hor, self.opt.radius + radius)\\n\",\n    \"                \\n\",\n    \"                poses.append(pose)\\n\",\n    \"\\n\",\n    \"\\n\",\n    \"                cur_cam = MiniCam(\\n\",\n    \"                        pose,\\n\",\n    \"                        render_resolution,\\n\",\n    \"                        render_resolution,\\n\",\n    \"                        self.cam.fovy,\\n\",\n    \"                        self.cam.fovx,\\n\",\n    \"                        self.cam.near,\\n\",\n    \"                        self.cam.far,\\n\",\n    \"                    )\\n\",\n    \"                cur_cam.time=self.time\\n\",\n    \"                \\n\",\n    \"                bg_color = torch.tensor([1, 1, 1] if np.random.rand() > self.opt.invert_bg_prob else [0, 0, 0], dtype=torch.float32, device=\\\"cuda\\\")\\n\",\n    \"                #use stage='coarse' for static rendering, use stage='fine' for dynamic rendering\\n\",\n    \"                out = self.renderer.render(cur_cam, bg_color=bg_color,stage='fine')\\n\",\n    \"                \\n\",\n    \"                #basic values for densification\\n\",\n    \"                viewspace_point_tensor, visibility_filter, radii = out[\\\"viewspace_points\\\"], out[\\\"visibility_filter\\\"], out[\\\"radii\\\"]  \\n\",\n    \"                radii_list.append(radii.unsqueeze(0))\\n\",\n    \"                visibility_filter_list.append(visibility_filter.unsqueeze(0))\\n\",\n    \"                viewspace_point_tensor_list.append(viewspace_point_tensor)\\n\",\n    \"                \\n\",\n    \"                image = out[\\\"image\\\"].unsqueeze(0)# [1, 3, H, W] in [0, 1]\\n\",\n    \"                images.append(image)\\n\",\n    \"                \\n\",\n    \"            images_batch = torch.cat(images, dim=0)\\n\",\n    \"            poses = torch.from_numpy(np.stack(poses, axis=0)).to(self.device)\\n\",\n    \"            vers_batch = torch.cat(vers, dim=0).cpu().numpy()\\n\",\n    \"            hors_batch = torch.cat(hors, dim=0).cpu().numpy()\\n\",\n    \"            radii_batch = torch.cat(radiis, dim=0).cpu().numpy()\\n\",\n    \"\\n\",\n    \"            if self.enable_sd:\\n\",\n    \"                sd_loss = self.guidance_sd.train_step(images_batch,step_ratio=None,poses=poses)\\n\",\n    \"                # guidance loss. replace with your own diffusion model if necessary.\\n\",\n    \"                loss = loss + sd_loss\\n\",\n    \"            else:\\n\",\n    \"                zero123_loss = self.guidance_zero123.train_step(images_batch, vers_batch, hors_batch, radii_batch,step_ratio=None)\\n\",\n    \"                # guidance loss.\\n\",\n    \"                loss = loss + zero123_loss\\n\",\n    \"                \\n\",\n    \"            # optimize step\\n\",\n    \"            loss.backward()\\n\",\n    \"            self.optimizer.step()\\n\",\n    \"            self.optimizer.zero_grad()\\n\",\n    \"\\n\",\n    \"            #densifications. Adaptive densification is used here.\\n\",\n    \"            viewspace_point_tensor_grad = torch.zeros_like(viewspace_point_tensor)\\n\",\n    \"            for idx in range(0, len(viewspace_point_tensor_list)):\\n\",\n    \"                    viewspace_point_tensor_grad = viewspace_point_tensor_grad + viewspace_point_tensor_list[idx].grad\\n\",\n    \"\\n\",\n    \"            if self.step >= self.opt.density_start_iter and self.step <= self.opt.density_end_iter:\\n\",\n    \"                self.renderer.gaussians.max_radii2D[visibility_filter] = torch.max(self.renderer.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])\\n\",\n    \"                self.renderer.gaussians.add_densification_stats(viewspace_point_tensor_grad, visibility_filter)\\n\",\n    \"                if self.step % self.opt.densification_interval == 0 :\\n\",\n    \"\\n\",\n    \"                    self.renderer.gaussians.densify_and_prune(self.opt.densify_grad_threshold, min_opacity=0.01, extent=1, max_screen_size=2)\\n\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 2,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"name\": \"stdout\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"feature_dim: 128\\n\",\n      \"Number of points at initialisation :  10000\\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"application/vnd.jupyter.widget-view+json\": {\n       \"model_id\": \"a4cb52b5dc0045ccba1d352fc337cbd0\",\n       \"version_major\": 2,\n       \"version_minor\": 0\n      },\n      \"text/plain\": [\n       \"Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    },\n    {\n     \"name\": \"stderr\",\n     \"output_type\": \"stream\",\n     \"text\": [\n      \"  1%|          | 72/10000 [00:19<44:48,  3.69it/s]  \\n\"\n     ]\n    },\n    {\n     \"data\": {\n      \"text/html\": [\n       \"<pre style=\\\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\\\"><span style=\\\"color: #800000; text-decoration-color: #800000\\\">╭─────────────────────────────── </span><span style=\\\"color: #800000; text-decoration-color: #800000; font-weight: bold\\\">Traceback </span><span style=\\\"color: #bf7f7f; text-decoration-color: #bf7f7f; font-weight: bold\\\">(most recent call last)</span><span style=\\\"color: #800000; text-decoration-color: #800000\\\"> ────────────────────────────────╮</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span> in <span style=\\\"color: #00ff00; text-decoration-color: #00ff00\\\">&lt;module&gt;</span>:<span style=\\\"color: #0000ff; text-decoration-color: #0000ff\\\">6</span>                                                                                    <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>                                                                                                  <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">3 </span>opt=OmegaConf.load(<span style=\\\"color: #808000; text-decoration-color: #808000\\\">'./configs/image_4d_m.yaml'</span>)                                              <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">4 </span>                                                                                             <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">5 </span>train=trainer(opt)                                                                           <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span> <span style=\\\"color: #800000; text-decoration-color: #800000\\\">❱ </span>6 train.train()                                                                                <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">7 </span>                                                                                             <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>                                                                                                  <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span> in <span style=\\\"color: #00ff00; text-decoration-color: #00ff00\\\">train</span>:<span style=\\\"color: #0000ff; text-decoration-color: #0000ff\\\">134</span>                                                                                     <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>                                                                                                  <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">131 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   │   │   </span>loss = loss + zero123_loss                                                 <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">132 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   │   </span>                                                                               <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">133 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   │   </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"># optimize step</span>                                                                <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span> <span style=\\\"color: #800000; text-decoration-color: #800000\\\">❱ </span>134 <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   │   </span>loss.backward()                                                                <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">135 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   │   </span><span style=\\\"color: #00ffff; text-decoration-color: #00ffff\\\">self</span>.optimizer.step()                                                          <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">136 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   │   </span><span style=\\\"color: #00ffff; text-decoration-color: #00ffff\\\">self</span>.optimizer.zero_grad()                                                     <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">137 </span>                                                                                           <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>                                                                                                  <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span> <span style=\\\"color: #bfbf7f; text-decoration-color: #bfbf7f\\\">/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/</span><span style=\\\"color: #808000; text-decoration-color: #808000; font-weight: bold\\\">_tensor.py</span>:<span style=\\\"color: #0000ff; text-decoration-color: #0000ff\\\">487</span> in <span style=\\\"color: #00ff00; text-decoration-color: #00ff00\\\">backward</span> <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>                                                                                                  <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"> 484 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   │   │   </span>create_graph=create_graph,                                                <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"> 485 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   │   │   </span>inputs=inputs,                                                            <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"> 486 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   │   </span>)                                                                             <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span> <span style=\\\"color: #800000; text-decoration-color: #800000\\\">❱ </span> 487 <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   </span>torch.autograd.backward(                                                          <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"> 488 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   │   </span><span style=\\\"color: #00ffff; text-decoration-color: #00ffff\\\">self</span>, gradient, retain_graph, create_graph, inputs=inputs                     <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"> 489 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   </span>)                                                                                 <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"> 490 </span>                                                                                          <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>                                                                                                  <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span> <span style=\\\"color: #bfbf7f; text-decoration-color: #bfbf7f\\\">/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/autograd/</span><span style=\\\"color: #808000; text-decoration-color: #808000; font-weight: bold\\\">__init__.py</span>:<span style=\\\"color: #0000ff; text-decoration-color: #0000ff\\\">200</span>   <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span> in <span style=\\\"color: #00ff00; text-decoration-color: #00ff00\\\">backward</span>                                                                                      <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>                                                                                                  <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">197 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"># The reason we repeat same the comment below is that</span>                                  <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">198 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"># some Python versions print out the first line of a multi-line function</span>               <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">199 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"># calls in the traceback and some print out the last line</span>                              <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span> <span style=\\\"color: #800000; text-decoration-color: #800000\\\">❱ </span>200 <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   </span>Variable._execution_engine.run_backward(  <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"># Calls into the C++ engine to run the bac</span>   <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">201 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   </span>tensors, grad_tensors_, retain_graph, create_graph, inputs,                        <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">202 </span><span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">│   │   </span>allow_unreachable=<span style=\\\"color: #0000ff; text-decoration-color: #0000ff\\\">True</span>, accumulate_grad=<span style=\\\"color: #0000ff; text-decoration-color: #0000ff\\\">True</span>)  <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\"># Calls into the C++ engine to ru</span>   <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>   <span style=\\\"color: #7f7f7f; text-decoration-color: #7f7f7f\\\">203 </span>                                                                                           <span style=\\\"color: #800000; text-decoration-color: #800000\\\">│</span>\\n\",\n       \"<span style=\\\"color: #800000; text-decoration-color: #800000\\\">╰──────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\\n\",\n       \"<span style=\\\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\\\">KeyboardInterrupt</span>\\n\",\n       \"</pre>\\n\"\n      ],\n      \"text/plain\": [\n       \"\\u001b[31m╭─\\u001b[0m\\u001b[31m──────────────────────────────\\u001b[0m\\u001b[31m \\u001b[0m\\u001b[1;31mTraceback \\u001b[0m\\u001b[1;2;31m(most recent call last)\\u001b[0m\\u001b[31m \\u001b[0m\\u001b[31m───────────────────────────────\\u001b[0m\\u001b[31m─╮\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m in \\u001b[92m<module>\\u001b[0m:\\u001b[94m6\\u001b[0m                                                                                    \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m                                                                                                  \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m3 \\u001b[0mopt=OmegaConf.load(\\u001b[33m'\\u001b[0m\\u001b[33m./configs/image_4d_m.yaml\\u001b[0m\\u001b[33m'\\u001b[0m)                                              \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m4 \\u001b[0m                                                                                             \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m5 \\u001b[0mtrain=trainer(opt)                                                                           \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m \\u001b[31m❱ \\u001b[0m6 train.train()                                                                                \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m7 \\u001b[0m                                                                                             \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m                                                                                                  \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m in \\u001b[92mtrain\\u001b[0m:\\u001b[94m134\\u001b[0m                                                                                     \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m                                                                                                  \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m131 \\u001b[0m\\u001b[2m│   │   │   │   \\u001b[0mloss = loss + zero123_loss                                                 \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m132 \\u001b[0m\\u001b[2m│   │   │   \\u001b[0m                                                                               \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m133 \\u001b[0m\\u001b[2m│   │   │   \\u001b[0m\\u001b[2m# optimize step\\u001b[0m                                                                \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m \\u001b[31m❱ \\u001b[0m134 \\u001b[2m│   │   │   \\u001b[0mloss.backward()                                                                \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m135 \\u001b[0m\\u001b[2m│   │   │   \\u001b[0m\\u001b[96mself\\u001b[0m.optimizer.step()                                                          \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m136 \\u001b[0m\\u001b[2m│   │   │   \\u001b[0m\\u001b[96mself\\u001b[0m.optimizer.zero_grad()                                                     \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m137 \\u001b[0m                                                                                           \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m                                                                                                  \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m \\u001b[2;33m/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/\\u001b[0m\\u001b[1;33m_tensor.py\\u001b[0m:\\u001b[94m487\\u001b[0m in \\u001b[92mbackward\\u001b[0m \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m                                                                                                  \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m 484 \\u001b[0m\\u001b[2m│   │   │   │   \\u001b[0mcreate_graph=create_graph,                                                \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m 485 \\u001b[0m\\u001b[2m│   │   │   │   \\u001b[0minputs=inputs,                                                            \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m 486 \\u001b[0m\\u001b[2m│   │   │   \\u001b[0m)                                                                             \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m \\u001b[31m❱ \\u001b[0m 487 \\u001b[2m│   │   \\u001b[0mtorch.autograd.backward(                                                          \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m 488 \\u001b[0m\\u001b[2m│   │   │   \\u001b[0m\\u001b[96mself\\u001b[0m, gradient, retain_graph, create_graph, inputs=inputs                     \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m 489 \\u001b[0m\\u001b[2m│   │   \\u001b[0m)                                                                                 \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m 490 \\u001b[0m                                                                                          \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m                                                                                                  \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m \\u001b[2;33m/home/vision/miniconda3/envs/torch0/lib/python3.8/site-packages/torch/autograd/\\u001b[0m\\u001b[1;33m__init__.py\\u001b[0m:\\u001b[94m200\\u001b[0m   \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m in \\u001b[92mbackward\\u001b[0m                                                                                      \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m                                                                                                  \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m197 \\u001b[0m\\u001b[2m│   \\u001b[0m\\u001b[2m# The reason we repeat same the comment below is that\\u001b[0m                                  \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m198 \\u001b[0m\\u001b[2m│   \\u001b[0m\\u001b[2m# some Python versions print out the first line of a multi-line function\\u001b[0m               \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m199 \\u001b[0m\\u001b[2m│   \\u001b[0m\\u001b[2m# calls in the traceback and some print out the last line\\u001b[0m                              \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m \\u001b[31m❱ \\u001b[0m200 \\u001b[2m│   \\u001b[0mVariable._execution_engine.run_backward(  \\u001b[2m# Calls into the C++ engine to run the bac\\u001b[0m   \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m201 \\u001b[0m\\u001b[2m│   │   \\u001b[0mtensors, grad_tensors_, retain_graph, create_graph, inputs,                        \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m202 \\u001b[0m\\u001b[2m│   │   \\u001b[0mallow_unreachable=\\u001b[94mTrue\\u001b[0m, accumulate_grad=\\u001b[94mTrue\\u001b[0m)  \\u001b[2m# Calls into the C++ engine to ru\\u001b[0m   \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m│\\u001b[0m   \\u001b[2m203 \\u001b[0m                                                                                           \\u001b[31m│\\u001b[0m\\n\",\n       \"\\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\\u001b[0m\\n\",\n       \"\\u001b[1;91mKeyboardInterrupt\\u001b[0m\\n\"\n      ]\n     },\n     \"metadata\": {},\n     \"output_type\": \"display_data\"\n    }\n   ],\n   \"source\": [\n    \"from omegaconf import OmegaConf\\n\",\n    \"\\n\",\n    \"opt=OmegaConf.load('./configs/image_4d_m.yaml')\\n\",\n    \"\\n\",\n    \"train=trainer(opt)\\n\",\n    \"train.train()\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": 1,\n   \"metadata\": {},\n   \"outputs\": [\n    {\n     \"data\": {\n      \"text/plain\": [\n       \"0.0\"\n      ]\n     },\n     \"execution_count\": 1,\n     \"metadata\": {},\n     \"output_type\": \"execute_result\"\n    }\n   ],\n   \"source\": [\n    \"import numpy as np\\n\",\n    \"np.deg2rad(0)\"\n   ]\n  },\n  {\n   \"cell_type\": \"code\",\n   \"execution_count\": null,\n   \"metadata\": {},\n   \"outputs\": [],\n   \"source\": []\n  }\n ],\n \"metadata\": {\n  \"kernelspec\": {\n   \"display_name\": \"torch0\",\n   \"language\": \"python\",\n   \"name\": \"python3\"\n  },\n  \"language_info\": {\n   \"codemirror_mode\": {\n    \"name\": \"ipython\",\n    \"version\": 3\n   },\n   \"file_extension\": \".py\",\n   \"mimetype\": \"text/x-python\",\n   \"name\": \"python\",\n   \"nbconvert_exporter\": \"python\",\n   \"pygments_lexer\": \"ipython3\",\n   \"version\": \"3.8.17\"\n  }\n },\n \"nbformat\": 4,\n \"nbformat_minor\": 2\n}\n"
  },
  {
    "path": "requirements.txt",
    "content": "tqdm\nrich\nninja\nnumpy\npandas\nscipy\nscikit-learn\nmatplotlib\nopencv-python\nimageio\nimageio-ffmpeg\nomegaconf\nargparse\ntorch\neinops\nplyfile\npygltflib\ndearpygui\naccelerate\nrembg[gpu,cli]\n\n#zero123plus\nopencv-contrib-python\ndiffusers==0.20.2\ntransformers==4.29.2\nstreamlit==1.22.0\naltair<5\nhuggingface_hub\ngit+https://github.com/facebookresearch/segment-anything.git\ngradio>=3.50\nfire"
  },
  {
    "path": "scripts/app.py",
    "content": "import os\nimport sys\nimport numpy\nimport torch\nimport rembg\nimport threading\nimport urllib.request\nfrom PIL import Image\nimport streamlit as st\nimport huggingface_hub\n\nclass SAMAPI:\n    predictor = None\n\n    @staticmethod\n    @st.cache_resource\n    def get_instance(sam_checkpoint=None):\n        if SAMAPI.predictor is None:\n            if sam_checkpoint is None:\n                sam_checkpoint = \"tmp/sam_vit_h_4b8939.pth\"\n            if not os.path.exists(sam_checkpoint):\n                os.makedirs('tmp', exist_ok=True)\n                urllib.request.urlretrieve(\n                    \"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth\",\n                    sam_checkpoint\n                )\n            device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n            model_type = \"default\"\n\n            from segment_anything import sam_model_registry, SamPredictor\n\n            sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n            sam.to(device=device)\n\n            predictor = SamPredictor(sam)\n            SAMAPI.predictor = predictor\n        return SAMAPI.predictor\n\n    @staticmethod\n    def segment_api(rgb, mask=None, bbox=None, sam_checkpoint=None):\n        \"\"\"\n\n        Parameters\n        ----------\n        rgb : np.ndarray h,w,3 uint8\n        mask: np.ndarray h,w bool\n\n        Returns\n        -------\n\n        \"\"\"\n        np = numpy\n        predictor = SAMAPI.get_instance(sam_checkpoint)\n        predictor.set_image(rgb)\n        if mask is None and bbox is None:\n            box_input = None\n        else:\n            # mask to bbox\n            if bbox is None:\n                y1, y2, x1, x2 = np.nonzero(mask)[0].min(), np.nonzero(mask)[0].max(), np.nonzero(mask)[1].min(), \\\n                                 np.nonzero(mask)[1].max()\n            else:\n                x1, y1, x2, y2 = bbox\n            box_input = np.array([[x1, y1, x2, y2]])\n        masks, scores, logits = predictor.predict(\n            box=box_input,\n            multimask_output=True,\n            return_logits=False,\n        )\n        mask = masks[-1]\n        return mask\n\n\ndef image_examples(samples, ncols, return_key=None, example_text=\"Examples\"):\n    global img_example_counter\n    trigger = False\n    with st.expander(example_text, True):\n        for i in range(len(samples) // ncols):\n            cols = st.columns(ncols)\n            for j in range(ncols):\n                idx = i * ncols + j\n                if idx >= len(samples):\n                    continue\n                entry = samples[idx]\n                with cols[j]:\n                    st.image(entry['dispi'])\n                    img_example_counter += 1\n                    with st.columns(5)[2]:\n                        this_trigger = st.button('\\+', key='imgexuse%d' % img_example_counter)\n                    trigger = trigger or this_trigger\n                    if this_trigger:\n                        trigger = entry[return_key]\n    return trigger\n\n\ndef segment_img(img: Image):\n    output = rembg.remove(img)\n    mask = numpy.array(output)[:, :, 3] > 0\n    sam_mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)\n    segmented_img = Image.new(\"RGBA\", img.size, (0, 0, 0, 0))\n    segmented_img.paste(img, mask=Image.fromarray(sam_mask))\n    return segmented_img\n\n\ndef segment_6imgs(zero123pp_imgs):\n    imgs = [zero123pp_imgs.crop([0, 0, 320, 320]),\n            zero123pp_imgs.crop([320, 0, 640, 320]),\n            zero123pp_imgs.crop([0, 320, 320, 640]),\n            zero123pp_imgs.crop([320, 320, 640, 640]),\n            zero123pp_imgs.crop([0, 640, 320, 960]),\n            zero123pp_imgs.crop([320, 640, 640, 960])]\n    segmented_imgs = []\n    for i, img in enumerate(imgs):\n        output = rembg.remove(img)\n        mask = numpy.array(output)[:, :, 3]\n        mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)\n        data = numpy.array(img)[:,:,:3]\n        data[mask == 0] = [255, 255, 255]\n        segmented_imgs.append(data)\n    result = numpy.concatenate([\n        numpy.concatenate([segmented_imgs[0], segmented_imgs[1]], axis=1),\n        numpy.concatenate([segmented_imgs[2], segmented_imgs[3]], axis=1),\n        numpy.concatenate([segmented_imgs[4], segmented_imgs[5]], axis=1)\n    ])\n    return Image.fromarray(result)\n\n\ndef expand2square(pil_img, background_color):\n    width, height = pil_img.size\n    if width == height:\n        return pil_img\n    elif width > height:\n        result = Image.new(pil_img.mode, (width, width), background_color)\n        result.paste(pil_img, (0, (width - height) // 2))\n        return result\n    else:\n        result = Image.new(pil_img.mode, (height, height), background_color)\n        result.paste(pil_img, ((height - width) // 2, 0))\n        return result\n\n\n@st.cache_data\ndef check_dependencies():\n    reqs = []\n    try:\n        import diffusers\n    except ImportError:\n        import traceback\n        traceback.print_exc()\n        print(\"Error: `diffusers` not found.\", file=sys.stderr)\n        reqs.append(\"diffusers==0.20.2\")\n    else:\n        if not diffusers.__version__.startswith(\"0.20\"):\n            print(\n                f\"Warning: You are using an unsupported version of diffusers ({diffusers.__version__}), which may lead to performance issues.\",\n                file=sys.stderr\n            )\n            print(\"Recommended version is `diffusers==0.20.2`.\", file=sys.stderr)\n    try:\n        import transformers\n    except ImportError:\n        import traceback\n        traceback.print_exc()\n        print(\"Error: `transformers` not found.\", file=sys.stderr)\n        reqs.append(\"transformers==4.29.2\")\n    if torch.__version__ < '2.0':\n        try:\n            import xformers\n        except ImportError:\n            print(\"Warning: You are using PyTorch 1.x without a working `xformers` installation.\", file=sys.stderr)\n            print(\"You may see a significant memory overhead when running the model.\", file=sys.stderr)\n    if len(reqs):\n        print(f\"Info: Fix all dependency errors with `pip install {' '.join(reqs)}`.\")\n\n\n@st.cache_resource\ndef load_zero123plus_pipeline():\n    if 'HF_TOKEN' in os.environ:\n        huggingface_hub.login(os.environ['HF_TOKEN'])\n    from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler\n    pipeline = DiffusionPipeline.from_pretrained(\n        \"sudo-ai/zero123plus-v1.1\", custom_pipeline=\"sudo-ai/zero123plus-pipeline\",\n        torch_dtype=torch.float16\n    )\n    # Feel free to tune the scheduler\n    pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(\n        pipeline.scheduler.config, timestep_spacing='trailing'\n    )\n    if torch.cuda.is_available():\n        pipeline.to('cuda:0')\n    sys.main_lock = threading.Lock()\n    return pipeline\n"
  },
  {
    "path": "scripts/gen_mv.py",
    "content": "import torch\nimport requests\nfrom PIL import Image\nfrom diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler\nimport numpy\nimport os\nimport sys\nimport numpy\nimport torch\nimport rembg\nimport threading\nimport urllib.request\nfrom PIL import Image\nimport streamlit as st\nimport huggingface_hub\nfrom app import SAMAPI\ndef segment_img(img: Image):\n    output = rembg.remove(img)\n    mask = numpy.array(output)[:, :, 3] > 0\n    sam_mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)\n    segmented_img = Image.new(\"RGBA\", img.size, (0, 0, 0, 0))\n    segmented_img.paste(img, mask=Image.fromarray(sam_mask))\n    return segmented_img\n\n\ndef segment_6imgs(zero123pp_imgs):\n    imgs = [zero123pp_imgs.crop([0, 0, 320, 320]),\n            zero123pp_imgs.crop([320, 0, 640, 320]),\n            zero123pp_imgs.crop([0, 320, 320, 640]),\n            zero123pp_imgs.crop([320, 320, 640, 640]),\n            zero123pp_imgs.crop([0, 640, 320, 960]),\n            zero123pp_imgs.crop([320, 640, 640, 960])]\n    segmented_imgs = []\n    import numpy as np\n    for i, img in enumerate(imgs):\n        output = rembg.remove(img)\n        mask = numpy.array(output)[:, :, 3]\n        mask = SAMAPI.segment_api(numpy.array(img)[:, :, :3], mask)\n        data = numpy.array(img)[:,:,:3]\n        data2 = numpy.ones([320,320,4])\n        data2[:,:,:3] = data\n        for i in np.arange(data2.shape[0]):\n                for j in np.arange(data2.shape[1]):\n                        if mask[i,j]==1:\n                                data2[i,j,3]=255\n        segmented_imgs.append(data2)\n\n        #torch.manual_seed(42)\n    return segmented_imgs\n\ndef process_img(path,destination,pipeline, is_first):\n    # Download an example image.\n        print('processing:',path)\n        #cond_whole = Image.open('output.png')\n        cond = Image.open(path)\n        # Run the pipeline!\n        result = pipeline(cond, num_inference_steps=75,is_first = is_first).images[0]\n        # for general real and synthetic images of general objects\n        # usually it is enough to have around 28 inference steps\n        # for images with delicate details like faces (real or anime)\n        # you may need 75-100 steps for the details to construct\n\n        #result.show()\n        #result.save(\"./test_png/zero123pp/output.png\")\n        result=segment_6imgs(result)\n        print('saving:',os.path.join(destination,'0~5.png'),'in',destination)\n        for i in numpy.arange(6):\n            Image.fromarray(numpy.uint8(result[i])).save(os.path.join(destination,'{}.png'.format(i)))\n\n\n\nif __name__ == \"__main__\":\n    import argparse\n\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--path\", required=True, help=\"path to process\")\n    # DiffusionPipeline.from_pretrained cannot received relative path for custom pipeline\n    parser.add_argument(\"--pipeline_path\", required=True, help=\"path of pipeline code, in ../guidance/zero123pp\")\n    args, extras = parser.parse_known_args()\n\n\n    pipeline = DiffusionPipeline.from_pretrained(\n        \"sudo-ai/zero123plus-v1.1\", custom_pipeline=args.pipeline_path,\n        torch_dtype=torch.float16\n    )\n\n    pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(\n        pipeline.scheduler.config, timestep_spacing='trailing'\n    )\n    pipeline.to('cuda:0')\n    \n    \n    directory = args.path+'/'\n    os.makedirs(directory+'ref', exist_ok=True)\n    os.system(f\"cp -r {directory+'*.png'} {directory+'ref/'}\")\n    is_first = True\n    l=sorted(os.listdir(directory+'ref'))\n        \n\n    for file in  sorted(os.listdir(directory+'ref')):\n        if  file[-4:-1]=='.pn':\n                \n            filename =  os.path.splitext(os.path.basename(file))[0]\n            destination = os.path.join(directory+'zero123',filename)\n            \n            os.makedirs(destination, exist_ok=True)\n            img_path = os.path.join(directory+'ref',file)\n            process_img(img_path,destination,pipeline, is_first)\n            is_first = False\n\n    \n"
  },
  {
    "path": "sh_utils.py",
    "content": "#  Copyright 2021 The PlenOctree Authors.\n#  Redistribution and use in source and binary forms, with or without\n#  modification, are permitted provided that the following conditions are met:\n#\n#  1. Redistributions of source code must retain the above copyright notice,\n#  this list of conditions and the following disclaimer.\n#\n#  2. Redistributions in binary form must reproduce the above copyright notice,\n#  this list of conditions and the following disclaimer in the documentation\n#  and/or other materials provided with the distribution.\n#\n#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\"\n#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE\n#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE\n#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE\n#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR\n#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF\n#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS\n#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN\n#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)\n#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE\n#  POSSIBILITY OF SUCH DAMAGE.\n\nimport torch\n\nC0 = 0.28209479177387814\nC1 = 0.4886025119029199\nC2 = [\n    1.0925484305920792,\n    -1.0925484305920792,\n    0.31539156525252005,\n    -1.0925484305920792,\n    0.5462742152960396\n]\nC3 = [\n    -0.5900435899266435,\n    2.890611442640554,\n    -0.4570457994644658,\n    0.3731763325901154,\n    -0.4570457994644658,\n    1.445305721320277,\n    -0.5900435899266435\n]\nC4 = [\n    2.5033429417967046,\n    -1.7701307697799304,\n    0.9461746957575601,\n    -0.6690465435572892,\n    0.10578554691520431,\n    -0.6690465435572892,\n    0.47308734787878004,\n    -1.7701307697799304,\n    0.6258357354491761,\n]   \n\n\ndef eval_sh(deg, sh, dirs):\n    \"\"\"\n    Evaluate spherical harmonics at unit directions\n    using hardcoded SH polynomials.\n    Works with torch/np/jnp.\n    ... Can be 0 or more batch dimensions.\n    Args:\n        deg: int SH deg. Currently, 0-3 supported\n        sh: jnp.ndarray SH coeffs [..., C, (deg + 1) ** 2]\n        dirs: jnp.ndarray unit directions [..., 3]\n    Returns:\n        [..., C]\n    \"\"\"\n    assert deg <= 4 and deg >= 0\n    coeff = (deg + 1) ** 2\n    assert sh.shape[-1] >= coeff\n\n    result = C0 * sh[..., 0]\n    if deg > 0:\n        x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]\n        result = (result -\n                C1 * y * sh[..., 1] +\n                C1 * z * sh[..., 2] -\n                C1 * x * sh[..., 3])\n\n        if deg > 1:\n            xx, yy, zz = x * x, y * y, z * z\n            xy, yz, xz = x * y, y * z, x * z\n            result = (result +\n                    C2[0] * xy * sh[..., 4] +\n                    C2[1] * yz * sh[..., 5] +\n                    C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +\n                    C2[3] * xz * sh[..., 7] +\n                    C2[4] * (xx - yy) * sh[..., 8])\n\n            if deg > 2:\n                result = (result +\n                C3[0] * y * (3 * xx - yy) * sh[..., 9] +\n                C3[1] * xy * z * sh[..., 10] +\n                C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +\n                C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +\n                C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +\n                C3[5] * z * (xx - yy) * sh[..., 14] +\n                C3[6] * x * (xx - 3 * yy) * sh[..., 15])\n\n                if deg > 3:\n                    result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +\n                            C4[1] * yz * (3 * xx - yy) * sh[..., 17] +\n                            C4[2] * xy * (7 * zz - 1) * sh[..., 18] +\n                            C4[3] * yz * (7 * zz - 3) * sh[..., 19] +\n                            C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +\n                            C4[5] * xz * (7 * zz - 3) * sh[..., 21] +\n                            C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +\n                            C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +\n                            C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])\n    return result\n\ndef RGB2SH(rgb):\n    return (rgb - 0.5) / C0\n\ndef SH2RGB(sh):\n    return sh * C0 + 0.5"
  },
  {
    "path": "simple-knn/ext.cpp",
    "content": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n * This software is free for non-commercial, research and evaluation use \n * under the terms of the LICENSE.md file.\n *\n * For inquiries contact  george.drettakis@inria.fr\n */\n\n#include <torch/extension.h>\n#include \"spatial.h\"\n\nPYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n  m.def(\"distCUDA2\", &distCUDA2);\n}\n"
  },
  {
    "path": "simple-knn/setup.py",
    "content": "#\n# Copyright (C) 2023, Inria\n# GRAPHDECO research group, https://team.inria.fr/graphdeco\n# All rights reserved.\n#\n# This software is free for non-commercial, research and evaluation use \n# under the terms of the LICENSE.md file.\n#\n# For inquiries contact  george.drettakis@inria.fr\n#\n\nfrom setuptools import setup\nfrom torch.utils.cpp_extension import CUDAExtension, BuildExtension\nimport os\n\ncxx_compiler_flags = []\n\nif os.name == 'nt':\n    cxx_compiler_flags.append(\"/wd4624\")\n\nsetup(\n    name=\"simple_knn\",\n    ext_modules=[\n        CUDAExtension(\n            name=\"simple_knn._C\",\n            sources=[\n            \"spatial.cu\", \n            \"simple_knn.cu\",\n            \"ext.cpp\"],\n            extra_compile_args={\"nvcc\": [], \"cxx\": cxx_compiler_flags})\n        ],\n    cmdclass={\n        'build_ext': BuildExtension\n    }\n)\n"
  },
  {
    "path": "simple-knn/simple_knn/.gitkeep",
    "content": ""
  },
  {
    "path": "simple-knn/simple_knn.cu",
    "content": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n * This software is free for non-commercial, research and evaluation use \n * under the terms of the LICENSE.md file.\n *\n * For inquiries contact  george.drettakis@inria.fr\n */\n\n#define BOX_SIZE 1024\n\n#include \"cuda_runtime.h\"\n#include \"device_launch_parameters.h\"\n#include \"simple_knn.h\"\n#include <cub/cub.cuh>\n#include <cub/device/device_radix_sort.cuh>\n#include <vector>\n#include <cuda_runtime_api.h>\n#include <thrust/device_vector.h>\n#include <thrust/sequence.h>\n#define __CUDACC__\n#include <cooperative_groups.h>\n#include <cooperative_groups/reduce.h>\n\nnamespace cg = cooperative_groups;\n\nstruct CustomMin\n{\n\t__device__ __forceinline__\n\t\tfloat3 operator()(const float3& a, const float3& b) const {\n\t\treturn { min(a.x, b.x), min(a.y, b.y), min(a.z, b.z) };\n\t}\n};\n\nstruct CustomMax\n{\n\t__device__ __forceinline__\n\t\tfloat3 operator()(const float3& a, const float3& b) const {\n\t\treturn { max(a.x, b.x), max(a.y, b.y), max(a.z, b.z) };\n\t}\n};\n\n__host__ __device__ uint32_t prepMorton(uint32_t x)\n{\n\tx = (x | (x << 16)) & 0x030000FF;\n\tx = (x | (x << 8)) & 0x0300F00F;\n\tx = (x | (x << 4)) & 0x030C30C3;\n\tx = (x | (x << 2)) & 0x09249249;\n\treturn x;\n}\n\n__host__ __device__ uint32_t coord2Morton(float3 coord, float3 minn, float3 maxx)\n{\n\tuint32_t x = prepMorton(((coord.x - minn.x) / (maxx.x - minn.x)) * ((1 << 10) - 1));\n\tuint32_t y = prepMorton(((coord.y - minn.y) / (maxx.y - minn.y)) * ((1 << 10) - 1));\n\tuint32_t z = prepMorton(((coord.z - minn.z) / (maxx.z - minn.z)) * ((1 << 10) - 1));\n\n\treturn x | (y << 1) | (z << 2);\n}\n\n__global__ void coord2Morton(int P, const float3* points, float3 minn, float3 maxx, uint32_t* codes)\n{\n\tauto idx = cg::this_grid().thread_rank();\n\tif (idx >= P)\n\t\treturn;\n\n\tcodes[idx] = coord2Morton(points[idx], minn, maxx);\n}\n\nstruct MinMax\n{\n\tfloat3 minn;\n\tfloat3 maxx;\n};\n\n__global__ void boxMinMax(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes)\n{\n\tauto idx = cg::this_grid().thread_rank();\n\n\tMinMax me;\n\tif (idx < P)\n\t{\n\t\tme.minn = points[indices[idx]];\n\t\tme.maxx = points[indices[idx]];\n\t}\n\telse\n\t{\n\t\tme.minn = { FLT_MAX, FLT_MAX, FLT_MAX };\n\t\tme.maxx = { -FLT_MAX,-FLT_MAX,-FLT_MAX };\n\t}\n\n\t__shared__ MinMax redResult[BOX_SIZE];\n\n\tfor (int off = BOX_SIZE / 2; off >= 1; off /= 2)\n\t{\n\t\tif (threadIdx.x < 2 * off)\n\t\t\tredResult[threadIdx.x] = me;\n\t\t__syncthreads();\n\n\t\tif (threadIdx.x < off)\n\t\t{\n\t\t\tMinMax other = redResult[threadIdx.x + off];\n\t\t\tme.minn.x = min(me.minn.x, other.minn.x);\n\t\t\tme.minn.y = min(me.minn.y, other.minn.y);\n\t\t\tme.minn.z = min(me.minn.z, other.minn.z);\n\t\t\tme.maxx.x = max(me.maxx.x, other.maxx.x);\n\t\t\tme.maxx.y = max(me.maxx.y, other.maxx.y);\n\t\t\tme.maxx.z = max(me.maxx.z, other.maxx.z);\n\t\t}\n\t\t__syncthreads();\n\t}\n\n\tif (threadIdx.x == 0)\n\t\tboxes[blockIdx.x] = me;\n}\n\n__device__ __host__ float distBoxPoint(const MinMax& box, const float3& p)\n{\n\tfloat3 diff = { 0, 0, 0 };\n\tif (p.x < box.minn.x || p.x > box.maxx.x)\n\t\tdiff.x = min(abs(p.x - box.minn.x), abs(p.x - box.maxx.x));\n\tif (p.y < box.minn.y || p.y > box.maxx.y)\n\t\tdiff.y = min(abs(p.y - box.minn.y), abs(p.y - box.maxx.y));\n\tif (p.z < box.minn.z || p.z > box.maxx.z)\n\t\tdiff.z = min(abs(p.z - box.minn.z), abs(p.z - box.maxx.z));\n\treturn diff.x * diff.x + diff.y * diff.y + diff.z * diff.z;\n}\n\ntemplate<int K>\n__device__ void updateKBest(const float3& ref, const float3& point, float* knn)\n{\n\tfloat3 d = { point.x - ref.x, point.y - ref.y, point.z - ref.z };\n\tfloat dist = d.x * d.x + d.y * d.y + d.z * d.z;\n\tfor (int j = 0; j < K; j++)\n\t{\n\t\tif (knn[j] > dist)\n\t\t{\n\t\t\tfloat t = knn[j];\n\t\t\tknn[j] = dist;\n\t\t\tdist = t;\n\t\t}\n\t}\n}\n\n__global__ void boxMeanDist(uint32_t P, float3* points, uint32_t* indices, MinMax* boxes, float* dists)\n{\n\tint idx = cg::this_grid().thread_rank();\n\tif (idx >= P)\n\t\treturn;\n\n\tfloat3 point = points[indices[idx]];\n\tfloat best[3] = { FLT_MAX, FLT_MAX, FLT_MAX };\n\n\tfor (int i = max(0, idx - 3); i <= min(P - 1, idx + 3); i++)\n\t{\n\t\tif (i == idx)\n\t\t\tcontinue;\n\t\tupdateKBest<3>(point, points[indices[i]], best);\n\t}\n\n\tfloat reject = best[2];\n\tbest[0] = FLT_MAX;\n\tbest[1] = FLT_MAX;\n\tbest[2] = FLT_MAX;\n\n\tfor (int b = 0; b < (P + BOX_SIZE - 1) / BOX_SIZE; b++)\n\t{\n\t\tMinMax box = boxes[b];\n\t\tfloat dist = distBoxPoint(box, point);\n\t\tif (dist > reject || dist > best[2])\n\t\t\tcontinue;\n\n\t\tfor (int i = b * BOX_SIZE; i < min(P, (b + 1) * BOX_SIZE); i++)\n\t\t{\n\t\t\tif (i == idx)\n\t\t\t\tcontinue;\n\t\t\tupdateKBest<3>(point, points[indices[i]], best);\n\t\t}\n\t}\n\tdists[indices[idx]] = (best[0] + best[1] + best[2]) / 3.0f;\n}\n\nvoid SimpleKNN::knn(int P, float3* points, float* meanDists)\n{\n\tfloat3* result;\n\tcudaMalloc(&result, sizeof(float3));\n\tsize_t temp_storage_bytes;\n\n\tfloat3 init = { 0, 0, 0 }, minn, maxx;\n\n\tcub::DeviceReduce::Reduce(nullptr, temp_storage_bytes, points, result, P, CustomMin(), init);\n\tthrust::device_vector<char> temp_storage(temp_storage_bytes);\n\n\tcub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMin(), init);\n\tcudaMemcpy(&minn, result, sizeof(float3), cudaMemcpyDeviceToHost);\n\n\tcub::DeviceReduce::Reduce(temp_storage.data().get(), temp_storage_bytes, points, result, P, CustomMax(), init);\n\tcudaMemcpy(&maxx, result, sizeof(float3), cudaMemcpyDeviceToHost);\n\n\tthrust::device_vector<uint32_t> morton(P);\n\tthrust::device_vector<uint32_t> morton_sorted(P);\n\tcoord2Morton << <(P + 255) / 256, 256 >> > (P, points, minn, maxx, morton.data().get());\n\n\tthrust::device_vector<uint32_t> indices(P);\n\tthrust::sequence(indices.begin(), indices.end());\n\tthrust::device_vector<uint32_t> indices_sorted(P);\n\n\tcub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P);\n\ttemp_storage.resize(temp_storage_bytes);\n\n\tcub::DeviceRadixSort::SortPairs(temp_storage.data().get(), temp_storage_bytes, morton.data().get(), morton_sorted.data().get(), indices.data().get(), indices_sorted.data().get(), P);\n\n\tuint32_t num_boxes = (P + BOX_SIZE - 1) / BOX_SIZE;\n\tthrust::device_vector<MinMax> boxes(num_boxes);\n\tboxMinMax << <num_boxes, BOX_SIZE >> > (P, points, indices_sorted.data().get(), boxes.data().get());\n\tboxMeanDist << <num_boxes, BOX_SIZE >> > (P, points, indices_sorted.data().get(), boxes.data().get(), meanDists);\n\n\tcudaFree(result);\n}"
  },
  {
    "path": "simple-knn/simple_knn.h",
    "content": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n * This software is free for non-commercial, research and evaluation use \n * under the terms of the LICENSE.md file.\n *\n * For inquiries contact  george.drettakis@inria.fr\n */\n\n#ifndef SIMPLEKNN_H_INCLUDED\n#define SIMPLEKNN_H_INCLUDED\n\nclass SimpleKNN\n{\npublic:\n\tstatic void knn(int P, float3* points, float* meanDists);\n};\n\n#endif"
  },
  {
    "path": "simple-knn/spatial.cu",
    "content": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n * This software is free for non-commercial, research and evaluation use \n * under the terms of the LICENSE.md file.\n *\n * For inquiries contact  george.drettakis@inria.fr\n */\n\n#include \"spatial.h\"\n#include \"simple_knn.h\"\n\ntorch::Tensor\ndistCUDA2(const torch::Tensor& points)\n{\n  const int P = points.size(0);\n\n  auto float_opts = points.options().dtype(torch::kFloat32);\n  torch::Tensor means = torch::full({P}, 0.0, float_opts);\n  \n  SimpleKNN::knn(P, (float3*)points.contiguous().data<float>(), means.contiguous().data<float>());\n\n  return means;\n}"
  },
  {
    "path": "simple-knn/spatial.h",
    "content": "/*\n * Copyright (C) 2023, Inria\n * GRAPHDECO research group, https://team.inria.fr/graphdeco\n * All rights reserved.\n *\n * This software is free for non-commercial, research and evaluation use \n * under the terms of the LICENSE.md file.\n *\n * For inquiries contact  george.drettakis@inria.fr\n */\n\n#include <torch/extension.h>\n\ntorch::Tensor distCUDA2(const torch::Tensor& points);"
  },
  {
    "path": "visualize.py",
    "content": "import os\nimport cv2\nimport time\nimport tqdm\nimport numpy as np\nimport dearpygui.dearpygui as dpg\n\nimport torch\nimport torch.nn.functional as F\n\nimport rembg\nimport numpy as np\nimport matplotlib.pyplot as plt\nfrom mpl_toolkits.mplot3d import Axes3D\nfrom cam_utils import orbit_camera, OrbitCamera\nfrom gs_renderer_4d import Renderer, MiniCam\nfrom dataset_4d import SparseDataset\nfrom einops import rearrange, repeat\nimport torchvision.utils as vutils\nimport imageio\n\ndef save_image_to_local(image_tensor, file_path):\n    # Ensure the image tensor is in the range [0, 1]\n    image_tensor = image_tensor.clamp(0, 1) \n\n    # Save the image tensor to the specified file path\n    vutils.save_image(image_tensor, file_path)\nclass GUI:\n    def __init__(self, opt):\n        self.opt = opt  # shared with the trainer's opt to support in-place modification of rendering parameters.\n        self.gui = opt.gui # enable gui\n        self.W = opt.W\n        self.H = opt.H\n        self.cam = OrbitCamera(opt.W, opt.H, r=opt.radius, fovy=opt.fovy)\n\n        self.mode = \"image\"\n        self.seed = \"random\"\n\n        self.buffer_image = np.ones((self.W, self.H, 3), dtype=np.float32)\n        self.need_update = True  # update buffer_image\n\n        # models\n        self.device = torch.device(\"cuda\")\n        self.bg_remover = None\n\n        self.guidance_sd = None\n        self.guidance_zero123 = None\n\n        self.enable_sd = False\n        self.enable_zero123 = False\n\n        # renderer\n        self.renderer = Renderer(sh_degree=self.opt.sh_degree)\n        self.gaussain_scale_factor = 1\n    \n        # input image\n        self.input_img = None\n        self.input_mask = None\n        self.input_img_torch = None\n        self.input_mask_torch = None\n        self.overlay_input_img = False\n        self.overlay_input_img_ratio = 0.5\n\n\n        # input text\n        self.prompt = \"\"\n        self.negative_prompt = \"\"\n\n        # training stuff\n        self.training = False\n        self.optimizer = None\n        self.step = 0\n        self.t = 0\n        self.time =0\n        self.train_steps = 1  # steps per rendering loop\n        \n        self.path =self.opt.path\n        if self.opt.size is not None:\n            self.size = self.opt.size\n        else:\n            self.size = len(os.listdir(os.path.join(self.path,'ref')))\n        self.frames=self.size\n\n\n        # override if provide a checkpoint\n        \n        self.renderer.initialize(num_pts=5000)            \n\n\n        if self.gui:\n            dpg.create_context()\n            self.register_dpg()\n            self.test_step()\n\n    def __del__(self):\n        if self.gui:\n            dpg.destroy_context()\n\n    def seed_everything(self):\n        try:\n            seed = int(self.seed)\n        except:\n            seed = np.random.randint(0, 1000000)\n\n        os.environ[\"PYTHONHASHSEED\"] = str(seed)\n        np.random.seed(seed)\n        torch.manual_seed(seed)\n        torch.cuda.manual_seed(seed)\n        torch.backends.cudnn.deterministic = True\n        torch.backends.cudnn.benchmark = True\n\n        self.last_seed = seed\n\n        # prepare embeddings\n\n            #save_image_to_local(self.input_img_torch[0],'./valild2/ref_{}.jpg'.format(idx))\n            #save_image_to_local(self.input_img_torch_batch[idx][0],'./valild2/batch_{}.jpg'.format(idx))\n            #save_image_to_local(self.input_imgs_torch[0][0].detach(),'./valild2/ref0_{}.jpg'.format(idx))\n\n    def prepare_train(self):\n\n        self.step = 0\n\n        # setup training\n        self.renderer.gaussians.training_setup(self.opt)\n        # do not do progressive sh-level\n        self.renderer.gaussians.active_sh_degree = self.renderer.gaussians.max_sh_degree\n        self.optimizer = self.renderer.gaussians.optimizer\n\n        # default camera\n        pose = orbit_camera(self.opt.elevation, 0, self.opt.radius)\n        self.fixed_cam = MiniCam(\n            pose,\n            self.opt.ref_size,\n            self.opt.ref_size,\n            self.cam.fovy,\n            self.cam.fovx,\n            self.cam.near,\n            self.cam.far,\n        )\n        self.set_fix_cam2()\n        self.enable_sd = self.opt.lambda_sd > 0 and self.prompt != \"\"\n        self.enable_zero123 = self.opt.lambda_zero123 > 0 and self.input_img is not None\n\n        # lazy load guidance model\n        if self.guidance_sd is None and self.enable_sd:\n            if self.opt.mvdream:\n                print(f\"[INFO] loading MVDream...\")\n                from guidance.mvdream_utils import MVDream\n                self.guidance_sd = MVDream(self.device)\n                print(f\"[INFO] loaded MVDream!\")\n            else:\n                print(f\"[INFO] loading SD...\")\n                from guidance.sd_utils import StableDiffusion\n                self.guidance_sd = StableDiffusion(self.device)\n                print(f\"[INFO] loaded SD!\")\n\n\n\n\n\n        #self.renderer.gaussians.initialize_post_first_timestep()\n\n    def train_step(self):\n        starter = torch.cuda.Event(enable_timing=True)\n        ender = torch.cuda.Event(enable_timing=True)\n        starter.record()\n        \n        self.stage='fine'\n        \n        if self.opt.load_step==None: \n            self.step=8000\n        else:\n            self.step = self.opt.load_step\n        auto_path = os.path.join(self.opt.outdir,self.opt.save_path + str(self.step))\n        #os.makedirs(auto_path,exist_ok=True)\n        ply_path = os.path.join(auto_path,'model.ply')\n        self.renderer.gaussians.load_model(auto_path)\n        self.renderer.gaussians.load_ply(ply_path)\n\n\n        self.renderer.gaussians.update_learning_rate(self.step)\n        \n\n        self.save_renderings(name='front')\n        self.save_renderings(azim=180,name='back')\n        self.save_renderings(azim=-30,name='front_moving',interval=2)\n        self.save_renderings(azim=150,name='back_moving',interval=2)\n        self.save_renderings(azim=0,name='round',interval=360//self.size)\n        \n\n        ender.record()\n        torch.cuda.synchronize()\n        t = starter.elapsed_time(ender)\n\n        self.need_update = True\n\n        if self.gui:\n            dpg.set_value(\"_log_train_time\", f\"{t:.4f}ms\")\n            dpg.set_value(\n                \"_log_train_log\",\n                 f\"step = {self.step: 5d} (+{self.train_steps: 2d}) loss = {tv_loss.item():.4f}tv_loss = {loss.item():.4f}\\nzero123_loss = {zero123_loss.item():.4f}image_loss ={image_loss.item():.4f} \",\n            )\n\n    @torch.no_grad()\n    def save_renderings(self, elev=0, azim=0, radius=2, name='front', interval=0):\n        if interval==0:\n            images=[]\n            for i in np.arange(self.frames):\n                \n                pose = orbit_camera(elev, azim, radius)\n                cam = MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n                )   \n                cam.time=float(i/self.frames)\n                out = self.renderer.render(cam,stage=self.stage)\n                image = out[\"image\"].unsqueeze(0)\n                images.append(image)\n                #os.makedirs(f'./valid/{self.opt.save_path}/final_{name}',exist_ok=True)\n                #save_image_to_local(image[0].detach(),f'./valid/{self.opt.save_path}/final_{name}/{str(i).zfill(2)}.jpg')\n            samples=torch.cat(images,dim=0)\n            \n            vid = (\n                (rearrange(samples, \"t c h w -> t h w c\") * 255).clamp(0,255).detach()\n                .cpu()\n                .numpy()\n                .astype(np.uint8)\n            )\n            video_path = f'./valid/{self.opt.save_path}/video_{name}.mp4'\n            imageio.mimwrite(video_path, vid)\n        else:\n            images=[]\n            for i in np.arange(self.frames):\n                \n                pose = orbit_camera(elev, (azim+interval*i)%360, radius)\n                cam = MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n                )   \n                cam.time=float(i/self.frames)\n                out = self.renderer.render(cam,stage=self.stage)\n                image = out[\"image\"].unsqueeze(0)\n                images.append(image)\n                #os.makedirs(f'./valid/{self.opt.save_path}/final_{name}',exist_ok=True)\n                #save_image_to_local(image[0].detach(),f'./valid/{self.opt.save_path}/final_{name}/{str(i).zfill(2)}.jpg')\n            samples=torch.cat(images,dim=0)\n            \n            vid = (\n                (rearrange(samples, \"t c h w -> t h w c\") * 255).clamp(0,255).detach()\n                .cpu()\n                .numpy()\n                .astype(np.uint8)\n            )\n            video_path = f'./valid/{self.opt.save_path}/video_{name}.mp4'\n            imageio.mimwrite(video_path, vid)\n            \n        \n        \n    def set_fix_cam2(self):\n        self.fixed_cam2=[]\n        self.fixed_elevation = []\n        self.fixed_azimuth = []\n        \n        pose = orbit_camera(self.opt.elevation-30,30 , self.opt.radius)\n        self.fixed_elevation.append(-30)\n        self.fixed_azimuth.append(30)\n        self.fixed_cam2.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        \n        pose = orbit_camera(self.opt.elevation+20, 90, self.opt.radius)\n        self.fixed_elevation.append(20)\n        self.fixed_azimuth.append(90)\n        self.fixed_cam2.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        pose = orbit_camera(self.opt.elevation-30, 150, self.opt.radius)\n        self.fixed_elevation.append(-30)\n        self.fixed_azimuth.append(150)\n        self.fixed_cam2.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        \n        pose = orbit_camera(self.opt.elevation+20, 210, self.opt.radius)\n        self.fixed_elevation.append(+20)\n        self.fixed_azimuth.append(210)\n        self.fixed_cam2.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        \n        pose = orbit_camera(self.opt.elevation-30, 270, self.opt.radius)\n        self.fixed_elevation.append(-30)\n        self.fixed_azimuth.append(270)\n        self.fixed_cam2.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        \n        pose = orbit_camera(self.opt.elevation+20, 330, self.opt.radius)\n        self.fixed_elevation.append(20)\n        self.fixed_azimuth.append(330)\n        self.fixed_cam2.append(MiniCam(\n                pose,\n                self.opt.ref_size,\n                self.opt.ref_size,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n            ))\n        \n    @torch.no_grad()\n    def test_step(self):\n        # ignore if no need to update\n        if not self.need_update:\n            return\n\n        starter = torch.cuda.Event(enable_timing=True)\n        ender = torch.cuda.Event(enable_timing=True)\n        starter.record()\n\n        # should update image\n        if self.need_update:\n            # render image\n\n            cur_cam = MiniCam(\n                self.cam.pose,\n                self.W,\n                self.H,\n                self.cam.fovy,\n                self.cam.fovx,\n                self.cam.near,\n                self.cam.far,\n                time=self.time\n            )\n            #print(cur_cam.time)\n            out = self.renderer.render(cur_cam, self.gaussain_scale_factor)\n\n            buffer_image = out[self.mode]  # [3, H, W]\n\n            if self.mode in ['depth', 'alpha']:\n                buffer_image = buffer_image.repeat(3, 1, 1)\n                if self.mode == 'depth':\n                    buffer_image = (buffer_image - buffer_image.min()) / (buffer_image.max() - buffer_image.min() + 1e-20)\n\n            buffer_image = F.interpolate(\n                buffer_image.unsqueeze(0),\n                size=(self.H, self.W),\n                mode=\"bilinear\",\n                align_corners=False,\n            ).squeeze(0)\n\n            self.buffer_image = (\n                buffer_image.permute(1, 2, 0)\n                .contiguous()\n                .clamp(0, 1)\n                .contiguous()\n                .detach()\n                .cpu()\n                .numpy()\n            )\n\n            # display input_image\n            if self.overlay_input_img and self.input_img is not None:\n                self.buffer_image = (\n                    self.buffer_image * (1 - self.overlay_input_img_ratio)\n                    + self.input_img * self.overlay_input_img_ratio\n                )\n\n            self.need_update = False\n\n        ender.record()\n        torch.cuda.synchronize()\n        t = starter.elapsed_time(ender)\n\n        if self.gui:\n            dpg.set_value(\"_log_infer_time\", f\"{t:.4f}ms ({int(1000/t)} FPS)\")\n            dpg.set_value(\n                \"_texture\", self.buffer_image\n            )  # buffer must be contiguous, else seg fault!\n\n    \n    def load_input(self, file):\n        # load image\n\n        # load image\n        import glob\n        self.input_imgs=[]\n        self.input_masks=[]\n        file_list = glob.glob(self.pattern)\n        print(self.pattern,file_list)\n        for files in sorted(file_list):\n                    print(f\"Reading file: {self.pattern}\")\n                   \n                    print(f'[INFO] load image from {files}...')\n                    img = cv2.imread(files, cv2.IMREAD_UNCHANGED)\n                    if img.shape[-1] == 3:\n                        if self.bg_remover is None:\n                            self.bg_remover = rembg.new_session()\n                        img = rembg.remove(img, session=self.bg_remover)\n\n                    img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)\n                    img = img.astype(np.float32) / 255.0\n\n                    self.input_mask = img[..., 3:]\n                    # white bg\n                    self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)\n                    # bgr to rgb\n                    self.input_img = self.input_img[..., ::-1].copy()\n                    \n                    self.input_imgs.append(self.input_img)\n                    self.input_masks.append(self.input_mask)\n        \n        print(f'[INFO] load image from {file}...')\n        img = cv2.imread(file, cv2.IMREAD_UNCHANGED)\n        if img.shape[-1] == 3:\n            if self.bg_remover is None:\n                self.bg_remover = rembg.new_session()\n            img = rembg.remove(img, session=self.bg_remover)\n\n        img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_AREA)\n        img = img.astype(np.float32) / 255.0\n\n        self.input_mask = img[..., 3:]\n        # white bg\n        self.input_img = img[..., :3] * self.input_mask + (1 - self.input_mask)\n        # bgr to rgb\n        self.input_img = self.input_img[..., ::-1].copy()\n\n        # load prompt\n        file_prompt = file.replace(\"_rgba.png\", \"_caption.txt\")\n        if os.path.exists(file_prompt):\n            print(f'[INFO] load prompt from {file_prompt}...')\n            with open(file_prompt, \"r\") as f:\n                self.prompt = f.read().strip()\n                \n\n    @torch.no_grad()\n    def save_model(self, mode='geo', texture_size=1024):\n        os.makedirs(self.opt.outdir, exist_ok=True)\n        if mode == 'geo':\n            path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')\n            self.renderer.gaussians.save_ply(path)\n\n        elif mode == 'geo+tex':\n            path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')\n            self.renderer.gaussians.save_ply(path)\n\n        else:\n            path = os.path.join(self.opt.outdir, self.opt.save_path + '_model.ply')\n            self.renderer.gaussians.save_ply(path)\n\n        print(f\"[INFO] save model to {path}.\")\n\n    def register_dpg(self):\n        ### register texture\n\n        with dpg.texture_registry(show=False):\n            dpg.add_raw_texture(\n                self.W,\n                self.H,\n                self.buffer_image,\n                format=dpg.mvFormat_Float_rgb,\n                tag=\"_texture\",\n            )\n\n        ### register window\n\n        # the rendered image, as the primary window\n        with dpg.window(\n            tag=\"_primary_window\",\n            width=self.W,\n            height=self.H,\n            pos=[0, 0],\n            no_move=True,\n            no_title_bar=True,\n            no_scrollbar=True,\n        ):\n            # add the texture\n            dpg.add_image(\"_texture\")\n\n        # dpg.set_primary_window(\"_primary_window\", True)\n\n        # control window\n        with dpg.window(\n            label=\"Control\",\n            tag=\"_control_window\",\n            width=600,\n            height=self.H,\n            pos=[self.W, 0],\n            no_move=True,\n            no_title_bar=True,\n        ):\n            # button theme\n            with dpg.theme() as theme_button:\n                with dpg.theme_component(dpg.mvButton):\n                    dpg.add_theme_color(dpg.mvThemeCol_Button, (23, 3, 18))\n                    dpg.add_theme_color(dpg.mvThemeCol_ButtonHovered, (51, 3, 47))\n                    dpg.add_theme_color(dpg.mvThemeCol_ButtonActive, (83, 18, 83))\n                    dpg.add_theme_style(dpg.mvStyleVar_FrameRounding, 5)\n                    dpg.add_theme_style(dpg.mvStyleVar_FramePadding, 3, 3)\n\n            # timer stuff\n            with dpg.group(horizontal=True):\n                dpg.add_text(\"Infer time: \")\n                dpg.add_text(\"no data\", tag=\"_log_infer_time\")\n\n            def callback_setattr(sender, app_data, user_data):\n                setattr(self, user_data, app_data)\n\n            # init stuff\n            with dpg.collapsing_header(label=\"Initialize\", default_open=True):\n\n                # seed stuff\n                def callback_set_seed(sender, app_data):\n                    self.seed = app_data\n                    self.seed_everything()\n\n                dpg.add_input_text(\n                    label=\"seed\",\n                    default_value=self.seed,\n                    on_enter=True,\n                    callback=callback_set_seed,\n                )\n\n                # input stuff\n                def callback_select_input(sender, app_data):\n                    # only one item\n                    for k, v in app_data[\"selections\"].items():\n                        dpg.set_value(\"_log_input\", k)\n                        self.load_input(v)\n\n                    self.need_update = True\n\n                with dpg.file_dialog(\n                    directory_selector=False,\n                    show=False,\n                    callback=callback_select_input,\n                    file_count=1,\n                    tag=\"file_dialog_tag\",\n                    width=700,\n                    height=400,\n                ):\n                    dpg.add_file_extension(\"Images{.jpg,.jpeg,.png}\")\n\n                with dpg.group(horizontal=True):\n                    dpg.add_button(\n                        label=\"input\",\n                        callback=lambda: dpg.show_item(\"file_dialog_tag\"),\n                    )\n                    dpg.add_text(\"\", tag=\"_log_input\")\n                \n                # overlay stuff\n                with dpg.group(horizontal=True):\n\n                    def callback_toggle_overlay_input_img(sender, app_data):\n                        self.overlay_input_img = not self.overlay_input_img\n                        self.need_update = True\n\n                    dpg.add_checkbox(\n                        label=\"overlay image\",\n                        default_value=self.overlay_input_img,\n                        callback=callback_toggle_overlay_input_img,\n                    )\n\n                    def callback_set_overlay_input_img_ratio(sender, app_data):\n                        self.overlay_input_img_ratio = app_data\n                        self.need_update = True\n\n                    dpg.add_slider_float(\n                        label=\"ratio\",\n                        min_value=0,\n                        max_value=1,\n                        format=\"%.1f\",\n                        default_value=self.overlay_input_img_ratio,\n                        callback=callback_set_overlay_input_img_ratio,\n                    )\n\n                # prompt stuff\n            \n                dpg.add_input_text(\n                    label=\"prompt\",\n                    default_value=self.prompt,\n                    callback=callback_setattr,\n                    user_data=\"prompt\",\n                )\n\n                dpg.add_input_text(\n                    label=\"negative\",\n                    default_value=self.negative_prompt,\n                    callback=callback_setattr,\n                    user_data=\"negative_prompt\",\n                )\n\n                # save current model\n                with dpg.group(horizontal=True):\n                    dpg.add_text(\"Save: \")\n\n                    def callback_save(sender, app_data, user_data):\n                        self.save_model(mode=user_data)\n\n                    dpg.add_button(\n                        label=\"model\",\n                        tag=\"_button_save_model\",\n                        callback=callback_save,\n                        user_data='model',\n                    )\n                    dpg.bind_item_theme(\"_button_save_model\", theme_button)\n\n                    dpg.add_button(\n                        label=\"geo\",\n                        tag=\"_button_save_mesh\",\n                        callback=callback_save,\n                        user_data='geo',\n                    )\n                    dpg.bind_item_theme(\"_button_save_mesh\", theme_button)\n\n                    dpg.add_button(\n                        label=\"geo+tex\",\n                        tag=\"_button_save_mesh_with_tex\",\n                        callback=callback_save,\n                        user_data='geo+tex',\n                    )\n                    dpg.bind_item_theme(\"_button_save_mesh_with_tex\", theme_button)\n\n                    dpg.add_input_text(\n                        label=\"\",\n                        default_value=self.opt.save_path,\n                        callback=callback_setattr,\n                        user_data=\"save_path\",\n                    )\n\n            # training stuff\n            with dpg.collapsing_header(label=\"Train\", default_open=True):\n                # lr and train button\n                with dpg.group(horizontal=True):\n                    dpg.add_text(\"Train: \")\n\n                    def callback_train(sender, app_data):\n                        if self.training:\n                            self.training = False\n                            dpg.configure_item(\"_button_train\", label=\"start\")\n                        else:\n                            self.prepare_train()\n                            self.training = True\n                            dpg.configure_item(\"_button_train\", label=\"stop\")\n\n                    # dpg.add_button(\n                    #     label=\"init\", tag=\"_button_init\", callback=self.prepare_train\n                    # )\n                    # dpg.bind_item_theme(\"_button_init\", theme_button)\n\n                    dpg.add_button(\n                        label=\"start\", tag=\"_button_train\", callback=callback_train\n                    )\n                    dpg.bind_item_theme(\"_button_train\", theme_button)\n\n                with dpg.group(horizontal=True):\n                    dpg.add_text(\"\", tag=\"_log_train_time\")\n                    dpg.add_text(\"\", tag=\"_log_train_log\")\n\n            # rendering options\n            with dpg.collapsing_header(label=\"Rendering\", default_open=True):\n                # mode combo\n                def callback_change_mode(sender, app_data):\n                    self.mode = app_data\n                    self.need_update = True\n\n                dpg.add_combo(\n                    (\"image\", \"depth\", \"alpha\"),\n                    label=\"mode\",\n                    default_value=self.mode,\n                    callback=callback_change_mode,\n                )\n\n                # fov slider\n                def callback_set_fovy(sender, app_data):\n                    self.cam.fovy = np.deg2rad(app_data)\n                    self.need_update = True\n\n                dpg.add_slider_int(\n                    label=\"FoV (vertical)\",\n                    min_value=1,\n                    max_value=120,\n                    format=\"%d deg\",\n                    default_value=np.rad2deg(self.cam.fovy),\n                    callback=callback_set_fovy,\n                )\n\n                def callback_set_gaussain_scale(sender, app_data):\n                    self.gaussain_scale_factor = app_data\n                    self.need_update = True\n\n                dpg.add_slider_float(\n                    label=\"gaussain scale\",\n                    min_value=0,\n                    max_value=1,\n                    format=\"%.2f\",\n                    default_value=self.gaussain_scale_factor,\n                    callback=callback_set_gaussain_scale,\n                )\n\n        ### register camera handler\n\n        def callback_camera_drag_rotate_or_draw_mask(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            dx = app_data[1]\n            dy = app_data[2]\n\n            self.cam.orbit(dx, dy)\n            self.need_update = True\n\n        def callback_camera_wheel_scale(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            delta = app_data\n\n            self.cam.scale(delta)\n            self.need_update = True\n\n        def callback_camera_drag_pan(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            dx = app_data[1]\n            dy = app_data[2]\n\n            self.cam.pan(dx, dy)\n            self.need_update = True\n\n        def callback_set_mouse_loc(sender, app_data):\n            if not dpg.is_item_focused(\"_primary_window\"):\n                return\n\n            # just the pixel coordinate in image\n            self.mouse_loc = np.array(app_data)\n\n        with dpg.handler_registry():\n            # for camera moving\n            dpg.add_mouse_drag_handler(\n                button=dpg.mvMouseButton_Left,\n                callback=callback_camera_drag_rotate_or_draw_mask,\n            )\n            dpg.add_mouse_wheel_handler(callback=callback_camera_wheel_scale)\n            dpg.add_mouse_drag_handler(\n                button=dpg.mvMouseButton_Middle, callback=callback_camera_drag_pan\n            )\n\n        dpg.create_viewport(\n            title=\"Gaussian3D\",\n            width=self.W + 600,\n            height=self.H + (45 if os.name == \"nt\" else 0),\n            resizable=False,\n        )\n\n        ### global theme\n        with dpg.theme() as theme_no_padding:\n            with dpg.theme_component(dpg.mvAll):\n                # set all padding to 0 to avoid scroll bar\n                dpg.add_theme_style(\n                    dpg.mvStyleVar_WindowPadding, 0, 0, category=dpg.mvThemeCat_Core\n                )\n                dpg.add_theme_style(\n                    dpg.mvStyleVar_FramePadding, 0, 0, category=dpg.mvThemeCat_Core\n                )\n                dpg.add_theme_style(\n                    dpg.mvStyleVar_CellPadding, 0, 0, category=dpg.mvThemeCat_Core\n                )\n\n        dpg.bind_item_theme(\"_primary_window\", theme_no_padding)\n\n        dpg.setup_dearpygui()\n\n        ### register a larger font\n        # get it from: https://github.com/lxgw/LxgwWenKai/releases/download/v1.300/LXGWWenKai-Regular.ttf\n        if os.path.exists(\"LXGWWenKai-Regular.ttf\"):\n            with dpg.font_registry():\n                with dpg.font(\"LXGWWenKai-Regular.ttf\", 18) as default_font:\n                    dpg.bind_font(default_font)\n\n        # dpg.show_metrics()\n\n        dpg.show_viewport()\n\n    def render(self):\n        assert self.gui\n        while dpg.is_dearpygui_running():\n            # update texture every frame\n            if self.training:\n                self.train_step()\n            self.test_step()\n            dpg.render_dearpygui_frame()\n    \n    # no gui mode\n    def train(self, iters=500):\n        self.prepare_train()\n\n        self.train_step()\n            # do a last prune\n        #self.renderer.gaussians.prune(min_opacity=0.01, extent=1, max_screen_size=1)\n        # save\n        self.save_model(mode='model')\n        self.save_model(mode='geo+tex')\n        \n\nif __name__ == \"__main__\":\n    import argparse\n    from omegaconf import OmegaConf\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\"--config\", required=True, help=\"path to the yaml config file\")\n    args, extras = parser.parse_known_args()\n\n    # override default config from cli\n    opt = OmegaConf.merge(OmegaConf.load(args.config), OmegaConf.from_cli(extras))\n\n    gui = GUI(opt)\n\n    if opt.gui:\n        gui.render()\n    else:\n        gui.train(opt.iters)"
  },
  {
    "path": "zero123.py",
    "content": "# Copyright 2023 The HuggingFace Team. All rights reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\nimport inspect\nimport math\nimport warnings\nfrom typing import Any, Callable, Dict, List, Optional, Union\n\nimport PIL\nimport torch\nimport torchvision.transforms.functional as TF\nfrom diffusers.configuration_utils import ConfigMixin, FrozenDict, register_to_config\nfrom diffusers.image_processor import VaeImageProcessor\nfrom diffusers.models import AutoencoderKL, UNet2DConditionModel\nfrom diffusers.models.modeling_utils import ModelMixin\nfrom diffusers.pipelines.pipeline_utils import DiffusionPipeline\nfrom diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput\nfrom diffusers.pipelines.stable_diffusion.safety_checker import (\n    StableDiffusionSafetyChecker,\n)\nfrom diffusers.schedulers import KarrasDiffusionSchedulers\nfrom diffusers.utils import deprecate, is_accelerate_available, logging\nfrom diffusers.utils.torch_utils import randn_tensor\nfrom packaging import version\nfrom transformers import CLIPImageProcessor, CLIPVisionModelWithProjection\n\nlogger = logging.get_logger(__name__)  # pylint: disable=invalid-name\n\n\nclass CLIPCameraProjection(ModelMixin, ConfigMixin):\n    \"\"\"\n    A Projection layer for CLIP embedding and camera embedding.\n\n    Parameters:\n        embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `clip_embed`\n        additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the\n            projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +\n            additional_embeddings`.\n    \"\"\"\n\n    @register_to_config\n    def __init__(self, embedding_dim: int = 768, additional_embeddings: int = 4):\n        super().__init__()\n        self.embedding_dim = embedding_dim\n        self.additional_embeddings = additional_embeddings\n\n        self.input_dim = self.embedding_dim + self.additional_embeddings\n        self.output_dim = self.embedding_dim\n\n        self.proj = torch.nn.Linear(self.input_dim, self.output_dim)\n\n    def forward(\n        self,\n        embedding: torch.FloatTensor,\n    ):\n        \"\"\"\n        The [`PriorTransformer`] forward method.\n\n        Args:\n            hidden_states (`torch.FloatTensor` of shape `(batch_size, input_dim)`):\n                The currently input embeddings.\n\n        Returns:\n            The output embedding projection (`torch.FloatTensor` of shape `(batch_size, output_dim)`).\n        \"\"\"\n        proj_embedding = self.proj(embedding)\n        return proj_embedding\n\n\nclass Zero123Pipeline(DiffusionPipeline):\n    r\"\"\"\n    Pipeline to generate variations from an input image using Stable Diffusion.\n\n    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the\n    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)\n\n    Args:\n        vae ([`AutoencoderKL`]):\n            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.\n        image_encoder ([`CLIPVisionModelWithProjection`]):\n            Frozen CLIP image-encoder. Stable Diffusion Image Variation uses the vision portion of\n            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModelWithProjection),\n            specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.\n        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.\n        scheduler ([`SchedulerMixin`]):\n            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of\n            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].\n        safety_checker ([`StableDiffusionSafetyChecker`]):\n            Classification module that estimates whether generated images could be considered offensive or harmful.\n            Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.\n        feature_extractor ([`CLIPImageProcessor`]):\n            Model that extracts features from generated images to be used as inputs for the `safety_checker`.\n    \"\"\"\n    # TODO: feature_extractor is required to encode images (if they are in PIL format),\n    # we should give a descriptive message if the pipeline doesn't have one.\n    _optional_components = [\"safety_checker\"]\n\n    def __init__(\n        self,\n        vae: AutoencoderKL,\n        image_encoder: CLIPVisionModelWithProjection,\n        unet: UNet2DConditionModel,\n        scheduler: KarrasDiffusionSchedulers,\n        safety_checker: StableDiffusionSafetyChecker,\n        feature_extractor: CLIPImageProcessor,\n        clip_camera_projection: CLIPCameraProjection,\n        requires_safety_checker: bool = True,\n    ):\n        super().__init__()\n\n        if safety_checker is None and requires_safety_checker:\n            logger.warn(\n                f\"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure\"\n                \" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered\"\n                \" results in services or applications open to the public. Both the diffusers team and Hugging Face\"\n                \" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling\"\n                \" it only for use-cases that involve analyzing network behavior or auditing its results. For more\"\n                \" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\"\n            )\n\n        if safety_checker is not None and feature_extractor is None:\n            raise ValueError(\n                \"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety\"\n                \" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.\"\n            )\n\n        is_unet_version_less_0_9_0 = hasattr(\n            unet.config, \"_diffusers_version\"\n        ) and version.parse(\n            version.parse(unet.config._diffusers_version).base_version\n        ) < version.parse(\n            \"0.9.0.dev0\"\n        )\n        is_unet_sample_size_less_64 = (\n            hasattr(unet.config, \"sample_size\") and unet.config.sample_size < 64\n        )\n        if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:\n            deprecation_message = (\n                \"The configuration file of the unet has set the default `sample_size` to smaller than\"\n                \" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the\"\n                \" following: \\n- CompVis/stable-diffusion-v1-4 \\n- CompVis/stable-diffusion-v1-3 \\n-\"\n                \" CompVis/stable-diffusion-v1-2 \\n- CompVis/stable-diffusion-v1-1 \\n- runwayml/stable-diffusion-v1-5\"\n                \" \\n- runwayml/stable-diffusion-inpainting \\n you should change 'sample_size' to 64 in the\"\n                \" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`\"\n                \" in the config might lead to incorrect results in future versions. If you have downloaded this\"\n                \" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for\"\n                \" the `unet/config.json` file\"\n            )\n            deprecate(\n                \"sample_size<64\", \"1.0.0\", deprecation_message, standard_warn=False\n            )\n            new_config = dict(unet.config)\n            new_config[\"sample_size\"] = 64\n            unet._internal_dict = FrozenDict(new_config)\n\n        self.register_modules(\n            vae=vae,\n            image_encoder=image_encoder,\n            unet=unet,\n            scheduler=scheduler,\n            safety_checker=safety_checker,\n            feature_extractor=feature_extractor,\n            clip_camera_projection=clip_camera_projection,\n        )\n        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)\n        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)\n        self.register_to_config(requires_safety_checker=requires_safety_checker)\n\n    def enable_sequential_cpu_offload(self, gpu_id=0):\n        r\"\"\"\n        Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,\n        text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a\n        `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.\n        \"\"\"\n        if is_accelerate_available():\n            from accelerate import cpu_offload\n        else:\n            raise ImportError(\"Please install accelerate via `pip install accelerate`\")\n\n        device = torch.device(f\"cuda:{gpu_id}\")\n\n        for cpu_offloaded_model in [\n            self.unet,\n            self.image_encoder,\n            self.vae,\n            self.safety_checker,\n        ]:\n            if cpu_offloaded_model is not None:\n                cpu_offload(cpu_offloaded_model, device)\n\n    @property\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device\n    def _execution_device(self):\n        r\"\"\"\n        Returns the device on which the pipeline's models will be executed. After calling\n        `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module\n        hooks.\n        \"\"\"\n        if not hasattr(self.unet, \"_hf_hook\"):\n            return self.device\n        for module in self.unet.modules():\n            if (\n                hasattr(module, \"_hf_hook\")\n                and hasattr(module._hf_hook, \"execution_device\")\n                and module._hf_hook.execution_device is not None\n            ):\n                return torch.device(module._hf_hook.execution_device)\n        return self.device\n\n    def _encode_image(\n        self,\n        image,\n        elevation,\n        azimuth,\n        distance,\n        device,\n        num_images_per_prompt,\n        do_classifier_free_guidance,\n        clip_image_embeddings=None,\n        image_camera_embeddings=None,\n    ):\n        dtype = next(self.image_encoder.parameters()).dtype\n\n        if image_camera_embeddings is None:\n            if image is None:\n                assert clip_image_embeddings is not None\n                image_embeddings = clip_image_embeddings.to(device=device, dtype=dtype)\n            else:\n                if not isinstance(image, torch.Tensor):\n                    image = self.feature_extractor(\n                        images=image, return_tensors=\"pt\"\n                    ).pixel_values\n\n                image = image.to(device=device, dtype=dtype)\n                image_embeddings = self.image_encoder(image).image_embeds\n                image_embeddings = image_embeddings.unsqueeze(1)\n\n            bs_embed, seq_len, _ = image_embeddings.shape\n\n            if isinstance(elevation, float):\n                elevation = torch.as_tensor(\n                    [elevation] * bs_embed, dtype=dtype, device=device\n                )\n            if isinstance(azimuth, float):\n                azimuth = torch.as_tensor(\n                    [azimuth] * bs_embed, dtype=dtype, device=device\n                )\n            if isinstance(distance, float):\n                distance = torch.as_tensor(\n                    [distance] * bs_embed, dtype=dtype, device=device\n                )\n\n            camera_embeddings = torch.stack(\n                [\n                    torch.deg2rad(elevation),\n                    torch.sin(torch.deg2rad(azimuth)),\n                    torch.cos(torch.deg2rad(azimuth)),\n                    distance,\n                ],\n                dim=-1,\n            )[:, None, :]\n\n            image_embeddings = torch.cat([image_embeddings, camera_embeddings], dim=-1)\n\n            # project (image, camera) embeddings to the same dimension as clip embeddings\n            image_embeddings = self.clip_camera_projection(image_embeddings)\n        else:\n            image_embeddings = image_camera_embeddings.to(device=device, dtype=dtype)\n            bs_embed, seq_len, _ = image_embeddings.shape\n\n        # duplicate image embeddings for each generation per prompt, using mps friendly method\n        image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1)\n        image_embeddings = image_embeddings.view(\n            bs_embed * num_images_per_prompt, seq_len, -1\n        )\n\n        if do_classifier_free_guidance:\n            negative_prompt_embeds = torch.zeros_like(image_embeddings)\n\n            # For classifier free guidance, we need to do two forward passes.\n            # Here we concatenate the unconditional and text embeddings into a single batch\n            # to avoid doing two forward passes\n            image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings])\n\n        return image_embeddings\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker\n    def run_safety_checker(self, image, device, dtype):\n        if self.safety_checker is None:\n            has_nsfw_concept = None\n        else:\n            if torch.is_tensor(image):\n                feature_extractor_input = self.image_processor.postprocess(\n                    image, output_type=\"pil\"\n                )\n            else:\n                feature_extractor_input = self.image_processor.numpy_to_pil(image)\n            safety_checker_input = self.feature_extractor(\n                feature_extractor_input, return_tensors=\"pt\"\n            ).to(device)\n            image, has_nsfw_concept = self.safety_checker(\n                images=image, clip_input=safety_checker_input.pixel_values.to(dtype)\n            )\n        return image, has_nsfw_concept\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents\n    def decode_latents(self, latents):\n        warnings.warn(\n            \"The decode_latents method is deprecated and will be removed in a future version. Please\"\n            \" use VaeImageProcessor instead\",\n            FutureWarning,\n        )\n        latents = 1 / self.vae.config.scaling_factor * latents\n        image = self.vae.decode(latents, return_dict=False)[0]\n        image = (image / 2 + 0.5).clamp(0, 1)\n        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16\n        image = image.cpu().permute(0, 2, 3, 1).float().numpy()\n        return image\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs\n    def prepare_extra_step_kwargs(self, generator, eta):\n        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature\n        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.\n        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502\n        # and should be between [0, 1]\n\n        accepts_eta = \"eta\" in set(\n            inspect.signature(self.scheduler.step).parameters.keys()\n        )\n        extra_step_kwargs = {}\n        if accepts_eta:\n            extra_step_kwargs[\"eta\"] = eta\n\n        # check if the scheduler accepts generator\n        accepts_generator = \"generator\" in set(\n            inspect.signature(self.scheduler.step).parameters.keys()\n        )\n        if accepts_generator:\n            extra_step_kwargs[\"generator\"] = generator\n        return extra_step_kwargs\n\n    def check_inputs(self, image, height, width, callback_steps):\n        # TODO: check image size or adjust image size to (height, width)\n\n        if height % 8 != 0 or width % 8 != 0:\n            raise ValueError(\n                f\"`height` and `width` have to be divisible by 8 but are {height} and {width}.\"\n            )\n\n        if (callback_steps is None) or (\n            callback_steps is not None\n            and (not isinstance(callback_steps, int) or callback_steps <= 0)\n        ):\n            raise ValueError(\n                f\"`callback_steps` has to be a positive integer but is {callback_steps} of type\"\n                f\" {type(callback_steps)}.\"\n            )\n\n    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents\n    def prepare_latents(\n        self,\n        batch_size,\n        num_channels_latents,\n        height,\n        width,\n        dtype,\n        device,\n        generator,\n        latents=None,\n    ):\n        shape = (\n            batch_size,\n            num_channels_latents,\n            height // self.vae_scale_factor,\n            width // self.vae_scale_factor,\n        )\n        if isinstance(generator, list) and len(generator) != batch_size:\n            raise ValueError(\n                f\"You have passed a list of generators of length {len(generator)}, but requested an effective batch\"\n                f\" size of {batch_size}. Make sure the batch size matches the length of the generators.\"\n            )\n\n        if latents is None:\n            latents = randn_tensor(\n                shape, generator=generator, device=device, dtype=dtype\n            )\n        else:\n            latents = latents.to(device)\n\n        # scale the initial noise by the standard deviation required by the scheduler\n        latents = latents * self.scheduler.init_noise_sigma\n        return latents\n\n    def _get_latent_model_input(\n        self,\n        latents: torch.FloatTensor,\n        image: Optional[\n            Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]\n        ],\n        num_images_per_prompt: int,\n        do_classifier_free_guidance: bool,\n        image_latents: Optional[torch.FloatTensor] = None,\n    ):\n        if isinstance(image, PIL.Image.Image):\n            image_pt = TF.to_tensor(image).unsqueeze(0).to(latents)\n        elif isinstance(image, list):\n            image_pt = torch.stack([TF.to_tensor(img) for img in image], dim=0).to(\n                latents\n            )\n        elif isinstance(image, torch.Tensor):\n            image_pt = image\n        else:\n            image_pt = None\n\n        if image_pt is None:\n            assert image_latents is not None\n            image_pt = image_latents.repeat_interleave(num_images_per_prompt, dim=0)\n        else:\n            image_pt = image_pt * 2.0 - 1.0  # scale to [-1, 1]\n            # FIXME: encoded latents should be multiplied with self.vae.config.scaling_factor\n            # but zero123 was not trained this way\n            image_pt = self.vae.encode(image_pt).latent_dist.mode()\n            image_pt = image_pt.repeat_interleave(num_images_per_prompt, dim=0)\n        if do_classifier_free_guidance:\n            latent_model_input = torch.cat(\n                [\n                    torch.cat([latents, latents], dim=0),\n                    torch.cat([torch.zeros_like(image_pt), image_pt], dim=0),\n                ],\n                dim=1,\n            )\n        else:\n            latent_model_input = torch.cat([latents, image_pt], dim=1)\n\n        return latent_model_input\n\n    @torch.no_grad()\n    def __call__(\n        self,\n        image: Optional[\n            Union[PIL.Image.Image, List[PIL.Image.Image], torch.FloatTensor]\n        ] = None,\n        elevation: Optional[Union[float, torch.FloatTensor]] = None,\n        azimuth: Optional[Union[float, torch.FloatTensor]] = None,\n        distance: Optional[Union[float, torch.FloatTensor]] = None,\n        height: Optional[int] = None,\n        width: Optional[int] = None,\n        num_inference_steps: int = 50,\n        guidance_scale: float = 3.0,\n        num_images_per_prompt: int = 1,\n        eta: float = 0.0,\n        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,\n        latents: Optional[torch.FloatTensor] = None,\n        clip_image_embeddings: Optional[torch.FloatTensor] = None,\n        image_camera_embeddings: Optional[torch.FloatTensor] = None,\n        image_latents: Optional[torch.FloatTensor] = None,\n        output_type: Optional[str] = \"pil\",\n        return_dict: bool = True,\n        callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,\n        callback_steps: int = 1,\n        cross_attention_kwargs: Optional[Dict[str, Any]] = None,\n    ):\n        r\"\"\"\n        Function invoked when calling the pipeline for generation.\n\n        Args:\n            image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):\n                The image or images to guide the image generation. If you provide a tensor, it needs to comply with the\n                configuration of\n                [this](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json)\n                `CLIPImageProcessor`\n            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The height in pixels of the generated image.\n            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):\n                The width in pixels of the generated image.\n            num_inference_steps (`int`, *optional*, defaults to 50):\n                The number of denoising steps. More denoising steps usually lead to a higher quality image at the\n                expense of slower inference.\n            guidance_scale (`float`, *optional*, defaults to 7.5):\n                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).\n                `guidance_scale` is defined as `w` of equation 2. of [Imagen\n                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >\n                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,\n                usually at the expense of lower image quality.\n            num_images_per_prompt (`int`, *optional*, defaults to 1):\n                The number of images to generate per prompt.\n            eta (`float`, *optional*, defaults to 0.0):\n                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to\n                [`schedulers.DDIMScheduler`], will be ignored for others.\n            generator (`torch.Generator`, *optional*):\n                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)\n                to make generation deterministic.\n            latents (`torch.FloatTensor`, *optional*):\n                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image\n                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents\n                tensor will ge generated by sampling using the supplied random `generator`.\n            output_type (`str`, *optional*, defaults to `\"pil\"`):\n                The output format of the generate image. Choose between\n                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.\n            return_dict (`bool`, *optional*, defaults to `True`):\n                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a\n                plain tuple.\n            callback (`Callable`, *optional*):\n                A function that will be called every `callback_steps` steps during inference. The function will be\n                called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.\n            callback_steps (`int`, *optional*, defaults to 1):\n                The frequency at which the `callback` function will be called. If not specified, the callback will be\n                called at every step.\n\n        Returns:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:\n            [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.\n            When returning a tuple, the first element is a list with the generated images, and the second element is a\n            list of `bool`s denoting whether the corresponding generated image likely represents \"not-safe-for-work\"\n            (nsfw) content, according to the `safety_checker`.\n        \"\"\"\n        # 0. Default height and width to unet\n        height = height or self.unet.config.sample_size * self.vae_scale_factor\n        width = width or self.unet.config.sample_size * self.vae_scale_factor\n\n        # 1. Check inputs. Raise error if not correct\n        # TODO: check input elevation, azimuth, and distance\n        # TODO: check image, clip_image_embeddings, image_latents\n        self.check_inputs(image, height, width, callback_steps)\n\n        # 2. Define call parameters\n        if isinstance(image, PIL.Image.Image):\n            batch_size = 1\n        elif isinstance(image, list):\n            batch_size = len(image)\n        elif isinstance(image, torch.Tensor):\n            batch_size = image.shape[0]\n        else:\n            assert image_latents is not None\n            assert (\n                clip_image_embeddings is not None or image_camera_embeddings is not None\n            )\n            batch_size = image_latents.shape[0]\n\n        device = self._execution_device\n        # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)\n        # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`\n        # corresponds to doing no classifier free guidance.\n        do_classifier_free_guidance = guidance_scale > 1.0\n\n        # 3. Encode input image\n        if isinstance(image, PIL.Image.Image) or isinstance(image, list):\n            pil_image = image\n        elif isinstance(image, torch.Tensor):\n            pil_image = [TF.to_pil_image(image[i]) for i in range(image.shape[0])]\n        else:\n            pil_image = None\n        image_embeddings = self._encode_image(\n            pil_image,\n            elevation,\n            azimuth,\n            distance,\n            device,\n            num_images_per_prompt,\n            do_classifier_free_guidance,\n            clip_image_embeddings,\n            image_camera_embeddings,\n        )\n\n        # 4. Prepare timesteps\n        self.scheduler.set_timesteps(num_inference_steps, device=device)\n        timesteps = self.scheduler.timesteps\n\n        # 5. Prepare latent variables\n        # num_channels_latents = self.unet.config.in_channels\n        num_channels_latents = 4  # FIXME: hard-coded\n        latents = self.prepare_latents(\n            batch_size * num_images_per_prompt,\n            num_channels_latents,\n            height,\n            width,\n            image_embeddings.dtype,\n            device,\n            generator,\n            latents,\n        )\n\n        # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline\n        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)\n\n        # 7. Denoising loop\n        num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order\n        with self.progress_bar(total=num_inference_steps) as progress_bar:\n            for i, t in enumerate(timesteps):\n                # expand the latents if we are doing classifier free guidance\n                latent_model_input = self._get_latent_model_input(\n                    latents,\n                    image,\n                    num_images_per_prompt,\n                    do_classifier_free_guidance,\n                    image_latents,\n                )\n                latent_model_input = self.scheduler.scale_model_input(\n                    latent_model_input, t\n                )\n\n                # predict the noise residual\n                noise_pred = self.unet(\n                    latent_model_input,\n                    t,\n                    encoder_hidden_states=image_embeddings,\n                    cross_attention_kwargs=cross_attention_kwargs,\n                ).sample\n\n                # perform guidance\n                if do_classifier_free_guidance:\n                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)\n                    noise_pred = noise_pred_uncond + guidance_scale * (\n                        noise_pred_text - noise_pred_uncond\n                    )\n\n                # compute the previous noisy sample x_t -> x_t-1\n                latents = self.scheduler.step(\n                    noise_pred, t, latents, **extra_step_kwargs\n                ).prev_sample\n\n                # call the callback, if provided\n                if i == len(timesteps) - 1 or (\n                    (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0\n                ):\n                    progress_bar.update()\n                    if callback is not None and i % callback_steps == 0:\n                        callback(i, t, latents)\n\n        if not output_type == \"latent\":\n            image = self.vae.decode(\n                latents / self.vae.config.scaling_factor, return_dict=False\n            )[0]\n            image, has_nsfw_concept = self.run_safety_checker(\n                image, device, image_embeddings.dtype\n            )\n        else:\n            image = latents\n            has_nsfw_concept = None\n\n        if has_nsfw_concept is None:\n            do_denormalize = [True] * image.shape[0]\n        else:\n            do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]\n\n        image = self.image_processor.postprocess(\n            image, output_type=output_type, do_denormalize=do_denormalize\n        )\n\n        if not return_dict:\n            return (image, has_nsfw_concept)\n\n        return StableDiffusionPipelineOutput(\n            images=image, nsfw_content_detected=has_nsfw_concept\n        )"
  }
]