main a23745f37edd cached
30 files
163.3 KB
45.8k tokens
295 symbols
1 requests
Download .txt
Repository: zoomin-lee/scene-scale-diffusion
Branch: main
Commit: a23745f37edd
Files: 30
Total size: 163.3 KB

Directory structure:
gitextract_ghos05eb/

├── .gitignore
├── LICENSE.txt
├── README.md
├── SSC_train.py
├── __init__.py
├── datasets/
│   ├── carla.yaml
│   ├── carla_dataset.py
│   └── data.py
├── layers/
│   ├── Ablation/
│   │   └── wo_diffusion.py
│   ├── Latent_Level/
│   │   ├── stage1/
│   │   │   ├── model.py
│   │   │   ├── vector_quantizer.py
│   │   │   └── vqvae.py
│   │   └── stage2/
│   │       ├── Gen_diffusion.py
│   │       └── gen_denoise.py
│   ├── Voxel_Level/
│   │   ├── Con_Diffusion.py
│   │   ├── Gen_Diffusion.py
│   │   ├── denoise.py
│   │   └── gen_denoise.py
│   └── __init__.py
├── requirements.txt
├── setup.py
├── simple_visualize.py
├── train.py
├── utils/
│   ├── cuda.py
│   ├── dicts.py
│   ├── intermediate_vis.py
│   ├── loss.py
│   ├── multistep.py
│   └── tables.py
└── visualization.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
build/
dist/
*.egg-info/
.eggs/

# Virtual environment
venv/
env/
.venv/
.env/

# Jupyter Notebook checkpoints
.ipynb_checkpoints/

# PyInstaller
*.manifest
*.spec

# pytest
.cache/
.pytest_cache/

# mypy
.mypy_cache/

# coverage
htmlcov/
.coverage
.coverage.*

# logs and temporary files
*.log
*.tmp
*.bak

# IDEs and editors
.vscode/
.idea/
*.sublime-workspace
*.sublime-project

# OS files
.DS_Store
Thumbs.db

# dotenv / secrets
.env
.env.*



================================================
FILE: LICENSE.txt
================================================
MIT License

Copyright (c) 2023 jumin

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


================================================
FILE: README.md
================================================
# Diffusion Probabilistic Models for Scene-Scale 3D Categorical Data

📌[Paper](http://arxiv.org/abs/2301.00527)        

<img src=https://user-images.githubusercontent.com/65997635/210452550-2c7c7c6d-7260-43ce-b4b6-18d3f15fccde.png width="480"
  height="400">

Comparison of object-scale and scene scale generation (ours). Our result includes multiple objects in a generated scene,
while the object-scale generation crafts one object at a time. (a) is obtained by [Point-E](https://github.com/openai/point-e)

## Abstract
In this paper, we learn a diffusion model to generate 3D data on a scene-scale. Specifically, our model crafts a 3D scene consisting of multiple objects, while recent diffusion research has focused on a single object. To realize our goal, we represent a scene with discrete class labels, i.e., categorical distribution, to assign multiple objects into semantic categories. Thus, we extend discrete diffusion models to learn scene-scale categorical distributions. In addition, we validate that a latent diffusion model can reduce computation costs for training and deploying. To the best of our knowledge, our work is the first to apply discrete and latent diffusion for 3D categorical data on a scene-scale. We further propose to perform semantic scene completion (SSC) by learning a conditional distribution using our diffusion model, where the condition is a partial observation in a sparse point cloud. In experiments, we empirically show that our diffusion models not only generate reasonable scenes, but also perform the scene completion task better than a discriminative model. 


## Instructions
### Dataset
: We use [CarlaSC](https://umich-curly.github.io/CarlaSC.github.io/download/) cartesian dataset.

### Training
: There are some argparse in 'SSC_train.py'.
    
    python SSC_train.py 
    
- For **multi-GPU** : --distribution True
- For **Discrete Diffusion Model** : --mode gen/con/vis
- For **Latent Diffusion Model** : --mode l_vae/l_gen --l_size 882/16162/32322 --init_size 32 --l_attention True --vq_size 100

Example for training l_gen mode
  
    python SSC_train.py --mode l_gen --vq_size 100 --l_size 32322 --init_size 32 --l_attention True --log_path ./result --vqvae_path ./lst_stage.tar


### Visualization
: We save the result to a txt file using the `utils/table.py/visulization` function. 
If you use open3d, you will be able to easily visualize it.

## Result
### 3D Scene Generation
![image](https://github.com/zoomin-lee/scene-scale-diffusion/blob/main/images/3D_scene_generation.png?raw=true)

### Semantic Scene Completion
![image](https://github.com/zoomin-lee/scene-scale-diffusion/blob/main/images/table4.PNG?raw=true)


![image](https://github.com/zoomin-lee/scene-scale-diffusion/blob/main/images/semantic_scene_completion.png?raw=true)


## Acknowledgments
This project is based on the following codebase.
- [Multinomial Diffusion](https://github.com/ehoogeboom/multinomial_diffusion/tree/9d907a60536ad793efd6d2a6067b3c3d6ba9fce7)
- [MotionSC](https://github.com/UMich-CURLY/3DMapping)
- [Cylinder3D](https://github.com/xinge008/Cylinder3D)


================================================
FILE: SSC_train.py
================================================
import argparse
import os
import warnings
import time
import torch
from utils.intermediate_vis import Vis_iter

from datasets.data import *
from utils.cuda import launch
from utils.multistep import get_optim
from train import Experiment

from layers.Voxel_Level.Gen_Diffusion import Diffusion
from layers.Voxel_Level.Con_Diffusion import Con_Diffusion

from layers.Latent_Level.stage1.vqvae import vqvae
from layers.Latent_Level.stage2.Gen_diffusion import latent_diffusion

from layers.Ablation.wo_diffusion import wo_diff

# environment variables
NODE_RANK = os.environ['AZ_BATCHAI_TASK_INDEX'] if 'AZ_BATCHAI_TASK_INDEX' in os.environ else 0
NODE_RANK = int(NODE_RANK)
MASTER_ADDR, MASTER_PORT = os.environ['AZ_BATCH_MASTER_NODE'].split(':') if 'AZ_BATCH_MASTER_NODE' in os.environ else ("127.0.0.1", 29500)
MASTER_PORT = int(MASTER_PORT)
DIST_URL = 'tcp://%s:%s' % (MASTER_ADDR, MASTER_PORT)

def get_args():
    ###########
    ## Setup ##
    ###########
    parser = argparse.ArgumentParser()

    parser.add_argument('--gpu', type=int, default=None, help='GPU id to use. If given, only the specific gpu will be used, and ddp will be disabled')
    parser.add_argument('--distribution', type=bool, default=True)
    parser.add_argument('--num_node', type=int, default=1, help='number of nodes for distributed training')
    parser.add_argument('--node_rank', type=int, default=0, help='node rank for distributed training')
    parser.add_argument('--dist_url', type=str, default='tcp://127.0.0.1:29500', help='url used to set up distributed training')
    
    # Data params
    parser.add_argument('--dataset', type=str, default='carla', choices='carla')
    parser.add_argument('--dataset_dir', type=str, required=True, help='Path to the dataset directory')
    # Train params
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--num_workers', type=int, default=4)
    parser.add_argument('--pin_memory', type=eval, default=False)
    parser.add_argument('--augmentation', type=str, default=None)

    # Experiemtn params
    parser.add_argument('--clip_value', type=float, default=None)
    parser.add_argument('--clip_norm', type=float, default=None)
    parser.add_argument('--recon_loss', default=False)
    parser.add_argument('--mode', default='wo_diff', choices='gen, con, vis, l_vae l_gen, wo_diff')
    parser.add_argument('--l_size', default='32322', choices=['882', '16162', '32322'])
    parser.add_argument('--init_size', type=int, default=8)
    parser.add_argument('--l_attention', default=True)
    parser.add_argument('--vq_size', type=int, default=50)

    # Model params
    parser.add_argument('--auxiliary_loss_weight', type=int, default=0.0005)
    parser.add_argument('--diffusion_steps', type=int, default=100)
    parser.add_argument('--diffusion_dim', type=int, default=32)
    parser.add_argument('--dp_rate', type=float, default=0.)

    # Optim params
    parser.add_argument('--optimizer', type=str, default='adam')
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--warmup', type=int, default=None)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--momentum_sqr', type=float, default=0.999)
    parser.add_argument('--milestones', type=eval, default=[])
    parser.add_argument('--gamma', type=float, default=0.1)

    # Train params
    parser.add_argument('--epochs', type=int, default=5000)
    parser.add_argument('--resume', type=str, default=False)
    parser.add_argument('--resume_path', type=str, default='')
    parser.add_argument('--vqvae_path', type=str, default='')

    # Logging params
    parser.add_argument('--eval_every', type=int, default=10)
    parser.add_argument('--check_every', type=int, default=5)
    parser.add_argument('--completion_epoch', type=int, default=20)
    parser.add_argument('--log_tb', type=eval, default=True)
    parser.add_argument('--log_home', type=str, default=None)
    parser.add_argument('--log_path', type=str, default='')

    args = parser.parse_args()
    return args


def main():
    print('start!')
    args = get_args()

    if args.gpu is not None:
        warnings.warn('You have chosen a specific GPU. This will completely disable ddp.')
        torch.cuda.set_device(args.gpu)
        args.ngpus_per_node = 1
        args.world_size = 1
    else:
        if args.num_node == 1:
            args.dist_url == "auto"
        else:
            assert args.num_node > 1
        args.ngpus_per_node = torch.cuda.device_count()
        args.world_size = args.ngpus_per_node * args.num_node

    launch(start, args.ngpus_per_node, args.num_node, args.node_rank, args.dist_url, args=(args,))


def start(local_rank, args):
    args.local_rank = local_rank
    args.global_rank = args.local_rank + args.node_rank * args.ngpus_per_node
    args.distributed = args.world_size > 1

    ##################
    ## Specify data ##
    ##################
    train_loader, eval_loader, test_loader, num_classes, comp_weights, seg_weights, train_sampler = get_data(args)
    args.num_classes = num_classes

    completion_criterion = torch.nn.CrossEntropyLoss(weight=comp_weights)
    seg_criterion = torch.nn.CrossEntropyLoss(weight=seg_weights, ignore_index=0)
    similarity_criterion = torch.nn.MSELoss()

    #######################
    ## Without Diffusion ##
    #######################
    if args.mode == 'wo_diff':
        model = wo_diff(args, completion_criterion).cuda()
        if args.distribution :
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)

    ########################
    ## Discrete Diffusion ##
    ########################
    elif args.mode == 'gen':
        model = Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()
        if args.distribution :
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)

    elif args.mode == 'con':
        model = Con_Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()
        if args.distribution :
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
    
    ######################
    ## Latent Diffusion ##
    ######################
    elif args.mode == 'l_vae':
        model = vqvae(args, completion_criterion).cuda()
        if args.distribution:
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)

    elif args.mode == 'l_gen':
        Dense = vqvae(args, completion_criterion).cuda()
        dense_check = torch.load(args.vqvae_path)
        model = latent_diffusion(args, Dense, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()
        if args.distribution:
            Dense = torch.nn.parallel.DistributedDataParallel(Dense, device_ids=[args.gpu], find_unused_parameters=False)
            Dense.module.load_state_dict(dense_check['model'])
            for p in Dense.module.parameters():
                p.requires_grad = False   
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)
            
    ###################
    ## Visualization ##
    ###################
    elif args.mode == 'vis':
        model = Con_Diffusion(args, completion_criterion, auxiliary_loss_weight=args.auxiliary_loss_weight).cuda()
        if args.distribution :
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False)

    optimizer, scheduler_iter, scheduler_epoch = get_optim(args, model)
    if args.mode == 'vis':
        exp = Vis_iter(args, model, optimizer, scheduler_iter, scheduler_epoch, test_loader, args.log_path)
    
    else : 
        exp = Experiment(args, model, optimizer, scheduler_iter, scheduler_epoch,
                        train_loader, eval_loader, test_loader, train_sampler, 
                        args.log_path, args.eval_every, args.check_every)
    
    exp.run(epochs = args.epochs)

if __name__ == '__main__':
    main()


================================================
FILE: __init__.py
================================================


================================================
FILE: datasets/carla.yaml
================================================
color_map :
  0 : [255, 255, 255]  # None
  1 : [70, 70, 70]     # Building
  2 : [100, 40, 40]    # Fences
  3 : [55, 90, 80]     # Other
  4 : [255, 255, 0 ]   # Pedestrian
  5 : [153, 153, 153]  # Pole
  6 : [157, 234, 50]   # RoadLines
  7 : [0, 0, 255]      # Road
  8 : [255, 255, 255]  # Sidewalk
  9 : [0, 155, 0]      # Vegetation
  10 : [255, 0, 0]     # Vehicle
  11 : [102, 102, 156] # Wall
  12 : [220, 220, 0]   # TrafficSign
  13 : [70, 130, 180]  # Sky
  14 : [255, 255, 255] # Ground
  15 : [150, 100, 100] # Bridge
  16 : [230, 150, 140] # RailTrack
  17 : [180, 165, 180] # GuardRail
  18 : [250, 170, 30]  # TrafficLight
  19 : [110, 190, 160] # Static
  20 : [170, 120, 50]  # Dynamic
  21 : [45, 60, 150]   # Water
  22 : [145, 170, 100] # Terrain

learning_map :
  0 : 0
  1 : 1
  2 : 2
  3 : 3
  4 : 4
  5 : 5
  6 : 6
  7 : 6
  8 : 8
  9 : 9
  10: 10
  11 : 2
  12 : 5
  13 : 3
  14 : 7
  15 : 3
  16 : 3
  17 : 2
  18 : 5
  19 : 3
  20 : 3
  21 : 3
  22 : 7

remap_color_map:
  0 : [255, 255, 255]  # None
  1 : [255, 200, 0]     # Building
  2 : [255, 120, 50]    # Fences
  3 : [55, 90, 80]     # Other
  4 : [255, 30, 30]   # Pedestrian
  5 : [255, 240, 150]  # Pole
  6 : [255, 0, 255]      # Road
  7 : [175, 0, 75] # Ground
  8 : [75, 0, 75]  # Sidewalk
  9 : [0, 175, 0]      # Vegetation
  10 : [100, 150, 245]     # Vehicle

label_to_names:
  0 : Free
  1 : Building
  2 : Barrier
  3 : Other
  4 : Pedestrian
  5 : Pole
  6 : Road
  7 : Ground
  8 : Sidewalk
  9 : Vegetation
  10 : Vehicle

content :
  0 : 4166593275
  1 : 42309744
  2 : 8550180
  3 : 478193
  4 : 905663
  5 : 2801091
  6 : 6452733
  7 : 229316930
  8 : 112863867
  9 : 29816894
  10: 13839655
  11 : 15581458
  12 : 221821
  13 : 0
  14 : 7931550
  15 : 467989
  16 : 3354
  17 : 9201043
  18 : 61011
  19 : 3796746
  20 : 3217865
  21 : 215372
  22 : 79669695

remap_content : 
  0 : 4.16659328e+09
  1 : 4.23097440e+07
  2 : 3.33326810e+07
  3 : 8.17951900e+06
  4 : 9.05663000e+05
  5 : 3.08392300e+06
  6 : 2.35769663e+08
  7 : 8.76012450e+07
  8 : 1.12863867e+08
  9 : 2.98168940e+07
  10 : 1.38396550e+07

================================================
FILE: datasets/carla_dataset.py
================================================
import os
import numpy as np
import random
import json
import yaml
import torch
import numba as nb
from torch.utils.data import Dataset

base_dir = os.path.dirname(__file__)
config_file = os.path.join(base_dir, 'carla.yaml')
carla_config = yaml.safe_load(open(config_file, 'r'))
LABELS_REMAP = carla_config["learning_map"]
REMAP_FREQUENCIES = carla_config["remap_content"]
FREQUENCIES= carla_config["content"]

LABELS_REMAP = np.asarray(list(LABELS_REMAP.values()))
frequencies_cartesian = np.asarray(list(FREQUENCIES.values()))
remap_frequencies_cartesian = np.asarray(list(REMAP_FREQUENCIES.values()))

class CarlaDataset(Dataset):
    """Carla Simulation Dataset for 3D mapping project
    Access to the processed data, including evaluation labels predictions velodyne poses times
    """
    def __init__(self, directory,
        voxelize_input=True,
        binary_counts=True,
        random_flips=False,
        remap=True,
        num_frames=1,
        transform_pose=True,
        get_gt=True,
        ):
        '''Constructor.
        Parameters:
            directory: directory to the dataset
        '''
        self.get_gt = get_gt
        self.voxelize_input = voxelize_input
        self.binary_counts = binary_counts
        self._directory = directory
        self._num_frames = num_frames
        self.random_flips = random_flips
        self.remap = remap
        self.transform_pose = transform_pose
        self.sparse_output = True
        
        self._scenes = sorted(os.listdir(self._directory))
        self._scenes = [os.path.join(scene, "cartesian") for scene in self._scenes]

        self._num_scenes = len(self._scenes)
        self._num_frames_scene = []

        param_file = os.path.join(self._directory, self._scenes[0], 'evaluation', 'params.json')
        with open(param_file) as f:
            self._eval_param = json.load(f)
        
        self._out_dim = self._eval_param['num_channels']
        self._grid_size = self._eval_param['grid_size']
        self.grid_dims = np.asarray(self._grid_size)
        self._eval_size = list(np.uint32(self._grid_size))
        
        self.coor_ranges = self._eval_param['min_bound'] + self._eval_param['max_bound']
        self.voxel_sizes = [abs(self.coor_ranges[3] - self.coor_ranges[0]) / self._grid_size[0], 
                      abs(self.coor_ranges[4] - self.coor_ranges[1]) / self._grid_size[1],
                      abs(self.coor_ranges[5] - self.coor_ranges[2]) / self._grid_size[2]]
        self.min_bound = np.asarray(self.coor_ranges[:3])
        self.max_bound = np.asarray(self.coor_ranges[3:])
        self.voxel_sizes = np.asarray(self.voxel_sizes)

        self._velodyne_list = []
        self._label_list = []
        self._pred_list = []
        self._eval_labels = []
        self._eval_counts = []
        self._frames_list = []
        self._timestamps = []
        self._poses = [] 

        for scene in self._scenes:
            velodyne_dir = os.path.join(self._directory, scene, 'velodyne')
            label_dir = os.path.join(self._directory, scene, 'labels')
            pred_dir = os.path.join(self._directory, scene, 'predictions')
            eval_dir = os.path.join(self._directory, scene, 'evaluation')
            
            self._num_frames_scene.append(len(os.listdir(velodyne_dir)))

            frames_list = [os.path.splitext(filename)[0] for filename in sorted(os.listdir(velodyne_dir))]
            self._frames_list.extend(frames_list)
            self._velodyne_list.extend([os.path.join(velodyne_dir, str(frame).zfill(6)+'.bin') for frame in frames_list])
            self._label_list.extend([os.path.join(label_dir, str(frame).zfill(6)+'.label') for frame in frames_list])
            self._pred_list.extend([os.path.join(pred_dir, str(frame).zfill(6)+'.bin') for frame in frames_list])
            self._eval_labels.extend([os.path.join(eval_dir, str(frame).zfill(6)+'.label') for frame in frames_list])
            self._eval_counts.extend([os.path.join(eval_dir, str(frame).zfill(6) + '.bin') for frame in frames_list])
            self._timestamps.append(np.loadtxt(os.path.join(self._directory, scene, 'times.txt')))
            self._poses.append(np.loadtxt(os.path.join(self._directory, scene, 'poses.txt')))
            # for poses and timestamps
        self._timestamps = np.array(self._timestamps).reshape(sum(self._num_frames_scene))
        self._poses = np.array(self._poses).reshape(sum(self._num_frames_scene), 12)
        
        self._cum_num_frames = np.cumsum(np.array(self._num_frames_scene) - self._num_frames + 1)

    # Use all frames, if there is no data then zero pad
    def __len__(self):
        return sum(self._num_frames_scene)
    
    def collate_fn(self, data):
        voxel_batch = [bi[0] for bi in data]
        output_batch = [bi[1] for bi in data]
        counts_batch = [bi[2] for bi in data]
        return voxel_batch, output_batch, counts_batch
    
    def points_to_voxels(self, voxel_grid, points, t_i):
        # Valid voxels (make sure to clip)
        voxels = np.floor((points - self.min_bound) / self.voxel_sizes).astype(np.int32)
        # Clamp to account for any floating point errors
        maxes = np.reshape(self.grid_dims - 1, (1, 3))
        mins = np.zeros_like(maxes)
        voxels = np.clip(voxels, mins, maxes).astype(np.int32)
        # This line is needed to create a mask with number of points, not just binary occupied
        if self.binary_counts:
            voxel_grid[t_i, voxels[:, 0], voxels[:, 1], voxels[:, 2]] += 1
        else:
            unique_voxels, counts = np.unique(voxels, return_counts=True, axis=0)
            unique_voxels = unique_voxels.astype(np.int32)
            voxel_grid[t_i, unique_voxels[:, 0], unique_voxels[:, 1], unique_voxels[:, 2]] += counts
        return voxel_grid

    def get_pose(self, idx):
        pose = np.zeros((4, 4))
        pose[3, 3] = 1
        pose[:3, :4] = self._poses[idx].reshape(3, 4)
        return pose

    def __getitem__(self, idx):
        # -1 indicates no data
        # the final index is the output
        idx_range = self.find_horizon(idx)
        if self.transform_pose:
            ego_pose = self.get_pose(idx_range[-1])
            to_ego = np.linalg.inv(ego_pose)
         
        if self.voxelize_input:
            voxel_input = np.zeros((idx_range.shape[0], int(self.grid_dims[0]), int(self.grid_dims[1]), int(self.grid_dims[2])), dtype=np.float32)
        t_i = 0

        for i in idx_range:
            if i == -1: # Zero pad
                points = np.zeros((1, 3), dtype=np.float32)
                
            else:
                points = np.fromfile(self._velodyne_list[i],dtype=np.float32).reshape(-1, 4)[:, :3]

                if self.transform_pose:
                    to_world = self.get_pose(i)
                    relative_pose = np.matmul(to_ego, to_world)
                    points = np.dot(relative_pose[:3, :3], points.T).T + relative_pose[:3, 3]

                valid_point_mask= np.all((points < self.max_bound) & (points >= self.min_bound), axis=1)
                valid_points = points[valid_point_mask, :]

            if self.voxelize_input:
                voxel_input = self.points_to_voxels(voxel_input, valid_points, t_i)

            t_i += 1

        if self.get_gt:
            output = np.fromfile(self._eval_labels[idx_range[-1]],dtype=np.uint32).reshape(self._eval_size).astype(np.uint8)
            counts = np.fromfile(self._eval_counts[idx_range[-1]],dtype=np.float32).reshape(self._eval_size)
        else:
            output = None
            counts = None

        if self.voxelize_input and self.random_flips:
            # X flip
            if np.random.randint(2):
                output = np.flip(output, axis=0)
                counts = np.flip(counts, axis=0)
                voxel_input = np.flip(voxel_input, axis=1) # Because there is a time dimension
            # Y Flip
            if np.random.randint(2):
                output = np.flip(output, axis=1)
                counts = np.flip(counts, axis=1)
                voxel_input = np.flip(voxel_input, axis=2) # Because there is a time dimension
                
        if self.remap:
            output = LABELS_REMAP[output].astype(np.uint8)            

        return voxel_input, output, counts
        
        # no enough frames
    
    def find_horizon(self, idx):
        end_idx = idx
        idx_range = np.arange(idx-self._num_frames, idx)+1
        diffs = np.asarray([int(self._frames_list[end_idx]) - int(self._frames_list[i]) for i in idx_range])
        good_difs = -1 * (np.arange(-self._num_frames, 0) + 1)
        
        idx_range[good_difs != diffs] = -1

        return idx_range


================================================
FILE: datasets/data.py
================================================
import os
import math
import torch
import numpy as np
from torch.utils.data import DataLoader
from datasets.carla_dataset import *

dataset_choices = {'carla', 'kitti'}


def get_data_id(args):
    return '{}'.format(args.dataset)

def get_class_weights(freq):
    '''
    Cless weights being 1/log(fc) (https://arxiv.org/pdf/2008.10559.pdf)
    '''
    epsilon_w = 0.001  # eps to avoid zero division
    weights = torch.from_numpy(1 / np.log(freq + epsilon_w))

    return weights

def get_data(args):
    assert args.dataset in dataset_choices
    if args.dataset == 'carla':
        train_dir = os.path.join(args.dataset_dir, "Train")
        val_dir   = os.path.join(args.dataset_dir, "Val")
        test_dir  = os.path.join(args.dataset_dir, "Test")

        x_dim = 128
        y_dim = 128
        z_dim = 8
        data_shape = [x_dim, y_dim, z_dim]
        args.data_shape= data_shape

        binary_counts = True
        transform_pose = True
        remap = True
        if remap:
            class_frequencies = remap_frequencies_cartesian
            args.num_classes = 11
        else:
            args.num_classes = 23

        comp_weights = get_class_weights(class_frequencies).to(torch.float32)
        seg_weights = get_class_weights(class_frequencies[1:]).to(torch.float32)

        train_ds = CarlaDataset(directory=train_dir, random_flips=True, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose)
        coor_ranges = train_ds._eval_param['min_bound'] + train_ds._eval_param['max_bound']
        voxel_sizes = [abs(coor_ranges[3] - coor_ranges[0]) / x_dim,
                    abs(coor_ranges[4] - coor_ranges[1]) / y_dim,
                    abs(coor_ranges[5] - coor_ranges[2]) / z_dim] # since BEV
        val_ds = CarlaDataset(directory=val_dir, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose)
        test_ds = CarlaDataset(directory=test_dir, remap=remap, binary_counts=binary_counts, transform_pose=transform_pose)

        if args is not None and args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_ds, shuffle=True)
            val_sampler = torch.utils.data.distributed.DistributedSampler(val_ds, shuffle=False)
            train_iters = len(train_sampler) // args.batch_size
            val_iters = len(val_sampler) // args.batch_size
        else:
            train_sampler = None
            val_sampler = None
            train_iters = len(train_ds) // args.batch_size
            val_iters = len(val_ds) // args.batch_size
        
        dataloader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=(train_sampler is None), sampler=train_sampler, collate_fn=train_ds.collate_fn, num_workers=args.num_workers)
        dataloader_val = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, sampler=val_sampler, collate_fn=val_ds.collate_fn, num_workers=args.num_workers)
        dataloader_test = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, collate_fn=test_ds.collate_fn, num_workers=args.num_workers)
    else:
        raise NotImplementedError("Wrong `dataset` has come. Other datasets are not supported.")
    
    
    return dataloader, dataloader_val, dataloader_test, args.num_classes, comp_weights, seg_weights, train_sampler


================================================
FILE: layers/Ablation/wo_diffusion.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from layers.Latent_Level.stage1.model import C_Encoder, C_Decoder

class wo_diff(torch.nn.Module):
    def __init__(self, args, multi_criterion) -> None:
        super(wo_diff, self).__init__()
        self.args = args

        if self.args.dataset == 'kitti':
            init_size = args.init_size
        elif self.args.dataset == 'carla':
            init_size = args.init_size
        
        self.encoder = C_Encoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)
        self.decoder = C_Decoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)
        
        self.multi_criterion = multi_criterion

    def device(self):
        return self.encoder.device

    def forward(self, x, input_ten):
        latent = self.encoder(input_ten, out_conv=False) 
        recons = self.decoder(latent, in_conv=False)
        recons_loss = self.multi_criterion(recons, x)
        return recons_loss 

    def sample(self, x):
        latent = self.encoder(x, out_conv=False) 
        recons = self.decoder(latent, in_conv=False)
        recons = recons.argmax(1)
        return recons


================================================
FILE: layers/Latent_Level/stage1/model.py
================================================
import numpy as np
import math
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import nn, einsum


def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)

def conv1x1x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)

def conv1x3x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)

def conv3x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)

def conv3x1x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)


class Asymmetric_Residual_Block(nn.Module):
    def __init__(self, in_filters, out_filters):
        super(Asymmetric_Residual_Block, self).__init__()
        self.conv1 = conv1x3x3(in_filters, out_filters)
        self.act1 = nn.LeakyReLU()          
        self.conv1_2 = conv3x1x3(out_filters, out_filters)
        self.act1_2 = nn.LeakyReLU()

        self.conv2 = conv3x1x3(in_filters, out_filters)
        self.act2 = nn.LeakyReLU()

        self.conv3 = conv1x3x3(out_filters, out_filters)
        self.act3 = nn.LeakyReLU()

        if in_filters<32 :
            self.GroupNorm = nn.GroupNorm(8, in_filters)
            self.bn0 = nn.GroupNorm(8, out_filters)
            self.bn0_2 = nn.GroupNorm(8, out_filters)
            self.bn1 = nn.GroupNorm(8, out_filters)
            self.bn2 = nn.GroupNorm(8, out_filters)
        else :
            self.GroupNorm = nn.GroupNorm(32, in_filters)
            self.bn0 = nn.GroupNorm(32, out_filters)
            self.bn0_2 = nn.GroupNorm(32, out_filters)
            self.bn1 = nn.GroupNorm(32, out_filters)
            self.bn2 = nn.GroupNorm(32, out_filters)


    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = self.act1(shortcut)
        shortcut = self.bn0(shortcut)

        shortcut = self.conv1_2(shortcut)
        shortcut = self.act1_2(shortcut)
        shortcut = self.bn0_2(shortcut)

        resA = self.conv2(x) 
        resA = self.act2(resA)
        resA = self.bn1(resA)

        resA = self.conv3(resA) 
        resA = self.act3(resA)
        resA = self.bn2(resA)
        resA += shortcut

        return resA


class DownBlock(nn.Module):
    def __init__(self, in_filters, out_filters, pooling=True, drop_out=True, height_pooling=False):
        super(DownBlock, self).__init__()
        self.pooling = pooling
        self.drop_out = drop_out
        self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters)
        if pooling:
            if height_pooling:
                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,padding=1, bias=False)
            else:
                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),padding=1, bias=False)

    def forward(self, x):
        resA = self.residual_block(x)
        if self.pooling:
            resB = self.pool(resA) 
            return resB, resA
        else:
            return resA


class UpBlock(nn.Module):
    def __init__(self, in_filters, out_filters, height_pooling):
        super(UpBlock, self).__init__()
        # self.drop_out = drop_out
        self.trans_dilao = conv3x3x3(in_filters, out_filters)
        self.trans_act = nn.LeakyReLU()

        self.conv1 = conv1x3x3(out_filters, out_filters)
        self.act1 = nn.LeakyReLU()

        self.conv2 = conv3x1x3(out_filters, out_filters)
        self.act2 = nn.LeakyReLU()

        self.conv3 = conv3x3x3(out_filters, out_filters)
        self.act3 = nn.LeakyReLU()

        if out_filters<32 :
            self.trans_bn = nn.GroupNorm(8, out_filters)
            self.bn1 = nn.GroupNorm(8, out_filters)
            self.bn2 = nn.GroupNorm(8, out_filters)
            self.bn3 = nn.GroupNorm(8, out_filters)
        else :
            self.trans_bn = nn.GroupNorm(32, out_filters)
            self.bn1 = nn.GroupNorm(32, out_filters)
            self.bn2 = nn.GroupNorm(32, out_filters)
            self.bn3 = nn.GroupNorm(32, out_filters)
        
        if height_pooling :
            self.up_subm = nn.ConvTranspose3d(out_filters, out_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)
        else : 
            self.up_subm = nn.ConvTranspose3d(out_filters, out_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)


    def forward(self, x, skip=False): 
        if skip :
            x, residual = x
        upA = self.trans_dilao(x)
        upA = self.trans_act(upA)
        upA = self.trans_bn(upA) 

        upA = self.up_subm(upA)
        if skip :
            upA += residual
        upE = self.conv1(upA)
        upE = self.act1(upE)
        upE = self.bn1(upE)

        upE = self.conv2(upE)
        upE = self.act2(upE)
        upE = self.bn2(upE)

        upE = self.conv3(upE)
        upE = self.act3(upE)
        upE = self.bn3(upE)
        return upE


class DDCM(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):
        super(DDCM, self).__init__()
        self.conv1 = conv3x1x1(in_filters, out_filters)
        self.act1 = nn.Sigmoid()

        self.conv1_2 = conv1x3x1(in_filters, out_filters)
        self.act1_2 = nn.Sigmoid()

        self.conv1_3 = conv1x1x3(in_filters, out_filters)
        self.act1_3 = nn.Sigmoid()

        if in_filters<32 :
            self.bn0 = nn.GroupNorm(8, out_filters)
            self.bn0_2 = nn.GroupNorm(8, out_filters)
            self.bn0_3 = nn.GroupNorm(8, out_filters)
        else :
            self.bn0 = nn.GroupNorm(32, out_filters)
            self.bn0_2 = nn.GroupNorm(32, out_filters)
            self.bn0_3 = nn.GroupNorm(32, out_filters)

    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = self.bn0(shortcut)
        shortcut = self.act1(shortcut)

        shortcut2 = self.conv1_2(x)
        shortcut2 = self.bn0_2(shortcut2)
        shortcut2 = self.act1_2(shortcut2)

        shortcut3 = self.conv1_3(x)
        shortcut3 = self.bn0_3(shortcut3)
        shortcut3 = self.act1_3(shortcut3)
        shortcut = shortcut + shortcut2 + shortcut3

        shortcut = shortcut * x
        return shortcut

def l2norm(t):
    return F.normalize(t, dim = -1)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, scale = 10):
        super().__init__()
        self.scale = scale
        self.heads = heads
        self.to_qkv = conv1x1(dim, dim*3, stride=1)
        self.to_out = conv1x1(dim, dim, stride=1)

    def forward(self, x):
        b, c, h, w, Z = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)

        q, k = map(l2norm, (q, k))

        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
        return self.to_out(out)

class C_Encoder(nn.Module):
    def __init__(self, args,  nclasses=20, init_size=16, l_size='882', attention=True):
        super(C_Encoder, self).__init__()
        self.nclasses = nclasses
        self.args = args
        self.l_size = l_size
        self.attention = attention

        self.embedding = nn.Embedding(nclasses, init_size)

        self.A = Asymmetric_Residual_Block(init_size, init_size)

        self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True)
        self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True)
        self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False)
        self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False)
        
        if self.l_size == '32322':
            self.midBlock1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)
            self.attention = Attention(4 * init_size, 32)
            self.midBlock2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)
            self.out = nn.Conv3d(4 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True)
        elif self.l_size == '16162':
            self.midBlock1 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)
            self.attention = Attention(8 * init_size, 32)
            self.midBlock2 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)
            self.out = nn.Conv3d(8 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True)
        elif self.l_size == '882':
            self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
            self.attention = Attention(16 * init_size, 32)
            self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
            self.out = nn.Conv3d(16 * init_size, nclasses, kernel_size=3, stride=1, padding=1,bias=True)
        else:
            raise NotImplementedError("Unsupported `l_size` has come")
        
    def forward(self, x, out_conv=True):
        x = self.embedding(x)
        x = x.permute(0, 4, 1, 2, 3)

        x = self.A(x)
        x, down1b = self.downBlock1(x)
        x, down2b = self.downBlock2(x)

        if self.l_size == '882':
            x, down3b = self.downBlock3(x)
            x, down4b = self.downBlock4(x)
        elif self.l_size == '16162':
            x, down3b = self.downBlock3(x)
        
        if self.attention : 
            x = self.midBlock1(x) # (4, 128, 32, 32, 2)
            x = self.attention(x)
            x = self.midBlock2(x) # (4, 128, 32, 32, 2)
        if out_conv : 
            x = self.out(x)
        return x

class C_Decoder(nn.Module):
    def __init__(self, args, nclasses=20, init_size=16, l_size='882', attention=True):
        super(C_Decoder, self).__init__()
        self.nclasses = nclasses
        self.args = args
        self.l_size = l_size
        self.attention = attention

        if l_size == '882':
            self.conv_in = nn.Conv3d(nclasses, 16 * init_size, kernel_size=3, stride=1, padding=1,bias=True)
            self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
            self.attention = Attention(16 * init_size, 32)
            self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
        elif l_size == '16162':
            self.conv_in = nn.Conv3d(nclasses, 8 * init_size, kernel_size=3, stride=1, padding=1,bias=True)
            self.midBlock1 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)
            self.attention = Attention(8 * init_size, 32)
            self.midBlock2 = Asymmetric_Residual_Block(8 * init_size, 8 * init_size)
        elif (l_size =='32322'):
            self.conv_in = nn.Conv3d(nclasses, 4 * init_size, kernel_size=3, stride=1, padding=1,bias=True)
            self.midBlock1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)
            self.attention = Attention(4 * init_size, 32)
            self.midBlock2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)

        self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False)
        self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False)
        self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True)
        self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True)
        self.DDCM = DDCM(2 * init_size, 2 * init_size)
        self.logits = nn.Conv3d(4 * init_size, self.nclasses, kernel_size=3, stride=1, padding=1, bias=True)

    def forward(self, x, in_conv=True):
        if in_conv :
            x = self.conv_in(x)

        if self.attention : 
            x = self.midBlock1(x)
            x = self.attention(x)
            x = self.midBlock2(x)                    

        if self.l_size == '882':
            x = self.upBlock4(x)
            x = self.upBlock3(x)
            
        elif self.l_size == '16162':
            x = self.upBlock3(x)

        x = self.upBlock2(x)
        up1 = self.upBlock1(x)

        up0 = self.DDCM(up1) 
        up = torch.cat((up1, up0), 1) 
        logits = self.logits(up) 
        return logits

class Completion(nn.Module):
    def __init__(self, args, num_class = 11, init_size=32):
        super(Completion, self).__init__()
        self.args = args
        self.num_class = num_class
        self.init_size = init_size

        self.embedding = nn.Embedding(self.num_class, init_size)

        self.A = Asymmetric_Residual_Block(init_size, init_size)

        self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True)
        self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True)
        self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False)
        self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False)
        
        self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)
        self.attention = Attention(16 * init_size, 32)
        self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size)

        self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False)
        self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False)
        self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True)
        self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True)

        self.DDCM = DDCM(2 * init_size, 2 * init_size)
        self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True)
        

    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(0, 4, 1, 2, 3)

        x = self.A(x)
        down1c, down1b = self.downBlock1(x)
        down2c, down2b = self.downBlock2(down1c) 
        down3c, down3b = self.downBlock3(down2c)
        down4c, down4b = self.downBlock4(down3c) 

        down4c = self.midBlock1(down4c) 
        down4c = self.attention(down4c)
        down4c = self.midBlock2(down4c) 
        
        up4 = self.upBlock4((down4c, down4b), skip=True)
        up3 = self.upBlock3((up4, down3b), skip=True)
        up2 = self.upBlock2((up3, down2b), skip=True)
        up1 = self.upBlock1((up2, down1b), skip=True)

        up0 = self.DDCM(up1) 
        up = torch.cat((up1, up0), 1) 
        logits = self.logits(up) 
        return logits


================================================
FILE: layers/Latent_Level/stage1/vector_quantizer.py
================================================
import torch
from torch import nn
from torch.nn import functional as F

class VectorQuantizer(nn.Module):

    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 beta: float = 0.25):
        super(VectorQuantizer, self).__init__()
        self.K = num_embeddings
        self.D = embedding_dim
        self.beta = beta

        self.embedding = nn.Embedding(self.K, self.D)
        self.embedding.weight.data.uniform_(-1 / self.K, 1 / self.K)

    def forward(self, z: torch.tensor, point=False) -> torch.tensor: # latents (8, 128, 8, 8, 2)
        z = z.permute(0, 2, 3, 4, 1).contiguous()  # [B x D x H x W x Z] -> [B x H x W x Z x D]
        latents_shape = z.shape # ( 8, 8, 8, 2, 128 )
        flat_latents = z.view(-1, self.D)  # [BHWZ x D] = [1024, 128]

        # Compute L2 distance between latents and embedding weights
        dist = torch.sum(flat_latents ** 2, dim=1, keepdim=True) + torch.sum(self.embedding.weight ** 2, dim=1) - \
               2 * torch.matmul(flat_latents, self.embedding.weight.t())  # [BHWZ x K]

        # Get the encoding that has the min distance
        min_encoding_indices = torch.argmin(dist, dim=1).unsqueeze(1)  # [BHWZ, 1]

        z_q = self.embedding(min_encoding_indices).view(z.shape)

        # Compute the VQ Losses
        commitment_loss = F.mse_loss(z_q.detach(), z)
        embedding_loss = F.mse_loss(z_q, z.detach())
        if point :
            vq_loss = commitment_loss * self.beta
        else :
            vq_loss = commitment_loss * self.beta + embedding_loss

        # Add the residue back to the latents
        z_q = z + (z_q - z).detach()

        return z_q.permute(0, 4, 1, 2, 3).contiguous(), vq_loss, min_encoding_indices, latents_shape

    def codebook_to_embedding(self, encoding_inds, latents_shape): # latents (16, 512, 8, 8, 2)
        # Convert to one-hot encodings
        z_q = self.embedding(encoding_inds).view(latents_shape)
        return z_q.permute(0, 4, 1, 2, 3).contiguous()


================================================
FILE: layers/Latent_Level/stage1/vqvae.py
================================================
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import math
from utils.loss import lovasz_softmax
from layers.Latent_Level.stage1.model import C_Encoder, C_Decoder
from layers.Latent_Level.stage1.vector_quantizer import VectorQuantizer

class vqvae(torch.nn.Module):
    def __init__(self, args, multi_criterion) -> None:
        super(vqvae, self).__init__()
        self.args = args

        init_size = args.init_size
        embedding_dim = int(self.args.num_classes)
        
        self.VQ = VectorQuantizer(num_embeddings = int(self.args.num_classes)*int(self.args.vq_size), embedding_dim = embedding_dim)

        self.encoder = C_Encoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)
        self.quant_conv = nn.Conv3d(self.args.num_classes, self.args.num_classes, kernel_size=1, stride=1)

        self.decoder = C_Decoder(args, nclasses=self.args.num_classes, init_size=init_size, l_size=args.l_size, attention=args.l_attention)
        self.post_quant_conv = nn.Conv3d(self.args.num_classes, self.args.num_classes, kernel_size=1, stride=1)

        self.multi_criterion = multi_criterion

    def device(self):
        return self.encoder.device

    def encode(self, x):
        latent = self.encoder(x) 
        latent = self.quant_conv(latent)
        return latent

    def vector_quantize(self, latent):
        quantized_latent, vq_loss, quantized_latent_ind, latents_shape = self.VQ(latent)
        return quantized_latent, vq_loss, quantized_latent_ind, latents_shape

    def coodbook(self,quantized_latent_ind, latents_shape):
        quantized_latent = self.VQ.codebook_to_embedding(quantized_latent_ind.view(-1,1), latents_shape)
        return quantized_latent

    def decode(self, quantized_latent):
        quantized_latent = self.post_quant_conv(quantized_latent)
        recons = self.decoder(quantized_latent)
        return recons

    def forward(self, x, input_ten):
        latent = self.encode(x) 
        quantized_latent, vq_loss, _, _ = self.vector_quantize(latent) 
        recons = self.decode(quantized_latent)

        recons_loss = self.multi_criterion(recons, x)
        loss = recons_loss + vq_loss 
        return loss 

    def sample(self, x):
        latent = self.encode(x)
        quantized_latent, _, _, _ = self.vector_quantize(latent)
        recons = self.decode(quantized_latent)
        recons = recons.argmax(1)
        return recons


================================================
FILE: layers/Latent_Level/stage2/Gen_diffusion.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
import math
from inspect import isfunction
from layers.Latent_Level.stage2.gen_denoise import Denoise

"""
Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281
"""
eps = 1e-8


def sum_except_batch(x, num_dims=1):
    return x.reshape(*x.shape[:num_dims], -1).sum(-1)


def log_1_min_a(a):
    return torch.log(1 - a.exp() + 1e-40)


def log_add_exp(a, b):
    maximum = torch.max(a, b)
    return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))


def exists(x):
    return x is not None


def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def log_categorical(log_x_start, log_prob):
    return (log_x_start.exp() * log_prob).sum(dim=1)


def index_to_log_onehot(x, num_classes):
    assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
    
    x_onehot = F.one_hot(x, num_classes)
    permute_order = (0, -1) + tuple(range(1, len(x.size())))
    x_onehot = x_onehot.permute(permute_order)
    log_x = torch.log(x_onehot.float().clamp(min=1e-30))

    return log_x


def log_onehot_to_index(log_x):
    return log_x.argmax(1)


def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = np.linspace(0, steps, steps)
    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])

    alphas = np.clip(alphas, a_min=0.001, a_max=1.)
    alphas = np.sqrt(alphas)

    return alphas

class latent_diffusion(torch.nn.Module):
    def __init__(self, args, VAE_DENSE, multi_criterion,
                 auxiliary_loss_weight=0.0005, adaptive_auxiliary_loss=True):
        super(latent_diffusion, self).__init__()
        self.args = args
        self.num_classes = self.args.num_classes * self.args.vq_size
        self.denoise = Denoise(args= self.args,  num_class = self.num_classes)
        
        self.num_timesteps = self.args.diffusion_steps
        self.auxiliary_loss_weight = auxiliary_loss_weight
        self.adaptive_auxiliary_loss = adaptive_auxiliary_loss

        self.VAE_DENSE = VAE_DENSE
        self.multi_criterion = multi_criterion

        alphas = cosine_beta_schedule(self.num_timesteps )

        alphas = torch.tensor(alphas.astype('float64'))
        log_alpha = np.log(alphas)
        log_cumprod_alpha = np.cumsum(log_alpha)

        log_1_min_alpha = log_1_min_a(log_alpha)
        log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)

        assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
        assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
        assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5

        # Convert to float32 and register buffers.
        self.register_buffer('log_alpha', log_alpha.float())
        self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())
        self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())
        self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())

        self.register_buffer('Lt_history', torch.zeros(self.num_timesteps ))
        self.register_buffer('Lt_count', torch.zeros(self.num_timesteps ))
    
    def device(self):
        return self.denoise.device

    def multinomial_kl(self, log_prob1, log_prob2):
        kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
        return kl

    def q_pred_one_timestep(self, log_x_t, t):
        log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
        log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)

        # alpha_t * E[xt] + (1 - alpha_t) 1 / K
        
        log_probs = log_add_exp(
            log_x_t + log_alpha_t,
            log_1_min_alpha_t - np.log(self.num_classes)
        )

        return log_probs

    def q_pred(self, log_x_start, t):
        log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
        log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)

        log_probs = log_add_exp(
            log_x_start + log_cumprod_alpha_t,
            log_1_min_cumprod_alpha - np.log(self.num_classes)
        )

        return log_probs

    def predict_start(self, log_x_t, t):
        x_t = log_onehot_to_index(log_x_t)

        out = self.denoise(x_t, t)

        assert out.size(0) == x_t.size(0)
        assert out.size(1) == self.num_classes
        assert out.size()[2:] == x_t.size()[1:]

        log_pred = F.log_softmax(out, dim=1)
        return log_pred

    def q_posterior(self, log_x_start, log_x_t, t):
        # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
        # where q(xt | xt-1, x0) = q(xt | xt-1).

        t_minus_1 = t - 1
        # Remove negative values, will not be used anyway for final decoder
        t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
        log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)

        num_axes = (1,) * (len(log_x_start.size()) - 1)
        t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
        log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)


        # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
        # Not very easy to see why this is true. But it is :)
        unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)

        log_EV_xtmin_given_xt_given_xstart = \
            unnormed_logprobs \
            - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)

        return log_EV_xtmin_given_xt_given_xstart

    def p_pred(self, log_x, t):
        log_x0_recon = self.predict_start(log_x, t=t)
        log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t)
        return log_model_pred, log_x0_recon

    def log_sample_categorical(self, logits):
        uniform = torch.rand_like(logits)
        gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
        sample = (gumbel_noise + logits).argmax(dim=1)
        log_sample = index_to_log_onehot(sample, self.num_classes)
        return log_sample

    def q_sample(self, log_x_start, t):
        log_EV_qxt_x0 = self.q_pred(log_x_start, t)
        log_sample = self.log_sample_categorical(log_EV_qxt_x0)
        return log_sample

    def kl_prior(self, log_x_start):
        b = log_x_start.size(0)
        device = log_x_start.device
        ones = torch.ones(b, device=device).long()

        log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
        log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))

        kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
        return sum_except_batch(kl_prior)

    def sample_time(self, b, device, method='uniform'):
        if method == 'importance':
            if not (self.Lt_count > 10).all():
                return self.sample_time(b, device, method='uniform')

            Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
            Lt_sqrt[0] = Lt_sqrt[1]  # Overwrite decoder term with L1.
            pt_all = Lt_sqrt / Lt_sqrt.sum()

            t = torch.multinomial(pt_all, num_samples=b, replacement=True)

            pt = pt_all.gather(dim=0, index=t)

            return t, pt

        elif method == 'uniform':
            t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

            pt = torch.ones_like(t).float() / self.num_timesteps
            return t, pt
        else:
            raise ValueError

    def forward(self, x, input_data):
        b, device = x.size(0), x.device
        self.shape = x.size()[1:]
        
        latent = self.VAE_DENSE.encode(x)
        _, _, dense_ind, latents_shape = self.VAE_DENSE.vector_quantize(latent)
        reshape_size = [latent.size()[0], latent.size()[2], latent.size()[3], latent.size()[4]]

        t, pt = self.sample_time(b, device, 'importance')

        log_x_start = index_to_log_onehot(dense_ind.view(reshape_size), self.num_classes)
        log_x_t = self.q_sample(log_x_start=log_x_start, t=t) # log_x_t : (8,551,8,8,2)

        log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t)

        log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t)

        kl = self.multinomial_kl(log_true_prob, log_model_prob)
        kl = sum_except_batch(kl)

        decoder_nll = -log_categorical(log_x_start, log_model_prob)
        decoder_nll = sum_except_batch(decoder_nll)

        mask = (t == torch.zeros_like(t)).float()
        kl_loss = mask * decoder_nll + (1. - mask) * kl
        
        if self.training:
            Lt2 = kl_loss.pow(2)
            Lt2_prev = self.Lt_history.gather(dim=0, index=t)
            new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()
            self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)
            self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))

        kl_prior = self.kl_prior(log_x_start)

        # Upweigh loss term of the kl
        loss = kl_loss / pt + kl_prior

        kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:])
        kl_aux = sum_except_batch(kl_aux)
        kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux
        if self.adaptive_auxiliary_loss:
            addition_loss_weight = (1-t/self.num_timesteps) + 1.0
        else:
            addition_loss_weight = 1.0

        aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt
        loss += aux_loss
        loss = -loss.sum() / (math.log(2) * dense_ind.view(reshape_size).shape.numel())

        x0 = log_onehot_to_index(F.log_softmax(log_x0_recon, dim=1))

        return -loss

    def sample(self, x):
        device = self.log_alpha.device
        self.shape = x.size()[1:]
        
        x = torch.randint(self.args.num_classes, size=x.size()).to(device)
        latent = self.VAE_DENSE.encode(x)
        _, _, sparse_ind, latents_shape = self.VAE_DENSE.vector_quantize(latent)
        reshape_size = [latent.size()[0], latent.size()[2], latent.size()[3], latent.size()[4]]

        log_z = index_to_log_onehot(sparse_ind.view(reshape_size), self.num_classes) # log_x_t : (8,551,8,8,2)

        for i in reversed(range(0, self.num_timesteps)):
            print(f'Sample timestep {i:4d}', end='\r')

            t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long)

            log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t)

            uniform = torch.rand_like(log_model_prob)
            gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
            pre_sample = gumbel_noise + log_model_prob

            sample = pre_sample.argmax(dim=1)                  # (32,  1, 32, 64)
            log_z = index_to_log_onehot(sample, self.num_classes)

        vq_ind = log_onehot_to_index(log_z)
        vq_latent = self.VAE_DENSE.coodbook(vq_ind.view(-1,1), latents_shape)
        recons = self.VAE_DENSE.decode(vq_latent)
        recons = recons.argmax(1)
        return recons


================================================
FILE: layers/Latent_Level/stage2/gen_denoise.py
================================================
import math
from mimetypes import init
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import nn, einsum


def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)


def conv1x1x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)


def conv1x3x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)


def conv3x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)


def conv3x1x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)


class Asymmetric_Residual_Block(nn.Module):
    def __init__(self, in_filters, out_filters, time_filters=128):
        super(Asymmetric_Residual_Block, self).__init__()
        self.GroupNorm = nn.GroupNorm(32, in_filters)
        self.time_layers = nn.Sequential(
                            nn.SiLU(),
                            nn.Linear(time_filters, in_filters*2)
                        )

        self.conv1 = conv1x3x3(in_filters, out_filters)
        self.bn0 = nn.GroupNorm(32, out_filters)
        self.act1 = nn.LeakyReLU()
          
        self.conv1_2 = conv3x1x3(out_filters, out_filters)
        self.bn0_2 = nn.GroupNorm(32, out_filters)
        self.act1_2 = nn.LeakyReLU()

        self.conv2 = conv3x1x3(in_filters, out_filters)
        self.act2 = nn.LeakyReLU()
        self.bn1 = nn.GroupNorm(32, out_filters)

        self.conv3 = conv1x3x3(out_filters, out_filters)
        self.act3 = nn.LeakyReLU()
        self.bn2 = nn.GroupNorm(32, out_filters)


    def forward(self, x, t):
        t = self.time_layers(t)
        while len(t.shape) < len(x.shape):
            t = t[..., None]
        scale, shift = torch.chunk(t, 2, dim=1)
        
        x = self.GroupNorm(x) * (1 + scale) + shift

        shortcut = self.conv1(x)
        shortcut = self.act1(shortcut)
        shortcut = self.bn0(shortcut)

        shortcut = self.conv1_2(shortcut)
        shortcut = self.act1_2(shortcut)
        shortcut = self.bn0_2(shortcut)

        resA = self.conv2(x) 
        resA = self.act2(resA)
        resA = self.bn1(resA)

        resA = self.conv3(resA)
        resA = self.act3(resA)
        resA = self.bn2(resA)
        resA += shortcut

        return resA

class DDCM(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):
        super(DDCM, self).__init__()
        self.conv1 = conv3x1x1(in_filters, out_filters)
        self.bn0 = nn.GroupNorm(32, out_filters)
        self.act1 = nn.Sigmoid()

        self.conv1_2 = conv1x3x1(in_filters, out_filters)
        self.bn0_2 = nn.GroupNorm(32, out_filters)
        self.act1_2 = nn.Sigmoid()

        self.conv1_3 = conv1x1x3(in_filters, out_filters)
        self.bn0_3 = nn.GroupNorm(32, out_filters)
        self.act1_3 = nn.Sigmoid()

    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = self.bn0(shortcut)
        shortcut = self.act1(shortcut)

        shortcut2 = self.conv1_2(x)
        shortcut2 = self.bn0_2(shortcut2)
        shortcut2 = self.act1_2(shortcut2)

        shortcut3 = self.conv1_3(x)
        shortcut3 = self.bn0_3(shortcut3)
        shortcut3 = self.act1_3(shortcut3)
        shortcut = shortcut + shortcut2 + shortcut3

        shortcut = shortcut * x

        return shortcut

def l2norm(t):
    return F.normalize(t, dim = -1)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, scale = 10):
        super().__init__()
        self.scale = scale
        self.heads = heads
        self.to_qkv = conv1x1(dim, dim*3, stride=1)
        self.to_out = conv1x1(dim, dim, stride=1)

    def forward(self, x):
        b, c, h, w, Z = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)

        q, k = map(l2norm, (q, k))

        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
        return self.to_out(out)

class Cross_Attention(nn.Module):
    def __init__(self, dim, heads = 4, scale = 10):
        super().__init__()
        self.scale = scale
        self.heads = heads
        self.to_q = conv1x1(dim, dim, stride=1)
        self.to_k = conv1x1(dim, dim, stride=1)
        self.to_v = conv1x1(dim, dim, stride=1)

        self.to_out = conv1x1(dim, dim, stride=1)

    def forward(self, x, cond_x):
        b, c, h, w, Z = x.shape
        q = self.to_q(x)
        k = self.to_k(cond_x)
        v = self.to_v(cond_x)

        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v))

        q, k = map(l2norm, (q, k))

        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
        return self.to_out(out)

class DownBlock(nn.Module):
    def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), stride=1,
                 pooling=True, drop_out=True, height_pooling=False):
        super(DownBlock, self).__init__()
        self.pooling = pooling
        self.drop_out = drop_out

        self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters)

        if pooling:
            if height_pooling:
                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,
                                                padding=1, bias=False)
            else:
                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),
                                                padding=1, bias=False)


    def forward(self, x, t):
        resA = self.residual_block(x, t)
        if self.pooling:
            resB = self.pool(resA) 
            return resB, resA
        else:
            return resA

class UpBlock(nn.Module):
    def __init__(self, in_filters, out_filters, height_pooling, time_filters=32*4):
        super(UpBlock, self).__init__()
        # self.drop_out = drop_out
        self.trans_dilao = conv3x3x3(in_filters, in_filters)
        self.trans_act = nn.LeakyReLU()
        self.trans_bn = nn.GroupNorm(32, in_filters)
        self.time_layers = nn.Sequential(
                            nn.SiLU(),
                            nn.Linear(time_filters, in_filters*2)
                        )

        self.conv1 = conv1x3x3(in_filters, out_filters)
        self.act1 = nn.LeakyReLU()
        self.bn1 = nn.GroupNorm(32, out_filters)

        self.conv2 = conv3x1x3(out_filters, out_filters)
        self.act2 = nn.LeakyReLU()
        self.bn2 = nn.GroupNorm(32, out_filters)

        self.conv3 = conv3x3x3(out_filters, out_filters)
        self.act3 = nn.LeakyReLU()
        self.bn3 = nn.GroupNorm(32, out_filters)
        
        if height_pooling :
            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)
        else : 
            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)
    

    def forward(self, x, residual, t):
        upA = self.trans_dilao(x) 
        upA = self.trans_act(upA)

        t = self.time_layers(t)
        while len(t.shape) < len(x.shape):
            t = t[..., None]
        scale, shift = torch.chunk(t, 2, dim=1)
        
        upA = self.trans_bn(upA) * (1 + scale) + shift
        ## upsample
        upA = self.up_subm(upA)
        upA += residual
        upE = self.conv1(upA)
        upE = self.act1(upE)
        upE = self.bn1(upE)

        upE = self.conv2(upE)
        upE = self.act2(upE)
        upE = self.bn2(upE)

        upE = self.conv3(upE)
        upE = self.act3(upE)
        upE = self.bn3(upE)

        return upE

def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    return embedding

class Denoise(nn.Module):
    def __init__(self, args, num_class = 11, init_size=32, discrete=True):
        super(Denoise, self).__init__()
        self.args = args
        self.discrete = discrete
        self.num_class = num_class
        self.init_size = init_size
        self.time_size = init_size*4

        self.time_embed = nn.Sequential(
            nn.Linear(init_size, self.time_size),
            nn.SiLU(),
            nn.Linear(self.time_size, self.time_size),
        )

        self.embedding = nn.Embedding(self.num_class, init_size)
        self.conv_in = nn.Conv3d(init_size, init_size, kernel_size=1, stride=1)

        self.A = Asymmetric_Residual_Block(init_size, init_size)

        self.midBlock1_1 = Asymmetric_Residual_Block(init_size, 2 * init_size)
        self.attention1 = Attention(2 * init_size, 4)
        self.midBlock1_2 = Asymmetric_Residual_Block(2 * init_size, 2 * init_size)

        self.downBlock2 = DownBlock(init_size*2, 2 * init_size, 0.2, height_pooling=False)
        self.downBlock3 = DownBlock(2 * init_size, 4 * init_size, 0.2, height_pooling=False)
        
        self.midBlock2_1 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)
        self.attention2 = Attention(4 * init_size, 4)
        self.midBlock2_2 = Asymmetric_Residual_Block(4 * init_size, 4 * init_size)

        self.upBlock0 = UpBlock(4 * init_size, 2 * init_size, height_pooling=False)
        self.upBlock1 = UpBlock(2 * init_size, init_size, height_pooling=False)

        self.midBlock3_1 = Asymmetric_Residual_Block(init_size, init_size)
        self.attention3 = Attention(init_size, 4)
        self.midBlock3_2 = Asymmetric_Residual_Block(init_size, init_size)

        self.DDCM = DDCM(init_size, init_size)

        self.logits = nn.Sequential(
            nn.Conv3d(2 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True),
        )

    def forward(self, x, t):
        x = self.embedding(x)
        x = x.permute(0, 4, 1, 2, 3)
        x = self.conv_in(x)
        t = self.time_embed(timestep_embedding(t, self.init_size))

        ret = self.A(x, t)

        mid1 = self.midBlock1_1(ret, t)
        att = self.attention1(mid1)
        mid2 = self.midBlock1_2(att, t)

        down1c, down1b = self.downBlock2(mid2, t) 
        down2c, down2b = self.downBlock3(down1c, t) 

        d_mid2 = self.midBlock2_1(down2c, t) 
        d_att = self.attention2(d_mid2)
        d_mid1 = self.midBlock2_2(d_att, t) 

        up3e = self.upBlock0(d_mid1, down2b, t)
        up2e = self.upBlock1(up3e, down1b, t)

        u_mid2 = self.midBlock3_1(up2e, t) 
        u_att = self.attention3(u_mid2)
        u_mid1 = self.midBlock3_2(u_att, t) 

        up0e = self.DDCM(u_mid1) 
        up0e = torch.cat((up0e, up2e), 1) 
        logits = self.logits(up0e) 
        
        return logits


================================================
FILE: layers/Voxel_Level/Con_Diffusion.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
import math
from inspect import isfunction
from layers.Voxel_Level.denoise import Denoise
from utils.loss import *
"""
Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281
"""
eps = 1e-8


def sum_except_batch(x, num_dims=1):
    return x.reshape(*x.shape[:num_dims], -1).sum(-1)


def log_1_min_a(a):
    return torch.log(1 - a.exp() + 1e-40)


def log_add_exp(a, b):
    maximum = torch.max(a, b)
    return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))


def exists(x):
    return x is not None


def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def log_categorical(log_x_start, log_prob):
    return (log_x_start.exp() * log_prob).sum(dim=1)


def index_to_log_onehot(x, num_classes):
    assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
    
    x_onehot = F.one_hot(x, num_classes)
    permute_order = (0, -1) + tuple(range(1, len(x.size())))
    x_onehot = x_onehot.permute(permute_order)
    log_x = torch.log(x_onehot.float().clamp(min=1e-30))
    return log_x


def log_onehot_to_index(log_x):
    return log_x.argmax(1)


def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = np.linspace(0, steps, steps)
    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])

    alphas = np.clip(alphas, a_min=0.001, a_max=1.)
    alphas = np.sqrt(alphas)

    return alphas

class Con_Diffusion(torch.nn.Module):
    def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, adaptive_auxiliary_loss=True):
        super(Con_Diffusion, self).__init__()

        #self._denoise_fn = SSCNet(num_classes=args.num_classes*50, num_steps=args.diffusion_steps)
        self.args = args
        self.num_classes = self.args.num_classes
        self.num_timesteps = self.args.diffusion_steps
        self.recon_loss = self.args.recon_loss
        if args.dataset == 'carla':
            self._denoise_fn = Denoise(args= self.args,  num_class = self.num_classes)
        elif args.dataset=='kitti':
            self._denoise_fn = Denoise(args= self.args,  num_class = self.num_classes, init_size=16)
        self.auxiliary_loss_weight = auxiliary_loss_weight
        self.adaptive_auxiliary_loss = adaptive_auxiliary_loss

        self.multi_criterion = multi_criterion

        alphas = cosine_beta_schedule(self.num_timesteps )
        alphas = torch.tensor(alphas.astype('float64'))

        log_alpha = np.log(alphas)
        log_cumprod_alpha = np.cumsum(log_alpha)

        log_1_min_alpha = log_1_min_a(log_alpha)
        log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)

        assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
        assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
        assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5

        # Convert to float32 and register buffers.
        self.register_buffer('log_alpha', log_alpha.float())
        self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())
        self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())
        self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())

        self.register_buffer('Lt_history', torch.zeros(self.num_timesteps ))
        self.register_buffer('Lt_count', torch.zeros(self.num_timesteps ))
    
    def device(self):
        return self.denoise_fn.device

    def multinomial_kl(self, log_prob1, log_prob2):
        kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
        return kl

    def q_pred_one_timestep(self, log_x_t, t):
        log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
        log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)

        # alpha_t * E[xt] + (1 - alpha_t) 1 / K
        
        log_probs = log_add_exp(
            log_x_t + log_alpha_t,
            log_1_min_alpha_t - np.log(self.num_classes)
        )

        return log_probs

    def q_pred(self, log_x_start, t):
        log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
        log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)

        log_probs = log_add_exp(
            log_x_start + log_cumprod_alpha_t,
            log_1_min_cumprod_alpha - np.log(self.num_classes)
        )

        return log_probs

    def predict_start(self, log_x_t, t, cond):
        x_t = log_onehot_to_index(log_x_t)

        out = self._denoise_fn(x_t, cond, t)

        log_pred = F.log_softmax(out, dim=1)
        return log_pred

    def q_posterior(self, log_x_start, log_x_t, t):
        # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
        # where q(xt | xt-1, x0) = q(xt | xt-1).

        t_minus_1 = t - 1
        # Remove negative values, will not be used anyway for final decoder
        t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
        log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)

        num_axes = (1,) * (len(log_x_start.size()) - 1)
        t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
        log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)

        # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
        # Not very easy to see why this is true. But it is :)
        unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)

        log_EV_xtmin_given_xt_given_xstart = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)

        return log_EV_xtmin_given_xt_given_xstart

    def p_pred(self, log_x, t, cond):
        log_x0_recon = self.predict_start(log_x, t, cond)
        log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t)
        return log_model_pred, log_x0_recon

    def log_sample_categorical(self, logits):
        uniform = torch.rand_like(logits)
        gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
        sample = (gumbel_noise + logits).argmax(dim=1)
        log_sample = index_to_log_onehot(sample, self.num_classes)
        return log_sample

    def q_sample(self, log_x_start, t):
        log_EV_qxt_x0 = self.q_pred(log_x_start, t)
        log_sample = self.log_sample_categorical(log_EV_qxt_x0)
        return log_sample

    def kl_prior(self, log_x_start):
        b = log_x_start.size(0)
        device = log_x_start.device
        ones = torch.ones(b, device=device).long()

        log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
        log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))

        kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
        return sum_except_batch(kl_prior)

    def sample_time(self, b, device, method='uniform'):
        if method == 'importance':
            if not (self.Lt_count > 10).all():
                return self.sample_time(b, device, method='uniform')

            Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
            Lt_sqrt[0] = Lt_sqrt[1]  # Overwrite decoder term with L1.
            pt_all = Lt_sqrt / Lt_sqrt.sum()

            t = torch.multinomial(pt_all, num_samples=b, replacement=True)

            pt = pt_all.gather(dim=0, index=t)

            return t, pt

        elif method == 'uniform':
            t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

            pt = torch.ones_like(t).float() / self.num_timesteps
            return t, pt
        else:
            raise ValueError

    def forward(self, x, voxel_input):
        b, device = x.size(0), x.device
        self.shape = x.size()[1:]        
        t, pt = self.sample_time(b, device, 'importance')

        log_x_start = index_to_log_onehot(x, self.num_classes)
        log_x_t = self.q_sample(log_x_start, t) # log_x_t : (batch, #class, 128, 128, 8)

        log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t)
        log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t, cond=voxel_input)

        kl = self.multinomial_kl(log_true_prob, log_model_prob)
        kl = sum_except_batch(kl)

        decoder_nll = -log_categorical(log_x_start, log_model_prob)
        decoder_nll = sum_except_batch(decoder_nll)

        mask = (t == torch.zeros_like(t)).float()
        kl_loss = mask * decoder_nll + (1. - mask) * kl
        
        if self.training:
            Lt2 = kl_loss.pow(2)
            Lt2_prev = self.Lt_history.gather(dim=0, index=t)
            new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()
            self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)
            self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))

        kl_prior = self.kl_prior(log_x_start)

        # Upweigh loss term of the kl
        loss = kl_loss / pt + kl_prior

        kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:])
        kl_aux = sum_except_batch(kl_aux)
        if self.recon_loss : 
            kl_aux += self.multi_criterion(log_x0_recon.exp(), x)
            #kl_aux += lovasz_softmax(torch.nn.functional.softmax(log_x0_recon.exp(), dim=1), x)

        kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux
        if self.adaptive_auxiliary_loss:
            addition_loss_weight = (1-t/self.num_timesteps) + 1.0
        else:
            addition_loss_weight = 1.0

        aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt
        
        loss += aux_loss
        loss = -loss.sum() / (self.shape[0]*self.shape[1])
        #loss += seg_loss

        return -loss

    def sample(self, voxel_input, intermediate=False):
        device = self.log_alpha.device
        self.shape = voxel_input.size()[1:]
        uniform_logits = torch.zeros((self.args.batch_size, self.num_classes) + self.shape, device=device)
        log_z = self.log_sample_categorical(uniform_logits)
        diffusion = []

        for i in reversed(range(0, self.num_timesteps)):
            print(f'Sample timestep {i:4d}', end='\r')

            t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long)

            log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t, cond=voxel_input)

            log_z = self.log_sample_categorical(log_model_prob)

            if i%10 ==0:
                diffusion.append(log_onehot_to_index(log_z))

        result = log_onehot_to_index(log_z)
        if intermediate : 
            return result, diffusion
        else : 
            return result



================================================
FILE: layers/Voxel_Level/Gen_Diffusion.py
================================================
import torch
import torch.nn.functional as F
import numpy as np
import math
from inspect import isfunction
from layers.Voxel_Level.gen_denoise import Denoise
from utils.loss import *
"""
Based in part on: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/5989f4c77eafcdc6be0fb4739f0f277a6dd7f7d8/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L281
"""
eps = 1e-8


def sum_except_batch(x, num_dims=1):
    return x.reshape(*x.shape[:num_dims], -1).sum(-1)


def log_1_min_a(a):
    return torch.log(1 - a.exp() + 1e-40)


def log_add_exp(a, b):
    maximum = torch.max(a, b)
    return maximum + torch.log(torch.exp(a - maximum) + torch.exp(b - maximum))


def exists(x):
    return x is not None


def extract(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def log_categorical(log_x_start, log_prob):
    return (log_x_start.exp() * log_prob).sum(dim=1)


def index_to_log_onehot(x, num_classes):
    assert x.max().item() < num_classes, f'Error: {x.max().item()} >= {num_classes}'
    
    x_onehot = F.one_hot(x, num_classes)
    permute_order = (0, -1) + tuple(range(1, len(x.size())))
    x_onehot = x_onehot.permute(permute_order)
    log_x = torch.log(x_onehot.float().clamp(min=1e-30))
    return log_x


def log_onehot_to_index(log_x):
    return log_x.argmax(1)


def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    x = np.linspace(0, steps, steps)
    alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    alphas = (alphas_cumprod[1:] / alphas_cumprod[:-1])

    alphas = np.clip(alphas, a_min=0.001, a_max=1.)
    alphas = np.sqrt(alphas)

    return alphas

class Diffusion(torch.nn.Module):
    def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, adaptive_auxiliary_loss=True):
        super(Diffusion, self).__init__()

        #self._denoise_fn = SSCNet(num_classes=args.num_classes*50, num_steps=args.diffusion_steps)
        self.args = args
        self.num_classes = self.args.num_classes
        self.num_timesteps = self.args.diffusion_steps
        self.recon_loss = self.args.recon_loss
        self._denoise_fn = Denoise(args= self.args,  num_class = self.num_classes)
        self.auxiliary_loss_weight = auxiliary_loss_weight
        self.adaptive_auxiliary_loss = adaptive_auxiliary_loss

        self.multi_criterion = multi_criterion

        alphas = cosine_beta_schedule(self.num_timesteps )

        alphas = torch.tensor(alphas.astype('float64'))
        log_alpha = np.log(alphas)
        log_cumprod_alpha = np.cumsum(log_alpha)

        log_1_min_alpha = log_1_min_a(log_alpha)
        log_1_min_cumprod_alpha = log_1_min_a(log_cumprod_alpha)

        assert log_add_exp(log_alpha, log_1_min_alpha).abs().sum().item() < 1.e-5
        assert log_add_exp(log_cumprod_alpha, log_1_min_cumprod_alpha).abs().sum().item() < 1e-5
        assert (np.cumsum(log_alpha) - log_cumprod_alpha).abs().sum().item() < 1.e-5

        # Convert to float32 and register buffers.
        self.register_buffer('log_alpha', log_alpha.float())
        self.register_buffer('log_1_min_alpha', log_1_min_alpha.float())
        self.register_buffer('log_cumprod_alpha', log_cumprod_alpha.float())
        self.register_buffer('log_1_min_cumprod_alpha', log_1_min_cumprod_alpha.float())

        self.register_buffer('Lt_history', torch.zeros(self.num_timesteps ))
        self.register_buffer('Lt_count', torch.zeros(self.num_timesteps ))
    
    def device(self):
        return self.denoise_fn.device

    def multinomial_kl(self, log_prob1, log_prob2):
        kl = (log_prob1.exp() * (log_prob1 - log_prob2)).sum(dim=1)
        return kl

    def q_pred_one_timestep(self, log_x_t, t):
        log_alpha_t = extract(self.log_alpha, t, log_x_t.shape)
        log_1_min_alpha_t = extract(self.log_1_min_alpha, t, log_x_t.shape)

        # alpha_t * E[xt] + (1 - alpha_t) 1 / K
        
        log_probs = log_add_exp(
            log_x_t + log_alpha_t,
            log_1_min_alpha_t - np.log(self.num_classes)
        )

        return log_probs

    def q_pred(self, log_x_start, t):
        log_cumprod_alpha_t = extract(self.log_cumprod_alpha, t, log_x_start.shape)
        log_1_min_cumprod_alpha = extract(self.log_1_min_cumprod_alpha, t, log_x_start.shape)

        log_probs = log_add_exp(
            log_x_start + log_cumprod_alpha_t,
            log_1_min_cumprod_alpha - np.log(self.num_classes)
        )

        return log_probs

    def predict_start(self, log_x_t, t):
        x_t = log_onehot_to_index(log_x_t)

        out = self._denoise_fn(x_t, t)

        log_pred = F.log_softmax(out, dim=1)
        return log_pred

    def q_posterior(self, log_x_start, log_x_t, t):
        # q(xt-1 | xt, x0) = q(xt | xt-1, x0) * q(xt-1 | x0) / q(xt | x0)
        # where q(xt | xt-1, x0) = q(xt | xt-1).

        t_minus_1 = t - 1
        # Remove negative values, will not be used anyway for final decoder
        t_minus_1 = torch.where(t_minus_1 < 0, torch.zeros_like(t_minus_1), t_minus_1)
        log_EV_qxtmin_x0 = self.q_pred(log_x_start, t_minus_1)

        num_axes = (1,) * (len(log_x_start.size()) - 1)
        t_broadcast = t.view(-1, *num_axes) * torch.ones_like(log_x_start)
        log_EV_qxtmin_x0 = torch.where(t_broadcast == 0, log_x_start, log_EV_qxtmin_x0)

        # Note: _NOT_ x_tmin1, which is how the formula is typically used!!!
        # Not very easy to see why this is true. But it is :)
        unnormed_logprobs = log_EV_qxtmin_x0 + self.q_pred_one_timestep(log_x_t, t)

        log_EV_xtmin_given_xt_given_xstart = unnormed_logprobs - torch.logsumexp(unnormed_logprobs, dim=1, keepdim=True)

        return log_EV_xtmin_given_xt_given_xstart

    def p_pred(self, log_x, t):
        log_x0_recon = self.predict_start(log_x, t)
        log_model_pred = self.q_posterior(log_x_start=log_x0_recon, log_x_t=log_x, t=t)
        return log_model_pred, log_x0_recon

    def log_sample_categorical(self, logits):
        uniform = torch.rand_like(logits)
        gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
        sample = (gumbel_noise + logits).argmax(dim=1)
        log_sample = index_to_log_onehot(sample, self.num_classes)
        return log_sample

    def q_sample(self, log_x_start, t):
        log_EV_qxt_x0 = self.q_pred(log_x_start, t)
        log_sample = self.log_sample_categorical(log_EV_qxt_x0)
        return log_sample

    def kl_prior(self, log_x_start):
        b = log_x_start.size(0)
        device = log_x_start.device
        ones = torch.ones(b, device=device).long()

        log_qxT_prob = self.q_pred(log_x_start, t=(self.num_timesteps - 1) * ones)
        log_half_prob = -torch.log(self.num_classes * torch.ones_like(log_qxT_prob))

        kl_prior = self.multinomial_kl(log_qxT_prob, log_half_prob)
        return sum_except_batch(kl_prior)

    def sample_time(self, b, device, method='uniform'):
        if method == 'importance':
            if not (self.Lt_count > 10).all():
                return self.sample_time(b, device, method='uniform')

            Lt_sqrt = torch.sqrt(self.Lt_history + 1e-10) + 0.0001
            Lt_sqrt[0] = Lt_sqrt[1]  # Overwrite decoder term with L1.
            pt_all = Lt_sqrt / Lt_sqrt.sum()

            t = torch.multinomial(pt_all, num_samples=b, replacement=True)

            pt = pt_all.gather(dim=0, index=t)

            return t, pt

        elif method == 'uniform':
            t = torch.randint(0, self.num_timesteps, (b,), device=device).long()

            pt = torch.ones_like(t).float() / self.num_timesteps
            return t, pt
        else:
            raise ValueError

    def forward(self, x, voxel_input):
        b, device = x.size(0), x.device
        self.shape = x.size()[1:]        
        t, pt = self.sample_time(b, device, 'importance')

        log_x_start = index_to_log_onehot(x, self.num_classes)
        log_x_t = self.q_sample(log_x_start, t) # log_x_t : (batch, #class, 128, 128, 8)

        log_true_prob = self.q_posterior(log_x_start=log_x_start, log_x_t=log_x_t, t=t)
        log_model_prob, log_x0_recon = self.p_pred(log_x=log_x_t, t=t)

        kl = self.multinomial_kl(log_true_prob, log_model_prob)
        kl = sum_except_batch(kl)

        decoder_nll = -log_categorical(log_x_start, log_model_prob)
        decoder_nll = sum_except_batch(decoder_nll)

        mask = (t == torch.zeros_like(t)).float()
        kl_loss = mask * decoder_nll + (1. - mask) * kl
        
        if self.training:
            Lt2 = kl_loss.pow(2)
            Lt2_prev = self.Lt_history.gather(dim=0, index=t)
            new_Lt_history = (0.1 * Lt2 + 0.9 * Lt2_prev).detach()
            self.Lt_history.scatter_(dim=0, index=t, src=new_Lt_history)
            self.Lt_count.scatter_add_(dim=0, index=t, src=torch.ones_like(Lt2))

        kl_prior = self.kl_prior(log_x_start)

        # Upweigh loss term of the kl
        loss = kl_loss / pt + kl_prior

        kl_aux = self.multinomial_kl(log_x_start[:,:-1,:,:,:], log_x0_recon[:,:-1,:,:,:])
        kl_aux = sum_except_batch(kl_aux)
        '''if self.recon_loss : 
            kl_aux += self.multi_criterion(log_x0_recon.exp(), x)
            kl_aux += lovasz_softmax(torch.nn.functional.softmax(log_x0_recon.exp(), dim=1), x)'''

        kl_aux_loss = mask * decoder_nll + (1. - mask) * kl_aux
        if self.adaptive_auxiliary_loss:
            addition_loss_weight = (1-t/self.num_timesteps) + 1.0
        else:
            addition_loss_weight = 1.0

        aux_loss = addition_loss_weight * self.auxiliary_loss_weight * kl_aux_loss / pt
        
        loss += aux_loss
        loss = -loss.sum() / (self.shape[0]*self.shape[1])
        #loss += seg_loss

        return -loss

    def sample(self, voxel_input):
        device = self.log_alpha.device
        self.shape = voxel_input.size()[1:]
        uniform_logits = torch.zeros((self.args.batch_size, self.num_classes) + self.shape, device=device)
        log_z = self.log_sample_categorical(uniform_logits)

        for i in reversed(range(0, self.num_timesteps)):
            print(f'Sample timestep {i:4d}', end='\r')

            t = torch.full((self.args.batch_size,), i, device=device, dtype=torch.long)

            log_model_prob, log_x0_recon = self.p_pred(log_x=log_z, t=t)

            log_z = self.log_sample_categorical(log_model_prob)

        result = log_onehot_to_index(log_z)
        return result



================================================
FILE: layers/Voxel_Level/denoise.py
================================================
import math
from mimetypes import init
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import nn, einsum


def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)


def conv1x1x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)


def conv1x3x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)


def conv3x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)


def conv3x1x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)


class Asymmetric_Residual_Block(nn.Module):
    def __init__(self, in_filters, out_filters, time_filters=32*4):
        super(Asymmetric_Residual_Block, self).__init__()
        if in_filters<32 :
            self.GroupNorm = nn.GroupNorm(16, in_filters)
            self.bn0 = nn.GroupNorm(16, out_filters)
            self.bn0_2 = nn.GroupNorm(16, out_filters)
            self.bn1 = nn.GroupNorm(16, out_filters)
            self.bn2 = nn.GroupNorm(16, out_filters)
        else :
            self.GroupNorm = nn.GroupNorm(32, in_filters)
            self.bn0 = nn.GroupNorm(32, out_filters)
            self.bn0_2 = nn.GroupNorm(32, out_filters)
            self.bn1 = nn.GroupNorm(32, out_filters)
            self.bn2 = nn.GroupNorm(32, out_filters)
        self.time_layers = nn.Sequential(
                            nn.SiLU(),
                            nn.Linear(time_filters, in_filters*2)
                        )

        self.conv1 = conv1x3x3(in_filters, out_filters)
        self.act1 = nn.LeakyReLU()
          
        self.conv1_2 = conv3x1x3(out_filters, out_filters)
        self.act1_2 = nn.LeakyReLU()

        self.conv2 = conv3x1x3(in_filters, out_filters)
        self.act2 = nn.LeakyReLU()

        self.conv3 = conv1x3x3(out_filters, out_filters)
        self.act3 = nn.LeakyReLU()


    def forward(self, x, t):
        t = self.time_layers(t)
        while len(t.shape) < len(x.shape):
            t = t[..., None]
        scale, shift = torch.chunk(t, 2, dim=1)
        
        x = self.GroupNorm(x) * (1 + scale) + shift

        shortcut = self.conv1(x)
        shortcut = self.act1(shortcut)
        shortcut = self.bn0(shortcut)

        shortcut = self.conv1_2(shortcut)
        shortcut = self.act1_2(shortcut)
        shortcut = self.bn0_2(shortcut)

        resA = self.conv2(x) 
        resA = self.act2(resA)
        resA = self.bn1(resA)

        resA = self.conv3(resA) 
        resA = self.act3(resA)
        resA = self.bn2(resA)
        resA += shortcut

        return resA

class DDCM(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):
        super(DDCM, self).__init__()
        self.conv1 = conv3x1x1(in_filters, out_filters)
        if in_filters<32 :
            self.bn0 = nn.GroupNorm(16, out_filters)
            self.bn0_2 = nn.GroupNorm(16, out_filters)
            self.bn0_3 = nn.GroupNorm(16, out_filters)
        else :
            self.bn0 = nn.GroupNorm(32, out_filters)
            self.bn0_2 = nn.GroupNorm(32, out_filters)
            self.bn0_3 = nn.GroupNorm(32, out_filters)
        self.act1 = nn.Sigmoid()

        self.conv1_2 = conv1x3x1(in_filters, out_filters)
        self.act1_2 = nn.Sigmoid()

        self.conv1_3 = conv1x1x3(in_filters, out_filters)
        self.act1_3 = nn.Sigmoid()

    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = self.bn0(shortcut)
        shortcut = self.act1(shortcut)

        shortcut2 = self.conv1_2(x)
        shortcut2 = self.bn0_2(shortcut2)
        shortcut2 = self.act1_2(shortcut2)

        shortcut3 = self.conv1_3(x)
        shortcut3 = self.bn0_3(shortcut3)
        shortcut3 = self.act1_3(shortcut3)
        shortcut = shortcut + shortcut2 + shortcut3

        shortcut = shortcut * x

        return shortcut

def l2norm(t):
    return F.normalize(t, dim = -1)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, scale = 10):
        super().__init__()
        self.scale = scale
        self.heads = heads
        self.to_qkv = conv1x1(dim, dim*3, stride=1)
        self.to_out = conv1x1(dim, dim, stride=1)

    def forward(self, x):
        b, c, h, w, Z = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)

        q, k = map(l2norm, (q, k))

        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
        return self.to_out(out)

class Cross_Attention(nn.Module):
    def __init__(self, dim, heads = 4, scale = 10):
        super().__init__()
        self.scale = scale
        self.heads = heads
        self.to_q = conv1x1(dim, dim, stride=1)
        self.to_k = conv1x1(dim, dim, stride=1)
        self.to_v = conv1x1(dim, dim, stride=1)

        self.to_out = conv1x1(dim, dim, stride=1)

    def forward(self, x, cond_x):
        b, c, h, w, Z = x.shape
        q = self.to_q(x)
        k = self.to_k(cond_x)
        v = self.to_v(cond_x)

        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v))

        q, k = map(l2norm, (q, k))

        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
        return self.to_out(out)

class DownBlock(nn.Module):
    def __init__(self, in_filters, out_filters, time_filters=32*4, kernel_size=(3, 3, 3), stride=1,
                 pooling=True, height_pooling=False):
        super(DownBlock, self).__init__()
        self.pooling = pooling

        self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters, time_filters=time_filters)

        if pooling:
            if height_pooling:
                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,padding=1, bias=False)
            else:
                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),padding=1, bias=False)

    def forward(self, x, t):
        resA = self.residual_block(x, t)
        if self.pooling:
            resB = self.pool(resA) 
            return resB, resA
        else:
            return resA

class UpBlock(nn.Module):
    def __init__(self, in_filters, out_filters, height_pooling, time_filters=32*4):
        super(UpBlock, self).__init__()
        # self.drop_out = drop_out
        if out_filters<32 :
            self.trans_bn = nn.GroupNorm(16, in_filters)
            self.bn1 = nn.GroupNorm(16, out_filters)
            self.bn2 = nn.GroupNorm(16, out_filters)
            self.bn3 = nn.GroupNorm(16, out_filters)
        else :
            self.trans_bn = nn.GroupNorm(32, in_filters)
            self.bn1 = nn.GroupNorm(32, out_filters)
            self.bn2 = nn.GroupNorm(32, out_filters)
            self.bn3 = nn.GroupNorm(32, out_filters)
        self.trans_dilao = conv3x3x3(in_filters, in_filters)
        self.trans_act = nn.LeakyReLU()
        self.time_layers = nn.Sequential(
                            nn.SiLU(),
                            nn.Linear(time_filters, in_filters*2)
                        )

        self.conv1 = conv1x3x3(in_filters, out_filters)
        self.act1 = nn.LeakyReLU()

        self.conv2 = conv3x1x3(out_filters, out_filters)
        self.act2 = nn.LeakyReLU()

        self.conv3 = conv3x3x3(out_filters, out_filters)
        self.act3 = nn.LeakyReLU()
        
        if height_pooling :
            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)
        else : 
            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)
    

    def forward(self, x, residual, t): 
        upA = self.trans_dilao(x) 
        upA = self.trans_act(upA)

        t = self.time_layers(t)
        while len(t.shape) < len(x.shape):
            t = t[..., None]
        scale, shift = torch.chunk(t, 2, dim=1)
        
        upA = self.trans_bn(upA) * (1 + scale) + shift
        ## upsample
        upA = self.up_subm(upA)
        upA += residual
        upE = self.conv1(upA)
        upE = self.act1(upE)
        upE = self.bn1(upE)

        upE = self.conv2(upE)
        upE = self.act2(upE)
        upE = self.bn2(upE)

        upE = self.conv3(upE)
        upE = self.act3(upE)
        upE = self.bn3(upE)

        return upE

def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    return embedding

class Denoise(nn.Module):
    def __init__(self, args, num_class = 11, init_size=32, discrete=True):
        super(Denoise, self).__init__()
        self.args = args
        self.discrete = discrete
        self.num_class = num_class
        self.init_size = init_size
        self.time_size = self.init_size*4

        self.time_embed = nn.Sequential(
            nn.Linear(init_size, self.time_size),
            nn.SiLU(),
            nn.Linear(self.time_size, self.time_size),
        )

        self.embedding = nn.Embedding(self.num_class, init_size)
        self.conv_in = nn.Conv3d(init_size+1, init_size, kernel_size=1, stride=1)

        self.A = Asymmetric_Residual_Block(init_size, init_size, time_filters=init_size*4)

        self.downBlock1 = DownBlock(init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4)
        self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, height_pooling=True, time_filters=init_size*4)
        self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, height_pooling=False, time_filters=init_size*4)
        self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, height_pooling=False, time_filters=init_size*4)
        self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, time_filters=init_size*4)
        self.attention = Attention(16 * init_size, 32)
        self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, time_filters=init_size*4)

        self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False, time_filters=init_size*4)
        self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False, time_filters=init_size*4)
        self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4)
        self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True, time_filters=init_size*4)

        self.DDCM = DDCM(2 * init_size, 2 * init_size)
        self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True)
        
    def forward(self, x, x_cond, t):
        x = self.embedding(x)
        x = x.permute(0, 4, 1, 2, 3)
        x_cond = x_cond.unsqueeze(1)
        x = torch.cat([x, x_cond], dim=1)
        x = self.conv_in(x)

        t = self.time_embed(timestep_embedding(t, self.init_size))

        x = self.A(x, t)

        down1c, down1b = self.downBlock1(x, t)
        down2c, down2b = self.downBlock2(down1c, t)
        down3c, down3b = self.downBlock3(down2c, t)
        down4c, down4b = self.downBlock4(down3c, t)

        down4c = self.midBlock1(down4c, t)
        down4c = self.attention(down4c)
        down4c = self.midBlock2(down4c, t)
        
        up4 = self.upBlock4(down4c, down4b, t)
        up3 = self.upBlock3(up4, down3b, t)
        up2 = self.upBlock2(up3, down2b, t)
        up1 = self.upBlock1(up2, down1b, t)
        up0 = self.DDCM(up1)
        up = torch.cat((up1, up0), 1)
        logits = self.logits(up) 
       
        return logits


================================================
FILE: layers/Voxel_Level/gen_denoise.py
================================================
import math
from mimetypes import init
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from torch import nn, einsum


def conv3x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

def conv1x3x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,padding=(0, 1, 1), bias=False)


def conv1x1x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride, padding=(0, 0, 1), bias=False)


def conv1x3x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride, padding=(0, 1, 0), bias=False)


def conv3x1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride, padding=(1, 0, 0), bias=False)


def conv3x1x3(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride, padding=(1, 0, 1), bias=False)


def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=stride)


class Asymmetric_Residual_Block(nn.Module):
    def __init__(self, in_filters, out_filters, time_filters=128):
        super(Asymmetric_Residual_Block, self).__init__()
        if in_filters < 32 : 
            n_ng = in_filters
        else : n_ng =32
        self.GroupNorm = nn.GroupNorm(n_ng, in_filters)
        self.time_layers = nn.Sequential(
                            nn.SiLU(),
                            nn.Linear(time_filters, in_filters*2)
                        )

        self.conv1 = conv1x3x3(in_filters, out_filters)
        if out_filters < 32 : 
            n_ng = out_filters
        else : n_ng =32
        self.bn0 = nn.GroupNorm(n_ng, out_filters)
        self.act1 = nn.LeakyReLU()
          
        self.conv1_2 = conv3x1x3(out_filters, out_filters)
        self.bn0_2 = nn.GroupNorm(n_ng, out_filters)
        self.act1_2 = nn.LeakyReLU()

        self.conv2 = conv3x1x3(in_filters, out_filters)
        self.act2 = nn.LeakyReLU()
        self.bn1 = nn.GroupNorm(n_ng, out_filters)

        self.conv3 = conv1x3x3(out_filters, out_filters)
        self.act3 = nn.LeakyReLU()
        self.bn2 = nn.GroupNorm(n_ng, out_filters)


    def forward(self, x, t):
        t = self.time_layers(t)
        while len(t.shape) < len(x.shape):
            t = t[..., None]
        scale, shift = torch.chunk(t, 2, dim=1)
        
        x = self.GroupNorm(x) * (1 + scale) + shift

        shortcut = self.conv1(x) 
        shortcut = self.act1(shortcut)
        shortcut = self.bn0(shortcut)

        shortcut = self.conv1_2(shortcut) 
        shortcut = self.act1_2(shortcut)
        shortcut = self.bn0_2(shortcut)

        resA = self.conv2(x)
        resA = self.act2(resA)
        resA = self.bn1(resA)

        resA = self.conv3(resA) 
        resA = self.act3(resA)
        resA = self.bn2(resA)
        resA += shortcut

        return resA

class DDCM(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1):
        super(DDCM, self).__init__()
        self.conv1 = conv3x1x1(in_filters, out_filters)
        if out_filters < 32 : 
            n_ng = out_filters
        else : n_ng =32
        self.bn0 = nn.GroupNorm(n_ng, out_filters)
        self.act1 = nn.Sigmoid()

        self.conv1_2 = conv1x3x1(in_filters, out_filters)
        self.bn0_2 = nn.GroupNorm(n_ng, out_filters)
        self.act1_2 = nn.Sigmoid()

        self.conv1_3 = conv1x1x3(in_filters, out_filters)
        self.bn0_3 = nn.GroupNorm(n_ng, out_filters)
        self.act1_3 = nn.Sigmoid()

    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = self.bn0(shortcut)
        shortcut = self.act1(shortcut)

        shortcut2 = self.conv1_2(x)
        shortcut2 = self.bn0_2(shortcut2)
        shortcut2 = self.act1_2(shortcut2)

        shortcut3 = self.conv1_3(x)
        shortcut3 = self.bn0_3(shortcut3)
        shortcut3 = self.act1_3(shortcut3)
        shortcut = shortcut + shortcut2 + shortcut3
        shortcut = shortcut * x

        return shortcut

def l2norm(t):
    return F.normalize(t, dim = -1)

class Attention(nn.Module):
    def __init__(self, dim, heads = 4, scale = 10):
        super().__init__()
        self.scale = scale
        self.heads = heads
        self.to_qkv = conv1x1(dim, dim*3, stride=1)
        self.to_out = conv1x1(dim, dim, stride=1)

    def forward(self, x):
        b, c, h, w, Z = x.shape
        qkv = self.to_qkv(x).chunk(3, dim = 1)
        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), qkv)

        q, k = map(l2norm, (q, k))

        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
        return self.to_out(out)

class Cross_Attention(nn.Module):
    def __init__(self, dim, heads = 4, scale = 10):
        super().__init__()
        self.scale = scale
        self.heads = heads
        self.to_q = conv1x1(dim, dim, stride=1)
        self.to_k = conv1x1(dim, dim, stride=1)
        self.to_v = conv1x1(dim, dim, stride=1)

        self.to_out = conv1x1(dim, dim, stride=1)

    def forward(self, x, cond_x):
        b, c, h, w, Z = x.shape
        q = self.to_q(x)
        k = self.to_k(cond_x)
        v = self.to_v(cond_x)

        q, k, v = map(lambda t: rearrange(t, 'b (h c) x y z-> b h c (x y z)', h = self.heads), (q, k, v))

        q, k = map(l2norm, (q, k))

        sim = einsum('b h d i, b h d j -> b h i j', q, k) * self.scale
        attn = sim.softmax(dim = -1)
        out = einsum('b h i j, b h d j -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y z) d -> b (h d) x y z', x = h, y = w, z = Z)
        return self.to_out(out)

class DownBlock(nn.Module):
    def __init__(self, in_filters, out_filters, time_filters, kernel_size=(3, 3, 3), stride=1,
                 pooling=True, drop_out=True, height_pooling=False):
        super(DownBlock, self).__init__()
        self.pooling = pooling
        self.drop_out = drop_out
        self.residual_block = Asymmetric_Residual_Block(in_filters, out_filters, time_filters)

        if pooling:
            if height_pooling:
                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=2,
                                                padding=1, bias=False)
            else:
                self.pool = nn.Conv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),
                                                padding=1, bias=False)


    def forward(self, x, t):
        resA = self.residual_block(x, t)
        if self.pooling:
            resB = self.pool(resA) 
            return resB, resA
        else:
            return resA

class UpBlock(nn.Module):
    def __init__(self, in_filters, out_filters, height_pooling, time_filters):
        super(UpBlock, self).__init__()
        # self.drop_out = drop_out
        self.trans_dilao = conv3x3x3(in_filters, in_filters)
        self.trans_act = nn.LeakyReLU()
        if in_filters < 32 : 
            n_ng = out_filters
        else : n_ng =32
        self.trans_bn = nn.GroupNorm(n_ng, in_filters)
        self.time_layers = nn.Sequential(
                            nn.SiLU(),
                            nn.Linear(time_filters, in_filters*2)
                        )

        self.conv1 = conv1x3x3(in_filters, out_filters)
        self.act1 = nn.LeakyReLU()
        if out_filters < 32 : 
            n_ng = out_filters
        else :n_ng = 32
        self.bn1 = nn.GroupNorm(n_ng, out_filters)

        self.conv2 = conv3x1x3(out_filters, out_filters)
        self.act2 = nn.LeakyReLU()
        self.bn2 = nn.GroupNorm(n_ng, out_filters)

        self.conv3 = conv3x3x3(out_filters, out_filters)
        self.act3 = nn.LeakyReLU()
        self.bn3 = nn.GroupNorm(n_ng, out_filters)
        
        if height_pooling :
            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=3, bias=False, stride=2, padding=1, output_padding=1, dilation=1)
        else : 
            self.up_subm = nn.ConvTranspose3d(in_filters, in_filters, kernel_size=(3,3,1), bias=False, stride=(2,2,1), padding=(1,1,0), output_padding=(1,1,0), dilation=1)
    

    def forward(self, x, residual, t):
        upA = self.trans_dilao(x)
        upA = self.trans_act(upA)

        t = self.time_layers(t)
        while len(t.shape) < len(x.shape):
            t = t[..., None]
        scale, shift = torch.chunk(t, 2, dim=1)
        
        upA = self.trans_bn(upA) * (1 + scale) + shift
        ## upsample
        upA = self.up_subm(upA)
        upA += residual
        upE = self.conv1(upA)
        upE = self.act1(upE)
        upE = self.bn1(upE)

        upE = self.conv2(upE)
        upE = self.act2(upE)
        upE = self.bn2(upE)

        upE = self.conv3(upE)
        upE = self.act3(upE)
        upE = self.bn3(upE)

        return upE

def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    if not repeat_only:
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)
        args = timesteps[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:
        embedding = repeat(timesteps, 'b -> b d', d=dim)
    return embedding

class Denoise(nn.Module):
    def __init__(self, args, num_class = 11, init_size=32, discrete=True):
        super(Denoise, self).__init__()
        self.args = args
        self.discrete = discrete
        self.num_class = num_class
        self.init_size = init_size
        self.time_size = init_size*4

        self.time_embed = nn.Sequential(
            nn.Linear(self.init_size, self.time_size),
            nn.SiLU(),
            nn.Linear(self.time_size, self.time_size),
        )

        self.embedding = nn.Embedding(self.num_class, self.init_size)
        self.conv_in = nn.Conv3d(self.init_size, self.init_size, kernel_size=1, stride=1)

        self.A = Asymmetric_Residual_Block(self.init_size, self.init_size, self.time_size)

        self.downBlock1 = DownBlock(init_size, 2 * init_size, self.time_size, height_pooling=True)
        self.downBlock2 = DownBlock(2 * init_size, 4 * init_size, self.time_size, height_pooling=True)
        self.downBlock3 = DownBlock(4 * init_size, 8 * init_size, self.time_size, height_pooling=False)
        self.downBlock4 = DownBlock(8 * init_size, 16 * init_size, self.time_size, height_pooling=False)
        self.midBlock1 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, self.time_size)
        self.attention = Attention(16 * init_size, 32)
        self.midBlock2 = Asymmetric_Residual_Block(16 * init_size, 16 * init_size, self.time_size)
        
        self.upBlock4 = UpBlock(16 * init_size, 8 * init_size, height_pooling=False, time_filters=self.time_size)
        self.upBlock3 = UpBlock(8 * init_size, 4 * init_size, height_pooling=False, time_filters=self.time_size)
        self.upBlock2 = UpBlock(4 * init_size, 2 * init_size, height_pooling=True, time_filters=self.time_size)
        self.upBlock1 = UpBlock(2 * init_size, 2 * init_size, height_pooling=True, time_filters=self.time_size)
        self.DDCM = DDCM(2 * init_size, 2 * init_size)
        self.logits = nn.Conv3d(4 * init_size, self.num_class, kernel_size=3, stride=1, padding=1, bias=True)
 

    def forward(self, x, t):
        x = self.embedding(x)
        x = x.permute(0, 4, 1, 2, 3)
        x = self.conv_in(x)

        t = self.time_embed(timestep_embedding(t, self.init_size))

        x = self.A(x, t)

        down1c, down1b = self.downBlock1(x, t) 
        down2c, down2b = self.downBlock2(down1c, t) 
        down3c, down3b = self.downBlock3(down2c, t) 
        
        down4c, down4b = self.downBlock4(down3c, t) 
        down4c = self.midBlock1(down4c, t) 
        down4c = self.attention(down4c)
        down4c = self.midBlock2(down4c, t) 
        up4 = self.upBlock4(down4c, down4b, t)
        up3 = self.upBlock3(up4, down3b, t)


        up2 = self.upBlock2(up3, down2b, t)
        up1 = self.upBlock1(up2, down1b, t)

        up0 = self.DDCM(up1) 

        up = torch.cat((up1, up0), 1)

        logits = self.logits(up) 
        
        return logits


================================================
FILE: layers/__init__.py
================================================


================================================
FILE: requirements.txt
================================================
numpy
torch
scipy
scikit-learn
matplotlib
tqdm
open3d
pyyaml
prettytable
tensorboard
numba
einops


================================================
FILE: setup.py
================================================
from setuptools import setup, find_packages

setup(
    name="scene_scale_diffusion",
    version="0.1",
    author="Lee Jumin, Im Woobin, Lee Sebin, Yoon Sung-Eui",
    author_email="",
    description="Experiments in PyTorch",
    long_description="",
    packages=setuptools.find_packages(),
    classifiers=[
        "Programming Language :: Python :: 3",
        "License :: OSI Approved :: MIT License",
        "Operating System :: OS Independent",
    ],
)


================================================
FILE: simple_visualize.py
================================================
import os
import numpy as np
import open3d as o3d
import argparse
import yaml

def load_config(yaml_path):
    with open(yaml_path, 'r') as f:
        config = yaml.safe_load(f)
    return config["learning_map"], config["remap_color_map"]

def load_pointcloud(filepath, learning_map, color_map):
    data = np.loadtxt(filepath, delimiter=' ')
    if data.shape[1] < 4:
        raise ValueError(f"Expected at least 4 columns (label + x y z), got shape {data.shape}")

    raw_labels = data[:, 0].astype(int)
    points = data[:, 1:4]

    # Map raw labels → remapped labels → colors
    remapped_labels = np.array([learning_map.get(int(l), 0) for l in raw_labels])
    colors = np.array([color_map.get(int(l), [255, 255, 255]) for l in remapped_labels]) / 255.0

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    pcd.colors = o3d.utility.Vector3dVector(colors)
    return pcd

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--file', default='result_for_l_gen/Completion/result_0.txt',
                        help='Path to the point cloud .txt file')
    parser.add_argument('--config', default='datasets/carla.yaml',
                        help='Path to Carla YAML config file')
    args = parser.parse_args()

    if not os.path.exists(args.file):
        raise FileNotFoundError(f"Point cloud file not found: {args.file}")
    if not os.path.exists(args.config):
        raise FileNotFoundError(f"YAML config file not found: {args.config}")

    learning_map, color_map = load_config(args.config)
    pcd = load_pointcloud(args.file, learning_map, color_map)

    o3d.visualization.draw([pcd])

if __name__ == "__main__":
    main()


================================================
FILE: train.py
================================================
from dataclasses import astuple
import torch
import argparse
import numpy as np
import os
import pickle
import torch
import torch.nn.functional as F
import yaml

from prettytable import PrettyTable
from torch.utils.tensorboard import SummaryWriter

from utils.tables import *
from utils.dicts import clean_dict
from utils.loss import lovasz_softmax


class Experiment(object):
    no_log_keys = ['project', 'name','log_tb', 'log_wandb','check_every', 'eval_every','device', 'parallel', 'pin_memory', 'num_workers']
                   
    def __init__(self, args, model, optimizer, scheduler_iter, scheduler_epoch,
                 train_loader, eval_loader, test_loader, train_sampler,
                 log_path, eval_every, check_every):

        # Objects
        self.model = model

        self.loss_fun = torch.nn.CrossEntropyLoss(ignore_index=0)
        self.optimizer, self.scheduler_iter, self.scheduler_epoch= optimizer, scheduler_iter, scheduler_epoch
        # Paths
        self.log_path = log_path

        if args.dataset =='carla':
            config_file = os.path.join('./datasets/carla.yaml')
            carla_config = yaml.safe_load(open(config_file, 'r'))
            self.color_map = carla_config["remap_color_map"]
            self.remap = None
            LABEL_TO_NAMES = carla_config["label_to_names"]
            self.label_to_names = np.asarray(list(LABEL_TO_NAMES.values()))

        # Intervals
        self.eval_every, self.check_every = eval_every, check_every

        # Initialize
        self.current_epoch = 0
        self.train_metrics, self.eval_metrics, self.ssc_metrics, self.seg_metrics = {}, {}, {}, {}
        self.eval_epochs = []
        self.completion_epochs = []

        # Store data loaders
        self.train_loader, self.eval_loader, self.test_loader, self.train_sampler = train_loader, eval_loader, test_loader, train_sampler

        # Store args
        create_folders(args)
        save_args(args)
        self.args = args

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0)

    def run(self, epochs):
        if self.args.resume: 
            self.resume()
        
        for epoch in range(self.current_epoch, epochs): 
            
            # Train
            train_dict = self.train_fn(epoch)
            self.log_metrics(train_dict, self.train_metrics)

            # Checkpoint
            self.current_epoch += 1
            if (epoch+1) % self.check_every == 0:
                self.checkpoint_save(epoch)

            # Eval
            if (epoch+1) % self.eval_every == 0:
                eval_dict = self.eval_fn(epoch)
                self.log_metrics(eval_dict, self.eval_metrics)
                self.eval_epochs.append(epoch)
            else:
                eval_dict = None

            if (epoch+1) % self.args.completion_epoch == 0:
                ssc_dict, miou, seg_dict, seg_miou = self.sample()
                self.log_metrics(ssc_dict, self.ssc_metrics)
                self.log_metrics(seg_dict, self.ssc_metrics)
                self.completion_epochs.append(epoch)
            else :
                ssc_dict, seg_dict = None, None

            # Log
            #self.save_metrics()
            if self.args.log_tb:
                for metric_name, metric_value in train_dict.items():
                    self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch+1)
                if eval_dict:
                    for metric_name, metric_value in eval_dict.items():
                        self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch+1)
                if ssc_dict:
                    for metric_name, metric_value in ssc_dict.items():
                        self.writer.add_scalar('SSC/{}'.format(metric_name), metric_value, global_step=epoch+1)
                    self.writer.add_text("SSC_mIoU", get_miou_table(self.args, self.label_to_names, miou).get_html_string(), global_step=epoch+1)
                    for metric_name, metric_value in seg_dict.items():
                        self.writer.add_scalar('Seg/{}'.format(metric_name), metric_value, global_step=epoch+1)
                    self.writer.add_text("Seg_mIoU", get_miou_table(self.args, self.label_to_names, seg_miou).get_html_string(), global_step=epoch+1)

    def train_fn(self, epoch):
        self.model.train()
        loss_sum = 0.0
        loss_count = 0
        if self.args.distribution :
            self.train_sampler.set_epoch(epoch)

        for voxel_input, output, counts in self.train_loader:
            self.optimizer.zero_grad()
            voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32)
            output = torch.from_numpy(np.asarray(output)).long().cuda()            
            if self.args.distribution:
                loss = self.model.module(output, voxel_input)
            else : 
                loss = self.model(output, voxel_input)
            loss.backward()

            if self.args.clip_value: torch.nn.utils.clip_grad_value_(self.model.parameters(), self.args.clip_value)
            if self.args.clip_norm: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_norm)

            self.optimizer.step()
            if self.scheduler_iter: self.scheduler_iter.step()
            loss_sum += loss.detach().cpu().item() * len(output)
            loss_count += len(output)
            print('Training. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.train_loader.dataset), loss_sum/loss_count), end='\r')
        print('')
        if self.scheduler_epoch: self.scheduler_epoch.step()
        return {'loss': loss_sum/loss_count}


    def eval_fn(self, epoch):
        self.model.eval()

        with torch.no_grad():
            loss_sum = 0.0
            loss_count = 0
            for voxel_input, output, counts in self.eval_loader:
                voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32)
                output = torch.from_numpy(np.asarray(output)).long().cuda()            
                if self.args.distribution:
                    loss = self.model.module(output, voxel_input)
                else : 
                    loss = self.model(output, voxel_input)
                loss_sum += loss.detach().cpu().item() * len(output)
                loss_count += len(output)
                print('Train evaluating. Epoch: {}/{}, Datapoint: {}/{}, Bits/dim: {:.3f}'.format(epoch+1, self.args.epochs, loss_count, len(self.eval_loader.dataset), loss_sum/loss_count), end='\r')
            print('')
        return {'loss': loss_sum/loss_count}


    def sample(self):
        self.model.eval()
        with torch.no_grad():
            TP, FP, TN, FN, num_correct, num_total = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
            s_TP, s_FP, s_TN, s_FN, s_num_correct, s_num_total = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
            all_intersections, all_unions = np.zeros(self.args.num_classes), np.zeros(self.args.num_classes) + 1e-6
            s_all_intersections, s_all_unions = np.zeros(self.args.num_classes), np.zeros(self.args.num_classes) + 1e-6
            if self.args.dataset == 'carla':
                dataloader = self.test_loader
            else :
                dataloader = self.eval_loader
            for iterate, (voxel_input, output, counts) in enumerate(dataloader):
                if len(voxel_input) == self.args.batch_size :
                    voxel_input = torch.from_numpy(np.asarray(voxel_input)).long().squeeze(1).cuda() # (4,1,256,256,32)
                    output = torch.from_numpy(np.asarray(output)).long().cuda()            
                    invalid = torch.from_numpy(np.asarray(counts)).cuda()

                    if self.args.mode == 'l_vae':
                        if self.args.distribution:
                            recons = self.model.module.sample(output) 
                        else : 
                            recons = self.model.sample(output) 
                    else :
                        if self.args.distribution:
                            recons = self.model.module.sample(voxel_input) 
                        else : 
                            recons = self.model.sample(voxel_input)   

                    visualization(self.args, recons, voxel_input, output, invalid, iteration = iterate)
                    correct, total, pred_TP, pred_FP, pred_TN, pred_FN, intersection, union = get_result(self.args, invalid, output, recons)
                    all_intersections += intersection
                    all_unions += union
                    num_correct += correct
                    num_total += total
                    TP += pred_TP
                    FP += pred_FP
                    TN += pred_TN
                    FN += pred_FN

                    s_correct, s_total, s_pred_TP, s_pred_FP, s_pred_TN, s_pred_FN, s_intersection, s_union = get_result(self.args, voxel_input, output, recons, SSC=False)
                    s_all_intersections += s_intersection
                    s_all_unions += s_union
                    s_num_correct += s_correct
                    s_num_total += s_total
                    s_TP += s_pred_TP
                    s_FP += s_pred_FP
                    s_TN += s_pred_TN
                    s_FN += s_pred_FN
                   
            iou, miou = print_result(self.args, self.label_to_names, num_correct, num_total, all_intersections, all_unions, TP, FP, FN)
            s_iou, seg_miou = print_result(self.args, self.label_to_names, s_num_correct, s_num_total, s_all_intersections, s_all_unions, s_TP, s_FP, s_FN, SSC=False)
            return {"IoU" : iou, "mIoU": np.mean(miou)*100 }, miou, {"IoU" : s_iou, "mIoU": np.mean(seg_miou)*100 }, seg_miou

    def resume(self):
        self.checkpoint_load(self.args.resume_path)
        for epoch in range(self.current_epoch):
            train_dict = {}
            for metric_name, metric_values in self.train_metrics.items():
                train_dict[metric_name] = metric_values[epoch]

            if epoch in self.eval_epochs:
                eval_dict = {}
                for metric_name, metric_values in self.eval_metrics.items():
                    eval_dict[metric_name] = metric_values[self.eval_epochs.index(epoch)]
            else: 
                eval_dict = None
            
            if epoch in self.completion_epochs:
                sample_dict = {}
                for metric_name, metric_values in self.eval_metrics.items():
                    sample_dict[metric_name] = metric_values[self.eval_epochs.index(epoch)]
            else: 
                sample_dict = None

            for metric_name, metric_value in train_dict.items():
                self.writer.add_scalar('base/{}'.format(metric_name), metric_value, global_step=epoch+1)
            if eval_dict:
                for metric_name, metric_value in eval_dict.items():
                    self.writer.add_scalar('eval/{}'.format(metric_name), metric_value, global_step=epoch+1)
            if sample_dict:
                for metric_name, metric_value in sample_dict.items():
                    self.writer.add_scalar('sample/{}'.format(metric_name), metric_value, global_step=epoch+1)


    def log_metrics(self, dict, type):
        if len(type)==0:
            for metric_name, metric_value in dict.items():
                type[metric_name] = [metric_value]
        else:
            for metric_name, metric_value in dict.items():
                type[metric_name].append(metric_value)

    def save_metrics(self):
        # Save metrics
        with open(os.path.join(self.log_path,'metrics_train.pickle'), 'wb') as f:
            pickle.dump(self.train_metrics, f)
        with open(os.path.join(self.log_path,'metrics_eval.pickle'), 'wb') as f:
            pickle.dump(self.eval_metrics, f)

        # Save metrics table
        metric_table = get_metric_table(self.train_metrics, epochs=list(range(1, self.current_epoch+2)))
        with open(os.path.join(self.log_path,'metrics_train.txt'), "w") as f:
            f.write(str(metric_table))
        metric_table = get_metric_table(self.eval_metrics, epochs=[e+1 for e in self.eval_epochs])
        with open(os.path.join(self.log_path,'metrics_eval.txt'), "w") as f:
            f.write(str(metric_table))


    def checkpoint_save(self, epoch):    
        if self.args.distribution:
            checkpoint = {'current_epoch': self.current_epoch,
                          'train_metrics': self.train_metrics,
                          'eval_metrics': self.eval_metrics,
                          'eval_epochs': self.eval_epochs,
                          'optimizer': self.optimizer.state_dict(),
                          'model': self.model.module.state_dict(),
                          'scheduler_iter': self.scheduler_iter.state_dict() if self.scheduler_iter else None,
                          'scheduler_epoch': self.scheduler_epoch.state_dict() if self.scheduler_epoch else None,}
        else : 
            checkpoint = {'current_epoch': self.current_epoch,
                          'train_metrics': self.train_metrics,
                          'eval_metrics': self.eval_metrics,
                          'eval_epochs': self.eval_epochs,
                          'optimizer': self.optimizer.state_dict(),
                          'model': self.model.state_dict(),
                          'scheduler_iter': self.scheduler_iter.state_dict() if self.scheduler_iter else None,
                          'scheduler_epoch': self.scheduler_epoch.state_dict() if self.scheduler_epoch else None,}
        epoch_name = 'epoch{}.tar'.format(epoch)
        torch.save(checkpoint, os.path.join(self.log_path, epoch_name))

    def checkpoint_load(self, resume_path):
        checkpoint = torch.load(resume_path)
        
        if self.args.distribution:
            self.model.module.load_state_dict(checkpoint['model'])
        else :
            self.model.load_state_dict(checkpoint['model'])
        
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        if self.scheduler_iter: self.scheduler_iter.load_state_dict(checkpoint['scheduler_iter'])
        if self.scheduler_epoch: self.scheduler_epoch.load_state_dict(checkpoint['scheduler_epoch'])

        self.current_epoch = checkpoint['current_epoch']
        self.train_metrics = checkpoint['train_metrics']
        self.eval_metrics = checkpoint['eval_metrics']
        self.eval_epochs = checkpoint['eval_epochs']


================================================
FILE: utils/cuda.py
================================================
import os

import os

import torch
from torch import distributed as dist
from torch import multiprocessing as mp

import utils.dicts as dist_fn

def find_free_port():
    import socket

    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

    sock.bind(("", 0))
    port = sock.getsockname()[1]
    sock.close()

    return port


def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()):
    world_size = n_machine * n_gpu_per_machine

    if world_size > 1:
        # if "OMP_NUM_THREADS" not in os.environ:
        #     os.environ["OMP_NUM_THREADS"] = "1"

        if dist_url == "auto":
            if n_machine != 1:
                raise ValueError('dist_url="auto" not supported in multi-machine jobs')

            port = find_free_port()
            dist_url = f"tcp://127.0.0.1:{port}"

        if n_machine > 1 and dist_url.startswith("file://"):
            raise ValueError(
                "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://"
            )

        mp.spawn(
            distributed_worker,
            nprocs=n_gpu_per_machine,
            args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args),
            daemon=False,
        )

    else:
        local_rank = 0
        fn(local_rank, *args)


def distributed_worker(local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args):
    if not torch.cuda.is_available():
        raise OSError("CUDA is not available. Please check your environments")

    global_rank = machine_rank * n_gpu_per_machine + local_rank

    try:
        dist.init_process_group(
            backend="NCCL",
            init_method=dist_url,
            world_size=world_size,
            rank=global_rank,
        )

    except Exception:
        raise OSError("failed to initialize NCCL groups")

    dist_fn.synchronize()

    if n_gpu_per_machine > torch.cuda.device_count():
        raise ValueError(
            f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})"
        )

    torch.cuda.set_device(local_rank)

    if dist_fn.LOCAL_PROCESS_GROUP is not None:
        raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None")

    n_machine = world_size // n_gpu_per_machine

    for i in range(n_machine):
        ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine))
        pg = dist.new_group(ranks_on_i)

        if i == machine_rank:
            dist_fn.LOCAL_PROCESS_GROUP = pg

    fn(local_rank, *args)

def set_cuda_vd(gpu_ids, verbose=True):
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(id) for id in gpu_ids)
    if verbose: print("CUDA_VISIBLE_DEVICES = {}",format(os.environ["CUDA_VISIBLE_DEVICES"]))



================================================
FILE: utils/dicts.py
================================================
import copy
import math
import pickle

import torch
from torch import distributed as dist
from torch.utils import data


LOCAL_PROCESS_GROUP = None


def is_primary():
    return get_rank() == 0


def get_rank():
    if not dist.is_available():
        return 0

    if not dist.is_initialized():
        return 0

    return dist.get_rank()


def get_local_rank():
    if not dist.is_available():
        return 0

    if not dist.is_initialized():
        return 0

    if LOCAL_PROCESS_GROUP is None:
        raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None")

    return dist.get_rank(group=LOCAL_PROCESS_GROUP)


def synchronize():
    if not dist.is_available():
        return

    if not dist.is_initialized():
        return

    world_size = dist.get_world_size()

    if world_size == 1:
        return

    dist.barrier()


def get_world_size():
    if not dist.is_available():
        return 1

    if not dist.is_initialized():
        return 1

    return dist.get_world_size()


def is_distributed():
    raise RuntimeError('Please debug this function!')
    return get_world_size() > 1


def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False):
    world_size = get_world_size()

    if world_size == 1:
        return tensor
    dist.all_reduce(tensor, op=op, async_op=async_op)

    return tensor


def all_gather(data):
    world_size = get_world_size()

    if world_size == 1:
        return [data]

    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to("cuda")

    local_size = torch.IntTensor([tensor.numel()]).to("cuda")
    size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)]
    dist.all_gather(size_list, local_size)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    tensor_list = []
    for _ in size_list:
        tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))

    if local_size != max_size:
        padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
        tensor = torch.cat((tensor, padding), 0)

    dist.all_gather(tensor_list, tensor)

    data_list = []

    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list


def reduce_dict(input_dict, average=True):
    world_size = get_world_size()

    if world_size < 2:
        return input_dict

    with torch.no_grad():
        keys = []
        values = []

        for k in sorted(input_dict.keys()):
            keys.append(k)
            values.append(input_dict[k])

        values = torch.stack(values, 0)
        dist.reduce(values, dst=0)

        if dist.get_rank() == 0 and average:
            values /= world_size

        reduced_dict = {k: v for k, v in zip(keys, values)}

    return reduced_dict


def data_sampler(dataset, shuffle, distributed):
    if distributed:
        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)

    if shuffle:
        return data.RandomSampler(dataset)

    else:
        return data.SequentialSampler(dataset)

def clean_dict(d, keys):
    d2 = copy.deepcopy(d)
    for key in keys:
        if key in d2:
            del d2[key]
    return d2


================================================
FILE: utils/intermediate_vis.py
================================================
from dataclasses import astuple
import torch
import argparse
import numpy as np
import os
import pickle
import torch
import torch.nn.functional as F
import yaml

from prettytable import PrettyTable
from torch.utils.tensorboard import SummaryWriter

from utils.tables import *
from utils.dicts import clean_dict
from utils.loss import lovasz_softmax


class Vis_iter(object):
    no_log_keys = ['project', 'name','log_tb', 'log_wandb','check_every', 'eval_every','device', 'parallel', 'pin_memory', 'num_workers']
                   
    def __init__(self, args, model, optimizer, scheduler_iter, scheduler_epoch, test_loader,log_path):

        # Objects
        self.model = model
        self.optimizer, self.scheduler_iter, self.scheduler_epoch= optimizer, scheduler_iter, scheduler_epoch
        # Paths
        self.log_path = log_path

        if args.dataset =='kitti':
            config_file = os.path.join('/home/jumin/multinomial_diffusion/datasets/semantic_kitti.yaml')
            kitti_config = yaml.safe_load(open(config_file, 'r'))
            self.remap = kitti_config['learning_map_inv']
            self.color_map = kitti_config["color_map"]
            label = kitti_config['labels']
            map_index = np.asarray([self.remap[i] for i in range(20)])
            self.label_to_names = np.asarray([label[map_i] for map_i in map_index])

        elif args.dataset =='carla':
            base_dir = os.path.dirname(__file__)
            config_file = os.path.join(base_dir, '../datasets/carla.yaml')
            carla_config = yaml.safe_load(open(config_file, 'r'))
            self.color_map = carla_config["remap_color_map"]
            self.remap = None
            LABEL_TO_NAMES = carla_config["label_to_names"]
            self.label_to_names = np.asarray(list(LABEL_TO_NAMES.values()))


        # Initialize
        self.current_epoch = 0
        self.train_metrics, self.eval_metrics, self.ssc_metrics, self.seg_metrics = {}, {}, {}, {}
        self.eval_epochs = []
        self.completion_epochs = []

        # Store data loaders
        self.test_loader = test_loader

        # Store args
        create_folders(args)
        save_args(args)
        self.args = args

        # Init logging
        args_dict = clean_dict(vars(args), keys=self.no_log_keys)
        if args.log_tb:
            self.writer = SummaryWriter(os.path.join(self.log_path, 'tb'))
            self.writer.add_text("args", get_args_table(args_dict).get_html_string(), global_step=0)

    def run(self, epochs):
        self.checkpoint_load(self.args.resume_path)
        for epoch in range(self.current_epoch, epochs): 
            self.sample()

    def sample(self):
        self.model.eval()
        with torch.no_grad():
            for iterate, (voxel_input, output, counts) in enumerate(self.test_loader):
                voxel_input = torch.from_numpy(np.asarray(voxel_input)).squeeze(1).cuda() 
                output = torch.from_numpy(np.asarray(output)).long().cuda()            
                _, intermediate = self.model.module.sample(voxel_input, intermediate=True)
                inter_vis(self.args, intermediate)
                break
                   
    def checkpoint_load(self, resume_path):
        checkpoint = torch.load(resume_path)
        
        if self.args.distribution:
            self.model.module.load_state_dict(checkpoint['model'])
        else :
            self.model.load_state_dict(checkpoint['model'])
        
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        if self.scheduler_iter: self.scheduler_iter.load_state_dict(checkpoint['scheduler_iter'])
        if self.scheduler_epoch: self.scheduler_epoch.load_state_dict(checkpoint['scheduler_epoch'])

        self.current_epoch = checkpoint['current_epoch']
        self.train_metrics = checkpoint['train_metrics']
        self.eval_metrics = checkpoint['eval_metrics']
        self.eval_epochs = checkpoint['eval_epochs']


================================================
FILE: utils/loss.py
================================================
import math
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
try:
    from itertools import  ifilterfalse
except ImportError: # py3k
    from itertools import  filterfalse as ifilterfalse



# -*- coding:utf-8 -*-
# author: Xinge

def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = y_true.view(-1)
    y_pred_f = y_pred.view(-1)
    intersection = (y_true_f * y_pred_f).sum()
    return (2. * intersection + smooth) / (y_true_f.sum() + y_pred_f.sum() + smooth)

def dice_coef_multilabel(y_true, y_pred, numLabels=11):
    dice=0
    for index in range(1, numLabels):
        dice += dice_coef(y_true[:,index,:,:,:], y_pred[:,index,:,:,:])
    return (numLabels-1) - dice

"""
Lovasz-Softmax and Jaccard hinge loss in PyTorch
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License)
"""

def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard

# --------------------------- MULTICLASS LOSSES ---------------------------


def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None):
    """
    Multi-class Lovasz-Softmax loss
      probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1).
              Interpreted as binary (sigmoid) output with outputs of size [B, H, W].
      labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
      per_image: compute the loss per image instead of per batch
      ignore: void class labels
    """
    if per_image:
        loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes)
                          for prob, lab in zip(probas, labels))
    else:
        loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes)
    return loss


def lovasz_softmax_flat(probas, labels, classes='present'):
    """
    Multi-class Lovasz-Softmax loss
      probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1)
      labels: [P] Tensor, ground truth labels (between 0 and C - 1)
      classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average.
    """
    if probas.numel() == 0:
        # only void pixels, the gradients should be 0
        return probas * 0.
    C = probas.size(1)
    losses = []
    class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
    for c in class_to_sum:
        fg = (labels == c).float() # foreground for class c
        if (classes is 'present' and fg.sum() == 0):
            continue
        if C == 1:
            if len(classes) > 1:
                raise ValueError('Sigmoid output possible only with 1 class')
            class_pred = probas[:, 0]
        else:
            class_pred = probas[:, c]
        errors = (Variable(fg) - class_pred).abs()
        errors_sorted, perm = torch.sort(errors, 0, descending=True)
        perm = perm.data
        fg_sorted = fg[perm]
        losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted))))
    return mean(losses)


def flatten_probas(probas, labels, ignore=None):
    """
    Flattens predictions in the batch
    """
    if probas.dim() == 3:
        # assumes output of a sigmoid layer
        B, H, W = probas.size()
        probas = probas.view(B, 1, H, W)
    elif probas.dim() == 5:
        #3D segmentation
        B, C, L, H, W = probas.size()
        probas = probas.contiguous().view(B, C, L, H*W)
    B, C, H, W = probas.size()
    probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C)  # B * H * W, C = P, C
    labels = labels.view(-1)
    if ignore is None:
        return probas, labels
    valid = (labels != ignore)
    vprobas = probas[valid.nonzero().squeeze()]
    vlabels = labels[valid]
    return vprobas, vlabels


# --------------------------- HELPER FUNCTIONS ---------------------------
def isnan(x):
    return x != x
    
    
def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n


================================================
FILE: utils/multistep.py
================================================
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.lr_scheduler import _LRScheduler

class LinearWarmupScheduler(_LRScheduler):
    """ Linearly warm-up (increasing) learning rate, starting from zero.
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        total_epoch: target learning rate is reached at total_epoch.
    """

    def __init__(self, optimizer, total_epoch, last_epoch=-1):
        self.total_epoch = total_epoch
        super(LinearWarmupScheduler, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        return [base_lr * min(1, (self.last_epoch / self.total_epoch)) for base_lr in self.base_lrs]
        
optim_choices = {'sgd', 'adam', 'adamax'}

def get_optim(args, model):
    assert args.optimizer in optim_choices

    if args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr))
    elif args.optimizer == 'adamax':
        optimizer = optim.Adamax(model.parameters(), lr=args.lr, betas=(args.momentum, args.momentum_sqr))

    if args.warmup is not None:
        scheduler_iter = LinearWarmupScheduler(optimizer, total_epoch=args.warmup)
    else:
        scheduler_iter = None

    if len(args.milestones)>0:
        scheduler_epoch = MultiStepLR(optimizer, milestones=args.milestones, gamma=args.gamma)
    else:
        scheduler_epoch = None

    return optimizer, scheduler_iter, scheduler_epoch

================================================
FILE: utils/tables.py
================================================
from prettytable import PrettyTable
import torch
import os
import pickle
import numpy as np
import torch.nn.functional as F
import open3d as o3d

def get_args_table(args_dict):
    table = PrettyTable(['Arg', 'Value'])
    for arg, val in args_dict.items():
        table.add_row([arg, val])
    return table

def get_miou_table(args, label_to_names, miou):
    table = PrettyTable(['Label', 'mIoU'])
    for i in range(args.num_classes):
        table.add_row([label_to_names[i], 100 * miou[i]])
    return table

def get_metric_table(metric_dict, epochs):
    table = PrettyTable()
    table.add_column('Epoch', epochs)
    if len(metric_dict)>0:
        for metric_name, metric_values in metric_dict.items():
            table.add_column(metric_name, metric_values)
    return table

def create_folders(args):
    # Create log folder
    os.makedirs(args.log_path, exist_ok=True)
    os.makedirs(args.log_path+'/Completion', exist_ok=True)
    os.makedirs(args.log_path+'/Input', exist_ok=True)
    os.makedirs(args.log_path+'/Output', exist_ok=True)
    os.makedirs(args.log_path+'/Invalid', exist_ok=True)
    print("Storing logs in:", args.log_path)

def inter_vis(args, recons):
    for r in range(len(recons)):
        for batch, samples_i in enumerate(recons[r]):
            color_index = []
            for i in range(1, args.num_classes):
                index = torch.nonzero(samples_i == i ,as_tuple=False)
                color_index.append(F.pad(index,(1,0),'constant',value = i))
            colors_indexs = torch.cat(color_index, dim = 0).cpu().numpy()
            np.savetxt('/home/jumin/multinomial_diffusion/Result/Condition/Completion/iteration/batch{}_{}.txt'.format(batch, r), colors_indexs)


def visualization(args, recons, input_data, output, invalid, iteration):

    for batch, (samples_i, input_i, output_i, invalid_i) in enumerate(zip(recons, input_data, output, invalid)):
        color_index = []
        output_index = []
        input_points = torch.nonzero(input_i == 1, as_tuple=False).cpu().numpy()
        if args.dataset =='carla':
            invalid_points = torch.nonzero(invalid_i == 0, as_tuple=False).cpu().numpy() 
        elif args.dataset =='kitti':
            invalid_points = torch.nonzero(invalid_i == 1, as_tuple=False).cpu().numpy() 

        for i in range(1, args.num_classes):
            index = torch.nonzero(samples_i == i ,as_tuple=False)
            out_color = torch.nonzero(output_i == i, as_tuple=False)
            color_index.append(F.pad(index,(1,0),'constant',value = i))
            output_index.append(F.pad(out_color,(1,0),'constant',value=i))
        colors_indexs = torch.cat(color_index, dim = 0).cpu().numpy()
        out_indexs = torch.cat(output_index, dim = 0).cpu().numpy()
        np.savetxt(args.log_path+'/Completion/result_{}.txt'.format((iteration * args.batch_size) + batch), colors_indexs)

        '''np.savetxt(args.log_path+'/Input/input_{}.txt'.format((iteration * args.batch_size) + batch), input_points)
        np.savetxt(args.log_path+'/Invalid/invalid_{}.txt'.format((iteration * args.batch_size) + batch), invalid_points)
        np.savetxt(args.log_path+'/Output/gt_{}.txt'.format((iteration * args.batch_size) + batch), out_indexs)'''
        

def completion_vis(args, input_p, recons):
    for batch, (recon_i, input_i) in enumerate(zip(recons, input_p)):
        recon_points = torch.nonzero(recon_i == 1, as_tuple=False).cpu().numpy()
        input_points = torch.nonzero(input_i == 1, as_tuple=False).cpu().numpy()
        np.savetxt(args.log_path+'/Completion/completion_{}.txt'.format(batch), recon_points)
        np.savetxt(args.log_path+'/Input/input_{}.txt'.format(batch), input_points)


def iou_one_frame(pred, target, n_classes=23):
    pred = pred.view(-1).detach().cpu().numpy()
    target = target.view(-1).detach().cpu().numpy()
    intersection = np.zeros(n_classes)
    union = np.zeros(n_classes)

    for cls in range(n_classes):
        intersection[cls] = np.sum((pred == cls) & (target == cls))
        union[cls] = np.sum((pred == cls) | (target == cls))
    return intersection, union


def get_result(args, for_mask, output, preds, SSC=True):
    for_mask = for_mask.contiguous().view(-1)
    output = output.contiguous().view(-1)
    preds = preds.contiguous().view(-1)
    
    if SSC :
        if args.dataset == 'kitti':
            mask = for_mask == 0
        elif args.dataset== 'carla':
            mask = for_mask > 0
    else : 
        mask = for_mask == 1

    output_masked = output[mask]
    iou_output_masked = output_masked.cpu().numpy()
    iou_output_masked[iou_output_masked != 0] = 1

    preds_masked = preds[mask]
    iou_preds_masked = preds_masked.cpu().numpy()
    iou_preds_masked[iou_preds_masked != 0] = 1

    # I, U for a frame
    correct = np.sum(output_masked.cpu().numpy() == preds_masked.cpu().numpy())
    total = preds_masked.shape[0]

    pred_TP = np.sum((iou_preds_masked == 1) & (iou_output_masked == 1))
    pred_FP = np.sum((iou_preds_masked == 1) & (iou_output_masked == 0))
    pred_TN = np.sum((iou_preds_masked == 0) & (iou_output_masked == 0))
    pred_FN = np.sum((iou_preds_masked == 0) & (iou_output_masked == 1))

    intersection, union = iou_one_frame(preds_masked, output_masked, n_classes=args.num_classes)
    return correct, total, pred_TP, pred_FP, pred_TN, pred_FN, intersection, union

def save_args(args):
    # Save args
    with open(os.path.join(args.log_path, 'args.pickle'), "wb") as f:
        pickle.dump(args, f)

    # Save args table
    args_table = get_args_table(vars(args))
    with open(os.path.join(args.log_path,'args_table.txt'), "w") as f:
        f.write(str(args_table))

def print_completion(num_correct, num_total, TP, FP, FN):
    print("\n=========================================\n")
    accuracy = num_correct/num_total
    print("\nAccuracy : ", accuracy)

    precision = 100 * TP / (TP + FP)
    recall = 100 * TP / (TP + FN)
    iou = 100 * TP / (TP + FP + FN)

    print("\nCompleteness")
    print("precision:", precision)
    print("recall:", recall)
    print("iou:", iou)

    print("\n=========================================\n")
    return iou

def print_result(args, label_to_names, num_correct, num_total, all_intersections, all_unions, TP, FP, FN, SSC=True):
    if SSC :
        print("\n========== Semantic Scene Completion =============\n")
    else :
        print("\n============ Semantic Segmentation ===============\n")
    accuracy = num_correct/num_total
    print("\nAccuracy : ", accuracy)

    precision = 100 * TP / (TP + FP)
    recall = 100 * TP / (TP + FN)
    iou = 100 * TP / (TP + FP + FN)

    print("\nCompleteness")
    print("precision:", precision)
    print("recall:", recall)
    print("iou:", iou)

    print("\nSemantic IoU Per Class")
    miou = all_intersections / all_unions
    for i in range(args.num_classes):
        print(label_to_names[i], ':', 100 * miou[i])
    print("\n====================================================\n")
    return iou, miou

================================================
FILE: visualization.py
================================================
import os
import open3d as o3d
import open3d.visualization.gui as gui
import open3d.visualization.rendering as rendering
import argparse
import numpy as np
import yaml
import struct

parser = argparse.ArgumentParser()
parser.add_argument('--M', default='scene-scale-diffusion') # VQVAE, multinomial_diffusion
parser.add_argument('--Driver', default='D')
parser.add_argument('--frame', default='0')
parser.add_argument('--file', default='result_')
parser.add_argument('--folder', default='Completion')
parser.add_argument('--model', default='image_init8_concat_att')
parser.add_argument('--name', default='Semantic Scene Completion')
parser.add_argument('--invalid', default = False)

class SpheresApp:
    MENU_SCENE = 1
    MENU_BEFORE = 2
    MENU_QUIT = 3

    def __init__(self, opt):
        self._id = 0
        self.opt = opt
        
        self.window = gui.Application.instance.create_window("Semantic Scene Completion", 1500, 1000)
        self.scene = gui.SceneWidget()
        self.scene.scene = rendering.Open3DScene(self.window.renderer)
        self.scene.scene.set_background([1, 1, 1, 1])
        self.scene.scene.scene.set_sun_light(
            [-0.577, 0.577, -0.577],  # direction
            [1, 1, 1],  # color
            60000)  # intensity
        
        self.scene.scene.scene.enable_sun_light(True)
        bbox = o3d.geometry.AxisAlignedBoundingBox([64, 64, -60], [64, 64, 60])
        
        self.scene.setup_camera(60, bbox, [0, 0, 1])
        self.window.add_child(self.scene)


        if gui.Application.instance.menubar is None:
            
            debug_menu = gui.Menu()
            debug_menu.add_item("Next Scene", SpheresApp.MENU_SCENE)
            debug_menu.add_separator()
            debug_menu.add_item("Before Scene", SpheresApp.MENU_BEFORE)
            debug_menu.add_separator()
            debug_menu.add_item("Quit", SpheresApp.MENU_QUIT)
            menu = gui.Menu()
            menu.add_menu("SSC", debug_menu)
            gui.Application.instance.menubar = menu

        # The menubar is global, but we need to connect the menu items to the
        # window, so that the window can call the appropriate function when the menu item is activated.
        self.window.set_on_menu_item_activated(SpheresApp.MENU_SCENE,self._on_menu_scene)
        self.window.set_on_menu_item_activated(SpheresApp.MENU_QUIT,self._on_menu_quit)
        self.window.set_on_menu_item_activated(SpheresApp.MENU_BEFORE,self._on_menu_before)

    def _on_menu_before(self):
        self._id -= 1
        mat = rendering.MaterialRecord()
        mat.shader = "defaultLit"

        if self.opt.file == 'input_':
            points = get_input(self.opt)
        else :
            points, colors = get_voxel(self.opt)
        
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)

        if (self.opt.file != 'input_'):
            pcd.colors = o3d.utility.Vector3dVector(colors/255)
        self.scene.scene.clear_geometry()
        voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=1)
        self.scene.scene.add_geometry("scene" + str(self._id), voxel_grid, mat)
        print(self.opt.frame)
        self.opt.frame = str(int(self.opt.frame)-1)

    def _on_menu_quit(self):
        gui.Application.instance.quit()

    def _on_menu_scene(self):
        self._id += 1
        mat = rendering.MaterialRecord()
        mat.shader = "defaultLit"

        if self.opt.file == 'input_':
            points = get_input(self.opt)
        else :
            points, colors = get_voxel(self.opt)
        
        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)

        if (self.opt.file != 'input_'):
            pcd.colors = o3d.utility.Vector3dVector(colors/255)
        self.scene.scene.clear_geometry()
        voxel_grid = o3d.geometry.VoxelGrid.create_from_point_cloud(pcd, voxel_size=1)
        self.scene.scene.add_geometry("scene" + str(self._id), voxel_grid, mat)
        print(self.opt.frame)
        self.opt.frame = str(int(self.opt.frame)+1)

def get_voxel(opt):

    if opt.invalid :
        invalid_path = opt.Driver+':/'+opt.M+ '/result/' + opt.model +'/Invalid/invalid_'+ opt.frame +'.txt'
        invalid_points = np.loadtxt(invalid_path, delimiter=' ')
        invalid_colors = np.full(len(invalid_points,), 0)
                
        point_cloud_path = opt.Driver+':/'+opt.M+ '/result/' + opt.model +'/' + opt.folder +'/'+ opt.file + opt.frame +'.txt'
        points_colors = np.loadtxt(point_cloud_path, delimiter=' ')
        points = points_colors[:, 1:]
        colors = points_colors[:, 0]
        
        points = np.concatenate((invalid_points, points), axis=0)
        colors = np.concatenate((invalid_colors, colors), axis=0)
        
        points, index = np.unique(points, return_index=True, axis=0)
        colors = colors[index, ...]
        
    else :
        point_cloud_path = 'C:/Users/jumin/Dataset/result_319_110.txt'
        points_colors = np.loadtxt(point_cloud_path, delimiter=' ')
        points = points_colors[:, 1:]
        colors = points_colors[:, 0]

    if opt.dataset == 'carla' : 
        base_dir = os.path.dirname(__file__)
        config_file = os.path.join(base_dir, 'datasets/carla.yaml')
        config = yaml.safe_load(open(config_file, 'r'))
        color_map = config["remap_color_map"]
    
    color = np.asarray([color_map[c] for c in colors])

    return points, color

def get_input(opt):
    point_cloud_path=opt.Driver+':/'+opt.M+'/result/' + opt.model +'/Invalid/invalid_' + opt.frame +'.txt'
    points_colors = np.loadtxt(point_cloud_path, delimiter=' ')
    points = points_colors
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(points)
    return points

def main(opt):
    gui.Application.instance.initialize()
    SpheresApp(opt)
    gui.Application.instance.run()

if __name__ == "__main__":
    opt = parser.parse_args()
    main(opt)
Download .txt
gitextract_ghos05eb/

├── .gitignore
├── LICENSE.txt
├── README.md
├── SSC_train.py
├── __init__.py
├── datasets/
│   ├── carla.yaml
│   ├── carla_dataset.py
│   └── data.py
├── layers/
│   ├── Ablation/
│   │   └── wo_diffusion.py
│   ├── Latent_Level/
│   │   ├── stage1/
│   │   │   ├── model.py
│   │   │   ├── vector_quantizer.py
│   │   │   └── vqvae.py
│   │   └── stage2/
│   │       ├── Gen_diffusion.py
│   │       └── gen_denoise.py
│   ├── Voxel_Level/
│   │   ├── Con_Diffusion.py
│   │   ├── Gen_Diffusion.py
│   │   ├── denoise.py
│   │   └── gen_denoise.py
│   └── __init__.py
├── requirements.txt
├── setup.py
├── simple_visualize.py
├── train.py
├── utils/
│   ├── cuda.py
│   ├── dicts.py
│   ├── intermediate_vis.py
│   ├── loss.py
│   ├── multistep.py
│   └── tables.py
└── visualization.py
Download .txt
SYMBOL INDEX (295 symbols across 22 files)

FILE: SSC_train.py
  function get_args (line 28) | def get_args():
  function main (line 92) | def main():
  function start (line 112) | def start(local_rank, args):

FILE: datasets/carla_dataset.py
  class CarlaDataset (line 21) | class CarlaDataset(Dataset):
    method __init__ (line 25) | def __init__(self, directory,
    method __len__ (line 104) | def __len__(self):
    method collate_fn (line 107) | def collate_fn(self, data):
    method points_to_voxels (line 113) | def points_to_voxels(self, voxel_grid, points, t_i):
    method get_pose (line 129) | def get_pose(self, idx):
    method __getitem__ (line 135) | def __getitem__(self, idx):
    method find_horizon (line 193) | def find_horizon(self, idx):

FILE: datasets/data.py
  function get_data_id (line 11) | def get_data_id(args):
  function get_class_weights (line 14) | def get_class_weights(freq):
  function get_data (line 23) | def get_data(args):

FILE: layers/Ablation/wo_diffusion.py
  class wo_diff (line 7) | class wo_diff(torch.nn.Module):
    method __init__ (line 8) | def __init__(self, args, multi_criterion) -> None:
    method device (line 22) | def device(self):
    method forward (line 25) | def forward(self, x, input_ten):
    method sample (line 31) | def sample(self, x):

FILE: layers/Latent_Level/stage1/model.py
  function conv3x3x3 (line 10) | def conv3x3x3(in_planes, out_planes, stride=1):
  function conv1x3x3 (line 13) | def conv1x3x3(in_planes, out_planes, stride=1):
  function conv1x1x3 (line 16) | def conv1x1x3(in_planes, out_planes, stride=1):
  function conv1x3x1 (line 19) | def conv1x3x1(in_planes, out_planes, stride=1):
  function conv3x1x1 (line 22) | def conv3x1x1(in_planes, out_planes, stride=1):
  function conv3x1x3 (line 25) | def conv3x1x3(in_planes, out_planes, stride=1):
  function conv1x1 (line 28) | def conv1x1(in_planes, out_planes, stride=1):
  class Asymmetric_Residual_Block (line 32) | class Asymmetric_Residual_Block(nn.Module):
    method __init__ (line 33) | def __init__(self, in_filters, out_filters):
    method forward (line 60) | def forward(self, x):
  class DownBlock (line 81) | class DownBlock(nn.Module):
    method __init__ (line 82) | def __init__(self, in_filters, out_filters, pooling=True, drop_out=Tru...
    method forward (line 93) | def forward(self, x):
  class UpBlock (line 102) | class UpBlock(nn.Module):
    method __init__ (line 103) | def __init__(self, in_filters, out_filters, height_pooling):
    method forward (line 135) | def forward(self, x, skip=False):
  class DDCM (line 159) | class DDCM(nn.Module):
    method __init__ (line 160) | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), str...
    method forward (line 180) | def forward(self, x):
  function l2norm (line 197) | def l2norm(t):
  class Attention (line 200) | class Attention(nn.Module):
    method __init__ (line 201) | def __init__(self, dim, heads = 4, scale = 10):
    method forward (line 208) | def forward(self, x):
  class C_Encoder (line 221) | class C_Encoder(nn.Module):
    method __init__ (line 222) | def __init__(self, args,  nclasses=20, init_size=16, l_size='882', att...
    method forward (line 256) | def forward(self, x, out_conv=True):
  class C_Decoder (line 278) | class C_Decoder(nn.Module):
    method __init__ (line 279) | def __init__(self, args, nclasses=20, init_size=16, l_size='882', atte...
    method forward (line 309) | def forward(self, x, in_conv=True):
  class Completion (line 333) | class Completion(nn.Module):
    method __init__ (line 334) | def __init__(self, args, num_class = 11, init_size=32):
    method forward (line 362) | def forward(self, x):

FILE: layers/Latent_Level/stage1/vector_quantizer.py
  class VectorQuantizer (line 5) | class VectorQuantizer(nn.Module):
    method __init__ (line 7) | def __init__(self,
    method forward (line 19) | def forward(self, z: torch.tensor, point=False) -> torch.tensor: # lat...
    method codebook_to_embedding (line 46) | def codebook_to_embedding(self, encoding_inds, latents_shape): # laten...

FILE: layers/Latent_Level/stage1/vqvae.py
  class vqvae (line 10) | class vqvae(torch.nn.Module):
    method __init__ (line 11) | def __init__(self, args, multi_criterion) -> None:
    method device (line 28) | def device(self):
    method encode (line 31) | def encode(self, x):
    method vector_quantize (line 36) | def vector_quantize(self, latent):
    method coodbook (line 40) | def coodbook(self,quantized_latent_ind, latents_shape):
    method decode (line 44) | def decode(self, quantized_latent):
    method forward (line 49) | def forward(self, x, input_ten):
    method sample (line 58) | def sample(self, x):

FILE: layers/Latent_Level/stage2/Gen_diffusion.py
  function sum_except_batch (line 14) | def sum_except_batch(x, num_dims=1):
  function log_1_min_a (line 18) | def log_1_min_a(a):
  function log_add_exp (line 22) | def log_add_exp(a, b):
  function exists (line 27) | def exists(x):
  function extract (line 31) | def extract(a, t, x_shape):
  function default (line 37) | def default(val, d):
  function log_categorical (line 43) | def log_categorical(log_x_start, log_prob):
  function index_to_log_onehot (line 47) | def index_to_log_onehot(x, num_classes):
  function log_onehot_to_index (line 58) | def log_onehot_to_index(log_x):
  function cosine_beta_schedule (line 62) | def cosine_beta_schedule(timesteps, s = 0.008):
  class latent_diffusion (line 78) | class latent_diffusion(torch.nn.Module):
    method __init__ (line 79) | def __init__(self, args, VAE_DENSE, multi_criterion,
    method device (line 115) | def device(self):
    method multinomial_kl (line 118) | def multinomial_kl(self, log_prob1, log_prob2):
    method q_pred_one_timestep (line 122) | def q_pred_one_timestep(self, log_x_t, t):
    method q_pred (line 135) | def q_pred(self, log_x_start, t):
    method predict_start (line 146) | def predict_start(self, log_x_t, t):
    method q_posterior (line 158) | def q_posterior(self, log_x_start, log_x_t, t):
    method p_pred (line 182) | def p_pred(self, log_x, t):
    method log_sample_categorical (line 187) | def log_sample_categorical(self, logits):
    method q_sample (line 194) | def q_sample(self, log_x_start, t):
    method kl_prior (line 199) | def kl_prior(self, log_x_start):
    method sample_time (line 210) | def sample_time(self, b, device, method='uniform'):
    method forward (line 233) | def forward(self, x, input_data):
    method sample (line 287) | def sample(self, x):

FILE: layers/Latent_Level/stage2/gen_denoise.py
  function conv3x3x3 (line 11) | def conv3x3x3(in_planes, out_planes, stride=1):
  function conv1x3x3 (line 14) | def conv1x3x3(in_planes, out_planes, stride=1):
  function conv1x1x3 (line 18) | def conv1x1x3(in_planes, out_planes, stride=1):
  function conv1x3x1 (line 22) | def conv1x3x1(in_planes, out_planes, stride=1):
  function conv3x1x1 (line 26) | def conv3x1x1(in_planes, out_planes, stride=1):
  function conv3x1x3 (line 30) | def conv3x1x3(in_planes, out_planes, stride=1):
  function conv1x1 (line 34) | def conv1x1(in_planes, out_planes, stride=1):
  class Asymmetric_Residual_Block (line 38) | class Asymmetric_Residual_Block(nn.Module):
    method __init__ (line 39) | def __init__(self, in_filters, out_filters, time_filters=128):
    method forward (line 64) | def forward(self, x, t):
  class DDCM (line 91) | class DDCM(nn.Module):
    method __init__ (line 92) | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), str...
    method forward (line 106) | def forward(self, x):
  function l2norm (line 124) | def l2norm(t):
  class Attention (line 127) | class Attention(nn.Module):
    method __init__ (line 128) | def __init__(self, dim, heads = 4, scale = 10):
    method forward (line 135) | def forward(self, x):
  class Cross_Attention (line 148) | class Cross_Attention(nn.Module):
    method __init__ (line 149) | def __init__(self, dim, heads = 4, scale = 10):
    method forward (line 159) | def forward(self, x, cond_x):
  class DownBlock (line 175) | class DownBlock(nn.Module):
    method __init__ (line 176) | def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=...
    method forward (line 193) | def forward(self, x, t):
  class UpBlock (line 201) | class UpBlock(nn.Module):
    method __init__ (line 202) | def __init__(self, in_filters, out_filters, height_pooling, time_filte...
    method forward (line 231) | def forward(self, x, residual, t):
  function timestep_embedding (line 258) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
  class Denoise (line 280) | class Denoise(nn.Module):
    method __init__ (line 281) | def __init__(self, args, num_class = 11, init_size=32, discrete=True):
    method forward (line 324) | def forward(self, x, t):

FILE: layers/Voxel_Level/Con_Diffusion.py
  function sum_except_batch (line 14) | def sum_except_batch(x, num_dims=1):
  function log_1_min_a (line 18) | def log_1_min_a(a):
  function log_add_exp (line 22) | def log_add_exp(a, b):
  function exists (line 27) | def exists(x):
  function extract (line 31) | def extract(a, t, x_shape):
  function default (line 37) | def default(val, d):
  function log_categorical (line 43) | def log_categorical(log_x_start, log_prob):
  function index_to_log_onehot (line 47) | def index_to_log_onehot(x, num_classes):
  function log_onehot_to_index (line 57) | def log_onehot_to_index(log_x):
  function cosine_beta_schedule (line 61) | def cosine_beta_schedule(timesteps, s = 0.008):
  class Con_Diffusion (line 77) | class Con_Diffusion(torch.nn.Module):
    method __init__ (line 78) | def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, a...
    method device (line 117) | def device(self):
    method multinomial_kl (line 120) | def multinomial_kl(self, log_prob1, log_prob2):
    method q_pred_one_timestep (line 124) | def q_pred_one_timestep(self, log_x_t, t):
    method q_pred (line 137) | def q_pred(self, log_x_start, t):
    method predict_start (line 148) | def predict_start(self, log_x_t, t, cond):
    method q_posterior (line 156) | def q_posterior(self, log_x_start, log_x_t, t):
    method p_pred (line 177) | def p_pred(self, log_x, t, cond):
    method log_sample_categorical (line 182) | def log_sample_categorical(self, logits):
    method q_sample (line 189) | def q_sample(self, log_x_start, t):
    method kl_prior (line 194) | def kl_prior(self, log_x_start):
    method sample_time (line 205) | def sample_time(self, b, device, method='uniform'):
    method forward (line 228) | def forward(self, x, voxel_input):
    method sample (line 280) | def sample(self, voxel_input, intermediate=False):

FILE: layers/Voxel_Level/Gen_Diffusion.py
  function sum_except_batch (line 14) | def sum_except_batch(x, num_dims=1):
  function log_1_min_a (line 18) | def log_1_min_a(a):
  function log_add_exp (line 22) | def log_add_exp(a, b):
  function exists (line 27) | def exists(x):
  function extract (line 31) | def extract(a, t, x_shape):
  function default (line 37) | def default(val, d):
  function log_categorical (line 43) | def log_categorical(log_x_start, log_prob):
  function index_to_log_onehot (line 47) | def index_to_log_onehot(x, num_classes):
  function log_onehot_to_index (line 57) | def log_onehot_to_index(log_x):
  function cosine_beta_schedule (line 61) | def cosine_beta_schedule(timesteps, s = 0.008):
  class Diffusion (line 77) | class Diffusion(torch.nn.Module):
    method __init__ (line 78) | def __init__(self, args, multi_criterion,auxiliary_loss_weight=0.05, a...
    method device (line 114) | def device(self):
    method multinomial_kl (line 117) | def multinomial_kl(self, log_prob1, log_prob2):
    method q_pred_one_timestep (line 121) | def q_pred_one_timestep(self, log_x_t, t):
    method q_pred (line 134) | def q_pred(self, log_x_start, t):
    method predict_start (line 145) | def predict_start(self, log_x_t, t):
    method q_posterior (line 153) | def q_posterior(self, log_x_start, log_x_t, t):
    method p_pred (line 174) | def p_pred(self, log_x, t):
    method log_sample_categorical (line 179) | def log_sample_categorical(self, logits):
    method q_sample (line 186) | def q_sample(self, log_x_start, t):
    method kl_prior (line 191) | def kl_prior(self, log_x_start):
    method sample_time (line 202) | def sample_time(self, b, device, method='uniform'):
    method forward (line 225) | def forward(self, x, voxel_input):
    method sample (line 277) | def sample(self, voxel_input):

FILE: layers/Voxel_Level/denoise.py
  function conv3x3x3 (line 11) | def conv3x3x3(in_planes, out_planes, stride=1):
  function conv1x3x3 (line 14) | def conv1x3x3(in_planes, out_planes, stride=1):
  function conv1x1x3 (line 18) | def conv1x1x3(in_planes, out_planes, stride=1):
  function conv1x3x1 (line 22) | def conv1x3x1(in_planes, out_planes, stride=1):
  function conv3x1x1 (line 26) | def conv3x1x1(in_planes, out_planes, stride=1):
  function conv3x1x3 (line 30) | def conv3x1x3(in_planes, out_planes, stride=1):
  function conv1x1 (line 34) | def conv1x1(in_planes, out_planes, stride=1):
  class Asymmetric_Residual_Block (line 38) | class Asymmetric_Residual_Block(nn.Module):
    method __init__ (line 39) | def __init__(self, in_filters, out_filters, time_filters=32*4):
    method forward (line 71) | def forward(self, x, t):
  class DDCM (line 98) | class DDCM(nn.Module):
    method __init__ (line 99) | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), str...
    method forward (line 118) | def forward(self, x):
  function l2norm (line 136) | def l2norm(t):
  class Attention (line 139) | class Attention(nn.Module):
    method __init__ (line 140) | def __init__(self, dim, heads = 4, scale = 10):
    method forward (line 147) | def forward(self, x):
  class Cross_Attention (line 160) | class Cross_Attention(nn.Module):
    method __init__ (line 161) | def __init__(self, dim, heads = 4, scale = 10):
    method forward (line 171) | def forward(self, x, cond_x):
  class DownBlock (line 187) | class DownBlock(nn.Module):
    method __init__ (line 188) | def __init__(self, in_filters, out_filters, time_filters=32*4, kernel_...
    method forward (line 201) | def forward(self, x, t):
  class UpBlock (line 209) | class UpBlock(nn.Module):
    method __init__ (line 210) | def __init__(self, in_filters, out_filters, height_pooling, time_filte...
    method forward (line 245) | def forward(self, x, residual, t):
  function timestep_embedding (line 272) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
  class Denoise (line 286) | class Denoise(nn.Module):
    method __init__ (line 287) | def __init__(self, args, num_class = 11, init_size=32, discrete=True):
    method forward (line 322) | def forward(self, x, x_cond, t):

FILE: layers/Voxel_Level/gen_denoise.py
  function conv3x3x3 (line 11) | def conv3x3x3(in_planes, out_planes, stride=1):
  function conv1x3x3 (line 14) | def conv1x3x3(in_planes, out_planes, stride=1):
  function conv1x1x3 (line 18) | def conv1x1x3(in_planes, out_planes, stride=1):
  function conv1x3x1 (line 22) | def conv1x3x1(in_planes, out_planes, stride=1):
  function conv3x1x1 (line 26) | def conv3x1x1(in_planes, out_planes, stride=1):
  function conv3x1x3 (line 30) | def conv3x1x3(in_planes, out_planes, stride=1):
  function conv1x1 (line 34) | def conv1x1(in_planes, out_planes, stride=1):
  class Asymmetric_Residual_Block (line 38) | class Asymmetric_Residual_Block(nn.Module):
    method __init__ (line 39) | def __init__(self, in_filters, out_filters, time_filters=128):
    method forward (line 70) | def forward(self, x, t):
  class DDCM (line 97) | class DDCM(nn.Module):
    method __init__ (line 98) | def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), str...
    method forward (line 115) | def forward(self, x):
  function l2norm (line 132) | def l2norm(t):
  class Attention (line 135) | class Attention(nn.Module):
    method __init__ (line 136) | def __init__(self, dim, heads = 4, scale = 10):
    method forward (line 143) | def forward(self, x):
  class Cross_Attention (line 156) | class Cross_Attention(nn.Module):
    method __init__ (line 157) | def __init__(self, dim, heads = 4, scale = 10):
    method forward (line 167) | def forward(self, x, cond_x):
  class DownBlock (line 183) | class DownBlock(nn.Module):
    method __init__ (line 184) | def __init__(self, in_filters, out_filters, time_filters, kernel_size=...
    method forward (line 200) | def forward(self, x, t):
  class UpBlock (line 208) | class UpBlock(nn.Module):
    method __init__ (line 209) | def __init__(self, in_filters, out_filters, height_pooling, time_filte...
    method forward (line 244) | def forward(self, x, residual, t):
  function timestep_embedding (line 271) | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=Fal...
  class Denoise (line 293) | class Denoise(nn.Module):
    method __init__ (line 294) | def __init__(self, args, num_class = 11, init_size=32, discrete=True):
    method forward (line 329) | def forward(self, x, t):

FILE: simple_visualize.py
  function load_config (line 7) | def load_config(yaml_path):
  function load_pointcloud (line 12) | def load_pointcloud(filepath, learning_map, color_map):
  function main (line 29) | def main():

FILE: train.py
  class Experiment (line 19) | class Experiment(object):
    method __init__ (line 22) | def __init__(self, args, model, optimizer, scheduler_iter, scheduler_e...
    method run (line 65) | def run(self, epochs):
    method train_fn (line 112) | def train_fn(self, epoch):
    method eval_fn (line 142) | def eval_fn(self, epoch):
    method sample (line 162) | def sample(self):
    method resume (line 215) | def resume(self):
    method log_metrics (line 246) | def log_metrics(self, dict, type):
    method save_metrics (line 254) | def save_metrics(self):
    method checkpoint_save (line 270) | def checkpoint_save(self, epoch):
    method checkpoint_load (line 292) | def checkpoint_load(self, resume_path):

FILE: utils/cuda.py
  function find_free_port (line 11) | def find_free_port():
  function launch (line 23) | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=...
  function distributed_worker (line 54) | def distributed_worker(local_rank, fn, world_size, n_gpu_per_machine, ma...
  function set_cuda_vd (line 94) | def set_cuda_vd(gpu_ids, verbose=True):

FILE: utils/dicts.py
  function is_primary (line 13) | def is_primary():
  function get_rank (line 17) | def get_rank():
  function get_local_rank (line 27) | def get_local_rank():
  function synchronize (line 40) | def synchronize():
  function get_world_size (line 55) | def get_world_size():
  function is_distributed (line 65) | def is_distributed():
  function all_reduce (line 70) | def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False):
  function all_gather (line 80) | def all_gather(data):
  function reduce_dict (line 115) | def reduce_dict(input_dict, average=True):
  function data_sampler (line 140) | def data_sampler(dataset, shuffle, distributed):
  function clean_dict (line 150) | def clean_dict(d, keys):

FILE: utils/intermediate_vis.py
  class Vis_iter (line 19) | class Vis_iter(object):
    method __init__ (line 22) | def __init__(self, args, model, optimizer, scheduler_iter, scheduler_e...
    method run (line 69) | def run(self, epochs):
    method sample (line 74) | def sample(self):
    method checkpoint_load (line 84) | def checkpoint_load(self, resume_path):

FILE: utils/loss.py
  function dice_coef (line 16) | def dice_coef(y_true, y_pred, smooth=1e-6):
  function dice_coef_multilabel (line 22) | def dice_coef_multilabel(y_true, y_pred, numLabels=11):
  function lovasz_grad (line 33) | def lovasz_grad(gt_sorted):
  function lovasz_softmax (line 50) | def lovasz_softmax(probas, labels, classes='present', per_image=False, i...
  function lovasz_softmax_flat (line 68) | def lovasz_softmax_flat(probas, labels, classes='present'):
  function flatten_probas (line 99) | def flatten_probas(probas, labels, ignore=None):
  function isnan (line 123) | def isnan(x):
  function mean (line 127) | def mean(l, ignore_nan=False, empty=0):

FILE: utils/multistep.py
  class LinearWarmupScheduler (line 5) | class LinearWarmupScheduler(_LRScheduler):
    method __init__ (line 12) | def __init__(self, optimizer, total_epoch, last_epoch=-1):
    method get_lr (line 16) | def get_lr(self):
  function get_optim (line 21) | def get_optim(args, model):

FILE: utils/tables.py
  function get_args_table (line 9) | def get_args_table(args_dict):
  function get_miou_table (line 15) | def get_miou_table(args, label_to_names, miou):
  function get_metric_table (line 21) | def get_metric_table(metric_dict, epochs):
  function create_folders (line 29) | def create_folders(args):
  function inter_vis (line 38) | def inter_vis(args, recons):
  function visualization (line 49) | def visualization(args, recons, input_data, output, invalid, iteration):
  function completion_vis (line 74) | def completion_vis(args, input_p, recons):
  function iou_one_frame (line 82) | def iou_one_frame(pred, target, n_classes=23):
  function get_result (line 94) | def get_result(args, for_mask, output, preds, SSC=True):
  function save_args (line 127) | def save_args(args):
  function print_completion (line 137) | def print_completion(num_correct, num_total, TP, FP, FN):
  function print_result (line 154) | def print_result(args, label_to_names, num_correct, num_total, all_inter...

FILE: visualization.py
  class SpheresApp (line 20) | class SpheresApp:
    method __init__ (line 25) | def __init__(self, opt):
    method _on_menu_before (line 63) | def _on_menu_before(self):
    method _on_menu_quit (line 84) | def _on_menu_quit(self):
    method _on_menu_scene (line 87) | def _on_menu_scene(self):
  function get_voxel (line 108) | def get_voxel(opt):
  function get_input (line 142) | def get_input(opt):
  function main (line 150) | def main(opt):
Condensed preview — 30 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (174K chars).
[
  {
    "path": ".gitignore",
    "chars": 569,
    "preview": "# Byte-compiled / optimized / DLL files\n__pycache__/\n*.py[cod]\n*$py.class\n\n# C extensions\n*.so\n\n# Distribution / packagi"
  },
  {
    "path": "LICENSE.txt",
    "chars": 1062,
    "preview": "MIT License\n\nCopyright (c) 2023 jumin\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof t"
  },
  {
    "path": "README.md",
    "chars": 3106,
    "preview": "# Diffusion Probabilistic Models for Scene-Scale 3D Categorical Data\n\n📌[Paper](http://arxiv.org/abs/2301.00527)        \n"
  },
  {
    "path": "SSC_train.py",
    "chars": 8251,
    "preview": "import argparse\nimport os\nimport warnings\nimport time\nimport torch\nfrom utils.intermediate_vis import Vis_iter\n\nfrom dat"
  },
  {
    "path": "__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "datasets/carla.yaml",
    "chars": 2116,
    "preview": "color_map :\n  0 : [255, 255, 255]  # None\n  1 : [70, 70, 70]     # Building\n  2 : [100, 40, 40]    # Fences\n  3 : [55, 9"
  },
  {
    "path": "datasets/carla_dataset.py",
    "chars": 8711,
    "preview": "import os\nimport numpy as np\nimport random\nimport json\nimport yaml\nimport torch\nimport numba as nb\nfrom torch.utils.data"
  },
  {
    "path": "datasets/data.py",
    "chars": 3303,
    "preview": "import os\nimport math\nimport torch\nimport numpy as np\nfrom torch.utils.data import DataLoader\nfrom datasets.carla_datase"
  },
  {
    "path": "layers/Ablation/wo_diffusion.py",
    "chars": 1291,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nfrom layers.Latent_Level.stage"
  },
  {
    "path": "layers/Latent_Level/stage1/model.py",
    "chars": 15109,
    "preview": "import numpy as np\nimport math\nimport torch\nfrom torch import nn\nimport torch.nn.functional as F\nfrom einops import rear"
  },
  {
    "path": "layers/Latent_Level/stage1/vector_quantizer.py",
    "chars": 2028,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\n\nclass VectorQuantizer(nn.Module):\n\n    def __ini"
  },
  {
    "path": "layers/Latent_Level/stage1/vqvae.py",
    "chars": 2498,
    "preview": "import torch\nfrom torch import nn\nfrom torch.nn import functional as F\nimport numpy as np\nimport math\nfrom utils.loss im"
  },
  {
    "path": "layers/Latent_Level/stage2/Gen_diffusion.py",
    "chars": 11604,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport math\nfrom inspect import isfunction\nfrom layers.L"
  },
  {
    "path": "layers/Latent_Level/stage2/gen_denoise.py",
    "chars": 12611,
    "preview": "import math\nfrom mimetypes import init\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional "
  },
  {
    "path": "layers/Voxel_Level/Con_Diffusion.py",
    "chars": 11181,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport math\nfrom inspect import isfunction\nfrom layers.V"
  },
  {
    "path": "layers/Voxel_Level/Gen_Diffusion.py",
    "chars": 10731,
    "preview": "import torch\nimport torch.nn.functional as F\nimport numpy as np\nimport math\nfrom inspect import isfunction\nfrom layers.V"
  },
  {
    "path": "layers/Voxel_Level/denoise.py",
    "chars": 13160,
    "preview": "import math\nfrom mimetypes import init\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional "
  },
  {
    "path": "layers/Voxel_Level/gen_denoise.py",
    "chars": 13151,
    "preview": "import math\nfrom mimetypes import init\nimport torch\nimport torch.nn as nn\nimport numpy as np\nimport torch.nn.functional "
  },
  {
    "path": "layers/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "requirements.txt",
    "chars": 98,
    "preview": "numpy\ntorch\nscipy\nscikit-learn\nmatplotlib\ntqdm\nopen3d\npyyaml\nprettytable\ntensorboard\nnumba\neinops\n"
  },
  {
    "path": "setup.py",
    "chars": 465,
    "preview": "from setuptools import setup, find_packages\n\nsetup(\n    name=\"scene_scale_diffusion\",\n    version=\"0.1\",\n    author=\"Lee"
  },
  {
    "path": "simple_visualize.py",
    "chars": 1707,
    "preview": "import os\nimport numpy as np\nimport open3d as o3d\nimport argparse\nimport yaml\n\ndef load_config(yaml_path):\n    with open"
  },
  {
    "path": "train.py",
    "chars": 14831,
    "preview": "from dataclasses import astuple\nimport torch\nimport argparse\nimport numpy as np\nimport os\nimport pickle\nimport torch\nimp"
  },
  {
    "path": "utils/cuda.py",
    "chars": 2762,
    "preview": "import os\n\nimport os\n\nimport torch\nfrom torch import distributed as dist\nfrom torch import multiprocessing as mp\n\nimport"
  },
  {
    "path": "utils/dicts.py",
    "chars": 3314,
    "preview": "import copy\nimport math\nimport pickle\n\nimport torch\nfrom torch import distributed as dist\nfrom torch.utils import data\n\n"
  },
  {
    "path": "utils/intermediate_vis.py",
    "chars": 3954,
    "preview": "from dataclasses import astuple\nimport torch\nimport argparse\nimport numpy as np\nimport os\nimport pickle\nimport torch\nimp"
  },
  {
    "path": "utils/loss.py",
    "chars": 4811,
    "preview": "import math\nimport torch\nfrom torch.autograd import Variable\nimport torch.nn.functional as F\nimport numpy as np\ntry:\n   "
  },
  {
    "path": "utils/multistep.py",
    "chars": 1587,
    "preview": "import torch.optim as optim\nfrom torch.optim.lr_scheduler import MultiStepLR\nfrom torch.optim.lr_scheduler import _LRSch"
  },
  {
    "path": "utils/tables.py",
    "chars": 7019,
    "preview": "from prettytable import PrettyTable\nimport torch\nimport os\nimport pickle\nimport numpy as np\nimport torch.nn.functional a"
  },
  {
    "path": "visualization.py",
    "chars": 6147,
    "preview": "import os\r\nimport open3d as o3d\r\nimport open3d.visualization.gui as gui\r\nimport open3d.visualization.rendering as render"
  }
]

About this extraction

This page contains the full source code of the zoomin-lee/scene-scale-diffusion GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 30 files (163.3 KB), approximately 45.8k tokens, and a symbol index with 295 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!